Skip to main content

normalize_languages/
query_predicates.rs

1//! Predicate evaluation for tree-sitter queries.
2//!
3//! Tree-sitter compiles `#match?`, `#eq?`, etc. into `QueryPredicate` structs but
4//! does **not** evaluate them at match time — the caller is responsible for
5//! filtering matches that fail their predicates.
6//!
7//! [`satisfies_predicates`] evaluates the standard tree-sitter predicates so that
8//! query authors can use them in `.scm` files and have them honoured at runtime.
9
10use tree_sitter::{Query, QueryMatch, QueryPredicateArg};
11
12/// Return `true` if all predicates on `m`'s pattern are satisfied, `false` otherwise.
13///
14/// Supported predicates:
15/// - `#match?` — captured text must match the regex
16/// - `#not-match?` — captured text must not match the regex
17/// - `#eq?` — two captures/strings must be equal
18/// - `#not-eq?` — two captures/strings must not be equal
19///
20/// Unknown predicates pass (return `true`) so future predicates don't break existing
21/// queries.
22pub fn satisfies_predicates(query: &Query, m: &QueryMatch, source: &[u8]) -> bool {
23    for predicate in query.general_predicates(m.pattern_index) {
24        let op = predicate.operator.as_ref();
25        match op {
26            "match?" | "not-match?" => {
27                let args = &predicate.args;
28                if args.len() != 2 {
29                    continue;
30                }
31                let capture_index = match &args[0] {
32                    QueryPredicateArg::Capture(idx) => *idx,
33                    _ => continue,
34                };
35                let pattern = match &args[1] {
36                    QueryPredicateArg::String(s) => s.as_ref(),
37                    _ => continue,
38                };
39                let text = capture_text(m, capture_index, source);
40                let matches = regex_matches(pattern, text);
41                let want_match = op == "match?";
42                if matches != want_match {
43                    return false;
44                }
45            }
46            "eq?" | "not-eq?" => {
47                let args = &predicate.args;
48                if args.len() != 2 {
49                    continue;
50                }
51                let lhs = resolve_arg(&args[0], m, source);
52                let rhs = resolve_arg(&args[1], m, source);
53                let equal = lhs == rhs;
54                let want_eq = op == "eq?";
55                if equal != want_eq {
56                    return false;
57                }
58            }
59            // Unknown predicates pass so future predicates don't break existing queries.
60            _ => {}
61        }
62    }
63    true
64}
65
66// ── Helpers ──────────────────────────────────────────────────────────────────
67
68fn capture_text<'a>(m: &QueryMatch, capture_index: u32, source: &'a [u8]) -> &'a str {
69    m.captures
70        .iter()
71        .find(|c| c.index == capture_index)
72        .and_then(|c| c.node.utf8_text(source).ok())
73        .unwrap_or("")
74}
75
76fn resolve_arg<'a>(arg: &QueryPredicateArg, m: &'a QueryMatch, source: &'a [u8]) -> &'a str {
77    match arg {
78        QueryPredicateArg::Capture(idx) => capture_text(m, *idx, source),
79        QueryPredicateArg::String(s) => {
80            // SAFETY: we extend the lifetime here — the string is borrowed from the
81            // predicate which lives as long as the Query, which outlives this call.
82            // Callers hold the Query for the duration of the loop so this is safe.
83            unsafe { std::mem::transmute::<&str, &'a str>(s.as_ref()) }
84        }
85    }
86}
87
88/// Test whether `text` matches `pattern` as a regex.
89///
90/// Errors (invalid regex) are treated as non-matching so a bad predicate doesn't panic.
91fn regex_matches(pattern: &str, text: &str) -> bool {
92    // Compile the regex on each call. In practice predicates are called on every
93    // match so a caching approach would be better for hot paths, but this is correct
94    // and avoids adding a HashMap dependency to this helper.
95    match regex::Regex::new(pattern) {
96        Ok(re) => re.is_match(text),
97        Err(_) => false,
98    }
99}