1use crate::error::{AqlError, Result};
2use dashmap::DashMap;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use tokio::task_local;
6
7static TRANSACTION_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
8
9task_local! {
10 pub static ACTIVE_TRANSACTION_ID: TransactionId;
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct TransactionId(u64);
15
16impl Default for TransactionId {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl std::fmt::Display for TransactionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "tx:{}", self.0)
25 }
26}
27
28impl TransactionId {
29 pub fn new() -> Self {
30 Self(TRANSACTION_ID_COUNTER.fetch_add(1, Ordering::SeqCst))
31 }
32
33 pub fn from_u64(id: u64) -> Self {
34 Self(id)
35 }
36
37 pub fn as_u64(&self) -> u64 {
38 self.0
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct TransactionBuffer {
44 pub _sid: TransactionId,
45 pub writes: DashMap<String, Vec<u8>>,
46 pub deletes: DashMap<String, ()>,
47 pub events: Arc<tokio::sync::Mutex<Vec<crate::pubsub::ChangeEvent>>>,
48}
49
50impl TransactionBuffer {
51 pub fn new(_sid: TransactionId) -> Self {
52 Self {
53 _sid,
54 writes: DashMap::new(),
55 deletes: DashMap::new(),
56 events: Arc::new(tokio::sync::Mutex::new(Vec::new())),
57 }
58 }
59
60 pub fn write(&self, key: String, value: Vec<u8>) {
61 self.deletes.remove(&key);
62 self.writes.insert(key, value);
63 }
64
65 pub fn read(&self, key: &str) -> Option<Vec<u8>> {
66 if self.deletes.contains_key(key) {
67 return None;
68 }
69 self.writes.get(key).map(|v| v.value().clone())
70 }
71
72 pub fn delete(&self, key: String) {
73 self.writes.remove(&key);
74 self.deletes.insert(key, ());
75 }
76
77 pub fn is_empty(&self) -> bool {
78 self.writes.is_empty() && self.deletes.is_empty()
79 }
80}
81
82pub struct TransactionManager {
83 pub active_transactions: Arc<DashMap<TransactionId, Arc<TransactionBuffer>>>,
84}
85
86impl TransactionManager {
87 pub fn new() -> Self {
88 Self {
89 active_transactions: Arc::new(DashMap::new()),
90 }
91 }
92
93 pub fn begin(&self) -> Arc<TransactionBuffer> {
94 let tx_id = TransactionId::new();
95 let buffer = Arc::new(TransactionBuffer::new(tx_id));
96 self.active_transactions.insert(tx_id, Arc::clone(&buffer));
97 buffer
98 }
99
100 pub fn commit(&self, tx_id: TransactionId) -> Result<()> {
101 if !self.active_transactions.contains_key(&tx_id) {
102 return Err(AqlError::invalid_operation(
103 "Transaction not found or already committed",
104 ));
105 }
106
107 self.active_transactions.remove(&tx_id);
108 Ok(())
109 }
110
111 pub fn rollback(&self, tx_id: TransactionId) -> Result<()> {
112 if !self.active_transactions.contains_key(&tx_id) {
113 return Err(AqlError::invalid_operation(
114 "Transaction not found or already rolled back",
115 ));
116 }
117
118 self.active_transactions.remove(&tx_id);
119 Ok(())
120 }
121
122 pub fn is_active(&self, tx_id: TransactionId) -> bool {
123 self.active_transactions.contains_key(&tx_id)
124 }
125
126 pub fn active_count(&self) -> usize {
127 self.active_transactions.len()
128 }
129}
130
131impl Clone for TransactionManager {
132 fn clone(&self) -> Self {
133 Self {
134 active_transactions: Arc::clone(&self.active_transactions),
135 }
136 }
137}
138
139impl Default for TransactionManager {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_transaction_isolation() {
151 let manager = TransactionManager::new();
152
153 let tx1 = manager.begin();
154 let tx2 = manager.begin();
155
156 assert_ne!(tx1._sid, tx2._sid);
157 assert_eq!(manager.active_count(), 2);
158
159 tx1.write("key1".to_string(), b"value1".to_vec());
160 tx2.write("key1".to_string(), b"value2".to_vec());
161
162 assert_eq!(tx1.writes.get("key1").unwrap().as_slice(), b"value1");
163 assert_eq!(tx2.writes.get("key1").unwrap().as_slice(), b"value2");
164
165 manager.commit(tx1._sid).unwrap();
166 assert_eq!(manager.active_count(), 1);
167
168 manager.rollback(tx2._sid).unwrap();
169 assert_eq!(manager.active_count(), 0);
170 }
171}