Skip to main content

harn_rules/
evaluator.rs

1//! The relational + composite matching algebra (#2833).
2//!
3//! A [`crate::model::RuleNode`] compiles to a [`CompiledNode`] tree, which
4//! the evaluator walks against a parsed source tree. A node matches a
5//! tree-sitter node iff its **atomic** leaf matches *and* every
6//! **relational** (`inside` / `has` / `follows` / `precedes`) and
7//! **composite** (`all` / `any` / `not` / `matches`) part holds — every set
8//! key is ANDed.
9//!
10//! Candidates for the top rule are seeded cheaply (the atomic pattern query
11//! in one pass, or a kind/regex/whole-tree walk); each candidate is then
12//! checked in full. Relational sub-rules run their own atomic match rooted
13//! at the ancestor/descendant/sibling under test.
14
15use std::collections::BTreeMap;
16
17use harn_hostlib::ast::{api, Language};
18use regex::Regex;
19use streaming_iterator::StreamingIterator;
20use tree_sitter::{Node, Query, QueryCursor};
21
22use crate::engine::{Binding, Span};
23use crate::error::RulesError;
24use crate::model::{AtomicMatcher, RuleNode, StopBy, StopKeyword};
25use crate::pattern::{compile_pattern, ROOT_CAPTURE};
26
27/// Metavar bindings accumulated while matching a node.
28type Bindings = BTreeMap<String, Binding>;
29
30/// One node that matched the top rule, with its bindings.
31pub struct EvalMatch {
32    /// The matched node's span.
33    pub span: Span,
34    /// The matched node's text.
35    pub text: String,
36    /// Metavar bindings (captured + threaded from relational sub-matches).
37    pub bindings: Bindings,
38}
39
40/// A compiled rule tree: the top node plus the utility rules `matches`
41/// references.
42pub struct CompiledRuleTree {
43    top: CompiledNode,
44    utils: BTreeMap<String, CompiledNode>,
45}
46
47struct CompiledNode {
48    atomic: Option<CompiledAtomic>,
49    inside: Option<Box<CompiledRel>>,
50    has: Option<Box<CompiledRel>>,
51    follows: Option<Box<CompiledRel>>,
52    precedes: Option<Box<CompiledRel>>,
53    all: Vec<CompiledNode>,
54    any: Vec<CompiledNode>,
55    not: Option<Box<CompiledNode>>,
56    matches: Option<String>,
57}
58
59struct CompiledRel {
60    node: CompiledNode,
61    stop_by: CompiledStopBy,
62    field: Option<String>,
63}
64
65enum CompiledStopBy {
66    Neighbor,
67    End,
68    Rule(Box<CompiledNode>),
69}
70
71enum CompiledAtomic {
72    Query { query: Query, metavars: Vec<String> },
73    Kind(String),
74    Regex(Regex),
75}
76
77impl CompiledRuleTree {
78    /// Compile a rule's `[rule]` node and its `[utils]` into a runnable
79    /// tree.
80    pub fn compile(
81        rule_id: &str,
82        language: Language,
83        top: &RuleNode,
84        utils: &BTreeMap<String, RuleNode>,
85    ) -> Result<Self, RulesError> {
86        if top.is_empty() {
87            return Err(RulesError::PatternCompile {
88                rule: rule_id.to_string(),
89                message: "rule node is empty (no atomic / relational / composite key)".into(),
90            });
91        }
92        let compiled_utils = utils
93            .iter()
94            .map(|(id, node)| Ok((id.clone(), compile_node(rule_id, language, node)?)))
95            .collect::<Result<BTreeMap<_, _>, RulesError>>()?;
96        Ok(CompiledRuleTree {
97            top: compile_node(rule_id, language, top)?,
98            utils: compiled_utils,
99        })
100    }
101
102    /// Find every node matching the top rule, in document order.
103    pub fn find(
104        &self,
105        rule_id: &str,
106        language: Language,
107        source: &str,
108    ) -> Result<Vec<EvalMatch>, RulesError> {
109        let tree = api::parse_tree(source, language).map_err(|err| RulesError::SourceParse {
110            rule: rule_id.to_string(),
111            message: err.to_string(),
112        })?;
113        let ctx = Ctx {
114            source,
115            utils: &self.utils,
116        };
117        let root = tree.root_node();
118
119        let mut seen: BTreeMap<(usize, usize), EvalMatch> = BTreeMap::new();
120        for node in seed_candidates(&self.top, &ctx, root) {
121            if let Some(bindings) = node_satisfies(&self.top, node, &ctx) {
122                let key = (node.start_byte(), node.end_byte());
123                seen.entry(key).or_insert_with(|| EvalMatch {
124                    span: Span::of(node),
125                    text: ctx.text(node),
126                    bindings,
127                });
128            }
129        }
130        Ok(seen.into_values().collect())
131    }
132}
133
134/// Per-run evaluation context.
135struct Ctx<'a> {
136    source: &'a str,
137    utils: &'a BTreeMap<String, CompiledNode>,
138}
139
140impl Ctx<'_> {
141    fn text(&self, node: Node<'_>) -> String {
142        self.source[node.start_byte()..node.end_byte()].to_string()
143    }
144}
145
146// ---------------------------------------------------------------------------
147// Compilation
148// ---------------------------------------------------------------------------
149
150fn compile_node(
151    rule_id: &str,
152    language: Language,
153    node: &RuleNode,
154) -> Result<CompiledNode, RulesError> {
155    let mkerr = |message: String| RulesError::PatternCompile {
156        rule: rule_id.to_string(),
157        message,
158    };
159
160    let atomic = match node.atomic().map_err(mkerr)? {
161        None => None,
162        Some(AtomicMatcher::Pattern(snippet)) => {
163            let ts_language = language
164                .ts_language()
165                .ok_or_else(|| mkerr(format!("grammar for `{}` unavailable", language.name())))?;
166            let compiled =
167                compile_pattern(&snippet, language).map_err(|m| mkerr(format!("pattern: {m}")))?;
168            let query = Query::new(&ts_language, &compiled.query).map_err(|e| {
169                RulesError::QueryRejected {
170                    rule: rule_id.to_string(),
171                    message: e.to_string(),
172                    query: compiled.query.clone(),
173                }
174            })?;
175            Some(CompiledAtomic::Query {
176                query,
177                metavars: compiled.metavars,
178            })
179        }
180        Some(AtomicMatcher::Kind(kind)) => Some(CompiledAtomic::Kind(kind)),
181        Some(AtomicMatcher::Regex(re)) => Some(CompiledAtomic::Regex(
182            Regex::new(&re).map_err(|e| mkerr(format!("regex `{re}` invalid: {e}")))?,
183        )),
184    };
185
186    let rel = |sub: &Option<Box<RuleNode>>| -> Result<Option<Box<CompiledRel>>, RulesError> {
187        match sub {
188            None => Ok(None),
189            Some(n) => Ok(Some(Box::new(compile_rel(rule_id, language, n)?))),
190        }
191    };
192
193    let compile_list = |list: &Option<Vec<RuleNode>>| -> Result<Vec<CompiledNode>, RulesError> {
194        list.iter()
195            .flatten()
196            .map(|n| compile_node(rule_id, language, n))
197            .collect()
198    };
199
200    Ok(CompiledNode {
201        atomic,
202        inside: rel(&node.inside)?,
203        has: rel(&node.has)?,
204        follows: rel(&node.follows)?,
205        precedes: rel(&node.precedes)?,
206        all: compile_list(&node.all)?,
207        any: compile_list(&node.any)?,
208        not: match &node.not {
209            None => None,
210            Some(n) => Some(Box::new(compile_node(rule_id, language, n)?)),
211        },
212        matches: node.matches.clone(),
213    })
214}
215
216fn compile_rel(
217    rule_id: &str,
218    language: Language,
219    node: &RuleNode,
220) -> Result<CompiledRel, RulesError> {
221    let stop_by = match &node.stop_by {
222        None | Some(StopBy::Keyword(StopKeyword::Neighbor)) => CompiledStopBy::Neighbor,
223        Some(StopBy::Keyword(StopKeyword::End)) => CompiledStopBy::End,
224        Some(StopBy::Rule(r)) => {
225            CompiledStopBy::Rule(Box::new(compile_node(rule_id, language, r)?))
226        }
227    };
228    Ok(CompiledRel {
229        node: compile_node(rule_id, language, node)?,
230        stop_by,
231        field: node.field.clone(),
232    })
233}
234
235// ---------------------------------------------------------------------------
236// Evaluation
237// ---------------------------------------------------------------------------
238
239/// Seed candidate nodes for the top rule. The atomic pattern query runs in
240/// one pass; kind/regex/composite-only rules fall back to a tree walk.
241fn seed_candidates<'t>(top: &CompiledNode, ctx: &Ctx<'_>, root: Node<'t>) -> Vec<Node<'t>> {
242    match &top.atomic {
243        Some(CompiledAtomic::Query { query, .. }) => {
244            let mut out = Vec::new();
245            let mut seen = std::collections::HashSet::new();
246            let root_index = root_capture_index(query);
247            let mut cursor = QueryCursor::new();
248            let mut it = cursor.matches(query, root, ctx.source.as_bytes());
249            while let Some(m) = it.next() {
250                for cap in m.captures {
251                    if Some(cap.index) == root_index && seen.insert(cap.node.id()) {
252                        out.push(cap.node);
253                    }
254                }
255            }
256            out
257        }
258        Some(CompiledAtomic::Kind(kind)) => {
259            let mut out = Vec::new();
260            for_each_named_descendant(root, &mut |n| {
261                if n.kind() == kind {
262                    out.push(n);
263                }
264            });
265            out
266        }
267        Some(CompiledAtomic::Regex(_)) | None => {
268            let mut out = Vec::new();
269            for_each_named_descendant(root, &mut |n| out.push(n));
270            out
271        }
272    }
273}
274
275/// Full match check for a specific `node`: atomic + relational + composite.
276fn node_satisfies(cnode: &CompiledNode, node: Node<'_>, ctx: &Ctx<'_>) -> Option<Bindings> {
277    let mut bindings = Bindings::new();
278
279    if let Some(atomic) = &cnode.atomic {
280        merge(&mut bindings, atomic_match(atomic, node, ctx)?);
281    }
282    if let Some(rel) = &cnode.inside {
283        merge(&mut bindings, eval_inside(rel, node, ctx)?);
284    }
285    if let Some(rel) = &cnode.has {
286        merge(&mut bindings, eval_has(rel, node, ctx)?);
287    }
288    if let Some(rel) = &cnode.follows {
289        merge(&mut bindings, eval_sibling(rel, node, ctx, Dir::Before)?);
290    }
291    if let Some(rel) = &cnode.precedes {
292        merge(&mut bindings, eval_sibling(rel, node, ctx, Dir::After)?);
293    }
294    for sub in &cnode.all {
295        merge(&mut bindings, node_satisfies(sub, node, ctx)?);
296    }
297    if !cnode.any.is_empty() {
298        let matched = cnode
299            .any
300            .iter()
301            .find_map(|sub| node_satisfies(sub, node, ctx));
302        merge(&mut bindings, matched?);
303    }
304    if let Some(not) = &cnode.not {
305        if node_satisfies(not, node, ctx).is_some() {
306            return None;
307        }
308    }
309    if let Some(id) = &cnode.matches {
310        let util = ctx.utils.get(id)?;
311        merge(&mut bindings, node_satisfies(util, node, ctx)?);
312    }
313
314    Some(bindings)
315}
316
317fn atomic_match(atomic: &CompiledAtomic, node: Node<'_>, ctx: &Ctx<'_>) -> Option<Bindings> {
318    match atomic {
319        CompiledAtomic::Kind(kind) => (node.kind() == kind).then(Bindings::new),
320        CompiledAtomic::Regex(re) => re.is_match(&ctx.text(node)).then(Bindings::new),
321        CompiledAtomic::Query { query, metavars } => {
322            let root_index = root_capture_index(query);
323            let names: Vec<&str> = query.capture_names().to_vec();
324            let mut cursor = QueryCursor::new();
325            let mut it = cursor.matches(query, node, ctx.source.as_bytes());
326            while let Some(m) = it.next() {
327                // The pattern must match `node` itself, not a descendant.
328                let roots_here = m
329                    .captures
330                    .iter()
331                    .any(|c| Some(c.index) == root_index && c.node.id() == node.id());
332                if !roots_here {
333                    continue;
334                }
335                let mut bindings = Bindings::new();
336                for cap in m.captures {
337                    let name = names[cap.index as usize];
338                    if metavars.iter().any(|mv| mv == name) {
339                        bindings.entry(name.to_string()).or_insert_with(|| Binding {
340                            text: ctx.text(cap.node),
341                            span: Span::of(cap.node),
342                        });
343                    }
344                }
345                return Some(bindings);
346            }
347            None
348        }
349    }
350}
351
352fn eval_inside(rel: &CompiledRel, node: Node<'_>, ctx: &Ctx<'_>) -> Option<Bindings> {
353    let mut current = node.parent();
354    let mut child = node;
355    while let Some(ancestor) = current {
356        if let CompiledStopBy::Rule(stop) = &rel.stop_by {
357            if node_satisfies(stop, ancestor, ctx).is_some()
358                && node_satisfies(&rel.node, ancestor, ctx).is_none()
359            {
360                // Hit the stop boundary without matching.
361                return None;
362            }
363        }
364        if let Some(b) = node_satisfies(&rel.node, ancestor, ctx) {
365            if field_ok(rel.field.as_deref(), ancestor, child) {
366                return Some(b);
367            }
368        }
369        if matches!(rel.stop_by, CompiledStopBy::Neighbor) {
370            return None;
371        }
372        child = ancestor;
373        current = ancestor.parent();
374    }
375    None
376}
377
378fn eval_has(rel: &CompiledRel, node: Node<'_>, ctx: &Ctx<'_>) -> Option<Bindings> {
379    let neighbor = matches!(rel.stop_by, CompiledStopBy::Neighbor);
380    let mut found: Option<Bindings> = None;
381    let mut cursor = node.walk();
382    let children: Vec<Node<'_>> = node.named_children(&mut cursor).collect();
383    for child in children {
384        if let Some(b) = node_satisfies(&rel.node, child, ctx) {
385            if field_ok(rel.field.as_deref(), node, child) {
386                found = Some(b);
387                break;
388            }
389        }
390        if !neighbor {
391            if let Some(b) = eval_has(rel, child, ctx) {
392                found = Some(b);
393                break;
394            }
395        }
396    }
397    found
398}
399
400enum Dir {
401    Before,
402    After,
403}
404
405fn eval_sibling(rel: &CompiledRel, node: Node<'_>, ctx: &Ctx<'_>, dir: Dir) -> Option<Bindings> {
406    let neighbor = matches!(rel.stop_by, CompiledStopBy::Neighbor);
407    let mut sib = match dir {
408        Dir::Before => node.prev_named_sibling(),
409        Dir::After => node.next_named_sibling(),
410    };
411    while let Some(s) = sib {
412        if let Some(b) = node_satisfies(&rel.node, s, ctx) {
413            return Some(b);
414        }
415        if neighbor {
416            return None;
417        }
418        sib = match dir {
419            Dir::Before => s.prev_named_sibling(),
420            Dir::After => s.next_named_sibling(),
421        };
422    }
423    None
424}
425
426/// When a relation names a `field`, the related child must sit in that
427/// field of the parent. `parent` is the ancestor; `child` is the node on
428/// the path toward the matched node.
429fn field_ok(field: Option<&str>, parent: Node<'_>, child: Node<'_>) -> bool {
430    match field {
431        None => true,
432        Some(name) => parent
433            .child_by_field_name(name)
434            .is_some_and(|f| f.id() == child.id()),
435    }
436}
437
438fn merge(into: &mut Bindings, from: Bindings) {
439    for (k, v) in from {
440        into.entry(k).or_insert(v);
441    }
442}
443
444fn root_capture_index(query: &Query) -> Option<u32> {
445    query
446        .capture_names()
447        .iter()
448        .position(|n| *n == ROOT_CAPTURE)
449        .map(|i| i as u32)
450}
451
452fn for_each_named_descendant<'t>(node: Node<'t>, f: &mut impl FnMut(Node<'t>)) {
453    let mut cursor = node.walk();
454    for child in node.named_children(&mut cursor) {
455        f(child);
456        for_each_named_descendant(child, f);
457    }
458}