use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use crate::exec::error::ExecutionError;
use crate::session::SessionManager;
use super::isolation::IsolationLevel;
use super::state::{AccessMode, OperationType, TransactionId, TransactionState, TxnIsolationLevel};
use super::wal::{PersistentWAL, WALEntry, WALEntryType};
pub struct TransactionManager {
active_transactions: Arc<RwLock<HashMap<TransactionId, Arc<Mutex<TransactionState>>>>>,
default_isolation_level: IsolationLevel,
next_transaction_characteristics: Arc<Mutex<Option<(IsolationLevel, AccessMode)>>>,
wal: Arc<PersistentWAL>,
session_manager: Option<Arc<SessionManager>>,
}
impl TransactionManager {
pub fn new(db_path: std::path::PathBuf) -> Result<Self, ExecutionError> {
let wal = PersistentWAL::new(db_path).map_err(|e| {
ExecutionError::RuntimeError(format!("Failed to initialize WAL: {}", e))
})?;
Ok(Self {
active_transactions: Arc::new(RwLock::new(HashMap::new())),
default_isolation_level: IsolationLevel::ReadCommitted,
next_transaction_characteristics: Arc::new(Mutex::new(None)),
wal: Arc::new(wal),
session_manager: None,
})
}
pub fn start_transaction(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<TransactionId, ExecutionError> {
self.start_transaction_with_session(isolation_level, access_mode, None)
}
pub fn start_transaction_with_session(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
session_id: Option<String>,
) -> Result<TransactionId, ExecutionError> {
let (final_isolation_level, final_access_mode) = {
let next_characteristics = self
.next_transaction_characteristics
.lock()
.map_err(|_| ExecutionError::RuntimeError("Failed to acquire lock".to_string()))?;
match (isolation_level, access_mode, next_characteristics.as_ref()) {
(Some(iso), Some(acc), _) => (iso, acc),
(Some(iso), None, Some((_, next_acc))) => (iso, *next_acc),
(Some(iso), None, None) => (iso, AccessMode::ReadWrite),
(None, Some(acc), Some((next_iso, _))) => (*next_iso, acc),
(None, Some(acc), None) => (self.default_isolation_level, acc),
(None, None, Some((next_iso, next_acc))) => (*next_iso, *next_acc),
(None, None, None) => (self.default_isolation_level, AccessMode::ReadWrite),
}
};
*self
.next_transaction_characteristics
.lock()
.map_err(|_| ExecutionError::RuntimeError("Failed to acquire lock".to_string()))? =
None;
let txn_isolation_level = match final_isolation_level {
IsolationLevel::ReadUncommitted => TxnIsolationLevel::ReadUncommitted,
IsolationLevel::ReadCommitted => TxnIsolationLevel::ReadCommitted,
IsolationLevel::RepeatableRead => TxnIsolationLevel::RepeatableRead,
IsolationLevel::Serializable => TxnIsolationLevel::Serializable,
};
if !matches!(final_isolation_level, IsolationLevel::ReadCommitted) {
return Err(ExecutionError::UnsupportedOperator(
format!("Isolation level {} not yet supported. Only READ COMMITTED is currently implemented.",
final_isolation_level.as_str())
));
}
if let Some(ref session_id) = session_id {
if let Some(ref session_manager) = self.session_manager {
if session_manager.get_session(session_id).is_none() {
return Err(ExecutionError::RuntimeError(format!(
"Session {} not found",
session_id
)));
}
}
}
let mut transaction = if let Some(session_id) = session_id.clone() {
TransactionState::new_with_session(txn_isolation_level, final_access_mode, session_id)
} else {
TransactionState::new(txn_isolation_level, final_access_mode)
};
let transaction_id = transaction.id;
let begin_description = if let Some(ref session_id) = session_id {
format!(
"BEGIN TRANSACTION (Session: {}) - {} isolation level, {} access mode",
session_id,
final_isolation_level.as_str(),
if final_access_mode == AccessMode::ReadOnly {
"READ ONLY"
} else {
"READ WRITE"
}
)
} else {
format!(
"BEGIN TRANSACTION - {} isolation level, {} access mode",
final_isolation_level.as_str(),
if final_access_mode == AccessMode::ReadOnly {
"READ ONLY"
} else {
"READ WRITE"
}
)
};
let wal_entry = WALEntry::new(
WALEntryType::TransactionBegin,
transaction_id,
self.wal.next_global_sequence(),
0, None, begin_description.clone(),
);
if let Err(e) = self.wal.write_entry(wal_entry) {
return Err(ExecutionError::RuntimeError(format!(
"Failed to write BEGIN to WAL: {}",
e
)));
}
transaction.add_operation(OperationType::Other, begin_description);
let mut active_txns = self.active_transactions.write().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
active_txns.insert(transaction_id, Arc::new(Mutex::new(transaction)));
Ok(transaction_id)
}
pub fn commit_transaction(&self, transaction_id: TransactionId) -> Result<(), ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
if let Some(txn_arc) = active_txns.get(&transaction_id) {
let mut transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
if !transaction.is_active() {
return Err(ExecutionError::RuntimeError(format!(
"Transaction {} is not active",
transaction_id
)));
}
let final_sequence = transaction.get_sequence_number();
let commit_description = format!(
"COMMIT TRANSACTION - final sequence number: {}",
final_sequence
);
let wal_entry = WALEntry::new(
WALEntryType::TransactionCommit,
transaction_id,
self.wal.next_global_sequence(),
final_sequence,
None,
commit_description.clone(),
);
if let Err(e) = self.wal.write_entry(wal_entry) {
return Err(ExecutionError::RuntimeError(format!(
"Failed to write COMMIT to WAL: {}",
e
)));
}
transaction.add_operation(OperationType::Other, commit_description);
transaction.commit();
Ok(())
} else {
Err(ExecutionError::RuntimeError(format!(
"Transaction {} not found",
transaction_id
)))
}
}
pub fn rollback_transaction(
&self,
transaction_id: TransactionId,
) -> Result<(), ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
if let Some(txn_arc) = active_txns.get(&transaction_id) {
let mut transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
if !transaction.is_active() {
return Err(ExecutionError::RuntimeError(format!(
"Transaction {} is not active",
transaction_id
)));
}
let final_sequence = transaction.get_sequence_number();
let rollback_description = format!(
"ROLLBACK TRANSACTION - final sequence number: {}",
final_sequence
);
let wal_entry = WALEntry::new(
WALEntryType::TransactionRollback,
transaction_id,
self.wal.next_global_sequence(),
final_sequence,
None,
rollback_description.clone(),
);
if let Err(e) = self.wal.write_entry(wal_entry) {
return Err(ExecutionError::RuntimeError(format!(
"Failed to write ROLLBACK to WAL: {}",
e
)));
}
transaction.add_operation(OperationType::Other, rollback_description);
transaction.rollback();
Ok(())
} else {
Err(ExecutionError::RuntimeError(format!(
"Transaction {} not found",
transaction_id
)))
}
}
pub fn set_next_transaction_characteristics(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<(), ExecutionError> {
let mut next_characteristics = self
.next_transaction_characteristics
.lock()
.map_err(|_| ExecutionError::RuntimeError("Failed to acquire lock".to_string()))?;
let current = next_characteristics.as_ref();
let final_isolation_level = isolation_level.unwrap_or_else(|| {
current
.map(|(iso, _)| *iso)
.unwrap_or(self.default_isolation_level)
});
let final_access_mode = access_mode.unwrap_or_else(|| {
current
.map(|(_, acc)| *acc)
.unwrap_or(AccessMode::ReadWrite)
});
*next_characteristics = Some((final_isolation_level, final_access_mode));
Ok(())
}
pub fn get_transaction(
&self,
transaction_id: TransactionId,
) -> Result<Option<Arc<Mutex<TransactionState>>>, ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
Ok(active_txns.get(&transaction_id).cloned())
}
pub fn get_active_transaction_ids(&self) -> Result<Vec<TransactionId>, ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
let mut active_ids = Vec::new();
for (id, txn_arc) in active_txns.iter() {
let transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
if transaction.is_active() {
active_ids.push(*id);
}
}
Ok(active_ids)
}
pub fn cleanup_completed_transactions(&self) -> Result<usize, ExecutionError> {
let mut active_txns = self.active_transactions.write().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
let mut to_remove = Vec::new();
for (id, txn_arc) in active_txns.iter() {
let transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
if !transaction.is_active() {
to_remove.push(*id);
}
}
let removed_count = to_remove.len();
for id in to_remove {
active_txns.remove(&id);
}
Ok(removed_count)
}
pub fn log_operation(
&self,
transaction_id: TransactionId,
operation_type: OperationType,
description: String,
) -> Result<(), ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
if let Some(txn_arc) = active_txns.get(&transaction_id) {
let mut transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
if !transaction.is_active() {
return Err(ExecutionError::RuntimeError(format!(
"Transaction {} is not active",
transaction_id
)));
}
transaction.add_operation(operation_type.clone(), description.clone());
let txn_sequence = transaction.get_sequence_number();
let wal_entry = WALEntry::new(
WALEntryType::TransactionOperation,
transaction_id,
self.wal.next_global_sequence(),
txn_sequence,
Some(operation_type),
description,
);
if let Err(e) = self.wal.write_entry(wal_entry) {
return Err(ExecutionError::RuntimeError(format!(
"Failed to write operation to WAL: {}",
e
)));
}
Ok(())
} else {
Err(ExecutionError::RuntimeError(format!(
"Transaction {} not found",
transaction_id
)))
}
}
pub fn get_session_transactions(
&self,
session_id: &str,
) -> Result<Vec<TransactionId>, ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
let mut session_txns = Vec::new();
for (id, txn_arc) in active_txns.iter() {
let transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
if transaction.is_active() {
if let Some(ref txn_session_id) = transaction.session_id {
if txn_session_id == session_id {
session_txns.push(*id);
}
}
}
}
Ok(session_txns)
}
pub fn get_transaction_current_graph(
&self,
transaction_id: TransactionId,
) -> Result<Option<String>, ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
if let Some(txn_arc) = active_txns.get(&transaction_id) {
let transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
if let Some(ref session_id) = transaction.session_id {
if let Some(ref session_manager) = self.session_manager {
if let Some(session_arc) = session_manager.get_session(session_id) {
if let Ok(session) = session_arc.read() {
return Ok(session.current_graph.clone());
}
}
}
}
}
Ok(None)
}
pub fn get_statistics(&self) -> Result<TransactionStatistics, ExecutionError> {
let active_txns = self.active_transactions.read().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transactions lock".to_string())
})?;
let mut stats = TransactionStatistics::default();
for txn_arc in active_txns.values() {
let transaction = txn_arc.lock().map_err(|_| {
ExecutionError::RuntimeError("Failed to acquire transaction lock".to_string())
})?;
stats.total_transactions += 1;
match &transaction.status {
super::state::TransactionStatus::Active => stats.active_transactions += 1,
super::state::TransactionStatus::Committed => stats.committed_transactions += 1,
super::state::TransactionStatus::RolledBack => stats.rolled_back_transactions += 1,
super::state::TransactionStatus::Failed(_) => stats.failed_transactions += 1,
_ => {}
}
}
Ok(stats)
}
}
#[derive(Debug, Clone, Default)]
pub struct TransactionStatistics {
pub total_transactions: u64,
pub active_transactions: u64,
pub committed_transactions: u64,
pub rolled_back_transactions: u64,
pub failed_transactions: u64,
}