1use crate::storage::{StorageEngine, Row, RowId};
10use crate::types::{DbValue, DbResult, DbError};
11
12#[derive(Debug)]
14pub enum WriteLog {
15 Inserted { table: String, row_id: RowId },
17 Updated { table: String, row_id: RowId, old_row: Row },
19 Deleted { table: String, row_id: RowId, row: Row },
21}
22
23pub struct Transaction<'a> {
25 engine: &'a mut dyn StorageEngine,
26 write_log: Vec<WriteLog>,
27 is_committed: bool,
28 enable_rollback_log: bool,
30}
31
32impl<'a> Transaction<'a> {
33 pub fn new(engine: &'a mut dyn StorageEngine) -> Self {
35 Transaction {
36 engine,
37 write_log: Vec::new(),
38 is_committed: false,
39 enable_rollback_log: true,
40 }
41 }
42
43 pub fn with_rollback_log(engine: &'a mut dyn StorageEngine, enable_log: bool) -> Self {
45 Transaction {
46 engine,
47 write_log: Vec::new(),
48 is_committed: false,
49 enable_rollback_log: enable_log,
50 }
51 }
52
53 pub fn commit(&mut self) -> DbResult<()> {
55 if self.is_committed {
56 return Err(DbError::TransactionError(
57 "Transaction already committed".to_string()
58 ));
59 }
60 self.write_log.clear();
62 self.is_committed = true;
63 Ok(())
64 }
65
66 pub fn rollback(&mut self) -> DbResult<()> {
73 if self.is_committed {
74 return Err(DbError::TransactionError(
75 "Cannot rollback a committed transaction".to_string()
76 ));
77 }
78
79 if !self.enable_rollback_log {
80 self.write_log.clear();
81 return Ok(());
82 }
83
84 for log in self.write_log.drain(..).rev() {
86 match log {
87 WriteLog::Inserted { table, row_id } => {
88 let _ = self.engine.delete(&table, row_id);
90 }
91 WriteLog::Updated { table, row_id, old_row } => {
92 let _ = self.engine.update(&table, row_id, old_row);
94 }
95 WriteLog::Deleted { table, row_id: _, row } => {
96 let _ = self.engine.insert(&table, row);
98 }
99 }
100 }
101
102 Ok(())
103 }
104
105 pub fn insert(&mut self, table: &str, values: Vec<(&str, DbValue)>) -> DbResult<RowId> {
107 let schema = self.engine.get_schema(table)?.clone();
109
110 let mut row = Row::new();
112 for (name, value) in values {
113 row.insert(name.to_string(), value);
114 }
115
116 schema.fill_defaults(&mut row);
118
119 let row_id = self.engine.insert(table, row.clone())?;
121
122 if self.enable_rollback_log {
124 self.write_log.push(WriteLog::Inserted {
125 table: table.to_string(),
126 row_id,
127 });
128 }
129
130 Ok(row_id)
131 }
132
133 pub fn engine(&self) -> &dyn StorageEngine {
135 self.engine
136 }
137
138 pub fn engine_mut(&mut self) -> &mut dyn StorageEngine {
140 self.engine
141 }
142
143 pub fn query_all(&self, table: &str) -> DbResult<Vec<Row>> {
145 let rows = self.engine.scan(table)?;
146 Ok(rows.into_iter().map(|(_, row)| row.clone()).collect())
147 }
148
149 pub fn with_engine<F, R>(&self, f: F) -> R
152 where
153 F: FnOnce(&dyn StorageEngine) -> R,
154 {
155 f(self.engine)
156 }
157
158 pub fn with_engine_mut<F, R>(&mut self, f: F) -> R
160 where
161 F: FnOnce(&mut dyn StorageEngine) -> R,
162 {
163 f(self.engine)
164 }
165
166 pub fn update<F>(&mut self, table: &str, condition: F, updates: Vec<(&str, DbValue)>) -> DbResult<usize>
168 where
169 F: Fn(&Row) -> bool,
170 {
171 let schema = self.engine.get_schema(table)?;
173
174 let mut new_values = Row::new();
176 for (name, value) in updates {
177 new_values.insert(name.to_string(), value);
178 }
179 schema.validate(&new_values.iter().map(|(k, v)| (k.clone(), v.clone())).collect::<Vec<_>>())?;
180
181 let rows = self.engine.scan(table)?;
183 let matching_rows: Vec<(RowId, Row)> = rows
184 .into_iter()
185 .filter(|(_, row)| condition(row))
186 .map(|(row_id, row)| (row_id, row.clone()))
187 .collect();
188
189 let mut updated_count = 0;
190
191 for (row_id, old_row) in matching_rows {
192 if self.enable_rollback_log {
194 self.write_log.push(WriteLog::Updated {
195 table: table.to_string(),
196 row_id,
197 old_row: old_row.clone(),
198 });
199 }
200
201 let mut updated_row = old_row.clone();
203 for (key, value) in new_values.iter() {
204 updated_row.insert(key.clone(), value.clone());
205 }
206
207 self.engine.update(table, row_id, updated_row)?;
209 updated_count += 1;
210 }
211
212 Ok(updated_count)
213 }
214
215 pub fn delete<F>(&mut self, table: &str, condition: F) -> DbResult<usize>
217 where
218 F: Fn(&Row) -> bool,
219 {
220 let rows = self.engine.scan(table)?;
222 let rows_to_delete: Vec<(RowId, Row)> = rows
223 .into_iter()
224 .filter(|(_, row)| condition(row))
225 .map(|(row_id, row)| (row_id, row.clone()))
226 .collect();
227
228 let mut deleted_count = 0;
229
230 for (row_id, row) in rows_to_delete {
231 if self.enable_rollback_log {
233 self.write_log.push(WriteLog::Deleted {
234 table: table.to_string(),
235 row_id,
236 row,
237 });
238 }
239
240 self.engine.delete(table, row_id)?;
241 deleted_count += 1;
242 }
243
244 Ok(deleted_count)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::storage::MemoryEngine;
252 use crate::types::{DataType, Column, TableSchema};
253
254 fn create_test_schema() -> TableSchema {
255 TableSchema::new(
256 "users",
257 vec![
258 Column::new("id", DataType::integer()).primary_key(),
259 Column::new("name", DataType::text()),
260 ],
261 )
262 }
263
264 #[test]
265 fn test_transaction_insert() {
266 let mut engine = MemoryEngine::new();
267 engine.create_table(create_test_schema()).unwrap();
268
269 let mut tx = Transaction::new(&mut engine);
270 let row_id = tx.insert("users", vec![
271 ("id", DbValue::integer(1)),
272 ("name", DbValue::text("Alice")),
273 ]).unwrap();
274
275 assert_eq!(row_id.0, 0);
276
277 tx.commit().unwrap();
278
279 let rows = engine.scan("users").unwrap();
281 assert_eq!(rows.len(), 1);
282 }
283
284 #[test]
285 fn test_transaction_rollback() {
286 let mut engine = MemoryEngine::new();
287 engine.create_table(create_test_schema()).unwrap();
288
289 {
290 let mut tx = Transaction::new(&mut engine);
291 tx.insert("users", vec![
292 ("id", DbValue::integer(1)),
293 ("name", DbValue::text("Alice")),
294 ]).unwrap();
295
296 tx.rollback().unwrap();
298 }
299
300 let rows = engine.scan("users").unwrap();
302 assert!(rows.is_empty());
303 }
304
305 #[test]
306 fn test_transaction_update_rollback() {
307 let mut engine = MemoryEngine::new();
308 engine.create_table(create_test_schema()).unwrap();
309
310 let mut row = Row::new();
312 row.insert("id".to_string(), DbValue::integer(1));
313 row.insert("name".to_string(), DbValue::text("Alice"));
314 engine.insert("users", row).unwrap();
315
316 {
317 let mut tx = Transaction::new(&mut engine);
318
319 tx.update(
321 "users",
322 |row| row.get("id").and_then(|v| v.as_integer()) == Some(1),
323 vec![("name", DbValue::text("Bob"))],
324 ).unwrap();
325
326 tx.rollback().unwrap();
328 }
329
330 let rows = engine.scan("users").unwrap();
332 assert_eq!(rows[0].1.get("name").unwrap().as_text(), Some("Alice"));
333 }
334
335 #[test]
336 fn test_transaction_delete_rollback() {
337 let mut engine = MemoryEngine::new();
338 engine.create_table(create_test_schema()).unwrap();
339
340 let mut row = Row::new();
342 row.insert("id".to_string(), DbValue::integer(1));
343 row.insert("name".to_string(), DbValue::text("Alice"));
344 engine.insert("users", row).unwrap();
345
346 {
347 let mut tx = Transaction::new(&mut engine);
348
349 tx.delete(
351 "users",
352 |row| row.get("id").and_then(|v| v.as_integer()) == Some(1),
353 ).unwrap();
354
355 tx.rollback().unwrap();
357 }
358
359 let rows = engine.scan("users").unwrap();
361 assert_eq!(rows.len(), 1);
362 assert_eq!(rows[0].1.get("name").unwrap().as_text(), Some("Alice"));
363 }
364
365 #[test]
366 fn test_transaction_insert_with_default_values_rollback() {
367 let mut engine = MemoryEngine::new();
368 let schema = TableSchema::new(
369 "users",
370 vec![
371 Column::new("id", DataType::integer()).primary_key(),
372 Column::new("name", DataType::text()).not_null(),
373 Column::new("status", DataType::text()).default(DbValue::text("active")),
374 Column::new("age", DataType::integer()).default(DbValue::integer(0)),
375 Column::new("active", DataType::boolean()).default(DbValue::boolean(true)),
376 ],
377 );
378 engine.create_table(schema).unwrap();
379
380 {
381 let mut tx = Transaction::new(&mut engine);
382 tx.insert("users", vec![
384 ("id", DbValue::integer(1)),
385 ("name", DbValue::text("Alice")),
386 ]).unwrap();
387
388 tx.rollback().unwrap();
390 }
391
392 let rows = engine.scan("users").unwrap();
394 assert!(rows.is_empty());
395 }
396
397 #[test]
398 fn test_transaction_insert_with_default_values_commit() {
399 let mut engine = MemoryEngine::new();
400 let schema = TableSchema::new(
401 "users",
402 vec![
403 Column::new("id", DataType::integer()).primary_key(),
404 Column::new("name", DataType::text()).not_null(),
405 Column::new("status", DataType::text()).default(DbValue::text("active")),
406 Column::new("age", DataType::integer()).default(DbValue::integer(0)),
407 ],
408 );
409 engine.create_table(schema).unwrap();
410
411 {
412 let mut tx = Transaction::new(&mut engine);
413 tx.insert("users", vec![
415 ("id", DbValue::integer(1)),
416 ("name", DbValue::text("Alice")),
417 ]).unwrap();
418
419 tx.commit().unwrap();
421 }
422
423 let rows = engine.scan("users").unwrap();
425 assert_eq!(rows.len(), 1);
426 let row = &rows[0].1;
427 assert_eq!(row.get("status").unwrap().as_text(), Some("active"));
428 assert_eq!(row.get("age").unwrap().as_integer(), Some(0));
429 }
430}