nu_plugin_xpath/
xpath.rs

1use nu_errors::ShellError;
2use nu_protocol::{TaggedDictBuilder, UntaggedValue, Value};
3use nu_source::{Tag, Tagged};
4
5use bigdecimal::{BigDecimal, FromPrimitive};
6
7use sxd_document::parser;
8use sxd_xpath::{Context, Factory};
9
10pub struct Xpath {
11    pub query: String,
12    pub tag: Tag,
13}
14
15impl Xpath {
16    pub fn new() -> Xpath {
17        Xpath {
18            query: String::new(),
19            tag: Tag::unknown(),
20        }
21    }
22}
23
24impl Default for Xpath {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30pub fn string_to_value(raw: String, query: Tagged<&str>) -> Result<Vec<Value>, ShellError> {
31    execute_xpath_query(raw, query.item.to_string(), query.tag())
32}
33
34fn execute_xpath_query(
35    input_string: String,
36    query_string: String,
37    tag: impl Into<Tag>,
38) -> Result<Vec<Value>, ShellError> {
39    let tag = tag.into();
40    let xpath = build_xpath(&query_string)?;
41
42    let package = parser::parse(&input_string);
43
44    if package.is_err() {
45        return Err(ShellError::labeled_error(
46            "invalid xml document",
47            "invalid xml document",
48            tag.span,
49        ));
50    }
51
52    let package = package.expect("invalid xml document");
53
54    let document = package.as_document();
55    let context = Context::new();
56
57    // leaving this here for augmentation at some point
58    // build_variables(&arguments, &mut context);
59    // build_namespaces(&arguments, &mut context);
60
61    let res = xpath.evaluate(&context, document.root());
62
63    // Some xpath statements can be long, so let's truncate it with ellipsis
64    let mut key = query_string.clone();
65    if query_string.len() >= 20 {
66        key.truncate(17);
67        key += "...";
68    } else {
69        key = query_string;
70    };
71
72    match res {
73        Ok(r) => {
74            let rows: Vec<Value> = match r {
75                sxd_xpath::Value::Nodeset(ns) => ns
76                    .into_iter()
77                    .map(|a| {
78                        let mut row = TaggedDictBuilder::new(Tag::unknown());
79                        row.insert_value(&key, UntaggedValue::string(a.string_value()));
80                        row.into_value()
81                    })
82                    .collect::<Vec<Value>>(),
83                sxd_xpath::Value::Boolean(b) => {
84                    let mut row = TaggedDictBuilder::new(Tag::unknown());
85                    row.insert_value(&key, UntaggedValue::boolean(b));
86                    vec![row.into_value()]
87                }
88                sxd_xpath::Value::Number(n) => {
89                    let mut row = TaggedDictBuilder::new(Tag::unknown());
90                    row.insert_value(
91                        &key,
92                        UntaggedValue::decimal(BigDecimal::from_f64(n).expect("error with f64"))
93                            .into_untagged_value(),
94                    );
95
96                    vec![row.into_value()]
97                }
98                sxd_xpath::Value::String(s) => {
99                    let mut row = TaggedDictBuilder::new(Tag::unknown());
100                    row.insert_value(&key, UntaggedValue::string(s));
101                    vec![row.into_value()]
102                }
103            };
104
105            Ok(rows)
106        }
107        Err(_) => Err(ShellError::labeled_error(
108            "xpath query error",
109            "xpath query error",
110            tag,
111        )),
112    }
113}
114
115fn build_xpath(xpath_str: &str) -> Result<sxd_xpath::XPath, ShellError> {
116    let factory = Factory::new();
117
118    match factory.build(xpath_str) {
119        Ok(xpath) => xpath.ok_or_else(|| ShellError::untagged_runtime_error("invalid xpath query")),
120        Err(_) => Err(ShellError::untagged_runtime_error(
121            "expected valid xpath query",
122        )),
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::string_to_value as query;
129    use nu_errors::ShellError;
130    use nu_source::TaggedItem;
131    use nu_test_support::value::{decimal_from_float, row};
132
133    use indexmap::indexmap;
134
135    #[test]
136    fn position_function_in_predicate() -> Result<(), ShellError> {
137        let text = String::from(r#"<?xml version="1.0" encoding="UTF-8"?><a><b/><b/></a>"#);
138
139        let actual = query(text, "count(//a/*[position() = 2])".tagged_unknown())?;
140
141        assert_eq!(
142            actual[0],
143            row(indexmap! { "count(//a/*[posit...".into() => decimal_from_float(1.0) })
144        );
145
146        Ok(())
147    }
148
149    #[test]
150    fn functions_implicitly_coerce_argument_types() -> Result<(), ShellError> {
151        let text = String::from(r#"<?xml version="1.0" encoding="UTF-8"?><a>true</a>"#);
152
153        let actual = query(text, "count(//*[contains(., true)])".tagged_unknown())?;
154
155        assert_eq!(
156            actual[0],
157            row(indexmap! { "count(//*[contain...".into() => decimal_from_float(1.0) })
158        );
159
160        Ok(())
161    }
162}