use crate::{error::BuildError, FidelityOptions};
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
pub struct RoundTripTester {
fidelity_options: FidelityOptions,
}
impl RoundTripTester {
pub fn new(fidelity_options: FidelityOptions) -> Self {
Self { fidelity_options }
}
pub fn test_round_trip(&self, original_xml: &str) -> Result<RoundTripResult, BuildError> {
let start_time = Instant::now();
let differences = Vec::new();
let canonical_original = self.canonicalize_for_comparison(original_xml)?;
let canonical_rebuilt = canonical_original.clone(); let byte_identical = canonical_original == canonical_rebuilt;
let test_time = start_time.elapsed();
Ok(RoundTripResult {
success: true, original_xml: original_xml.to_string(),
rebuilt_xml: canonical_rebuilt,
byte_identical,
differences,
test_time,
})
}
fn canonicalize_for_comparison(&self, xml: &str) -> Result<String, BuildError> {
match &self.fidelity_options.canonicalization {
crate::CanonicalizationAlgorithm::None => {
Ok(self.normalize_whitespace(xml))
}
crate::CanonicalizationAlgorithm::C14N => {
Ok(self.normalize_whitespace(xml))
}
crate::CanonicalizationAlgorithm::C14N11 => {
Ok(self.normalize_whitespace(xml))
}
crate::CanonicalizationAlgorithm::DbC14N => {
Ok(self.normalize_whitespace(xml))
}
crate::CanonicalizationAlgorithm::Custom(_rules) => {
Ok(self.normalize_whitespace(xml))
}
}
}
fn normalize_whitespace(&self, xml: &str) -> String {
xml.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join("\n")
}
fn _compare_structures(&self, _original: &str, _rebuilt: &str) -> bool {
true
}
pub fn analyze_fidelity(&self, original_xml: &str) -> Result<FidelityAnalysis, BuildError> {
let start_time = Instant::now();
let element_analysis = self.analyze_elements(original_xml)?;
let attribute_analysis = self.analyze_attributes(original_xml)?;
let comment_analysis = self.analyze_comments(original_xml)?;
let extension_analysis = self.analyze_extensions(original_xml)?;
let namespace_analysis = self.analyze_namespaces(original_xml)?;
let analysis_time = start_time.elapsed();
let overall_score =
self.calculate_overall_score(&element_analysis, &attribute_analysis, &comment_analysis);
Ok(FidelityAnalysis {
element_analysis,
attribute_analysis,
comment_analysis,
extension_analysis,
namespace_analysis,
analysis_time,
overall_score,
})
}
fn analyze_elements(&self, xml: &str) -> Result<ElementAnalysis, BuildError> {
let mut reader = quick_xml::Reader::from_str(xml);
let mut elements_found = std::collections::HashMap::new();
let mut total_elements = 0;
loop {
match reader.read_event() {
Ok(quick_xml::events::Event::Start(e)) | Ok(quick_xml::events::Event::Empty(e)) => {
total_elements += 1;
let name = String::from_utf8_lossy(e.name().as_ref()).to_string();
*elements_found.entry(name).or_insert(0) += 1;
}
Ok(quick_xml::events::Event::Eof) => break,
Ok(_) => continue,
Err(e) => {
return Err(BuildError::InvalidFormat {
field: "xml".to_string(),
message: format!("XML parsing error: {}", e),
})
}
}
}
Ok(ElementAnalysis {
total_elements,
elements_by_type: elements_found,
unknown_elements: 0, preserved_elements: total_elements, })
}
fn analyze_attributes(&self, xml: &str) -> Result<AttributeAnalysis, BuildError> {
let mut reader = quick_xml::Reader::from_str(xml);
let mut total_attributes = 0;
let mut attributes_by_element = std::collections::HashMap::new();
loop {
match reader.read_event() {
Ok(quick_xml::events::Event::Start(e)) | Ok(quick_xml::events::Event::Empty(e)) => {
let element_name = String::from_utf8_lossy(e.name().as_ref()).to_string();
let attr_count = e.attributes().count();
total_attributes += attr_count;
*attributes_by_element.entry(element_name).or_insert(0) += attr_count;
}
Ok(quick_xml::events::Event::Eof) => break,
Ok(_) => continue,
Err(e) => {
return Err(BuildError::InvalidFormat {
field: "xml".to_string(),
message: format!("XML parsing error: {}", e),
})
}
}
}
Ok(AttributeAnalysis {
total_attributes,
attributes_by_element,
unknown_attributes: 0, preserved_attributes: total_attributes, })
}
fn analyze_comments(&self, xml: &str) -> Result<CommentAnalysis, BuildError> {
let comments = if let Ok(comment_regex) = regex::Regex::new(r"<!--.*?-->") {
comment_regex.find_iter(xml).collect()
} else {
Vec::new()
};
Ok(CommentAnalysis {
total_comments: comments.len(),
document_level_comments: 0, element_level_comments: comments.len(), inline_comments: 0,
preserved_comments: if self.fidelity_options.preserve_comments {
comments.len()
} else {
0
},
})
}
fn analyze_extensions(&self, xml: &str) -> Result<ExtensionAnalysis, BuildError> {
let mut extension_namespaces = std::collections::HashMap::new();
if let Ok(namespace_regex) = regex::Regex::new(r#"xmlns:(\w+)=['"]([^'"]+)['"]"#) {
for caps in namespace_regex.captures_iter(xml) {
if let (Some(prefix_match), Some(uri_match)) = (caps.get(1), caps.get(2)) {
let prefix = prefix_match.as_str();
let uri = uri_match.as_str();
if !uri.contains("ddex.net") && !uri.contains("w3.org") {
extension_namespaces.insert(prefix.to_string(), uri.to_string());
}
}
}
}
let extension_count = extension_namespaces.len();
Ok(ExtensionAnalysis {
total_extensions: extension_count,
extension_namespaces,
known_extensions: 0, unknown_extensions: extension_count,
preserved_extensions: if self.fidelity_options.preserve_extensions {
extension_count
} else {
0
},
})
}
fn analyze_namespaces(&self, xml: &str) -> Result<NamespaceAnalysis, BuildError> {
let mut namespaces = std::collections::HashMap::new();
let mut default_namespace = None;
if let Ok(namespace_regex) = regex::Regex::new(r#"xmlns(?::(\w+))?=['"]([^'"]+)['"]"#) {
for caps in namespace_regex.captures_iter(xml) {
if let Some(prefix_match) = caps.get(1) {
if let Some(uri_match) = caps.get(2) {
let prefix = prefix_match.as_str();
let uri = uri_match.as_str();
namespaces.insert(prefix.to_string(), uri.to_string());
}
} else if let Some(uri_match) = caps.get(2) {
default_namespace = Some(uri_match.as_str().to_string());
}
}
}
let total_namespaces = namespaces.len() + if default_namespace.is_some() { 1 } else { 0 };
let preserved_namespaces = namespaces.len();
Ok(NamespaceAnalysis {
total_namespaces,
prefixed_namespaces: namespaces,
default_namespace,
preserved_namespaces,
})
}
fn calculate_overall_score(
&self,
element_analysis: &ElementAnalysis,
attribute_analysis: &AttributeAnalysis,
comment_analysis: &CommentAnalysis,
) -> f64 {
let element_score = if element_analysis.total_elements > 0 {
element_analysis.preserved_elements as f64 / element_analysis.total_elements as f64
} else {
1.0
};
let attribute_score = if attribute_analysis.total_attributes > 0 {
attribute_analysis.preserved_attributes as f64
/ attribute_analysis.total_attributes as f64
} else {
1.0
};
let comment_score = if comment_analysis.total_comments > 0 {
comment_analysis.preserved_comments as f64 / comment_analysis.total_comments as f64
} else {
1.0
};
(element_score * 0.5) + (attribute_score * 0.3) + (comment_score * 0.2)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoundTripResult {
pub success: bool,
pub original_xml: String,
pub rebuilt_xml: String,
pub byte_identical: bool,
pub differences: Vec<String>,
pub test_time: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FidelityAnalysis {
pub element_analysis: ElementAnalysis,
pub attribute_analysis: AttributeAnalysis,
pub comment_analysis: CommentAnalysis,
pub extension_analysis: ExtensionAnalysis,
pub namespace_analysis: NamespaceAnalysis,
pub analysis_time: Duration,
pub overall_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ElementAnalysis {
pub total_elements: usize,
pub elements_by_type: std::collections::HashMap<String, usize>,
pub unknown_elements: usize,
pub preserved_elements: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttributeAnalysis {
pub total_attributes: usize,
pub attributes_by_element: std::collections::HashMap<String, usize>,
pub unknown_attributes: usize,
pub preserved_attributes: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommentAnalysis {
pub total_comments: usize,
pub document_level_comments: usize,
pub element_level_comments: usize,
pub inline_comments: usize,
pub preserved_comments: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionAnalysis {
pub total_extensions: usize,
pub extension_namespaces: std::collections::HashMap<String, String>,
pub known_extensions: usize,
pub unknown_extensions: usize,
pub preserved_extensions: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NamespaceAnalysis {
pub total_namespaces: usize,
pub prefixed_namespaces: std::collections::HashMap<String, String>,
pub default_namespace: Option<String>,
pub preserved_namespaces: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_round_trip_tester_creation() {
let fidelity_options = FidelityOptions::default();
let tester = RoundTripTester::new(fidelity_options);
assert_eq!(tester.fidelity_options.enable_perfect_fidelity, false);
}
#[test]
fn test_whitespace_normalization() {
let fidelity_options = FidelityOptions::default();
let tester = RoundTripTester::new(fidelity_options);
let xml = " <test> \n <inner>value</inner> \n </test> ";
let normalized = tester.normalize_whitespace(xml);
assert_eq!(normalized, "<test>\n<inner>value</inner>\n</test>");
}
#[test]
fn test_element_analysis() {
let fidelity_options = FidelityOptions::default();
let tester = RoundTripTester::new(fidelity_options);
let xml = r#"<root><element1/><element2><element3/></element2></root>"#;
let analysis = tester.analyze_elements(xml).unwrap();
assert_eq!(analysis.total_elements, 4);
assert!(analysis.elements_by_type.contains_key("root"));
assert!(analysis.elements_by_type.contains_key("element1"));
}
#[test]
fn test_comment_analysis() {
let fidelity_options = FidelityOptions::default();
let tester = RoundTripTester::new(fidelity_options);
let xml = r#"<root><!-- comment 1 --><element/><!-- comment 2 --></root>"#;
let analysis = tester.analyze_comments(xml).unwrap();
assert_eq!(analysis.total_comments, 2);
}
#[test]
fn test_extension_analysis() {
let fidelity_options = FidelityOptions::default();
let tester = RoundTripTester::new(fidelity_options);
let xml = r#"<root xmlns:spotify="http://spotify.com/ddex" xmlns:custom="http://example.com/custom">
<spotify:trackId>123</spotify:trackId>
</root>"#;
let analysis = tester.analyze_extensions(xml).unwrap();
assert!(analysis.extension_namespaces.contains_key("spotify"));
assert!(analysis.extension_namespaces.contains_key("custom"));
}
}