use std::collections::HashMap;
use thiserror::Error;
use crate::runtime::values::Value;
pub trait StateStorage {
fn get(&self, key: &str) -> Option<Value>;
fn set(&mut self, key: &str, value: Value);
}
#[derive(Default)]
pub struct InMemoryStorage {
state: HashMap<String, Value>,
}
impl InMemoryStorage {
pub fn new() -> Self {
Self::default()
}
pub fn from_map(map: HashMap<String, Value>) -> Self {
Self { state: map }
}
}
impl StateStorage for InMemoryStorage {
fn get(&self, key: &str) -> Option<Value> {
self.state.get(key).cloned()
}
fn set(&mut self, key: &str, value: Value) {
self.state.insert(key.to_string(), value);
}
}
#[derive(Error, Debug, Clone)]
pub enum TransactionError {
#[error("Transaction not found: {0}")]
NotFound(String),
#[error("Transaction already active")]
AlreadyActive,
#[error("No active transaction")]
NoActiveTransaction,
#[error("Transaction conflict detected")]
Conflict,
#[error("Deadlock detected")]
Deadlock,
#[error("Transaction timeout")]
Timeout,
#[error("Rollback failed: {0}")]
RollbackFailed(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
ReadUncommitted, ReadCommitted, RepeatableRead, Serializable, }
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransactionState {
Active,
Preparing, Committed,
RolledBack,
Failed,
}
#[derive(Debug, Clone)]
pub struct Savepoint {
pub name: String,
pub state_snapshot: HashMap<String, Value>,
pub timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct Transaction {
pub id: String,
pub state: TransactionState,
pub isolation_level: IsolationLevel,
pub start_time: u64,
pub timeout_ms: Option<u64>,
pub original_state: HashMap<String, Value>,
pub modified_state: HashMap<String, Value>,
pub savepoints: Vec<Savepoint>,
pub participants: Vec<String>, pub is_distributed: bool,
}
pub struct TransactionManager {
active_transactions: HashMap<String, Transaction>,
transaction_counter: u64,
storage: Box<dyn StateStorage>,
read_locks: HashMap<String, Vec<String>>, write_locks: HashMap<String, String>, }
impl Transaction {
pub fn new(id: String, isolation_level: IsolationLevel) -> Self {
Self {
id,
state: TransactionState::Active,
isolation_level,
start_time: get_current_timestamp(),
timeout_ms: Some(30000), original_state: HashMap::new(),
modified_state: HashMap::new(),
savepoints: Vec::new(),
participants: Vec::new(),
is_distributed: false,
}
}
pub fn is_timed_out(&self) -> bool {
if let Some(timeout) = self.timeout_ms {
let elapsed = get_current_timestamp() - self.start_time;
elapsed > timeout
} else {
false
}
}
pub fn create_savepoint(&mut self, name: String) {
let savepoint = Savepoint {
name,
state_snapshot: self.modified_state.clone(),
timestamp: get_current_timestamp(),
};
self.savepoints.push(savepoint);
}
pub fn rollback_to_savepoint(&mut self, name: &str) -> Result<(), TransactionError> {
if let Some(pos) = self.savepoints.iter().position(|sp| sp.name == name) {
let savepoint = &self.savepoints[pos];
self.modified_state = savepoint.state_snapshot.clone();
self.savepoints.truncate(pos + 1);
Ok(())
} else {
Err(TransactionError::NotFound(format!(
"Savepoint '{}' not found",
name
)))
}
}
}
impl TransactionManager {
pub fn new() -> Self {
Self::with_storage(Box::new(InMemoryStorage::new()))
}
pub fn with_storage(storage: Box<dyn StateStorage>) -> Self {
Self {
active_transactions: HashMap::new(),
transaction_counter: 0,
storage,
read_locks: HashMap::new(),
write_locks: HashMap::new(),
}
}
pub fn get_committed(&self, key: &str) -> Option<Value> {
self.storage.get(key)
}
pub fn get_transaction(&self, tx_id: &str) -> Option<&Transaction> {
self.active_transactions.get(tx_id)
}
pub fn set_transaction_timeout(&mut self, tx_id: &str, timeout_ms: Option<u64>) -> Result<(), TransactionError> {
let tx = self.active_transactions.get_mut(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
tx.timeout_ms = timeout_ms;
Ok(())
}
pub fn begin_transaction(&mut self, isolation_level: IsolationLevel) -> Result<String, TransactionError> {
self.transaction_counter += 1;
let tx_id = format!("tx_{}", self.transaction_counter);
let transaction = Transaction::new(tx_id.clone(), isolation_level);
self.active_transactions.insert(tx_id.clone(), transaction);
Ok(tx_id)
}
pub fn read(&mut self, tx_id: &str, key: &str) -> Result<Option<Value>, TransactionError> {
let (should_lock, is_timed_out, modified_value) = {
let tx = self.active_transactions.get(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
if tx.state != TransactionState::Active {
return Err(TransactionError::NoActiveTransaction);
}
let should_lock = tx.isolation_level != IsolationLevel::ReadUncommitted;
let is_timed_out = tx.is_timed_out();
let modified_value = tx.modified_state.get(key).cloned();
(should_lock, is_timed_out, modified_value)
};
if is_timed_out {
return Err(TransactionError::Timeout);
}
if should_lock {
self.acquire_read_lock(tx_id, key)?;
}
if let Some(value) = modified_value {
return Ok(Some(value));
}
Ok(self.storage.get(key))
}
pub fn write(&mut self, tx_id: &str, key: String, value: Value) -> Result<(), TransactionError> {
self.acquire_write_lock(tx_id, &key)?;
let tx = self.active_transactions.get_mut(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
if tx.state != TransactionState::Active {
return Err(TransactionError::NoActiveTransaction);
}
if !tx.original_state.contains_key(&key) {
if let Some(original) = self.storage.get(&key) {
tx.original_state.insert(key.clone(), original);
}
}
tx.modified_state.insert(key, value);
Ok(())
}
pub fn commit(&mut self, tx_id: &str) -> Result<(), TransactionError> {
let tx = self.active_transactions.get_mut(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
if tx.state != TransactionState::Active {
return Err(TransactionError::NoActiveTransaction);
}
if tx.is_timed_out() {
self.rollback(tx_id)?;
return Err(TransactionError::Timeout);
}
if tx.is_distributed {
return self.two_phase_commit(tx_id);
}
for (key, value) in &tx.modified_state {
self.storage.set(key, value.clone());
}
tx.state = TransactionState::Committed;
self.release_locks(tx_id);
self.active_transactions.remove(tx_id);
Ok(())
}
pub fn rollback(&mut self, tx_id: &str) -> Result<(), TransactionError> {
let tx = self.active_transactions.get_mut(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
tx.state = TransactionState::RolledBack;
self.release_locks(tx_id);
self.active_transactions.remove(tx_id);
Ok(())
}
fn two_phase_commit(&mut self, tx_id: &str) -> Result<(), TransactionError> {
let tx = self.active_transactions.get_mut(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
tx.state = TransactionState::Preparing;
for (key, value) in &tx.modified_state {
self.storage.set(key, value.clone());
}
tx.state = TransactionState::Committed;
self.release_locks(tx_id);
self.active_transactions.remove(tx_id);
Ok(())
}
fn acquire_read_lock(&mut self, tx_id: &str, key: &str) -> Result<(), TransactionError> {
if let Some(write_owner) = self.write_locks.get(key) {
if write_owner != tx_id {
return Err(TransactionError::Conflict);
}
}
self.read_locks
.entry(key.to_string())
.or_insert_with(Vec::new)
.push(tx_id.to_string());
Ok(())
}
fn acquire_write_lock(&mut self, tx_id: &str, key: &str) -> Result<(), TransactionError> {
if let Some(write_owner) = self.write_locks.get(key) {
if write_owner != tx_id {
return Err(TransactionError::Conflict);
}
}
if let Some(readers) = self.read_locks.get(key) {
if readers.iter().any(|r| r != tx_id) {
return Err(TransactionError::Conflict);
}
}
self.write_locks.insert(key.to_string(), tx_id.to_string());
Ok(())
}
fn release_locks(&mut self, tx_id: &str) {
self.read_locks.retain(|_, readers| {
readers.retain(|r| r != tx_id);
!readers.is_empty()
});
self.write_locks.retain(|_, owner| owner != tx_id);
}
pub fn create_savepoint(&mut self, tx_id: &str, name: String) -> Result<(), TransactionError> {
let tx = self.active_transactions.get_mut(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
tx.create_savepoint(name);
Ok(())
}
pub fn rollback_to_savepoint(&mut self, tx_id: &str, name: &str) -> Result<(), TransactionError> {
let tx = self.active_transactions.get_mut(tx_id)
.ok_or_else(|| TransactionError::NotFound(tx_id.to_string()))?;
tx.rollback_to_savepoint(name)
}
}
impl Default for TransactionManager {
fn default() -> Self {
Self::new()
}
}
fn get_current_timestamp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_begin_commit() {
let mut manager = TransactionManager::new();
let tx_id = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
assert!(manager.active_transactions.contains_key(&tx_id));
manager.write(&tx_id, "key1".to_string(), Value::Int(42)).unwrap();
manager.commit(&tx_id).unwrap();
assert!(!manager.get_transaction(&tx_id).is_some());
assert_eq!(manager.get_committed("key1"), Some(Value::Int(42)));
}
#[test]
fn test_transaction_rollback() {
let mut manager = TransactionManager::with_storage(Box::new(
InMemoryStorage::from_map(HashMap::from([("key1".to_string(), Value::Int(10))])),
));
let tx_id = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
manager.write(&tx_id, "key1".to_string(), Value::Int(42)).unwrap();
manager.rollback(&tx_id).unwrap();
assert_eq!(manager.get_committed("key1"), Some(Value::Int(10)));
}
#[test]
fn test_savepoint_rollback() {
let mut manager = TransactionManager::new();
let tx_id = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
manager.write(&tx_id, "key1".to_string(), Value::Int(1)).unwrap();
manager.create_savepoint(&tx_id, "sp1".to_string()).unwrap();
manager.write(&tx_id, "key1".to_string(), Value::Int(2)).unwrap();
manager.rollback_to_savepoint(&tx_id, "sp1").unwrap();
let tx = manager.get_transaction(&tx_id).unwrap();
assert_eq!(tx.modified_state.get("key1"), Some(&Value::Int(1)));
}
#[test]
fn test_isolation_read_committed() {
let mut manager = TransactionManager::with_storage(Box::new(
InMemoryStorage::from_map(HashMap::from([("counter".to_string(), Value::Int(0))])),
));
let tx1 = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
manager.write(&tx1, "counter".to_string(), Value::Int(1)).unwrap();
let tx2 = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
let _read_result = manager.read(&tx2, "counter");
manager.commit(&tx1).unwrap();
let value = manager.read(&tx2, "counter").unwrap();
assert_eq!(value, Some(Value::Int(1)));
manager.commit(&tx2).unwrap();
}
#[test]
fn test_write_conflict() {
let mut manager = TransactionManager::new();
let tx1 = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
let tx2 = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
manager.write(&tx1, "key1".to_string(), Value::Int(1)).unwrap();
let result = manager.write(&tx2, "key1".to_string(), Value::Int(2));
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TransactionError::Conflict));
}
#[test]
fn test_transaction_timeout() {
let mut manager = TransactionManager::new();
let tx_id = manager.begin_transaction(IsolationLevel::ReadCommitted).unwrap();
manager.set_transaction_timeout(&tx_id, Some(1)).unwrap();
std::thread::sleep(std::time::Duration::from_millis(10));
let result = manager.commit(&tx_id);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TransactionError::Timeout));
}
}