Skip to main content

regulus_db/storage/
persisted_engine.rs

1use crate::storage::{MemoryEngine, StorageEngine, Row, RowId};
2use crate::persistence::PersistenceManager;
3use crate::persistence::wal::WalOperation;
4use crate::types::{TableSchema, DbResult};
5use std::path::Path;
6use std::sync::{Arc, RwLock};
7
8/// 持久化存储引擎
9/// 包装 MemoryEngine + PersistenceManager,提供 WAL 日志功能
10pub struct PersistedEngine {
11    inner: Arc<RwLock<MemoryEngine>>,
12    persistence: PersistenceManager,
13}
14
15impl PersistedEngine {
16    /// 打开持久化数据库
17    pub fn open(path: &Path) -> DbResult<Self> {
18        let mut persistence = PersistenceManager::new(path)?;
19        let engine = persistence.restore()?;
20
21        Ok(PersistedEngine {
22            inner: Arc::new(RwLock::new(engine)),
23            persistence,
24        })
25    }
26
27    /// 创建新的持久化数据库(如果已存在则覆盖)
28    pub fn create(path: &Path) -> DbResult<Self> {
29        // 删除现有的 WAL 和快照文件
30        let wal_path = path.join("wal.bin");
31        let snapshot_path = path.join("snapshot.bin");
32
33        let _ = std::fs::remove_file(&wal_path);
34        let _ = std::fs::remove_file(&snapshot_path);
35
36        let persistence = PersistenceManager::new(path)?;
37        let engine = MemoryEngine::new();
38
39        Ok(PersistedEngine {
40            inner: Arc::new(RwLock::new(engine)),
41            persistence,
42        })
43    }
44
45    /// 获取内部 MemoryEngine 的 Arc 克隆(用于 QueryBuilder 等)
46    pub fn inner_arc(&self) -> Arc<RwLock<MemoryEngine>> {
47        Arc::clone(&self.inner)
48    }
49
50    /// 手动触发检查点
51    pub fn checkpoint(&mut self) -> DbResult<()> {
52        if self.persistence.needs_checkpoint() {
53            let engine = self.inner.read().unwrap();
54            self.persistence.checkpoint(&engine)?;
55        }
56        Ok(())
57    }
58
59    /// 强制触发检查点(无论 WAL 大小)
60    pub fn force_checkpoint(&mut self) -> DbResult<()> {
61        let engine = self.inner.read().unwrap();
62        self.persistence.checkpoint(&engine)?;
63        Ok(())
64    }
65
66    /// 获取 WAL 文件大小
67    pub fn wal_size(&self) -> u64 {
68        self.persistence.wal_size()
69    }
70
71    // ========== 索引方法委托(委托给内部 MemoryEngine) ==========
72
73    /// 为表列创建索引(单列)
74    pub fn create_index(&mut self, table: &str, column: &str) -> DbResult<()> {
75        self.inner.write().unwrap().create_index(table, column)
76    }
77
78    /// 创建复合索引
79    pub fn create_composite_index(&mut self, table: &str, columns: &[&str]) -> DbResult<()> {
80        self.inner.write().unwrap().create_composite_index(table, columns)
81    }
82
83    /// 创建唯一复合索引
84    pub fn create_unique_index(&mut self, table: &str, columns: &[&str]) -> DbResult<()> {
85        self.inner.write().unwrap().create_unique_index(table, columns)
86    }
87
88    /// 删除索引
89    pub fn drop_index(&mut self, table: &str, column: &str) -> DbResult<bool> {
90        self.inner.write().unwrap().drop_index(table, column)
91    }
92
93    /// 删除复合索引
94    pub fn drop_composite_index(&mut self, table: &str, columns: &[&str]) -> DbResult<bool> {
95        self.inner.write().unwrap().drop_composite_index(table, columns)
96    }
97
98    /// 检查列是否有索引
99    pub fn has_index(&self, table: &str, column: &str) -> bool {
100        self.inner.read().unwrap().has_index(table, column)
101    }
102
103    /// 检查复合索引是否存在
104    pub fn has_composite_index(&self, table: &str, columns: &[&str]) -> bool {
105        self.inner.read().unwrap().has_composite_index(table, columns)
106    }
107}
108
109impl StorageEngine for PersistedEngine {
110    fn create_table(&mut self, schema: TableSchema) -> DbResult<()> {
111        // 1. 先写 WAL
112        let op = WalOperation::CreateTable { schema: schema.clone() };
113        self.persistence.log_operation(op)?;
114
115        // 2. 再写内存
116        self.inner.write().unwrap().create_table(schema)
117    }
118
119    fn drop_table(&mut self, name: &str) -> DbResult<()> {
120        // 1. 先写 WAL
121        let op = WalOperation::DropTable { name: name.to_string() };
122        self.persistence.log_operation(op)?;
123
124        // 2. 再写内存
125        self.inner.write().unwrap().drop_table(name)
126    }
127
128    fn has_table(&self, name: &str) -> bool {
129        self.inner.read().unwrap().has_table(name)
130    }
131
132    fn get_schema(&self, name: &str) -> DbResult<TableSchema> {
133        self.inner.read().unwrap().get_schema(name)
134    }
135
136    fn insert(&mut self, table: &str, row: Row) -> DbResult<RowId> {
137        // 1. 先写 WAL(在获取 row_id 之前记录原始数据)
138        let row_clone = row.clone();
139
140        // 2. 插入内存获取 row_id
141        let row_id = self.inner.write().unwrap().insert(table, row_clone.clone())?;
142
143        let op = WalOperation::Insert {
144            table: table.to_string(),
145            row_id: row_id.0,
146            row: row_clone,
147        };
148        self.persistence.log_operation(op)?;
149
150        // 检查是否需要自动检查点
151        if self.persistence.needs_checkpoint() {
152            let engine = self.inner.read().unwrap();
153            self.persistence.checkpoint(&engine)?;
154        }
155
156        Ok(row_id)
157    }
158
159    fn get(&self, table: &str, row_id: RowId) -> DbResult<Option<Row>> {
160        self.inner.read().unwrap().get(table, row_id)
161    }
162
163    fn update(&mut self, table: &str, row_id: RowId, values: Row) -> DbResult<()> {
164        // 先获取旧值(用于 WAL 记录)
165        let old_row = {
166            let inner = self.inner.read().unwrap();
167            inner.get(table, row_id)?
168        };
169
170        // 执行更新
171        self.inner.write().unwrap().update(table, row_id, values.clone())?;
172
173        // 根据是否有旧值决定写 Insert 还是 Update
174        let op = match old_row {
175            Some(_) => WalOperation::Update {
176                table: table.to_string(),
177                row_id: row_id.0,
178                row: values,
179            },
180            None => WalOperation::Insert {
181                table: table.to_string(),
182                row_id: row_id.0,
183                row: values,
184            },
185        };
186        self.persistence.log_operation(op)?;
187
188        // 检查是否需要自动检查点
189        if self.persistence.needs_checkpoint() {
190            let engine = self.inner.read().unwrap();
191            self.persistence.checkpoint(&engine)?;
192        }
193
194        Ok(())
195    }
196
197    fn delete(&mut self, table: &str, row_id: RowId) -> DbResult<Option<Row>> {
198        // 先获取要删除的行
199        let deleted_row = {
200            let inner = self.inner.read().unwrap();
201            inner.get(table, row_id)?
202        };
203
204        // 执行删除
205        let _ = self.inner.write().unwrap().delete(table, row_id)?;
206
207        // 写 WAL
208        let op = WalOperation::Delete {
209            table: table.to_string(),
210            row_id: row_id.0,
211        };
212        self.persistence.log_operation(op)?;
213
214        // 检查是否需要自动检查点
215        if self.persistence.needs_checkpoint() {
216            let engine = self.inner.read().unwrap();
217            self.persistence.checkpoint(&engine)?;
218        }
219
220        Ok(deleted_row)
221    }
222
223    fn scan(&self, table: &str) -> DbResult<Vec<(RowId, Row)>> {
224        self.inner.read().unwrap().scan(table)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::types::{DataType, Column, DbValue};
232    use tempfile::TempDir;
233
234    fn create_test_schema() -> TableSchema {
235        TableSchema::new(
236            "users",
237            vec![
238                Column::new("id", DataType::integer()).primary_key(),
239                Column::new("name", DataType::text()),
240                Column::new("age", DataType::integer()),
241            ],
242        )
243    }
244
245    fn create_test_row() -> Row {
246        let mut row = Row::new();
247        row.insert("id".to_string(), DbValue::integer(1));
248        row.insert("name".to_string(), DbValue::text("Alice"));
249        row.insert("age".to_string(), DbValue::integer(25));
250        row
251    }
252
253    #[test]
254    fn test_persisted_create_table() {
255        let temp_dir = TempDir::new().unwrap();
256        let mut engine = PersistedEngine::create(temp_dir.path()).unwrap();
257
258        let schema = create_test_schema();
259        assert!(engine.create_table(schema).is_ok());
260        assert!(engine.has_table("users"));
261    }
262
263    #[test]
264    fn test_persisted_insert_and_get() {
265        let temp_dir = TempDir::new().unwrap();
266        let mut engine = PersistedEngine::create(temp_dir.path()).unwrap();
267
268        engine.create_table(create_test_schema()).unwrap();
269
270        let row = create_test_row();
271        let row_id = engine.insert("users", row).unwrap();
272
273        let retrieved = engine.get("users", row_id).unwrap().unwrap();
274        assert_eq!(retrieved.get("name").unwrap().as_text(), Some("Alice"));
275    }
276
277    #[test]
278    fn test_persisted_checkpoint() {
279        let temp_dir = TempDir::new().unwrap();
280        let mut engine = PersistedEngine::create(temp_dir.path()).unwrap();
281
282        engine.create_table(create_test_schema()).unwrap();
283
284        for i in 0..5 {
285            let mut row = Row::new();
286            row.insert("id".to_string(), DbValue::integer(i));
287            row.insert("name".to_string(), DbValue::text(format!("User{}", i)));
288            row.insert("age".to_string(), DbValue::integer(20 + i));
289            engine.insert("users", row).unwrap();
290        }
291
292        // 手动检查点
293        assert!(engine.force_checkpoint().is_ok());
294
295        // WAL 大小应该归零
296        assert_eq!(engine.wal_size(), 0);
297    }
298
299    #[test]
300    fn test_persisted_recovery() {
301        let temp_dir = TempDir::new().unwrap();
302
303        // 1. 创建数据库并插入数据
304        {
305            let mut engine = PersistedEngine::create(temp_dir.path()).unwrap();
306            engine.create_table(create_test_schema()).unwrap();
307
308            let row = create_test_row();
309            engine.insert("users", row).unwrap();
310
311            // 强制检查点确保数据持久化
312            engine.force_checkpoint().unwrap();
313
314            // 再插入一条数据(不检查点,测试 WAL 恢复)
315            let mut row2 = Row::new();
316            row2.insert("id".to_string(), DbValue::integer(2));
317            row2.insert("name".to_string(), DbValue::text("Bob"));
318            row2.insert("age".to_string(), DbValue::integer(30));
319            engine.insert("users", row2).unwrap();
320        }
321
322        // 2. 重新打开数据库(模拟恢复)
323        let engine = PersistedEngine::open(temp_dir.path()).unwrap();
324
325        // 3. 验证数据
326        let engine_arc = engine.inner_arc();
327        let inner = engine_arc.read().unwrap();
328        assert_eq!(inner.get_row_count("users").unwrap(), 2);
329
330        let row1 = inner.get("users", RowId(0)).unwrap().unwrap();
331        assert_eq!(row1.get("name").unwrap().as_text(), Some("Alice"));
332
333        let row2 = inner.get("users", RowId(1)).unwrap().unwrap();
334        assert_eq!(row2.get("name").unwrap().as_text(), Some("Bob"));
335    }
336}