Skip to main content

fhp_selector/xpath/
eval.rs

1//! XPath expression evaluator.
2//!
3//! Walks the DOM tree and collects nodes or text matching an [`XPathExpr`](crate::xpath::ast::XPathExpr).
4
5use fhp_core::tag::Tag;
6use fhp_tree::arena::Arena;
7use fhp_tree::node::{NodeFlags, NodeId};
8
9use super::ast::{PathStep, Predicate, XPathExpr, XPathResult};
10
11/// Evaluate an XPath expression against an arena starting from `root`.
12pub fn evaluate(expr: &XPathExpr, arena: &Arena, root: NodeId) -> XPathResult {
13    match expr {
14        XPathExpr::DescendantByTag(tag) => {
15            let nodes = find_descendants_by_tag(arena, root, *tag);
16            XPathResult::Nodes(nodes)
17        }
18
19        XPathExpr::DescendantByAttr { tag, attr, value } => {
20            let nodes = find_descendants_by_tag_and_attr(arena, root, *tag, attr, Some(value));
21            XPathResult::Nodes(nodes)
22        }
23
24        XPathExpr::DescendantByAttrExists { tag, attr } => {
25            let nodes = find_descendants_by_tag_and_attr_exists(arena, root, *tag, attr);
26            XPathResult::Nodes(nodes)
27        }
28
29        XPathExpr::ContainsPredicate { tag, attr, substr } => {
30            let nodes = find_descendants_by_tag_and_contains(arena, root, *tag, attr, substr);
31            XPathResult::Nodes(nodes)
32        }
33
34        XPathExpr::PositionPredicate { tag, pos } => {
35            let nodes = find_descendants_by_tag(arena, root, *tag);
36            if *pos >= 1 && *pos <= nodes.len() {
37                XPathResult::Nodes(vec![nodes[*pos - 1]])
38            } else {
39                XPathResult::Nodes(vec![])
40            }
41        }
42
43        XPathExpr::AbsolutePath(steps) => {
44            let nodes = evaluate_absolute_path(arena, root, steps);
45            XPathResult::Nodes(nodes)
46        }
47
48        XPathExpr::TextExtract(inner) => {
49            let inner_result = evaluate(inner, arena, root);
50            match inner_result {
51                XPathResult::Nodes(nodes) => {
52                    let texts: Vec<String> =
53                        nodes.iter().map(|&id| collect_text(arena, id)).collect();
54                    XPathResult::Strings(texts)
55                }
56                other => other,
57            }
58        }
59
60        XPathExpr::DescendantWildcard => {
61            let nodes = find_all_elements(arena, root);
62            XPathResult::Nodes(nodes)
63        }
64
65        XPathExpr::DescendantWildcardByAttr { attr, value } => {
66            let nodes = find_all_elements_by_attr(arena, root, attr, Some(value));
67            XPathResult::Nodes(nodes)
68        }
69
70        XPathExpr::DescendantWildcardByAttrExists { attr } => {
71            let nodes = find_all_elements_by_attr(arena, root, attr, None);
72            XPathResult::Nodes(nodes)
73        }
74
75        XPathExpr::Parent => {
76            // Parent is relative — from root's parent.
77            let n = arena.get(root);
78            if n.parent.is_null() {
79                XPathResult::Nodes(vec![])
80            } else {
81                XPathResult::Nodes(vec![n.parent])
82            }
83        }
84    }
85}
86
87// ---------------------------------------------------------------------------
88// Internal helpers
89// ---------------------------------------------------------------------------
90
91/// Returns `true` for element nodes (not text, comment, doctype).
92#[inline]
93fn is_element(n: &fhp_tree::node::Node) -> bool {
94    !n.flags.has(NodeFlags::IS_TEXT)
95        && !n.flags.has(NodeFlags::IS_COMMENT)
96        && !n.flags.has(NodeFlags::IS_DOCTYPE)
97}
98
99/// Generic DFS: collect descendant nodes that satisfy `predicate`.
100fn dfs_collect(
101    arena: &Arena,
102    node: NodeId,
103    predicate: &dyn Fn(&Arena, NodeId, &fhp_tree::node::Node) -> bool,
104    results: &mut Vec<NodeId>,
105) {
106    if node.is_null() {
107        return;
108    }
109    let n = arena.get(node);
110    if predicate(arena, node, n) {
111        results.push(node);
112    }
113    let mut child = n.first_child;
114    while !child.is_null() {
115        dfs_collect(arena, child, predicate, results);
116        child = arena.get(child).next_sibling;
117    }
118}
119
120/// DFS: find all descendant elements with a specific tag.
121fn find_descendants_by_tag(arena: &Arena, root: NodeId, tag: Tag) -> Vec<NodeId> {
122    let mut results = Vec::new();
123    dfs_collect(
124        arena,
125        root,
126        &|_, _, n| is_element(n) && n.tag == tag,
127        &mut results,
128    );
129    results
130}
131
132/// DFS: find descendants by tag with exact attribute match (single pass).
133fn find_descendants_by_tag_and_attr(
134    arena: &Arena,
135    root: NodeId,
136    tag: Tag,
137    attr: &str,
138    value: Option<&str>,
139) -> Vec<NodeId> {
140    let mut results = Vec::new();
141    dfs_collect(
142        arena,
143        root,
144        &|a, id, n| {
145            is_element(n)
146                && n.tag == tag
147                && a.attrs(id).iter().any(|at| {
148                    a.attr_name(at).eq_ignore_ascii_case(attr) && a.attr_value(at) == value
149                })
150        },
151        &mut results,
152    );
153    results
154}
155
156/// DFS: find descendants by tag with attribute existence (single pass).
157fn find_descendants_by_tag_and_attr_exists(
158    arena: &Arena,
159    root: NodeId,
160    tag: Tag,
161    attr: &str,
162) -> Vec<NodeId> {
163    let mut results = Vec::new();
164    dfs_collect(
165        arena,
166        root,
167        &|a, id, n| {
168            is_element(n)
169                && n.tag == tag
170                && a.attrs(id)
171                    .iter()
172                    .any(|at| a.attr_name(at).eq_ignore_ascii_case(attr))
173        },
174        &mut results,
175    );
176    results
177}
178
179/// DFS: find descendants by tag with contains predicate (single pass).
180fn find_descendants_by_tag_and_contains(
181    arena: &Arena,
182    root: NodeId,
183    tag: Tag,
184    attr: &str,
185    substr: &str,
186) -> Vec<NodeId> {
187    let mut results = Vec::new();
188    dfs_collect(
189        arena,
190        root,
191        &|a, id, n| {
192            is_element(n)
193                && n.tag == tag
194                && a.attrs(id).iter().any(|at| {
195                    a.attr_name(at).eq_ignore_ascii_case(attr)
196                        && a.attr_value(at).is_some_and(|v| v.contains(substr))
197                })
198        },
199        &mut results,
200    );
201    results
202}
203
204/// DFS: find all descendant elements with optional attribute filter (single pass).
205fn find_all_elements_by_attr(
206    arena: &Arena,
207    root: NodeId,
208    attr: &str,
209    value: Option<&str>,
210) -> Vec<NodeId> {
211    let mut results = Vec::new();
212    dfs_collect(
213        arena,
214        root,
215        &|a, id, n| {
216            if !is_element(n) || n.depth == 0 {
217                return false;
218            }
219            match value {
220                Some(val) => a.attrs(id).iter().any(|at| {
221                    a.attr_name(at).eq_ignore_ascii_case(attr) && a.attr_value(at) == Some(val)
222                }),
223                None => a
224                    .attrs(id)
225                    .iter()
226                    .any(|at| a.attr_name(at).eq_ignore_ascii_case(attr)),
227            }
228        },
229        &mut results,
230    );
231    results
232}
233
234/// DFS: find all descendant elements.
235fn find_all_elements(arena: &Arena, root: NodeId) -> Vec<NodeId> {
236    let mut results = Vec::new();
237    dfs_collect(
238        arena,
239        root,
240        &|_, _, n| is_element(n) && n.depth > 0,
241        &mut results,
242    );
243    results
244}
245
246/// Evaluate an absolute path from the root.
247///
248/// Uses a single reusable buffer to avoid per-step Vec allocations.
249/// Children are expanded in-place using a swap buffer.
250fn evaluate_absolute_path(arena: &Arena, root: NodeId, steps: &[PathStep]) -> Vec<NodeId> {
251    if steps.is_empty() {
252        return vec![];
253    }
254
255    // Start from the root's children (the root itself is a synthetic wrapper).
256    let mut current = Vec::new();
257    collect_element_children(arena, root, &mut current);
258
259    let mut next = Vec::new();
260    let last_idx = steps.len() - 1;
261
262    for (i, step) in steps.iter().enumerate() {
263        next.clear();
264        for &node_id in &current {
265            let n = arena.get(node_id);
266            if !is_element(n) || n.tag != step.tag {
267                continue;
268            }
269            if let Some(ref pred) = step.predicate {
270                if !matches_predicate(arena, node_id, pred) {
271                    continue;
272                }
273            }
274            next.push(node_id);
275        }
276
277        if next.is_empty() {
278            return vec![];
279        }
280
281        if i < last_idx {
282            // Expand to children of matched nodes for the next step.
283            current.clear();
284            for &nid in &next {
285                collect_element_children(arena, nid, &mut current);
286            }
287        } else {
288            std::mem::swap(&mut current, &mut next);
289        }
290    }
291
292    current
293}
294
295/// Collect direct element children of a node into `out` (no allocation).
296#[inline]
297fn collect_element_children(arena: &Arena, node: NodeId, out: &mut Vec<NodeId>) {
298    let n = arena.get(node);
299    let mut child = n.first_child;
300    while !child.is_null() {
301        let c = arena.get(child);
302        if is_element(c) {
303            out.push(child);
304        }
305        child = c.next_sibling;
306    }
307}
308
309/// Check if a node matches a predicate.
310fn matches_predicate(arena: &Arena, node: NodeId, pred: &Predicate) -> bool {
311    match pred {
312        Predicate::AttrEquals { attr, value } => arena.attrs(node).iter().any(|a| {
313            arena.attr_name(a).eq_ignore_ascii_case(attr) && arena.attr_value(a) == Some(value)
314        }),
315
316        Predicate::Contains { attr, substr } => arena.attrs(node).iter().any(|a| {
317            arena.attr_name(a).eq_ignore_ascii_case(attr)
318                && arena
319                    .attr_value(a)
320                    .is_some_and(|v| v.contains(substr.as_str()))
321        }),
322
323        Predicate::Position(pos) => {
324            // 1-based position among siblings of same type.
325            let n = arena.get(node);
326            if n.parent.is_null() {
327                return *pos == 1;
328            }
329            let parent = arena.get(n.parent);
330            let mut child = parent.first_child;
331            let mut idx = 0usize;
332            while !child.is_null() {
333                let c = arena.get(child);
334                if is_element(c) && c.tag == n.tag {
335                    idx += 1;
336                    if child == node {
337                        return idx == *pos;
338                    }
339                }
340                child = c.next_sibling;
341            }
342            false
343        }
344
345        Predicate::AttrExists { attr } => arena
346            .attrs(node)
347            .iter()
348            .any(|a| arena.attr_name(a).eq_ignore_ascii_case(attr)),
349    }
350}
351
352/// Recursively collect text content from a node.
353fn collect_text(arena: &Arena, node: NodeId) -> String {
354    let mut out = String::new();
355    collect_text_inner(arena, node, &mut out);
356    out
357}
358
359/// DFS text collection helper.
360fn collect_text_inner(arena: &Arena, node: NodeId, out: &mut String) {
361    let n = arena.get(node);
362    if n.flags.has(NodeFlags::IS_TEXT) {
363        out.push_str(arena.text(node));
364        return;
365    }
366    let mut child = n.first_child;
367    while !child.is_null() {
368        collect_text_inner(arena, child, out);
369        child = arena.get(child).next_sibling;
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::xpath::parser::parse_xpath;
377
378    fn eval(html: &str, xpath: &str) -> XPathResult {
379        let doc = fhp_tree::parse(html).unwrap();
380        let expr = parse_xpath(xpath).unwrap();
381        evaluate(&expr, doc.arena(), doc.root_id())
382    }
383
384    #[test]
385    fn eval_descendant_tag() {
386        let result = eval("<div><p>Hello</p><p>World</p></div>", "//p");
387        match result {
388            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 2),
389            _ => panic!("expected Nodes"),
390        }
391    }
392
393    #[test]
394    fn eval_descendant_attr() {
395        let result = eval("<a href=\"x\">a</a><a href=\"y\">b</a>", "//a[@href='x']");
396        match result {
397            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 1),
398            _ => panic!("expected Nodes"),
399        }
400    }
401
402    #[test]
403    fn eval_descendant_attr_exists() {
404        let result = eval("<a href=\"x\">a</a><span>b</span>", "//a[@href]");
405        match result {
406            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 1),
407            _ => panic!("expected Nodes"),
408        }
409    }
410
411    #[test]
412    fn eval_contains() {
413        let result = eval(
414            "<div class=\"nav-main\">a</div><div class=\"footer\">b</div>",
415            "//div[contains(@class, 'nav')]",
416        );
417        match result {
418            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 1),
419            _ => panic!("expected Nodes"),
420        }
421    }
422
423    #[test]
424    fn eval_position() {
425        let result = eval(
426            "<ul><li>1</li><li>2</li><li>3</li></ul>",
427            "//li[position()=2]",
428        );
429        match result {
430            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 1),
431            _ => panic!("expected Nodes"),
432        }
433    }
434
435    #[test]
436    fn eval_position_shorthand() {
437        let result = eval("<ul><li>1</li><li>2</li><li>3</li></ul>", "//li[1]");
438        match result {
439            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 1),
440            _ => panic!("expected Nodes"),
441        }
442    }
443
444    #[test]
445    fn eval_text_extract() {
446        let result = eval("<div><p>Hello</p><p>World</p></div>", "//p/text()");
447        match result {
448            XPathResult::Strings(texts) => {
449                assert_eq!(texts.len(), 2);
450                assert_eq!(texts[0], "Hello");
451                assert_eq!(texts[1], "World");
452            }
453            _ => panic!("expected Strings"),
454        }
455    }
456
457    #[test]
458    fn eval_absolute_path() {
459        let result = eval(
460            "<html><body><div>content</div></body></html>",
461            "/html/body/div",
462        );
463        match result {
464            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 1),
465            _ => panic!("expected Nodes"),
466        }
467    }
468
469    #[test]
470    fn eval_absolute_path_text() {
471        let result = eval(
472            "<html><body><p>text</p></body></html>",
473            "/html/body/p/text()",
474        );
475        match result {
476            XPathResult::Strings(texts) => {
477                assert_eq!(texts.len(), 1);
478                assert_eq!(texts[0], "text");
479            }
480            _ => panic!("expected Strings"),
481        }
482    }
483
484    #[test]
485    fn eval_wildcard() {
486        let result = eval("<div><p>a</p><span>b</span></div>", "//*");
487        match result {
488            XPathResult::Nodes(nodes) => assert!(nodes.len() >= 3),
489            _ => panic!("expected Nodes"),
490        }
491    }
492
493    #[test]
494    fn eval_wildcard_attr() {
495        let result = eval("<div id=\"main\">a</div><span>b</span>", "//*[@id='main']");
496        match result {
497            XPathResult::Nodes(nodes) => assert_eq!(nodes.len(), 1),
498            _ => panic!("expected Nodes"),
499        }
500    }
501
502    #[test]
503    fn eval_empty_result() {
504        let result = eval("<div>text</div>", "//span");
505        match result {
506            XPathResult::Nodes(nodes) => assert!(nodes.is_empty()),
507            _ => panic!("expected Nodes"),
508        }
509    }
510
511    #[test]
512    fn eval_position_out_of_range() {
513        let result = eval("<ul><li>1</li></ul>", "//li[position()=5]");
514        match result {
515            XPathResult::Nodes(nodes) => assert!(nodes.is_empty()),
516            _ => panic!("expected Nodes"),
517        }
518    }
519}