use crate::error::{SchemaError, SchemaResult};
use crate::types::SchemaDefinition;
use crate::Schema;
use indexmap::IndexMap;
use std::collections::HashSet;
use std::fs;
use std::path::Path;
#[derive(Debug, Default)]
pub struct SchemaRegistry {
definitions: IndexMap<String, SchemaDefinition>,
config: RegistryConfig,
}
#[derive(Debug, Clone)]
pub struct RegistryConfig {
pub auto_register_deps: bool,
pub namespace: Option<String>,
pub version: Option<String>,
pub include_deprecated: bool,
pub title: Option<String>,
pub description: Option<String>,
}
impl Default for RegistryConfig {
fn default() -> Self {
Self {
auto_register_deps: true,
namespace: None,
version: None,
include_deprecated: true,
title: None,
description: None,
}
}
}
impl SchemaRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: RegistryConfig) -> Self {
Self {
definitions: IndexMap::new(),
config,
}
}
pub fn config(&self) -> &RegistryConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut RegistryConfig {
&mut self.config
}
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.config.namespace = Some(namespace.into());
self
}
pub fn with_version(mut self, version: impl Into<String>) -> Self {
self.config.version = Some(version.into());
self
}
pub fn with_title(mut self, title: impl Into<String>) -> Self {
self.config.title = Some(title.into());
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.config.description = Some(description.into());
self
}
pub fn register<T: Schema>(mut self) -> Self {
self.register_type::<T>();
self
}
pub fn register_type<T: Schema>(&mut self) -> &mut Self {
let definition = T::schema_definition();
let name = definition.name.to_string();
if self.definitions.contains_key(&name) {
return self;
}
let deps = definition.dependencies.clone();
self.definitions.insert(name, definition);
if self.config.auto_register_deps {
for dep in deps {
if !self.definitions.contains_key(&dep) {
}
}
}
self
}
pub fn register_definition(mut self, definition: SchemaDefinition) -> Self {
self.register_definition_mut(definition);
self
}
pub fn register_definition_mut(&mut self, definition: SchemaDefinition) -> &mut Self {
self.definitions
.insert(definition.name.to_string(), definition);
self
}
pub fn contains(&self, name: &str) -> bool {
self.definitions.contains_key(name)
}
pub fn len(&self) -> usize {
self.definitions.len()
}
pub fn is_empty(&self) -> bool {
self.definitions.is_empty()
}
pub fn type_names(&self) -> impl Iterator<Item = &str> {
self.definitions.keys().map(|s| s.as_str())
}
pub fn definitions(&self) -> impl Iterator<Item = (&str, &SchemaDefinition)> {
self.definitions.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn get(&self, name: &str) -> Option<&SchemaDefinition> {
self.definitions.get(name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut SchemaDefinition> {
self.definitions.get_mut(name)
}
pub fn remove(&mut self, name: &str) -> Option<SchemaDefinition> {
self.definitions.shift_remove(name)
}
pub fn validate(&self) -> SchemaResult<()> {
let mut missing = Vec::new();
for (name, def) in &self.definitions {
for dep in &def.dependencies {
if !self.definitions.contains_key(dep) {
missing.push((name.clone(), dep.clone()));
}
}
}
if !missing.is_empty() {
let msg = missing
.iter()
.map(|(from, to)| format!("'{}' depends on unregistered type '{}'", from, to))
.collect::<Vec<_>>()
.join("; ");
return Err(SchemaError::custom(format!(
"Missing dependencies: {}",
msg
)));
}
self.check_circular_dependencies()?;
Ok(())
}
fn check_circular_dependencies(&self) -> SchemaResult<()> {
let mut visited = HashSet::new();
let mut stack = Vec::new();
for name in self.definitions.keys() {
if !visited.contains(name) {
self.dfs_cycle_check(name, &mut visited, &mut stack)?;
}
}
Ok(())
}
fn dfs_cycle_check(
&self,
name: &str,
visited: &mut HashSet<String>,
stack: &mut Vec<String>,
) -> SchemaResult<()> {
visited.insert(name.to_string());
stack.push(name.to_string());
if let Some(def) = self.definitions.get(name) {
for dep in &def.dependencies {
if !visited.contains(dep) {
self.dfs_cycle_check(dep, visited, stack)?;
} else if stack.contains(dep) {
let cycle_start = stack.iter().position(|s| s == dep).unwrap();
let mut cycle: Vec<_> = stack[cycle_start..].to_vec();
cycle.push(dep.clone());
return Err(SchemaError::circular_reference(cycle));
}
}
}
stack.pop();
Ok(())
}
pub fn topological_sort(&self) -> SchemaResult<Vec<&SchemaDefinition>> {
let mut result = Vec::new();
let mut visited = HashSet::new();
let mut temp_visited = HashSet::new();
for name in self.definitions.keys() {
if !visited.contains(name) {
self.topo_visit(name, &mut visited, &mut temp_visited, &mut result)?;
}
}
Ok(result)
}
fn topo_visit<'a>(
&'a self,
name: &str,
visited: &mut HashSet<String>,
temp_visited: &mut HashSet<String>,
result: &mut Vec<&'a SchemaDefinition>,
) -> SchemaResult<()> {
if temp_visited.contains(name) {
return Err(SchemaError::circular_reference([name]));
}
if visited.contains(name) {
return Ok(());
}
temp_visited.insert(name.to_string());
if let Some(def) = self.definitions.get(name) {
for dep in &def.dependencies {
self.topo_visit(dep, visited, temp_visited, result)?;
}
result.push(def);
}
temp_visited.remove(name);
visited.insert(name.to_string());
Ok(())
}
#[cfg(feature = "json-schema")]
pub fn export_json_schema(&self, path: impl AsRef<Path>) -> SchemaResult<()> {
let path = path.as_ref();
if path.extension().is_some() {
let content = self.generate_json_schema_bundle();
fs::write(path, content)?;
} else {
fs::create_dir_all(path)?;
for (name, def) in &self.definitions {
let content = crate::formats::json_schema::generate(def);
let file_path = path.join(format!("{}.json", name));
fs::write(file_path, content)?;
}
}
Ok(())
}
#[cfg(feature = "json-schema")]
pub fn generate_json_schema_bundle(&self) -> String {
crate::formats::json_schema::generate_bundle(self)
}
#[cfg(feature = "openapi")]
pub fn export_openapi(&self, path: impl AsRef<Path>) -> SchemaResult<()> {
let content = self.generate_openapi();
fs::write(path.as_ref(), content)?;
Ok(())
}
#[cfg(feature = "openapi")]
pub fn generate_openapi(&self) -> String {
crate::formats::openapi::generate_bundle(self)
}
#[cfg(feature = "graphql")]
pub fn export_graphql(&self, path: impl AsRef<Path>) -> SchemaResult<()> {
let content = self.generate_graphql();
fs::write(path.as_ref(), content)?;
Ok(())
}
#[cfg(feature = "graphql")]
pub fn generate_graphql(&self) -> String {
crate::formats::graphql::generate_bundle(self)
}
#[cfg(feature = "protobuf")]
pub fn export_proto(&self, path: impl AsRef<Path>) -> SchemaResult<()> {
let content = self.generate_proto();
fs::write(path.as_ref(), content)?;
Ok(())
}
#[cfg(feature = "protobuf")]
pub fn generate_proto(&self) -> String {
crate::formats::protobuf::generate_bundle(self)
}
#[cfg(feature = "typescript")]
pub fn export_typescript(&self, path: impl AsRef<Path>) -> SchemaResult<()> {
let content = self.generate_typescript();
fs::write(path.as_ref(), content)?;
Ok(())
}
#[cfg(feature = "typescript")]
pub fn generate_typescript(&self) -> String {
crate::formats::typescript::generate_bundle(self)
}
#[cfg(feature = "avro")]
pub fn export_avro(&self, path: impl AsRef<Path>) -> SchemaResult<()> {
let path = path.as_ref();
if path.extension().is_some() {
let content = self.generate_avro_bundle();
fs::write(path, content)?;
} else {
fs::create_dir_all(path)?;
for (name, def) in &self.definitions {
let content = crate::formats::avro::generate(def);
let file_path = path.join(format!("{}.avsc", name));
fs::write(file_path, content)?;
}
}
Ok(())
}
#[cfg(feature = "avro")]
pub fn generate_avro_bundle(&self) -> String {
crate::formats::avro::generate_bundle(self)
}
pub fn export_all(&self, base_path: impl AsRef<Path>) -> SchemaResult<()> {
let base = base_path.as_ref();
fs::create_dir_all(base)?;
#[cfg(feature = "json-schema")]
{
let json_path = base.join("json-schema");
self.export_json_schema(&json_path)?;
}
#[cfg(feature = "openapi")]
{
let openapi_path = base.join("openapi.yaml");
self.export_openapi(&openapi_path)?;
}
#[cfg(feature = "graphql")]
{
let graphql_path = base.join("schema.graphql");
self.export_graphql(&graphql_path)?;
}
#[cfg(feature = "protobuf")]
{
let proto_path = base.join("schema.proto");
self.export_proto(&proto_path)?;
}
#[cfg(feature = "typescript")]
{
let ts_path = base.join("types.ts");
self.export_typescript(&ts_path)?;
}
#[cfg(feature = "avro")]
{
let avro_path = base.join("avro");
self.export_avro(&avro_path)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{PrimitiveType, SchemaType, StructDefinition, StructField};
fn create_test_definition(name: &str) -> SchemaDefinition {
SchemaDefinition::new(
name.to_string(),
SchemaType::Struct(
StructDefinition::new().with_field(
"id",
StructField::new(SchemaType::Primitive(PrimitiveType::U64), "id"),
),
),
)
}
#[test]
fn test_registry_basic() {
let mut registry = SchemaRegistry::new();
let def = create_test_definition("User");
registry.register_definition_mut(def);
assert!(registry.contains("User"));
assert_eq!(registry.len(), 1);
}
#[test]
fn test_registry_with_namespace() {
let registry = SchemaRegistry::new()
.with_namespace("com.example")
.with_version("1.0.0")
.with_title("My API")
.with_description("API schema definitions");
assert_eq!(registry.config().namespace, Some("com.example".to_string()));
assert_eq!(registry.config().version, Some("1.0.0".to_string()));
}
#[test]
fn test_registry_validate_missing_dep() {
let mut registry = SchemaRegistry::new();
let def = create_test_definition("User").with_dependency("Status");
registry.register_definition_mut(def);
let result = registry.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Status"));
}
#[test]
fn test_registry_validate_success() {
let mut registry = SchemaRegistry::new();
let user = create_test_definition("User");
let status = create_test_definition("Status");
registry.register_definition_mut(user);
registry.register_definition_mut(status);
let result = registry.validate();
assert!(result.is_ok());
}
#[test]
fn test_registry_topological_sort() {
let mut registry = SchemaRegistry::new();
let user = create_test_definition("User").with_dependency("Status");
let status = create_test_definition("Status");
registry.register_definition_mut(user);
registry.register_definition_mut(status);
let sorted = registry.topological_sort().unwrap();
let names: Vec<_> = sorted.iter().map(|d| d.name.as_ref()).collect();
let status_idx = names.iter().position(|&n| n == "Status").unwrap();
let user_idx = names.iter().position(|&n| n == "User").unwrap();
assert!(status_idx < user_idx);
}
}