use std::collections::{HashMap, HashSet, VecDeque};
use crate::error::Result;
use crate::schema::fetcher::AsyncSchemaFetcher;
use super::super::parser::parse_xsd_ast;
use super::super::types::XsdSchema;
use super::common::resolve_uri;
pub struct AsyncSchemaResolver<'a, F: AsyncSchemaFetcher> {
fetcher: &'a F,
schemas: HashMap<String, XsdSchema>,
resolving: HashSet<String>,
}
impl<'a, F: AsyncSchemaFetcher> AsyncSchemaResolver<'a, F> {
pub fn new(fetcher: &'a F) -> Self {
Self {
fetcher,
schemas: HashMap::new(),
resolving: HashSet::new(),
}
}
pub async fn resolve_all(
&mut self,
entry_content: &[u8],
entry_uri: &str,
) -> Result<Vec<XsdSchema>> {
let entry_schema = parse_xsd_ast(entry_content)?;
self.schemas.insert(entry_uri.to_string(), entry_schema);
let mut queue: VecDeque<String> = VecDeque::new();
queue.push_back(entry_uri.to_string());
while let Some(current_uri) = queue.pop_front() {
if self.resolving.contains(¤t_uri) {
return Err(crate::schema::error::SchemaError::CircularDependency {
uri: current_uri,
}
.into());
}
self.resolving.insert(current_uri.clone());
let (imports, includes) = {
let schema = self.schemas.get(¤t_uri).ok_or_else(|| {
crate::schema::error::SchemaError::SchemaNotFound {
uri: current_uri.clone(),
}
})?;
(schema.imports.clone(), schema.includes.clone())
};
for import in imports {
if let Some(location) = &import.schema_location {
let resolved_uri = resolve_uri(¤t_uri, location)?;
if !self.schemas.contains_key(&resolved_uri) {
let content = self.fetch_schema(&resolved_uri).await?;
let schema = parse_xsd_ast(&content)?;
self.schemas.insert(resolved_uri.clone(), schema);
queue.push_back(resolved_uri);
}
}
}
for include in includes {
let resolved_uri = resolve_uri(¤t_uri, &include.schema_location)?;
if !self.schemas.contains_key(&resolved_uri) {
let content = self.fetch_schema(&resolved_uri).await?;
let schema = parse_xsd_ast(&content)?;
self.schemas.insert(resolved_uri.clone(), schema);
queue.push_back(resolved_uri);
}
}
self.resolving.remove(¤t_uri);
}
let mut result: Vec<XsdSchema> = Vec::new();
for (uri, schema) in &self.schemas {
if uri != entry_uri {
result.push(schema.clone());
}
}
if let Some(entry) = self.schemas.remove(entry_uri) {
result.push(entry);
}
Ok(result)
}
async fn fetch_schema(&self, uri: &str) -> Result<Vec<u8>> {
let result = self.fetcher.fetch(uri).await?;
Ok(result.content)
}
pub async fn resolve_entry(&mut self, entry_content: &[u8], entry_uri: &str) -> Result<()> {
if self.schemas.contains_key(entry_uri) {
return Ok(());
}
let entry_schema = parse_xsd_ast(entry_content)?;
self.schemas.insert(entry_uri.to_string(), entry_schema);
let mut queue: VecDeque<String> = VecDeque::new();
queue.push_back(entry_uri.to_string());
while let Some(current_uri) = queue.pop_front() {
if self.resolving.contains(¤t_uri) {
return Err(crate::schema::error::SchemaError::CircularDependency {
uri: current_uri,
}
.into());
}
self.resolving.insert(current_uri.clone());
let (imports, includes) = {
let schema = self.schemas.get(¤t_uri).ok_or_else(|| {
crate::schema::error::SchemaError::SchemaNotFound {
uri: current_uri.clone(),
}
})?;
(schema.imports.clone(), schema.includes.clone())
};
for import in imports {
if let Some(location) = &import.schema_location {
let resolved_uri = resolve_uri(¤t_uri, location)?;
if !self.schemas.contains_key(&resolved_uri) {
let content = self.fetch_schema(&resolved_uri).await?;
let schema = parse_xsd_ast(&content)?;
self.schemas.insert(resolved_uri.clone(), schema);
queue.push_back(resolved_uri);
}
}
}
for include in includes {
let resolved_uri = resolve_uri(¤t_uri, &include.schema_location)?;
if !self.schemas.contains_key(&resolved_uri) {
let content = self.fetch_schema(&resolved_uri).await?;
let schema = parse_xsd_ast(&content)?;
self.schemas.insert(resolved_uri.clone(), schema);
queue.push_back(resolved_uri);
}
}
self.resolving.remove(¤t_uri);
}
Ok(())
}
pub fn take_all_schemas(self) -> Vec<XsdSchema> {
self.schemas.into_values().collect()
}
pub fn into_schemas(self) -> HashMap<String, XsdSchema> {
self.schemas
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::fetcher::FetchResult;
use parking_lot::RwLock;
use std::collections::HashMap as StdHashMap;
use std::sync::Arc;
struct MockAsyncFetcher {
responses: Arc<RwLock<StdHashMap<String, Vec<u8>>>>,
}
impl MockAsyncFetcher {
fn new() -> Self {
Self {
responses: Arc::new(RwLock::new(StdHashMap::new())),
}
}
fn add_response(&self, url: &str, content: &[u8]) {
self.responses
.write()
.insert(url.to_string(), content.to_vec());
}
}
#[async_trait::async_trait]
impl AsyncSchemaFetcher for MockAsyncFetcher {
async fn fetch(&self, url: &str) -> Result<FetchResult> {
let responses = self.responses.read();
if let Some(content) = responses.get(url) {
Ok(FetchResult {
content: content.clone(),
final_url: url.to_string(),
redirected: false,
})
} else {
Err(crate::schema::fetcher::error::FetchError::RequestFailed {
url: url.to_string(),
message: "Not found".to_string(),
}
.into())
}
}
}
#[tokio::test]
async fn test_async_resolve_simple() {
let xsd = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="test" type="xs:string"/>
</xs:schema>"#;
let fetcher = MockAsyncFetcher::new();
let mut resolver = AsyncSchemaResolver::new(&fetcher);
let schemas = resolver
.resolve_all(xsd.as_bytes(), "http://example.com/test.xsd")
.await
.unwrap();
assert_eq!(schemas.len(), 1);
assert_eq!(schemas[0].elements.len(), 1);
}
#[tokio::test]
async fn test_async_resolve_with_import() {
let types_xsd = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema"
targetNamespace="http://example.com/types">
<xs:simpleType name="NameType">
<xs:restriction base="xs:string">
<xs:maxLength value="100"/>
</xs:restriction>
</xs:simpleType>
</xs:schema>"#;
let main_xsd = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:t="http://example.com/types"
targetNamespace="http://example.com/main">
<xs:import namespace="http://example.com/types" schemaLocation="types.xsd"/>
<xs:element name="person">
<xs:complexType>
<xs:sequence>
<xs:element name="name" type="t:NameType"/>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>"#;
let fetcher = MockAsyncFetcher::new();
fetcher.add_response("http://example.com/types.xsd", types_xsd.as_bytes());
let mut resolver = AsyncSchemaResolver::new(&fetcher);
let schemas = resolver
.resolve_all(main_xsd.as_bytes(), "http://example.com/main.xsd")
.await
.unwrap();
assert_eq!(schemas.len(), 2);
}
#[tokio::test]
async fn test_async_resolve_with_include() {
let common_xsd = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:simpleType name="IDType">
<xs:restriction base="xs:string">
<xs:pattern value="[A-Z]{2}[0-9]{4}"/>
</xs:restriction>
</xs:simpleType>
</xs:schema>"#;
let main_xsd = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:include schemaLocation="common.xsd"/>
<xs:element name="item">
<xs:complexType>
<xs:sequence>
<xs:element name="id" type="IDType"/>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>"#;
let fetcher = MockAsyncFetcher::new();
fetcher.add_response("http://example.com/common.xsd", common_xsd.as_bytes());
let mut resolver = AsyncSchemaResolver::new(&fetcher);
let schemas = resolver
.resolve_all(main_xsd.as_bytes(), "http://example.com/main.xsd")
.await
.unwrap();
assert_eq!(schemas.len(), 2);
}
#[tokio::test]
async fn test_async_resolve_uses_cache() {
use crate::schema::fetcher::AsyncCachingFetcher;
let types_xsd = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:simpleType name="CachedType">
<xs:restriction base="xs:string"/>
</xs:simpleType>
</xs:schema>"#;
let main_xsd = r#"<?xml version="1.0"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:import schemaLocation="types.xsd"/>
<xs:element name="test" type="xs:string"/>
</xs:schema>"#;
let fetcher = MockAsyncFetcher::new();
let caching = AsyncCachingFetcher::new(fetcher);
caching.seed(
"http://example.com/types.xsd",
types_xsd.as_bytes().to_vec(),
);
let mut resolver = AsyncSchemaResolver::new(&caching);
let schemas = resolver
.resolve_all(main_xsd.as_bytes(), "http://example.com/main.xsd")
.await
.unwrap();
assert_eq!(schemas.len(), 2);
}
}