1use crate::engine::Database;
8use crate::error::DbxResult;
9use crate::sql::builder::{Execute, Query, QueryOne, QueryOptional, QueryScalar};
10use std::collections::HashMap;
11use std::marker::PhantomData;
12
13pub trait TxState {}
15
16pub struct Active;
18
19pub struct Committed;
21
22pub struct RolledBack;
24
25impl TxState for Active {}
26impl TxState for Committed {}
27impl TxState for RolledBack {}
28
29#[derive(Debug, Clone)]
31enum TxOp {
32 Insert(String, Vec<u8>, Vec<u8>),
34 Delete(String, Vec<u8>),
36 Batch(String, Vec<(Vec<u8>, Vec<u8>)>),
38}
39
40pub struct Transaction<'a, S: TxState> {
45 db: &'a Database,
46 ops: Vec<TxOp>,
48 local_buffer: HashMap<String, HashMap<Vec<u8>, Option<Vec<u8>>>>,
50 _state: PhantomData<S>,
51}
52
53impl Database {
54 pub fn begin(&self) -> DbxResult<Transaction<'_, Active>> {
60 Ok(Transaction {
61 db: self,
62 ops: Vec::new(),
63 local_buffer: HashMap::new(),
64 _state: PhantomData,
65 })
66 }
67}
68
69impl<'a> Transaction<'a, Active> {
70 pub fn query<T: crate::api::FromRow>(&self, sql: impl Into<String>) -> Query<'_, T> {
76 self.db.query(sql)
77 }
78
79 pub fn query_one<T: crate::api::FromRow>(&self, sql: impl Into<String>) -> QueryOne<'_, T> {
81 self.db.query_one(sql)
82 }
83
84 pub fn query_optional<T: crate::api::FromRow>(
86 &self,
87 sql: impl Into<String>,
88 ) -> QueryOptional<'_, T> {
89 self.db.query_optional(sql)
90 }
91
92 pub fn query_scalar<T: crate::api::FromScalar>(
94 &self,
95 sql: impl Into<String>,
96 ) -> QueryScalar<'_, T> {
97 self.db.query_scalar(sql)
98 }
99
100 pub fn execute(&self, sql: impl Into<String>) -> Execute<'_> {
102 self.db.execute(sql)
103 }
104
105 pub fn insert(&mut self, table: &str, key: &[u8], value: &[u8]) -> DbxResult<()> {
111 self.ops.push(TxOp::Insert(
112 table.to_string(),
113 key.to_vec(),
114 value.to_vec(),
115 ));
116 self.local_buffer
118 .entry(table.to_string())
119 .or_default()
120 .insert(key.to_vec(), Some(value.to_vec()));
121 Ok(())
122 }
123
124 pub fn insert_batch(&mut self, table: &str, rows: Vec<(Vec<u8>, Vec<u8>)>) -> DbxResult<()> {
126 self.ops.push(TxOp::Batch(table.to_string(), rows.clone()));
127 let table_buf = self.local_buffer.entry(table.to_string()).or_default();
129 for (key, value) in rows {
130 table_buf.insert(key, Some(value));
131 }
132 Ok(())
133 }
134
135 pub fn delete(&mut self, table: &str, key: &[u8]) -> DbxResult<bool> {
137 self.ops.push(TxOp::Delete(table.to_string(), key.to_vec()));
138 self.local_buffer
140 .entry(table.to_string())
141 .or_default()
142 .insert(key.to_vec(), None);
143 Ok(true)
144 }
145
146 pub fn get(&self, table: &str, key: &[u8]) -> DbxResult<Option<Vec<u8>>> {
148 if let Some(table_buf) = self.local_buffer.get(table)
150 && let Some(value_opt) = table_buf.get(key)
151 {
152 return Ok(value_opt.clone()); }
154 self.db.get(table, key)
156 }
157
158 pub fn pending_ops(&self) -> usize {
160 self.ops.len()
161 }
162
163 pub fn commit(self) -> DbxResult<Transaction<'a, Committed>> {
187 let commit_ts = self.db.allocate_commit_ts();
189
190 for op in &self.ops {
192 match op {
193 TxOp::Insert(table, key, value) => {
194 self.db
196 .insert_versioned(table, key, Some(value), commit_ts)?;
197 self.db.insert(table, key, value)?;
199 }
200 TxOp::Delete(table, key) => {
201 self.db.insert_versioned(table, key, None, commit_ts)?;
203 self.db.delete(table, key)?;
205 }
206 TxOp::Batch(table, rows) => {
207 for (key, value) in rows {
209 self.db
210 .insert_versioned(table, key, Some(value), commit_ts)?;
211 self.db.insert(table, key, value)?;
213 }
214 }
215 }
216 }
217 Ok(Transaction {
218 db: self.db,
219 ops: Vec::new(),
220 local_buffer: HashMap::new(),
221 _state: PhantomData,
222 })
223 }
224
225 pub fn rollback(self) -> DbxResult<Transaction<'a, RolledBack>> {
227 Ok(Transaction {
229 db: self.db,
230 ops: Vec::new(),
231 local_buffer: HashMap::new(),
232 _state: PhantomData,
233 })
234 }
235}
236
237impl<'a> Transaction<'a, Committed> {
239 pub fn is_committed(&self) -> bool {
241 true
242 }
243}
244
245impl<'a> Transaction<'a, RolledBack> {
246 pub fn is_rolled_back(&self) -> bool {
248 true
249 }
250}
251
252impl crate::traits::DatabaseTransaction for Database {
257 fn begin(&self) -> DbxResult<Transaction<'_, Active>> {
258 Database::begin(self)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use crate::engine::Database;
266
267 #[test]
268 fn test_begin_commit() {
269 let db = Database::open_in_memory().unwrap();
270 let mut tx = db.begin().unwrap();
271
272 tx.insert("users", b"u1", b"Alice").unwrap();
273 tx.insert("users", b"u2", b"Bob").unwrap();
274
275 assert_eq!(db.get("users", b"u1").unwrap(), None);
277
278 assert_eq!(tx.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
280
281 let committed = tx.commit().unwrap();
283 assert!(committed.is_committed());
284
285 assert_eq!(db.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
287 assert_eq!(db.get("users", b"u2").unwrap(), Some(b"Bob".to_vec()));
288 }
289
290 #[test]
291 fn test_begin_rollback() {
292 let db = Database::open_in_memory().unwrap();
293 let mut tx = db.begin().unwrap();
294
295 tx.insert("users", b"u1", b"Alice").unwrap();
296 tx.insert("users", b"u2", b"Bob").unwrap();
297
298 let rolled_back = tx.rollback().unwrap();
300 assert!(rolled_back.is_rolled_back());
301
302 assert_eq!(db.get("users", b"u1").unwrap(), None);
304 assert_eq!(db.get("users", b"u2").unwrap(), None);
305 }
306
307 #[test]
308 fn test_delete_in_transaction() {
309 let db = Database::open_in_memory().unwrap();
310
311 db.insert("users", b"u1", b"Alice").unwrap();
313 assert_eq!(db.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
314
315 let mut tx = db.begin().unwrap();
317 tx.delete("users", b"u1").unwrap();
318
319 assert_eq!(tx.get("users", b"u1").unwrap(), None);
321
322 assert_eq!(db.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
324
325 tx.commit().unwrap();
327 assert_eq!(db.get("users", b"u1").unwrap(), None);
328 }
329
330 #[test]
331 fn test_read_your_writes() {
332 let db = Database::open_in_memory().unwrap();
333
334 db.insert("t", b"k1", b"old").unwrap();
336
337 let mut tx = db.begin().unwrap();
338
339 tx.insert("t", b"k1", b"new").unwrap();
341 assert_eq!(tx.get("t", b"k1").unwrap(), Some(b"new".to_vec()));
342
343 db.insert("t", b"k2", b"main_data").unwrap();
345 assert_eq!(tx.get("t", b"k2").unwrap(), Some(b"main_data".to_vec()));
346
347 tx.rollback().unwrap();
348 assert_eq!(db.get("t", b"k1").unwrap(), Some(b"old".to_vec()));
350 }
351
352 #[test]
353 fn test_pending_ops_count() {
354 let db = Database::open_in_memory().unwrap();
355 let mut tx = db.begin().unwrap();
356
357 assert_eq!(tx.pending_ops(), 0);
358 tx.insert("t", b"a", b"1").unwrap();
359 assert_eq!(tx.pending_ops(), 1);
360 tx.delete("t", b"b").unwrap();
361 assert_eq!(tx.pending_ops(), 2);
362 tx.insert("t", b"c", b"3").unwrap();
363 assert_eq!(tx.pending_ops(), 3);
364 }
365
366 #[test]
367 fn test_empty_transaction_commit() {
368 let db = Database::open_in_memory().unwrap();
369 let tx = db.begin().unwrap();
370 let committed = tx.commit().unwrap();
371 assert!(committed.is_committed());
372 }
373
374 #[test]
375 fn test_empty_transaction_rollback() {
376 let db = Database::open_in_memory().unwrap();
377 let tx = db.begin().unwrap();
378 let rolled_back = tx.rollback().unwrap();
379 assert!(rolled_back.is_rolled_back());
380 }
381}