1use crate::error::{EventError, Result};
7use crate::types::Event;
8use std::collections::HashMap;
9use std::sync::RwLock;
10
11#[derive(Debug, Clone)]
13pub struct EventSchema {
14 pub event_type: String,
16
17 pub version: u32,
19
20 pub required_fields: Vec<String>,
22
23 pub description: String,
25}
26
27#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
29pub enum Compatibility {
30 #[default]
32 Backward,
33 Forward,
35 Full,
37 None,
39}
40
41pub trait SchemaRegistry: Send + Sync {
46 fn register(&self, schema: EventSchema) -> Result<()>;
48
49 fn get(&self, event_type: &str, version: u32) -> Result<Option<EventSchema>>;
51
52 fn latest_version(&self, event_type: &str) -> Result<Option<u32>>;
54
55 fn list_types(&self) -> Result<Vec<String>>;
57
58 fn validate(&self, event: &Event) -> Result<()>;
62
63 fn check_compatibility(
65 &self,
66 event_type: &str,
67 new_version: u32,
68 mode: Compatibility,
69 ) -> Result<()>;
70}
71
72pub struct MemorySchemaRegistry {
77 schemas: RwLock<HashMap<(String, u32), EventSchema>>,
79}
80
81impl MemorySchemaRegistry {
82 pub fn new() -> Self {
84 Self {
85 schemas: RwLock::new(HashMap::new()),
86 }
87 }
88}
89
90impl Default for MemorySchemaRegistry {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96impl SchemaRegistry for MemorySchemaRegistry {
97 fn register(&self, schema: EventSchema) -> Result<()> {
98 if schema.event_type.is_empty() {
99 return Err(EventError::Config(
100 "Event type cannot be empty".to_string(),
101 ));
102 }
103 if schema.version == 0 {
104 return Err(EventError::Config(
105 "Schema version must be >= 1".to_string(),
106 ));
107 }
108
109 let key = (schema.event_type.clone(), schema.version);
110 let mut schemas = self.schemas.write().map_err(|e| {
111 EventError::Provider(format!("Schema registry lock poisoned: {}", e))
112 })?;
113 schemas.insert(key, schema);
114 Ok(())
115 }
116
117 fn get(&self, event_type: &str, version: u32) -> Result<Option<EventSchema>> {
118 let schemas = self.schemas.read().map_err(|e| {
119 EventError::Provider(format!("Schema registry lock poisoned: {}", e))
120 })?;
121 Ok(schemas.get(&(event_type.to_string(), version)).cloned())
122 }
123
124 fn latest_version(&self, event_type: &str) -> Result<Option<u32>> {
125 let schemas = self.schemas.read().map_err(|e| {
126 EventError::Provider(format!("Schema registry lock poisoned: {}", e))
127 })?;
128 let max = schemas
129 .keys()
130 .filter(|(t, _)| t == event_type)
131 .map(|(_, v)| *v)
132 .max();
133 Ok(max)
134 }
135
136 fn list_types(&self) -> Result<Vec<String>> {
137 let schemas = self.schemas.read().map_err(|e| {
138 EventError::Provider(format!("Schema registry lock poisoned: {}", e))
139 })?;
140 let mut types: Vec<String> = schemas
141 .keys()
142 .map(|(t, _)| t.clone())
143 .collect::<std::collections::HashSet<_>>()
144 .into_iter()
145 .collect();
146 types.sort();
147 Ok(types)
148 }
149
150 fn validate(&self, event: &Event) -> Result<()> {
151 if event.event_type.is_empty() {
153 return Ok(());
154 }
155
156 let schemas = self.schemas.read().map_err(|e| {
157 EventError::Provider(format!("Schema registry lock poisoned: {}", e))
158 })?;
159
160 let key = (event.event_type.clone(), event.version);
161 let schema = match schemas.get(&key) {
162 Some(s) => s,
163 None => return Ok(()), };
165
166 if let serde_json::Value::Object(ref map) = event.payload {
168 for field in &schema.required_fields {
169 if !map.contains_key(field) {
170 return Err(EventError::SchemaValidation {
171 event_type: event.event_type.clone(),
172 version: event.version,
173 reason: format!("Missing required field '{}'", field),
174 });
175 }
176 }
177 } else if !schema.required_fields.is_empty() {
178 return Err(EventError::SchemaValidation {
179 event_type: event.event_type.clone(),
180 version: event.version,
181 reason: "Payload must be a JSON object when schema has required fields"
182 .to_string(),
183 });
184 }
185
186 Ok(())
187 }
188
189 fn check_compatibility(
190 &self,
191 event_type: &str,
192 new_version: u32,
193 mode: Compatibility,
194 ) -> Result<()> {
195 if mode == Compatibility::None || new_version <= 1 {
196 return Ok(());
197 }
198
199 let prev_version = new_version - 1;
200 let schemas = self.schemas.read().map_err(|e| {
201 EventError::Provider(format!("Schema registry lock poisoned: {}", e))
202 })?;
203
204 let prev = match schemas.get(&(event_type.to_string(), prev_version)) {
205 Some(s) => s,
206 None => return Ok(()), };
208
209 let new = match schemas.get(&(event_type.to_string(), new_version)) {
210 Some(s) => s,
211 None => return Ok(()), };
213
214 match mode {
215 Compatibility::Backward => {
216 for field in &new.required_fields {
219 if !prev.required_fields.contains(field) {
220 return Err(EventError::SchemaValidation {
221 event_type: event_type.to_string(),
222 version: new_version,
223 reason: format!(
224 "Backward incompatible: new required field '{}' \
225 not in v{}",
226 field, prev_version
227 ),
228 });
229 }
230 }
231 }
232 Compatibility::Forward => {
233 for field in &prev.required_fields {
235 if !new.required_fields.contains(field) {
236 return Err(EventError::SchemaValidation {
237 event_type: event_type.to_string(),
238 version: new_version,
239 reason: format!(
240 "Forward incompatible: required field '{}' from v{} \
241 removed in v{}",
242 field, prev_version, new_version
243 ),
244 });
245 }
246 }
247 }
248 Compatibility::Full => {
249 if prev.required_fields != new.required_fields {
251 return Err(EventError::SchemaValidation {
252 event_type: event_type.to_string(),
253 version: new_version,
254 reason: format!(
255 "Full incompatible: required fields differ between v{} and v{}",
256 prev_version, new_version
257 ),
258 });
259 }
260 }
261 Compatibility::None => {}
262 }
263
264 Ok(())
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 fn test_registry() -> MemorySchemaRegistry {
273 MemorySchemaRegistry::new()
274 }
275
276 #[test]
277 fn test_register_and_get() {
278 let reg = test_registry();
279 reg.register(EventSchema {
280 event_type: "forex.rate_change".to_string(),
281 version: 1,
282 required_fields: vec!["rate".to_string(), "currency".to_string()],
283 description: "Forex rate change event".to_string(),
284 })
285 .unwrap();
286
287 let schema = reg.get("forex.rate_change", 1).unwrap().unwrap();
288 assert_eq!(schema.event_type, "forex.rate_change");
289 assert_eq!(schema.version, 1);
290 assert_eq!(schema.required_fields, vec!["rate", "currency"]);
291 }
292
293 #[test]
294 fn test_get_nonexistent() {
295 let reg = test_registry();
296 assert!(reg.get("nonexistent", 1).unwrap().is_none());
297 }
298
299 #[test]
300 fn test_register_empty_type_fails() {
301 let reg = test_registry();
302 let result = reg.register(EventSchema {
303 event_type: "".to_string(),
304 version: 1,
305 required_fields: vec![],
306 description: String::new(),
307 });
308 assert!(result.is_err());
309 }
310
311 #[test]
312 fn test_register_zero_version_fails() {
313 let reg = test_registry();
314 let result = reg.register(EventSchema {
315 event_type: "test".to_string(),
316 version: 0,
317 required_fields: vec![],
318 description: String::new(),
319 });
320 assert!(result.is_err());
321 }
322
323 #[test]
324 fn test_latest_version() {
325 let reg = test_registry();
326 for v in 1..=3 {
327 reg.register(EventSchema {
328 event_type: "test.event".to_string(),
329 version: v,
330 required_fields: vec![],
331 description: String::new(),
332 })
333 .unwrap();
334 }
335
336 assert_eq!(reg.latest_version("test.event").unwrap(), Some(3));
337 assert_eq!(reg.latest_version("nonexistent").unwrap(), None);
338 }
339
340 #[test]
341 fn test_list_types() {
342 let reg = test_registry();
343 reg.register(EventSchema {
344 event_type: "b.event".to_string(),
345 version: 1,
346 required_fields: vec![],
347 description: String::new(),
348 })
349 .unwrap();
350 reg.register(EventSchema {
351 event_type: "a.event".to_string(),
352 version: 1,
353 required_fields: vec![],
354 description: String::new(),
355 })
356 .unwrap();
357 reg.register(EventSchema {
358 event_type: "a.event".to_string(),
359 version: 2,
360 required_fields: vec![],
361 description: String::new(),
362 })
363 .unwrap();
364
365 let types = reg.list_types().unwrap();
366 assert_eq!(types, vec!["a.event", "b.event"]);
367 }
368
369 #[test]
370 fn test_validate_untyped_event_passes() {
371 let reg = test_registry();
372 let event = Event::new(
373 "events.test.a",
374 "test",
375 "Test",
376 "test",
377 serde_json::json!({}),
378 );
379 assert!(reg.validate(&event).is_ok());
380 }
381
382 #[test]
383 fn test_validate_no_schema_registered_passes() {
384 let reg = test_registry();
385 let event = Event::typed(
386 "events.test.a",
387 "test",
388 "unknown.type",
389 1,
390 "Test",
391 "test",
392 serde_json::json!({}),
393 );
394 assert!(reg.validate(&event).is_ok());
395 }
396
397 #[test]
398 fn test_validate_valid_event() {
399 let reg = test_registry();
400 reg.register(EventSchema {
401 event_type: "forex.rate_change".to_string(),
402 version: 1,
403 required_fields: vec!["rate".to_string(), "currency".to_string()],
404 description: String::new(),
405 })
406 .unwrap();
407
408 let event = Event::typed(
409 "events.market.forex",
410 "market",
411 "forex.rate_change",
412 1,
413 "Rate change",
414 "reuters",
415 serde_json::json!({"rate": 7.35, "currency": "USD/CNY"}),
416 );
417 assert!(reg.validate(&event).is_ok());
418 }
419
420 #[test]
421 fn test_validate_missing_required_field() {
422 let reg = test_registry();
423 reg.register(EventSchema {
424 event_type: "forex.rate_change".to_string(),
425 version: 1,
426 required_fields: vec!["rate".to_string(), "currency".to_string()],
427 description: String::new(),
428 })
429 .unwrap();
430
431 let event = Event::typed(
432 "events.market.forex",
433 "market",
434 "forex.rate_change",
435 1,
436 "Rate change",
437 "reuters",
438 serde_json::json!({"rate": 7.35}), );
440
441 let err = reg.validate(&event).unwrap_err();
442 let msg = err.to_string();
443 assert!(msg.contains("currency"), "Error should mention missing field: {}", msg);
444 }
445
446 #[test]
447 fn test_validate_non_object_payload_with_required_fields() {
448 let reg = test_registry();
449 reg.register(EventSchema {
450 event_type: "test.event".to_string(),
451 version: 1,
452 required_fields: vec!["field".to_string()],
453 description: String::new(),
454 })
455 .unwrap();
456
457 let event = Event::typed(
458 "events.test.a",
459 "test",
460 "test.event",
461 1,
462 "Test",
463 "test",
464 serde_json::json!("not an object"),
465 );
466
467 assert!(reg.validate(&event).is_err());
468 }
469
470 #[test]
471 fn test_backward_compatibility_ok() {
472 let reg = test_registry();
473 reg.register(EventSchema {
475 event_type: "forex".to_string(),
476 version: 1,
477 required_fields: vec!["rate".to_string()],
478 description: String::new(),
479 })
480 .unwrap();
481 reg.register(EventSchema {
483 event_type: "forex".to_string(),
484 version: 2,
485 required_fields: vec!["rate".to_string()],
486 description: String::new(),
487 })
488 .unwrap();
489
490 assert!(reg
491 .check_compatibility("forex", 2, Compatibility::Backward)
492 .is_ok());
493 }
494
495 #[test]
496 fn test_backward_compatibility_fail() {
497 let reg = test_registry();
498 reg.register(EventSchema {
500 event_type: "forex".to_string(),
501 version: 1,
502 required_fields: vec!["rate".to_string()],
503 description: String::new(),
504 })
505 .unwrap();
506 reg.register(EventSchema {
508 event_type: "forex".to_string(),
509 version: 2,
510 required_fields: vec!["rate".to_string(), "currency".to_string()],
511 description: String::new(),
512 })
513 .unwrap();
514
515 let err = reg
516 .check_compatibility("forex", 2, Compatibility::Backward)
517 .unwrap_err();
518 assert!(err.to_string().contains("currency"));
519 }
520
521 #[test]
522 fn test_forward_compatibility_fail() {
523 let reg = test_registry();
524 reg.register(EventSchema {
526 event_type: "forex".to_string(),
527 version: 1,
528 required_fields: vec!["rate".to_string(), "currency".to_string()],
529 description: String::new(),
530 })
531 .unwrap();
532 reg.register(EventSchema {
534 event_type: "forex".to_string(),
535 version: 2,
536 required_fields: vec!["rate".to_string()],
537 description: String::new(),
538 })
539 .unwrap();
540
541 let err = reg
542 .check_compatibility("forex", 2, Compatibility::Forward)
543 .unwrap_err();
544 assert!(err.to_string().contains("currency"));
545 }
546
547 #[test]
548 fn test_full_compatibility() {
549 let reg = test_registry();
550 reg.register(EventSchema {
551 event_type: "forex".to_string(),
552 version: 1,
553 required_fields: vec!["rate".to_string()],
554 description: String::new(),
555 })
556 .unwrap();
557 reg.register(EventSchema {
558 event_type: "forex".to_string(),
559 version: 2,
560 required_fields: vec!["rate".to_string()],
561 description: String::new(),
562 })
563 .unwrap();
564
565 assert!(reg
566 .check_compatibility("forex", 2, Compatibility::Full)
567 .is_ok());
568 }
569
570 #[test]
571 fn test_no_compatibility_always_passes() {
572 let reg = test_registry();
573 reg.register(EventSchema {
574 event_type: "forex".to_string(),
575 version: 1,
576 required_fields: vec!["a".to_string()],
577 description: String::new(),
578 })
579 .unwrap();
580 reg.register(EventSchema {
581 event_type: "forex".to_string(),
582 version: 2,
583 required_fields: vec!["b".to_string()],
584 description: String::new(),
585 })
586 .unwrap();
587
588 assert!(reg
589 .check_compatibility("forex", 2, Compatibility::None)
590 .is_ok());
591 }
592
593 #[test]
594 fn test_compatibility_no_previous_version() {
595 let reg = test_registry();
596 reg.register(EventSchema {
597 event_type: "forex".to_string(),
598 version: 1,
599 required_fields: vec!["rate".to_string()],
600 description: String::new(),
601 })
602 .unwrap();
603
604 assert!(reg
606 .check_compatibility("forex", 1, Compatibility::Full)
607 .is_ok());
608 }
609}