use crate::errors::{Error, Result};
use crate::manifest::calculate_schema_checksum;
use crate::types::SchemaType;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[async_trait]
pub trait SchemaProvider: Send + Sync {
fn schema_type(&self) -> SchemaType;
async fn generate(&self, app: &dyn Application) -> Result<serde_json::Value>;
fn validate(&self, schema: &serde_json::Value) -> Result<()>;
fn hash(&self, schema: &serde_json::Value) -> Result<String> {
calculate_schema_checksum(schema)
}
fn serialize(&self, schema: &serde_json::Value) -> Result<Vec<u8>> {
serde_json::to_vec(schema).map_err(Error::from)
}
fn endpoint(&self) -> Option<String> {
None
}
fn spec_version(&self) -> String;
fn content_type(&self) -> String {
"application/json".to_string()
}
}
pub trait Application: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn routes(&self) -> Box<dyn std::any::Any + Send + Sync>;
}
pub struct BaseSchemaProvider {
schema_type: SchemaType,
spec_version: String,
content_type: String,
endpoint: Option<String>,
}
impl BaseSchemaProvider {
pub fn new(
schema_type: SchemaType,
spec_version: impl Into<String>,
content_type: impl Into<String>,
endpoint: Option<String>,
) -> Self {
Self {
schema_type,
spec_version: spec_version.into(),
content_type: content_type.into(),
endpoint,
}
}
pub fn get_schema_type(&self) -> SchemaType {
self.schema_type
}
pub fn get_spec_version(&self) -> &str {
&self.spec_version
}
pub fn get_content_type(&self) -> &str {
&self.content_type
}
pub fn get_endpoint(&self) -> Option<&str> {
self.endpoint.as_deref()
}
}
#[derive(Clone)]
pub struct ProviderRegistry {
providers: Arc<RwLock<HashMap<SchemaType, Arc<dyn SchemaProvider>>>>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register(&self, provider: Arc<dyn SchemaProvider>) {
let schema_type = provider.schema_type();
let mut providers = self.providers.write().unwrap();
providers.insert(schema_type, provider);
}
pub fn get(&self, schema_type: SchemaType) -> Option<Arc<dyn SchemaProvider>> {
let providers = self.providers.read().unwrap();
providers.get(&schema_type).cloned()
}
pub fn has(&self, schema_type: SchemaType) -> bool {
let providers = self.providers.read().unwrap();
providers.contains_key(&schema_type)
}
pub fn list(&self) -> Vec<SchemaType> {
let providers = self.providers.read().unwrap();
providers.keys().copied().collect()
}
pub fn unregister(&self, schema_type: SchemaType) -> bool {
let mut providers = self.providers.write().unwrap();
providers.remove(&schema_type).is_some()
}
pub fn clear(&self) {
let mut providers = self.providers.write().unwrap();
providers.clear();
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_REGISTRY: once_cell::sync::Lazy<ProviderRegistry> =
once_cell::sync::Lazy::new(ProviderRegistry::new);
pub fn register_provider(provider: Arc<dyn SchemaProvider>) {
GLOBAL_REGISTRY.register(provider);
}
pub fn get_provider(schema_type: SchemaType) -> Option<Arc<dyn SchemaProvider>> {
GLOBAL_REGISTRY.get(schema_type)
}
pub fn has_provider(schema_type: SchemaType) -> bool {
GLOBAL_REGISTRY.has(schema_type)
}
pub fn list_providers() -> Vec<SchemaType> {
GLOBAL_REGISTRY.list()
}
pub fn unregister_provider(schema_type: SchemaType) -> bool {
GLOBAL_REGISTRY.unregister(schema_type)
}
pub fn clear_providers() {
GLOBAL_REGISTRY.clear();
}
#[cfg(test)]
mod tests {
use super::*;
struct TestProvider {
base: BaseSchemaProvider,
}
#[async_trait]
impl SchemaProvider for TestProvider {
fn schema_type(&self) -> SchemaType {
self.base.get_schema_type()
}
async fn generate(&self, _app: &dyn Application) -> Result<serde_json::Value> {
Ok(serde_json::json!({"test": "schema"}))
}
fn validate(&self, _schema: &serde_json::Value) -> Result<()> {
Ok(())
}
fn spec_version(&self) -> String {
self.base.get_spec_version().to_string()
}
}
#[test]
fn test_provider_registry() {
let registry = ProviderRegistry::new();
let provider = Arc::new(TestProvider {
base: BaseSchemaProvider::new(SchemaType::OpenAPI, "3.1.0", "application/json", None),
});
registry.register(provider.clone());
assert!(registry.has(SchemaType::OpenAPI));
assert!(!registry.has(SchemaType::AsyncAPI));
let retrieved = registry.get(SchemaType::OpenAPI);
assert!(retrieved.is_some());
let types = registry.list();
assert_eq!(types.len(), 1);
registry.unregister(SchemaType::OpenAPI);
assert!(!registry.has(SchemaType::OpenAPI));
}
#[test]
fn test_base_provider() {
let base = BaseSchemaProvider::new(
SchemaType::OpenAPI,
"3.1.0",
"application/json",
Some("/openapi.json".to_string()),
);
assert_eq!(base.get_schema_type(), SchemaType::OpenAPI);
assert_eq!(base.get_spec_version(), "3.1.0");
assert_eq!(base.get_content_type(), "application/json");
assert_eq!(base.get_endpoint(), Some("/openapi.json"));
}
}