use std::str;
use anyhow::Context as _;
use anyhow::Result;
use tracing::warn;
use tree_sitter::Node;
use tree_sitter::Parser;
use tree_sitter::Query;
use tree_sitter::QueryCursor;
use tree_sitter::StreamingIterator as _;
use tree_sitter::Tree;
use tree_sitter_bpf_c::LANGUAGE;
use crate::Point;
use crate::Range;
mod lints {
include!(concat!(env!("OUT_DIR"), "/lints.rs"));
}
impl From<tree_sitter::Point> for Point {
fn from(other: tree_sitter::Point) -> Self {
let tree_sitter::Point { row, column } = other;
Self { row, col: column }
}
}
impl From<tree_sitter::Range> for Range {
fn from(other: tree_sitter::Range) -> Self {
let tree_sitter::Range {
start_byte,
end_byte,
start_point,
end_point,
} = other;
Self {
bytes: start_byte..end_byte,
start_point: Point::from(start_point),
end_point: Point::from(end_point),
}
}
}
#[derive(Clone, Debug)]
pub struct LintMeta {
pub name: String,
#[doc(hidden)]
pub _non_exhaustive: (),
}
pub fn builtin_lints() -> impl ExactSizeIterator<Item = LintMeta> + DoubleEndedIterator {
lints::LINTS.iter().map(|(name, _code)| LintMeta {
name: name.to_string(),
_non_exhaustive: (),
})
}
#[derive(Clone, Debug)]
pub struct LintMatch {
pub lint_name: String,
pub message: String,
pub range: Range,
}
fn is_lint_disabled(lint_name: &str, mut node: Node, code: &[u8]) -> bool {
loop {
if let Some(s) = node.prev_sibling() {
if s.kind() == "comment" {
let comment = &code[s.start_byte()..s.end_byte()];
if let Ok(comment) = str::from_utf8(comment) {
let comment = comment.trim_start_matches("//");
let comment = comment.trim_start_matches("/*");
let comment = comment.trim_end_matches("*/");
let comment = comment.trim();
if let Some(comment) = comment.strip_prefix("bpflint:") {
let directive = comment.trim();
match directive.strip_prefix("disable=") {
Some("all") => break true,
Some(disable) if disable == lint_name => break true,
_ => (),
}
}
} else {
warn!(
"encountered invalid UTF-8 in code comment at bytes `{}..{}`",
s.start_byte(),
s.end_byte()
);
}
}
}
match node.parent() {
Some(parent) => node = parent,
None => break false,
}
}
}
fn lint_impl(tree: &Tree, code: &[u8], lint_src: &str, lint_name: &str) -> Result<Vec<LintMatch>> {
let query =
Query::new(&LANGUAGE.into(), lint_src).with_context(|| "failed to compile lint query")?;
let mut query_cursor = QueryCursor::new();
let mut results = Vec::new();
let mut matches = query_cursor.matches(&query, tree.root_node(), code);
while let Some(m) = matches.next() {
for capture in m.captures {
if is_lint_disabled(lint_name, capture.node, code) {
continue;
}
let capture_name = query.capture_names()[capture.index as usize];
if capture_name.starts_with("__") {
continue
}
let settings = query.property_settings(m.pattern_index);
let setting = settings
.iter()
.find(|prop| &*prop.key == "message")
.with_context(|| format!("{lint_name}: failed to find `message` property"))?;
let r#match = LintMatch {
lint_name: lint_name.to_string(),
message: setting
.value
.as_ref()
.with_context(|| format!("{lint_name}: `message` property has no value set"))?
.to_string(),
range: Range::from(capture.node.range()),
};
let () = results.push(r#match);
}
}
if query_cursor.did_exceed_match_limit() {
warn!("query exceeded maximum number of in-progress captures");
}
Ok(results)
}
fn lint_multi(code: &[u8], lints: &[(&str, &str)]) -> Result<Vec<LintMatch>> {
let mut parser = Parser::new();
let () = parser
.set_language(&LANGUAGE.into())
.context("failed to load C parser")?;
let tree = parser
.parse(code, None)
.context("failed to provided source code")?;
let mut results = Vec::new();
for (lint_name, lint_src) in lints {
let matches = lint_impl(&tree, code, lint_src, lint_name)?;
let () = results.extend(matches);
}
let () = results.sort_by(|match1, match2| {
match1
.range
.start_point
.cmp(&match2.range.start_point)
.then_with(|| match1.range.end_point.cmp(&match2.range.end_point))
});
Ok(results)
}
pub fn lint(code: &[u8]) -> Result<Vec<LintMatch>> {
lint_multi(code, &lints::LINTS)
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use crate::Point;
static LINT_FOO: (&str, &str) = (
"foo",
r#"
(call_expression
function: (identifier) @function (#eq? @function "foo")
(#set! "message" "foo")
)
"#,
);
#[test]
fn missing_message_property() {
let code = indoc! { r#"
test_fn(/* doesn't matter */);
"# };
let lint = indoc! { r#"
(call_expression
function: (identifier) @function (#eq? @function "test_fn")
)
"# };
let err = lint_multi(code.as_bytes(), &[("test_fn", lint)]).unwrap_err();
assert_eq!(
err.to_string(),
"test_fn: failed to find `message` property",
"{err}"
);
}
#[test]
fn internal_capture_reporting() {
let lint_bar = indoc! { r#"
(call_expression
function: (identifier) @__function (#eq? @__function "bar")
(#set! "message" "bar")
)
"# };
let code = indoc! { r#"
bar();
"# };
let matches = lint_multi(code.as_bytes(), &[("bar", lint_bar)]).unwrap();
assert!(matches.is_empty(), "{matches:?}");
}
#[test]
fn validate_lint_queries() {
for (name, code) in lints::LINTS {
let query = Query::new(&LANGUAGE.into(), code).unwrap();
assert_eq!(
query.pattern_count(),
1,
"lint `{name}` has too many pattern matches: only a single one is supported currently"
);
let settings = query.property_settings(0);
let setting = settings
.iter()
.find(|prop| &*prop.key == "message")
.expect("`message` property is missing for lint `{name}`");
let message = setting
.value
.as_ref()
.unwrap_or_else(|| {
panic!("lint `{name}` has no `message` property has no value set")
})
.as_ref();
let last = message.chars().last().unwrap();
assert!(
!['.', '!', '?'].contains(&last),
"`message` property of lint `{name}` should be concise and not a fully blown sentence with punctuation"
);
}
}
#[test]
fn basic_linting() {
let code = indoc! { r#"
/* A handler for something */
SEC("tp_btf/sched_switch")
int handle__sched_switch(u64 *ctx)
{
struct task_struct *prev = (struct task_struct *)ctx[1];
struct event event = {0};
bpf_probe_read(event.comm, TASK_COMM_LEN, prev->comm);
return 0;
}
"# };
let matches = lint(code.as_bytes()).unwrap();
assert_eq!(matches.len(), 1);
let LintMatch {
lint_name,
message,
range,
} = &matches[0];
assert_eq!(lint_name, "probe-read");
assert!(
message.starts_with("bpf_probe_read() is deprecated"),
"{message}"
);
assert_eq!(&code[range.bytes.clone()], "bpf_probe_read");
assert_eq!(range.start_point, Point { row: 6, col: 4 });
assert_eq!(range.end_point, Point { row: 6, col: 18 });
}
#[test]
fn sorted_match_reporting() {
let lint_bar = indoc! { r#"
(call_expression
function: (identifier) @function (#eq? @function "bar")
(#set! "message" "bar")
)
"# };
let code = indoc! { r#"
bar();
foo();
"# };
let matches = lint_multi(code.as_bytes(), &[LINT_FOO, ("bar", lint_bar)]).unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].lint_name, "bar");
assert_eq!(matches[1].lint_name, "foo");
}
#[test]
fn lint_disabling() {
let code = indoc! { r#"
/* bpflint: disable=foo */
foo();
// bpflint: disable=foo
foo();
// bpflint: disable=all
foo();
"# };
let matches = lint_multi(code.as_bytes(), &[LINT_FOO]).unwrap();
assert_eq!(matches.len(), 0, "{matches:?}");
}
#[test]
fn lint_disabling_recursive() {
let code = indoc! { r#"
/* bpflint: disable=foo */
{
{
foo();
}
}
"# };
let matches = lint_multi(code.as_bytes(), &[LINT_FOO]).unwrap();
assert_eq!(matches.len(), 0, "{matches:?}");
let code = indoc! { r#"
/* bpflint: disable=foo */
void test_fn(void) {
foo();
}
"# };
let matches = lint_multi(code.as_bytes(), &[LINT_FOO]).unwrap();
assert_eq!(matches.len(), 0, "{matches:?}");
}
#[test]
fn lint_invalid_disabling() {
let code = indoc! { r#"
/* bpflint: disabled=foo */
foo();
/* disabled=foo */
foo();
// disabled=foo
foo();
// bpflint: foo
foo();
// bpflint: disable=bar
foo();
void test_fn(void) {
/* bpflint: disable=foo */
foobar();
foo();
}
"# };
let matches = lint_multi(code.as_bytes(), &[LINT_FOO]).unwrap();
assert_eq!(matches.len(), 6, "{matches:?}");
}
}