use crate::operation::Operation;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use ucm_core::{Error, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TransactionId(pub String);
impl TransactionId {
pub fn generate() -> Self {
use chrono::Utc;
#[cfg(not(target_arch = "wasm32"))]
let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0);
#[cfg(target_arch = "wasm32")]
let ts = 0; Self(format!("txn_{:x}", ts))
}
pub fn named(name: impl Into<String>) -> Self {
Self(name.into())
}
}
impl std::fmt::Display for TransactionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransactionState {
Active,
Committed,
RolledBack,
TimedOut,
}
#[derive(Debug, Clone)]
pub struct Transaction {
pub id: TransactionId,
pub name: Option<String>,
pub operations: Vec<Operation>,
pub savepoints: Vec<Savepoint>,
pub state: TransactionState,
#[cfg(not(target_arch = "wasm32"))]
pub started_at: Instant,
#[cfg(not(target_arch = "wasm32"))]
pub created_at: DateTime<Utc>,
pub timeout: Duration,
}
#[derive(Debug, Clone)]
pub struct Savepoint {
pub name: String,
pub operation_index: usize,
#[cfg(not(target_arch = "wasm32"))]
pub created_at: DateTime<Utc>,
}
impl Transaction {
pub fn new(timeout: Duration) -> Self {
Self {
id: TransactionId::generate(),
name: None,
operations: Vec::new(),
savepoints: Vec::new(),
state: TransactionState::Active,
#[cfg(not(target_arch = "wasm32"))]
started_at: Instant::now(),
#[cfg(not(target_arch = "wasm32"))]
created_at: Utc::now(),
timeout,
}
}
pub fn named(name: impl Into<String>, timeout: Duration) -> Self {
let name = name.into();
Self {
id: TransactionId::named(&name),
name: Some(name),
operations: Vec::new(),
savepoints: Vec::new(),
state: TransactionState::Active,
#[cfg(not(target_arch = "wasm32"))]
started_at: Instant::now(),
#[cfg(not(target_arch = "wasm32"))]
created_at: Utc::now(),
timeout,
}
}
pub fn add_operation(&mut self, op: Operation) -> Result<()> {
if self.state != TransactionState::Active {
return Err(Error::Internal(format!(
"Cannot add operation to {:?} transaction",
self.state
)));
}
if self.is_timed_out() {
self.state = TransactionState::TimedOut;
return Err(Error::new(
ucm_core::ErrorCode::E301TransactionTimeout,
"Transaction timed out",
));
}
self.operations.push(op);
Ok(())
}
pub fn savepoint(&mut self, name: impl Into<String>) {
self.savepoints.push(Savepoint {
name: name.into(),
operation_index: self.operations.len(),
#[cfg(not(target_arch = "wasm32"))]
created_at: Utc::now(),
});
}
pub fn is_timed_out(&self) -> bool {
#[cfg(not(target_arch = "wasm32"))]
return self.started_at.elapsed() > self.timeout;
#[cfg(target_arch = "wasm32")]
false
}
pub fn elapsed(&self) -> Duration {
#[cfg(not(target_arch = "wasm32"))]
return self.started_at.elapsed();
#[cfg(target_arch = "wasm32")]
Duration::from_secs(0)
}
pub fn operation_count(&self) -> usize {
self.operations.len()
}
}
#[derive(Debug, Default)]
pub struct TransactionManager {
transactions: HashMap<TransactionId, Transaction>,
default_timeout: Duration,
}
impl TransactionManager {
pub fn new() -> Self {
Self {
transactions: HashMap::new(),
default_timeout: Duration::from_secs(30),
}
}
pub fn with_timeout(timeout: Duration) -> Self {
Self {
transactions: HashMap::new(),
default_timeout: timeout,
}
}
pub fn begin(&mut self) -> TransactionId {
let txn = Transaction::new(self.default_timeout);
let id = txn.id.clone();
self.transactions.insert(id.clone(), txn);
id
}
pub fn begin_named(&mut self, name: impl Into<String>) -> TransactionId {
let txn = Transaction::named(name, self.default_timeout);
let id = txn.id.clone();
self.transactions.insert(id.clone(), txn);
id
}
pub fn get(&self, id: &TransactionId) -> Option<&Transaction> {
self.transactions.get(id)
}
pub fn get_mut(&mut self, id: &TransactionId) -> Option<&mut Transaction> {
self.transactions.get_mut(id)
}
pub fn add_operation(&mut self, id: &TransactionId, op: Operation) -> Result<()> {
let txn = self.transactions.get_mut(id).ok_or_else(|| {
Error::new(ucm_core::ErrorCode::E303TransactionNotFound, id.to_string())
})?;
txn.add_operation(op)
}
pub fn commit(&mut self, id: &TransactionId) -> Result<Vec<Operation>> {
let txn = self.transactions.get_mut(id).ok_or_else(|| {
Error::new(ucm_core::ErrorCode::E303TransactionNotFound, id.to_string())
})?;
if txn.state != TransactionState::Active {
return Err(Error::Internal(format!(
"Cannot commit {:?} transaction",
txn.state
)));
}
if txn.is_timed_out() {
txn.state = TransactionState::TimedOut;
return Err(Error::new(
ucm_core::ErrorCode::E301TransactionTimeout,
"Transaction timed out",
));
}
txn.state = TransactionState::Committed;
Ok(txn.operations.clone())
}
pub fn rollback(&mut self, id: &TransactionId) -> Result<()> {
let txn = self.transactions.get_mut(id).ok_or_else(|| {
Error::new(ucm_core::ErrorCode::E303TransactionNotFound, id.to_string())
})?;
if txn.state != TransactionState::Active {
return Err(Error::Internal(format!(
"Cannot rollback {:?} transaction",
txn.state
)));
}
txn.state = TransactionState::RolledBack;
Ok(())
}
pub fn cleanup(&mut self) {
self.transactions
.retain(|_, txn| txn.state == TransactionState::Active && !txn.is_timed_out());
}
pub fn active_count(&self) -> usize {
self.transactions
.values()
.filter(|t| t.state == TransactionState::Active)
.count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::operation::PruneCondition;
#[test]
fn test_transaction_lifecycle() {
let mut mgr = TransactionManager::new();
let id = mgr.begin();
assert_eq!(mgr.active_count(), 1);
mgr.add_operation(
&id,
Operation::Prune {
condition: Some(PruneCondition::Unreachable),
},
)
.unwrap();
let ops = mgr.commit(&id).unwrap();
assert_eq!(ops.len(), 1);
}
#[test]
fn test_named_transaction() {
let mut mgr = TransactionManager::new();
let id = mgr.begin_named("my-transaction");
assert_eq!(id.0, "my-transaction");
}
#[test]
fn test_rollback() {
let mut mgr = TransactionManager::new();
let id = mgr.begin();
mgr.rollback(&id).unwrap();
let txn = mgr.get(&id).unwrap();
assert_eq!(txn.state, TransactionState::RolledBack);
}
#[test]
fn test_timeout() {
let mut mgr = TransactionManager::with_timeout(Duration::from_millis(1));
let id = mgr.begin();
std::thread::sleep(Duration::from_millis(10));
let result = mgr.commit(&id);
assert!(result.is_err());
}
}