Skip to main content

harn_rules/
engine.rs

1//! Compile a [`Rule`] into a runnable matcher and run it against source.
2//!
3//! The atomic tier supports three matcher forms, all reduced to a single
4//! [`RuleMatch`] stream:
5//!
6//! - `pattern` → compiled to a tree-sitter query via [`crate::pattern`].
7//! - `kind` → the trivial query `(<kind>) @__match`.
8//! - `regex` → a text regex over the source, yielding spans with no AST
9//!   metavar bindings.
10
11use std::collections::BTreeMap;
12
13use harn_hostlib::ast::{api, Language};
14use streaming_iterator::StreamingIterator;
15use tree_sitter::{Query, QueryCursor};
16
17use crate::error::RulesError;
18use crate::model::{AtomicMatcher, Rule};
19use crate::pattern::{compile_pattern, ROOT_CAPTURE};
20
21/// A byte + row/col span. Rows/cols are 0-based, matching the rest of the
22/// Harn AST wire format.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct Span {
25    /// Start byte offset.
26    pub start_byte: usize,
27    /// End byte offset (exclusive).
28    pub end_byte: usize,
29    /// 0-based start row.
30    pub start_row: usize,
31    /// 0-based start column.
32    pub start_col: usize,
33    /// 0-based end row.
34    pub end_row: usize,
35    /// 0-based end column.
36    pub end_col: usize,
37}
38
39impl Span {
40    fn of(node: tree_sitter::Node<'_>) -> Self {
41        let start = node.start_position();
42        let end = node.end_position();
43        Span {
44            start_byte: node.start_byte(),
45            end_byte: node.end_byte(),
46            start_row: start.row,
47            start_col: start.column,
48            end_row: end.row,
49            end_col: end.column,
50        }
51    }
52}
53
54/// A metavariable binding: the captured text plus where it lives.
55#[derive(Debug, Clone)]
56pub struct Binding {
57    /// The captured source text.
58    pub text: String,
59    /// The captured node's span.
60    pub span: Span,
61}
62
63/// One match of a rule against a file.
64#[derive(Debug, Clone)]
65pub struct RuleMatch {
66    /// The rule that produced this match.
67    pub rule_id: String,
68    /// The whole matched range (the pattern root, or the regex span).
69    pub span: Span,
70    /// The matched source text.
71    pub text: String,
72    /// Metavar bindings, keyed by name (without the leading `$`). Empty for
73    /// `kind` and `regex` matchers.
74    pub bindings: BTreeMap<String, Binding>,
75}
76
77/// A rule whose matcher has been compiled and is ready to run.
78pub struct CompiledRule {
79    rule_id: String,
80    language: Language,
81    matcher: CompiledMatcher,
82}
83
84enum CompiledMatcher {
85    /// A tree-sitter query plus the metavar names to extract. Covers both
86    /// `pattern` and `kind` forms.
87    Query { query: Query, metavars: Vec<String> },
88    /// A text regex over the whole source.
89    Regex(regex::Regex),
90}
91
92impl CompiledRule {
93    /// Resolve the rule's language and grammar, then compile its matcher.
94    pub fn compile(rule: &Rule) -> Result<Self, RulesError> {
95        let language =
96            Language::from_name(&rule.language).ok_or_else(|| RulesError::UnknownLanguage {
97                rule: rule.id.clone(),
98                language: rule.language.clone(),
99            })?;
100
101        let matcher = match rule
102            .rule
103            .resolve()
104            .map_err(|message| RulesError::PatternCompile {
105                rule: rule.id.clone(),
106                message,
107            })? {
108            AtomicMatcher::Pattern(snippet) => {
109                let ts_language =
110                    language
111                        .ts_language()
112                        .ok_or_else(|| RulesError::GrammarUnavailable {
113                            rule: rule.id.clone(),
114                            language: language.name().to_string(),
115                        })?;
116                let compiled = compile_pattern(&snippet, language).map_err(|message| {
117                    RulesError::PatternCompile {
118                        rule: rule.id.clone(),
119                        message,
120                    }
121                })?;
122                let query = Query::new(&ts_language, &compiled.query).map_err(|err| {
123                    RulesError::QueryRejected {
124                        rule: rule.id.clone(),
125                        message: err.to_string(),
126                        query: compiled.query.clone(),
127                    }
128                })?;
129                CompiledMatcher::Query {
130                    query,
131                    metavars: compiled.metavars,
132                }
133            }
134            AtomicMatcher::Kind(kind) => {
135                let ts_language =
136                    language
137                        .ts_language()
138                        .ok_or_else(|| RulesError::GrammarUnavailable {
139                            rule: rule.id.clone(),
140                            language: language.name().to_string(),
141                        })?;
142                let query_text = format!("({kind}) @{ROOT_CAPTURE}");
143                let query = Query::new(&ts_language, &query_text).map_err(|err| {
144                    RulesError::QueryRejected {
145                        rule: rule.id.clone(),
146                        message: err.to_string(),
147                        query: query_text.clone(),
148                    }
149                })?;
150                CompiledMatcher::Query {
151                    query,
152                    metavars: Vec::new(),
153                }
154            }
155            AtomicMatcher::Regex(pattern) => {
156                let regex =
157                    regex::Regex::new(&pattern).map_err(|err| RulesError::PatternCompile {
158                        rule: rule.id.clone(),
159                        message: format!("invalid regex `{pattern}`: {err}"),
160                    })?;
161                CompiledMatcher::Regex(regex)
162            }
163        };
164
165        Ok(CompiledRule {
166            rule_id: rule.id.clone(),
167            language,
168            matcher,
169        })
170    }
171
172    /// The language this rule targets.
173    pub fn language(&self) -> Language {
174        self.language
175    }
176
177    /// Run the compiled rule against `source`, returning matches in
178    /// document order.
179    pub fn run(&self, source: &str) -> Result<Vec<RuleMatch>, RulesError> {
180        match &self.matcher {
181            CompiledMatcher::Query { query, metavars } => self.run_query(query, metavars, source),
182            CompiledMatcher::Regex(regex) => Ok(self.run_regex(regex, source)),
183        }
184    }
185
186    fn run_query(
187        &self,
188        query: &Query,
189        metavars: &[String],
190        source: &str,
191    ) -> Result<Vec<RuleMatch>, RulesError> {
192        let tree =
193            api::parse_tree(source, self.language).map_err(|err| RulesError::SourceParse {
194                rule: self.rule_id.clone(),
195                message: err.to_string(),
196            })?;
197        let names: Vec<&str> = query.capture_names().to_vec();
198        let bytes = source.as_bytes();
199
200        let mut cursor = QueryCursor::new();
201        let mut it = cursor.matches(query, tree.root_node(), bytes);
202        let mut matches = Vec::new();
203        while let Some(m) = it.next() {
204            let mut root: Option<Span> = None;
205            let mut root_text = String::new();
206            let mut bindings: BTreeMap<String, Binding> = BTreeMap::new();
207            for cap in m.captures {
208                let name = names[cap.index as usize];
209                let span = Span::of(cap.node);
210                let text = source[cap.node.start_byte()..cap.node.end_byte()].to_string();
211                if name == ROOT_CAPTURE {
212                    root = Some(span);
213                    root_text = text;
214                } else if metavars.iter().any(|m| m == name) {
215                    // Canonical metavar capture; unification helpers carry a
216                    // `.` and never appear in `metavars`, so they are skipped.
217                    bindings
218                        .entry(name.to_string())
219                        .or_insert(Binding { text, span });
220                }
221            }
222            if let Some(span) = root {
223                matches.push(RuleMatch {
224                    rule_id: self.rule_id.clone(),
225                    span,
226                    text: root_text,
227                    bindings,
228                });
229            }
230        }
231        // Tree-sitter yields matches in query-eval order; sort to document
232        // order for a stable, intuitive result.
233        matches.sort_by_key(|m| (m.span.start_byte, m.span.end_byte));
234        Ok(matches)
235    }
236
237    fn run_regex(&self, regex: &regex::Regex, source: &str) -> Vec<RuleMatch> {
238        let mut matches = Vec::new();
239        for m in regex.find_iter(source) {
240            let span = byte_span(source, m.start(), m.end());
241            matches.push(RuleMatch {
242                rule_id: self.rule_id.clone(),
243                span,
244                text: m.as_str().to_string(),
245                bindings: BTreeMap::new(),
246            });
247        }
248        matches
249    }
250}
251
252/// Compute a [`Span`] for a byte range by counting rows/cols. Used by the
253/// regex matcher, which has no tree-sitter node to read positions from.
254fn byte_span(source: &str, start: usize, end: usize) -> Span {
255    let (start_row, start_col) = row_col(source, start);
256    let (end_row, end_col) = row_col(source, end);
257    Span {
258        start_byte: start,
259        end_byte: end,
260        start_row,
261        start_col,
262        end_row,
263        end_col,
264    }
265}
266
267fn row_col(source: &str, byte: usize) -> (usize, usize) {
268    let mut row = 0;
269    let mut col = 0;
270    for (i, ch) in source.char_indices() {
271        if i >= byte {
272            break;
273        }
274        if ch == '\n' {
275            row += 1;
276            col = 0;
277        } else {
278            col += 1;
279        }
280    }
281    (row, col)
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::model::Rule;
288
289    fn rule(toml: &str) -> CompiledRule {
290        let parsed = Rule::from_toml_str(toml).expect("rule parses");
291        CompiledRule::compile(&parsed).expect("rule compiles")
292    }
293
294    #[test]
295    fn pattern_rule_binds_metavars() {
296        let compiled = rule(
297            r#"
298            id = "destructure-default"
299            language = "typescript"
300            fix = "{ $KEY: $SRC }"
301            [rule]
302            pattern = "$SRC?.$KEY ?? $DEFAULT"
303            "#,
304        );
305        let matches = compiled
306            .run("const a = cfg?.timeout ?? 30;\nconst b = opts?.retries ?? 3;\n")
307            .unwrap();
308        assert_eq!(matches.len(), 2);
309        assert_eq!(matches[0].bindings["SRC"].text, "cfg");
310        assert_eq!(matches[0].bindings["KEY"].text, "timeout");
311        assert_eq!(matches[0].bindings["DEFAULT"].text, "30");
312        assert_eq!(matches[1].bindings["SRC"].text, "opts");
313        // The match span covers the whole expression.
314        assert_eq!(matches[0].text, "cfg?.timeout ?? 30");
315        assert_eq!(matches[0].span.start_row, 0);
316        assert_eq!(matches[1].span.start_row, 1);
317    }
318
319    #[test]
320    fn kind_rule_matches_node_kind() {
321        let compiled = rule(
322            r#"
323            id = "find-calls"
324            language = "python"
325            [rule]
326            kind = "call"
327            "#,
328        );
329        let matches = compiled.run("print(x)\nlog(y)\n").unwrap();
330        assert_eq!(matches.len(), 2);
331        assert_eq!(matches[0].text, "print(x)");
332        assert!(matches[0].bindings.is_empty());
333    }
334
335    #[test]
336    fn regex_rule_matches_text() {
337        let compiled = rule(
338            r#"
339            id = "todo"
340            language = "rust"
341            message = "Found a TODO"
342            [rule]
343            regex = "TODO\\(\\w+\\)"
344            "#,
345        );
346        let matches = compiled
347            .run("fn f() {\n    // TODO(ken) fix\n    // todo lower\n}\n")
348            .unwrap();
349        assert_eq!(matches.len(), 1);
350        assert_eq!(matches[0].text, "TODO(ken)");
351        assert_eq!(matches[0].span.start_row, 1);
352    }
353
354    #[test]
355    fn unknown_language_is_an_error() {
356        let parsed = Rule::from_toml_str(
357            r#"
358            id = "x"
359            language = "cobol"
360            [rule]
361            kind = "foo"
362            "#,
363        )
364        .unwrap();
365        assert!(matches!(
366            CompiledRule::compile(&parsed),
367            Err(RulesError::UnknownLanguage { .. })
368        ));
369    }
370
371    #[test]
372    fn invalid_pattern_surfaces_compile_error() {
373        let parsed = Rule::from_toml_str(
374            r#"
375            id = "x"
376            language = "typescript"
377            [rule]
378            pattern = "foo($$$ARGS)"
379            "#,
380        )
381        .unwrap();
382        assert!(matches!(
383            CompiledRule::compile(&parsed),
384            Err(RulesError::PatternCompile { .. })
385        ));
386    }
387}