use std::{ops::Deref, sync::atomic::Ordering};
use crate::{Connection, Result, Statement};
#[derive(Copy, Clone)]
#[non_exhaustive]
pub enum TransactionBehavior {
Deferred,
Immediate,
Exclusive,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum DropBehavior {
Rollback,
Commit,
Ignore,
Panic,
}
impl From<DropBehavior> for u8 {
fn from(behavior: DropBehavior) -> Self {
match behavior {
DropBehavior::Rollback => 0,
DropBehavior::Commit => 1,
DropBehavior::Ignore => 2,
DropBehavior::Panic => 3,
}
}
}
impl From<u8> for DropBehavior {
fn from(value: u8) -> Self {
match value {
0 => DropBehavior::Rollback,
1 => DropBehavior::Commit,
2 => DropBehavior::Ignore,
3 => DropBehavior::Panic,
_ => panic!("Invalid drop behavior: {value}"),
}
}
}
#[derive(Debug)]
pub struct Transaction<'conn> {
conn: &'conn Connection,
drop_behavior: DropBehavior,
in_progress: bool,
}
impl Transaction<'_> {
#[inline]
pub async fn new(
conn: &mut Connection,
behavior: TransactionBehavior,
) -> Result<Transaction<'_>> {
Self::new_unchecked(conn, behavior).await
}
#[inline]
pub async fn new_unchecked(
conn: &Connection,
behavior: TransactionBehavior,
) -> Result<Transaction<'_>> {
let query = match behavior {
TransactionBehavior::Deferred => "BEGIN DEFERRED",
TransactionBehavior::Immediate => "BEGIN IMMEDIATE",
TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE",
};
conn.execute(query, ()).await.map(move |_| Transaction {
conn,
drop_behavior: DropBehavior::Rollback,
in_progress: true,
})
}
pub async fn prepare(&self, sql: &str) -> Result<Statement> {
self.conn.prepare(sql).await
}
#[inline]
#[must_use]
pub fn drop_behavior(&self) -> DropBehavior {
self.drop_behavior
}
#[inline]
pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) {
self.drop_behavior = drop_behavior;
}
#[inline]
pub async fn commit(mut self) -> Result<()> {
self._commit().await
}
#[inline]
async fn _commit(&mut self) -> Result<()> {
self.conn.execute("COMMIT", ()).await?;
self.in_progress = false;
Ok(())
}
#[inline]
pub async fn rollback(mut self) -> Result<()> {
self._rollback().await
}
#[inline]
async fn _rollback(&mut self) -> Result<()> {
self.conn.execute("ROLLBACK", ()).await?;
self.in_progress = false;
Ok(())
}
#[inline]
pub async fn finish(mut self) -> Result<()> {
self._finish().await
}
#[inline]
async fn _finish(&mut self) -> Result<()> {
if self.conn.is_autocommit()? {
return Ok(());
}
match self.drop_behavior() {
DropBehavior::Commit => {
if (self._commit().await).is_err() {
self._rollback().await
} else {
Ok(())
}
}
DropBehavior::Rollback => self._rollback().await,
DropBehavior::Ignore => Ok(()),
DropBehavior::Panic => panic!("Transaction dropped unexpectedly."),
}
}
}
impl Deref for Transaction<'_> {
type Target = Connection;
#[inline]
fn deref(&self) -> &Connection {
self.conn
}
}
impl Drop for Transaction<'_> {
#[inline]
fn drop(&mut self) {
if self.in_progress {
self.conn
.dangling_tx
.store(self.drop_behavior(), Ordering::SeqCst);
} else {
self.conn
.dangling_tx
.store(DropBehavior::Ignore, Ordering::SeqCst);
}
}
}
impl Connection {
#[inline]
pub async fn transaction(&mut self) -> Result<Transaction<'_>> {
self.transaction_with_behavior(self.transaction_behavior)
.await
}
#[inline]
pub async fn transaction_with_behavior(
&mut self,
behavior: TransactionBehavior,
) -> Result<Transaction<'_>> {
self.maybe_handle_dangling_tx().await?;
Transaction::new(self, behavior).await
}
pub async fn unchecked_transaction(&self) -> Result<Transaction<'_>> {
Transaction::new_unchecked(self, self.transaction_behavior).await
}
pub fn set_transaction_behavior(&mut self, behavior: TransactionBehavior) {
self.transaction_behavior = behavior;
}
}
#[cfg(test)]
mod test {
use crate::{Builder, Connection, Error, Result};
use super::DropBehavior;
async fn checked_memory_handle() -> Result<Connection> {
let db = Builder::new_local(":memory:").build().await?;
let conn = db.connect()?;
conn.execute("CREATE TABLE foo (x INTEGER)", ()).await?;
Ok(conn)
}
#[tokio::test]
async fn test_drop_rollback_on_new_transaction() {
let mut conn = checked_memory_handle().await.unwrap();
{
let tx = conn.transaction().await.unwrap();
tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap();
}
let tx = conn.transaction().await.unwrap();
tx.execute("INSERT INTO foo VALUES(?)", &[2]).await.unwrap();
let result = tx
.prepare("SELECT SUM(x) FROM foo")
.await
.unwrap()
.query_row(())
.await
.unwrap();
assert_eq!(2, result.get::<i32>(0).unwrap());
tx.finish().await.unwrap();
}
#[tokio::test]
async fn test_drop_rollback_on_query() {
let mut conn = checked_memory_handle().await.unwrap();
{
let tx = conn.transaction().await.unwrap();
tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap();
}
let mut rows = conn.query("SELECT count(*) FROM foo", ()).await.unwrap();
let result = rows.next().await.unwrap().unwrap();
assert_eq!(0, result.get::<i32>(0).unwrap());
}
#[tokio::test]
async fn test_drop_rollback_on_execute() {
let mut conn = checked_memory_handle().await.unwrap();
{
let tx = conn.transaction().await.unwrap();
tx.execute("INSERT INTO foo VALUES(?)", &[1]).await.unwrap();
}
conn.execute("INSERT INTO foo VALUES(?)", &[2])
.await
.unwrap();
let mut rows = conn.query("SELECT count(*) FROM foo", ()).await.unwrap();
let result = rows.next().await.unwrap().unwrap();
assert_eq!(1, result.get::<i32>(0).unwrap());
}
#[tokio::test]
async fn test_drop() -> Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let mut conn = checked_memory_handle().await?;
{
let tx = conn.transaction().await?;
tx.execute("INSERT INTO foo VALUES(?)", &[1]).await?;
}
{
let mut tx = conn.transaction().await?;
tx.execute("INSERT INTO foo VALUES(?)", &[2]).await?;
tx.set_drop_behavior(DropBehavior::Commit);
}
{
let tx = conn.transaction().await?;
let result = tx
.prepare("SELECT SUM(x) FROM foo")
.await?
.query_row(())
.await?;
assert_eq!(2, result.get::<i32>(0)?);
}
Ok(())
}
fn assert_nested_tx_error(e: Error) {
if let Error::Error(e) = &e {
assert!(e.contains("transaction"));
} else {
panic!("Unexpected error type: {e:?}");
}
}
#[tokio::test]
async fn test_unchecked_nesting() -> Result<()> {
let conn = checked_memory_handle().await?;
{
let tx = conn.unchecked_transaction().await?;
let e = tx.unchecked_transaction().await.unwrap_err();
assert_nested_tx_error(e);
tx.finish().await?;
}
{
let tx = conn.unchecked_transaction().await?;
tx.execute("INSERT INTO foo VALUES(?)", &[1]).await?;
let e = tx.unchecked_transaction().await.unwrap_err();
assert_nested_tx_error(e);
tx.execute("INSERT INTO foo VALUES(?)", &[1]).await?;
tx.commit().await?;
}
let result = conn
.prepare("SELECT SUM(x) FROM foo")
.await?
.query_row(())
.await?;
assert_eq!(2, result.get::<i32>(0)?);
Ok(())
}
#[tokio::test]
async fn test_explicit_rollback_commit() -> Result<()> {
let mut conn = checked_memory_handle().await?;
{
let tx = conn.transaction().await?;
tx.execute("INSERT INTO foo VALUES(?)", &[1]).await?;
tx.rollback().await?;
let tx = conn.transaction().await?;
tx.execute("INSERT INTO foo VALUES(?)", &[2]).await?;
tx.commit().await?;
}
{
let tx = conn.transaction().await?;
tx.execute("INSERT INTO foo VALUES(?)", &[4]).await?;
tx.commit().await?;
}
{
let result = conn
.prepare("SELECT SUM(x) FROM foo")
.await?
.query_row(())
.await?;
assert_eq!(6, result.get::<i32>(0)?);
}
Ok(())
}
}