1use crate::{Error, Result};
2use apache_avro::Schema;
3use async_trait::async_trait;
4use dashmap::DashMap;
5use serde::{Deserialize, Serialize};
6use std::sync::atomic::{AtomicI32, Ordering};
7use std::sync::Arc;
8
9#[async_trait]
16pub trait SchemaRegistry: Send + Sync + std::fmt::Debug {
17 async fn register(&self, subject: &str, schema_str: &str) -> Result<i32>;
20
21 async fn get_schema(&self, id: i32) -> Result<Arc<Schema>>;
23
24 async fn get_latest_schema_id(&self, subject: &str) -> Result<Option<i32>>;
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29pub enum SchemaLogEvent {
30 Register(SchemaRegistration),
31}
32
33#[derive(Serialize, Deserialize, Debug)]
34pub struct SchemaRegistration {
35 pub id: i32,
36 pub subject: String,
37 pub schema: String,
38 pub version: i32,
39}
40
41#[derive(Debug)]
43pub struct EmbeddedSchemaRegistry {
44 cache: MemorySchemaRegistry,
47
48 partition: Arc<crate::partition::Partition>,
51}
52
53impl EmbeddedSchemaRegistry {
54 pub async fn new(config: &crate::Config) -> Result<Self> {
55 let partition = crate::partition::Partition::new(config, "_schemas", 0).await?;
57 let partition = Arc::new(partition);
58
59 let registry = Self {
60 cache: MemorySchemaRegistry::new(),
61 partition,
62 };
63
64 registry.recover().await?;
65
66 Ok(registry)
67 }
68
69 async fn recover(&self) -> Result<()> {
71 let mut offset = 0;
72 #[allow(clippy::while_let_loop)] loop {
76 match self.partition.read(offset, 100).await {
80 Ok(messages) => {
81 if messages.is_empty() {
82 break;
83 }
84 for msg in messages {
85 offset = msg.offset + 1; if let Ok(event) = postcard::from_bytes::<SchemaLogEvent>(&msg.value) {
88 match event {
89 SchemaLogEvent::Register(reg) => {
90 self.cache.inject_registration(
91 reg.id,
92 ®.subject,
93 ®.schema,
94 reg.version,
95 )?;
96 }
97 }
98 }
99 }
100 }
101 Err(_) => break, }
103 }
104
105 let max_id = self
107 .cache
108 .id_to_schema
109 .iter()
110 .map(|entry| *entry.key())
111 .max()
112 .unwrap_or(0);
113
114 self.cache.next_id.store(max_id + 1, Ordering::SeqCst);
115
116 Ok(())
117 }
118}
119
120#[async_trait]
121impl SchemaRegistry for EmbeddedSchemaRegistry {
122 async fn register(&self, subject: &str, schema_str: &str) -> Result<i32> {
123 if let Ok(id) = self.cache.check_existing(subject, schema_str).await {
125 return Ok(id);
126 }
127
128 let next_id = self.cache.next_id.load(Ordering::SeqCst);
130
131 let version = self
132 .cache
133 .subject_versions
134 .get(subject)
135 .map(|v| v.iter().map(|entry| *entry.key()).max().unwrap_or(0) + 1)
136 .unwrap_or(1);
137
138 let event = SchemaLogEvent::Register(SchemaRegistration {
140 id: next_id,
141 subject: subject.to_string(),
142 schema: schema_str.to_string(),
143 version,
144 });
145
146 let payload = postcard::to_allocvec(&event)
147 .map_err(|e| Error::Other(format!("Serialization error: {}", e)))?;
148
149 use crate::Message;
150
151 let msg = Message::with_key(
153 bytes::Bytes::from(subject.to_string()), bytes::Bytes::from(payload), );
156
157 self.partition.append(msg).await?;
159
160 self.cache
162 .inject_registration(next_id, subject, schema_str, version)?;
163
164 self.cache.next_id.fetch_add(1, Ordering::SeqCst);
166
167 Ok(next_id)
168 }
169
170 async fn get_schema(&self, id: i32) -> Result<Arc<Schema>> {
171 self.cache.get_schema(id).await
172 }
173
174 async fn get_latest_schema_id(&self, subject: &str) -> Result<Option<i32>> {
175 self.cache.get_latest_schema_id(subject).await
176 }
177}
178
179#[derive(Debug, Clone)]
183pub struct MemorySchemaRegistry {
184 id_to_schema: Arc<DashMap<i32, Arc<Schema>>>,
186
187 subject_versions: Arc<DashMap<String, DashMap<i32, i32>>>,
189
190 next_id: Arc<AtomicI32>,
192}
193
194impl Default for MemorySchemaRegistry {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl MemorySchemaRegistry {
201 pub fn new() -> Self {
202 Self {
203 id_to_schema: Arc::new(DashMap::new()),
204 subject_versions: Arc::new(DashMap::new()),
205 next_id: Arc::new(AtomicI32::new(1)),
206 }
207 }
208
209 pub fn inject_registration(
210 &self,
211 id: i32,
212 subject: &str,
213 schema_str: &str,
214 version: i32,
215 ) -> Result<()> {
216 let schema = Schema::parse_str(schema_str)
217 .map_err(|e| Error::Other(format!("Invalid Avro schema: {}", e)))?;
218 let schema = Arc::new(schema);
219
220 self.id_to_schema.insert(id, schema);
222
223 let versions = self
224 .subject_versions
225 .entry(subject.to_string())
226 .or_default();
227 versions.insert(version, id);
228
229 Ok(())
230 }
231
232 pub async fn check_existing(
233 &self,
234 subject: &str,
235 schema_str: &str,
236 ) -> std::result::Result<i32, ()> {
237 let schema = Schema::parse_str(schema_str).map_err(|_| ())?;
238 let canonical_form = schema.canonical_form();
239
240 for entry in self.id_to_schema.iter() {
242 let existing_id = *entry.key();
243 let existing_schema = entry.value();
244
245 if existing_schema.canonical_form() == canonical_form {
246 if let Some(versions) = self.subject_versions.get(subject) {
247 for version_entry in versions.iter() {
248 if *version_entry.value() == existing_id {
249 return Ok(existing_id);
250 }
251 }
252 }
253 }
254 }
255 Err(())
256 }
257}
258
259#[async_trait]
260impl SchemaRegistry for MemorySchemaRegistry {
261 async fn register(&self, subject: &str, schema_str: &str) -> Result<i32> {
262 if let Ok(id) = self.check_existing(subject, schema_str).await {
263 return Ok(id);
264 }
265
266 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
268
269 let version = self
270 .subject_versions
271 .get(subject)
272 .map(|v| v.iter().map(|entry| *entry.key()).max().unwrap_or(0) + 1)
273 .unwrap_or(1);
274
275 self.inject_registration(id, subject, schema_str, version)?;
276 Ok(id)
277 }
278
279 async fn get_schema(&self, id: i32) -> Result<Arc<Schema>> {
280 self.id_to_schema
281 .get(&id)
282 .map(|entry| Arc::clone(entry.value()))
283 .ok_or_else(|| Error::Other(format!("Schema ID {} not found", id)))
284 }
285
286 async fn get_latest_schema_id(&self, subject: &str) -> Result<Option<i32>> {
287 if let Some(versions) = self.subject_versions.get(subject) {
288 let max_ver = versions.iter().map(|entry| *entry.key()).max();
289
290 if let Some(max_version) = max_ver {
291 return Ok(versions.get(&max_version).map(|e| *e.value()));
292 }
293 }
294 Ok(None)
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::Config;
302 use std::fs;
303
304 fn get_test_config() -> Config {
305 let config = Config {
306 data_dir: format!("/tmp/rivven-test-registry-{}", uuid::Uuid::new_v4()),
307 ..Default::default()
308 };
309 let _ = fs::remove_dir_all(&config.data_dir);
310 config
311 }
312
313 #[tokio::test]
314 async fn test_embedded_registry_persistence() {
315 let config = get_test_config();
316 let subject = "user-value";
317 let schema_str =
318 r#"{"type":"record","name":"User","fields":[{"name":"name","type":"string"}]}"#;
319
320 {
322 let registry = EmbeddedSchemaRegistry::new(&config)
323 .await
324 .expect("Failed to create registry");
325 let id = registry
326 .register(subject, schema_str)
327 .await
328 .expect("Failed to register");
329 assert_eq!(id, 1);
330
331 let schema = registry.get_schema(1).await.expect("Failed to get schema");
333 assert!(format!("{:?}", schema).contains("User"));
334 }
335
336 {
338 let registry = EmbeddedSchemaRegistry::new(&config)
339 .await
340 .expect("Failed to recover registry");
341
342 let schema = registry
344 .get_schema(1)
345 .await
346 .expect("Failed to get schema after recovery");
347 assert!(format!("{:?}", schema).contains("User"));
348
349 let latest = registry
351 .get_latest_schema_id(subject)
352 .await
353 .expect("Failed to get latest")
354 .unwrap();
355 assert_eq!(latest, 1);
356
357 let schema_str2 =
359 r#"{"type":"record","name":"UserV2","fields":[{"name":"age","type":"int"}]}"#;
360 let id2 = registry
361 .register(subject, schema_str2)
362 .await
363 .expect("Failed to register 2");
364 assert_eq!(id2, 2);
365 }
366
367 let _ = fs::remove_dir_all(&config.data_dir);
369 }
370}