1use xee_xpath::context::StaticContextBuilder;
11use xee_xpath::{Documents, Item, Queries, Query};
12
13#[derive(Debug, Clone, Copy)]
15pub enum QueryMode {
16 Count,
18 Text,
20 Xml,
22}
23
24#[derive(Debug)]
26pub enum QueryResult {
27 Count(usize),
29 Text(Vec<String>),
31 Xml(Vec<String>),
33}
34
35pub fn evaluate_xpath(
46 xml: &str,
47 xpath_expr: &str,
48 mode: QueryMode,
49 namespaces: &[(&str, &str)],
50) -> Result<QueryResult, crate::Error> {
51 let mut documents = Documents::new();
52 let doc = documents
53 .add_string_without_uri(xml)
54 .map_err(|e| crate::Error::Query(format!("XML parse error: {e}")))?;
55
56 let doc_namespaces = {
58 let xot = documents.xot();
59 let doc_node = documents
60 .document_node(doc)
61 .ok_or_else(|| crate::Error::Query("missing document node".into()))?;
62 let mut ns_map = std::collections::HashMap::<String, String>::new();
63 collect_namespace_declarations(xot, doc_node, &mut ns_map);
64 ns_map.into_iter().collect::<Vec<_>>()
65 };
66
67 let mut ctx = StaticContextBuilder::default();
69 for (prefix, uri) in &doc_namespaces {
70 if !prefix.is_empty() && !uri.is_empty() {
71 ctx.add_namespace(prefix, uri);
72 }
73 }
74 ctx.namespaces(namespaces.iter().copied());
76
77 let queries = Queries::new(ctx);
78 let q = queries
79 .sequence(xpath_expr)
80 .map_err(|e| crate::Error::Query(format!("XPath compile error: {e}")))?;
81 let seq = q
82 .execute(&mut documents, doc)
83 .map_err(|e| crate::Error::Query(format!("XPath execution error: {e}")))?;
84
85 match mode {
86 QueryMode::Count => Ok(QueryResult::Count(seq.iter().count())),
87 QueryMode::Text => {
88 let xot = documents.xot();
89 let texts = seq
90 .iter()
91 .map(|item| match item {
92 Item::Node(n) => Ok(collect_all_text(xot, n)),
93 _ => item
94 .string_value(xot)
95 .map_err(|e| crate::Error::Query(format!("string value error: {e}"))),
96 })
97 .collect::<Result<Vec<_>, _>>()?;
98 Ok(QueryResult::Text(texts))
99 }
100 QueryMode::Xml => {
101 let xot = documents.xot();
102 let xmls = seq
103 .iter()
104 .map(|item| match item {
105 Item::Node(n) => Ok(xot.to_string(n).unwrap_or_default()),
106 _ => item
107 .string_value(xot)
108 .map_err(|e| crate::Error::Query(format!("string value error: {e}"))),
109 })
110 .collect::<Result<Vec<_>, _>>()?;
111 Ok(QueryResult::Xml(xmls))
112 }
113 }
114}
115
116fn collect_namespace_declarations(
121 xot: &xot::Xot,
122 node: xot::Node,
123 ns_map: &mut std::collections::HashMap<String, String>,
124) {
125 for (prefix_id, ns_id) in xot.namespaces(node).iter() {
126 let prefix = xot.prefix_str(prefix_id);
127 let uri = xot.namespace_str(*ns_id);
128 ns_map.entry(prefix.to_owned()).or_insert_with(|| uri.to_owned());
129 }
130 for ch in xot.children(node) {
131 collect_namespace_declarations(xot, ch, ns_map);
132 }
133}
134
135fn collect_all_text(xot: &xot::Xot, node: xot::Node) -> String {
137 let mut text = String::new();
138 collect_text_recursive(xot, node, &mut text);
139 text.trim().to_string()
140}
141
142fn collect_text_recursive(xot: &xot::Xot, node: xot::Node, out: &mut String) {
143 if let Some(t) = xot.text_str(node) {
144 out.push_str(t);
145 }
146 for child in xot.children(node) {
147 collect_text_recursive(xot, child, out);
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn auto_discovers_namespaces() {
157 let xml = r#"<root xmlns:app="urn:test:app"><app:item id="1">hello</app:item></root>"#;
158 let result = evaluate_xpath(xml, "//app:item", QueryMode::Count, &[]).unwrap();
160 match result {
161 QueryResult::Count(n) => assert_eq!(n, 1),
162 _ => panic!("expected Count"),
163 }
164 }
165
166 #[test]
167 fn caller_namespace_overrides_document() {
168 let xml = r#"<root xmlns:x="urn:a"><x:item>hello</x:item></root>"#;
171 let result = evaluate_xpath(xml, "//x:item", QueryMode::Count, &[]).unwrap();
173 assert!(matches!(result, QueryResult::Count(1)));
174 let result =
176 evaluate_xpath(xml, "//x:item", QueryMode::Count, &[("x", "urn:b")]).unwrap();
177 assert!(matches!(result, QueryResult::Count(0)));
178 }
179
180 #[test]
181 fn default_namespace_via_caller() {
182 let xml = r#"<root xmlns="urn:example"><entry id="1">hello</entry></root>"#;
184 let ns = &[("ex", "urn:example")];
185 let result = evaluate_xpath(xml, "//ex:entry", QueryMode::Count, ns).unwrap();
186 assert!(matches!(result, QueryResult::Count(1)));
187 }
188
189 #[test]
190 fn text_mode() {
191 let xml = r"<root><item>hello</item></root>";
192 let result = evaluate_xpath(xml, "//item", QueryMode::Text, &[]).unwrap();
193 match result {
194 QueryResult::Text(texts) => assert_eq!(texts, vec!["hello"]),
195 _ => panic!("expected Text"),
196 }
197 }
198
199 #[test]
200 fn discovers_namespaces_from_nested_elements() {
201 let xml = r#"<root><app:item xmlns:app="urn:test:app" id="1">hello</app:item></root>"#;
203 let result = evaluate_xpath(xml, "//app:item", QueryMode::Count, &[]).unwrap();
204 assert!(matches!(result, QueryResult::Count(1)));
205 }
206
207 #[test]
208 fn count_function() {
209 let xml = r"<root><a/><a/><a/></root>";
210 let result = evaluate_xpath(xml, "count(//a)", QueryMode::Text, &[]).unwrap();
211 match result {
212 QueryResult::Text(texts) => {
213 assert_eq!(texts.len(), 1);
214 assert_eq!(texts[0], "3");
215 }
216 _ => panic!("expected Text"),
217 }
218 }
219}