use std::collections::HashMap;
use std::sync::Mutex;
use uuid::Uuid;
use crate::error::{ClawDBError, ClawDBResult};
#[derive(Debug, Clone)]
struct TxRecord {
session_id: Uuid,
write_set: Vec<String>,
started_at: std::time::Instant,
}
pub struct TransactionCoordinator {
active: Mutex<HashMap<Uuid, TxRecord>>,
}
impl TransactionCoordinator {
pub fn new() -> Self {
Self {
active: Mutex::new(HashMap::new()),
}
}
pub fn register(&self, tx_id: Uuid, session_id: Uuid) {
let mut map = self.active.lock().expect("coordinator lock poisoned");
map.insert(
tx_id,
TxRecord {
session_id,
write_set: Vec::new(),
started_at: std::time::Instant::now(),
},
);
}
pub fn extend_write_set(&self, tx_id: Uuid, keys: impl IntoIterator<Item = String>) {
let mut map = self.active.lock().expect("coordinator lock poisoned");
if let Some(record) = map.get_mut(&tx_id) {
record.write_set.extend(keys);
}
}
pub fn prepare(&self, tx_id: Uuid, additional_writes: &[String]) -> ClawDBResult<()> {
let map = self.active.lock().expect("coordinator lock poisoned");
let my_writes: Vec<&str> = map
.get(&tx_id)
.map(|r| r.write_set.iter().map(|s| s.as_str()).collect())
.unwrap_or_default();
for (other_id, other) in map.iter() {
if *other_id == tx_id {
continue;
}
let conflict = my_writes
.iter()
.chain(additional_writes.iter().map(|s| s.as_str()).collect::<Vec<_>>().iter())
.any(|k| other.write_set.iter().any(|w| w == k));
if conflict {
return Err(ClawDBError::TransactionConflict {
tx_id,
conflicting_tx: *other_id,
});
}
}
Ok(())
}
pub fn check_conflicts(&self, tx_id: Uuid, write_set: &[String]) -> ClawDBResult<()> {
self.prepare(tx_id, write_set)
}
pub fn deregister(&self, tx_id: Uuid) {
self.active
.lock()
.expect("coordinator lock poisoned")
.remove(&tx_id);
}
pub fn active_count(&self) -> usize {
self.active.lock().expect("coordinator lock poisoned").len()
}
pub fn stale_transactions(
&self,
threshold: std::time::Duration,
) -> Vec<(Uuid, Uuid)> {
let map = self.active.lock().expect("coordinator lock poisoned");
map.iter()
.filter(|(_, r)| r.started_at.elapsed() > threshold)
.map(|(id, r)| (*id, r.session_id))
.collect()
}
}
impl Default for TransactionCoordinator {
fn default() -> Self {
Self::new()
}
}
pub fn active_count(&self) -> usize {
self.active.lock().expect("coordinator lock poisoned").len()
}
}
impl Default for TransactionCoordinator {
fn default() -> Self {
Self::new()
}
}