use std::fmt;
use std::sync::{Arc, OnceLock};
use chrono::{DateTime, Utc};
use crate::common::{CompactArc, StringMap};
use super::error::{Error, Result};
use super::types::{DataType, ForeignKeyAction};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SchemaColumn {
pub id: usize,
pub name: String,
pub name_lower: String,
pub data_type: DataType,
pub nullable: bool,
pub primary_key: bool,
pub auto_increment: bool,
pub default_expr: Option<String>,
pub default_value: Option<super::Value>,
pub check_expr: Option<String>,
pub vector_dimensions: u16,
}
impl SchemaColumn {
pub fn new(
id: usize,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
primary_key: bool,
) -> Self {
let name_str = name.into();
let name_lower = name_str.to_lowercase();
Self {
id,
name: name_str,
name_lower,
data_type,
nullable,
primary_key,
auto_increment: false,
default_expr: None,
default_value: None,
check_expr: None,
vector_dimensions: 0,
}
}
pub fn with_vector_dimensions(mut self, dims: u16) -> Self {
self.vector_dimensions = dims;
self
}
#[allow(clippy::too_many_arguments)]
pub fn with_constraints(
id: usize,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
primary_key: bool,
auto_increment: bool,
default_expr: Option<String>,
check_expr: Option<String>,
) -> Self {
let name_str = name.into();
let name_lower = name_str.to_lowercase();
Self {
id,
name: name_str,
name_lower,
data_type,
nullable,
primary_key,
auto_increment,
default_expr,
default_value: None,
check_expr,
vector_dimensions: 0,
}
}
#[allow(clippy::too_many_arguments)]
pub fn with_default_value(
id: usize,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
primary_key: bool,
auto_increment: bool,
default_expr: Option<String>,
default_value: Option<super::Value>,
check_expr: Option<String>,
) -> Self {
let name_str = name.into();
let name_lower = name_str.to_lowercase();
Self {
id,
name: name_str,
name_lower,
data_type,
nullable,
primary_key,
auto_increment,
default_expr,
default_value,
check_expr,
vector_dimensions: 0,
}
}
pub fn simple(id: usize, name: impl Into<String>, data_type: DataType) -> Self {
Self::new(id, name, data_type, false, false)
}
pub fn nullable(id: usize, name: impl Into<String>, data_type: DataType) -> Self {
Self::new(id, name, data_type, true, false)
}
pub fn primary_key(id: usize, name: impl Into<String>, data_type: DataType) -> Self {
Self::new(id, name, data_type, false, true)
}
}
impl fmt::Display for SchemaColumn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.data_type == DataType::Vector && self.vector_dimensions > 0 {
write!(f, "{} VECTOR({})", self.name, self.vector_dimensions)?;
} else {
write!(f, "{} {}", self.name, self.data_type)?;
}
if self.primary_key {
write!(f, " PRIMARY KEY")?;
}
if !self.nullable && !self.primary_key {
write!(f, " NOT NULL")?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ForeignKeyConstraint {
pub column_index: usize,
pub column_name: String,
pub referenced_table: String,
pub referenced_column: String,
pub on_delete: ForeignKeyAction,
pub on_update: ForeignKeyAction,
}
#[derive(Debug)]
pub struct Schema {
pub table_name: String,
pub table_name_lower: String,
pub columns: Vec<SchemaColumn>,
pub foreign_keys: Vec<ForeignKeyConstraint>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
column_names_cache: OnceLock<CompactArc<Vec<String>>>,
pk_column_index_cache: OnceLock<Option<usize>>,
column_index_map_cache: OnceLock<StringMap<usize>>,
pk_indices_cache: OnceLock<Arc<Vec<usize>>>,
column_names_lower_cache: OnceLock<CompactArc<Vec<String>>>,
}
impl Clone for Schema {
fn clone(&self) -> Self {
let column_names_cache = OnceLock::new();
if let Some(names) = self.column_names_cache.get() {
let _ = column_names_cache.set(CompactArc::clone(names));
}
let pk_column_index_cache = OnceLock::new();
if let Some(pk_idx) = self.pk_column_index_cache.get() {
let _ = pk_column_index_cache.set(*pk_idx);
}
let column_index_map_cache = OnceLock::new();
if let Some(map) = self.column_index_map_cache.get() {
let _ = column_index_map_cache.set(map.clone());
}
let pk_indices_cache = OnceLock::new();
if let Some(indices) = self.pk_indices_cache.get() {
let _ = pk_indices_cache.set(Arc::clone(indices));
}
let column_names_lower_cache = OnceLock::new();
if let Some(names) = self.column_names_lower_cache.get() {
let _ = column_names_lower_cache.set(CompactArc::clone(names));
}
Self {
table_name: self.table_name.clone(),
table_name_lower: self.table_name_lower.clone(),
columns: self.columns.clone(),
foreign_keys: self.foreign_keys.clone(),
created_at: self.created_at,
updated_at: self.updated_at,
column_names_cache,
pk_column_index_cache,
column_index_map_cache,
pk_indices_cache,
column_names_lower_cache,
}
}
}
impl PartialEq for Schema {
fn eq(&self, other: &Self) -> bool {
self.table_name == other.table_name
&& self.columns == other.columns
&& self.foreign_keys == other.foreign_keys
&& self.created_at == other.created_at
&& self.updated_at == other.updated_at
}
}
impl Eq for Schema {}
impl Schema {
pub fn new(table_name: impl Into<String>, columns: Vec<SchemaColumn>) -> Self {
Self::with_foreign_keys(table_name, columns, Vec::new())
}
pub fn with_foreign_keys(
table_name: impl Into<String>,
columns: Vec<SchemaColumn>,
foreign_keys: Vec<ForeignKeyConstraint>,
) -> Self {
let now = Utc::now();
let name = table_name.into();
let name_lower = name.to_lowercase();
let column_names_cache = OnceLock::new();
let _ = column_names_cache.set(CompactArc::new(
columns.iter().map(|c| c.name.clone()).collect(),
));
let pk_column_index_cache = OnceLock::new();
let pk_idx = columns
.iter()
.enumerate()
.find(|(_, col)| col.primary_key && col.data_type == DataType::Integer)
.map(|(i, _)| i);
let _ = pk_column_index_cache.set(pk_idx);
let column_index_map_cache = OnceLock::new();
let _ = column_index_map_cache.set(
columns
.iter()
.enumerate()
.map(|(i, c)| (c.name_lower.clone(), i))
.collect(),
);
let pk_indices_cache = OnceLock::new();
let _ = pk_indices_cache.set(Arc::new(
columns
.iter()
.enumerate()
.filter(|(_, c)| c.primary_key)
.map(|(i, _)| i)
.collect(),
));
let column_names_lower_cache = OnceLock::new();
let _ = column_names_lower_cache.set(CompactArc::new(
columns.iter().map(|c| c.name_lower.clone()).collect(),
));
Self {
table_name: name,
table_name_lower: name_lower,
columns,
foreign_keys,
created_at: now,
updated_at: now,
column_names_cache,
pk_column_index_cache,
column_index_map_cache,
pk_indices_cache,
column_names_lower_cache,
}
}
pub fn with_timestamps(
table_name: impl Into<String>,
columns: Vec<SchemaColumn>,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
) -> Self {
Self::with_timestamps_and_foreign_keys(
table_name,
columns,
Vec::new(),
created_at,
updated_at,
)
}
pub fn with_timestamps_and_foreign_keys(
table_name: impl Into<String>,
columns: Vec<SchemaColumn>,
foreign_keys: Vec<ForeignKeyConstraint>,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
) -> Self {
let name = table_name.into();
let name_lower = name.to_lowercase();
let column_names_cache = OnceLock::new();
let _ = column_names_cache.set(CompactArc::new(
columns.iter().map(|c| c.name.clone()).collect(),
));
let pk_column_index_cache = OnceLock::new();
let pk_idx = columns
.iter()
.enumerate()
.find(|(_, col)| col.primary_key && col.data_type == DataType::Integer)
.map(|(i, _)| i);
let _ = pk_column_index_cache.set(pk_idx);
let column_index_map_cache = OnceLock::new();
let _ = column_index_map_cache.set(
columns
.iter()
.enumerate()
.map(|(i, c)| (c.name_lower.clone(), i))
.collect(),
);
let pk_indices_cache = OnceLock::new();
let _ = pk_indices_cache.set(Arc::new(
columns
.iter()
.enumerate()
.filter(|(_, c)| c.primary_key)
.map(|(i, _)| i)
.collect(),
));
let column_names_lower_cache = OnceLock::new();
let _ = column_names_lower_cache.set(CompactArc::new(
columns.iter().map(|c| c.name_lower.clone()).collect(),
));
Self {
table_name: name,
table_name_lower: name_lower,
columns,
foreign_keys,
created_at,
updated_at,
column_names_cache,
pk_column_index_cache,
column_index_map_cache,
pk_indices_cache,
column_names_lower_cache,
}
}
pub fn column_count(&self) -> usize {
self.columns.len()
}
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
pub fn find_column(&self, name: &str) -> Option<(usize, &SchemaColumn)> {
let name_lower = name.to_lowercase();
self.column_index_map()
.get(&name_lower)
.map(|&idx| (idx, &self.columns[idx]))
}
pub fn get_column(&self, index: usize) -> Option<&SchemaColumn> {
self.columns.get(index)
}
pub fn get_column_by_name(&self, name: &str) -> Option<&SchemaColumn> {
self.find_column(name).map(|(_, col)| col)
}
pub fn get_column_index(&self, name: &str) -> Option<usize> {
self.find_column(name).map(|(idx, _)| idx)
}
pub fn get_column_type(&self, name: &str) -> Option<DataType> {
self.get_column_by_name(name).map(|col| col.data_type)
}
pub fn has_column(&self, name: &str) -> bool {
self.find_column(name).is_some()
}
pub fn column_names(&self) -> Vec<&str> {
self.columns.iter().map(|c| c.name.as_str()).collect()
}
#[inline]
pub fn column_names_owned(&self) -> &[String] {
self.column_names_cache
.get_or_init(|| CompactArc::new(self.columns.iter().map(|c| c.name.clone()).collect()))
}
#[inline]
pub fn column_names_arc(&self) -> CompactArc<Vec<String>> {
CompactArc::clone(
self.column_names_cache.get_or_init(|| {
CompactArc::new(self.columns.iter().map(|c| c.name.clone()).collect())
}),
)
}
#[inline]
pub fn column_names_lower_arc(&self) -> CompactArc<Vec<String>> {
CompactArc::clone(self.column_names_lower_cache.get_or_init(|| {
CompactArc::new(self.columns.iter().map(|c| c.name_lower.clone()).collect())
}))
}
#[inline]
pub fn column_index_map(&self) -> &StringMap<usize> {
self.column_index_map_cache.get_or_init(|| {
self.columns
.iter()
.enumerate()
.map(|(i, c)| (c.name_lower.clone(), i))
.collect()
})
}
pub fn primary_key_columns(&self) -> Vec<&SchemaColumn> {
self.columns.iter().filter(|c| c.primary_key).collect()
}
pub fn has_primary_key(&self) -> bool {
self.columns.iter().any(|c| c.primary_key)
}
#[inline]
pub fn primary_key_indices(&self) -> &[usize] {
self.pk_indices_cache.get_or_init(|| {
Arc::new(
self.columns
.iter()
.enumerate()
.filter(|(_, c)| c.primary_key)
.map(|(i, _)| i)
.collect(),
)
})
}
#[inline]
pub fn pk_column_index(&self) -> Option<usize> {
*self.pk_column_index_cache.get_or_init(|| {
for (i, col) in self.columns.iter().enumerate() {
if col.primary_key && col.data_type == DataType::Integer {
return Some(i);
}
}
None
})
}
pub fn validate_column_count(&self, expected: usize) -> Result<()> {
if self.columns.len() != expected {
return Err(Error::table_columns_not_match(expected, self.columns.len()));
}
Ok(())
}
pub fn mark_updated(&mut self) {
self.updated_at = Utc::now();
}
fn rebuild_caches(&mut self) {
self.column_names_cache = OnceLock::new();
let _ = self.column_names_cache.set(CompactArc::new(
self.columns.iter().map(|c| c.name.clone()).collect(),
));
self.pk_column_index_cache = OnceLock::new();
let pk_idx = self
.columns
.iter()
.enumerate()
.find(|(_, col)| col.primary_key && col.data_type == DataType::Integer)
.map(|(i, _)| i);
let _ = self.pk_column_index_cache.set(pk_idx);
self.column_index_map_cache = OnceLock::new();
let _ = self.column_index_map_cache.set(
self.columns
.iter()
.enumerate()
.map(|(i, c)| (c.name_lower.clone(), i))
.collect(),
);
self.pk_indices_cache = OnceLock::new();
let _ = self.pk_indices_cache.set(Arc::new(
self.columns
.iter()
.enumerate()
.filter(|(_, c)| c.primary_key)
.map(|(i, _)| i)
.collect(),
));
self.column_names_lower_cache = OnceLock::new();
let _ = self.column_names_lower_cache.set(CompactArc::new(
self.columns.iter().map(|c| c.name_lower.clone()).collect(),
));
}
pub fn add_column(&mut self, column: SchemaColumn) -> Result<()> {
if self.has_column(&column.name) {
return Err(Error::DuplicateColumn);
}
self.columns.push(column);
self.mark_updated();
self.rebuild_caches();
Ok(())
}
pub fn remove_column(&mut self, name: &str) -> Result<SchemaColumn> {
let idx = self
.get_column_index(name)
.ok_or_else(|| Error::ColumnNotFound(name.to_string()))?;
let column = self.columns.remove(idx);
for (i, col) in self.columns.iter_mut().enumerate() {
col.id = i;
}
self.mark_updated();
self.rebuild_caches();
Ok(column)
}
pub fn rename_column(&mut self, old_name: &str, new_name: impl Into<String>) -> Result<()> {
let new_name = new_name.into();
if self.has_column(&new_name) {
return Err(Error::DuplicateColumn);
}
let idx = self
.get_column_index(old_name)
.ok_or_else(|| Error::ColumnNotFound(old_name.to_string()))?;
self.columns[idx].name_lower = new_name.to_lowercase();
self.columns[idx].name = new_name;
self.mark_updated();
self.rebuild_caches();
Ok(())
}
pub fn modify_column(
&mut self,
name: &str,
data_type: Option<DataType>,
nullable: Option<bool>,
) -> Result<()> {
let idx = self
.get_column_index(name)
.ok_or_else(|| Error::ColumnNotFound(name.to_string()))?;
if let Some(dt) = data_type {
self.columns[idx].data_type = dt;
}
if let Some(n) = nullable {
self.columns[idx].nullable = n;
}
self.mark_updated();
self.rebuild_caches();
Ok(())
}
}
impl Default for Schema {
fn default() -> Self {
Self::new("", Vec::new())
}
}
impl fmt::Display for Schema {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CREATE TABLE {} (", self.table_name)?;
for (i, col) in self.columns.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", col)?;
}
write!(f, ")")
}
}
pub struct SchemaBuilder {
table_name: String,
columns: Vec<SchemaColumn>,
foreign_keys: Vec<ForeignKeyConstraint>,
}
impl SchemaBuilder {
pub fn new(table_name: impl Into<String>) -> Self {
Self {
table_name: table_name.into(),
columns: Vec::new(),
foreign_keys: Vec::new(),
}
}
pub fn column(
mut self,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
primary_key: bool,
) -> Self {
let id = self.columns.len();
self.columns.push(SchemaColumn::new(
id,
name,
data_type,
nullable,
primary_key,
));
self
}
pub fn add(self, name: impl Into<String>, data_type: DataType) -> Self {
self.column(name, data_type, false, false)
}
pub fn add_nullable(self, name: impl Into<String>, data_type: DataType) -> Self {
self.column(name, data_type, true, false)
}
pub fn add_primary_key(self, name: impl Into<String>, data_type: DataType) -> Self {
self.column(name, data_type, false, true)
}
#[allow(clippy::too_many_arguments)]
pub fn add_with_constraints(
mut self,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
primary_key: bool,
auto_increment: bool,
default_expr: Option<String>,
check_expr: Option<String>,
) -> Self {
let id = self.columns.len();
self.columns.push(SchemaColumn::with_constraints(
id,
name,
data_type,
nullable,
primary_key,
auto_increment,
default_expr,
check_expr,
));
self
}
pub fn set_last_vector_dimensions(mut self, dims: u16) -> Self {
if let Some(col) = self.columns.last_mut() {
col.vector_dimensions = dims;
}
self
}
pub fn column_index(&self, name: &str) -> Option<usize> {
let lower = name.to_lowercase();
self.columns
.iter()
.position(|c| c.name.to_lowercase() == lower)
}
pub fn is_column_nullable(&self, idx: usize) -> bool {
self.columns.get(idx).is_some_and(|c| c.nullable)
}
pub fn add_foreign_key(mut self, fk: ForeignKeyConstraint) -> Self {
self.foreign_keys.push(fk);
self
}
pub fn build(self) -> Schema {
Schema::with_foreign_keys(self.table_name, self.columns, self.foreign_keys)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_schema() -> Schema {
SchemaBuilder::new("users")
.add_primary_key("id", DataType::Integer)
.add("name", DataType::Text)
.add_nullable("email", DataType::Text)
.add("active", DataType::Boolean)
.build()
}
#[test]
fn test_schema_column_creation() {
let col = SchemaColumn::new(0, "id", DataType::Integer, false, true);
assert_eq!(col.id, 0);
assert_eq!(col.name, "id");
assert_eq!(col.data_type, DataType::Integer);
assert!(!col.nullable);
assert!(col.primary_key);
}
#[test]
fn test_schema_column_helpers() {
let simple = SchemaColumn::simple(0, "name", DataType::Text);
assert!(!simple.nullable);
assert!(!simple.primary_key);
let nullable = SchemaColumn::nullable(1, "email", DataType::Text);
assert!(nullable.nullable);
assert!(!nullable.primary_key);
let pk = SchemaColumn::primary_key(2, "id", DataType::Integer);
assert!(!pk.nullable);
assert!(pk.primary_key);
}
#[test]
fn test_schema_creation() {
let schema = create_test_schema();
assert_eq!(schema.table_name, "users");
assert_eq!(schema.column_count(), 4);
assert!(!schema.is_empty());
}
#[test]
fn test_schema_find_column() {
let schema = create_test_schema();
let (idx, col) = schema.find_column("name").unwrap();
assert_eq!(idx, 1);
assert_eq!(col.name, "name");
let (idx, _) = schema.find_column("NAME").unwrap();
assert_eq!(idx, 1);
assert!(schema.find_column("nonexistent").is_none());
}
#[test]
fn test_schema_get_column() {
let schema = create_test_schema();
let col = schema.get_column(0).unwrap();
assert_eq!(col.name, "id");
let col = schema.get_column_by_name("email").unwrap();
assert_eq!(col.data_type, DataType::Text);
assert!(col.nullable);
assert!(schema.get_column(100).is_none());
}
#[test]
fn test_schema_column_names() {
let schema = create_test_schema();
let names = schema.column_names();
assert_eq!(names, vec!["id", "name", "email", "active"]);
}
#[test]
fn test_schema_primary_key() {
let schema = create_test_schema();
assert!(schema.has_primary_key());
let pk_cols = schema.primary_key_columns();
assert_eq!(pk_cols.len(), 1);
assert_eq!(pk_cols[0].name, "id");
let pk_indices = schema.primary_key_indices();
assert_eq!(pk_indices, vec![0]);
}
#[test]
fn test_schema_validate_column_count() {
let schema = create_test_schema();
assert!(schema.validate_column_count(4).is_ok());
let err = schema.validate_column_count(3).unwrap_err();
assert!(matches!(
err,
Error::TableColumnsNotMatch {
expected: 3,
got: 4
}
));
}
#[test]
fn test_schema_add_column() {
let mut schema = create_test_schema();
let original_count = schema.column_count();
schema
.add_column(SchemaColumn::simple(
original_count,
"age",
DataType::Integer,
))
.unwrap();
assert_eq!(schema.column_count(), original_count + 1);
assert!(schema.has_column("age"));
let err = schema
.add_column(SchemaColumn::simple(0, "age", DataType::Integer))
.unwrap_err();
assert!(matches!(err, Error::DuplicateColumn));
}
#[test]
fn test_schema_remove_column() {
let mut schema = create_test_schema();
let removed = schema.remove_column("email").unwrap();
assert_eq!(removed.name, "email");
assert_eq!(schema.column_count(), 3);
assert!(!schema.has_column("email"));
assert_eq!(schema.columns[2].id, 2);
assert!(schema.remove_column("nonexistent").is_err());
}
#[test]
fn test_schema_rename_column() {
let mut schema = create_test_schema();
schema.rename_column("name", "full_name").unwrap();
assert!(schema.has_column("full_name"));
assert!(!schema.has_column("name"));
let err = schema.rename_column("full_name", "id").unwrap_err();
assert!(matches!(err, Error::DuplicateColumn));
assert!(schema.rename_column("nonexistent", "new_name").is_err());
}
#[test]
fn test_schema_modify_column() {
let mut schema = create_test_schema();
schema
.modify_column("name", Some(DataType::Json), Some(true))
.unwrap();
let col = schema.get_column_by_name("name").unwrap();
assert_eq!(col.data_type, DataType::Json);
assert!(col.nullable);
assert!(schema
.modify_column("nonexistent", None, Some(true))
.is_err());
}
#[test]
fn test_schema_column_display() {
let col = SchemaColumn::new(0, "id", DataType::Integer, false, true);
assert_eq!(col.to_string(), "id INTEGER PRIMARY KEY");
let col = SchemaColumn::new(1, "name", DataType::Text, false, false);
assert_eq!(col.to_string(), "name TEXT NOT NULL");
let col = SchemaColumn::new(2, "email", DataType::Text, true, false);
assert_eq!(col.to_string(), "email TEXT");
}
#[test]
fn test_schema_display() {
let schema = SchemaBuilder::new("users")
.add_primary_key("id", DataType::Integer)
.add("name", DataType::Text)
.build();
let expected = "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL)";
assert_eq!(schema.to_string(), expected);
}
#[test]
fn test_schema_builder() {
let schema = SchemaBuilder::new("products")
.add_primary_key("id", DataType::Integer)
.add("name", DataType::Text)
.add_nullable("description", DataType::Text)
.add("price", DataType::Float)
.build();
assert_eq!(schema.table_name, "products");
assert_eq!(schema.column_count(), 4);
assert!(schema.get_column_by_name("id").unwrap().primary_key);
assert!(schema.get_column_by_name("description").unwrap().nullable);
}
#[test]
fn test_schema_timestamps() {
let schema1 = Schema::new("test", vec![]);
std::thread::sleep(std::time::Duration::from_millis(10));
let schema2 = Schema::new("test", vec![]);
assert!(schema2.created_at >= schema1.created_at);
}
#[test]
fn test_schema_get_column_type() {
let schema = create_test_schema();
assert_eq!(schema.get_column_type("id"), Some(DataType::Integer));
assert_eq!(schema.get_column_type("name"), Some(DataType::Text));
assert_eq!(schema.get_column_type("active"), Some(DataType::Boolean));
assert_eq!(schema.get_column_type("nonexistent"), None);
}
}