use std::future::Future;
use std::sync::atomic::{AtomicU32, Ordering};
use crate::database::{Database, Pinned};
use crate::dialect::Dialect;
use crate::driver::ExecuteResult;
use crate::executor::Executor;
use crate::row::Row;
use crate::value::Value;
use crate::BoxFuture;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
#[default]
Deferred,
Immediate,
Exclusive,
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
pub fn standard_sql(self) -> Option<&'static str> {
match self {
IsolationLevel::ReadUncommitted => Some("READ UNCOMMITTED"),
IsolationLevel::ReadCommitted => Some("READ COMMITTED"),
IsolationLevel::RepeatableRead => Some("REPEATABLE READ"),
IsolationLevel::Serializable => Some("SERIALIZABLE"),
IsolationLevel::Deferred | IsolationLevel::Immediate | IsolationLevel::Exclusive => None,
}
}
}
pub struct Transaction {
inner: Pinned,
committed: bool,
savepoint_counter: AtomicU32,
}
impl Transaction {
pub(crate) fn new(inner: Pinned) -> Self {
Self { inner, committed: false, savepoint_counter: AtomicU32::new(0) }
}
pub async fn commit(&mut self) -> crate::Result<()> {
self.committed = true;
let commit_sql = self.inner.dialect().commit_sql().to_string();
let rollback_sql = self.inner.dialect().rollback_sql().to_string();
let result = self.inner.execute(commit_sql, vec![]).await;
if result.is_err() {
let _ = self.inner.execute(rollback_sql, vec![]).await;
}
result.map(|_| ())
}
pub async fn rollback(&mut self) -> crate::Result<()> {
self.committed = true;
let sql = self.inner.dialect().rollback_sql().to_string();
self.inner.execute(sql, vec![]).await.map(|_| ())
}
pub async fn savepoint<F, R>(&self, f: F) -> crate::Result<R>
where
F: for<'a> FnOnce(&'a Transaction) -> BoxFuture<'a, crate::Result<R>>,
R: Send + 'static,
{
let n = self.savepoint_counter.fetch_add(1, Ordering::Relaxed);
let sp_name = format!("tork_sp_{n}");
let savepoint_sql = self.inner.dialect().savepoint_sql(&sp_name);
let release_sql = self.inner.dialect().release_sql(&sp_name);
let rollback_to_sql = self.inner.dialect().rollback_to_sql(&sp_name);
self.inner.execute(savepoint_sql, vec![]).await?;
match f(self).await {
Ok(value) => {
self.inner.execute(release_sql, vec![]).await?;
Ok(value)
}
Err(error) => {
let _ = self.inner.execute(rollback_to_sql, vec![]).await;
Err(error)
}
}
}
}
impl Executor for Transaction {
fn dialect(&self) -> &dyn Dialect {
self.inner.dialect()
}
fn fetch_all(
&self,
sql: String,
params: Vec<Value>,
) -> impl Future<Output = crate::Result<Vec<Row>>> + Send {
self.inner.fetch_all(sql, params)
}
fn execute(
&self,
sql: String,
params: Vec<Value>,
) -> impl Future<Output = crate::Result<ExecuteResult>> + Send {
self.inner.execute(sql, params)
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.committed {
self.inner.rollback_now();
}
}
}
impl Database {
pub async fn transaction<F, R>(&self, f: F) -> crate::Result<R>
where
F: for<'a> FnOnce(&'a Transaction) -> BoxFuture<'a, crate::Result<R>>,
R: Send + 'static,
{
let mut tx = self.begin().await?;
match f(&tx).await {
Ok(value) => {
tx.commit().await?;
Ok(value)
}
Err(error) => {
let _ = tx.rollback().await;
Err(error)
}
}
}
pub async fn transaction_retry<F, R>(&self, max_attempts: u32, f: F) -> crate::Result<R>
where
F: for<'a> Fn(&'a Transaction) -> BoxFuture<'a, crate::Result<R>>,
R: Send + 'static,
{
let attempts = max_attempts.max(1);
let mut attempt = 0;
loop {
attempt += 1;
let mut tx = self.begin().await?;
let outcome = match f(&tx).await {
Ok(value) => tx.commit().await.map(|()| value),
Err(error) => {
let _ = tx.rollback().await;
Err(error)
}
};
match outcome {
Ok(value) => return Ok(value),
Err(error) if attempt < attempts && error.is_retryable() => continue,
Err(error) => return Err(error),
}
}
}
pub async fn begin(&self) -> crate::Result<Transaction> {
let pinned = self.pinned().await?;
let begin_sql = pinned.dialect().begin_sql().to_string();
pinned.execute(begin_sql, vec![]).await?;
Ok(Transaction::new(pinned))
}
pub fn transaction_with(&self) -> TransactionBuilder<'_> {
TransactionBuilder { db: self, level: IsolationLevel::Deferred }
}
}
pub struct TransactionBuilder<'db> {
db: &'db Database,
level: IsolationLevel,
}
impl<'db> TransactionBuilder<'db> {
pub fn deferred(mut self) -> Self {
self.level = IsolationLevel::Deferred;
self
}
pub fn immediate(mut self) -> Self {
self.level = IsolationLevel::Immediate;
self
}
pub fn exclusive(mut self) -> Self {
self.level = IsolationLevel::Exclusive;
self
}
pub fn read_uncommitted(mut self) -> Self {
self.level = IsolationLevel::ReadUncommitted;
self
}
pub fn read_committed(mut self) -> Self {
self.level = IsolationLevel::ReadCommitted;
self
}
pub fn repeatable_read(mut self) -> Self {
self.level = IsolationLevel::RepeatableRead;
self
}
pub fn serializable(mut self) -> Self {
self.level = IsolationLevel::Serializable;
self
}
pub async fn begin(self) -> crate::Result<Transaction> {
let pinned = self.db.pinned().await?;
if let Some(setup) = pinned.dialect().isolation_setup_sql(self.level) {
pinned.execute(setup, vec![]).await?;
}
let sql = pinned.dialect().begin_with_sql(self.level);
pinned.execute(sql, vec![]).await?;
Ok(Transaction::new(pinned))
}
pub async fn run<F, R>(self, f: F) -> crate::Result<R>
where
F: for<'a> FnOnce(&'a Transaction) -> BoxFuture<'a, crate::Result<R>>,
R: Send + 'static,
{
let mut tx = self.begin().await?;
match f(&tx).await {
Ok(value) => {
tx.commit().await?;
Ok(value)
}
Err(error) => {
let _ = tx.rollback().await;
Err(error)
}
}
}
}