use std::collections::HashSet;
use std::fmt;
use crate::tree::{Document, NodeId, NodeKind};
pub const XINCLUDE_NS: &str = "http://www.w3.org/2001/XInclude";
const INCLUDE_ELEMENT: &str = "include";
const FALLBACK_ELEMENT: &str = "fallback";
#[derive(Debug, Clone)]
pub struct XIncludeOptions {
pub max_depth: usize,
}
impl Default for XIncludeOptions {
fn default() -> Self {
Self { max_depth: 50 }
}
}
#[derive(Debug, Clone)]
pub struct XIncludeError {
pub message: String,
pub href: Option<String>,
}
impl fmt::Display for XIncludeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.href {
Some(href) => write!(f, "XInclude error for '{href}': {}", self.message),
None => write!(f, "XInclude error: {}", self.message),
}
}
}
pub struct XIncludeResult {
pub inclusions: usize,
pub errors: Vec<XIncludeError>,
}
pub fn process_xincludes<F>(
doc: &mut Document,
resolver: F,
options: &XIncludeOptions,
) -> XIncludeResult
where
F: Fn(&str) -> Option<String>,
{
let mut state = ProcessingState {
inclusions: 0,
errors: Vec::new(),
active_hrefs: HashSet::new(),
max_depth: options.max_depth,
};
process_node(doc, doc.root(), &resolver, &mut state, 0);
XIncludeResult {
inclusions: state.inclusions,
errors: state.errors,
}
}
struct ProcessingState {
inclusions: usize,
errors: Vec<XIncludeError>,
active_hrefs: HashSet<String>,
max_depth: usize,
}
fn process_node<F>(
doc: &mut Document,
node: NodeId,
resolver: &F,
state: &mut ProcessingState,
depth: usize,
) where
F: Fn(&str) -> Option<String>,
{
let children: Vec<NodeId> = doc.children(node).collect();
for child in children {
if is_xinclude_element(doc, child) {
process_include_element(doc, child, resolver, state, depth);
} else {
process_node(doc, child, resolver, state, depth);
}
}
}
fn is_xinclude_element(doc: &Document, node: NodeId) -> bool {
if let NodeKind::Element {
name, namespace, ..
} = &doc.node(node).kind
{
name == INCLUDE_ELEMENT && namespace.as_deref() == Some(XINCLUDE_NS)
} else {
false
}
}
fn is_fallback_element(doc: &Document, node: NodeId) -> bool {
if let NodeKind::Element {
name, namespace, ..
} = &doc.node(node).kind
{
name == FALLBACK_ELEMENT && namespace.as_deref() == Some(XINCLUDE_NS)
} else {
false
}
}
fn process_include_element<F>(
doc: &mut Document,
include_node: NodeId,
resolver: &F,
state: &mut ProcessingState,
depth: usize,
) where
F: Fn(&str) -> Option<String>,
{
let href = doc.attribute(include_node, "href").map(str::to_owned);
let parse = doc
.attribute(include_node, "parse")
.unwrap_or("xml")
.to_owned();
let Some(href) = href else {
state.errors.push(XIncludeError {
message: "xi:include element is missing required 'href' attribute".to_string(),
href: None,
});
doc.detach(include_node);
return;
};
if parse != "xml" && parse != "text" {
state.errors.push(XIncludeError {
message: format!("invalid parse attribute value '{parse}'; expected 'xml' or 'text'"),
href: Some(href),
});
doc.detach(include_node);
return;
}
if depth >= state.max_depth {
state.errors.push(XIncludeError {
message: format!(
"maximum XInclude nesting depth ({}) exceeded",
state.max_depth
),
href: Some(href),
});
doc.detach(include_node);
return;
}
let (base_href, _fragment) = split_fragment(&href);
if state.active_hrefs.contains(base_href) {
state.errors.push(XIncludeError {
message: "circular inclusion detected".to_string(),
href: Some(href),
});
doc.detach(include_node);
return;
}
let content = resolver(base_href);
match content {
Some(content) => {
state.active_hrefs.insert(base_href.to_owned());
let success = match parse.as_str() {
"xml" => process_xml_include(doc, include_node, &content, resolver, state, depth),
"text" => process_text_include(doc, include_node, &content),
_ => false, };
state.active_hrefs.remove(base_href);
if success {
state.inclusions += 1;
}
}
None => {
if !try_fallback(doc, include_node, resolver, state, depth) {
state.errors.push(XIncludeError {
message: "resource not found and no xi:fallback provided".to_string(),
href: Some(href),
});
doc.detach(include_node);
}
}
}
}
fn process_xml_include<F>(
doc: &mut Document,
include_node: NodeId,
content: &str,
resolver: &F,
state: &mut ProcessingState,
depth: usize,
) -> bool
where
F: Fn(&str) -> Option<String>,
{
let included_doc = match Document::parse_str(content) {
Ok(d) => d,
Err(e) => {
if try_fallback(doc, include_node, resolver, state, depth) {
return false;
}
state.errors.push(XIncludeError {
message: format!("failed to parse included XML: {e}"),
href: None,
});
doc.detach(include_node);
return false;
}
};
let included_root = included_doc.root();
let included_children: Vec<NodeId> = included_doc.children(included_root).collect();
let parent = doc.parent(include_node);
let mut inserted_nodes = Vec::new();
for inc_child in &included_children {
let new_node = deep_copy_node(doc, &included_doc, *inc_child);
inserted_nodes.push(new_node);
}
for new_node in &inserted_nodes {
doc.insert_before(include_node, *new_node);
}
doc.detach(include_node);
if parent.is_some() {
for new_node in inserted_nodes {
process_node(doc, new_node, resolver, state, depth + 1);
}
}
true
}
fn process_text_include(doc: &mut Document, include_node: NodeId, content: &str) -> bool {
let text_node = doc.create_node(NodeKind::Text {
content: content.to_string(),
});
doc.insert_before(include_node, text_node);
doc.detach(include_node);
true
}
fn try_fallback<F>(
doc: &mut Document,
include_node: NodeId,
resolver: &F,
state: &mut ProcessingState,
depth: usize,
) -> bool
where
F: Fn(&str) -> Option<String>,
{
let fallback_node = {
let children: Vec<NodeId> = doc.children(include_node).collect();
children
.into_iter()
.find(|&child| is_fallback_element(doc, child))
};
let Some(fallback) = fallback_node else {
return false;
};
let fallback_children: Vec<NodeId> = doc.children(fallback).collect();
let mut inserted_nodes = Vec::new();
for child in fallback_children {
doc.detach(child);
doc.insert_before(include_node, child);
inserted_nodes.push(child);
}
doc.detach(include_node);
for node in inserted_nodes {
process_node(doc, node, resolver, state, depth + 1);
}
true
}
fn deep_copy_node(target: &mut Document, source: &Document, source_id: NodeId) -> NodeId {
let source_node = source.node(source_id);
let new_id = target.create_node(source_node.kind.clone());
let children: Vec<NodeId> = source.children(source_id).collect();
for child_id in children {
let new_child = deep_copy_node(target, source, child_id);
target.append_child(new_id, new_child);
}
new_id
}
fn split_fragment(href: &str) -> (&str, Option<&str>) {
if let Some(pos) = href.find('#') {
let (base, frag) = href.split_at(pos);
(base, Some(&frag[1..]))
} else {
(href, None)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn process_with_resolver<F>(xml: &str, resolver: F) -> (Document, XIncludeResult)
where
F: Fn(&str) -> Option<String>,
{
let mut doc = Document::parse_str(xml).unwrap();
let result = process_xincludes(&mut doc, resolver, &XIncludeOptions::default());
(doc, result)
}
fn doc_text_content(doc: &Document) -> String {
let root_elem = doc.root_element().unwrap();
doc.text_content(root_elem)
}
#[test]
fn test_basic_xml_include() {
let xml =
r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="inc.xml"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"inc.xml" => Some("<greeting>hello</greeting>".to_string()),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert!(result.errors.is_empty());
let root = doc.root_element().unwrap();
let children: Vec<NodeId> = doc.children(root).collect();
assert_eq!(children.len(), 1);
assert_eq!(doc.node_name(children[0]), Some("greeting"));
assert_eq!(doc.text_content(children[0]), "hello");
}
#[test]
fn test_basic_text_include() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="msg.txt" parse="text"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"msg.txt" => Some("Hello, World!".to_string()),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert!(result.errors.is_empty());
assert_eq!(doc_text_content(&doc), "Hello, World!");
}
#[test]
fn test_fallback_when_resource_not_found() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="missing.xml"><xi:fallback><alt>fallback content</alt></xi:fallback></xi:include></doc>"#;
let (doc, result) = process_with_resolver(xml, |_| None);
assert_eq!(result.inclusions, 0);
assert!(result.errors.is_empty());
let root = doc.root_element().unwrap();
let children: Vec<NodeId> = doc.children(root).collect();
assert_eq!(children.len(), 1);
assert_eq!(doc.node_name(children[0]), Some("alt"));
assert_eq!(doc.text_content(children[0]), "fallback content");
}
#[test]
fn test_fallback_with_text_content() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="missing.xml"><xi:fallback>plain fallback</xi:fallback></xi:include></doc>"#;
let (doc, result) = process_with_resolver(xml, |_| None);
assert_eq!(result.inclusions, 0);
assert!(result.errors.is_empty());
assert_eq!(doc_text_content(&doc), "plain fallback");
}
#[test]
fn test_missing_href_attribute() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include/></doc>"#;
let (_doc, result) = process_with_resolver(xml, |_| None);
assert_eq!(result.inclusions, 0);
assert_eq!(result.errors.len(), 1);
assert!(result.errors[0].message.contains("missing required 'href'"));
assert!(result.errors[0].href.is_none());
}
#[test]
fn test_circular_inclusion_detection() {
let xml =
r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="a.xml"/></doc>"#;
let (_, result) = process_with_resolver(xml, |href| match href {
"a.xml" => Some(
r#"<a xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="a.xml"/></a>"#
.to_string(),
),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert_eq!(result.errors.len(), 1);
assert!(result.errors[0].message.contains("circular inclusion"));
}
#[test]
fn test_max_depth_exceeded() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="deep.xml"/></doc>"#;
let mut doc = Document::parse_str(xml).unwrap();
let opts = XIncludeOptions { max_depth: 2 };
let result = process_xincludes(
&mut doc,
|href| {
match href {
"deep.xml" => Some(
r#"<level xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="deeper.xml"/></level>"#
.to_string(),
),
"deeper.xml" => Some(
r#"<level xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="deepest.xml"/></level>"#
.to_string(),
),
"deepest.xml" => Some("<leaf/>".to_string()),
_ => None,
}
},
&opts,
);
assert!(result.errors.iter().any(|e| e.message.contains("depth")));
}
#[test]
fn test_multiple_includes_in_same_document() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="a.xml"/><xi:include href="b.xml"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"a.xml" => Some("<first/>".to_string()),
"b.xml" => Some("<second/>".to_string()),
_ => None,
});
assert_eq!(result.inclusions, 2);
assert!(result.errors.is_empty());
let root = doc.root_element().unwrap();
let children: Vec<NodeId> = doc.children(root).collect();
assert_eq!(children.len(), 2);
assert_eq!(doc.node_name(children[0]), Some("first"));
assert_eq!(doc.node_name(children[1]), Some("second"));
}
#[test]
fn test_nested_includes() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="outer.xml"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| {
match href {
"outer.xml" => Some(
r#"<outer xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="inner.xml"/></outer>"#
.to_string(),
),
"inner.xml" => Some("<inner>nested</inner>".to_string()),
_ => None,
}
});
assert_eq!(result.inclusions, 2);
assert!(result.errors.is_empty());
let root = doc.root_element().unwrap();
let outer: Vec<NodeId> = doc.children(root).collect();
assert_eq!(doc.node_name(outer[0]), Some("outer"));
let inner: Vec<NodeId> = doc.children(outer[0]).collect();
assert_eq!(doc.node_name(inner[0]), Some("inner"));
assert_eq!(doc.text_content(inner[0]), "nested");
}
#[test]
fn test_default_parse_attribute_is_xml() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="data.xml"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"data.xml" => Some("<item>value</item>".to_string()),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert!(result.errors.is_empty());
let root = doc.root_element().unwrap();
let children: Vec<NodeId> = doc.children(root).collect();
assert_eq!(doc.node_name(children[0]), Some("item"));
}
#[test]
fn test_include_replaces_entire_xi_include_element() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><before/><xi:include href="mid.xml"/><after/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"mid.xml" => Some("<middle/>".to_string()),
_ => None,
});
assert_eq!(result.inclusions, 1);
let root = doc.root_element().unwrap();
let names: Vec<Option<&str>> = doc.children(root).map(|c| doc.node_name(c)).collect();
assert_eq!(names, vec![Some("before"), Some("middle"), Some("after")]);
}
#[test]
fn test_text_include_preserves_whitespace() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="ws.txt" parse="text"/></doc>"#;
let content = " line1\n line2\n";
let (doc, result) = process_with_resolver(xml, |href| match href {
"ws.txt" => Some(content.to_string()),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert_eq!(doc_text_content(&doc), content);
}
#[test]
fn test_empty_include_content() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="empty.txt" parse="text"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"empty.txt" => Some(String::new()),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert!(result.errors.is_empty());
assert_eq!(doc_text_content(&doc), "");
}
#[test]
fn test_include_with_fragment_identifier() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="data.xml#section1"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"data.xml" => Some("<section>content</section>".to_string()),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert!(result.errors.is_empty());
let root = doc.root_element().unwrap();
let children: Vec<NodeId> = doc.children(root).collect();
assert_eq!(doc.node_name(children[0]), Some("section"));
}
#[test]
fn test_xinclude_namespace_detection() {
let xml = r#"<doc><include href="should-ignore.xml"/></doc>"#;
let (_, result) = process_with_resolver(xml, |_| {
panic!("resolver should not be called for non-XInclude elements");
});
assert_eq!(result.inclusions, 0);
assert!(result.errors.is_empty());
}
#[test]
fn test_split_fragment() {
assert_eq!(split_fragment("file.xml#sec"), ("file.xml", Some("sec")));
assert_eq!(split_fragment("file.xml"), ("file.xml", None));
assert_eq!(split_fragment("file.xml#"), ("file.xml", Some("")));
assert_eq!(split_fragment("#frag"), ("", Some("frag")));
}
#[test]
fn test_no_fallback_records_error() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="nope.xml"/></doc>"#;
let (_, result) = process_with_resolver(xml, |_| None);
assert_eq!(result.inclusions, 0);
assert_eq!(result.errors.len(), 1);
assert!(result.errors[0].message.contains("resource not found"));
assert_eq!(result.errors[0].href.as_deref(), Some("nope.xml"));
}
#[test]
fn test_invalid_parse_attribute() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="x.xml" parse="json"/></doc>"#;
let (_, result) = process_with_resolver(xml, |_| None);
assert_eq!(result.errors.len(), 1);
assert!(result.errors[0].message.contains("invalid parse attribute"));
}
#[test]
fn test_xml_include_with_wrapper_element() {
let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="multi.xml"/></doc>"#;
let (doc, result) = process_with_resolver(xml, |href| match href {
"multi.xml" => Some("<wrapper><first/><second/></wrapper>".to_string()),
_ => None,
});
assert_eq!(result.inclusions, 1);
assert!(result.errors.is_empty());
let root = doc.root_element().unwrap();
let children: Vec<NodeId> = doc.children(root).collect();
assert_eq!(children.len(), 1);
assert_eq!(doc.node_name(children[0]), Some("wrapper"));
let wrapper_children: Vec<NodeId> = doc.children(children[0]).collect();
assert_eq!(wrapper_children.len(), 2);
assert_eq!(doc.node_name(wrapper_children[0]), Some("first"));
assert_eq!(doc.node_name(wrapper_children[1]), Some("second"));
}
#[test]
fn test_options_default() {
let opts = XIncludeOptions::default();
assert_eq!(opts.max_depth, 50);
}
#[test]
fn test_error_display() {
let err = XIncludeError {
message: "resource not found".to_string(),
href: Some("file.xml".to_string()),
};
assert_eq!(
err.to_string(),
"XInclude error for 'file.xml': resource not found"
);
let err_no_href = XIncludeError {
message: "bad element".to_string(),
href: None,
};
assert_eq!(err_no_href.to_string(), "XInclude error: bad element");
}
}