Skip to main content

clayers_xml/
query.rs

1//! Shared `XPath` 3.1 evaluation via xee-xpath.
2//!
3//! All `!Send` xee-xpath types are created and dropped within [`evaluate_xpath`],
4//! keeping them invisible to async callers.
5//!
6//! Namespace prefixes used in the `XPath` expression are automatically
7//! discovered from the XML document's root element declarations. Callers
8//! may supply additional bindings that override or supplement these.
9
10use xee_xpath::context::StaticContextBuilder;
11use xee_xpath::{Documents, Item, Queries, Query};
12
13/// Output mode for `XPath` queries.
14#[derive(Debug, Clone, Copy)]
15pub enum QueryMode {
16    /// Return the count of matching nodes.
17    Count,
18    /// Return the text content of matching nodes.
19    Text,
20    /// Return the serialized XML of matching nodes.
21    Xml,
22}
23
24/// Result of an `XPath` query.
25#[derive(Debug)]
26pub enum QueryResult {
27    /// Node count.
28    Count(usize),
29    /// Text content of each matching node.
30    Text(Vec<String>),
31    /// Serialized XML of each matching node.
32    Xml(Vec<String>),
33}
34
35/// Evaluate an `XPath` 3.1 expression against an XML string.
36///
37/// Namespace prefixes are discovered automatically from the XML document's
38/// root element. Additional `namespaces` bindings are merged on top
39/// (overriding any conflicting prefix from the document).
40///
41/// # Errors
42///
43/// Returns an error if the XML cannot be parsed, the `XPath` cannot be compiled,
44/// or execution fails.
45pub fn evaluate_xpath(
46    xml: &str,
47    xpath_expr: &str,
48    mode: QueryMode,
49    namespaces: &[(&str, &str)],
50) -> Result<QueryResult, crate::Error> {
51    let mut documents = Documents::new();
52    let doc = documents
53        .add_string_without_uri(xml)
54        .map_err(|e| crate::Error::Query(format!("XML parse error: {e}")))?;
55
56    // Discover namespace declarations from all elements in the document.
57    let doc_namespaces = {
58        let xot = documents.xot();
59        let doc_node = documents
60            .document_node(doc)
61            .ok_or_else(|| crate::Error::Query("missing document node".into()))?;
62        let mut ns_map = std::collections::HashMap::<String, String>::new();
63        collect_namespace_declarations(xot, doc_node, &mut ns_map);
64        ns_map.into_iter().collect::<Vec<_>>()
65    };
66
67    // Build the static context: document namespaces first, caller overrides on top.
68    let mut ctx = StaticContextBuilder::default();
69    for (prefix, uri) in &doc_namespaces {
70        if !prefix.is_empty() && !uri.is_empty() {
71            ctx.add_namespace(prefix, uri);
72        }
73    }
74    // Caller-provided namespaces override document ones.
75    ctx.namespaces(namespaces.iter().copied());
76
77    let queries = Queries::new(ctx);
78    let q = queries
79        .sequence(xpath_expr)
80        .map_err(|e| crate::Error::Query(format!("XPath compile error: {e}")))?;
81    let seq = q
82        .execute(&mut documents, doc)
83        .map_err(|e| crate::Error::Query(format!("XPath execution error: {e}")))?;
84
85    match mode {
86        QueryMode::Count => Ok(QueryResult::Count(seq.iter().count())),
87        QueryMode::Text => {
88            let xot = documents.xot();
89            let texts = seq
90                .iter()
91                .map(|item| match item {
92                    Item::Node(n) => Ok(collect_all_text(xot, n)),
93                    _ => item
94                        .string_value(xot)
95                        .map_err(|e| crate::Error::Query(format!("string value error: {e}"))),
96                })
97                .collect::<Result<Vec<_>, _>>()?;
98            Ok(QueryResult::Text(texts))
99        }
100        QueryMode::Xml => {
101            let xot = documents.xot();
102            let xmls = seq
103                .iter()
104                .map(|item| match item {
105                    Item::Node(n) => Ok(xot.to_string(n).unwrap_or_default()),
106                    _ => item
107                        .string_value(xot)
108                        .map_err(|e| crate::Error::Query(format!("string value error: {e}"))),
109                })
110                .collect::<Result<Vec<_>, _>>()?;
111            Ok(QueryResult::Xml(xmls))
112        }
113    }
114}
115
116/// Recursively collect all namespace declarations from a node and its descendants.
117///
118/// First declaration wins: if a prefix is declared on multiple elements,
119/// the one closest to the root is kept.
120fn collect_namespace_declarations(
121    xot: &xot::Xot,
122    node: xot::Node,
123    ns_map: &mut std::collections::HashMap<String, String>,
124) {
125    for (prefix_id, ns_id) in xot.namespaces(node).iter() {
126        let prefix = xot.prefix_str(prefix_id);
127        let uri = xot.namespace_str(*ns_id);
128        ns_map.entry(prefix.to_owned()).or_insert_with(|| uri.to_owned());
129    }
130    for ch in xot.children(node) {
131        collect_namespace_declarations(xot, ch, ns_map);
132    }
133}
134
135/// Collect all text content from a node and its descendants.
136fn collect_all_text(xot: &xot::Xot, node: xot::Node) -> String {
137    let mut text = String::new();
138    collect_text_recursive(xot, node, &mut text);
139    text.trim().to_string()
140}
141
142fn collect_text_recursive(xot: &xot::Xot, node: xot::Node, out: &mut String) {
143    if let Some(t) = xot.text_str(node) {
144        out.push_str(t);
145    }
146    for child in xot.children(node) {
147        collect_text_recursive(xot, child, out);
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn auto_discovers_namespaces() {
157        let xml = r#"<root xmlns:app="urn:test:app"><app:item id="1">hello</app:item></root>"#;
158        // No caller namespaces - should discover from document.
159        let result = evaluate_xpath(xml, "//app:item", QueryMode::Count, &[]).unwrap();
160        match result {
161            QueryResult::Count(n) => assert_eq!(n, 1),
162            _ => panic!("expected Count"),
163        }
164    }
165
166    #[test]
167    fn caller_namespace_overrides_document() {
168        // XML declares xmlns:x="urn:a", caller maps x -> urn:b.
169        // The XPath should use the caller's mapping.
170        let xml = r#"<root xmlns:x="urn:a"><x:item>hello</x:item></root>"#;
171        // With urn:a, should find 1.
172        let result = evaluate_xpath(xml, "//x:item", QueryMode::Count, &[]).unwrap();
173        assert!(matches!(result, QueryResult::Count(1)));
174        // Override to urn:b - no elements match.
175        let result =
176            evaluate_xpath(xml, "//x:item", QueryMode::Count, &[("x", "urn:b")]).unwrap();
177        assert!(matches!(result, QueryResult::Count(0)));
178    }
179
180    #[test]
181    fn default_namespace_via_caller() {
182        // XML uses default ns (no prefix), caller provides a prefix for it.
183        let xml = r#"<root xmlns="urn:example"><entry id="1">hello</entry></root>"#;
184        let ns = &[("ex", "urn:example")];
185        let result = evaluate_xpath(xml, "//ex:entry", QueryMode::Count, ns).unwrap();
186        assert!(matches!(result, QueryResult::Count(1)));
187    }
188
189    #[test]
190    fn text_mode() {
191        let xml = r"<root><item>hello</item></root>";
192        let result = evaluate_xpath(xml, "//item", QueryMode::Text, &[]).unwrap();
193        match result {
194            QueryResult::Text(texts) => assert_eq!(texts, vec!["hello"]),
195            _ => panic!("expected Text"),
196        }
197    }
198
199    #[test]
200    fn discovers_namespaces_from_nested_elements() {
201        // Only child elements declare the namespace, not the root.
202        let xml = r#"<root><app:item xmlns:app="urn:test:app" id="1">hello</app:item></root>"#;
203        let result = evaluate_xpath(xml, "//app:item", QueryMode::Count, &[]).unwrap();
204        assert!(matches!(result, QueryResult::Count(1)));
205    }
206
207    #[test]
208    fn count_function() {
209        let xml = r"<root><a/><a/><a/></root>";
210        let result = evaluate_xpath(xml, "count(//a)", QueryMode::Text, &[]).unwrap();
211        match result {
212            QueryResult::Text(texts) => {
213                assert_eq!(texts.len(), 1);
214                assert_eq!(texts[0], "3");
215            }
216            _ => panic!("expected Text"),
217        }
218    }
219}