1use std::collections::hash_map::Entry;
21use std::collections::HashMap;
22use std::sync::{Arc, RwLock};
23
24use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
25use serde::{Deserialize, Serialize};
26
27use crate::error::CoreError;
28
29pub fn validate_identifier(name: &str) -> Result<(), CoreError> {
40 if name.is_empty() {
41 return Err(CoreError::SchemaValidation(
42 "identifier must not be empty".to_string(),
43 ));
44 }
45 if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
46 return Err(CoreError::SchemaValidation(format!(
47 "identifier '{}' contains invalid characters (only [A-Za-z0-9_] allowed)",
48 name
49 )));
50 }
51 Ok(())
52}
53
54#[derive(Debug, Clone)]
56pub struct TableSchema {
57 pub name: String,
59 pub arrow_schema: SchemaRef,
61 pub primary_key: Vec<String>,
63}
64
65impl TableSchema {
66 pub fn new(name: impl Into<String>, schema: SchemaRef, primary_key: Vec<String>) -> Self {
76 Self {
77 name: name.into(),
78 arrow_schema: schema,
79 primary_key,
80 }
81 }
82
83 pub fn validate(&self) -> Result<(), CoreError> {
95 validate_identifier(&self.name)?;
97
98 for field in self.arrow_schema.fields() {
100 validate_identifier(field.name()).map_err(|_| {
101 CoreError::SchemaValidation(format!(
102 "column name '{}' in table '{}' contains invalid characters",
103 field.name(),
104 self.name
105 ))
106 })?;
107 }
108
109 if self.primary_key.is_empty() {
111 return Err(CoreError::SchemaValidation(format!(
112 "table '{}' must have at least one primary key column",
113 self.name
114 )));
115 }
116 for pk_col in &self.primary_key {
117 validate_identifier(pk_col)?;
118 if self.arrow_schema.field_with_name(pk_col).is_err() {
119 return Err(CoreError::SchemaValidation(format!(
120 "primary key column '{}' not found in schema for table '{}'",
121 pk_col, self.name
122 )));
123 }
124 }
125 Ok(())
126 }
127}
128
129#[derive(Debug, Clone)]
134pub struct SchemaRegistry {
135 tables: Arc<RwLock<HashMap<String, Arc<TableSchema>>>>,
136}
137
138impl SchemaRegistry {
139 pub fn new() -> Self {
141 Self {
142 tables: Arc::new(RwLock::new(HashMap::new())),
143 }
144 }
145
146 pub fn register(&self, schema: TableSchema) -> Result<(), CoreError> {
157 schema.validate()?;
158 let mut tables = self.tables.write().unwrap();
159 match tables.entry(schema.name.clone()) {
160 Entry::Occupied(_) => Err(CoreError::TableAlreadyRegistered(schema.name)),
161 Entry::Vacant(entry) => {
162 entry.insert(Arc::new(schema));
163 Ok(())
164 }
165 }
166 }
167
168 pub fn get(&self, table_name: &str) -> Result<Arc<TableSchema>, CoreError> {
177 let tables = self.tables.read().unwrap();
178 tables
179 .get(table_name)
180 .cloned()
181 .ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))
182 }
183
184 pub fn table_names(&self) -> Vec<String> {
186 let tables = self.tables.read().unwrap();
187 tables.keys().cloned().collect()
188 }
189
190 pub fn unregister(&self, table_name: &str) -> Result<Arc<TableSchema>, CoreError> {
196 let mut tables = self.tables.write().unwrap();
197 tables
198 .remove(table_name)
199 .ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))
200 }
201
202 pub fn update(&self, schema: TableSchema) -> Result<(), CoreError> {
213 schema.validate()?;
214 let mut tables = self.tables.write().unwrap();
215 match tables.entry(schema.name.clone()) {
216 Entry::Occupied(mut entry) => {
217 entry.insert(Arc::new(schema));
218 Ok(())
219 }
220 Entry::Vacant(_) => Err(CoreError::TableNotFound(schema.name)),
221 }
222 }
223
224 pub fn add_column(
238 &self,
239 table_name: &str,
240 column_name: &str,
241 data_type: DataType,
242 ) -> Result<Arc<TableSchema>, CoreError> {
243 validate_identifier(column_name)?;
244 let mut tables = self.tables.write().unwrap();
245 let existing = tables
246 .get(table_name)
247 .ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))?;
248
249 if existing.arrow_schema.field_with_name(column_name).is_ok() {
251 return Err(CoreError::SchemaValidation(format!(
252 "column '{}' already exists in table '{}'",
253 column_name, table_name
254 )));
255 }
256
257 let primary_key = existing.primary_key.clone();
258 let mut fields: Vec<Field> = existing
259 .arrow_schema
260 .fields()
261 .iter()
262 .map(|f| f.as_ref().clone())
263 .collect();
264 fields.push(Field::new(column_name, data_type, true)); Ok(commit_schema(&mut tables, table_name, fields, primary_key))
267 }
268
269 pub fn drop_column(
281 &self,
282 table_name: &str,
283 column_name: &str,
284 ) -> Result<Arc<TableSchema>, CoreError> {
285 let mut tables = self.tables.write().unwrap();
286 let existing = tables
287 .get(table_name)
288 .ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))?;
289
290 if existing.primary_key.contains(&column_name.to_string()) {
292 return Err(CoreError::SchemaValidation(format!(
293 "cannot drop primary key column '{}' from table '{}'",
294 column_name, table_name
295 )));
296 }
297
298 if existing.arrow_schema.field_with_name(column_name).is_err() {
300 return Err(CoreError::SchemaValidation(format!(
301 "column '{}' not found in table '{}'",
302 column_name, table_name
303 )));
304 }
305
306 let primary_key = existing.primary_key.clone();
307 let fields: Vec<Field> = existing
308 .arrow_schema
309 .fields()
310 .iter()
311 .filter(|f| f.name() != column_name)
312 .map(|f| f.as_ref().clone())
313 .collect();
314
315 Ok(commit_schema(&mut tables, table_name, fields, primary_key))
316 }
317}
318
319fn commit_schema(
323 tables: &mut HashMap<String, Arc<TableSchema>>,
324 table_name: &str,
325 fields: Vec<Field>,
326 primary_key: Vec<String>,
327) -> Arc<TableSchema> {
328 let schema = Arc::new(TableSchema {
329 name: table_name.to_string(),
330 arrow_schema: Arc::new(Schema::new(fields)),
331 primary_key,
332 });
333 tables.insert(table_name.to_string(), schema.clone());
334 schema
335}
336
337impl Default for SchemaRegistry {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[derive(Debug, Serialize, Deserialize)]
352struct PersistedSchema {
353 name: String,
354 primary_key: Vec<String>,
355 fields: Vec<(String, String, bool)>,
357}
358
359fn arrow_type_to_str(dt: &DataType) -> String {
366 match dt {
367 DataType::Int8 => "int8".to_string(),
368 DataType::Int16 => "int16".to_string(),
369 DataType::Int32 => "int32".to_string(),
370 DataType::Int64 => "int64".to_string(),
371 DataType::UInt8 => "uint8".to_string(),
372 DataType::UInt16 => "uint16".to_string(),
373 DataType::UInt32 => "uint32".to_string(),
374 DataType::UInt64 => "uint64".to_string(),
375 DataType::Float16 => "float16".to_string(),
376 DataType::Float32 => "float32".to_string(),
377 DataType::Float64 => "float64".to_string(),
378 DataType::Boolean => "boolean".to_string(),
379 DataType::Utf8 => "utf8".to_string(),
380 DataType::LargeUtf8 => "large_utf8".to_string(),
381 DataType::Binary => "binary".to_string(),
382 DataType::LargeBinary => "large_binary".to_string(),
383 DataType::Date32 => "date32".to_string(),
384 DataType::Date64 => "date64".to_string(),
385 DataType::Timestamp(TimeUnit::Second, tz) => {
386 format!("timestamp_s[{}]", tz.as_deref().unwrap_or(""))
387 }
388 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
389 format!("timestamp_ms[{}]", tz.as_deref().unwrap_or(""))
390 }
391 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
392 format!("timestamp_us[{}]", tz.as_deref().unwrap_or(""))
393 }
394 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
395 format!("timestamp_ns[{}]", tz.as_deref().unwrap_or(""))
396 }
397 DataType::Null => "null".to_string(),
398 other => format!("unknown:{other:?}"),
399 }
400}
401
402fn arrow_type_from_str(s: &str) -> Result<DataType, CoreError> {
404 if let Some(rest) = s.strip_prefix("timestamp_") {
406 let (unit_str, tz_part) = if let Some(idx) = rest.find('[') {
407 if !rest.ends_with(']') {
409 return Err(CoreError::SchemaValidation(format!(
410 "malformed timestamp type string '{s}': missing closing ']'"
411 )));
412 }
413 let unit = &rest[..idx];
414 let tz_raw = &rest[idx + 1..rest.len() - 1];
416 let tz: Option<Arc<str>> = if tz_raw.is_empty() {
417 None
418 } else {
419 Some(Arc::from(tz_raw))
420 };
421 (unit, tz)
422 } else {
423 (rest, None)
424 };
425 let unit = match unit_str {
426 "s" => TimeUnit::Second,
427 "ms" => TimeUnit::Millisecond,
428 "us" => TimeUnit::Microsecond,
429 "ns" => TimeUnit::Nanosecond,
430 other => {
431 return Err(CoreError::SchemaValidation(format!(
432 "unknown timestamp unit '{other}'"
433 )))
434 }
435 };
436 return Ok(DataType::Timestamp(unit, tz_part));
437 }
438
439 match s {
440 "int8" => Ok(DataType::Int8),
441 "int16" => Ok(DataType::Int16),
442 "int32" => Ok(DataType::Int32),
443 "int64" => Ok(DataType::Int64),
444 "uint8" => Ok(DataType::UInt8),
445 "uint16" => Ok(DataType::UInt16),
446 "uint32" => Ok(DataType::UInt32),
447 "uint64" => Ok(DataType::UInt64),
448 "float16" => Ok(DataType::Float16),
449 "float32" => Ok(DataType::Float32),
450 "float64" => Ok(DataType::Float64),
451 "boolean" => Ok(DataType::Boolean),
452 "utf8" => Ok(DataType::Utf8),
453 "large_utf8" => Ok(DataType::LargeUtf8),
454 "binary" => Ok(DataType::Binary),
455 "large_binary" => Ok(DataType::LargeBinary),
456 "date32" => Ok(DataType::Date32),
457 "date64" => Ok(DataType::Date64),
458 "null" => Ok(DataType::Null),
459 other => Err(CoreError::SchemaValidation(format!(
460 "cannot deserialize unknown Arrow type string '{other}'"
461 ))),
462 }
463}
464
465impl SchemaRegistry {
466 pub fn save_to_disk(&self, path: &str) -> Result<(), CoreError> {
476 let tables = self.tables.read().unwrap();
477 let persisted: Vec<PersistedSchema> = tables
478 .values()
479 .map(|ts| PersistedSchema {
480 name: ts.name.clone(),
481 primary_key: ts.primary_key.clone(),
482 fields: ts
483 .arrow_schema
484 .fields()
485 .iter()
486 .map(|f| {
487 (
488 f.name().clone(),
489 arrow_type_to_str(f.data_type()),
490 f.is_nullable(),
491 )
492 })
493 .collect(),
494 })
495 .collect();
496
497 let json = serde_json::to_string_pretty(&persisted).map_err(|e| {
498 CoreError::SchemaValidation(format!("failed to serialize schema registry: {e}"))
499 })?;
500
501 let tmp_path = format!("{path}.tmp");
503 std::fs::write(&tmp_path, &json).map_err(|e| {
504 CoreError::SchemaValidation(format!(
505 "failed to write schema registry to '{tmp_path}': {e}"
506 ))
507 })?;
508 std::fs::rename(&tmp_path, path).map_err(|e| {
509 CoreError::SchemaValidation(format!(
510 "failed to rename schema registry file '{tmp_path}' -> '{path}': {e}"
511 ))
512 })?;
513
514 Ok(())
515 }
516
517 pub fn load_from_disk(path: &str) -> Result<SchemaRegistry, CoreError> {
529 if !std::path::Path::new(path).exists() {
530 return Ok(SchemaRegistry::new());
531 }
532
533 let json = std::fs::read_to_string(path).map_err(|e| {
534 CoreError::SchemaValidation(format!(
535 "failed to read schema registry from '{path}': {e}"
536 ))
537 })?;
538
539 let persisted: Vec<PersistedSchema> = serde_json::from_str(&json).map_err(|e| {
540 CoreError::SchemaValidation(format!("failed to parse schema registry at '{path}': {e}"))
541 })?;
542
543 let registry = SchemaRegistry::new();
544 for ps in persisted {
545 let fields: Vec<Field> = ps
546 .fields
547 .iter()
548 .map(|(name, type_str, nullable)| {
549 arrow_type_from_str(type_str).map(|dt| Field::new(name.as_str(), dt, *nullable))
550 })
551 .collect::<Result<_, _>>()?;
552
553 let schema = Arc::new(Schema::new(fields));
554 let table_schema = TableSchema::new(ps.name, schema, ps.primary_key);
555 registry.register(table_schema)?;
558 }
559
560 Ok(registry)
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use arrow::datatypes::{DataType, Field, Schema};
568
569 fn simple_schema(col_type: DataType) -> TableSchema {
570 TableSchema::new(
571 "t",
572 Arc::new(Schema::new(vec![
573 Field::new("id", DataType::Int64, false),
574 Field::new("val", col_type, true),
575 ])),
576 vec!["id".to_string()],
577 )
578 }
579
580 #[test]
585 fn register_idempotent_matching_schema() {
586 let registry = SchemaRegistry::new();
588 let schema = simple_schema(DataType::Utf8);
589 registry.register(schema.clone()).unwrap();
590
591 let result = registry.register(schema);
593 assert!(
597 matches!(result, Err(CoreError::TableAlreadyRegistered(_))),
598 "expected TableAlreadyRegistered, got {result:?}"
599 );
600 }
601
602 #[test]
603 fn register_detects_conflict() {
604 let registry = SchemaRegistry::new();
607 registry.register(simple_schema(DataType::Utf8)).unwrap();
608
609 let conflicting = simple_schema(DataType::Int32); let result = registry.register(conflicting);
611 assert!(
612 matches!(result, Err(CoreError::TableAlreadyRegistered(_))),
613 "expected TableAlreadyRegistered for conflicting schema, got {result:?}"
614 );
615 }
616
617 #[test]
622 fn arrow_type_from_str_valid_timestamp() {
623 let cases = [
625 ("timestamp_s[]", DataType::Timestamp(TimeUnit::Second, None)),
626 (
627 "timestamp_ms[UTC]",
628 DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))),
629 ),
630 (
631 "timestamp_us[America/New_York]",
632 DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("America/New_York"))),
633 ),
634 (
635 "timestamp_ns[]",
636 DataType::Timestamp(TimeUnit::Nanosecond, None),
637 ),
638 ];
639 for (s, expected) in cases {
640 let got = arrow_type_from_str(s).unwrap_or_else(|e| panic!("parse '{s}' failed: {e}"));
641 assert_eq!(got, expected, "round-trip mismatch for '{s}'");
642 }
643 }
644
645 #[test]
646 fn arrow_type_from_str_missing_close_bracket() {
647 let result = arrow_type_from_str("timestamp_us[UTC");
649 assert!(
650 matches!(result, Err(CoreError::SchemaValidation(ref msg)) if msg.contains("missing closing ']'")),
651 "expected SchemaValidation error for missing ']', got {result:?}"
652 );
653 }
654
655 #[test]
656 fn arrow_type_from_str_empty_bracket_no_close() {
657 let result = arrow_type_from_str("timestamp_us[");
659 assert!(
660 matches!(result, Err(CoreError::SchemaValidation(ref msg)) if msg.contains("missing closing ']'")),
661 "expected SchemaValidation error for 'timestamp_us[', got {result:?}"
662 );
663 }
664}