1use std::collections::HashMap;
7
8use fsqlite_ast::TransactionMode;
9use fsqlite_error::{FrankenError, Result};
10use tracing::{debug, error, info};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
18pub enum LockLevel {
19 None,
21 Shared,
23 Reserved,
25 Exclusive,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TxnState {
36 Idle,
38 Active,
40 Error,
42}
43
44#[derive(Debug, Clone)]
53pub struct SavepointEntry {
54 pub name: String,
56 write_set_snapshot: HashMap<u64, Vec<u8>>,
58}
59
60#[derive(Debug)]
70pub struct TransactionController {
71 state: TxnState,
73 mode: Option<TransactionMode>,
75 lock_level: LockLevel,
77 savepoints: Vec<SavepointEntry>,
79 write_set: HashMap<u64, Vec<u8>>,
81 concurrent: bool,
83 implicit_txn: bool,
85}
86
87impl TransactionController {
88 #[must_use]
90 pub fn new() -> Self {
91 Self {
92 state: TxnState::Idle,
93 mode: None,
94 lock_level: LockLevel::None,
95 savepoints: Vec::new(),
96 write_set: HashMap::new(),
97 concurrent: false,
98 implicit_txn: false,
99 }
100 }
101
102 #[must_use]
104 pub const fn state(&self) -> TxnState {
105 self.state
106 }
107
108 #[must_use]
110 pub const fn lock_level(&self) -> LockLevel {
111 self.lock_level
112 }
113
114 #[must_use]
116 pub const fn mode(&self) -> Option<TransactionMode> {
117 self.mode
118 }
119
120 #[must_use]
122 pub const fn is_concurrent(&self) -> bool {
123 self.concurrent
124 }
125
126 #[must_use]
128 pub fn savepoint_depth(&self) -> usize {
129 self.savepoints.len()
130 }
131
132 pub fn begin(&mut self, mode: Option<TransactionMode>) -> Result<()> {
141 if self.state != TxnState::Idle {
142 error!(
143 begin_mode = ?mode,
144 "BEGIN failed: transaction already active"
145 );
146 return Err(FrankenError::Busy);
147 }
148
149 let resolved_mode = mode.unwrap_or(TransactionMode::Deferred);
150
151 let (lock, concurrent) = match resolved_mode {
153 TransactionMode::Deferred => {
154 (LockLevel::None, false)
156 }
157 TransactionMode::Immediate => {
158 (LockLevel::Reserved, false)
160 }
161 TransactionMode::Exclusive => {
162 (LockLevel::Exclusive, false)
164 }
165 TransactionMode::Concurrent => {
166 (LockLevel::Shared, true)
168 }
169 };
170
171 self.state = TxnState::Active;
172 self.mode = Some(resolved_mode);
173 self.lock_level = lock;
174 self.concurrent = concurrent;
175 self.write_set.clear();
176
177 info!(
178 begin_mode = ?resolved_mode,
179 lock_level = ?lock,
180 concurrent,
181 "transaction started"
182 );
183
184 Ok(())
185 }
186
187 pub fn commit(&mut self) -> Result<()> {
198 match self.state {
199 TxnState::Idle => {
200 return Err(FrankenError::NoActiveTransaction);
201 }
202 TxnState::Error => {
203 error!("COMMIT failed: transaction is in error state, must ROLLBACK");
204 return Err(FrankenError::Busy);
205 }
206 TxnState::Active => {}
207 }
208
209 info!(
210 mode = ?self.mode,
211 savepoint_depth = self.savepoints.len(),
212 "commit"
213 );
214
215 self.reset();
216 Ok(())
217 }
218
219 pub fn rollback(&mut self) -> Result<()> {
228 if self.state == TxnState::Idle {
229 return Err(FrankenError::NoActiveTransaction);
230 }
231
232 info!(
233 mode = ?self.mode,
234 savepoint_depth = self.savepoints.len(),
235 "rollback"
236 );
237
238 self.reset();
239 Ok(())
240 }
241
242 #[allow(clippy::needless_pass_by_value)]
251 pub fn savepoint(&mut self, name: String) -> Result<()> {
252 if self.state == TxnState::Idle {
253 self.begin(Some(TransactionMode::Deferred))?;
254 self.implicit_txn = true;
255 }
256
257 let entry = SavepointEntry {
258 name: name.clone(),
259 write_set_snapshot: self.write_set.clone(),
260 };
261 self.savepoints.push(entry);
262
263 debug!(
264 savepoint = %name,
265 depth = self.savepoints.len(),
266 "savepoint created"
267 );
268
269 Ok(())
270 }
271
272 pub fn release(&mut self, name: &str) -> Result<()> {
278 let pos = self.find_savepoint(name)?;
279
280 let removed = self.savepoints.len() - pos;
282 self.savepoints.truncate(pos);
283
284 debug!(
285 savepoint = %name,
286 removed,
287 remaining = self.savepoints.len(),
288 "savepoint released"
289 );
290
291 if self.savepoints.is_empty() && self.state == TxnState::Active && self.implicit_txn {
294 self.commit()?;
296 }
297
298 Ok(())
299 }
300
301 pub fn rollback_to(&mut self, name: &str) -> Result<()> {
307 let pos = self.find_savepoint(name)?;
308
309 self.savepoints.truncate(pos + 1);
311
312 let sp = &self.savepoints[pos];
314 self.write_set = sp.write_set_snapshot.clone();
315
316 if self.state == TxnState::Error {
318 self.state = TxnState::Active;
319 }
320
321 info!(
322 savepoint = %name,
323 depth = self.savepoints.len(),
324 "rollback to savepoint"
325 );
326
327 Ok(())
328 }
329
330 pub fn record_write(&mut self, page_number: u64, data: Vec<u8>) {
336 self.write_set.entry(page_number).or_insert(data);
338 }
339
340 pub fn promote_on_read(&mut self) {
343 if self.state == TxnState::Active && self.lock_level == LockLevel::None {
344 self.lock_level = LockLevel::Shared;
345 debug!("DEFERRED transaction promoted to SHARED on first read");
346 }
347 }
348
349 pub fn promote_on_write(&mut self) {
351 if self.state == TxnState::Active {
352 match self.lock_level {
353 LockLevel::None | LockLevel::Shared => {
354 if self.concurrent {
355 self.lock_level = LockLevel::Shared;
357 } else {
358 self.lock_level = LockLevel::Reserved;
359 }
360 debug!(
361 lock_level = ?self.lock_level,
362 concurrent = self.concurrent,
363 "transaction promoted on first write"
364 );
365 }
366 LockLevel::Reserved | LockLevel::Exclusive => {
367 }
369 }
370 }
371 }
372
373 pub fn set_error(&mut self) {
375 if self.state == TxnState::Active {
376 self.state = TxnState::Error;
377 error!("transaction entered error state");
378 }
379 }
380
381 fn find_savepoint(&self, name: &str) -> Result<usize> {
387 for (i, sp) in self.savepoints.iter().enumerate().rev() {
388 if sp.name.eq_ignore_ascii_case(name) {
389 return Ok(i);
390 }
391 }
392 Err(FrankenError::internal(format!("no such savepoint: {name}")))
393 }
394
395 fn reset(&mut self) {
397 self.state = TxnState::Idle;
398 self.mode = None;
399 self.lock_level = LockLevel::None;
400 self.savepoints.clear();
401 self.write_set.clear();
402 self.concurrent = false;
403 self.implicit_txn = false;
404 }
405}
406
407impl Default for TransactionController {
408 fn default() -> Self {
409 Self::new()
410 }
411}
412
413#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
423 fn test_begin_deferred() {
424 let mut tc = TransactionController::new();
425 tc.begin(Some(TransactionMode::Deferred)).unwrap();
426 assert_eq!(tc.state(), TxnState::Active);
427 assert_eq!(tc.lock_level(), LockLevel::None);
429 }
430
431 #[test]
433 fn test_begin_immediate() {
434 let mut tc = TransactionController::new();
435 tc.begin(Some(TransactionMode::Immediate)).unwrap();
436 assert_eq!(tc.state(), TxnState::Active);
437 assert_eq!(tc.lock_level(), LockLevel::Reserved);
439 }
440
441 #[test]
443 fn test_begin_exclusive() {
444 let mut tc = TransactionController::new();
445 tc.begin(Some(TransactionMode::Exclusive)).unwrap();
446 assert_eq!(tc.state(), TxnState::Active);
447 assert_eq!(tc.lock_level(), LockLevel::Exclusive);
449 }
450
451 #[test]
453 fn test_begin_concurrent() {
454 let mut tc = TransactionController::new();
455 tc.begin(Some(TransactionMode::Concurrent)).unwrap();
456 assert_eq!(tc.state(), TxnState::Active);
457 assert!(tc.is_concurrent());
459 assert_eq!(tc.lock_level(), LockLevel::Shared);
460 }
461
462 #[test]
464 fn test_concurrent_no_conflict() {
465 let mut tc1 = TransactionController::new();
466 let mut tc2 = TransactionController::new();
467
468 tc1.begin(Some(TransactionMode::Concurrent)).unwrap();
469 tc2.begin(Some(TransactionMode::Concurrent)).unwrap();
470
471 tc1.promote_on_write();
473 tc1.record_write(1, vec![0xAA; 4096]);
474
475 tc2.promote_on_write();
477 tc2.record_write(2, vec![0xBB; 4096]);
478
479 tc1.commit().unwrap();
481 tc2.commit().unwrap();
482 }
483
484 #[test]
489 fn test_concurrent_page_conflict() {
490 let mut tc1 = TransactionController::new();
491 let mut tc2 = TransactionController::new();
492
493 tc1.begin(Some(TransactionMode::Concurrent)).unwrap();
494 tc2.begin(Some(TransactionMode::Concurrent)).unwrap();
495
496 assert!(tc1.is_concurrent());
497 assert!(tc2.is_concurrent());
498
499 tc1.record_write(1, vec![0xAA; 4096]);
502 tc2.record_write(1, vec![0xBB; 4096]);
503
504 tc1.commit().unwrap();
507 tc2.commit().unwrap();
508 }
509
510 #[test]
512 fn test_commit_end_synonym() {
513 let mut tc = TransactionController::new();
514 tc.begin(None).unwrap();
515 assert_eq!(tc.state(), TxnState::Active);
516 tc.commit().unwrap();
518 assert_eq!(tc.state(), TxnState::Idle);
519 }
520
521 #[test]
523 fn test_rollback() {
524 let mut tc = TransactionController::new();
525 tc.begin(Some(TransactionMode::Immediate)).unwrap();
526 tc.record_write(1, vec![0xAA; 100]);
527 tc.rollback().unwrap();
528 assert_eq!(tc.state(), TxnState::Idle);
529 assert_eq!(tc.lock_level(), LockLevel::None);
530 }
531
532 #[test]
534 fn test_savepoint_basic() {
535 let mut tc = TransactionController::new();
536 tc.begin(Some(TransactionMode::Deferred)).unwrap();
537 tc.savepoint("sp1".to_owned()).unwrap();
538 assert_eq!(tc.savepoint_depth(), 1);
539 }
540
541 #[test]
543 fn test_savepoint_release() {
544 let mut tc = TransactionController::new();
545 tc.begin(Some(TransactionMode::Immediate)).unwrap();
546 tc.savepoint("sp1".to_owned()).unwrap();
547 tc.record_write(1, vec![0xAA; 100]);
548 tc.release("sp1").unwrap();
549 assert_eq!(tc.savepoint_depth(), 0);
551 }
552
553 #[test]
555 fn test_savepoint_release_removes_later() {
556 let mut tc = TransactionController::new();
557 tc.begin(Some(TransactionMode::Immediate)).unwrap();
558 tc.savepoint("sp1".to_owned()).unwrap();
559 tc.savepoint("sp2".to_owned()).unwrap();
560 tc.savepoint("sp3".to_owned()).unwrap();
561 assert_eq!(tc.savepoint_depth(), 3);
562
563 tc.release("sp1").unwrap();
565 assert_eq!(tc.savepoint_depth(), 0);
566 }
567
568 #[test]
570 fn test_savepoint_rollback_to() {
571 let mut tc = TransactionController::new();
572 tc.begin(Some(TransactionMode::Immediate)).unwrap();
573 tc.savepoint("sp1".to_owned()).unwrap();
574 tc.record_write(1, vec![0xAA; 100]);
575 tc.rollback_to("sp1").unwrap();
576 assert_eq!(tc.savepoint_depth(), 1);
578 }
579
580 #[test]
582 fn test_savepoint_nested() {
583 let mut tc = TransactionController::new();
584 tc.begin(Some(TransactionMode::Immediate)).unwrap();
585 tc.savepoint("sp1".to_owned()).unwrap();
586 tc.savepoint("sp2".to_owned()).unwrap();
587 tc.savepoint("sp3".to_owned()).unwrap();
588 assert_eq!(tc.savepoint_depth(), 3);
589
590 tc.rollback_to("sp2").unwrap();
592 assert_eq!(tc.savepoint_depth(), 2);
593 }
594
595 #[test]
597 fn test_savepoint_rollback_then_continue() {
598 let mut tc = TransactionController::new();
599 tc.begin(Some(TransactionMode::Immediate)).unwrap();
600 tc.savepoint("sp1".to_owned()).unwrap();
601 tc.record_write(1, vec![0xAA; 100]);
602 tc.rollback_to("sp1").unwrap();
603
604 tc.record_write(2, vec![0xBB; 100]);
606 tc.commit().unwrap();
607 assert_eq!(tc.state(), TxnState::Idle);
608 }
609
610 #[test]
612 fn test_deferred_lock_promotion() {
613 let mut tc = TransactionController::new();
614 tc.begin(Some(TransactionMode::Deferred)).unwrap();
615 assert_eq!(tc.lock_level(), LockLevel::None);
616
617 tc.promote_on_read();
619 assert_eq!(tc.lock_level(), LockLevel::Shared);
620
621 tc.promote_on_write();
623 assert_eq!(tc.lock_level(), LockLevel::Reserved);
624 }
625
626 #[test]
628 fn test_error_state_requires_rollback() {
629 let mut tc = TransactionController::new();
630 tc.begin(None).unwrap();
631 tc.set_error();
632 assert_eq!(tc.state(), TxnState::Error);
633
634 assert!(tc.commit().is_err());
636
637 tc.rollback().unwrap();
639 assert_eq!(tc.state(), TxnState::Idle);
640 }
641
642 #[test]
644 fn test_begin_within_transaction() {
645 let mut tc = TransactionController::new();
646 tc.begin(None).unwrap();
647 assert!(tc.begin(None).is_err());
648 }
649
650 #[test]
652 fn test_savepoint_starts_transaction() {
653 let mut tc = TransactionController::new();
654 assert_eq!(tc.state(), TxnState::Idle);
655 tc.savepoint("sp1".to_owned()).unwrap();
656 assert_eq!(tc.state(), TxnState::Active);
657 assert_eq!(tc.savepoint_depth(), 1);
658 tc.release("sp1").unwrap();
659 assert_eq!(tc.state(), TxnState::Idle);
660 }
661
662 #[test]
664 fn test_savepoint_explicit_transaction_no_commit_on_release() {
665 let mut tc = TransactionController::new();
666 tc.begin(Some(TransactionMode::Deferred)).unwrap();
667 tc.savepoint("sp1".to_owned()).unwrap();
668 assert_eq!(tc.state(), TxnState::Active);
669 tc.release("sp1").unwrap();
670 assert_eq!(tc.state(), TxnState::Active); tc.commit().unwrap();
672 assert_eq!(tc.state(), TxnState::Idle);
673 }
674
675 #[test]
677 fn test_rollback_to_clears_error() {
678 let mut tc = TransactionController::new();
679 tc.begin(None).unwrap();
680 tc.savepoint("sp1".to_owned()).unwrap();
681 tc.set_error();
682 assert_eq!(tc.state(), TxnState::Error);
683 tc.rollback_to("sp1").unwrap();
684 assert_eq!(tc.state(), TxnState::Active);
685 }
686}