use std::collections::HashMap;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TransactionMode {
Read,
Write,
ReadWrite,
}
impl TransactionMode {
pub fn allows_writes(&self) -> bool {
matches!(self, TransactionMode::Write | TransactionMode::ReadWrite)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TransactionState {
Active,
Committed,
Aborted,
TimedOut,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum WriteOp {
Insert(String, String, String, Option<String>),
Delete(String, String, String, Option<String>),
Clear(Option<String>),
}
#[derive(Clone, Debug)]
pub struct Transaction {
pub id: u64,
pub mode: TransactionMode,
pub state: TransactionState,
pub started_at: u64,
pub timeout_ms: u64,
writes: Vec<WriteOp>,
}
impl Transaction {
fn new(id: u64, mode: TransactionMode, started_at: u64, timeout_ms: u64) -> Self {
Self {
id,
mode,
state: TransactionState::Active,
started_at,
timeout_ms,
writes: Vec::new(),
}
}
pub fn is_active(&self) -> bool {
self.state == TransactionState::Active
}
pub fn write_count(&self) -> usize {
self.writes.len()
}
pub fn is_timed_out(&self, current_time_ms: u64) -> bool {
self.timeout_ms > 0 && current_time_ms.saturating_sub(self.started_at) >= self.timeout_ms
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum TxError {
NotFound(u64),
AlreadyCommitted(u64),
AlreadyAborted(u64),
ReadOnly(u64),
MaxConcurrentWritesExceeded,
WriteConflict,
TimedOut(u64),
}
impl std::fmt::Display for TxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TxError::NotFound(id) => write!(f, "Transaction {} not found", id),
TxError::AlreadyCommitted(id) => write!(f, "Transaction {} already committed", id),
TxError::AlreadyAborted(id) => write!(f, "Transaction {} already aborted", id),
TxError::ReadOnly(id) => write!(f, "Transaction {} is read-only", id),
TxError::MaxConcurrentWritesExceeded => {
write!(f, "Maximum concurrent write transactions exceeded")
}
TxError::WriteConflict => write!(f, "Write conflict detected"),
TxError::TimedOut(id) => write!(f, "Transaction {} timed out", id),
}
}
}
impl std::error::Error for TxError {}
pub struct TransactionManager {
next_id: u64,
active: HashMap<u64, Transaction>,
max_concurrent_writes: usize,
default_timeout_ms: u64,
}
impl TransactionManager {
pub fn new(max_concurrent_writes: usize, default_timeout_ms: u64) -> Self {
Self {
next_id: 1,
active: HashMap::new(),
max_concurrent_writes,
default_timeout_ms,
}
}
pub fn begin(&mut self, mode: TransactionMode) -> Result<u64, TxError> {
self.begin_at(mode, 0)
}
pub fn begin_at(&mut self, mode: TransactionMode, current_time_ms: u64) -> Result<u64, TxError> {
if mode.allows_writes() && self.max_concurrent_writes > 0 {
let active_writes = self.write_count();
if active_writes >= self.max_concurrent_writes {
return Err(TxError::MaxConcurrentWritesExceeded);
}
}
let id = self.next_id;
self.next_id += 1;
let tx = Transaction::new(id, mode, current_time_ms, self.default_timeout_ms);
self.active.insert(id, tx);
Ok(id)
}
pub fn commit(&mut self, tx_id: u64) -> Result<Vec<WriteOp>, TxError> {
let tx = self.active.get_mut(&tx_id).ok_or(TxError::NotFound(tx_id))?;
match tx.state {
TransactionState::Active => {}
TransactionState::Committed => return Err(TxError::AlreadyCommitted(tx_id)),
TransactionState::Aborted => return Err(TxError::AlreadyAborted(tx_id)),
TransactionState::TimedOut => return Err(TxError::TimedOut(tx_id)),
}
tx.state = TransactionState::Committed;
let ops = tx.writes.drain(..).collect();
Ok(ops)
}
pub fn abort(&mut self, tx_id: u64) -> Result<(), TxError> {
let tx = self.active.get_mut(&tx_id).ok_or(TxError::NotFound(tx_id))?;
match tx.state {
TransactionState::Active => {}
TransactionState::Committed => return Err(TxError::AlreadyCommitted(tx_id)),
TransactionState::Aborted => return Err(TxError::AlreadyAborted(tx_id)),
TransactionState::TimedOut => return Err(TxError::TimedOut(tx_id)),
}
tx.state = TransactionState::Aborted;
tx.writes.clear();
Ok(())
}
pub fn add_write(&mut self, tx_id: u64, op: WriteOp) -> Result<(), TxError> {
let tx = self.active.get_mut(&tx_id).ok_or(TxError::NotFound(tx_id))?;
match tx.state {
TransactionState::Active => {}
TransactionState::Committed => return Err(TxError::AlreadyCommitted(tx_id)),
TransactionState::Aborted => return Err(TxError::AlreadyAborted(tx_id)),
TransactionState::TimedOut => return Err(TxError::TimedOut(tx_id)),
}
if !tx.mode.allows_writes() {
return Err(TxError::ReadOnly(tx_id));
}
tx.writes.push(op);
Ok(())
}
pub fn get_transaction(&self, tx_id: u64) -> Option<&Transaction> {
self.active.get(&tx_id)
}
pub fn active_count(&self) -> usize {
self.active
.values()
.filter(|tx| tx.state == TransactionState::Active)
.count()
}
pub fn write_count(&self) -> usize {
self.active
.values()
.filter(|tx| tx.state == TransactionState::Active && tx.mode.allows_writes())
.count()
}
pub fn expire_timed_out(&mut self, current_time_ms: u64) -> Vec<u64> {
let mut expired = Vec::new();
for tx in self.active.values_mut() {
if tx.state == TransactionState::Active && tx.is_timed_out(current_time_ms) {
tx.state = TransactionState::TimedOut;
tx.writes.clear();
expired.push(tx.id);
}
}
expired
}
pub fn is_active(&self, tx_id: u64) -> bool {
self.active
.get(&tx_id)
.is_some_and(|tx| tx.state == TransactionState::Active)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mgr() -> TransactionManager {
TransactionManager::new(10, 0)
}
#[test]
fn test_begin_read() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
assert_eq!(id, 1);
assert!(m.is_active(id));
}
#[test]
fn test_begin_write() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
assert!(m.is_active(id));
}
#[test]
fn test_begin_readwrite() {
let mut m = mgr();
let id = m.begin(TransactionMode::ReadWrite).unwrap();
assert!(m.is_active(id));
}
#[test]
fn test_begin_increments_id() {
let mut m = mgr();
let id1 = m.begin(TransactionMode::Read).unwrap();
let id2 = m.begin(TransactionMode::Read).unwrap();
let id3 = m.begin(TransactionMode::Write).unwrap();
assert!(id1 < id2);
assert!(id2 < id3);
}
#[test]
fn test_active_count() {
let mut m = mgr();
assert_eq!(m.active_count(), 0);
m.begin(TransactionMode::Read).unwrap();
m.begin(TransactionMode::Write).unwrap();
assert_eq!(m.active_count(), 2);
}
#[test]
fn test_commit_read_transaction() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
let ops = m.commit(id).unwrap();
assert!(ops.is_empty());
assert!(!m.is_active(id));
}
#[test]
fn test_commit_returns_write_ops() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.add_write(
id,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
)
.unwrap();
let ops = m.commit(id).unwrap();
assert_eq!(ops.len(), 1);
}
#[test]
fn test_commit_already_committed() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
m.commit(id).unwrap();
let result = m.commit(id);
assert_eq!(result, Err(TxError::AlreadyCommitted(id)));
}
#[test]
fn test_commit_already_aborted() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.abort(id).unwrap();
let result = m.commit(id);
assert_eq!(result, Err(TxError::AlreadyAborted(id)));
}
#[test]
fn test_commit_not_found() {
let mut m = mgr();
let result = m.commit(999);
assert_eq!(result, Err(TxError::NotFound(999)));
}
#[test]
fn test_abort_active_transaction() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.add_write(
id,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
)
.unwrap();
m.abort(id).unwrap();
assert!(!m.is_active(id));
assert_eq!(m.commit(id), Err(TxError::AlreadyAborted(id)));
}
#[test]
fn test_abort_already_aborted() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
m.abort(id).unwrap();
assert_eq!(m.abort(id), Err(TxError::AlreadyAborted(id)));
}
#[test]
fn test_abort_already_committed() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
m.commit(id).unwrap();
assert_eq!(m.abort(id), Err(TxError::AlreadyCommitted(id)));
}
#[test]
fn test_abort_not_found() {
let mut m = mgr();
assert_eq!(m.abort(999), Err(TxError::NotFound(999)));
}
#[test]
fn test_add_write_to_write_transaction() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.add_write(
id,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
)
.unwrap();
let tx = m.get_transaction(id).unwrap();
assert_eq!(tx.write_count(), 1);
}
#[test]
fn test_add_write_to_read_transaction_fails() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
let result = m.add_write(
id,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
);
assert_eq!(result, Err(TxError::ReadOnly(id)));
}
#[test]
fn test_add_write_multiple_ops() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
for i in 0..5u32 {
m.add_write(
id,
WriteOp::Insert(
format!("s{}", i),
"p".into(),
format!("o{}", i),
None,
),
)
.unwrap();
}
let tx = m.get_transaction(id).unwrap();
assert_eq!(tx.write_count(), 5);
}
#[test]
fn test_add_write_clear_op() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.add_write(id, WriteOp::Clear(None)).unwrap();
let ops = m.commit(id).unwrap();
assert_eq!(ops[0], WriteOp::Clear(None));
}
#[test]
fn test_add_write_delete_op() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.add_write(
id,
WriteOp::Delete("s".into(), "p".into(), "o".into(), Some("g".into())),
)
.unwrap();
let ops = m.commit(id).unwrap();
assert_eq!(
ops[0],
WriteOp::Delete("s".into(), "p".into(), "o".into(), Some("g".into()))
);
}
#[test]
fn test_add_write_after_commit_fails() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.commit(id).unwrap();
let result = m.add_write(
id,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
);
assert_eq!(result, Err(TxError::AlreadyCommitted(id)));
}
#[test]
fn test_max_concurrent_writes_respected() {
let mut m = TransactionManager::new(2, 0);
m.begin(TransactionMode::Write).unwrap();
m.begin(TransactionMode::Write).unwrap();
let result = m.begin(TransactionMode::Write);
assert_eq!(result, Err(TxError::MaxConcurrentWritesExceeded));
}
#[test]
fn test_read_allowed_when_writes_at_limit() {
let mut m = TransactionManager::new(1, 0);
m.begin(TransactionMode::Write).unwrap();
let id = m.begin(TransactionMode::Read).unwrap();
assert!(m.is_active(id));
}
#[test]
fn test_write_allowed_after_commit_frees_slot() {
let mut m = TransactionManager::new(1, 0);
let id = m.begin(TransactionMode::Write).unwrap();
m.commit(id).unwrap();
let id2 = m.begin(TransactionMode::Write).unwrap();
assert!(m.is_active(id2));
}
#[test]
fn test_write_count() {
let mut m = mgr();
let r = m.begin(TransactionMode::Read).unwrap();
let w = m.begin(TransactionMode::Write).unwrap();
assert_eq!(m.write_count(), 1);
m.commit(w).unwrap();
assert_eq!(m.write_count(), 0);
m.commit(r).unwrap();
}
#[test]
fn test_unlimited_writes_when_max_zero() {
let mut m = TransactionManager::new(0, 0);
for _ in 0..50 {
m.begin(TransactionMode::Write).unwrap();
}
assert_eq!(m.write_count(), 50);
}
#[test]
fn test_expire_timed_out() {
let mut m = TransactionManager::new(10, 1000);
let id = m.begin_at(TransactionMode::Write, 0).unwrap();
let expired = m.expire_timed_out(2000);
assert_eq!(expired, vec![id]);
assert!(!m.is_active(id));
}
#[test]
fn test_no_expiry_before_deadline() {
let mut m = TransactionManager::new(10, 5000);
let _id = m.begin_at(TransactionMode::Write, 0).unwrap();
let expired = m.expire_timed_out(4999);
assert!(expired.is_empty());
}
#[test]
fn test_expire_timed_out_returns_multiple() {
let mut m = TransactionManager::new(10, 500);
let id1 = m.begin_at(TransactionMode::Write, 0).unwrap();
let id2 = m.begin_at(TransactionMode::Read, 0).unwrap();
let mut expired = m.expire_timed_out(1000);
expired.sort();
assert!(expired.contains(&id1));
assert!(expired.contains(&id2));
}
#[test]
fn test_commit_after_timeout_fails() {
let mut m = TransactionManager::new(10, 500);
let id = m.begin_at(TransactionMode::Write, 0).unwrap();
m.expire_timed_out(1000);
assert_eq!(m.commit(id), Err(TxError::TimedOut(id)));
}
#[test]
fn test_no_timeout_when_zero() {
let mut m = TransactionManager::new(10, 0);
let id = m.begin_at(TransactionMode::Write, 0).unwrap();
let expired = m.expire_timed_out(u64::MAX);
assert!(expired.is_empty());
assert!(m.is_active(id));
}
#[test]
fn test_get_transaction_exists() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
let tx = m.get_transaction(id);
assert!(tx.is_some());
assert_eq!(tx.unwrap().id, id);
}
#[test]
fn test_get_transaction_missing() {
let m = mgr();
assert!(m.get_transaction(999).is_none());
}
#[test]
fn test_transaction_mode_preserved() {
let mut m = mgr();
let id = m.begin(TransactionMode::ReadWrite).unwrap();
let tx = m.get_transaction(id).unwrap();
assert_eq!(tx.mode, TransactionMode::ReadWrite);
}
#[test]
fn test_is_active_after_commit() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
m.commit(id).unwrap();
assert!(!m.is_active(id));
}
#[test]
fn test_is_active_unknown_id() {
let m = mgr();
assert!(!m.is_active(42));
}
#[test]
fn test_tx_error_display() {
assert!(!TxError::NotFound(1).to_string().is_empty());
assert!(!TxError::AlreadyCommitted(1).to_string().is_empty());
assert!(!TxError::AlreadyAborted(1).to_string().is_empty());
assert!(!TxError::ReadOnly(1).to_string().is_empty());
assert!(!TxError::MaxConcurrentWritesExceeded.to_string().is_empty());
assert!(!TxError::WriteConflict.to_string().is_empty());
assert!(!TxError::TimedOut(1).to_string().is_empty());
}
#[test]
fn test_abort_after_timeout() {
let mut m = TransactionManager::new(10, 500);
let id = m.begin_at(TransactionMode::Write, 0).unwrap();
m.expire_timed_out(1000);
assert_eq!(m.abort(id), Err(TxError::TimedOut(id)));
}
#[test]
fn test_add_write_after_abort() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.abort(id).unwrap();
let result = m.add_write(
id,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
);
assert_eq!(result, Err(TxError::AlreadyAborted(id)));
}
#[test]
fn test_add_write_not_found() {
let mut m = mgr();
let result = m.add_write(
999,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
);
assert_eq!(result, Err(TxError::NotFound(999)));
}
#[test]
fn test_is_active_after_abort() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.abort(id).unwrap();
assert!(!m.is_active(id));
}
#[test]
fn test_active_count_decrements_after_commit() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
assert_eq!(m.active_count(), 1);
m.commit(id).unwrap();
assert_eq!(m.active_count(), 0);
}
#[test]
fn test_active_count_decrements_after_abort() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.abort(id).unwrap();
assert_eq!(m.active_count(), 0);
}
#[test]
fn test_write_count_after_abort() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.abort(id).unwrap();
assert_eq!(m.write_count(), 0);
}
#[test]
fn test_transaction_state_active() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
assert_eq!(
m.get_transaction(id).unwrap().state,
TransactionState::Active
);
}
#[test]
fn test_transaction_state_committed() {
let mut m = mgr();
let id = m.begin(TransactionMode::Read).unwrap();
m.commit(id).unwrap();
assert_eq!(
m.get_transaction(id).unwrap().state,
TransactionState::Committed
);
}
#[test]
fn test_transaction_state_aborted() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.abort(id).unwrap();
assert_eq!(
m.get_transaction(id).unwrap().state,
TransactionState::Aborted
);
}
#[test]
fn test_transaction_mode_allows_writes() {
assert!(TransactionMode::Write.allows_writes());
assert!(TransactionMode::ReadWrite.allows_writes());
assert!(!TransactionMode::Read.allows_writes());
}
#[test]
fn test_commit_clears_write_ops() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
for i in 0..3u32 {
m.add_write(
id,
WriteOp::Insert(format!("s{}", i), "p".into(), "o".into(), None),
)
.unwrap();
}
let ops = m.commit(id).unwrap();
assert_eq!(ops.len(), 3);
}
#[test]
fn test_abort_clears_write_ops() {
let mut m = mgr();
let id = m.begin(TransactionMode::Write).unwrap();
m.add_write(
id,
WriteOp::Insert("s".into(), "p".into(), "o".into(), None),
)
.unwrap();
m.abort(id).unwrap();
assert_eq!(m.get_transaction(id).unwrap().write_count(), 0);
}
}