Skip to main content

regulus_db/transaction/
mod.rs

1//! 事务模块 - 泛型实现,支持 MemoryEngine 和 PersistedEngine
2//!
3//! 设计策略:
4//! 1. 使用 trait object (&'a mut dyn StorageEngine) 实现代码复用
5//! 2. 事务操作直接执行到引擎(简化实现)
6//! 3. PersistedEngine 内部处理事务 WAL 标记
7//! 4. 回滚功能:通过 WriteSet 记录已执行的操作,rollback 时反向操作
8
9use crate::storage::{StorageEngine, Row, RowId};
10use crate::types::{DbValue, DbResult, DbError};
11
12/// 写操作日志(用于回滚)
13#[derive(Debug)]
14pub enum WriteLog {
15    /// 插入前的状态(回滚时需要删除)
16    Inserted { table: String, row_id: RowId },
17    /// 更新前的状态(回滚时需要恢复)
18    Updated { table: String, row_id: RowId, old_row: Row },
19    /// 删除前的状态(回滚时需要重新插入)
20    Deleted { table: String, row_id: RowId, row: Row },
21}
22
23/// 事务包装器 - 使用 trait object 实现泛型
24pub struct Transaction<'a> {
25    engine: &'a mut dyn StorageEngine,
26    write_log: Vec<WriteLog>,
27    is_committed: bool,
28    /// 是否启用回滚日志(MemoryEngine 可以禁用,PersistedEngine 启用)
29    enable_rollback_log: bool,
30}
31
32impl<'a> Transaction<'a> {
33    /// 创建新事务(默认启用回滚日志)
34    pub fn new(engine: &'a mut dyn StorageEngine) -> Self {
35        Transaction {
36            engine,
37            write_log: Vec::new(),
38            is_committed: false,
39            enable_rollback_log: true,
40        }
41    }
42
43    /// 创建新事务(带回滚日志开关)
44    pub fn with_rollback_log(engine: &'a mut dyn StorageEngine, enable_log: bool) -> Self {
45        Transaction {
46            engine,
47            write_log: Vec::new(),
48            is_committed: false,
49            enable_rollback_log: enable_log,
50        }
51    }
52
53    /// 提交事务
54    pub fn commit(&mut self) -> DbResult<()> {
55        if self.is_committed {
56            return Err(DbError::TransactionError(
57                "Transaction already committed".to_string()
58            ));
59        }
60        // 清空日志(提交后不再需要回滚)
61        self.write_log.clear();
62        self.is_committed = true;
63        Ok(())
64    }
65
66    /// 回滚事务
67    ///
68    /// 反向应用所有写操作:
69    /// - Inserted: 删除该行
70    /// - Updated: 恢复旧值
71    /// - Deleted: 重新插入该行
72    pub fn rollback(&mut self) -> DbResult<()> {
73        if self.is_committed {
74            return Err(DbError::TransactionError(
75                "Cannot rollback a committed transaction".to_string()
76            ));
77        }
78
79        if !self.enable_rollback_log {
80            self.write_log.clear();
81            return Ok(());
82        }
83
84        // 反向遍历日志
85        for log in self.write_log.drain(..).rev() {
86            match log {
87                WriteLog::Inserted { table, row_id } => {
88                    // 删除插入的行
89                    let _ = self.engine.delete(&table, row_id);
90                }
91                WriteLog::Updated { table, row_id, old_row } => {
92                    // 恢复旧值
93                    let _ = self.engine.update(&table, row_id, old_row);
94                }
95                WriteLog::Deleted { table, row_id: _, row } => {
96                    // 重新插入删除的行
97                    let _ = self.engine.insert(&table, row);
98                }
99            }
100        }
101
102        Ok(())
103    }
104
105    /// 插入数据
106    pub fn insert(&mut self, table: &str, values: Vec<(&str, DbValue)>) -> DbResult<RowId> {
107        // 获取 schema
108        let schema = self.engine.get_schema(table)?.clone();
109
110        // 构建行
111        let mut row = Row::new();
112        for (name, value) in values {
113            row.insert(name.to_string(), value);
114        }
115
116        // 填充默认值
117        schema.fill_defaults(&mut row);
118
119        // 执行插入
120        let row_id = self.engine.insert(table, row.clone())?;
121
122        // 记录日志(用于回滚)
123        if self.enable_rollback_log {
124            self.write_log.push(WriteLog::Inserted {
125                table: table.to_string(),
126                row_id,
127            });
128        }
129
130        Ok(row_id)
131    }
132
133    /// 获取引擎的不可变引用(用于查询等操作)
134    pub fn engine(&self) -> &dyn StorageEngine {
135        self.engine
136    }
137
138    /// 获取引擎的可变引用
139    pub fn engine_mut(&mut self) -> &mut dyn StorageEngine {
140        self.engine
141    }
142
143    /// 查询表中所有行(简化查询接口)
144    pub fn query_all(&self, table: &str) -> DbResult<Vec<Row>> {
145        let rows = self.engine.scan(table)?;
146        Ok(rows.into_iter().map(|(_, row)| row.clone()).collect())
147    }
148
149    /// 直接访问底层引擎执行操作(不记录回滚日志)
150    /// 适用于查询等只读操作
151    pub fn with_engine<F, R>(&self, f: F) -> R
152    where
153        F: FnOnce(&dyn StorageEngine) -> R,
154    {
155        f(self.engine)
156    }
157
158    /// 直接访问底层引擎执行操作(可变)
159    pub fn with_engine_mut<F, R>(&mut self, f: F) -> R
160    where
161        F: FnOnce(&mut dyn StorageEngine) -> R,
162    {
163        f(self.engine)
164    }
165
166    /// 更新操作(带条件)
167    pub fn update<F>(&mut self, table: &str, condition: F, updates: Vec<(&str, DbValue)>) -> DbResult<usize>
168    where
169        F: Fn(&Row) -> bool,
170    {
171        // 获取 schema 验证列名
172        let schema = self.engine.get_schema(table)?;
173
174        // 构建更新值
175        let mut new_values = Row::new();
176        for (name, value) in updates {
177            new_values.insert(name.to_string(), value);
178        }
179        schema.validate(&new_values.iter().map(|(k, v)| (k.clone(), v.clone())).collect::<Vec<_>>())?;
180
181        // 扫描全表,找到匹配的行并克隆出来(避免借用冲突)
182        let rows = self.engine.scan(table)?;
183        let matching_rows: Vec<(RowId, Row)> = rows
184            .into_iter()
185            .filter(|(_, row)| condition(row))
186            .map(|(row_id, row)| (row_id, row.clone()))
187            .collect();
188
189        let mut updated_count = 0;
190
191        for (row_id, old_row) in matching_rows {
192            // 记录旧值(用于回滚)
193            if self.enable_rollback_log {
194                self.write_log.push(WriteLog::Updated {
195                    table: table.to_string(),
196                    row_id,
197                    old_row: old_row.clone(),
198                });
199            }
200
201            // 合并新旧值
202            let mut updated_row = old_row.clone();
203            for (key, value) in new_values.iter() {
204                updated_row.insert(key.clone(), value.clone());
205            }
206
207            // 执行更新
208            self.engine.update(table, row_id, updated_row)?;
209            updated_count += 1;
210        }
211
212        Ok(updated_count)
213    }
214
215    /// 删除操作(带条件)
216    pub fn delete<F>(&mut self, table: &str, condition: F) -> DbResult<usize>
217    where
218        F: Fn(&Row) -> bool,
219    {
220        // 扫描全表,找到匹配的行并克隆(避免借用冲突)
221        let rows = self.engine.scan(table)?;
222        let rows_to_delete: Vec<(RowId, Row)> = rows
223            .into_iter()
224            .filter(|(_, row)| condition(row))
225            .map(|(row_id, row)| (row_id, row.clone()))
226            .collect();
227
228        let mut deleted_count = 0;
229
230        for (row_id, row) in rows_to_delete {
231            // 记录旧值(用于回滚)- 保存完整的行数据
232            if self.enable_rollback_log {
233                self.write_log.push(WriteLog::Deleted {
234                    table: table.to_string(),
235                    row_id,
236                    row,
237                });
238            }
239
240            self.engine.delete(table, row_id)?;
241            deleted_count += 1;
242        }
243
244        Ok(deleted_count)
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::storage::MemoryEngine;
252    use crate::types::{DataType, Column, TableSchema};
253
254    fn create_test_schema() -> TableSchema {
255        TableSchema::new(
256            "users",
257            vec![
258                Column::new("id", DataType::integer()).primary_key(),
259                Column::new("name", DataType::text()),
260            ],
261        )
262    }
263
264    #[test]
265    fn test_transaction_insert() {
266        let mut engine = MemoryEngine::new();
267        engine.create_table(create_test_schema()).unwrap();
268
269        let mut tx = Transaction::new(&mut engine);
270        let row_id = tx.insert("users", vec![
271            ("id", DbValue::integer(1)),
272            ("name", DbValue::text("Alice")),
273        ]).unwrap();
274
275        assert_eq!(row_id.0, 0);
276
277        tx.commit().unwrap();
278
279        // 验证提交后数据存在
280        let rows = engine.scan("users").unwrap();
281        assert_eq!(rows.len(), 1);
282    }
283
284    #[test]
285    fn test_transaction_rollback() {
286        let mut engine = MemoryEngine::new();
287        engine.create_table(create_test_schema()).unwrap();
288
289        {
290            let mut tx = Transaction::new(&mut engine);
291            tx.insert("users", vec![
292                ("id", DbValue::integer(1)),
293                ("name", DbValue::text("Alice")),
294            ]).unwrap();
295
296            // 回滚
297            tx.rollback().unwrap();
298        }
299
300        // 验证回滚后数据不存在
301        let rows = engine.scan("users").unwrap();
302        assert!(rows.is_empty());
303    }
304
305    #[test]
306    fn test_transaction_update_rollback() {
307        let mut engine = MemoryEngine::new();
308        engine.create_table(create_test_schema()).unwrap();
309
310        // 先插入一条数据
311        let mut row = Row::new();
312        row.insert("id".to_string(), DbValue::integer(1));
313        row.insert("name".to_string(), DbValue::text("Alice"));
314        engine.insert("users", row).unwrap();
315
316        {
317            let mut tx = Transaction::new(&mut engine);
318
319            // 更新
320            tx.update(
321                "users",
322                |row| row.get("id").and_then(|v| v.as_integer()) == Some(1),
323                vec![("name", DbValue::text("Bob"))],
324            ).unwrap();
325
326            // 回滚
327            tx.rollback().unwrap();
328        }
329
330        // 验证回滚后数据恢复
331        let rows = engine.scan("users").unwrap();
332        assert_eq!(rows[0].1.get("name").unwrap().as_text(), Some("Alice"));
333    }
334
335    #[test]
336    fn test_transaction_delete_rollback() {
337        let mut engine = MemoryEngine::new();
338        engine.create_table(create_test_schema()).unwrap();
339
340        // 先插入一条数据
341        let mut row = Row::new();
342        row.insert("id".to_string(), DbValue::integer(1));
343        row.insert("name".to_string(), DbValue::text("Alice"));
344        engine.insert("users", row).unwrap();
345
346        {
347            let mut tx = Transaction::new(&mut engine);
348
349            // 删除
350            tx.delete(
351                "users",
352                |row| row.get("id").and_then(|v| v.as_integer()) == Some(1),
353            ).unwrap();
354
355            // 回滚
356            tx.rollback().unwrap();
357        }
358
359        // 验证回滚后数据恢复
360        let rows = engine.scan("users").unwrap();
361        assert_eq!(rows.len(), 1);
362        assert_eq!(rows[0].1.get("name").unwrap().as_text(), Some("Alice"));
363    }
364
365    #[test]
366    fn test_transaction_insert_with_default_values_rollback() {
367        let mut engine = MemoryEngine::new();
368        let schema = TableSchema::new(
369            "users",
370            vec![
371                Column::new("id", DataType::integer()).primary_key(),
372                Column::new("name", DataType::text()).not_null(),
373                Column::new("status", DataType::text()).default(DbValue::text("active")),
374                Column::new("age", DataType::integer()).default(DbValue::integer(0)),
375                Column::new("active", DataType::boolean()).default(DbValue::boolean(true)),
376            ],
377        );
378        engine.create_table(schema).unwrap();
379
380        {
381            let mut tx = Transaction::new(&mut engine);
382            // 插入时依赖默认值
383            tx.insert("users", vec![
384                ("id", DbValue::integer(1)),
385                ("name", DbValue::text("Alice")),
386            ]).unwrap();
387
388            // 回滚
389            tx.rollback().unwrap();
390        }
391
392        // 验证回滚后数据不存在(包括默认值填充的字段)
393        let rows = engine.scan("users").unwrap();
394        assert!(rows.is_empty());
395    }
396
397    #[test]
398    fn test_transaction_insert_with_default_values_commit() {
399        let mut engine = MemoryEngine::new();
400        let schema = TableSchema::new(
401            "users",
402            vec![
403                Column::new("id", DataType::integer()).primary_key(),
404                Column::new("name", DataType::text()).not_null(),
405                Column::new("status", DataType::text()).default(DbValue::text("active")),
406                Column::new("age", DataType::integer()).default(DbValue::integer(0)),
407            ],
408        );
409        engine.create_table(schema).unwrap();
410
411        {
412            let mut tx = Transaction::new(&mut engine);
413            // 插入时依赖默认值
414            tx.insert("users", vec![
415                ("id", DbValue::integer(1)),
416                ("name", DbValue::text("Alice")),
417            ]).unwrap();
418
419            // 提交
420            tx.commit().unwrap();
421        }
422
423        // 验证提交后默认值已正确填充
424        let rows = engine.scan("users").unwrap();
425        assert_eq!(rows.len(), 1);
426        let row = &rows[0].1;
427        assert_eq!(row.get("status").unwrap().as_text(), Some("active"));
428        assert_eq!(row.get("age").unwrap().as_integer(), Some(0));
429    }
430}