Skip to main content

regulus_db/types/
schema.rs

1use super::value::DbValue;
2use serde::{Serialize, Deserialize};
3
4/// 数据类型定义
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
6pub enum DataType {
7    Null,
8    Integer,
9    Real,
10    Text { max_length: Option<usize> },
11    Blob { max_length: Option<usize> },
12    Boolean,
13    Date,
14    Datetime,
15}
16
17impl DataType {
18    pub fn null() -> Self {
19        DataType::Null
20    }
21
22    pub fn integer() -> Self {
23        DataType::Integer
24    }
25
26    pub fn real() -> Self {
27        DataType::Real
28    }
29
30    pub fn text() -> Self {
31        DataType::Text { max_length: None }
32    }
33
34    pub fn text_with_max(max_length: usize) -> Self {
35        DataType::Text { max_length: Some(max_length) }
36    }
37
38    pub fn blob() -> Self {
39        DataType::Blob { max_length: None }
40    }
41
42    pub fn blob_with_max(max_length: usize) -> Self {
43        DataType::Blob { max_length: Some(max_length) }
44    }
45
46    pub fn boolean() -> Self {
47        DataType::Boolean
48    }
49
50    pub fn date() -> Self {
51        DataType::Date
52    }
53
54    pub fn datetime() -> Self {
55        DataType::Datetime
56    }
57
58    /// 验证值是否与该类型兼容
59    pub fn validate(&self, value: &DbValue) -> bool {
60        match (self, value) {
61            (DataType::Null, DbValue::Null) => true,
62            (DataType::Integer, DbValue::Integer(_)) => true,
63            (DataType::Real, DbValue::Real(_)) => true,
64            (DataType::Real, DbValue::Integer(_)) => true, // 整数可以隐式转换为实数
65            (DataType::Text { max_length }, DbValue::Text(s)) => {
66                max_length.map_or(true, |max| s.len() <= max)
67            }
68            (DataType::Blob { max_length }, DbValue::Blob(b)) => {
69                max_length.map_or(true, |max| b.len() <= max)
70            }
71            (DataType::Boolean, DbValue::Boolean(_)) => true,
72            (DataType::Date, DbValue::Date(_)) => true,
73            (DataType::Datetime, DbValue::Datetime(_)) => true,
74            _ => false,
75        }
76    }
77}
78
79impl std::fmt::Display for DataType {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            DataType::Null => write!(f, "NULL"),
83            DataType::Integer => write!(f, "INTEGER"),
84            DataType::Real => write!(f, "REAL"),
85            DataType::Text { max_length } => match max_length {
86                Some(max) => write!(f, "TEXT({})", max),
87                None => write!(f, "TEXT"),
88            },
89            DataType::Blob { max_length } => match max_length {
90                Some(max) => write!(f, "BLOB({})", max),
91                None => write!(f, "BLOB"),
92            },
93            DataType::Boolean => write!(f, "BOOLEAN"),
94            DataType::Date => write!(f, "DATE"),
95            DataType::Datetime => write!(f, "DATETIME"),
96        }
97    }
98}
99
100/// 列定义
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct Column {
103    pub name: String,
104    pub data_type: DataType,
105    pub nullable: bool,
106    pub primary_key: bool,
107    pub unique: bool,
108    pub default_value: Option<DbValue>,
109    pub auto_increment: bool,
110}
111
112impl Column {
113    pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
114        Column {
115            name: name.into(),
116            data_type,
117            nullable: true,
118            primary_key: false,
119            unique: false,
120            default_value: None,
121            auto_increment: false,
122        }
123    }
124
125    pub fn not_null(mut self) -> Self {
126        self.nullable = false;
127        self
128    }
129
130    pub fn primary_key(mut self) -> Self {
131        self.primary_key = true;
132        self.nullable = false;
133        self
134    }
135
136    pub fn unique(mut self) -> Self {
137        self.unique = true;
138        self
139    }
140
141    pub fn default(mut self, value: DbValue) -> Self {
142        self.default_value = Some(value);
143        self
144    }
145
146    pub fn auto_increment(mut self) -> Self {
147        self.auto_increment = true;
148        self
149    }
150}
151
152/// 表结构定义
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct TableSchema {
155    pub name: String,
156    pub columns: Vec<Column>,
157}
158
159impl TableSchema {
160    pub fn new(name: impl Into<String>, columns: Vec<Column>) -> Self {
161        TableSchema {
162            name: name.into(),
163            columns,
164        }
165    }
166
167    /// 根据列名获取列索引
168    pub fn column_index(&self, name: &str) -> Option<usize> {
169        self.columns.iter().position(|c| c.name == name)
170    }
171
172    /// 根据列名获取列定义
173    pub fn column(&self, name: &str) -> Option<&Column> {
174        self.columns.iter().find(|c| c.name == name)
175    }
176
177    /// 获取主键列
178    pub fn primary_key(&self) -> Option<&Column> {
179        self.columns.iter().find(|c| c.primary_key)
180    }
181
182    /// 获取自增列
183    pub fn auto_increment_column(&self) -> Option<&Column> {
184        self.columns.iter().find(|c| c.auto_increment)
185    }
186
187    /// 填充行的默认值
188    /// 对于行中缺少的列,如果该列有默认值,则自动填充
189    pub fn fill_defaults(&self, row: &mut crate::storage::Row) {
190        for column in &self.columns {
191            // 如果行中已经有该列的值,跳过
192            if row.contains_key(&column.name) {
193                continue;
194            }
195            // 如果该列有默认值,填充
196            if let Some(ref default_value) = column.default_value {
197                row.insert(column.name.clone(), default_value.clone());
198            }
199        }
200    }
201
202    /// 验证行值是否与 schema 匹配
203    pub fn validate(&self, values: &[(String, DbValue)]) -> Result<(), SchemaError> {
204        for (name, value) in values {
205            let column = self.column(name).ok_or_else(|| SchemaError::UnknownColumn {
206                table: self.name.clone(),
207                column: name.clone(),
208            })?;
209
210            if !column.data_type.validate(value) {
211                return Err(SchemaError::TypeMismatch {
212                    table: self.name.clone(),
213                    column: name.clone(),
214                    expected: column.data_type.clone(),
215                    actual: value.type_name().to_string(),
216                });
217            }
218
219            if !column.nullable && value.is_null() {
220                return Err(SchemaError::NotNullViolation {
221                    table: self.name.clone(),
222                    column: name.clone(),
223                });
224            }
225        }
226
227        Ok(())
228    }
229}
230
231/// Schema 错误类型
232#[derive(Debug, thiserror::Error)]
233pub enum SchemaError {
234    #[error("Unknown column '{column}' in table '{table}'")]
235    UnknownColumn { table: String, column: String },
236
237    #[error("Type mismatch for column '{column}' in table '{table}': expected {expected}, got {actual}")]
238    TypeMismatch {
239        table: String,
240        column: String,
241        expected: DataType,
242        actual: String,
243    },
244
245    #[error("NOT NULL constraint failed for column '{column}' in table '{table}'")]
246    NotNullViolation { table: String, column: String },
247
248    #[error("UNIQUE constraint failed for column '{column}' in table '{table}'")]
249    UniqueViolation { table: String, column: String },
250
251    #[error("PRIMARY KEY constraint failed for column '{column}' in table '{table}'")]
252    PrimaryKeyViolation { table: String, column: String },
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::storage::Row;
259
260    #[test]
261    fn test_column_builder() {
262        let col = Column::new("id", DataType::integer())
263            .not_null()
264            .primary_key();
265
266        assert_eq!(col.name, "id");
267        assert!(!col.nullable);
268        assert!(col.primary_key);
269    }
270
271    #[test]
272    fn test_text_with_max_length() {
273        let dt = DataType::text_with_max(50);
274        assert!(dt.validate(&DbValue::text("hello")));
275        assert!(!dt.validate(&DbValue::text("a".repeat(51))));
276    }
277
278    #[test]
279    fn test_schema_validation() {
280        let schema = TableSchema::new(
281            "users",
282            vec![
283                Column::new("id", DataType::integer()).primary_key(),
284                Column::new("name", DataType::text()).not_null(),
285                Column::new("age", DataType::integer()),
286            ],
287        );
288
289        let valid_values = vec![
290            ("id".to_string(), DbValue::integer(1)),
291            ("name".to_string(), DbValue::text("Alice")),
292            ("age".to_string(), DbValue::integer(25)),
293        ];
294
295        assert!(schema.validate(&valid_values).is_ok());
296    }
297
298    #[test]
299    fn test_fill_defaults() {
300        let schema = TableSchema::new(
301            "users",
302            vec![
303                Column::new("id", DataType::integer()).primary_key(),
304                Column::new("name", DataType::text()).not_null(),
305                Column::new("status", DataType::text()).default(DbValue::text("active")),
306                Column::new("age", DataType::integer()).default(DbValue::integer(0)),
307                Column::new("active", DataType::boolean()).default(DbValue::boolean(true)),
308            ],
309        );
310
311        // 只填充部分字段
312        let mut row = Row::new();
313        row.insert("id".to_string(), DbValue::integer(1));
314        row.insert("name".to_string(), DbValue::text("Alice"));
315
316        // 填充默认值
317        schema.fill_defaults(&mut row);
318
319        // 验证默认值已填充
320        assert_eq!(row.get("status").unwrap().as_text(), Some("active"));
321        assert_eq!(row.get("age").unwrap().as_integer(), Some(0));
322        assert_eq!(row.get("active").unwrap().as_boolean(), Some(true));
323
324        // 验证原有值未被覆盖
325        assert_eq!(row.get("id").unwrap().as_integer(), Some(1));
326        assert_eq!(row.get("name").unwrap().as_text(), Some("Alice"));
327    }
328
329    #[test]
330    fn test_fill_defaults_with_explicit_value() {
331        let schema = TableSchema::new(
332            "users",
333            vec![
334                Column::new("id", DataType::integer()).primary_key(),
335                Column::new("name", DataType::text()).not_null(),
336                Column::new("status", DataType::text()).default(DbValue::text("active")),
337            ],
338        );
339
340        // 显式提供所有字段的值
341        let mut row = Row::new();
342        row.insert("id".to_string(), DbValue::integer(1));
343        row.insert("name".to_string(), DbValue::text("Bob"));
344        row.insert("status".to_string(), DbValue::text("inactive"));
345
346        // 填充默认值
347        schema.fill_defaults(&mut row);
348
349        // 验证显式值未被覆盖
350        assert_eq!(row.get("status").unwrap().as_text(), Some("inactive"));
351    }
352}