use super::{types::*, Connection};
pub enum TransactionType {
Deferred,
Immediate,
Exclusive,
}
#[derive(Debug, PartialEq, Eq)]
enum TransactionState {
ActiveTransaction,
ActiveSavepoint,
Inactive,
}
#[derive(Debug)]
pub struct Transaction<'db> {
db: &'db Connection,
state: TransactionState,
}
impl Connection {
pub fn transaction(&self, tt: TransactionType) -> Result<Transaction<'_>> {
let mut txn = Transaction {
db: self,
state: TransactionState::Inactive,
};
txn.start(tt)?;
Ok(txn)
}
}
impl<'db> Transaction<'db> {
fn start(&mut self, tt: TransactionType) -> Result<()> {
let sql = match tt {
TransactionType::Deferred => "BEGIN",
TransactionType::Immediate => "BEGIN IMMEDIATE",
TransactionType::Exclusive => "BEGIN EXCLUSIVE",
};
self.execute(sql, ())?;
self.state = TransactionState::ActiveTransaction;
Ok(())
}
pub fn commit(mut self) -> Result<&'db Connection> {
self.commit_mut().map(|_| self.db)
}
fn commit_mut(&mut self) -> Result<()> {
let ret = match self.state {
TransactionState::ActiveTransaction => self.execute("COMMIT", ()),
TransactionState::ActiveSavepoint => self.execute("RELEASE SAVEPOINT a", ()),
TransactionState::Inactive => panic!("lifetime error"),
};
self.state = TransactionState::Inactive;
ret.map(|_| ())
}
pub fn rollback(mut self) -> Result<&'db Connection> {
self.rollback_mut().map(|_| self.db)
}
fn rollback_mut(&mut self) -> Result<()> {
let ret = match self.state {
TransactionState::ActiveTransaction => self.execute("ROLLBACK", ()),
TransactionState::ActiveSavepoint => self.execute("ROLLBACK TO a", ()),
TransactionState::Inactive => panic!("lifetime error"),
};
self.state = TransactionState::Inactive;
ret.map(|_| ())
}
pub fn savepoint(&mut self) -> Result<Transaction<'_>> {
self.execute("SAVEPOINT a", ())?;
let txn = Self {
db: self.db,
state: TransactionState::ActiveSavepoint,
};
Ok(txn)
}
}
impl std::ops::Deref for Transaction<'_> {
type Target = Connection;
fn deref(&self) -> &Connection {
self.db
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if self.state != TransactionState::Inactive {
if let Err(e) = self.rollback_mut() {
if std::thread::panicking() {
eprintln!("Error while closing SQLite transaction: {e:?}");
} else {
panic!("Error while closing SQLite transaction: {e:?}");
}
}
}
}
}
#[cfg(all(test, feature = "static"))]
mod test {
use crate::test_helpers::prelude::*;
#[test]
fn commit() -> Result<()> {
let h = TestHelpers::new();
h.db.execute("CREATE TABLE tbl(col)", ())?;
let txn = h.db.transaction(TransactionType::Deferred)?;
txn.execute("INSERT INTO tbl VALUES (1)", ())?;
txn.commit()?;
let count =
h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
assert_eq!(count, 1);
Ok(())
}
#[test]
fn rollback() -> Result<()> {
let h = TestHelpers::new();
h.db.execute("CREATE TABLE tbl(col)", ())?;
let txn = h.db.transaction(TransactionType::Deferred)?;
txn.execute("INSERT INTO tbl VALUES (1)", ())?;
txn.rollback()?;
let count =
h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
assert_eq!(count, 0);
Ok(())
}
#[test]
fn drop() -> Result<()> {
let h = TestHelpers::new();
h.db.execute("CREATE TABLE tbl(col)", ())?;
{
let txn = h.db.transaction(TransactionType::Deferred)?;
txn.execute("INSERT INTO tbl VALUES (1)", ())?;
}
let count =
h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
assert_eq!(count, 0);
Ok(())
}
#[test]
fn savepoint_commit() -> Result<()> {
let h = TestHelpers::new();
h.db.execute("CREATE TABLE tbl(col)", ())?;
let mut txn = h.db.transaction(TransactionType::Deferred)?;
txn.execute("INSERT INTO tbl VALUES (1)", ())?;
let sp = txn.savepoint()?;
sp.execute("INSERT INTO tbl VALUES (2)", ())?;
sp.commit()?;
txn.commit()?;
let count =
h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
assert_eq!(count, 2);
Ok(())
}
#[test]
fn savepoint_rollback() -> Result<()> {
let h = TestHelpers::new();
h.db.execute("CREATE TABLE tbl(col)", ())?;
let mut txn = h.db.transaction(TransactionType::Deferred)?;
txn.execute("INSERT INTO tbl VALUES (1)", ())?;
let sp = txn.savepoint()?;
sp.execute("INSERT INTO tbl VALUES (2)", ())?;
sp.rollback()?;
txn.commit()?;
let count =
h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
assert_eq!(count, 1);
Ok(())
}
#[test]
fn savepoint_drop() -> Result<()> {
let h = TestHelpers::new();
h.db.execute("CREATE TABLE tbl(col)", ())?;
let mut txn = h.db.transaction(TransactionType::Deferred)?;
txn.execute("INSERT INTO tbl VALUES (1)", ())?;
{
let sp = txn.savepoint()?;
sp.execute("INSERT INTO tbl VALUES (2)", ())?;
}
txn.commit()?;
let count =
h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
assert_eq!(count, 1);
Ok(())
}
#[test]
fn commit_fail() -> Result<()> {
let h = TestHelpers::new();
h.db.execute("CREATE TABLE tbl(col)", ())?;
let txn = h.db.transaction(TransactionType::Deferred)?;
txn.execute("ROLLBACK", ())?;
match txn.commit() {
Ok(_) => unreachable!(),
Err(e) => assert_eq!(e.to_string(), "cannot commit - no transaction is active"),
}
Ok(())
}
}