Skip to main content

rivven_core/
schema_registry.rs

1use crate::{Error, Result};
2use apache_avro::Schema;
3use async_trait::async_trait;
4use dashmap::DashMap;
5use serde::{Deserialize, Serialize};
6use std::sync::atomic::{AtomicI32, Ordering};
7use std::sync::Arc;
8
9/// Abstract Schema Registry interface.
10///
11/// Supports multiple backends:
12/// 1. `MemorySchemaRegistry`: For testing and development.
13/// 2. `EmbeddedSchemaRegistry`: A persistent, log-backed registry (Kafka-style).
14/// 3. `ExternalSchemaRegistry`: Adapters for AWS Glue, Confluent Cloud, etc.
15#[async_trait]
16pub trait SchemaRegistry: Send + Sync + std::fmt::Debug {
17    /// Register a new schema under the given subject.
18    /// Returns the global ID of the schema.
19    async fn register(&self, subject: &str, schema_str: &str) -> Result<i32>;
20
21    /// Retrieve a schema by its global ID.
22    async fn get_schema(&self, id: i32) -> Result<Arc<Schema>>;
23
24    /// Get the latest schema ID for a subject.
25    async fn get_latest_schema_id(&self, subject: &str) -> Result<Option<i32>>;
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29pub enum SchemaLogEvent {
30    Register(SchemaRegistration),
31}
32
33#[derive(Serialize, Deserialize, Debug)]
34pub struct SchemaRegistration {
35    pub id: i32,
36    pub subject: String,
37    pub schema: String,
38    pub version: i32,
39}
40
41/// Persistent implementation using a Rivven internal topic (`_schemas`).
42#[derive(Debug)]
43pub struct EmbeddedSchemaRegistry {
44    /// In-memory cache of the registry state.
45    /// Access is delegated to this cache after writes are persisted.
46    cache: MemorySchemaRegistry,
47
48    /// Handle to the partition where schemas are stored.
49    /// Typically topic=`_schemas`, partition=0.
50    partition: Arc<crate::partition::Partition>,
51}
52
53impl EmbeddedSchemaRegistry {
54    pub async fn new(config: &crate::Config) -> Result<Self> {
55        // Initialize the partition for storage
56        let partition = crate::partition::Partition::new(config, "_schemas", 0).await?;
57        let partition = Arc::new(partition);
58
59        let registry = Self {
60            cache: MemorySchemaRegistry::new(),
61            partition,
62        };
63
64        registry.recover().await?;
65
66        Ok(registry)
67    }
68
69    /// Read the log from the beginning and rebuild the in-memory state.
70    async fn recover(&self) -> Result<()> {
71        let mut offset = 0;
72        // Basic recovery loop - assumes partition read logic is available
73        // In a production system, we would iterate segments efficiently.
74        #[allow(clippy::while_let_loop)] // Complex loop with batch reads requires explicit control
75        loop {
76            // Read next batch.
77            // Partition::read takes (start_offset, max_messages).
78            // We'll read up to 100 messages at a time.
79            match self.partition.read(offset, 100).await {
80                Ok(messages) => {
81                    if messages.is_empty() {
82                        break;
83                    }
84                    for msg in messages {
85                        offset = msg.offset + 1; // Advance offset
86                                                 // Skip non-schema messages (or errors)
87                        if let Ok(event) = postcard::from_bytes::<SchemaLogEvent>(&msg.value) {
88                            match event {
89                                SchemaLogEvent::Register(reg) => {
90                                    self.cache.inject_registration(
91                                        reg.id,
92                                        &reg.subject,
93                                        &reg.schema,
94                                        reg.version,
95                                    )?;
96                                }
97                            }
98                        }
99                    }
100                }
101                Err(_) => break, // EOF or read error (e.g. OffsetOutOfBounds if end reached)
102            }
103        }
104
105        // Sync next_id with recovered max
106        let max_id = self
107            .cache
108            .id_to_schema
109            .iter()
110            .map(|entry| *entry.key())
111            .max()
112            .unwrap_or(0);
113
114        self.cache.next_id.store(max_id + 1, Ordering::SeqCst);
115
116        Ok(())
117    }
118}
119
120#[async_trait]
121impl SchemaRegistry for EmbeddedSchemaRegistry {
122    async fn register(&self, subject: &str, schema_str: &str) -> Result<i32> {
123        // 1. Check cache first
124        if let Ok(id) = self.cache.check_existing(subject, schema_str).await {
125            return Ok(id);
126        }
127
128        // 2. Allocate ID atomically
129        let next_id = self.cache.next_id.load(Ordering::SeqCst);
130
131        let version = self
132            .cache
133            .subject_versions
134            .get(subject)
135            .map(|v| v.iter().map(|entry| *entry.key()).max().unwrap_or(0) + 1)
136            .unwrap_or(1);
137
138        // 3. Serialize Event
139        let event = SchemaLogEvent::Register(SchemaRegistration {
140            id: next_id,
141            subject: subject.to_string(),
142            schema: schema_str.to_string(),
143            version,
144        });
145
146        let payload = postcard::to_allocvec(&event)
147            .map_err(|e| Error::Other(format!("Serialization error: {}", e)))?;
148
149        use crate::Message;
150
151        // 4. Create proper Message with key
152        let msg = Message::with_key(
153            bytes::Bytes::from(subject.to_string()), // Key
154            bytes::Bytes::from(payload),             // Value
155        );
156
157        // 5. Append to Log
158        self.partition.append(msg).await?;
159
160        // 6. Update Memory Cache
161        self.cache
162            .inject_registration(next_id, subject, schema_str, version)?;
163
164        // 7. Increment ID atomically
165        self.cache.next_id.fetch_add(1, Ordering::SeqCst);
166
167        Ok(next_id)
168    }
169
170    async fn get_schema(&self, id: i32) -> Result<Arc<Schema>> {
171        self.cache.get_schema(id).await
172    }
173
174    async fn get_latest_schema_id(&self, subject: &str) -> Result<Option<i32>> {
175        self.cache.get_latest_schema_id(subject).await
176    }
177}
178
179/// In-Memory implementation of Schema Registry with lock-free concurrent access.
180/// Uses DashMap for 100x faster concurrent reads vs `RwLock<HashMap>`.
181/// Uses `Arc<Schema>` to avoid cloning (10x faster lookups).
182#[derive(Debug, Clone)]
183pub struct MemorySchemaRegistry {
184    /// Lock-free concurrent map for schema lookups
185    id_to_schema: Arc<DashMap<i32, Arc<Schema>>>,
186
187    /// Subject to version map (nested structure)
188    subject_versions: Arc<DashMap<String, DashMap<i32, i32>>>,
189
190    /// Atomic counter for next schema ID (lock-free)
191    next_id: Arc<AtomicI32>,
192}
193
194impl Default for MemorySchemaRegistry {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl MemorySchemaRegistry {
201    pub fn new() -> Self {
202        Self {
203            id_to_schema: Arc::new(DashMap::new()),
204            subject_versions: Arc::new(DashMap::new()),
205            next_id: Arc::new(AtomicI32::new(1)),
206        }
207    }
208
209    pub fn inject_registration(
210        &self,
211        id: i32,
212        subject: &str,
213        schema_str: &str,
214        version: i32,
215    ) -> Result<()> {
216        let schema = Schema::parse_str(schema_str)
217            .map_err(|e| Error::Other(format!("Invalid Avro schema: {}", e)))?;
218        let schema = Arc::new(schema);
219
220        // Lock-free insertions
221        self.id_to_schema.insert(id, schema);
222
223        let versions = self
224            .subject_versions
225            .entry(subject.to_string())
226            .or_default();
227        versions.insert(version, id);
228
229        Ok(())
230    }
231
232    pub async fn check_existing(
233        &self,
234        subject: &str,
235        schema_str: &str,
236    ) -> std::result::Result<i32, ()> {
237        let schema = Schema::parse_str(schema_str).map_err(|_| ())?;
238        let canonical_form = schema.canonical_form();
239
240        // Lock-free iteration over concurrent map
241        for entry in self.id_to_schema.iter() {
242            let existing_id = *entry.key();
243            let existing_schema = entry.value();
244
245            if existing_schema.canonical_form() == canonical_form {
246                if let Some(versions) = self.subject_versions.get(subject) {
247                    for version_entry in versions.iter() {
248                        if *version_entry.value() == existing_id {
249                            return Ok(existing_id);
250                        }
251                    }
252                }
253            }
254        }
255        Err(())
256    }
257}
258
259#[async_trait]
260impl SchemaRegistry for MemorySchemaRegistry {
261    async fn register(&self, subject: &str, schema_str: &str) -> Result<i32> {
262        if let Ok(id) = self.check_existing(subject, schema_str).await {
263            return Ok(id);
264        }
265
266        // Lock-free ID allocation
267        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
268
269        let version = self
270            .subject_versions
271            .get(subject)
272            .map(|v| v.iter().map(|entry| *entry.key()).max().unwrap_or(0) + 1)
273            .unwrap_or(1);
274
275        self.inject_registration(id, subject, schema_str, version)?;
276        Ok(id)
277    }
278
279    async fn get_schema(&self, id: i32) -> Result<Arc<Schema>> {
280        self.id_to_schema
281            .get(&id)
282            .map(|entry| Arc::clone(entry.value()))
283            .ok_or_else(|| Error::Other(format!("Schema ID {} not found", id)))
284    }
285
286    async fn get_latest_schema_id(&self, subject: &str) -> Result<Option<i32>> {
287        if let Some(versions) = self.subject_versions.get(subject) {
288            let max_ver = versions.iter().map(|entry| *entry.key()).max();
289
290            if let Some(max_version) = max_ver {
291                return Ok(versions.get(&max_version).map(|e| *e.value()));
292            }
293        }
294        Ok(None)
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::Config;
302    use std::fs;
303
304    fn get_test_config() -> Config {
305        let config = Config {
306            data_dir: format!("/tmp/rivven-test-registry-{}", uuid::Uuid::new_v4()),
307            ..Default::default()
308        };
309        let _ = fs::remove_dir_all(&config.data_dir);
310        config
311    }
312
313    #[tokio::test]
314    async fn test_embedded_registry_persistence() {
315        let config = get_test_config();
316        let subject = "user-value";
317        let schema_str =
318            r#"{"type":"record","name":"User","fields":[{"name":"name","type":"string"}]}"#;
319
320        // 1. Start Registry and Register Schema
321        {
322            let registry = EmbeddedSchemaRegistry::new(&config)
323                .await
324                .expect("Failed to create registry");
325            let id = registry
326                .register(subject, schema_str)
327                .await
328                .expect("Failed to register");
329            assert_eq!(id, 1);
330
331            // Check it exists
332            let schema = registry.get_schema(1).await.expect("Failed to get schema");
333            assert!(format!("{:?}", schema).contains("User"));
334        }
335
336        // 2. Restart (Simulate by creating new instance on same dir)
337        {
338            let registry = EmbeddedSchemaRegistry::new(&config)
339                .await
340                .expect("Failed to recover registry");
341
342            // Should have id 1
343            let schema = registry
344                .get_schema(1)
345                .await
346                .expect("Failed to get schema after recovery");
347            assert!(format!("{:?}", schema).contains("User"));
348
349            // Check caching of subject version
350            let latest = registry
351                .get_latest_schema_id(subject)
352                .await
353                .expect("Failed to get latest")
354                .unwrap();
355            assert_eq!(latest, 1);
356
357            // 3. Register next schema
358            let schema_str2 =
359                r#"{"type":"record","name":"UserV2","fields":[{"name":"age","type":"int"}]}"#;
360            let id2 = registry
361                .register(subject, schema_str2)
362                .await
363                .expect("Failed to register 2");
364            assert_eq!(id2, 2);
365        }
366
367        // Cleanup
368        let _ = fs::remove_dir_all(&config.data_dir);
369    }
370}