Skip to main content

aurora_db/
transaction.rs

1use crate::error::{AqlError, Result};
2use dashmap::DashMap;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6static TRANSACTION_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct TransactionId(u64);
10
11impl Default for TransactionId {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl TransactionId {
18    pub fn new() -> Self {
19        Self(TRANSACTION_ID_COUNTER.fetch_add(1, Ordering::SeqCst))
20    }
21
22    pub fn from_u64(id: u64) -> Self {
23        Self(id)
24    }
25
26    pub fn as_u64(&self) -> u64 {
27        self.0
28    }
29}
30
31#[derive(Debug, Clone)]
32pub struct TransactionBuffer {
33    pub id: TransactionId,
34    pub writes: DashMap<String, Vec<u8>>,
35    pub deletes: DashMap<String, ()>,
36}
37
38impl TransactionBuffer {
39    pub fn new(id: TransactionId) -> Self {
40        Self {
41            id,
42            writes: DashMap::new(),
43            deletes: DashMap::new(),
44        }
45    }
46
47    pub fn write(&self, key: String, value: Vec<u8>) {
48        self.deletes.remove(&key);
49        self.writes.insert(key, value);
50    }
51
52    pub fn delete(&self, key: String) {
53        self.writes.remove(&key);
54        self.deletes.insert(key, ());
55    }
56
57    pub fn is_empty(&self) -> bool {
58        self.writes.is_empty() && self.deletes.is_empty()
59    }
60}
61
62pub struct TransactionManager {
63    pub active_transactions: Arc<DashMap<TransactionId, Arc<TransactionBuffer>>>,
64}
65
66impl TransactionManager {
67    pub fn new() -> Self {
68        Self {
69            active_transactions: Arc::new(DashMap::new()),
70        }
71    }
72
73    pub fn begin(&self) -> Arc<TransactionBuffer> {
74        let tx_id = TransactionId::new();
75        let buffer = Arc::new(TransactionBuffer::new(tx_id));
76        self.active_transactions.insert(tx_id, Arc::clone(&buffer));
77        buffer
78    }
79
80    pub fn commit(&self, tx_id: TransactionId) -> Result<()> {
81        if !self.active_transactions.contains_key(&tx_id) {
82            return Err(AqlError::invalid_operation(
83                "Transaction not found or already committed",
84            ));
85        }
86
87        self.active_transactions.remove(&tx_id);
88        Ok(())
89    }
90
91    pub fn rollback(&self, tx_id: TransactionId) -> Result<()> {
92        if !self.active_transactions.contains_key(&tx_id) {
93            return Err(AqlError::invalid_operation(
94                "Transaction not found or already rolled back",
95            ));
96        }
97
98        self.active_transactions.remove(&tx_id);
99        Ok(())
100    }
101
102    pub fn is_active(&self, tx_id: TransactionId) -> bool {
103        self.active_transactions.contains_key(&tx_id)
104    }
105
106    pub fn active_count(&self) -> usize {
107        self.active_transactions.len()
108    }
109}
110
111impl Clone for TransactionManager {
112    fn clone(&self) -> Self {
113        Self {
114            active_transactions: Arc::clone(&self.active_transactions),
115        }
116    }
117}
118
119impl Default for TransactionManager {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_transaction_isolation() {
131        let manager = TransactionManager::new();
132
133        let tx1 = manager.begin();
134        let tx2 = manager.begin();
135
136        assert_ne!(tx1.id, tx2.id);
137        assert_eq!(manager.active_count(), 2);
138
139        tx1.write("key1".to_string(), b"value1".to_vec());
140        tx2.write("key1".to_string(), b"value2".to_vec());
141
142        assert_eq!(tx1.writes.get("key1").unwrap().as_slice(), b"value1");
143        assert_eq!(tx2.writes.get("key1").unwrap().as_slice(), b"value2");
144
145        manager.commit(tx1.id).unwrap();
146        assert_eq!(manager.active_count(), 1);
147
148        manager.rollback(tx2.id).unwrap();
149        assert_eq!(manager.active_count(), 0);
150    }
151}