#![allow(dead_code)]
use std::future::Future;
use std::time::Duration;
use tracing::debug;
use crate::error::QueryResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum IsolationLevel {
ReadUncommitted,
#[default]
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
pub fn as_sql(&self) -> &'static str {
match self {
Self::ReadUncommitted => "READ UNCOMMITTED",
Self::ReadCommitted => "READ COMMITTED",
Self::RepeatableRead => "REPEATABLE READ",
Self::Serializable => "SERIALIZABLE",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum AccessMode {
#[default]
ReadWrite,
ReadOnly,
}
impl AccessMode {
pub fn as_sql(&self) -> &'static str {
match self {
Self::ReadWrite => "READ WRITE",
Self::ReadOnly => "READ ONLY",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TransactionConfig {
pub isolation: IsolationLevel,
pub access_mode: AccessMode,
pub timeout: Option<Duration>,
pub deferrable: bool,
}
impl TransactionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn isolation(mut self, level: IsolationLevel) -> Self {
self.isolation = level;
self
}
pub fn access_mode(mut self, mode: AccessMode) -> Self {
self.access_mode = mode;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn read_only(self) -> Self {
self.access_mode(AccessMode::ReadOnly)
}
pub fn deferrable(mut self) -> Self {
self.deferrable = true;
self
}
pub fn to_begin_sql(&self) -> String {
let mut parts = vec!["BEGIN"];
parts.push("ISOLATION LEVEL");
parts.push(self.isolation.as_sql());
parts.push(self.access_mode.as_sql());
if self.deferrable
&& self.isolation == IsolationLevel::Serializable
&& self.access_mode == AccessMode::ReadOnly
{
parts.push("DEFERRABLE");
}
let sql = parts.join(" ");
debug!(isolation = %self.isolation.as_sql(), access_mode = %self.access_mode.as_sql(), "Transaction BEGIN");
sql
}
}
pub struct Transaction<E> {
engine: E,
config: TransactionConfig,
committed: bool,
savepoint_count: u32,
}
impl<E> Transaction<E> {
pub fn new(engine: E, config: TransactionConfig) -> Self {
Self {
engine,
config,
committed: false,
savepoint_count: 0,
}
}
pub fn config(&self) -> &TransactionConfig {
&self.config
}
pub fn engine(&self) -> &E {
&self.engine
}
pub fn savepoint_name(&mut self) -> String {
self.savepoint_count += 1;
format!("sp_{}", self.savepoint_count)
}
pub fn mark_committed(&mut self) {
self.committed = true;
}
pub fn is_committed(&self) -> bool {
self.committed
}
}
pub struct TransactionBuilder<E, F, Fut, T>
where
F: FnOnce(Transaction<E>) -> Fut,
Fut: Future<Output = QueryResult<T>>,
{
engine: E,
callback: F,
config: TransactionConfig,
}
impl<E, F, Fut, T> TransactionBuilder<E, F, Fut, T>
where
F: FnOnce(Transaction<E>) -> Fut,
Fut: Future<Output = QueryResult<T>>,
{
pub fn new(engine: E, callback: F) -> Self {
Self {
engine,
callback,
config: TransactionConfig::default(),
}
}
pub fn isolation(mut self, level: IsolationLevel) -> Self {
self.config.isolation = level;
self
}
pub fn read_only(mut self) -> Self {
self.config.access_mode = AccessMode::ReadOnly;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = Some(timeout);
self
}
pub fn deferrable(mut self) -> Self {
self.config.deferrable = true;
self
}
}
pub struct InteractiveTransaction<E> {
inner: Transaction<E>,
started: bool,
}
impl<E> InteractiveTransaction<E> {
pub fn new(engine: E) -> Self {
Self {
inner: Transaction::new(engine, TransactionConfig::default()),
started: false,
}
}
pub fn with_config(engine: E, config: TransactionConfig) -> Self {
Self {
inner: Transaction::new(engine, config),
started: false,
}
}
pub fn engine(&self) -> &E {
&self.inner.engine
}
pub fn is_started(&self) -> bool {
self.started
}
pub fn begin_sql(&self) -> String {
self.inner.config.to_begin_sql()
}
pub fn commit_sql(&self) -> &'static str {
"COMMIT"
}
pub fn rollback_sql(&self) -> &'static str {
"ROLLBACK"
}
pub fn savepoint_sql(&mut self, name: Option<&str>) -> String {
let name = name
.map(|s| s.to_string())
.unwrap_or_else(|| self.inner.savepoint_name());
format!("SAVEPOINT {}", name)
}
pub fn rollback_to_sql(&self, name: &str) -> String {
format!("ROLLBACK TO SAVEPOINT {}", name)
}
pub fn release_savepoint_sql(&self, name: &str) -> String {
format!("RELEASE SAVEPOINT {}", name)
}
pub fn mark_started(&mut self) {
self.started = true;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_isolation_level() {
assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
}
#[test]
fn test_access_mode() {
assert_eq!(AccessMode::ReadWrite.as_sql(), "READ WRITE");
assert_eq!(AccessMode::ReadOnly.as_sql(), "READ ONLY");
}
#[test]
fn test_transaction_config_default() {
let config = TransactionConfig::new();
assert_eq!(config.isolation, IsolationLevel::ReadCommitted);
assert_eq!(config.access_mode, AccessMode::ReadWrite);
assert!(config.timeout.is_none());
assert!(!config.deferrable);
}
#[test]
fn test_transaction_config_builder() {
let config = TransactionConfig::new()
.isolation(IsolationLevel::Serializable)
.read_only()
.deferrable()
.timeout(Duration::from_secs(30));
assert_eq!(config.isolation, IsolationLevel::Serializable);
assert_eq!(config.access_mode, AccessMode::ReadOnly);
assert!(config.deferrable);
assert_eq!(config.timeout, Some(Duration::from_secs(30)));
}
#[test]
fn test_begin_sql() {
let config = TransactionConfig::new();
let sql = config.to_begin_sql();
assert!(sql.contains("BEGIN"));
assert!(sql.contains("ISOLATION LEVEL READ COMMITTED"));
assert!(sql.contains("READ WRITE"));
}
#[test]
fn test_begin_sql_serializable_deferrable() {
let config = TransactionConfig::new()
.isolation(IsolationLevel::Serializable)
.read_only()
.deferrable();
let sql = config.to_begin_sql();
assert!(sql.contains("SERIALIZABLE"));
assert!(sql.contains("READ ONLY"));
assert!(sql.contains("DEFERRABLE"));
}
#[test]
fn test_interactive_transaction() {
#[derive(Clone)]
struct MockEngine;
let mut tx = InteractiveTransaction::new(MockEngine);
assert!(!tx.is_started());
let begin = tx.begin_sql();
assert!(begin.contains("BEGIN"));
let sp = tx.savepoint_sql(Some("test_sp"));
assert_eq!(sp, "SAVEPOINT test_sp");
let rollback_to = tx.rollback_to_sql("test_sp");
assert_eq!(rollback_to, "ROLLBACK TO SAVEPOINT test_sp");
let release = tx.release_savepoint_sql("test_sp");
assert_eq!(release, "RELEASE SAVEPOINT test_sp");
}
}