use crate::backends::DatabaseTransaction;
use crate::database::ManagedPool;
use crate::error::{ModelError, ModelResult};
use tracing::{debug, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
pub fn as_sql(&self) -> &'static str {
match self {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
}
}
}
#[derive(Debug, Clone)]
pub struct TransactionConfig {
pub isolation_level: Option<IsolationLevel>,
pub read_only: bool,
pub auto_retry: bool,
pub max_retries: u32,
}
impl Default for TransactionConfig {
fn default() -> Self {
Self {
isolation_level: None, read_only: false,
auto_retry: false,
max_retries: 3,
}
}
}
pub struct Transaction {
inner: Option<Box<dyn DatabaseTransaction>>,
config: TransactionConfig,
committed: bool,
}
impl Transaction {
pub async fn begin(
pool: &ManagedPool,
config: TransactionConfig,
) -> Result<Transaction, ModelError> {
debug!("Beginning transaction with config: {:?}", config);
let mut tx = pool
.begin_transaction()
.await
.map_err(|e| ModelError::Transaction(format!("Failed to begin transaction: {}", e)))?;
if let Some(isolation_level) = config.isolation_level {
let sql = format!(
"SET TRANSACTION ISOLATION LEVEL {}",
isolation_level.as_sql()
);
tx.execute(&sql, &[]).await.map_err(|e| {
ModelError::Transaction(format!("Failed to set isolation level: {}", e))
})?;
debug!("Transaction isolation level set to: {:?}", isolation_level);
}
if config.read_only {
tx.execute("SET TRANSACTION READ ONLY", &[])
.await
.map_err(|e| {
ModelError::Transaction(format!("Failed to set read-only mode: {}", e))
})?;
debug!("Transaction set to read-only mode");
}
Ok(Transaction {
inner: Some(tx),
config,
committed: false,
})
}
pub async fn begin_default(pool: &ManagedPool) -> Result<Transaction, ModelError> {
Self::begin(pool, TransactionConfig::default()).await
}
pub async fn begin_read_only(pool: &ManagedPool) -> Result<Transaction, ModelError> {
let config = TransactionConfig {
read_only: true,
..Default::default()
};
Self::begin(pool, config).await
}
pub async fn begin_serializable(pool: &ManagedPool) -> Result<Transaction, ModelError> {
let config = TransactionConfig {
isolation_level: Some(IsolationLevel::Serializable),
auto_retry: true, ..Default::default()
};
Self::begin(pool, config).await
}
pub fn is_active(&self) -> bool {
self.inner.is_some() && !self.committed
}
pub fn as_mut(&mut self) -> Option<&mut Box<dyn DatabaseTransaction>> {
self.inner.as_mut()
}
pub fn as_ref(&self) -> Option<&Box<dyn DatabaseTransaction>> {
self.inner.as_ref()
}
pub async fn execute(
&mut self,
sql: &str,
params: &[crate::backends::DatabaseValue],
) -> Result<u64, ModelError> {
if let Some(tx) = &mut self.inner {
tx.execute(sql, params)
.await
.map_err(|e| ModelError::Transaction(format!("Transaction query failed: {}", e)))
} else {
Err(ModelError::Transaction(
"Transaction has been consumed".to_string(),
))
}
}
pub async fn fetch_all(
&mut self,
sql: &str,
params: &[crate::backends::DatabaseValue],
) -> Result<Vec<Box<dyn crate::backends::DatabaseRow>>, ModelError> {
if let Some(tx) = &mut self.inner {
tx.fetch_all(sql, params)
.await
.map_err(|e| ModelError::Transaction(format!("Transaction query failed: {}", e)))
} else {
Err(ModelError::Transaction(
"Transaction has been consumed".to_string(),
))
}
}
pub async fn fetch_optional(
&mut self,
sql: &str,
params: &[crate::backends::DatabaseValue],
) -> Result<Option<Box<dyn crate::backends::DatabaseRow>>, ModelError> {
if let Some(tx) = &mut self.inner {
tx.fetch_optional(sql, params)
.await
.map_err(|e| ModelError::Transaction(format!("Transaction query failed: {}", e)))
} else {
Err(ModelError::Transaction(
"Transaction has been consumed".to_string(),
))
}
}
pub async fn execute_with<F, Fut, R>(&mut self, f: F) -> Result<R, ModelError>
where
F: FnOnce(&mut Self) -> Fut,
Fut: std::future::Future<Output = Result<R, ModelError>>,
{
if self.inner.is_some() {
f(self).await
} else {
Err(ModelError::Transaction(
"Transaction has been consumed".to_string(),
))
}
}
pub async fn commit(mut self) -> ModelResult<()> {
if let Some(tx) = self.inner.take() {
debug!("Committing transaction");
tx.commit().await.map_err(|e| {
ModelError::Transaction(format!("Failed to commit transaction: {}", e))
})?;
debug!("Transaction committed successfully");
Ok(())
} else {
Err(ModelError::Transaction(
"Transaction has already been consumed".to_string(),
))
}
}
pub async fn rollback(mut self) -> ModelResult<()> {
if let Some(tx) = self.inner.take() {
debug!("Rolling back transaction");
tx.rollback().await.map_err(|e| {
ModelError::Transaction(format!("Failed to rollback transaction: {}", e))
})?;
debug!("Transaction rolled back successfully");
Ok(())
} else {
Err(ModelError::Transaction(
"Transaction has already been consumed".to_string(),
))
}
}
pub fn is_committed(&self) -> bool {
self.committed
}
pub fn config(&self) -> &TransactionConfig {
&self.config
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if let Some(tx) = self.inner.take() {
if !self.committed {
warn!("Transaction dropped without explicit commit or rollback - this will cause an automatic rollback");
std::mem::drop(tx);
}
}
}
}
pub async fn with_transaction<F, Fut, R>(
pool: &ManagedPool,
config: TransactionConfig,
f: F,
) -> Result<R, ModelError>
where
F: Fn(&mut Transaction) -> Fut,
Fut: std::future::Future<Output = Result<R, ModelError>>,
{
let mut attempts = 0;
let max_attempts = if config.auto_retry {
config.max_retries + 1
} else {
1
};
loop {
attempts += 1;
debug!(
"Starting transaction attempt {} of {}",
attempts, max_attempts
);
let mut tx = Transaction::begin(pool, config.clone()).await?;
match f(&mut tx).await {
Ok(result) => {
tx.commit().await?;
return Ok(result);
}
Err(e) => {
let should_retry =
config.auto_retry && attempts < max_attempts && is_serialization_failure(&e);
if should_retry {
warn!(
"Serialization failure on attempt {}, retrying: {}",
attempts, e
);
tx.rollback().await.ok(); continue;
} else {
tx.rollback().await.ok(); return Err(e);
}
}
}
}
}
pub async fn with_transaction_default<F, Fut, R>(pool: &ManagedPool, f: F) -> Result<R, ModelError>
where
F: Fn(&mut Transaction) -> Fut,
Fut: std::future::Future<Output = Result<R, ModelError>>,
{
with_transaction(pool, TransactionConfig::default(), f).await
}
pub fn is_serialization_failure(error: &ModelError) -> bool {
match error {
ModelError::Database(msg) | ModelError::Transaction(msg) => {
msg.contains("40001") || msg.contains("40P01") || msg.contains("could not serialize access")
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_isolation_level_sql() {
assert_eq!(IsolationLevel::ReadUncommitted.as_sql(), "READ UNCOMMITTED");
assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
}
#[test]
fn test_transaction_config_default() {
let config = TransactionConfig::default();
assert!(config.isolation_level.is_none());
assert!(!config.read_only);
assert!(!config.auto_retry);
assert_eq!(config.max_retries, 3);
}
#[test]
fn test_serialization_failure_detection() {
let err1 = ModelError::Database(
"ERROR: could not serialize access due to concurrent update".to_string(),
);
assert!(is_serialization_failure(&err1));
let err2 = ModelError::Transaction("ERROR: 40001".to_string());
assert!(is_serialization_failure(&err2));
let err3 = ModelError::Validation("Invalid input".to_string());
assert!(!is_serialization_failure(&err3));
}
}