use std::collections::HashMap;
use std::path::Path;
use crate::error::Result;
use crate::schema::fetcher::{FetchResult, SchemaFetcher};
#[derive(Debug, Clone)]
pub struct ExportResult {
pub schema_count: usize,
pub uri_to_filename: HashMap<String, String>,
pub entry_filename: Option<String>,
}
pub fn export_schemas_from_xml<F: SchemaFetcher>(
xml_content: &[u8],
output_dir: &Path,
fetcher: &F,
) -> Result<ExportResult> {
let locations = crate::parser::parse_schema_locations_from_reader(xml_content)?;
if locations.is_empty() {
return Ok(ExportResult {
schema_count: 0,
uri_to_filename: HashMap::new(),
entry_filename: None,
});
}
std::fs::create_dir_all(output_dir)?;
let mut schemas: HashMap<String, Vec<u8>> = HashMap::new();
let mut entry_uri = None;
for (_namespace, location) in &locations {
match fetcher.fetch(location) {
Ok(result) => {
if entry_uri.is_none() {
entry_uri = Some(result.final_url.clone());
}
schemas.insert(result.final_url.clone(), result.content.clone());
let _ = resolve_imports_recursive(
&result.final_url,
&result.content,
fetcher,
&mut schemas,
);
}
Err(_) => {
continue;
}
}
}
let mut uri_to_filename: HashMap<String, String> = HashMap::new();
let mut existing_filenames: std::collections::HashSet<String> =
std::collections::HashSet::new();
let uris: Vec<String> = schemas.keys().cloned().collect();
for uri in &uris {
let filename = uri_to_safe_filename(uri, &existing_filenames);
existing_filenames.insert(filename.clone());
uri_to_filename.insert(uri.clone(), filename);
}
for uri in &uris {
if let Some(content) = schemas.get(uri) {
let filename = uri_to_filename.get(uri).unwrap();
let rewritten = rewrite_schema_locations(content, uri, &uri_to_filename)?;
let output_path = output_dir.join(filename);
std::fs::write(&output_path, rewritten)?;
}
}
let entry_filename = entry_uri.and_then(|uri| uri_to_filename.get(&uri).cloned());
Ok(ExportResult {
schema_count: uri_to_filename.len(),
uri_to_filename,
entry_filename,
})
}
fn resolve_imports_recursive<F: SchemaFetcher>(
base_uri: &str,
content: &[u8],
fetcher: &F,
schemas: &mut HashMap<String, Vec<u8>>,
) -> Result<()> {
let content_str = std::str::from_utf8(content).unwrap_or("");
for location in extract_schema_locations(content_str) {
let resolved_uri = resolve_uri(base_uri, &location)?;
if !schemas.contains_key(&resolved_uri) {
match fetcher.fetch(&resolved_uri) {
Ok(FetchResult {
content: fetched_content,
final_url,
..
}) => {
schemas.insert(final_url.clone(), fetched_content.clone());
if final_url != resolved_uri {
schemas.insert(resolved_uri, fetched_content.clone());
}
resolve_imports_recursive(&final_url, &fetched_content, fetcher, schemas)?;
}
Err(_) => {
continue;
}
}
}
}
Ok(())
}
fn extract_schema_locations(content: &str) -> Vec<String> {
let mut locations = Vec::new();
let patterns = [r#"schemaLocation=""#, r#"schemaLocation='"#];
for pattern in patterns {
let quote = if pattern.ends_with('"') { '"' } else { '\'' };
let mut remaining = content;
while let Some(start) = remaining.find(pattern) {
let after_pattern = &remaining[start + pattern.len()..];
if let Some(end) = after_pattern.find(quote) {
let location = &after_pattern[..end];
if !location.contains(' ') && !location.is_empty() {
locations.push(location.to_string());
}
remaining = &after_pattern[end + 1..];
} else {
break;
}
}
}
locations
}
fn rewrite_schema_locations(
content: &[u8],
base_uri: &str,
uri_to_filename: &HashMap<String, String>,
) -> Result<Vec<u8>> {
let content_str = std::str::from_utf8(content).unwrap_or("");
let mut result = content_str.to_string();
for (uri, filename) in uri_to_filename {
let old_double = format!(r#"schemaLocation="{}""#, uri);
let new_double = format!(r#"schemaLocation="{}""#, filename);
result = result.replace(&old_double, &new_double);
let old_single = format!(r#"schemaLocation='{}'"#, uri);
let new_single = format!(r#"schemaLocation='{}'"#, filename);
result = result.replace(&old_single, &new_single);
}
let locations = extract_schema_locations(&result);
for location in locations {
if !location.contains('/') && !location.contains('\\') {
continue;
}
if let Ok(resolved) = resolve_uri(base_uri, &location) {
if let Some(filename) = uri_to_filename.get(&resolved) {
let old_double = format!(r#"schemaLocation="{}""#, location);
let new_double = format!(r#"schemaLocation="{}""#, filename);
result = result.replace(&old_double, &new_double);
let old_single = format!(r#"schemaLocation='{}'"#, location);
let new_single = format!(r#"schemaLocation='{}'"#, filename);
result = result.replace(&old_single, &new_single);
}
}
}
Ok(result.into_bytes())
}
fn uri_to_safe_filename(
uri: &str,
existing_filenames: &std::collections::HashSet<String>,
) -> String {
let without_protocol = uri
.strip_prefix("http://")
.or_else(|| uri.strip_prefix("https://"))
.or_else(|| uri.strip_prefix("file://"))
.unwrap_or(uri);
let base_filename = Path::new(without_protocol)
.file_name()
.and_then(|n| n.to_str())
.map(|s| s.to_string())
.unwrap_or_else(|| {
format!("schema_{:x}.xsd", hash_uri(uri))
});
let base_filename = if base_filename.ends_with(".xsd") {
base_filename
} else {
format!("{}.xsd", base_filename)
};
if !existing_filenames.contains(&base_filename) {
return base_filename;
}
let stem = base_filename.strip_suffix(".xsd").unwrap_or(&base_filename);
let hash_suffix = format!("{:08x}", hash_uri(uri) as u32);
format!("{}_{}.xsd", stem, hash_suffix)
}
fn hash_uri(uri: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
uri.hash(&mut hasher);
hasher.finish()
}
fn resolve_uri(base: &str, relative: &str) -> Result<String> {
if relative.starts_with("http://")
|| relative.starts_with("https://")
|| relative.starts_with("file://")
{
return Ok(relative.to_string());
}
if let Some(base_path) = base.strip_prefix("file://") {
let base_dir = Path::new(base_path).parent().unwrap_or(Path::new("."));
let resolved = base_dir.join(relative);
let canonical = resolved.canonicalize().unwrap_or_else(|_| resolved.clone());
return Ok(format!("file://{}", canonical.display()));
}
if base.starts_with("http://") || base.starts_with("https://") {
if let Some(last_slash) = base.rfind('/') {
let base_dir = &base[..=last_slash];
let combined = format!("{}{}", base_dir, relative);
return Ok(normalize_url_path(&combined));
}
}
Ok(format!("{}/{}", base, relative))
}
fn normalize_url_path(url: &str) -> String {
let (prefix, path) = if let Some(pos) = url.find("://") {
let after_protocol = &url[pos + 3..];
if let Some(slash_pos) = after_protocol.find('/') {
let host_end = pos + 3 + slash_pos;
(&url[..host_end], &url[host_end..])
} else {
return url.to_string();
}
} else {
return url.to_string();
};
let mut segments: Vec<&str> = Vec::new();
for segment in path.split('/') {
match segment {
"" | "." => {}
".." => {
segments.pop();
}
s => segments.push(s),
}
}
format!("{}/{}", prefix, segments.join("/"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uri_to_safe_filename() {
use std::collections::HashSet;
let empty: HashSet<String> = HashSet::new();
assert_eq!(
uri_to_safe_filename("http://example.com/schemas/types.xsd", &empty),
"types.xsd"
);
assert_eq!(
uri_to_safe_filename("https://schemas.opengis.net/gml/3.2.1/gml.xsd", &empty),
"gml.xsd"
);
assert_eq!(
uri_to_safe_filename("file:///path/to/schema.xsd", &empty),
"schema.xsd"
);
}
#[test]
fn test_uri_to_safe_filename_uniqueness() {
use std::collections::HashSet;
let mut existing: HashSet<String> = HashSet::new();
existing.insert("types.xsd".to_string());
let filename = uri_to_safe_filename("http://example.com/other/types.xsd", &existing);
assert!(filename.starts_with("types_"));
assert!(filename.ends_with(".xsd"));
assert_ne!(filename, "types.xsd");
}
#[test]
fn test_extract_schema_locations() {
let content = r#"
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:import namespace="http://example.com" schemaLocation="types.xsd"/>
<xs:include schemaLocation='common.xsd'/>
</xs:schema>
"#;
let locations = extract_schema_locations(content);
assert_eq!(locations.len(), 2);
assert!(locations.contains(&"types.xsd".to_string()));
assert!(locations.contains(&"common.xsd".to_string()));
}
#[test]
fn test_extract_schema_locations_skips_xsi() {
let content = r#"
<root xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://example.com http://example.com/schema.xsd">
</root>
"#;
let locations = extract_schema_locations(content);
assert!(locations.is_empty());
}
#[test]
fn test_resolve_uri_absolute() {
let result = resolve_uri("http://base.com/path/", "http://other.com/schema.xsd").unwrap();
assert_eq!(result, "http://other.com/schema.xsd");
}
#[test]
fn test_resolve_uri_relative_http() {
let result = resolve_uri("http://example.com/schemas/main.xsd", "types.xsd").unwrap();
assert_eq!(result, "http://example.com/schemas/types.xsd");
}
#[test]
fn test_rewrite_schema_locations() {
let content = br#"<xs:import schemaLocation="http://example.com/types.xsd"/>"#;
let mut mapping = HashMap::new();
mapping.insert(
"http://example.com/types.xsd".to_string(),
"types.xsd".to_string(),
);
let result =
rewrite_schema_locations(content, "http://example.com/main.xsd", &mapping).unwrap();
let result_str = std::str::from_utf8(&result).unwrap();
assert!(result_str.contains(r#"schemaLocation="types.xsd""#));
}
#[test]
fn test_rewrite_schema_locations_relative_path() {
let content = br#"<xs:import schemaLocation="../types/common.xsd"/>"#;
let mut mapping = HashMap::new();
mapping.insert(
"http://example.com/types/common.xsd".to_string(),
"common.xsd".to_string(),
);
let result =
rewrite_schema_locations(content, "http://example.com/schemas/main.xsd", &mapping)
.unwrap();
let result_str = std::str::from_utf8(&result).unwrap();
assert!(result_str.contains(r#"schemaLocation="common.xsd""#));
}
}