use crate::connection::Connection;
use crate::error::SqliteError;
use crate::types::Param;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
#[default]
Deferred,
Immediate,
Exclusive,
}
impl IsolationLevel {
pub fn as_sql(&self) -> &'static str {
match self {
Self::Deferred => "DEFERRED",
Self::Immediate => "IMMEDIATE",
Self::Exclusive => "EXCLUSIVE",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
Active,
Committed,
RolledBack,
}
pub struct Transaction<'a> {
conn: &'a Connection,
state: TransactionState,
}
impl<'a> Transaction<'a> {
pub fn begin(conn: &'a Connection) -> Result<Self, SqliteError> {
Self::begin_with_isolation(conn, IsolationLevel::Deferred)
}
pub fn begin_with_isolation(
conn: &'a Connection,
isolation: IsolationLevel,
) -> Result<Self, SqliteError> {
conn.execute_batch(&format!("BEGIN {} TRANSACTION", isolation.as_sql()))
.map_err(|e| SqliteError::TransactionFailed(format!("BEGIN failed: {}", e)))?;
Ok(Self {
conn,
state: TransactionState::Active,
})
}
pub fn begin_immediate(conn: &'a Connection) -> Result<Self, SqliteError> {
Self::begin_with_isolation(conn, IsolationLevel::Immediate)
}
pub fn state(&self) -> TransactionState {
self.state
}
pub fn is_active(&self) -> bool {
self.state == TransactionState::Active
}
pub fn execute(&self, sql: &str, params: &[Param]) -> Result<usize, SqliteError> {
if !self.is_active() {
return Err(SqliteError::TransactionFailed(
"Transaction is not active".to_string(),
));
}
self.conn.execute(sql, params)
}
pub fn execute_batch(&self, sql: &str) -> Result<(), SqliteError> {
if !self.is_active() {
return Err(SqliteError::TransactionFailed(
"Transaction is not active".to_string(),
));
}
self.conn.execute_batch(sql)
}
pub fn commit(mut self) -> Result<(), SqliteError> {
if !self.is_active() {
return Err(SqliteError::TransactionFailed(
"Transaction is not active".to_string(),
));
}
self.conn
.execute_batch("COMMIT")
.map_err(|e| SqliteError::TransactionFailed(format!("COMMIT failed: {}", e)))?;
self.state = TransactionState::Committed;
Ok(())
}
pub fn rollback(mut self) -> Result<(), SqliteError> {
if !self.is_active() {
return Ok(()); }
self.conn
.execute_batch("ROLLBACK")
.map_err(|e| SqliteError::TransactionFailed(format!("ROLLBACK failed: {}", e)))?;
self.state = TransactionState::RolledBack;
Ok(())
}
pub fn savepoint(&self, name: &str) -> Result<Savepoint<'_, 'a>, SqliteError> {
if !self.is_active() {
return Err(SqliteError::TransactionFailed(
"Transaction is not active".to_string(),
));
}
Savepoint::new(self, name)
}
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if self.is_active() {
let _ = self.conn.execute_batch("ROLLBACK");
self.state = TransactionState::RolledBack;
}
}
}
pub struct Savepoint<'t, 'c> {
tx: &'t Transaction<'c>,
name: String,
released: bool,
}
impl<'t, 'c> Savepoint<'t, 'c> {
fn new(tx: &'t Transaction<'c>, name: &str) -> Result<Self, SqliteError> {
tx.execute_batch(&format!("SAVEPOINT {}", name))?;
Ok(Self {
tx,
name: name.to_string(),
released: false,
})
}
pub fn release(mut self) -> Result<(), SqliteError> {
self.tx.execute_batch(&format!("RELEASE SAVEPOINT {}", self.name))?;
self.released = true;
Ok(())
}
pub fn rollback(mut self) -> Result<(), SqliteError> {
self.tx
.execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name))?;
self.released = true;
Ok(())
}
}
impl<'t, 'c> Drop for Savepoint<'t, 'c> {
fn drop(&mut self) {
if !self.released {
let _ = self
.tx
.conn
.execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name));
}
}
}
pub fn with_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
where
F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
{
let tx = Transaction::begin(conn)?;
match f(&tx) {
Ok(result) => {
tx.commit()?;
Ok(result)
}
Err(e) => {
tx.rollback()?;
Err(e)
}
}
}
pub fn with_immediate_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
where
F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
{
let tx = Transaction::begin_immediate(conn)?;
match f(&tx) {
Ok(result) => {
tx.commit()?;
Ok(result)
}
Err(e) => {
tx.rollback()?;
Err(e)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup_test_db() -> Connection {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)")
.unwrap();
conn
}
#[test]
fn test_commit() {
let conn = setup_test_db();
{
let tx = Transaction::begin(&conn).unwrap();
tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
.unwrap();
tx.commit().unwrap();
}
let rows = conn.query("SELECT * FROM test", &[]).unwrap();
assert_eq!(rows.len(), 1);
}
#[test]
fn test_rollback() {
let conn = setup_test_db();
{
let tx = Transaction::begin(&conn).unwrap();
tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
.unwrap();
tx.rollback().unwrap();
}
let rows = conn.query("SELECT * FROM test", &[]).unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn test_auto_rollback_on_drop() {
let conn = setup_test_db();
{
let tx = Transaction::begin(&conn).unwrap();
tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
.unwrap();
}
let rows = conn.query("SELECT * FROM test", &[]).unwrap();
assert_eq!(rows.len(), 0); }
#[test]
fn test_with_transaction() {
let conn = setup_test_db();
let result = with_transaction(&conn, |tx| {
tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])?;
Ok(42)
});
assert_eq!(result.unwrap(), 42);
let rows = conn.query("SELECT * FROM test", &[]).unwrap();
assert_eq!(rows.len(), 1);
let result: Result<i32, SqliteError> = with_transaction(&conn, |tx| {
tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])?;
Err(SqliteError::Internal("test error".to_string()))
});
assert!(result.is_err());
let rows = conn.query("SELECT * FROM test", &[]).unwrap();
assert_eq!(rows.len(), 1); }
#[test]
fn test_savepoint() {
let conn = setup_test_db();
let tx = Transaction::begin(&conn).unwrap();
tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
.unwrap();
{
let sp = tx.savepoint("sp1").unwrap();
tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])
.unwrap();
sp.rollback().unwrap(); }
tx.commit().unwrap();
let rows = conn.query("SELECT * FROM test", &[]).unwrap();
assert_eq!(rows.len(), 1); }
}