use crate::validation::SchemaValidator;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub struct SchemaRegistry {
schemas: RwLock<HashMap<String, Arc<SchemaValidator>>>,
}
impl SchemaRegistry {
#[must_use]
pub fn new() -> Self {
Self {
schemas: RwLock::new(HashMap::new()),
}
}
pub fn get_or_compile(&self, schema: &Value) -> Result<Arc<SchemaValidator>, String> {
let key = serde_json::to_string(schema).map_err(|e| format!("Failed to serialize schema: {e}"))?;
{
let schemas = self.schemas.read().unwrap();
if let Some(validator) = schemas.get(&key) {
return Ok(Arc::clone(validator));
}
}
let validator = Arc::new(SchemaValidator::new(schema.clone())?);
{
let mut schemas = self.schemas.write().unwrap();
if let Some(existing) = schemas.get(&key) {
return Ok(Arc::clone(existing));
}
schemas.insert(key, Arc::clone(&validator));
}
Ok(validator)
}
#[must_use]
pub fn all_schemas(&self) -> Vec<Arc<SchemaValidator>> {
let schemas = self.schemas.read().unwrap();
schemas.values().cloned().collect()
}
#[must_use]
pub fn schema_count(&self) -> usize {
let schemas = self.schemas.read().unwrap();
schemas.len()
}
}
impl Default for SchemaRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_schema_deduplication() {
let registry = SchemaRegistry::new();
let schema1 = json!({
"type": "object",
"properties": {
"name": {"type": "string"}
}
});
let schema2 = json!({
"type": "object",
"properties": {
"name": {"type": "string"}
}
});
let validator1 = registry.get_or_compile(&schema1).unwrap();
let validator2 = registry.get_or_compile(&schema2).unwrap();
assert!(Arc::ptr_eq(&validator1, &validator2));
assert_eq!(registry.schema_count(), 1);
}
#[test]
fn test_different_schemas() {
let registry = SchemaRegistry::new();
let schema1 = json!({
"type": "string"
});
let schema2 = json!({
"type": "integer"
});
let validator1 = registry.get_or_compile(&schema1).unwrap();
let validator2 = registry.get_or_compile(&schema2).unwrap();
assert!(!Arc::ptr_eq(&validator1, &validator2));
assert_eq!(registry.schema_count(), 2);
}
#[test]
fn test_all_schemas() {
let registry = SchemaRegistry::new();
let schema1 = json!({"type": "string"});
let schema2 = json!({"type": "integer"});
registry.get_or_compile(&schema1).unwrap();
registry.get_or_compile(&schema2).unwrap();
let all = registry.all_schemas();
assert_eq!(all.len(), 2);
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc as StdArc;
use std::thread;
let registry = StdArc::new(SchemaRegistry::new());
let schema = json!({
"type": "object",
"properties": {
"id": {"type": "integer"}
}
});
let validators: Vec<_> = (0..10)
.map(|_| {
let registry = StdArc::clone(®istry);
let schema = schema.clone();
thread::spawn(move || registry.get_or_compile(&schema).unwrap())
})
.map(|h| h.join().unwrap())
.collect();
for i in 1..validators.len() {
assert!(Arc::ptr_eq(&validators[0], &validators[i]));
}
assert_eq!(registry.schema_count(), 1);
}
}