Skip to main content

amql_selector/
matcher.rs

1//! Match parsed selector ASTs against structured node data.
2
3use crate::selector::{Combinator, CompoundSelector, SelectorAst};
4use crate::types::{AttrName, TagName};
5use rustc_hash::FxHashMap;
6use serde_json::Value as JsonValue;
7
8/// A matchable node — works for both annotations and code elements.
9pub trait Matchable {
10    /// The element's tag name (e.g. "function", "controller").
11    fn tag(&self) -> &TagName;
12    /// Key-value attributes on this node.
13    fn attrs(&self) -> &FxHashMap<AttrName, JsonValue>;
14    /// Parent node, if any, for combinator matching.
15    fn parent(&self) -> Option<&dyn Matchable>;
16}
17
18/// Test if a single node matches a compound selector (tag + attributes).
19pub fn matches_compound(node: &dyn Matchable, compound: &CompoundSelector) -> bool {
20    if let Some(ref tag) = compound.tag {
21        if *node.tag() != **tag {
22            return false;
23        }
24    }
25
26    for pred in &compound.attrs {
27        if !matches_attr_predicate(node, pred) {
28            return false;
29        }
30    }
31
32    true
33}
34
35/// Filter an array of flat nodes by a selector AST.
36/// Returns indices into the input slice for matched nodes.
37#[must_use]
38pub fn filter_by_selector(nodes: &[&dyn Matchable], selector: &SelectorAst) -> Vec<usize> {
39    if selector.compounds.is_empty() {
40        return vec![];
41    }
42
43    if selector.compounds.len() == 1 {
44        return nodes
45            .iter()
46            .enumerate()
47            .filter(|(_, n)| matches_compound(**n, &selector.compounds[0]))
48            .map(|(i, _)| i)
49            .collect();
50    }
51
52    // Multi-compound: the last compound is the target
53    let target = &selector.compounds[selector.compounds.len() - 1];
54    nodes
55        .iter()
56        .enumerate()
57        .filter(|(_, node)| {
58            if !matches_compound(**node, target) {
59                return false;
60            }
61            verify_chain(
62                **node,
63                &selector.compounds,
64                selector.compounds.len() as isize - 2,
65            )
66        })
67        .map(|(i, _)| i)
68        .collect()
69}
70
71/// Filter nodes using parent indices for combinator matching.
72/// `parent_indices[i]` is the index of node i's parent in `nodes`, or `None`.
73#[must_use]
74pub fn filter_by_selector_indexed(
75    nodes: &[&dyn Matchable],
76    parent_indices: &[Option<usize>],
77    selector: &SelectorAst,
78) -> Vec<usize> {
79    if selector.compounds.is_empty() {
80        return vec![];
81    }
82
83    if selector.compounds.len() == 1 {
84        return nodes
85            .iter()
86            .enumerate()
87            .filter(|(_, n)| matches_compound(**n, &selector.compounds[0]))
88            .map(|(i, _)| i)
89            .collect();
90    }
91
92    let target = &selector.compounds[selector.compounds.len() - 1];
93    nodes
94        .iter()
95        .enumerate()
96        .filter(|(i, node)| {
97            if !matches_compound(**node, target) {
98                return false;
99            }
100            verify_chain_indexed(
101                nodes,
102                parent_indices,
103                *i,
104                &selector.compounds,
105                selector.compounds.len() as isize - 2,
106            )
107        })
108        .map(|(i, _)| i)
109        .collect()
110}
111
112fn verify_chain_indexed(
113    nodes: &[&dyn Matchable],
114    parent_indices: &[Option<usize>],
115    node_idx: usize,
116    compounds: &[CompoundSelector],
117    index: isize,
118) -> bool {
119    if index < 0 {
120        return true;
121    }
122    let idx = index as usize;
123    let compound = &compounds[idx];
124    let combinator = compounds
125        .get(idx + 1)
126        .and_then(|c| c.combinator)
127        .unwrap_or(Combinator::Descendant);
128
129    match combinator {
130        Combinator::Child => match parent_indices[node_idx] {
131            Some(parent_idx) => {
132                if !matches_compound(nodes[parent_idx], compound) {
133                    return false;
134                }
135                verify_chain_indexed(nodes, parent_indices, parent_idx, compounds, index - 1)
136            }
137            None => false,
138        },
139        Combinator::Descendant => {
140            let mut current = parent_indices[node_idx];
141            let mut depth = 0;
142            while let Some(ancestor_idx) = current {
143                depth += 1;
144                // Defensive: stop if depth exceeds node count (cycle in parent_indices)
145                if depth > nodes.len() {
146                    return false;
147                }
148                if matches_compound(nodes[ancestor_idx], compound)
149                    && verify_chain_indexed(
150                        nodes,
151                        parent_indices,
152                        ancestor_idx,
153                        compounds,
154                        index - 1,
155                    )
156                {
157                    return true;
158                }
159                current = parent_indices[ancestor_idx];
160            }
161            false
162        }
163        Combinator::AdjacentSibling | Combinator::GeneralSibling => false,
164    }
165}
166
167/// Maximum ancestor depth before bailing out (defensive against cyclic parent refs).
168const MAX_ANCESTOR_DEPTH: usize = 1024;
169
170fn verify_chain(node: &dyn Matchable, compounds: &[CompoundSelector], index: isize) -> bool {
171    if index < 0 {
172        return true;
173    }
174    let idx = index as usize;
175    let compound = &compounds[idx];
176    let combinator = compounds
177        .get(idx + 1)
178        .and_then(|c| c.combinator)
179        .unwrap_or(Combinator::Descendant);
180
181    match combinator {
182        Combinator::Child => {
183            // Direct parent must match
184            match node.parent() {
185                Some(parent) => {
186                    if !matches_compound(parent, compound) {
187                        return false;
188                    }
189                    verify_chain(parent, compounds, index - 1)
190                }
191                None => false,
192            }
193        }
194        Combinator::Descendant => {
195            // Any ancestor must match; backtrack if chain fails at a higher compound
196            let mut ancestor = node.parent();
197            let mut depth = 0;
198            while let Some(anc) = ancestor {
199                depth += 1;
200                // Defensive: stop if depth exceeds limit (cycle in parent refs)
201                if depth > MAX_ANCESTOR_DEPTH {
202                    return false;
203                }
204                if matches_compound(anc, compound) && verify_chain(anc, compounds, index - 1) {
205                    return true;
206                }
207                ancestor = anc.parent();
208            }
209            false
210        }
211        // Sibling combinators are rejected at parse time; this arm is defensive.
212        Combinator::AdjacentSibling | Combinator::GeneralSibling => false,
213    }
214}
215
216fn matches_attr_predicate(node: &dyn Matchable, pred: &amql_predicates::AttrPredicate) -> bool {
217    let value = node.attrs().get(pred.name.as_str());
218
219    // Bare attribute: [async] → just check existence and truthiness
220    if pred.op.is_none() {
221        return match value {
222            None => false,
223            Some(JsonValue::Null) => false,
224            Some(JsonValue::Bool(false)) => false,
225            Some(_) => true,
226        };
227    }
228
229    // With operator: delegate to amql-predicates
230    let node_str = value.map(crate::json_value_to_string).unwrap_or_default();
231
232    match &pred.value {
233        Some(pv) => amql_predicates::eval_op_typed(pred.op.unwrap(), &node_str, pv),
234        None => amql_predicates::eval_op(pred.op.unwrap(), &node_str, ""),
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::selector::parse_selector;
242
243    /// Simple test node implementation.
244    struct TestNode {
245        tag: TagName,
246        attrs: FxHashMap<AttrName, JsonValue>,
247        parent: Option<Box<TestNode>>,
248    }
249
250    impl TestNode {
251        fn new(tag: &str, attrs: FxHashMap<AttrName, JsonValue>) -> Self {
252            Self {
253                tag: TagName::from(tag),
254                attrs,
255                parent: None,
256            }
257        }
258
259        fn with_parent(mut self, parent: TestNode) -> Self {
260            self.parent = Some(Box::new(parent));
261            self
262        }
263    }
264
265    impl Matchable for TestNode {
266        fn tag(&self) -> &TagName {
267            &self.tag
268        }
269        fn attrs(&self) -> &FxHashMap<AttrName, JsonValue> {
270            &self.attrs
271        }
272        fn parent(&self) -> Option<&dyn Matchable> {
273            self.parent.as_ref().map(|p| p.as_ref() as &dyn Matchable)
274        }
275    }
276
277    fn attrs(pairs: &[(&str, JsonValue)]) -> FxHashMap<AttrName, JsonValue> {
278        pairs
279            .iter()
280            .map(|(k, v)| (AttrName::from(*k), v.clone()))
281            .collect()
282    }
283
284    #[test]
285    fn matches_and_rejects_by_tag() {
286        // Arrange
287        let matching = TestNode::new("controller", FxHashMap::default());
288        let wrong = TestNode::new("function", FxHashMap::default());
289        let selector = parse_selector("controller").unwrap();
290        let compound = &selector.compounds[0];
291
292        // Act
293        let hit = matches_compound(&matching, compound);
294        let miss = matches_compound(&wrong, compound);
295
296        // Assert
297        assert!(hit, "node with correct tag should match");
298        assert!(!miss, "node with wrong tag should not match");
299    }
300
301    #[test]
302    fn matches_and_rejects_attributes() {
303        // Arrange
304        let present = TestNode::new("function", attrs(&[("async", JsonValue::Bool(true))]));
305        let absent = TestNode::new("function", FxHashMap::default());
306        let post = TestNode::new(
307            "controller",
308            attrs(&[("method", JsonValue::String("POST".to_string()))]),
309        );
310        let get = TestNode::new(
311            "controller",
312            attrs(&[("method", JsonValue::String("GET".to_string()))]),
313        );
314        let handle = TestNode::new(
315            "function",
316            attrs(&[("name", JsonValue::String("handleClick".to_string()))]),
317        );
318        let create = TestNode::new(
319            "function",
320            attrs(&[("name", JsonValue::String("createUser".to_string()))]),
321        );
322
323        let sel_presence = parse_selector("function[async]").unwrap();
324        let sel_value = parse_selector(r#"controller[method="POST"]"#).unwrap();
325        let sel_starts = parse_selector(r#"[name^="handle"]"#).unwrap();
326        let sel_contains = parse_selector(r#"[name*="User"]"#).unwrap();
327
328        // Act
329        let presence_hit = matches_compound(&present, &sel_presence.compounds[0]);
330        let presence_miss = matches_compound(&absent, &sel_presence.compounds[0]);
331        let value_hit = matches_compound(&post, &sel_value.compounds[0]);
332        let value_miss = matches_compound(&get, &sel_value.compounds[0]);
333        let starts_hit = matches_compound(&handle, &sel_starts.compounds[0]);
334        let contains_hit = matches_compound(&create, &sel_contains.compounds[0]);
335
336        // Assert
337        assert!(presence_hit, "attribute present should match");
338        assert!(!presence_miss, "missing attribute should not match");
339        assert!(value_hit, "exact attribute value should match");
340        assert!(!value_miss, "wrong attribute value should not match");
341        assert!(starts_hit, "starts-with operator should match");
342        assert!(contains_hit, "contains operator should match");
343    }
344
345    #[test]
346    fn filters_by_selector() {
347        // Arrange
348        let n1 = TestNode::new(
349            "controller",
350            attrs(&[("method", JsonValue::String("POST".to_string()))]),
351        );
352        let n2 = TestNode::new("react-hook", FxHashMap::default());
353        let n3 = TestNode::new(
354            "controller",
355            attrs(&[("method", JsonValue::String("GET".to_string()))]),
356        );
357        let n4 = TestNode::new(
358            "controller",
359            attrs(&[("owner", JsonValue::String("@backend".to_string()))]),
360        );
361        let n5 = TestNode::new(
362            "react-hook",
363            attrs(&[("owner", JsonValue::String("@frontend".to_string()))]),
364        );
365        let n6 = TestNode::new("function", FxHashMap::default());
366
367        let parent = TestNode::new(
368            "class",
369            attrs(&[("name", JsonValue::String("UserService".to_string()))]),
370        );
371        let child =
372            TestNode::new("method", attrs(&[("async", JsonValue::Bool(true))])).with_parent(parent);
373        let orphan = TestNode::new("method", attrs(&[("async", JsonValue::Bool(true))]));
374
375        let sel_tag = parse_selector("controller").unwrap();
376        let sel_tag_attr = parse_selector(r#"controller[method="POST"]"#).unwrap();
377        let sel_attr_only = parse_selector(r#"[owner="@backend"]"#).unwrap();
378        let sel_child = parse_selector("class > method[async]").unwrap();
379
380        // Act
381        let by_tag: Vec<&dyn Matchable> = vec![&n1, &n2, &n3];
382        let tag_indices = filter_by_selector(&by_tag, &sel_tag);
383
384        let by_tag_attr: Vec<&dyn Matchable> = vec![&n1, &n3];
385        let tag_attr_indices = filter_by_selector(&by_tag_attr, &sel_tag_attr);
386
387        let by_attr: Vec<&dyn Matchable> = vec![&n4, &n5, &n6];
388        let attr_indices = filter_by_selector(&by_attr, &sel_attr_only);
389
390        let by_child: Vec<&dyn Matchable> = vec![&child, &orphan];
391        let child_indices = filter_by_selector(&by_child, &sel_child);
392
393        // Assert
394        assert_eq!(tag_indices, vec![0, 2], "should filter by tag");
395        assert_eq!(
396            tag_attr_indices,
397            vec![0],
398            "should filter by tag + attribute"
399        );
400        assert_eq!(
401            attr_indices,
402            vec![0],
403            "should filter by attribute-only selector"
404        );
405        assert_eq!(
406            child_indices,
407            vec![0],
408            "child combinator should match only parented node"
409        );
410    }
411
412    #[test]
413    fn descendant_backtracks_past_first_match() {
414        // Arrange: outer[name="a"] > inner > inner > method
415        // Selector: outer[name="a"] inner method
416        // The first `inner` ancestor of `method` matches `inner`,
417        // but its parent is also `inner` (not `outer`).
418        // The matcher must backtrack to the second `inner` whose parent IS `outer`.
419        let outer = TestNode::new("outer", attrs(&[("name", JsonValue::String("a".into()))]));
420        let mid = TestNode::new("inner", FxHashMap::default()).with_parent(outer);
421        let inner = TestNode::new("inner", FxHashMap::default()).with_parent(mid);
422        let leaf = TestNode::new("method", FxHashMap::default()).with_parent(inner);
423
424        let sel = parse_selector(r#"outer[name="a"] inner method"#).unwrap();
425
426        // Act
427        let nodes: Vec<&dyn Matchable> = vec![&leaf];
428        let result = filter_by_selector(&nodes, &sel);
429
430        // Assert
431        assert_eq!(
432            result,
433            vec![0],
434            "descendant matcher should backtrack past first matching ancestor"
435        );
436    }
437
438    #[test]
439    fn descendant_backtracks_indexed() {
440        // Arrange: same tree as above but using filter_by_selector_indexed
441        // outer[name="a"] > inner(0) > inner(1) > method(2)
442        // Nodes indexed: 0=outer, 1=mid_inner, 2=inner, 3=method
443        let outer = TestNode::new("outer", attrs(&[("name", JsonValue::String("a".into()))]));
444        let mid = TestNode::new("inner", FxHashMap::default());
445        let inner = TestNode::new("inner", FxHashMap::default());
446        let leaf = TestNode::new("method", FxHashMap::default());
447
448        let nodes: Vec<&dyn Matchable> = vec![&outer, &mid, &inner, &leaf];
449        let parent_indices: Vec<Option<usize>> = vec![None, Some(0), Some(1), Some(2)];
450
451        let sel = parse_selector(r#"outer[name="a"] inner method"#).unwrap();
452
453        // Act
454        let result = filter_by_selector_indexed(&nodes, &parent_indices, &sel);
455
456        // Assert
457        assert_eq!(
458            result,
459            vec![3],
460            "indexed descendant matcher should backtrack past first matching ancestor"
461        );
462    }
463}