use rustc_hash::FxHashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::core::{Error, IsolationLevel, Result, Schema, SchemaColumn};
use crate::storage::mvcc::{get_fast_timestamp, MvccError, TransactionRegistry};
use crate::storage::traits::{QueryResult, Table, Transaction};
use crate::storage::Expression;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
Active,
Committing,
Committed,
RolledBack,
}
pub struct MvccTransaction {
id: i64,
start_time: Instant,
state: TransactionState,
tables: FxHashMap<String, Box<dyn Table>>,
isolation_level: Option<IsolationLevel>,
registry: Arc<TransactionRegistry>,
begin_seq: i64,
last_table_name: Option<String>,
engine_operations: Option<Arc<dyn TransactionEngineOperations>>,
savepoints: FxHashMap<String, i64>,
created_tables: Vec<String>,
dropped_tables: Vec<(String, Schema)>,
}
pub trait TransactionEngineOperations: Send + Sync {
fn get_table_for_transaction(&self, txn_id: i64, table_name: &str) -> Result<Box<dyn Table>>;
fn create_table(&self, name: &str, schema: Schema) -> Result<Box<dyn Table>>;
fn drop_table(&self, name: &str) -> Result<()>;
fn list_tables(&self) -> Result<Vec<String>>;
fn create_schema(&self, name: &str) -> Result<()>;
fn drop_schema(&self, name: &str) -> Result<()>;
fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>;
fn commit_table(&self, txn_id: i64, table: &dyn Table) -> Result<()>;
fn rollback_table(&self, txn_id: i64, table: &dyn Table);
fn record_commit(&self, txn_id: i64) -> Result<()>;
fn record_rollback(&self, txn_id: i64) -> Result<()>;
fn get_tables_with_pending_changes(&self, txn_id: i64) -> Result<Vec<Box<dyn Table>>>;
fn commit_all_tables(&self, txn_id: i64) -> Result<()>;
}
impl MvccTransaction {
pub fn new(id: i64, begin_seq: i64, registry: Arc<TransactionRegistry>) -> Self {
Self {
id,
start_time: Instant::now(),
state: TransactionState::Active,
tables: FxHashMap::default(),
isolation_level: None,
registry,
begin_seq,
last_table_name: None,
engine_operations: None,
savepoints: FxHashMap::default(),
created_tables: Vec::new(),
dropped_tables: Vec::new(),
}
}
pub fn set_engine_operations(&mut self, ops: Arc<dyn TransactionEngineOperations>) {
self.engine_operations = Some(ops);
}
pub fn start_time(&self) -> Instant {
self.start_time
}
pub fn begin_seq(&self) -> i64 {
self.begin_seq
}
pub fn state(&self) -> TransactionState {
self.state
}
pub fn get_isolation_level(&self) -> IsolationLevel {
self.isolation_level
.unwrap_or_else(|| self.registry.get_global_isolation_level())
}
fn check_active(&self) -> Result<()> {
if self.state != TransactionState::Active {
return Err(MvccError::TransactionClosed.into());
}
Ok(())
}
pub fn create_schema(&mut self, _name: &str) -> Result<()> {
self.check_active()?;
Ok(())
}
pub fn drop_schema(&mut self, _name: &str) -> Result<()> {
self.check_active()?;
Ok(())
}
fn get_engine_ops(&self) -> Result<&Arc<dyn TransactionEngineOperations>> {
self.engine_operations
.as_ref()
.ok_or_else(|| Error::internal("engine operations not set"))
}
fn cleanup(&mut self) {
self.last_table_name = None;
self.tables.clear();
self.created_tables.clear();
self.dropped_tables.clear();
self.registry.remove_transaction_isolation_level(self.id);
}
fn is_read_only(&self) -> bool {
if !self.created_tables.is_empty() || !self.dropped_tables.is_empty() {
return false;
}
for table in self.tables.values() {
if table.has_local_changes() {
return false;
}
}
true
}
pub fn create_savepoint(&mut self, name: &str) -> Result<()> {
self.check_active()?;
let timestamp = get_fast_timestamp();
self.savepoints.insert(name.to_string(), timestamp);
Ok(())
}
pub fn release_savepoint(&mut self, name: &str) -> Result<()> {
self.check_active()?;
if self.savepoints.remove(name).is_none() {
return Err(Error::invalid_argument(format!(
"savepoint '{}' does not exist",
name
)));
}
Ok(())
}
pub fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
self.check_active()?;
let savepoint_ts = self.savepoints.get(name).copied().ok_or_else(|| {
Error::invalid_argument(format!("savepoint '{}' does not exist", name))
})?;
for table in self.tables.values() {
table.rollback_to_timestamp(savepoint_ts);
}
self.savepoints.retain(|_, &mut ts| ts <= savepoint_ts);
Ok(())
}
pub fn has_savepoint(&self, name: &str) -> bool {
self.savepoints.contains_key(name)
}
pub fn get_savepoint_ts(&self, name: &str) -> Option<i64> {
self.savepoints.get(name).copied()
}
}
impl Transaction for MvccTransaction {
fn id(&self) -> i64 {
self.id
}
fn begin(&mut self) -> Result<()> {
self.check_active()
}
fn commit(&mut self) -> Result<()> {
self.check_active()?;
self.state = TransactionState::Committing;
let tables_with_changes = if let Some(ops) = &self.engine_operations {
ops.get_tables_with_pending_changes(self.id)?
} else {
Vec::new()
};
let is_read_only = self.created_tables.is_empty()
&& self.dropped_tables.is_empty()
&& tables_with_changes.is_empty();
if !is_read_only {
self.registry.start_commit(self.id);
if let Some(ops) = &self.engine_operations {
for table in tables_with_changes.iter() {
if let Err(e) = ops.commit_table(self.id, table.as_ref()) {
self.registry.abort_transaction(self.id);
self.state = TransactionState::RolledBack;
self.cleanup();
return Err(e);
}
}
}
if let Some(ops) = &self.engine_operations {
if let Err(e) = ops.commit_all_tables(self.id) {
self.registry.abort_transaction(self.id);
self.state = TransactionState::RolledBack;
self.cleanup();
return Err(e);
}
}
self.registry.complete_commit(self.id);
if let Some(ops) = &self.engine_operations {
let _ = ops.record_commit(self.id);
}
} else {
self.registry.complete_commit(self.id);
}
self.state = TransactionState::Committed;
self.cleanup();
Ok(())
}
fn rollback(&mut self) -> Result<()> {
self.check_active()?;
let is_read_only = self.is_read_only();
self.registry.abort_transaction(self.id);
let mut ddl_rollback_errors: Vec<String> = Vec::new();
if let Some(ops) = &self.engine_operations {
for table_name in self.created_tables.iter().rev() {
if let Err(e) = ops.drop_table(table_name) {
ddl_rollback_errors.push(format!(
"Failed to drop table '{}' during rollback: {}",
table_name, e
));
}
}
for (table_name, schema) in self.dropped_tables.iter().rev() {
if let Err(e) = ops.create_table(table_name, schema.clone()) {
ddl_rollback_errors.push(format!(
"Failed to recreate table '{}' during rollback: {}",
table_name, e
));
} else {
eprintln!(
"Warning: Table '{}' was recreated during rollback but data was lost. \
DROP TABLE is not fully transactional.",
table_name
);
}
}
}
if !ddl_rollback_errors.is_empty() {
eprintln!(
"Warning: DDL rollback encountered {} error(s): {:?}",
ddl_rollback_errors.len(),
ddl_rollback_errors
);
}
for (_, table) in self.tables.iter_mut() {
table.rollback();
}
if let Some(ops) = &self.engine_operations {
for (_, table) in self.tables.iter() {
ops.rollback_table(self.id, table.as_ref());
}
}
if !is_read_only {
if let Some(ops) = &self.engine_operations {
let _ = ops.record_rollback(self.id);
}
}
self.state = TransactionState::RolledBack;
self.cleanup();
Ok(())
}
fn create_savepoint(&mut self, name: &str) -> Result<()> {
MvccTransaction::create_savepoint(self, name)
}
fn release_savepoint(&mut self, name: &str) -> Result<()> {
MvccTransaction::release_savepoint(self, name)
}
fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
MvccTransaction::rollback_to_savepoint(self, name)
}
fn get_savepoint_timestamp(&self, name: &str) -> Option<i64> {
MvccTransaction::get_savepoint_ts(self, name)
}
fn set_isolation_level(&mut self, level: IsolationLevel) -> Result<()> {
self.check_active()?;
self.isolation_level = Some(level);
self.registry
.set_transaction_isolation_level(self.id, level);
Ok(())
}
fn create_table(&mut self, name: &str, schema: Schema) -> Result<Box<dyn Table>> {
self.check_active()?;
let ops = self.get_engine_ops()?;
let table = ops.create_table(name, schema)?;
self.created_tables.push(name.to_lowercase());
Ok(table)
}
fn drop_table(&mut self, name: &str) -> Result<()> {
self.check_active()?;
let schema = {
let ops = self.get_engine_ops()?;
let table = ops.get_table_for_transaction(self.id, name)?;
table.schema().clone()
};
self.dropped_tables.push((name.to_lowercase(), schema));
let ops = self.get_engine_ops()?;
ops.drop_table(name)?;
self.tables.remove(name);
if let Some(last_name) = &self.last_table_name {
if last_name == name {
self.last_table_name = None;
}
}
Ok(())
}
fn get_table(&self, name: &str) -> Result<Box<dyn Table>> {
self.check_active()?;
let ops = self.get_engine_ops()?;
ops.get_table_for_transaction(self.id, name)
}
fn list_tables(&self) -> Result<Vec<String>> {
self.check_active()?;
let ops = self.get_engine_ops()?;
let mut tables = ops.list_tables()?;
tables.retain(|t| !self.dropped_tables.iter().any(|(name, _)| name == t));
Ok(tables)
}
fn rename_table(&mut self, old_name: &str, new_name: &str) -> Result<()> {
self.check_active()?;
let ops = self.get_engine_ops()?;
ops.rename_table(old_name, new_name)?;
if let Some(table) = self.tables.remove(old_name) {
self.tables.insert(new_name.to_string(), table);
}
if let Some(last_name) = &self.last_table_name {
if last_name == old_name {
self.last_table_name = Some(new_name.to_string());
}
}
Ok(())
}
fn create_schema(&mut self, _name: &str) -> Result<()> {
self.check_active()?;
Ok(())
}
fn drop_schema(&mut self, _name: &str) -> Result<()> {
self.check_active()?;
Ok(())
}
fn create_table_index(
&mut self,
table_name: &str,
index_name: &str,
columns: &[String],
is_unique: bool,
) -> Result<()> {
self.check_active()?;
let table = self.get_table(table_name)?;
let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
table.create_index(index_name, &col_refs, is_unique)
}
fn drop_table_index(&mut self, table_name: &str, index_name: &str) -> Result<()> {
self.check_active()?;
let table = self.get_table(table_name)?;
table.drop_index(index_name)
}
fn create_table_btree_index(
&mut self,
table_name: &str,
column_name: &str,
is_unique: bool,
custom_name: Option<&str>,
) -> Result<()> {
self.check_active()?;
let table = self.get_table(table_name)?;
table.create_btree_index(column_name, is_unique, custom_name)
}
fn drop_table_btree_index(&mut self, table_name: &str, column_name: &str) -> Result<()> {
self.check_active()?;
let table = self.get_table(table_name)?;
table.drop_btree_index(column_name)
}
fn add_table_column(&mut self, table_name: &str, column: SchemaColumn) -> Result<()> {
self.check_active()?;
let mut table = self.get_table(table_name)?;
table.create_column(&column.name, column.data_type, column.nullable)
}
fn drop_table_column(&mut self, table_name: &str, column_name: &str) -> Result<()> {
self.check_active()?;
let mut table = self.get_table(table_name)?;
table.drop_column(column_name)
}
fn rename_table_column(
&mut self,
table_name: &str,
old_name: &str,
new_name: &str,
) -> Result<()> {
self.check_active()?;
let table = self.get_table(table_name)?;
table.rename_column(old_name, new_name)
}
fn modify_table_column(&mut self, table_name: &str, column: SchemaColumn) -> Result<()> {
self.check_active()?;
let table = self.get_table(table_name)?;
table.modify_column(&column.name, column.data_type, column.nullable)
}
fn select(
&self,
table_name: &str,
columns_to_fetch: &[String],
expr: Option<&dyn Expression>,
_original_columns: Option<&[String]>,
) -> Result<Box<dyn QueryResult>> {
self.check_active()?;
let table = self.get_table(table_name)?;
let col_refs: Vec<&str> = columns_to_fetch.iter().map(|s| s.as_str()).collect();
table.select(&col_refs, expr)
}
fn select_with_aliases(
&self,
table_name: &str,
columns_to_fetch: &[String],
expr: Option<&dyn Expression>,
aliases: &FxHashMap<String, String>,
_original_columns: Option<&[String]>,
) -> Result<Box<dyn QueryResult>> {
self.check_active()?;
let table = self.get_table(table_name)?;
let col_refs: Vec<&str> = columns_to_fetch.iter().map(|s| s.as_str()).collect();
table.select_with_aliases(&col_refs, expr, aliases)
}
fn select_as_of(
&self,
table_name: &str,
columns_to_fetch: &[String],
expr: Option<&dyn Expression>,
temporal_type: &str,
temporal_value: i64,
_original_columns: Option<&[String]>,
) -> Result<Box<dyn QueryResult>> {
self.check_active()?;
let table = self.get_table(table_name)?;
let col_refs: Vec<&str> = columns_to_fetch.iter().map(|s| s.as_str()).collect();
table.select_as_of(&col_refs, expr, temporal_type, temporal_value)
}
}
impl std::fmt::Debug for MvccTransaction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MvccTransaction")
.field("id", &self.id)
.field("state", &self.state)
.field("begin_seq", &self.begin_seq)
.finish()
}
}
impl Drop for MvccTransaction {
fn drop(&mut self) {
if self.state == TransactionState::Active {
self.registry.abort_transaction(self.id);
self.cleanup();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_creation() {
let registry = Arc::new(TransactionRegistry::new());
let (txn_id, begin_seq) = registry.begin_transaction();
let txn = MvccTransaction::new(txn_id, begin_seq, Arc::clone(®istry));
assert_eq!(txn.id(), txn_id);
assert_eq!(txn.begin_seq(), begin_seq);
assert_eq!(txn.state(), TransactionState::Active);
}
#[test]
fn test_transaction_state_transitions() {
let registry = Arc::new(TransactionRegistry::new());
let (txn_id, begin_seq) = registry.begin_transaction();
let mut txn = MvccTransaction::new(txn_id, begin_seq, Arc::clone(®istry));
assert_eq!(txn.state(), TransactionState::Active);
txn.begin().unwrap();
assert_eq!(txn.state(), TransactionState::Active);
txn.commit().unwrap();
assert_eq!(txn.state(), TransactionState::Committed);
assert!(txn.begin().is_err());
}
#[test]
fn test_transaction_rollback() {
let registry = Arc::new(TransactionRegistry::new());
let (txn_id, begin_seq) = registry.begin_transaction();
let mut txn = MvccTransaction::new(txn_id, begin_seq, Arc::clone(®istry));
assert_eq!(txn.state(), TransactionState::Active);
txn.rollback().unwrap();
assert_eq!(txn.state(), TransactionState::RolledBack);
assert!(txn.begin().is_err());
}
#[test]
fn test_transaction_isolation_level() {
let registry = Arc::new(TransactionRegistry::new());
let (txn_id, begin_seq) = registry.begin_transaction();
let mut txn = MvccTransaction::new(txn_id, begin_seq, Arc::clone(®istry));
let default_level = txn.get_isolation_level();
assert_eq!(default_level, IsolationLevel::ReadCommitted);
txn.set_isolation_level(IsolationLevel::SnapshotIsolation)
.unwrap();
assert_eq!(txn.get_isolation_level(), IsolationLevel::SnapshotIsolation);
}
#[test]
fn test_transaction_double_commit() {
let registry = Arc::new(TransactionRegistry::new());
let (txn_id, begin_seq) = registry.begin_transaction();
let mut txn = MvccTransaction::new(txn_id, begin_seq, Arc::clone(®istry));
txn.commit().unwrap();
assert!(txn.commit().is_err());
}
#[test]
fn test_transaction_commit_after_rollback() {
let registry = Arc::new(TransactionRegistry::new());
let (txn_id, begin_seq) = registry.begin_transaction();
let mut txn = MvccTransaction::new(txn_id, begin_seq, Arc::clone(®istry));
txn.rollback().unwrap();
assert!(txn.commit().is_err());
}
#[test]
fn test_transaction_debug() {
let registry = Arc::new(TransactionRegistry::new());
let (txn_id, begin_seq) = registry.begin_transaction();
let txn = MvccTransaction::new(txn_id, begin_seq, Arc::clone(®istry));
let debug_str = format!("{:?}", txn);
assert!(debug_str.contains("MvccTransaction"));
assert!(debug_str.contains("Active"));
}
}