Skip to main content

aurora_db/
transaction.rs

1use crate::error::{AqlError, Result};
2use dashmap::DashMap;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use tokio::task_local;
6
7static TRANSACTION_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
8
9task_local! {
10    pub static ACTIVE_TRANSACTION_ID: TransactionId;
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct TransactionId(u64);
15
16impl Default for TransactionId {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl std::fmt::Display for TransactionId {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        write!(f, "tx:{}", self.0)
25    }
26}
27
28impl TransactionId {
29    pub fn new() -> Self {
30        Self(TRANSACTION_ID_COUNTER.fetch_add(1, Ordering::SeqCst))
31    }
32
33    pub fn from_u64(id: u64) -> Self {
34        Self(id)
35    }
36
37    pub fn as_u64(&self) -> u64 {
38        self.0
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct TransactionBuffer {
44    pub id: TransactionId,
45    pub writes: DashMap<String, Vec<u8>>,
46    pub deletes: DashMap<String, ()>,
47    pub events: Arc<tokio::sync::Mutex<Vec<crate::pubsub::ChangeEvent>>>,
48}
49
50impl TransactionBuffer {
51    pub fn new(id: TransactionId) -> Self {
52        Self {
53            id,
54            writes: DashMap::new(),
55            deletes: DashMap::new(),
56            events: Arc::new(tokio::sync::Mutex::new(Vec::new())),
57        }
58    }
59
60    pub fn write(&self, key: String, value: Vec<u8>) {
61        self.deletes.remove(&key);
62        self.writes.insert(key, value);
63    }
64
65    pub fn read(&self, key: &str) -> Option<Vec<u8>> {
66        if self.deletes.contains_key(key) {
67            return None;
68        }
69        self.writes.get(key).map(|v| v.value().clone())
70    }
71
72    pub fn delete(&self, key: String) {
73        self.writes.remove(&key);
74        self.deletes.insert(key, ());
75    }
76
77    pub fn is_empty(&self) -> bool {
78        self.writes.is_empty() && self.deletes.is_empty()
79    }
80}
81
82pub struct TransactionManager {
83    pub active_transactions: Arc<DashMap<TransactionId, Arc<TransactionBuffer>>>,
84}
85
86impl TransactionManager {
87    pub fn new() -> Self {
88        Self {
89            active_transactions: Arc::new(DashMap::new()),
90        }
91    }
92
93    pub fn begin(&self) -> Arc<TransactionBuffer> {
94        let tx_id = TransactionId::new();
95        let buffer = Arc::new(TransactionBuffer::new(tx_id));
96        self.active_transactions.insert(tx_id, Arc::clone(&buffer));
97        buffer
98    }
99
100    pub fn commit(&self, tx_id: TransactionId) -> Result<()> {
101        if !self.active_transactions.contains_key(&tx_id) {
102            return Err(AqlError::invalid_operation(
103                "Transaction not found or already committed",
104            ));
105        }
106
107        self.active_transactions.remove(&tx_id);
108        Ok(())
109    }
110
111    pub fn rollback(&self, tx_id: TransactionId) -> Result<()> {
112        if !self.active_transactions.contains_key(&tx_id) {
113            return Err(AqlError::invalid_operation(
114                "Transaction not found or already rolled back",
115            ));
116        }
117
118        self.active_transactions.remove(&tx_id);
119        Ok(())
120    }
121
122    pub fn is_active(&self, tx_id: TransactionId) -> bool {
123        self.active_transactions.contains_key(&tx_id)
124    }
125
126    pub fn active_count(&self) -> usize {
127        self.active_transactions.len()
128    }
129}
130
131impl Clone for TransactionManager {
132    fn clone(&self) -> Self {
133        Self {
134            active_transactions: Arc::clone(&self.active_transactions),
135        }
136    }
137}
138
139impl Default for TransactionManager {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_transaction_isolation() {
151        let manager = TransactionManager::new();
152
153        let tx1 = manager.begin();
154        let tx2 = manager.begin();
155
156        assert_ne!(tx1.id, tx2.id);
157        assert_eq!(manager.active_count(), 2);
158
159        tx1.write("key1".to_string(), b"value1".to_vec());
160        tx2.write("key1".to_string(), b"value2".to_vec());
161
162        assert_eq!(tx1.writes.get("key1").unwrap().as_slice(), b"value1");
163        assert_eq!(tx2.writes.get("key1").unwrap().as_slice(), b"value2");
164
165        manager.commit(tx1.id).unwrap();
166        assert_eq!(manager.active_count(), 1);
167
168        manager.rollback(tx2.id).unwrap();
169        assert_eq!(manager.active_count(), 0);
170    }
171}