use crate::error::{AqlError, Result};
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::task_local;
static TRANSACTION_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
task_local! {
pub static ACTIVE_TRANSACTION_ID: TransactionId;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TransactionId(u64);
impl Default for TransactionId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for TransactionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "tx:{}", self.0)
}
}
impl TransactionId {
pub fn new() -> Self {
Self(TRANSACTION_ID_COUNTER.fetch_add(1, Ordering::SeqCst))
}
pub fn from_u64(id: u64) -> Self {
Self(id)
}
pub fn as_u64(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone)]
pub struct TransactionBuffer {
pub _sid: TransactionId,
pub writes: DashMap<String, Vec<u8>>,
pub deletes: DashMap<String, ()>,
pub events: Arc<tokio::sync::Mutex<Vec<crate::pubsub::ChangeEvent>>>,
}
impl TransactionBuffer {
pub fn new(_sid: TransactionId) -> Self {
Self {
_sid,
writes: DashMap::new(),
deletes: DashMap::new(),
events: Arc::new(tokio::sync::Mutex::new(Vec::new())),
}
}
pub fn write(&self, key: String, value: Vec<u8>) {
self.deletes.remove(&key);
self.writes.insert(key, value);
}
pub fn read(&self, key: &str) -> Option<Vec<u8>> {
if self.deletes.contains_key(key) {
return None;
}
self.writes.get(key).map(|v| v.value().clone())
}
pub fn delete(&self, key: String) {
self.writes.remove(&key);
self.deletes.insert(key, ());
}
pub fn is_empty(&self) -> bool {
self.writes.is_empty() && self.deletes.is_empty()
}
}
pub struct TransactionManager {
pub active_transactions: Arc<DashMap<TransactionId, Arc<TransactionBuffer>>>,
}
impl TransactionManager {
pub fn new() -> Self {
Self {
active_transactions: Arc::new(DashMap::new()),
}
}
pub fn begin(&self) -> Arc<TransactionBuffer> {
let tx_id = TransactionId::new();
let buffer = Arc::new(TransactionBuffer::new(tx_id));
self.active_transactions.insert(tx_id, Arc::clone(&buffer));
buffer
}
pub fn commit(&self, tx_id: TransactionId) -> Result<()> {
if !self.active_transactions.contains_key(&tx_id) {
return Err(AqlError::invalid_operation(
"Transaction not found or already committed",
));
}
self.active_transactions.remove(&tx_id);
Ok(())
}
pub fn rollback(&self, tx_id: TransactionId) -> Result<()> {
if !self.active_transactions.contains_key(&tx_id) {
return Err(AqlError::invalid_operation(
"Transaction not found or already rolled back",
));
}
self.active_transactions.remove(&tx_id);
Ok(())
}
pub fn is_active(&self, tx_id: TransactionId) -> bool {
self.active_transactions.contains_key(&tx_id)
}
pub fn active_count(&self) -> usize {
self.active_transactions.len()
}
}
impl Clone for TransactionManager {
fn clone(&self) -> Self {
Self {
active_transactions: Arc::clone(&self.active_transactions),
}
}
}
impl Default for TransactionManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_isolation() {
let manager = TransactionManager::new();
let tx1 = manager.begin();
let tx2 = manager.begin();
assert_ne!(tx1._sid, tx2._sid);
assert_eq!(manager.active_count(), 2);
tx1.write("key1".to_string(), b"value1".to_vec());
tx2.write("key1".to_string(), b"value2".to_vec());
assert_eq!(tx1.writes.get("key1").unwrap().as_slice(), b"value1");
assert_eq!(tx2.writes.get("key1").unwrap().as_slice(), b"value2");
manager.commit(tx1._sid).unwrap();
assert_eq!(manager.active_count(), 1);
manager.rollback(tx2._sid).unwrap();
assert_eq!(manager.active_count(), 0);
}
}