use crate::{Connection, Result};
use std::ops::Deref;
#[derive(Copy, Clone)]
#[non_exhaustive]
pub enum TransactionBehavior {
Deferred,
Immediate,
Exclusive,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum DropBehavior {
Rollback,
Commit,
Ignore,
Panic,
}
#[derive(Debug)]
pub struct Transaction<'conn> {
conn: &'conn Connection,
drop_behavior: DropBehavior,
}
#[derive(Debug)]
pub struct Savepoint<'conn> {
conn: &'conn Connection,
name: String,
drop_behavior: DropBehavior,
committed: bool,
}
impl Transaction<'_> {
#[inline]
pub fn new(conn: &mut Connection, behavior: TransactionBehavior) -> Result<Transaction<'_>> {
Self::new_unchecked(conn, behavior)
}
#[inline]
pub fn new_unchecked(
conn: &Connection,
behavior: TransactionBehavior,
) -> Result<Transaction<'_>> {
let query = match behavior {
TransactionBehavior::Deferred => "BEGIN DEFERRED",
TransactionBehavior::Immediate => "BEGIN IMMEDIATE",
TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE",
};
conn.execute_batch(query).map(move |()| Transaction {
conn,
drop_behavior: DropBehavior::Rollback,
})
}
#[inline]
pub fn savepoint(&mut self) -> Result<Savepoint<'_>> {
Savepoint::new_(self.conn)
}
#[inline]
pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> {
Savepoint::with_name_(self.conn, name)
}
#[inline]
#[must_use]
pub fn drop_behavior(&self) -> DropBehavior {
self.drop_behavior
}
#[inline]
pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) {
self.drop_behavior = drop_behavior;
}
#[inline]
pub fn commit(mut self) -> Result<()> {
self.commit_()
}
#[inline]
fn commit_(&mut self) -> Result<()> {
self.conn.execute_batch("COMMIT")?;
Ok(())
}
#[inline]
pub fn rollback(mut self) -> Result<()> {
self.rollback_()
}
#[inline]
fn rollback_(&mut self) -> Result<()> {
self.conn.execute_batch("ROLLBACK")?;
Ok(())
}
#[inline]
pub fn finish(mut self) -> Result<()> {
self.finish_()
}
#[inline]
fn finish_(&mut self) -> Result<()> {
if self.conn.is_autocommit() {
return Ok(());
}
match self.drop_behavior() {
DropBehavior::Commit => self.commit_().or_else(|_| self.rollback_()),
DropBehavior::Rollback => self.rollback_(),
DropBehavior::Ignore => Ok(()),
DropBehavior::Panic => panic!("Transaction dropped unexpectedly."),
}
}
}
impl Deref for Transaction<'_> {
type Target = Connection;
#[inline]
fn deref(&self) -> &Connection {
self.conn
}
}
#[expect(unused_must_use)]
impl Drop for Transaction<'_> {
#[inline]
fn drop(&mut self) {
self.finish_();
}
}
impl Savepoint<'_> {
#[inline]
fn with_name_<T: Into<String>>(conn: &Connection, name: T) -> Result<Savepoint<'_>> {
let name = name.into();
conn.execute_batch(&format!("SAVEPOINT {name}"))
.map(|()| Savepoint {
conn,
name,
drop_behavior: DropBehavior::Rollback,
committed: false,
})
}
#[inline]
fn new_(conn: &Connection) -> Result<Savepoint<'_>> {
Savepoint::with_name_(conn, "_rusqlite_sp")
}
#[inline]
pub fn new(conn: &mut Connection) -> Result<Savepoint<'_>> {
Savepoint::new_(conn)
}
#[inline]
pub fn with_name<T: Into<String>>(conn: &mut Connection, name: T) -> Result<Savepoint<'_>> {
Savepoint::with_name_(conn, name)
}
#[inline]
pub fn savepoint(&mut self) -> Result<Savepoint<'_>> {
Savepoint::new_(self.conn)
}
#[inline]
pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> {
Savepoint::with_name_(self.conn, name)
}
#[inline]
#[must_use]
pub fn drop_behavior(&self) -> DropBehavior {
self.drop_behavior
}
#[inline]
pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) {
self.drop_behavior = drop_behavior;
}
#[inline]
pub fn commit(mut self) -> Result<()> {
self.commit_()
}
#[inline]
fn commit_(&mut self) -> Result<()> {
self.conn.execute_batch(&format!("RELEASE {}", self.name))?;
self.committed = true;
Ok(())
}
#[inline]
pub fn rollback(&mut self) -> Result<()> {
self.conn
.execute_batch(&format!("ROLLBACK TO {}", self.name))
}
#[inline]
pub fn finish(mut self) -> Result<()> {
self.finish_()
}
#[inline]
fn finish_(&mut self) -> Result<()> {
if self.committed {
return Ok(());
}
match self.drop_behavior() {
DropBehavior::Commit => self
.commit_()
.or_else(|_| self.rollback().and_then(|()| self.commit_())),
DropBehavior::Rollback => self.rollback().and_then(|()| self.commit_()),
DropBehavior::Ignore => Ok(()),
DropBehavior::Panic => panic!("Savepoint dropped unexpectedly."),
}
}
}
impl Deref for Savepoint<'_> {
type Target = Connection;
#[inline]
fn deref(&self) -> &Connection {
self.conn
}
}
#[expect(unused_must_use)]
impl Drop for Savepoint<'_> {
#[inline]
fn drop(&mut self) {
self.finish_();
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
#[cfg(feature = "modern_sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))]
pub enum TransactionState {
None,
Read,
Write,
}
impl Connection {
#[inline]
pub fn transaction(&mut self) -> Result<Transaction<'_>> {
Transaction::new(self, self.transaction_behavior)
}
#[inline]
pub fn transaction_with_behavior(
&mut self,
behavior: TransactionBehavior,
) -> Result<Transaction<'_>> {
Transaction::new(self, behavior)
}
pub fn unchecked_transaction(&self) -> Result<Transaction<'_>> {
Transaction::new_unchecked(self, self.transaction_behavior)
}
#[inline]
pub fn savepoint(&mut self) -> Result<Savepoint<'_>> {
Savepoint::new(self)
}
#[inline]
pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> {
Savepoint::with_name(self, name)
}
#[cfg(feature = "modern_sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))]
pub fn transaction_state(
&self,
db_name: Option<crate::DatabaseName<'_>>,
) -> Result<TransactionState> {
self.db.borrow().txn_state(db_name)
}
pub fn set_transaction_behavior(&mut self, behavior: TransactionBehavior) {
self.transaction_behavior = behavior;
}
}
#[cfg(test)]
mod test {
use super::DropBehavior;
use crate::{Connection, Error, Result};
fn checked_memory_handle() -> Result<Connection> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE foo (x INTEGER)")?;
Ok(db)
}
#[test]
fn test_drop() -> Result<()> {
let mut db = checked_memory_handle()?;
{
let tx = db.transaction()?;
tx.execute_batch("INSERT INTO foo VALUES(1)")?;
}
{
let mut tx = db.transaction()?;
tx.execute_batch("INSERT INTO foo VALUES(2)")?;
tx.set_drop_behavior(DropBehavior::Commit)
}
{
let tx = db.transaction()?;
assert_eq!(2i32, tx.one_column::<i32>("SELECT SUM(x) FROM foo")?);
}
Ok(())
}
fn assert_nested_tx_error(e: Error) {
if let Error::SqliteFailure(e, Some(m)) = &e {
assert_eq!(e.extended_code, crate::ffi::SQLITE_ERROR);
assert_eq!(e.code, crate::ErrorCode::Unknown);
assert!(m.contains("transaction"));
} else {
panic!("Unexpected error type: {e:?}");
}
}
#[test]
fn test_unchecked_nesting() -> Result<()> {
let db = checked_memory_handle()?;
{
let tx = db.unchecked_transaction()?;
let e = tx.unchecked_transaction().unwrap_err();
assert_nested_tx_error(e);
}
{
let tx = db.unchecked_transaction()?;
tx.execute_batch("INSERT INTO foo VALUES(1)")?;
let e = tx.unchecked_transaction().unwrap_err();
assert_nested_tx_error(e);
tx.execute_batch("INSERT INTO foo VALUES(1)")?;
tx.commit()?;
}
assert_eq!(2i32, db.one_column::<i32>("SELECT SUM(x) FROM foo")?);
Ok(())
}
#[test]
fn test_explicit_rollback_commit() -> Result<()> {
let mut db = checked_memory_handle()?;
{
let mut tx = db.transaction()?;
{
let mut sp = tx.savepoint()?;
sp.execute_batch("INSERT INTO foo VALUES(1)")?;
sp.rollback()?;
sp.execute_batch("INSERT INTO foo VALUES(2)")?;
sp.commit()?;
}
tx.commit()?;
}
{
let tx = db.transaction()?;
tx.execute_batch("INSERT INTO foo VALUES(4)")?;
tx.commit()?;
}
{
let tx = db.transaction()?;
assert_eq!(6i32, tx.one_column::<i32>("SELECT SUM(x) FROM foo")?);
}
Ok(())
}
#[test]
fn test_savepoint() -> Result<()> {
let mut db = checked_memory_handle()?;
{
let mut tx = db.transaction()?;
tx.execute_batch("INSERT INTO foo VALUES(1)")?;
assert_current_sum(1, &tx)?;
tx.set_drop_behavior(DropBehavior::Commit);
{
let mut sp1 = tx.savepoint()?;
sp1.execute_batch("INSERT INTO foo VALUES(2)")?;
assert_current_sum(3, &sp1)?;
{
let mut sp2 = sp1.savepoint()?;
sp2.execute_batch("INSERT INTO foo VALUES(4)")?;
assert_current_sum(7, &sp2)?;
{
let sp3 = sp2.savepoint()?;
sp3.execute_batch("INSERT INTO foo VALUES(8)")?;
assert_current_sum(15, &sp3)?;
sp3.commit()?;
}
assert_current_sum(15, &sp2)?;
}
assert_current_sum(3, &sp1)?;
}
assert_current_sum(1, &tx)?;
}
assert_current_sum(1, &db)?;
Ok(())
}
#[test]
fn test_ignore_drop_behavior() -> Result<()> {
let mut db = checked_memory_handle()?;
let mut tx = db.transaction()?;
{
let mut sp1 = tx.savepoint()?;
insert(1, &sp1)?;
sp1.rollback()?;
insert(2, &sp1)?;
{
let mut sp2 = sp1.savepoint()?;
sp2.set_drop_behavior(DropBehavior::Ignore);
insert(4, &sp2)?;
}
assert_current_sum(6, &sp1)?;
sp1.commit()?;
}
assert_current_sum(6, &tx)?;
Ok(())
}
#[test]
fn test_savepoint_drop_behavior_releases() -> Result<()> {
let mut db = checked_memory_handle()?;
{
let mut sp = db.savepoint()?;
sp.set_drop_behavior(DropBehavior::Commit);
}
assert!(db.is_autocommit());
{
let mut sp = db.savepoint()?;
sp.set_drop_behavior(DropBehavior::Rollback);
}
assert!(db.is_autocommit());
Ok(())
}
#[test]
fn test_savepoint_release_error() -> Result<()> {
let mut db = checked_memory_handle()?;
db.pragma_update(None, "foreign_keys", true)?;
db.execute_batch("CREATE TABLE r(n INTEGER PRIMARY KEY NOT NULL); CREATE TABLE f(n REFERENCES r(n) DEFERRABLE INITIALLY DEFERRED);")?;
{
let mut sp = db.savepoint()?;
sp.execute("INSERT INTO f VALUES (0)", [])?;
sp.set_drop_behavior(DropBehavior::Commit);
}
assert!(db.is_autocommit());
Ok(())
}
#[test]
fn test_savepoint_names() -> Result<()> {
let mut db = checked_memory_handle()?;
{
let mut sp1 = db.savepoint_with_name("my_sp")?;
insert(1, &sp1)?;
assert_current_sum(1, &sp1)?;
{
let mut sp2 = sp1.savepoint_with_name("my_sp")?;
sp2.set_drop_behavior(DropBehavior::Commit);
insert(2, &sp2)?;
assert_current_sum(3, &sp2)?;
sp2.rollback()?;
assert_current_sum(1, &sp2)?;
insert(4, &sp2)?;
}
assert_current_sum(5, &sp1)?;
sp1.rollback()?;
{
let mut sp2 = sp1.savepoint_with_name("my_sp")?;
sp2.set_drop_behavior(DropBehavior::Ignore);
insert(8, &sp2)?;
}
assert_current_sum(8, &sp1)?;
sp1.commit()?;
}
assert_current_sum(8, &db)?;
Ok(())
}
#[test]
fn test_rc() -> Result<()> {
use std::rc::Rc;
let mut conn = Connection::open_in_memory()?;
let rc_txn = Rc::new(conn.transaction()?);
Rc::try_unwrap(rc_txn).unwrap();
Ok(())
}
fn insert(x: i32, conn: &Connection) -> Result<usize> {
conn.execute("INSERT INTO foo VALUES(?1)", [x])
}
fn assert_current_sum(x: i32, conn: &Connection) -> Result<()> {
let i = conn.one_column::<i32>("SELECT SUM(x) FROM foo")?;
assert_eq!(x, i);
Ok(())
}
#[test]
#[cfg(feature = "modern_sqlite")]
fn txn_state() -> Result<()> {
use super::TransactionState;
use crate::DatabaseName;
let db = Connection::open_in_memory()?;
assert_eq!(
TransactionState::None,
db.transaction_state(Some(DatabaseName::Main))?
);
assert_eq!(TransactionState::None, db.transaction_state(None)?);
db.execute_batch("BEGIN")?;
assert_eq!(TransactionState::None, db.transaction_state(None)?);
let _: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?;
assert_eq!(TransactionState::Read, db.transaction_state(None)?);
db.pragma_update(None, "user_version", 1)?;
assert_eq!(TransactionState::Write, db.transaction_state(None)?);
db.execute_batch("ROLLBACK")?;
Ok(())
}
#[test]
#[cfg(feature = "modern_sqlite")]
fn auto_commit() -> Result<()> {
use super::TransactionState;
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE t(i UNIQUE);")?;
assert!(db.is_autocommit());
let mut stmt = db.prepare("SELECT name FROM sqlite_master")?;
assert_eq!(TransactionState::None, db.transaction_state(None)?);
{
let mut rows = stmt.query([])?;
assert!(rows.next()?.is_some()); assert_eq!(TransactionState::Read, db.transaction_state(None)?);
db.execute("INSERT INTO t VALUES (1)", [])?; assert_eq!(TransactionState::Read, db.transaction_state(None)?);
assert!(rows.next()?.is_some()); assert_eq!(TransactionState::Read, db.transaction_state(None)?);
assert!(rows.next()?.is_none()); assert_eq!(TransactionState::None, db.transaction_state(None)?);
}
Ok(())
}
}