use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ColumnType {
Timestamp,
Text,
Tensor(usize),
Integer,
Float,
Boolean,
Spatial,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnDef {
pub name: String,
pub col_type: ColumnType,
pub position: usize,
pub nullable: bool,
#[serde(default)]
pub auto_increment: bool,
#[serde(default)]
pub auto_increment_start: Option<i64>,
}
pub type Column = ColumnDef;
impl ColumnDef {
pub fn new(name: String, col_type: ColumnType, position: usize) -> Self {
Self {
name,
col_type,
position,
nullable: true,
auto_increment: false,
auto_increment_start: None,
}
}
pub fn not_null(mut self) -> Self {
self.nullable = false;
self
}
pub fn auto_increment(mut self) -> Self {
self.auto_increment = true;
self
}
pub fn auto_increment_with_start(mut self, start: i64) -> Self {
self.auto_increment = true;
self.auto_increment_start = Some(start);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum IndexType {
BTree,
FullText,
Vector { dimension: usize },
Spatial,
Timestamp,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexDef {
pub name: String,
pub table_name: String,
pub column_name: String,
pub index_type: IndexType,
}
impl IndexDef {
pub fn new(
name: String,
table_name: String,
column_name: String,
index_type: IndexType,
) -> Self {
Self {
name,
table_name,
column_name,
index_type,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableSchema {
pub name: String,
pub columns: Vec<ColumnDef>,
pub indexes: Vec<IndexDef>,
pub primary_key_column: Option<String>,
#[serde(default)]
pub primary_key_auto_increment: bool,
#[serde(default)]
pub auto_increment_start: Option<i64>,
#[serde(skip)]
column_map: HashMap<String, usize>,
}
impl TableSchema {
pub fn new(name: String, columns: Vec<ColumnDef>) -> Self {
let mut column_map = HashMap::new();
for col in &columns {
column_map.insert(col.name.clone(), col.position);
}
Self {
name,
columns,
indexes: Vec::new(),
primary_key_column: None,
primary_key_auto_increment: false,
auto_increment_start: None,
column_map,
}
}
pub fn with_primary_key(mut self, pk_column: String) -> Self {
self.primary_key_column = Some(pk_column);
self
}
pub fn with_auto_increment(mut self) -> Self {
self.primary_key_auto_increment = true;
if let Some(pk_col_name) = &self.primary_key_column {
if let Some(col) = self.columns.iter_mut().find(|c| &c.name == pk_col_name) {
col.auto_increment = true;
}
}
self
}
pub fn with_auto_increment_start(mut self, start: i64) -> Self {
self.primary_key_auto_increment = true;
self.auto_increment_start = Some(start);
if let Some(pk_col_name) = &self.primary_key_column {
if let Some(col) = self.columns.iter_mut().find(|c| &c.name == pk_col_name) {
col.auto_increment = true;
col.auto_increment_start = Some(start);
}
}
self
}
pub fn get_auto_increment_start(&self) -> i64 {
self.auto_increment_start.unwrap_or(1)
}
pub fn primary_key(&self) -> Option<&str> {
self.primary_key_column.as_deref()
}
pub fn is_primary_key_auto_increment(&self) -> bool {
self.primary_key_auto_increment
}
pub fn add_index(&mut self, index: IndexDef) {
self.indexes.push(index);
}
pub fn get_column(&self, name: &str) -> Option<&ColumnDef> {
self.columns.iter().find(|c| c.name == name)
}
pub fn get_column_position(&self, name: &str) -> Option<usize> {
self.column_map.get(name).copied()
}
pub fn column_count(&self) -> usize {
self.columns.len()
}
pub fn rebuild_column_map(&mut self) {
self.column_map.clear();
for col in &self.columns {
self.column_map.insert(col.name.clone(), col.position);
}
}
pub fn validate_row(&self, row: &[crate::types::Value]) -> Result<(), String> {
if row.len() != self.columns.len() {
return Err(format!(
"Column count mismatch: expected {}, got {}",
self.columns.len(),
row.len()
));
}
for (i, col) in self.columns.iter().enumerate() {
let value = &row[i];
if !col.nullable && matches!(value, crate::types::Value::Text(t) if t.is_empty()) {
return Err(format!("Column '{}' cannot be null", col.name));
}
let type_match = match (&col.col_type, value) {
(ColumnType::Integer, crate::types::Value::Integer(_)) => true,
(ColumnType::Float, crate::types::Value::Float(_)) => true,
(ColumnType::Float, crate::types::Value::Integer(_)) => true, (ColumnType::Boolean, crate::types::Value::Bool(_)) => true,
(ColumnType::Text, crate::types::Value::Text(_)) => true,
(ColumnType::Spatial, crate::types::Value::Spatial(_)) => true,
(ColumnType::Timestamp, crate::types::Value::Timestamp(_)) => true,
(ColumnType::Tensor(dim), crate::types::Value::Tensor(t)) => t.dimension() == *dim,
(ColumnType::Tensor(dim), crate::types::Value::Vector(v)) => v.len() == *dim,
(ColumnType::Integer, crate::types::Value::Timestamp(_)) => true,
(ColumnType::Float, crate::types::Value::Tensor(t)) if t.dimension() == 1 => true,
_ => false,
};
if !type_match {
return Err(format!(
"Type mismatch for column '{}': expected {:?}",
col.name, col.col_type
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Timestamp, Value};
#[test]
fn test_column_def() {
let col = ColumnDef::new("id".into(), ColumnType::Integer, 0).not_null();
assert_eq!(col.name, "id");
assert_eq!(col.position, 0);
assert!(!col.nullable);
}
#[test]
fn test_table_schema() {
let mut schema = TableSchema::new(
"users".into(),
vec![
ColumnDef::new("id".into(), ColumnType::Integer, 0).not_null(),
ColumnDef::new("name".into(), ColumnType::Text, 1),
ColumnDef::new("created_at".into(), ColumnType::Timestamp, 2),
],
);
assert_eq!(schema.column_count(), 3);
assert_eq!(schema.get_column_position("name"), Some(1));
schema.add_index(IndexDef::new(
"users_name_idx".into(),
"users".into(),
"name".into(),
IndexType::FullText,
));
assert_eq!(schema.indexes.len(), 1);
}
#[test]
fn test_validate_row() {
let schema = TableSchema::new(
"test".into(),
vec![
ColumnDef::new("id".into(), ColumnType::Timestamp, 0),
ColumnDef::new("name".into(), ColumnType::Text, 1),
],
);
let row = vec![
Value::Timestamp(Timestamp::from_micros(123)),
Value::Text("test".to_string()),
];
assert!(schema.validate_row(&row).is_ok());
let row = vec![Value::Timestamp(Timestamp::from_micros(123))];
assert!(schema.validate_row(&row).is_err());
}
}