use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use roxmltree::{Document, Node, NodeId};
use super::types::{NodeSet, TransformData, TransformError};
const DEFAULT_ID_ATTRS: &[&str] = &["ID", "Id", "id"];
pub struct UriReferenceResolver<'a> {
doc: &'a Document<'a>,
id_map: HashMap<&'a str, Node<'a, 'a>>,
}
impl<'a> UriReferenceResolver<'a> {
pub fn new(doc: &'a Document<'a>) -> Self {
Self::with_id_attrs(doc, DEFAULT_ID_ATTRS)
}
pub fn with_id_attrs(doc: &'a Document<'a>, extra_attrs: &[&str]) -> Self {
let mut id_map = HashMap::new();
let mut duplicate_ids: HashSet<&'a str> = HashSet::new();
let mut attr_names: Vec<&str> = DEFAULT_ID_ATTRS.to_vec();
for name in extra_attrs {
if !attr_names.contains(name) {
attr_names.push(name);
}
}
for node in doc.descendants() {
if node.is_element() {
for attr_name in &attr_names {
if let Some(value) = node.attribute(*attr_name) {
if duplicate_ids.contains(value) {
continue;
}
match id_map.entry(value) {
Entry::Vacant(v) => {
v.insert(node);
}
Entry::Occupied(o) => {
if o.get().id() != node.id() {
o.remove();
duplicate_ids.insert(value);
}
}
}
}
}
}
}
Self { doc, id_map }
}
pub fn dereference(&self, uri: &str) -> Result<TransformData<'a>, TransformError> {
if uri.is_empty() {
Ok(TransformData::NodeSet(
NodeSet::entire_document_without_comments(self.doc),
))
} else if let Some(fragment) = uri.strip_prefix('#') {
self.dereference_fragment(fragment)
} else {
Err(TransformError::UnsupportedUri(uri.to_string()))
}
}
fn dereference_fragment(&self, fragment: &str) -> Result<TransformData<'a>, TransformError> {
if fragment.is_empty() {
return Err(TransformError::UnsupportedUri("#".to_string()));
}
if fragment == "xpointer(/)" {
Ok(TransformData::NodeSet(
NodeSet::entire_document_with_comments(self.doc),
))
} else if let Some(id) = parse_xpointer_id_fragment(fragment) {
if id.is_empty() {
return Err(TransformError::UnsupportedUri(format!("#{fragment}")));
}
self.resolve_id(id)
} else if fragment.starts_with("xpointer(") {
Err(TransformError::UnsupportedUri(format!("#{fragment}")))
} else {
self.resolve_id(fragment)
}
}
fn resolve_id(&self, id: &str) -> Result<TransformData<'a>, TransformError> {
match self.id_map.get(id) {
Some(&element) => Ok(TransformData::NodeSet(NodeSet::subtree(element))),
None => Err(TransformError::ElementNotFound(id.to_string())),
}
}
pub fn has_id(&self, id: &str) -> bool {
self.id_map.contains_key(id)
}
pub(crate) fn node_id_for_id(&self, id: &str) -> Option<NodeId> {
self.id_map.get(id).map(|node| node.id())
}
pub fn id_count(&self) -> usize {
self.id_map.len()
}
}
pub(crate) fn parse_xpointer_id_fragment(fragment: &str) -> Option<&str> {
let inner = fragment.strip_prefix("xpointer(id(")?.strip_suffix("))")?;
if let Some(stripped) = inner.strip_prefix('\'').and_then(|s| s.strip_suffix('\'')) {
Some(stripped)
} else if let Some(stripped) = inner.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
Some(stripped)
} else {
None
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::super::types::NodeSet;
use super::*;
#[test]
fn empty_uri_returns_whole_document() {
let xml = "<root><child>text</child></root>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("").unwrap();
let node_set = data.into_node_set().unwrap();
let root = doc.root_element();
assert!(node_set.contains(root));
let child = root.first_child().unwrap();
assert!(node_set.contains(child));
}
#[test]
fn empty_uri_excludes_comments() {
let xml = "<root><!-- comment --><child/></root>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("").unwrap();
let node_set = data.into_node_set().unwrap();
for node in doc.descendants() {
if node.is_comment() {
assert!(
!node_set.contains(node),
"comment should be excluded for empty URI"
);
}
}
assert!(node_set.contains(doc.root_element()));
}
#[test]
fn fragment_uri_resolves_by_id_attr() {
let xml = r#"<root><item ID="abc">content</item><item ID="def">other</item></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("#abc").unwrap();
let node_set = data.into_node_set().unwrap();
let abc_elem = doc
.descendants()
.find(|n| n.attribute("ID") == Some("abc"))
.unwrap();
assert!(node_set.contains(abc_elem));
let text_child = abc_elem.first_child().unwrap();
assert!(node_set.contains(text_child));
assert!(!node_set.contains(doc.root_element()));
let def_elem = doc
.descendants()
.find(|n| n.attribute("ID") == Some("def"))
.unwrap();
assert!(!node_set.contains(def_elem));
}
#[test]
fn fragment_uri_resolves_lowercase_id() {
let xml = r#"<root><item id="lower">text</item></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("#lower").unwrap();
let node_set = data.into_node_set().unwrap();
let elem = doc
.descendants()
.find(|n| n.attribute("id") == Some("lower"))
.unwrap();
assert!(node_set.contains(elem));
}
#[test]
fn fragment_uri_resolves_mixed_case_id() {
let xml = r#"<root><ds:Signature Id="sig1" xmlns:ds="http://www.w3.org/2000/09/xmldsig#"/></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
assert!(resolver.has_id("sig1"));
let data = resolver.dereference("#sig1").unwrap();
assert!(data.into_node_set().is_ok());
}
#[test]
fn fragment_uri_not_found() {
let xml = "<root><child>text</child></root>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let result = resolver.dereference("#nonexistent");
assert!(result.is_err());
match result.unwrap_err() {
TransformError::ElementNotFound(id) => assert_eq!(id, "nonexistent"),
other => panic!("expected ElementNotFound, got: {other:?}"),
}
}
#[test]
fn unsupported_external_uri() {
let xml = "<root/>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let result = resolver.dereference("http://example.com/doc.xml");
assert!(result.is_err());
match result.unwrap_err() {
TransformError::UnsupportedUri(uri) => {
assert_eq!(uri, "http://example.com/doc.xml")
}
other => panic!("expected UnsupportedUri, got: {other:?}"),
}
}
#[test]
fn unsupported_xpointer_expression() {
let xml = "<root/>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let result = resolver.dereference("#xpointer(foo())");
assert!(result.is_err());
match result.unwrap_err() {
TransformError::UnsupportedUri(uri) => {
assert_eq!(uri, "#xpointer(foo())")
}
other => panic!("expected UnsupportedUri, got: {other:?}"),
}
let result = resolver.dereference("#xpointer(//element)");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TransformError::UnsupportedUri(_)
));
}
#[test]
fn empty_fragment_rejected() {
let xml = "<root/>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let result = resolver.dereference("#");
assert!(result.is_err());
match result.unwrap_err() {
TransformError::UnsupportedUri(uri) => assert_eq!(uri, "#"),
other => panic!("expected UnsupportedUri, got: {other:?}"),
}
}
#[test]
fn foreign_document_node_rejected() {
let xml1 = "<root><child/></root>";
let xml2 = "<other><item/></other>";
let doc1 = Document::parse(xml1).unwrap();
let doc2 = Document::parse(xml2).unwrap();
let node_set = NodeSet::entire_document_without_comments(&doc1);
let foreign_node = doc2.root_element();
assert!(
!node_set.contains(foreign_node),
"foreign document node should be rejected"
);
let own_node = doc1.root_element();
assert!(node_set.contains(own_node));
}
#[test]
fn custom_id_attr_name() {
let xml = r#"<root><elem myid="custom1">data</elem></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver_default = UriReferenceResolver::new(&doc);
assert!(!resolver_default.has_id("custom1"));
let resolver_custom = UriReferenceResolver::with_id_attrs(&doc, &["myid"]);
assert!(resolver_custom.has_id("custom1"));
let data = resolver_custom.dereference("#custom1").unwrap();
assert!(data.into_node_set().is_ok());
}
#[test]
fn namespaced_id_attr_found_by_local_name() {
let xml =
r#"<root><elem wsu:Id="ts1" xmlns:wsu="http://example.com/wsu">data</elem></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
assert!(resolver.has_id("ts1"));
}
#[test]
fn id_count_reports_unique_ids() {
let xml = r#"<root ID="r1"><a ID="a1"/><b Id="b1"/><c id="c1"/></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
assert_eq!(resolver.id_count(), 4);
}
#[test]
fn duplicate_ids_are_rejected() {
let xml = r#"<root><a ID="dup">first</a><b ID="dup">second</b></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
assert!(!resolver.has_id("dup"));
let result = resolver.dereference("#dup");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TransformError::ElementNotFound(_)
));
}
#[test]
fn triple_duplicate_ids_stay_rejected() {
let xml = r#"<root><a ID="dup">1</a><b ID="dup">2</b><c ID="dup">3</c></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
assert!(!resolver.has_id("dup"));
assert!(resolver.dereference("#dup").is_err());
}
#[test]
fn node_set_exclude_subtree() {
let xml = r#"<root><keep>yes</keep><remove><deep>no</deep></remove></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("").unwrap();
let mut node_set = data.into_node_set().unwrap();
let remove_elem = doc
.descendants()
.find(|n| n.is_element() && n.has_tag_name("remove"))
.unwrap();
node_set.exclude_subtree(remove_elem);
let keep_elem = doc
.descendants()
.find(|n| n.is_element() && n.has_tag_name("keep"))
.unwrap();
assert!(node_set.contains(keep_elem));
assert!(!node_set.contains(remove_elem));
let deep_elem = doc
.descendants()
.find(|n| n.is_element() && n.has_tag_name("deep"))
.unwrap();
assert!(!node_set.contains(deep_elem));
}
#[test]
fn subtree_includes_comments() {
let xml = r#"<root><item ID="x"><!-- comment --><child/></item></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("#x").unwrap();
let node_set = data.into_node_set().unwrap();
for node in doc.descendants() {
if node.is_comment() {
assert!(
node_set.contains(node),
"comment should be included in #id subtree"
);
}
}
}
#[test]
fn xpointer_root_returns_whole_document_with_comments() {
let xml = "<root><!-- comment --><child/></root>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("#xpointer(/)").unwrap();
let node_set = data.into_node_set().unwrap();
for node in doc.descendants() {
if node.is_comment() {
assert!(
node_set.contains(node),
"comment should be included for #xpointer(/)"
);
}
}
assert!(node_set.contains(doc.root_element()));
}
#[test]
fn xpointer_id_single_quotes() {
let xml = r#"<root><item ID="abc">content</item></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference("#xpointer(id('abc'))").unwrap();
let node_set = data.into_node_set().unwrap();
let elem = doc
.descendants()
.find(|n| n.attribute("ID") == Some("abc"))
.unwrap();
assert!(node_set.contains(elem));
}
#[test]
fn xpointer_id_double_quotes() {
let xml = r#"<root><item ID="xyz">content</item></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let data = resolver.dereference(r#"#xpointer(id("xyz"))"#).unwrap();
let node_set = data.into_node_set().unwrap();
let elem = doc
.descendants()
.find(|n| n.attribute("ID") == Some("xyz"))
.unwrap();
assert!(node_set.contains(elem));
}
#[test]
fn xpointer_id_not_found() {
let xml = "<root/>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let result = resolver.dereference("#xpointer(id('missing'))");
assert!(result.is_err());
match result.unwrap_err() {
TransformError::ElementNotFound(id) => assert_eq!(id, "missing"),
other => panic!("expected ElementNotFound, got: {other:?}"),
}
}
#[test]
fn xpointer_id_empty_value_rejected() {
let xml = "<root/>";
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
let result = resolver.dereference("#xpointer(id(''))");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TransformError::UnsupportedUri(_)
));
}
#[test]
fn parse_xpointer_id_variants() {
assert_eq!(
super::parse_xpointer_id_fragment("xpointer(id('foo'))"),
Some("foo")
);
assert_eq!(
super::parse_xpointer_id_fragment(r#"xpointer(id("bar"))"#),
Some("bar")
);
assert_eq!(super::parse_xpointer_id_fragment("xpointer(/)"), None);
assert_eq!(super::parse_xpointer_id_fragment("xpointer(id(foo))"), None); assert_eq!(super::parse_xpointer_id_fragment("not-xpointer"), None);
assert_eq!(super::parse_xpointer_id_fragment(""), None);
assert_eq!(super::parse_xpointer_id_fragment("xpointer(id('))"), None);
assert_eq!(
super::parse_xpointer_id_fragment(r#"xpointer(id("))"#),
None
);
}
#[test]
fn same_element_multiple_id_attrs_not_duplicate() {
let xml = r#"<root><item ID="x" Id="x">data</item></root>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
assert!(resolver.has_id("x"));
assert!(resolver.dereference("#x").is_ok());
}
#[test]
fn saml_style_document() {
let xml = r#"<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="_resp1">
<saml:Assertion ID="_assert1">
<saml:Subject>user@example.com</saml:Subject>
</saml:Assertion>
<ds:Signature xmlns:ds="http://www.w3.org/2000/09/xmldsig#" Id="sig1">
<ds:SignedInfo/>
</ds:Signature>
</samlp:Response>"#;
let doc = Document::parse(xml).unwrap();
let resolver = UriReferenceResolver::new(&doc);
assert!(resolver.has_id("_resp1"));
assert!(resolver.has_id("_assert1"));
assert!(resolver.has_id("sig1"));
assert_eq!(resolver.id_count(), 3);
let data = resolver.dereference("#_assert1").unwrap();
let node_set = data.into_node_set().unwrap();
let assertion = doc
.descendants()
.find(|n| n.attribute("ID") == Some("_assert1"))
.unwrap();
assert!(node_set.contains(assertion));
let subject = assertion
.children()
.find(|n| n.is_element() && n.has_tag_name("Subject"))
.unwrap();
assert!(node_set.contains(subject));
assert!(!node_set.contains(doc.root_element()));
}
}