use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use serde::{Deserialize, Serialize};
use crate::error::CoreError;
pub fn validate_identifier(name: &str) -> Result<(), CoreError> {
if name.is_empty() {
return Err(CoreError::SchemaValidation(
"identifier must not be empty".to_string(),
));
}
if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
return Err(CoreError::SchemaValidation(format!(
"identifier '{}' contains invalid characters (only [A-Za-z0-9_] allowed)",
name
)));
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct TableSchema {
pub name: String,
pub arrow_schema: SchemaRef,
pub primary_key: Vec<String>,
}
impl TableSchema {
pub fn new(name: impl Into<String>, schema: SchemaRef, primary_key: Vec<String>) -> Self {
Self {
name: name.into(),
arrow_schema: schema,
primary_key,
}
}
pub fn validate(&self) -> Result<(), CoreError> {
validate_identifier(&self.name)?;
for field in self.arrow_schema.fields() {
validate_identifier(field.name()).map_err(|_| {
CoreError::SchemaValidation(format!(
"column name '{}' in table '{}' contains invalid characters",
field.name(),
self.name
))
})?;
}
if self.primary_key.is_empty() {
return Err(CoreError::SchemaValidation(format!(
"table '{}' must have at least one primary key column",
self.name
)));
}
for pk_col in &self.primary_key {
validate_identifier(pk_col)?;
if self.arrow_schema.field_with_name(pk_col).is_err() {
return Err(CoreError::SchemaValidation(format!(
"primary key column '{}' not found in schema for table '{}'",
pk_col, self.name
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SchemaRegistry {
tables: Arc<RwLock<HashMap<String, Arc<TableSchema>>>>,
}
impl SchemaRegistry {
pub fn new() -> Self {
Self {
tables: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register(&self, schema: TableSchema) -> Result<(), CoreError> {
schema.validate()?;
let mut tables = self.tables.write().unwrap();
match tables.entry(schema.name.clone()) {
Entry::Occupied(_) => Err(CoreError::TableAlreadyRegistered(schema.name)),
Entry::Vacant(entry) => {
entry.insert(Arc::new(schema));
Ok(())
}
}
}
pub fn get(&self, table_name: &str) -> Result<Arc<TableSchema>, CoreError> {
let tables = self.tables.read().unwrap();
tables
.get(table_name)
.cloned()
.ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))
}
pub fn table_names(&self) -> Vec<String> {
let tables = self.tables.read().unwrap();
tables.keys().cloned().collect()
}
pub fn unregister(&self, table_name: &str) -> Result<Arc<TableSchema>, CoreError> {
let mut tables = self.tables.write().unwrap();
tables
.remove(table_name)
.ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))
}
pub fn update(&self, schema: TableSchema) -> Result<(), CoreError> {
schema.validate()?;
let mut tables = self.tables.write().unwrap();
match tables.entry(schema.name.clone()) {
Entry::Occupied(mut entry) => {
entry.insert(Arc::new(schema));
Ok(())
}
Entry::Vacant(_) => Err(CoreError::TableNotFound(schema.name)),
}
}
pub fn add_column(
&self,
table_name: &str,
column_name: &str,
data_type: DataType,
) -> Result<Arc<TableSchema>, CoreError> {
validate_identifier(column_name)?;
let mut tables = self.tables.write().unwrap();
let existing = tables
.get(table_name)
.ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))?;
if existing.arrow_schema.field_with_name(column_name).is_ok() {
return Err(CoreError::SchemaValidation(format!(
"column '{}' already exists in table '{}'",
column_name, table_name
)));
}
let primary_key = existing.primary_key.clone();
let mut fields: Vec<Field> = existing
.arrow_schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
fields.push(Field::new(column_name, data_type, true));
Ok(commit_schema(&mut tables, table_name, fields, primary_key))
}
pub fn drop_column(
&self,
table_name: &str,
column_name: &str,
) -> Result<Arc<TableSchema>, CoreError> {
let mut tables = self.tables.write().unwrap();
let existing = tables
.get(table_name)
.ok_or_else(|| CoreError::TableNotFound(table_name.to_string()))?;
if existing.primary_key.contains(&column_name.to_string()) {
return Err(CoreError::SchemaValidation(format!(
"cannot drop primary key column '{}' from table '{}'",
column_name, table_name
)));
}
if existing.arrow_schema.field_with_name(column_name).is_err() {
return Err(CoreError::SchemaValidation(format!(
"column '{}' not found in table '{}'",
column_name, table_name
)));
}
let primary_key = existing.primary_key.clone();
let fields: Vec<Field> = existing
.arrow_schema
.fields()
.iter()
.filter(|f| f.name() != column_name)
.map(|f| f.as_ref().clone())
.collect();
Ok(commit_schema(&mut tables, table_name, fields, primary_key))
}
}
fn commit_schema(
tables: &mut HashMap<String, Arc<TableSchema>>,
table_name: &str,
fields: Vec<Field>,
primary_key: Vec<String>,
) -> Arc<TableSchema> {
let schema = Arc::new(TableSchema {
name: table_name.to_string(),
arrow_schema: Arc::new(Schema::new(fields)),
primary_key,
});
tables.insert(table_name.to_string(), schema.clone());
schema
}
impl Default for SchemaRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Serialize, Deserialize)]
struct PersistedSchema {
name: String,
primary_key: Vec<String>,
fields: Vec<(String, String, bool)>,
}
fn arrow_type_to_str(dt: &DataType) -> String {
match dt {
DataType::Int8 => "int8".to_string(),
DataType::Int16 => "int16".to_string(),
DataType::Int32 => "int32".to_string(),
DataType::Int64 => "int64".to_string(),
DataType::UInt8 => "uint8".to_string(),
DataType::UInt16 => "uint16".to_string(),
DataType::UInt32 => "uint32".to_string(),
DataType::UInt64 => "uint64".to_string(),
DataType::Float16 => "float16".to_string(),
DataType::Float32 => "float32".to_string(),
DataType::Float64 => "float64".to_string(),
DataType::Boolean => "boolean".to_string(),
DataType::Utf8 => "utf8".to_string(),
DataType::LargeUtf8 => "large_utf8".to_string(),
DataType::Binary => "binary".to_string(),
DataType::LargeBinary => "large_binary".to_string(),
DataType::Date32 => "date32".to_string(),
DataType::Date64 => "date64".to_string(),
DataType::Timestamp(TimeUnit::Second, tz) => {
format!("timestamp_s[{}]", tz.as_deref().unwrap_or(""))
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
format!("timestamp_ms[{}]", tz.as_deref().unwrap_or(""))
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
format!("timestamp_us[{}]", tz.as_deref().unwrap_or(""))
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
format!("timestamp_ns[{}]", tz.as_deref().unwrap_or(""))
}
DataType::Null => "null".to_string(),
other => format!("unknown:{other:?}"),
}
}
fn arrow_type_from_str(s: &str) -> Result<DataType, CoreError> {
if let Some(rest) = s.strip_prefix("timestamp_") {
let (unit_str, tz_part) = if let Some(idx) = rest.find('[') {
if !rest.ends_with(']') {
return Err(CoreError::SchemaValidation(format!(
"malformed timestamp type string '{s}': missing closing ']'"
)));
}
let unit = &rest[..idx];
let tz_raw = &rest[idx + 1..rest.len() - 1];
let tz: Option<Arc<str>> = if tz_raw.is_empty() {
None
} else {
Some(Arc::from(tz_raw))
};
(unit, tz)
} else {
(rest, None)
};
let unit = match unit_str {
"s" => TimeUnit::Second,
"ms" => TimeUnit::Millisecond,
"us" => TimeUnit::Microsecond,
"ns" => TimeUnit::Nanosecond,
other => {
return Err(CoreError::SchemaValidation(format!(
"unknown timestamp unit '{other}'"
)))
}
};
return Ok(DataType::Timestamp(unit, tz_part));
}
match s {
"int8" => Ok(DataType::Int8),
"int16" => Ok(DataType::Int16),
"int32" => Ok(DataType::Int32),
"int64" => Ok(DataType::Int64),
"uint8" => Ok(DataType::UInt8),
"uint16" => Ok(DataType::UInt16),
"uint32" => Ok(DataType::UInt32),
"uint64" => Ok(DataType::UInt64),
"float16" => Ok(DataType::Float16),
"float32" => Ok(DataType::Float32),
"float64" => Ok(DataType::Float64),
"boolean" => Ok(DataType::Boolean),
"utf8" => Ok(DataType::Utf8),
"large_utf8" => Ok(DataType::LargeUtf8),
"binary" => Ok(DataType::Binary),
"large_binary" => Ok(DataType::LargeBinary),
"date32" => Ok(DataType::Date32),
"date64" => Ok(DataType::Date64),
"null" => Ok(DataType::Null),
other => Err(CoreError::SchemaValidation(format!(
"cannot deserialize unknown Arrow type string '{other}'"
))),
}
}
impl SchemaRegistry {
pub fn save_to_disk(&self, path: &str) -> Result<(), CoreError> {
let tables = self.tables.read().unwrap();
let persisted: Vec<PersistedSchema> = tables
.values()
.map(|ts| PersistedSchema {
name: ts.name.clone(),
primary_key: ts.primary_key.clone(),
fields: ts
.arrow_schema
.fields()
.iter()
.map(|f| {
(
f.name().clone(),
arrow_type_to_str(f.data_type()),
f.is_nullable(),
)
})
.collect(),
})
.collect();
let json = serde_json::to_string_pretty(&persisted).map_err(|e| {
CoreError::SchemaValidation(format!("failed to serialize schema registry: {e}"))
})?;
let tmp_path = format!("{path}.tmp");
std::fs::write(&tmp_path, &json).map_err(|e| {
CoreError::SchemaValidation(format!(
"failed to write schema registry to '{tmp_path}': {e}"
))
})?;
std::fs::rename(&tmp_path, path).map_err(|e| {
CoreError::SchemaValidation(format!(
"failed to rename schema registry file '{tmp_path}' -> '{path}': {e}"
))
})?;
Ok(())
}
pub fn load_from_disk(path: &str) -> Result<SchemaRegistry, CoreError> {
if !std::path::Path::new(path).exists() {
return Ok(SchemaRegistry::new());
}
let json = std::fs::read_to_string(path).map_err(|e| {
CoreError::SchemaValidation(format!(
"failed to read schema registry from '{path}': {e}"
))
})?;
let persisted: Vec<PersistedSchema> = serde_json::from_str(&json).map_err(|e| {
CoreError::SchemaValidation(format!("failed to parse schema registry at '{path}': {e}"))
})?;
let registry = SchemaRegistry::new();
for ps in persisted {
let fields: Vec<Field> = ps
.fields
.iter()
.map(|(name, type_str, nullable)| {
arrow_type_from_str(type_str).map(|dt| Field::new(name.as_str(), dt, *nullable))
})
.collect::<Result<_, _>>()?;
let schema = Arc::new(Schema::new(fields));
let table_schema = TableSchema::new(ps.name, schema, ps.primary_key);
registry.register(table_schema)?;
}
Ok(registry)
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{DataType, Field, Schema};
fn simple_schema(col_type: DataType) -> TableSchema {
TableSchema::new(
"t",
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("val", col_type, true),
])),
vec!["id".to_string()],
)
}
#[test]
fn register_idempotent_matching_schema() {
let registry = SchemaRegistry::new();
let schema = simple_schema(DataType::Utf8);
registry.register(schema.clone()).unwrap();
let result = registry.register(schema);
assert!(
matches!(result, Err(CoreError::TableAlreadyRegistered(_))),
"expected TableAlreadyRegistered, got {result:?}"
);
}
#[test]
fn register_detects_conflict() {
let registry = SchemaRegistry::new();
registry.register(simple_schema(DataType::Utf8)).unwrap();
let conflicting = simple_schema(DataType::Int32); let result = registry.register(conflicting);
assert!(
matches!(result, Err(CoreError::TableAlreadyRegistered(_))),
"expected TableAlreadyRegistered for conflicting schema, got {result:?}"
);
}
#[test]
fn arrow_type_from_str_valid_timestamp() {
let cases = [
("timestamp_s[]", DataType::Timestamp(TimeUnit::Second, None)),
(
"timestamp_ms[UTC]",
DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))),
),
(
"timestamp_us[America/New_York]",
DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("America/New_York"))),
),
(
"timestamp_ns[]",
DataType::Timestamp(TimeUnit::Nanosecond, None),
),
];
for (s, expected) in cases {
let got = arrow_type_from_str(s).unwrap_or_else(|e| panic!("parse '{s}' failed: {e}"));
assert_eq!(got, expected, "round-trip mismatch for '{s}'");
}
}
#[test]
fn arrow_type_from_str_missing_close_bracket() {
let result = arrow_type_from_str("timestamp_us[UTC");
assert!(
matches!(result, Err(CoreError::SchemaValidation(ref msg)) if msg.contains("missing closing ']'")),
"expected SchemaValidation error for missing ']', got {result:?}"
);
}
#[test]
fn arrow_type_from_str_empty_bracket_no_close() {
let result = arrow_type_from_str("timestamp_us[");
assert!(
matches!(result, Err(CoreError::SchemaValidation(ref msg)) if msg.contains("missing closing ']'")),
"expected SchemaValidation error for 'timestamp_us[', got {result:?}"
);
}
}