1use std::collections::hash_map::Entry;
15use std::collections::{HashMap, HashSet};
16
17use roxmltree::{Document, Node};
18
19use super::types::{NodeSet, TransformData, TransformError};
20
21const DEFAULT_ID_ATTRS: &[&str] = &["ID", "Id", "id"];
28
29pub struct UriReferenceResolver<'a> {
51 doc: &'a Document<'a>,
52 id_map: HashMap<&'a str, Node<'a, 'a>>,
54}
55
56impl<'a> UriReferenceResolver<'a> {
57 pub fn new(doc: &'a Document<'a>) -> Self {
59 Self::with_id_attrs(doc, DEFAULT_ID_ATTRS)
60 }
61
62 pub fn with_id_attrs(doc: &'a Document<'a>, extra_attrs: &[&str]) -> Self {
73 let mut id_map = HashMap::new();
74 let mut duplicate_ids: HashSet<&'a str> = HashSet::new();
77
78 let mut attr_names: Vec<&str> = DEFAULT_ID_ATTRS.to_vec();
80 for name in extra_attrs {
81 if !attr_names.contains(name) {
82 attr_names.push(name);
83 }
84 }
85
86 for node in doc.descendants() {
88 if node.is_element() {
89 for attr_name in &attr_names {
90 if let Some(value) = node.attribute(*attr_name) {
91 if duplicate_ids.contains(value) {
93 continue;
94 }
95
96 match id_map.entry(value) {
101 Entry::Vacant(v) => {
102 v.insert(node);
103 }
104 Entry::Occupied(o) => {
105 if o.get().id() != node.id() {
110 o.remove();
111 duplicate_ids.insert(value);
112 }
113 }
114 }
115 }
116 }
117 }
118 }
119
120 Self { doc, id_map }
121 }
122
123 pub fn dereference(&self, uri: &str) -> Result<TransformData<'a>, TransformError> {
135 if uri.is_empty() {
136 Ok(TransformData::NodeSet(
140 NodeSet::entire_document_without_comments(self.doc),
141 ))
142 } else if let Some(fragment) = uri.strip_prefix('#') {
143 self.dereference_fragment(fragment)
148 } else {
149 Err(TransformError::UnsupportedUri(uri.to_string()))
150 }
151 }
152
153 fn dereference_fragment(&self, fragment: &str) -> Result<TransformData<'a>, TransformError> {
160 if fragment.is_empty() {
161 return Err(TransformError::UnsupportedUri("#".to_string()));
163 }
164
165 if fragment == "xpointer(/)" {
166 Ok(TransformData::NodeSet(
170 NodeSet::entire_document_with_comments(self.doc),
171 ))
172 } else if let Some(id) = parse_xpointer_id(fragment) {
173 if id.is_empty() {
176 return Err(TransformError::UnsupportedUri(format!("#{fragment}")));
177 }
178 self.resolve_id(id)
179 } else if fragment.starts_with("xpointer(") {
180 Err(TransformError::UnsupportedUri(format!("#{fragment}")))
182 } else {
183 self.resolve_id(fragment)
185 }
186 }
187
188 fn resolve_id(&self, id: &str) -> Result<TransformData<'a>, TransformError> {
190 match self.id_map.get(id) {
191 Some(&element) => Ok(TransformData::NodeSet(NodeSet::subtree(element))),
192 None => Err(TransformError::ElementNotFound(id.to_string())),
193 }
194 }
195
196 pub fn has_id(&self, id: &str) -> bool {
198 self.id_map.contains_key(id)
199 }
200
201 pub fn id_count(&self) -> usize {
203 self.id_map.len()
204 }
205}
206
207fn parse_xpointer_id(fragment: &str) -> Option<&str> {
210 let inner = fragment.strip_prefix("xpointer(id(")?.strip_suffix("))")?;
211
212 if let Some(stripped) = inner.strip_prefix('\'').and_then(|s| s.strip_suffix('\'')) {
215 Some(stripped)
216 } else if let Some(stripped) = inner.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
217 Some(stripped)
218 } else {
219 None
220 }
221}
222
223#[cfg(test)]
224#[allow(clippy::unwrap_used)]
225mod tests {
226 use super::super::types::NodeSet;
227 use super::*;
228
229 #[test]
230 fn empty_uri_returns_whole_document() {
231 let xml = "<root><child>text</child></root>";
232 let doc = Document::parse(xml).unwrap();
233 let resolver = UriReferenceResolver::new(&doc);
234
235 let data = resolver.dereference("").unwrap();
236 let node_set = data.into_node_set().unwrap();
237
238 let root = doc.root_element();
240 assert!(node_set.contains(root));
241 let child = root.first_child().unwrap();
242 assert!(node_set.contains(child));
243 }
244
245 #[test]
246 fn empty_uri_excludes_comments() {
247 let xml = "<root><!-- comment --><child/></root>";
248 let doc = Document::parse(xml).unwrap();
249 let resolver = UriReferenceResolver::new(&doc);
250
251 let data = resolver.dereference("").unwrap();
252 let node_set = data.into_node_set().unwrap();
253
254 for node in doc.descendants() {
256 if node.is_comment() {
257 assert!(
258 !node_set.contains(node),
259 "comment should be excluded for empty URI"
260 );
261 }
262 }
263 assert!(node_set.contains(doc.root_element()));
265 }
266
267 #[test]
268 fn fragment_uri_resolves_by_id_attr() {
269 let xml = r#"<root><item ID="abc">content</item><item ID="def">other</item></root>"#;
270 let doc = Document::parse(xml).unwrap();
271 let resolver = UriReferenceResolver::new(&doc);
272
273 let data = resolver.dereference("#abc").unwrap();
274 let node_set = data.into_node_set().unwrap();
275
276 let abc_elem = doc
278 .descendants()
279 .find(|n| n.attribute("ID") == Some("abc"))
280 .unwrap();
281 assert!(node_set.contains(abc_elem));
282
283 let text_child = abc_elem.first_child().unwrap();
285 assert!(node_set.contains(text_child));
286
287 assert!(!node_set.contains(doc.root_element()));
289
290 let def_elem = doc
292 .descendants()
293 .find(|n| n.attribute("ID") == Some("def"))
294 .unwrap();
295 assert!(!node_set.contains(def_elem));
296 }
297
298 #[test]
299 fn fragment_uri_resolves_lowercase_id() {
300 let xml = r#"<root><item id="lower">text</item></root>"#;
301 let doc = Document::parse(xml).unwrap();
302 let resolver = UriReferenceResolver::new(&doc);
303
304 let data = resolver.dereference("#lower").unwrap();
305 let node_set = data.into_node_set().unwrap();
306
307 let elem = doc
308 .descendants()
309 .find(|n| n.attribute("id") == Some("lower"))
310 .unwrap();
311 assert!(node_set.contains(elem));
312 }
313
314 #[test]
315 fn fragment_uri_resolves_mixed_case_id() {
316 let xml = r#"<root><ds:Signature Id="sig1" xmlns:ds="http://www.w3.org/2000/09/xmldsig#"/></root>"#;
317 let doc = Document::parse(xml).unwrap();
318 let resolver = UriReferenceResolver::new(&doc);
319
320 assert!(resolver.has_id("sig1"));
321 let data = resolver.dereference("#sig1").unwrap();
322 assert!(data.into_node_set().is_ok());
323 }
324
325 #[test]
326 fn fragment_uri_not_found() {
327 let xml = "<root><child>text</child></root>";
328 let doc = Document::parse(xml).unwrap();
329 let resolver = UriReferenceResolver::new(&doc);
330
331 let result = resolver.dereference("#nonexistent");
332 assert!(result.is_err());
333 match result.unwrap_err() {
334 TransformError::ElementNotFound(id) => assert_eq!(id, "nonexistent"),
335 other => panic!("expected ElementNotFound, got: {other:?}"),
336 }
337 }
338
339 #[test]
340 fn unsupported_external_uri() {
341 let xml = "<root/>";
342 let doc = Document::parse(xml).unwrap();
343 let resolver = UriReferenceResolver::new(&doc);
344
345 let result = resolver.dereference("http://example.com/doc.xml");
346 assert!(result.is_err());
347 match result.unwrap_err() {
348 TransformError::UnsupportedUri(uri) => {
349 assert_eq!(uri, "http://example.com/doc.xml")
350 }
351 other => panic!("expected UnsupportedUri, got: {other:?}"),
352 }
353 }
354
355 #[test]
356 fn unsupported_xpointer_expression() {
357 let xml = "<root/>";
360 let doc = Document::parse(xml).unwrap();
361 let resolver = UriReferenceResolver::new(&doc);
362
363 let result = resolver.dereference("#xpointer(foo())");
364 assert!(result.is_err());
365 match result.unwrap_err() {
366 TransformError::UnsupportedUri(uri) => {
367 assert_eq!(uri, "#xpointer(foo())")
368 }
369 other => panic!("expected UnsupportedUri, got: {other:?}"),
370 }
371
372 let result = resolver.dereference("#xpointer(//element)");
374 assert!(result.is_err());
375 assert!(matches!(
376 result.unwrap_err(),
377 TransformError::UnsupportedUri(_)
378 ));
379 }
380
381 #[test]
382 fn empty_fragment_rejected() {
383 let xml = "<root/>";
385 let doc = Document::parse(xml).unwrap();
386 let resolver = UriReferenceResolver::new(&doc);
387
388 let result = resolver.dereference("#");
389 assert!(result.is_err());
390 match result.unwrap_err() {
391 TransformError::UnsupportedUri(uri) => assert_eq!(uri, "#"),
392 other => panic!("expected UnsupportedUri, got: {other:?}"),
393 }
394 }
395
396 #[test]
397 fn foreign_document_node_rejected() {
398 let xml1 = "<root><child/></root>";
400 let xml2 = "<other><item/></other>";
401 let doc1 = Document::parse(xml1).unwrap();
402 let doc2 = Document::parse(xml2).unwrap();
403
404 let node_set = NodeSet::entire_document_without_comments(&doc1);
405
406 let foreign_node = doc2.root_element();
408 assert!(
409 !node_set.contains(foreign_node),
410 "foreign document node should be rejected"
411 );
412
413 let own_node = doc1.root_element();
415 assert!(node_set.contains(own_node));
416 }
417
418 #[test]
419 fn custom_id_attr_name() {
420 let xml = r#"<root><elem myid="custom1">data</elem></root>"#;
423 let doc = Document::parse(xml).unwrap();
424
425 let resolver_default = UriReferenceResolver::new(&doc);
427 assert!(!resolver_default.has_id("custom1"));
428
429 let resolver_custom = UriReferenceResolver::with_id_attrs(&doc, &["myid"]);
431 assert!(resolver_custom.has_id("custom1"));
432
433 let data = resolver_custom.dereference("#custom1").unwrap();
434 assert!(data.into_node_set().is_ok());
435 }
436
437 #[test]
438 fn namespaced_id_attr_found_by_local_name() {
439 let xml =
441 r#"<root><elem wsu:Id="ts1" xmlns:wsu="http://example.com/wsu">data</elem></root>"#;
442 let doc = Document::parse(xml).unwrap();
443
444 let resolver = UriReferenceResolver::new(&doc);
445 assert!(resolver.has_id("ts1"));
446 }
447
448 #[test]
449 fn id_count_reports_unique_ids() {
450 let xml = r#"<root ID="r1"><a ID="a1"/><b Id="b1"/><c id="c1"/></root>"#;
451 let doc = Document::parse(xml).unwrap();
452 let resolver = UriReferenceResolver::new(&doc);
453
454 assert_eq!(resolver.id_count(), 4);
456 }
457
458 #[test]
459 fn duplicate_ids_are_rejected() {
460 let xml = r#"<root><a ID="dup">first</a><b ID="dup">second</b></root>"#;
463 let doc = Document::parse(xml).unwrap();
464 let resolver = UriReferenceResolver::new(&doc);
465
466 assert!(!resolver.has_id("dup"));
468 let result = resolver.dereference("#dup");
469 assert!(result.is_err());
470 assert!(matches!(
471 result.unwrap_err(),
472 TransformError::ElementNotFound(_)
473 ));
474 }
475
476 #[test]
477 fn triple_duplicate_ids_stay_rejected() {
478 let xml = r#"<root><a ID="dup">1</a><b ID="dup">2</b><c ID="dup">3</c></root>"#;
481 let doc = Document::parse(xml).unwrap();
482 let resolver = UriReferenceResolver::new(&doc);
483
484 assert!(!resolver.has_id("dup"));
485 assert!(resolver.dereference("#dup").is_err());
486 }
487
488 #[test]
489 fn node_set_exclude_subtree() {
490 let xml = r#"<root><keep>yes</keep><remove><deep>no</deep></remove></root>"#;
491 let doc = Document::parse(xml).unwrap();
492 let resolver = UriReferenceResolver::new(&doc);
493
494 let data = resolver.dereference("").unwrap();
495 let mut node_set = data.into_node_set().unwrap();
496
497 let remove_elem = doc
499 .descendants()
500 .find(|n| n.is_element() && n.has_tag_name("remove"))
501 .unwrap();
502 node_set.exclude_subtree(remove_elem);
503
504 let keep_elem = doc
506 .descendants()
507 .find(|n| n.is_element() && n.has_tag_name("keep"))
508 .unwrap();
509 assert!(node_set.contains(keep_elem));
510
511 assert!(!node_set.contains(remove_elem));
513 let deep_elem = doc
514 .descendants()
515 .find(|n| n.is_element() && n.has_tag_name("deep"))
516 .unwrap();
517 assert!(!node_set.contains(deep_elem));
518 }
519
520 #[test]
521 fn subtree_includes_comments() {
522 let xml = r#"<root><item ID="x"><!-- comment --><child/></item></root>"#;
524 let doc = Document::parse(xml).unwrap();
525 let resolver = UriReferenceResolver::new(&doc);
526
527 let data = resolver.dereference("#x").unwrap();
528 let node_set = data.into_node_set().unwrap();
529
530 for node in doc.descendants() {
531 if node.is_comment() {
532 assert!(
533 node_set.contains(node),
534 "comment should be included in #id subtree"
535 );
536 }
537 }
538 }
539
540 #[test]
541 fn xpointer_root_returns_whole_document_with_comments() {
542 let xml = "<root><!-- comment --><child/></root>";
543 let doc = Document::parse(xml).unwrap();
544 let resolver = UriReferenceResolver::new(&doc);
545
546 let data = resolver.dereference("#xpointer(/)").unwrap();
547 let node_set = data.into_node_set().unwrap();
548
549 for node in doc.descendants() {
551 if node.is_comment() {
552 assert!(
553 node_set.contains(node),
554 "comment should be included for #xpointer(/)"
555 );
556 }
557 }
558 assert!(node_set.contains(doc.root_element()));
559 }
560
561 #[test]
562 fn xpointer_id_single_quotes() {
563 let xml = r#"<root><item ID="abc">content</item></root>"#;
564 let doc = Document::parse(xml).unwrap();
565 let resolver = UriReferenceResolver::new(&doc);
566
567 let data = resolver.dereference("#xpointer(id('abc'))").unwrap();
568 let node_set = data.into_node_set().unwrap();
569
570 let elem = doc
571 .descendants()
572 .find(|n| n.attribute("ID") == Some("abc"))
573 .unwrap();
574 assert!(node_set.contains(elem));
575 }
576
577 #[test]
578 fn xpointer_id_double_quotes() {
579 let xml = r#"<root><item ID="xyz">content</item></root>"#;
580 let doc = Document::parse(xml).unwrap();
581 let resolver = UriReferenceResolver::new(&doc);
582
583 let data = resolver.dereference(r#"#xpointer(id("xyz"))"#).unwrap();
584 let node_set = data.into_node_set().unwrap();
585
586 let elem = doc
587 .descendants()
588 .find(|n| n.attribute("ID") == Some("xyz"))
589 .unwrap();
590 assert!(node_set.contains(elem));
591 }
592
593 #[test]
594 fn xpointer_id_not_found() {
595 let xml = "<root/>";
596 let doc = Document::parse(xml).unwrap();
597 let resolver = UriReferenceResolver::new(&doc);
598
599 let result = resolver.dereference("#xpointer(id('missing'))");
600 assert!(result.is_err());
601 match result.unwrap_err() {
602 TransformError::ElementNotFound(id) => assert_eq!(id, "missing"),
603 other => panic!("expected ElementNotFound, got: {other:?}"),
604 }
605 }
606
607 #[test]
608 fn xpointer_id_empty_value_rejected() {
609 let xml = "<root/>";
611 let doc = Document::parse(xml).unwrap();
612 let resolver = UriReferenceResolver::new(&doc);
613
614 let result = resolver.dereference("#xpointer(id(''))");
615 assert!(result.is_err());
616 assert!(matches!(
617 result.unwrap_err(),
618 TransformError::UnsupportedUri(_)
619 ));
620 }
621
622 #[test]
623 fn parse_xpointer_id_variants() {
624 assert_eq!(super::parse_xpointer_id("xpointer(id('foo'))"), Some("foo"));
626 assert_eq!(
627 super::parse_xpointer_id(r#"xpointer(id("bar"))"#),
628 Some("bar")
629 );
630
631 assert_eq!(super::parse_xpointer_id("xpointer(/)"), None);
633 assert_eq!(super::parse_xpointer_id("xpointer(id(foo))"), None); assert_eq!(super::parse_xpointer_id("not-xpointer"), None);
635 assert_eq!(super::parse_xpointer_id(""), None);
636
637 assert_eq!(super::parse_xpointer_id("xpointer(id('))"), None);
639 assert_eq!(super::parse_xpointer_id(r#"xpointer(id("))"#), None);
640 }
641
642 #[test]
643 fn same_element_multiple_id_attrs_not_duplicate() {
644 let xml = r#"<root><item ID="x" Id="x">data</item></root>"#;
648 let doc = Document::parse(xml).unwrap();
649 let resolver = UriReferenceResolver::new(&doc);
650
651 assert!(resolver.has_id("x"));
652 assert!(resolver.dereference("#x").is_ok());
653 }
654
655 #[test]
656 fn saml_style_document() {
657 let xml = r#"<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
659 xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
660 ID="_resp1">
661 <saml:Assertion ID="_assert1">
662 <saml:Subject>user@example.com</saml:Subject>
663 </saml:Assertion>
664 <ds:Signature xmlns:ds="http://www.w3.org/2000/09/xmldsig#" Id="sig1">
665 <ds:SignedInfo/>
666 </ds:Signature>
667 </samlp:Response>"#;
668
669 let doc = Document::parse(xml).unwrap();
670 let resolver = UriReferenceResolver::new(&doc);
671
672 assert!(resolver.has_id("_resp1"));
674 assert!(resolver.has_id("_assert1"));
675 assert!(resolver.has_id("sig1"));
676 assert_eq!(resolver.id_count(), 3);
677
678 let data = resolver.dereference("#_assert1").unwrap();
680 let node_set = data.into_node_set().unwrap();
681
682 let assertion = doc
684 .descendants()
685 .find(|n| n.attribute("ID") == Some("_assert1"))
686 .unwrap();
687 assert!(node_set.contains(assertion));
688
689 let subject = assertion
691 .children()
692 .find(|n| n.is_element() && n.has_tag_name("Subject"))
693 .unwrap();
694 assert!(node_set.contains(subject));
695
696 assert!(!node_set.contains(doc.root_element()));
698 }
699}