unistore-sqlite 0.1.0

SQLite embedded database capability for UniStore
Documentation
//! 事务支持
//!
//! 职责:提供 RAII 风格的事务管理

use crate::connection::Connection;
use crate::error::SqliteError;
use crate::types::Param;

/// 事务隔离级别
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
    /// 默认(DEFERRED)
    #[default]
    Deferred,
    /// 立即获取写锁
    Immediate,
    /// 独占锁
    Exclusive,
}

impl IsolationLevel {
    /// 转换为 SQL 关键字
    pub fn as_sql(&self) -> &'static str {
        match self {
            Self::Deferred => "DEFERRED",
            Self::Immediate => "IMMEDIATE",
            Self::Exclusive => "EXCLUSIVE",
        }
    }
}

/// 事务状态
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
    /// 活跃中
    Active,
    /// 已提交
    Committed,
    /// 已回滚
    RolledBack,
}

/// 事务包装器
///
/// 使用 RAII 模式,离开作用域时自动回滚未提交的事务
pub struct Transaction<'a> {
    conn: &'a Connection,
    state: TransactionState,
}

impl<'a> Transaction<'a> {
    /// 开始新事务
    pub fn begin(conn: &'a Connection) -> Result<Self, SqliteError> {
        Self::begin_with_isolation(conn, IsolationLevel::Deferred)
    }

    /// 开始事务(指定隔离级别)
    pub fn begin_with_isolation(
        conn: &'a Connection,
        isolation: IsolationLevel,
    ) -> Result<Self, SqliteError> {
        conn.execute_batch(&format!("BEGIN {} TRANSACTION", isolation.as_sql()))
            .map_err(|e| SqliteError::TransactionFailed(format!("BEGIN failed: {}", e)))?;

        Ok(Self {
            conn,
            state: TransactionState::Active,
        })
    }

    /// 开始立即事务(推荐用于写操作)
    pub fn begin_immediate(conn: &'a Connection) -> Result<Self, SqliteError> {
        Self::begin_with_isolation(conn, IsolationLevel::Immediate)
    }

    /// 获取事务状态
    pub fn state(&self) -> TransactionState {
        self.state
    }

    /// 判断事务是否活跃
    pub fn is_active(&self) -> bool {
        self.state == TransactionState::Active
    }

    /// 执行 SQL
    pub fn execute(&self, sql: &str, params: &[Param]) -> Result<usize, SqliteError> {
        if !self.is_active() {
            return Err(SqliteError::TransactionFailed(
                "Transaction is not active".to_string(),
            ));
        }
        self.conn.execute(sql, params)
    }

    /// 执行批量 SQL
    pub fn execute_batch(&self, sql: &str) -> Result<(), SqliteError> {
        if !self.is_active() {
            return Err(SqliteError::TransactionFailed(
                "Transaction is not active".to_string(),
            ));
        }
        self.conn.execute_batch(sql)
    }

    /// 提交事务
    pub fn commit(mut self) -> Result<(), SqliteError> {
        if !self.is_active() {
            return Err(SqliteError::TransactionFailed(
                "Transaction is not active".to_string(),
            ));
        }

        self.conn
            .execute_batch("COMMIT")
            .map_err(|e| SqliteError::TransactionFailed(format!("COMMIT failed: {}", e)))?;

        self.state = TransactionState::Committed;
        Ok(())
    }

    /// 回滚事务
    pub fn rollback(mut self) -> Result<(), SqliteError> {
        if !self.is_active() {
            return Ok(()); // 已经不活跃,无需回滚
        }

        self.conn
            .execute_batch("ROLLBACK")
            .map_err(|e| SqliteError::TransactionFailed(format!("ROLLBACK failed: {}", e)))?;

        self.state = TransactionState::RolledBack;
        Ok(())
    }

    /// 创建保存点
    pub fn savepoint(&self, name: &str) -> Result<Savepoint<'_, 'a>, SqliteError> {
        if !self.is_active() {
            return Err(SqliteError::TransactionFailed(
                "Transaction is not active".to_string(),
            ));
        }
        Savepoint::new(self, name)
    }
}

