1use nu_errors::ShellError;
2use nu_protocol::{TaggedDictBuilder, UntaggedValue, Value};
3use nu_source::{Tag, Tagged};
4
5use bigdecimal::{BigDecimal, FromPrimitive};
6
7use sxd_document::parser;
8use sxd_xpath::{Context, Factory};
9
10pub struct Xpath {
11 pub query: String,
12 pub tag: Tag,
13}
14
15impl Xpath {
16 pub fn new() -> Xpath {
17 Xpath {
18 query: String::new(),
19 tag: Tag::unknown(),
20 }
21 }
22}
23
24impl Default for Xpath {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30pub fn string_to_value(raw: String, query: Tagged<&str>) -> Result<Vec<Value>, ShellError> {
31 execute_xpath_query(raw, query.item.to_string(), query.tag())
32}
33
34fn execute_xpath_query(
35 input_string: String,
36 query_string: String,
37 tag: impl Into<Tag>,
38) -> Result<Vec<Value>, ShellError> {
39 let tag = tag.into();
40 let xpath = build_xpath(&query_string)?;
41
42 let package = parser::parse(&input_string);
43
44 if package.is_err() {
45 return Err(ShellError::labeled_error(
46 "invalid xml document",
47 "invalid xml document",
48 tag.span,
49 ));
50 }
51
52 let package = package.expect("invalid xml document");
53
54 let document = package.as_document();
55 let context = Context::new();
56
57 let res = xpath.evaluate(&context, document.root());
62
63 let mut key = query_string.clone();
65 if query_string.len() >= 20 {
66 key.truncate(17);
67 key += "...";
68 } else {
69 key = query_string;
70 };
71
72 match res {
73 Ok(r) => {
74 let rows: Vec<Value> = match r {
75 sxd_xpath::Value::Nodeset(ns) => ns
76 .into_iter()
77 .map(|a| {
78 let mut row = TaggedDictBuilder::new(Tag::unknown());
79 row.insert_value(&key, UntaggedValue::string(a.string_value()));
80 row.into_value()
81 })
82 .collect::<Vec<Value>>(),
83 sxd_xpath::Value::Boolean(b) => {
84 let mut row = TaggedDictBuilder::new(Tag::unknown());
85 row.insert_value(&key, UntaggedValue::boolean(b));
86 vec![row.into_value()]
87 }
88 sxd_xpath::Value::Number(n) => {
89 let mut row = TaggedDictBuilder::new(Tag::unknown());
90 row.insert_value(
91 &key,
92 UntaggedValue::decimal(BigDecimal::from_f64(n).expect("error with f64"))
93 .into_untagged_value(),
94 );
95
96 vec![row.into_value()]
97 }
98 sxd_xpath::Value::String(s) => {
99 let mut row = TaggedDictBuilder::new(Tag::unknown());
100 row.insert_value(&key, UntaggedValue::string(s));
101 vec![row.into_value()]
102 }
103 };
104
105 Ok(rows)
106 }
107 Err(_) => Err(ShellError::labeled_error(
108 "xpath query error",
109 "xpath query error",
110 tag,
111 )),
112 }
113}
114
115fn build_xpath(xpath_str: &str) -> Result<sxd_xpath::XPath, ShellError> {
116 let factory = Factory::new();
117
118 match factory.build(xpath_str) {
119 Ok(xpath) => xpath.ok_or_else(|| ShellError::untagged_runtime_error("invalid xpath query")),
120 Err(_) => Err(ShellError::untagged_runtime_error(
121 "expected valid xpath query",
122 )),
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::string_to_value as query;
129 use nu_errors::ShellError;
130 use nu_source::TaggedItem;
131 use nu_test_support::value::{decimal_from_float, row};
132
133 use indexmap::indexmap;
134
135 #[test]
136 fn position_function_in_predicate() -> Result<(), ShellError> {
137 let text = String::from(r#"<?xml version="1.0" encoding="UTF-8"?><a><b/><b/></a>"#);
138
139 let actual = query(text, "count(//a/*[position() = 2])".tagged_unknown())?;
140
141 assert_eq!(
142 actual[0],
143 row(indexmap! { "count(//a/*[posit...".into() => decimal_from_float(1.0) })
144 );
145
146 Ok(())
147 }
148
149 #[test]
150 fn functions_implicitly_coerce_argument_types() -> Result<(), ShellError> {
151 let text = String::from(r#"<?xml version="1.0" encoding="UTF-8"?><a>true</a>"#);
152
153 let actual = query(text, "count(//*[contains(., true)])".tagged_unknown())?;
154
155 assert_eq!(
156 actual[0],
157 row(indexmap! { "count(//*[contain...".into() => decimal_from_float(1.0) })
158 );
159
160 Ok(())
161 }
162}