1use crate::error::{LinkError, Result};
2use crate::protocol::{ComponentId, FieldId, FieldType};
3use ahash::AHashMap;
4use serde::{Deserialize, Serialize};
5use std::sync::{Arc, RwLock};
6
7pub type SchemaVersion = u32;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ComponentSchema {
11 pub component_id: ComponentId,
12 pub version: SchemaVersion,
13 pub fields: Vec<FieldSchema>,
14 pub description: Option<String>,
15}
16
17impl ComponentSchema {
18 pub fn new(component_id: ComponentId, version: SchemaVersion) -> Self {
19 Self {
20 component_id,
21 version,
22 fields: Vec::new(),
23 description: None,
24 }
25 }
26
27 pub fn with_field(mut self, field: FieldSchema) -> Self {
28 self.fields.push(field);
29 self
30 }
31
32 pub fn with_description(mut self, description: String) -> Self {
33 self.description = Some(description);
34 self
35 }
36
37 pub fn get_field(&self, field_id: &str) -> Option<&FieldSchema> {
38 self.fields.iter().find(|f| f.field_id == field_id)
39 }
40
41 pub fn validate_field(&self, field_id: &str, field_type: &FieldType) -> bool {
42 if let Some(schema) = self.get_field(field_id) {
43 &schema.field_type == field_type
44 } else {
45 false
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct FieldSchema {
52 pub field_id: FieldId,
53 pub field_type: FieldType,
54 pub optional: bool,
55 pub default_value: Option<String>,
56 pub description: Option<String>,
57}
58
59impl FieldSchema {
60 pub fn new(field_id: FieldId, field_type: FieldType) -> Self {
61 Self {
62 field_id,
63 field_type,
64 optional: false,
65 default_value: None,
66 description: None,
67 }
68 }
69
70 pub fn optional(mut self) -> Self {
71 self.optional = true;
72 self
73 }
74
75 pub fn with_default(mut self, default: String) -> Self {
76 self.default_value = Some(default);
77 self
78 }
79
80 pub fn with_description(mut self, description: String) -> Self {
81 self.description = Some(description);
82 self
83 }
84}
85
86pub struct SchemaRegistry {
87 schemas: Arc<RwLock<AHashMap<ComponentId, ComponentSchema>>>,
88 version_history: Arc<RwLock<AHashMap<ComponentId, Vec<SchemaVersion>>>>,
89 current_version: SchemaVersion,
90}
91
92impl SchemaRegistry {
93 pub fn new() -> Self {
94 Self {
95 schemas: Arc::new(RwLock::new(AHashMap::new())),
96 version_history: Arc::new(RwLock::new(AHashMap::new())),
97 current_version: 1,
98 }
99 }
100
101 pub fn register(&self, schema: ComponentSchema) -> Result<()> {
102 let mut schemas = self.schemas.write()
103 .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
104
105 let mut version_history = self.version_history.write()
106 .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
107
108 let component_id = schema.component_id.clone();
109 let version = schema.version;
110
111 if let Some(existing) = schemas.get(&component_id) {
112 if existing.version >= version {
113 return Err(LinkError::Unknown(
114 format!("Schema version {} already exists or is newer for component {}", version, component_id)
115 ));
116 }
117 }
118
119 version_history.entry(component_id.clone())
120 .or_insert_with(Vec::new)
121 .push(version);
122
123 schemas.insert(component_id, schema);
124
125 Ok(())
126 }
127
128 pub fn get(&self, component_id: &str) -> Result<ComponentSchema> {
129 let schemas = self.schemas.read()
130 .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
131
132 schemas.get(component_id)
133 .cloned()
134 .ok_or_else(|| LinkError::SchemaNotFound(component_id.to_string()))
135 }
136
137 pub fn get_version(&self, component_id: &str, version: SchemaVersion) -> Result<ComponentSchema> {
138 let schema = self.get(component_id)?;
139
140 if schema.version == version {
141 Ok(schema)
142 } else {
143 Err(LinkError::SchemaMismatch {
144 expected: version.to_string(),
145 actual: schema.version.to_string(),
146 })
147 }
148 }
149
150 pub fn has(&self, component_id: &str) -> bool {
151 self.schemas.read()
152 .map(|schemas| schemas.contains_key(component_id))
153 .unwrap_or(false)
154 }
155
156 pub fn get_all(&self) -> Result<Vec<ComponentSchema>> {
157 let schemas = self.schemas.read()
158 .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
159
160 Ok(schemas.values().cloned().collect())
161 }
162
163 pub fn get_version_history(&self, component_id: &str) -> Result<Vec<SchemaVersion>> {
164 let history = self.version_history.read()
165 .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
166
167 Ok(history.get(component_id)
168 .cloned()
169 .unwrap_or_default())
170 }
171
172 pub fn validate_compatibility(&self, old_version: SchemaVersion, new_version: SchemaVersion) -> bool {
173 new_version >= old_version
174 }
175
176 pub fn clear(&self) -> Result<()> {
177 let mut schemas = self.schemas.write()
178 .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
179
180 let mut version_history = self.version_history.write()
181 .map_err(|e| LinkError::Unknown(format!("Lock poisoned: {}", e)))?;
182
183 schemas.clear();
184 version_history.clear();
185
186 Ok(())
187 }
188
189 pub fn get_current_version(&self) -> SchemaVersion {
190 self.current_version
191 }
192
193 pub fn set_current_version(&mut self, version: SchemaVersion) {
194 self.current_version = version;
195 }
196}
197
198impl Default for SchemaRegistry {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl Clone for SchemaRegistry {
205 fn clone(&self) -> Self {
206 Self {
207 schemas: Arc::clone(&self.schemas),
208 version_history: Arc::clone(&self.version_history),
209 current_version: self.current_version,
210 }
211 }
212}
213
214pub struct SchemaValidator {
215 registry: SchemaRegistry,
216}
217
218impl SchemaValidator {
219 pub fn new(registry: SchemaRegistry) -> Self {
220 Self { registry }
221 }
222
223 pub fn validate_component(&self, component_id: &str, fields: &AHashMap<FieldId, FieldType>) -> Result<()> {
224 let schema = self.registry.get(component_id)?;
225
226 for field_schema in &schema.fields {
227 if !field_schema.optional {
228 if !fields.contains_key(&field_schema.field_id) {
229 return Err(LinkError::InvalidMessage(
230 format!("Required field '{}' missing in component '{}'", field_schema.field_id, component_id)
231 ));
232 }
233 }
234
235 if let Some(field_type) = fields.get(&field_schema.field_id) {
236 if field_type != &field_schema.field_type {
237 return Err(LinkError::InvalidMessage(
238 format!("Field '{}' has wrong type in component '{}'", field_schema.field_id, component_id)
239 ));
240 }
241 }
242 }
243
244 Ok(())
245 }
246
247 pub fn get_registry(&self) -> &SchemaRegistry {
248 &self.registry
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_schema_registry() {
258 let registry = SchemaRegistry::new();
259
260 let schema = ComponentSchema::new("Position".to_string(), 1)
261 .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
262 .with_field(FieldSchema::new("y".to_string(), FieldType::F64))
263 .with_description("2D position component".to_string());
264
265 registry.register(schema.clone()).unwrap();
266
267 let retrieved = registry.get("Position").unwrap();
268 assert_eq!(retrieved.component_id, "Position");
269 assert_eq!(retrieved.fields.len(), 2);
270 }
271
272 #[test]
273 fn test_schema_versioning() {
274 let registry = SchemaRegistry::new();
275
276 let schema_v1 = ComponentSchema::new("Position".to_string(), 1)
277 .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
278 .with_field(FieldSchema::new("y".to_string(), FieldType::F64));
279
280 registry.register(schema_v1).unwrap();
281
282 let schema_v2 = ComponentSchema::new("Position".to_string(), 2)
283 .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
284 .with_field(FieldSchema::new("y".to_string(), FieldType::F64))
285 .with_field(FieldSchema::new("z".to_string(), FieldType::F64).optional());
286
287 registry.register(schema_v2).unwrap();
288
289 let history = registry.get_version_history("Position").unwrap();
290 assert_eq!(history.len(), 2);
291 assert!(history.contains(&1));
292 assert!(history.contains(&2));
293 }
294
295 #[test]
296 fn test_schema_validation() {
297 let registry = SchemaRegistry::new();
298
299 let schema = ComponentSchema::new("Position".to_string(), 1)
300 .with_field(FieldSchema::new("x".to_string(), FieldType::F64))
301 .with_field(FieldSchema::new("y".to_string(), FieldType::F64));
302
303 registry.register(schema).unwrap();
304
305 let validator = SchemaValidator::new(registry);
306
307 let mut fields = AHashMap::new();
308 fields.insert("x".to_string(), FieldType::F64);
309 fields.insert("y".to_string(), FieldType::F64);
310
311 assert!(validator.validate_component("Position", &fields).is_ok());
312
313 let mut invalid_fields = AHashMap::new();
314 invalid_fields.insert("x".to_string(), FieldType::F64);
315
316 assert!(validator.validate_component("Position", &invalid_fields).is_err());
317 }
318}