impl<'a> Drop for Transaction<'a> {
    fn drop(&mut self) {
        // 如果事务仍然活跃,自动回滚
        if self.is_active() {
            let _ = self.conn.execute_batch("ROLLBACK");
            self.state = TransactionState::RolledBack;
        }
    }
}

/// 保存点
pub struct Savepoint<'t, 'c> {
    tx: &'t Transaction<'c>,
    name: String,
    released: bool,
}

impl<'t, 'c> Savepoint<'t, 'c> {
    /// 创建保存点
    fn new(tx: &'t Transaction<'c>, name: &str) -> Result<Self, SqliteError> {
        tx.execute_batch(&format!("SAVEPOINT {}", name))?;
        Ok(Self {
            tx,
            name: name.to_string(),
            released: false,
        })
    }

    /// 释放保存点(提交到父事务)
    pub fn release(mut self) -> Result<(), SqliteError> {
        self.tx.execute_batch(&format!("RELEASE SAVEPOINT {}", self.name))?;
        self.released = true;
        Ok(())
    }

    /// 回滚到保存点
    pub fn rollback(mut self) -> Result<(), SqliteError> {
        self.tx
            .execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name))?;
        self.released = true;
        Ok(())
    }
}

impl<'t, 'c> Drop for Savepoint<'t, 'c> {
    fn drop(&mut self) {
        // 如果没有明确释放或回滚,自动回滚
        if !self.released {
            let _ = self
                .tx
                .conn
                .execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name));
        }
    }
}

/// 便捷的事务执行函数
///
/// 自动处理提交和回滚
pub fn with_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
where
    F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
{
    let tx = Transaction::begin(conn)?;
    match f(&tx) {
        Ok(result) => {
            tx.commit()?;
            Ok(result)
        }
        Err(e) => {
            tx.rollback()?;
            Err(e)
        }
    }
}

/// 便捷的立即事务执行函数
pub fn with_immediate_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
where
    F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
{
    let tx = Transaction::begin_immediate(conn)?;
    match f(&tx) {
        Ok(result) => {
            tx.commit()?;
            Ok(result)
        }
        Err(e) => {
            tx.rollback()?;
            Err(e)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn setup_test_db() -> Connection {
        let conn = Connection::open_in_memory().unwrap();
        conn.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)")
            .unwrap();
        conn
    }

    #[test]
    fn test_commit() {
        let conn = setup_test_db();

        {
            let tx = Transaction::begin(&conn).unwrap();
            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
                .unwrap();
            tx.commit().unwrap();
        }

        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
        assert_eq!(rows.len(), 1);
    }

    #[test]
    fn test_rollback() {
        let conn = setup_test_db();

        {
            let tx = Transaction::begin(&conn).unwrap();
            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
                .unwrap();
            tx.rollback().unwrap();
        }

        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
        assert_eq!(rows.len(), 0);
    }

    #[test]
    fn test_auto_rollback_on_drop() {
        let conn = setup_test_db();

        {
            let tx = Transaction::begin(&conn).unwrap();
            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
                .unwrap();
            // 不调用 commit 或 rollback,让 Drop 处理
        }

        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
        assert_eq!(rows.len(), 0); // 应该已回滚
    }

    #[test]
    fn test_with_transaction() {
        let conn = setup_test_db();

        // 成功场景
        let result = with_transaction(&conn, |tx| {
            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])?;
            Ok(42)
        });

        assert_eq!(result.unwrap(), 42);
        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
        assert_eq!(rows.len(), 1);

        // 失败场景
        let result: Result<i32, SqliteError> = with_transaction(&conn, |tx| {
            tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])?;
            Err(SqliteError::Internal("test error".to_string()))
        });

        assert!(result.is_err());
        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
        assert_eq!(rows.len(), 1); // 仍然只有 1 行,第二个被回滚了
    }

    #[test]
    fn test_savepoint() {
        let conn = setup_test_db();

        let tx = Transaction::begin(&conn).unwrap();
        tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
            .unwrap();

        {
            let sp = tx.savepoint("sp1").unwrap();
            tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])
                .unwrap();
            sp.rollback().unwrap(); // 回滚到保存点
        }

        tx.commit().unwrap();

        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
        assert_eq!(rows.len(), 1); // 只有第一个插入成功
    }
}