Skip to main content

alun_db/
tx.rs

1//! 事务:真正的 Commit/Rollback,编译期保证永不"忘记回滚"
2
3use crate::{DbResult, DbError, db::DbPool};
4use sqlx::{Row, Column};
5use serde_json::{Value, Number};
6use tracing::debug;
7
8/// 事务隔离级别
9///
10/// 遵循 SQL 标准四级隔离,从低到高排列(可用于 `PartialOrd` 比较)。
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub enum Isolation {
13    /// 读未提交(最低隔离级别,可能脏读)
14    ReadUncommitted,
15    /// 读已提交(默认级别,无脏读)
16    ReadCommitted,
17    /// 可重复读(同一事务内多次读取一致)
18    RepeatableRead,
19    /// 串行化(最高隔离级别,完全隔离)
20    Serializable,
21}
22
23impl std::fmt::Display for Isolation {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Isolation::ReadUncommitted => write!(f, "READ UNCOMMITTED"),
27            Isolation::ReadCommitted    => write!(f, "READ COMMITTED"),
28            Isolation::RepeatableRead   => write!(f, "REPEATABLE READ"),
29            Isolation::Serializable     => write!(f, "SERIALIZABLE"),
30        }
31    }
32}
33
34/// 活跃事务句柄 —— 封装不同数据库的连接(已 BEGIN)
35///
36/// 事务通过 `Db::transaction()` 创建,不直接构造。
37/// 闭包正常返回则 COMMIT,返回 Err 则 ROLLBACK。
38/// 当 `ActiveTx` 被 drop 且未提交/回滚时,日志会输出警告。
39pub struct ActiveTx {
40    inner: ActiveTxInner,
41    committed: bool,
42    rolled_back: bool,
43}
44
45enum ActiveTxInner {
46    /// PostgreSQL 连接
47    Postgres(sqlx::pool::PoolConnection<sqlx::Postgres>),
48    /// MySQL 连接
49    Mysql(sqlx::pool::PoolConnection<sqlx::MySql>),
50    /// SQLite 连接
51    Sqlite(sqlx::pool::PoolConnection<sqlx::Sqlite>),
52}
53
54impl ActiveTx {
55    /// 在事务中执行写操作(INSERT/UPDATE/DELETE),返回受影响行数
56    ///
57    /// 参数使用 `$1`、`$2` 占位符,按顺序绑定。
58    pub async fn execute(&mut self, sql: &str, params: &[&str]) -> DbResult<u64> {
59        match &mut self.inner {
60            ActiveTxInner::Postgres(c) => {
61                let mut q = sqlx::query::<sqlx::Postgres>(sql);
62                for p in params { q = q.bind(*p); }
63                q.execute(&mut **c).await.map_err(DbError::from).map(|r| r.rows_affected())
64            }
65            ActiveTxInner::Mysql(c) => {
66                let mut q = sqlx::query::<sqlx::MySql>(sql);
67                for p in params { q = q.bind(*p); }
68                q.execute(&mut **c).await.map_err(DbError::from).map(|r| r.rows_affected())
69            }
70            ActiveTxInner::Sqlite(c) => {
71                let mut q = sqlx::query::<sqlx::Sqlite>(sql);
72                for p in params { q = q.bind(*p); }
73                q.execute(&mut **c).await.map_err(DbError::from).map(|r| r.rows_affected())
74            }
75        }
76    }
77
78    /// 在事务中执行查询,返回 `Option<Row>`
79    ///
80    /// 参数使用 `$1`、`$2` 占位符,按顺序绑定。
81    /// 未找到记录返回 `Ok(None)`。
82    pub async fn query_one(&mut self, sql: &str, params: &[&str]) -> DbResult<Option<crate::Row>> {
83        match &mut self.inner {
84            ActiveTxInner::Postgres(c) => {
85                let mut q = sqlx::query::<sqlx::Postgres>(sql);
86                for p in params { q = q.bind(*p); }
87                Ok(q.fetch_optional(&mut **c).await?.as_ref().map(tx_row_to_row_pg))
88            }
89            ActiveTxInner::Mysql(c) => {
90                let mut q = sqlx::query::<sqlx::MySql>(sql);
91                for p in params { q = q.bind(*p); }
92                Ok(q.fetch_optional(&mut **c).await?.as_ref().map(tx_row_to_row_my))
93            }
94            ActiveTxInner::Sqlite(c) => {
95                let mut q = sqlx::query::<sqlx::Sqlite>(sql);
96                for p in params { q = q.bind(*p); }
97                Ok(q.fetch_optional(&mut **c).await?.as_ref().map(tx_row_to_row_sqlite))
98            }
99        }
100    }
101
102    /// 标记事务需回滚(即使闭包返回 `Ok`,也会执行 ROLLBACK)
103    ///
104    /// 用于业务逻辑判断失败但不想中断闭包流程的场景。
105    pub fn set_rollback_only(&mut self) { self.committed = false; self.rolled_back = true; }
106
107    async fn commit(mut self) -> DbResult<()> {
108        if self.rolled_back { self.rollback().await; return Ok(()); }
109        debug!("事务提交");
110        match &mut self.inner {
111            ActiveTxInner::Postgres(c) => { sqlx::query::<sqlx::Postgres>("COMMIT").execute(&mut **c).await.map_err(DbError::from)?; }
112            ActiveTxInner::Mysql(c)    => { sqlx::query::<sqlx::MySql>("COMMIT").execute(&mut **c).await.map_err(DbError::from)?; }
113            ActiveTxInner::Sqlite(c)   => { sqlx::query::<sqlx::Sqlite>("COMMIT").execute(&mut **c).await.map_err(DbError::from)?; }
114        };
115        self.committed = true;
116        Ok(())
117    }
118
119    async fn rollback(&mut self) {
120        debug!("事务回滚");
121        match &mut self.inner {
122            ActiveTxInner::Postgres(c) => { let _ = sqlx::query::<sqlx::Postgres>("ROLLBACK").execute(&mut **c).await; }
123            ActiveTxInner::Mysql(c)    => { let _ = sqlx::query::<sqlx::MySql>("ROLLBACK").execute(&mut **c).await; }
124            ActiveTxInner::Sqlite(c)   => { let _ = sqlx::query::<sqlx::Sqlite>("ROLLBACK").execute(&mut **c).await; }
125        };
126        self.rolled_back = true;
127    }
128}
129
130impl Drop for ActiveTx {
131    fn drop(&mut self) {
132        if !self.committed && !self.rolled_back {
133            tracing::warn!("事务未提交也未回滚,连接返回池时将自动回滚(依赖数据库特性)");
134        }
135    }
136}
137
138macro_rules! tx_row_convert {
139    ($func_name:ident, $db_ty:ty) => {
140        fn $func_name(row: &<$db_ty as sqlx::Database>::Row) -> crate::Row {
141            let mut r = crate::Row::default();
142            for col in <$db_ty as sqlx::Database>::Row::columns(row) {
143                let name = col.name().to_string();
144                let idx: usize = col.ordinal();
145                if let Ok(v) = row.try_get::<i64, usize>(idx) {
146                    r.data.insert(name, Value::Number(v.into()));
147                } else if let Ok(v) = row.try_get::<i32, usize>(idx) {
148                    r.data.insert(name, Value::Number((v as i64).into()));
149                } else if let Ok(v) = row.try_get::<i16, usize>(idx) {
150                    r.data.insert(name, Value::Number((v as i64).into()));
151                } else if let Ok(v) = row.try_get::<String, usize>(idx) {
152                    r.data.insert(name, Value::String(v));
153                } else if let Ok(v) = row.try_get::<sqlx::types::Uuid, usize>(idx) {
154                    r.data.insert(name, Value::String(v.to_string()));
155                } else if let Ok(v) = row.try_get::<f64, usize>(idx) {
156                    if let Some(n) = Number::from_f64(v) {
157                        r.data.insert(name, Value::Number(n));
158                    }
159                } else if let Ok(v) = row.try_get::<bool, usize>(idx) {
160                    r.data.insert(name, Value::Bool(v));
161                }
162            }
163            r.mark_all_changed();
164            r
165        }
166    };
167}
168
169tx_row_convert!(tx_row_to_row_pg, sqlx::Postgres);
170tx_row_convert!(tx_row_to_row_my, sqlx::MySql);
171tx_row_convert!(tx_row_to_row_sqlite, sqlx::Sqlite);
172
173/// 执行事务 —— 传入闭包接收 `ActiveTx`,自动管理 BEGIN / COMMIT / ROLLBACK
174///
175/// 闭包接收一个 `ActiveTx`(已 BEGIN),需返回 `(ActiveTx, DbResult<T>)`。
176/// 返回 `Ok` 时自动 COMMIT,返回 `Err` 或 Drop 未提交时自动 ROLLBACK。
177pub(crate) async fn execute_transaction<F, Fut, T>(
178    pool: &DbPool, _isolation: Isolation, _rollback_only: &mut bool, f: F,
179) -> DbResult<T>
180where
181    F: FnOnce(ActiveTx) -> Fut + Send,
182    Fut: std::future::Future<Output = (ActiveTx, DbResult<T>)> + Send,
183    T: Send,
184{
185    let tx = match pool {
186        DbPool::Postgres(p) => {
187            let mut conn = p.acquire().await?;
188            sqlx::query::<sqlx::Postgres>("BEGIN").execute(&mut *conn).await?;
189            ActiveTx { inner: ActiveTxInner::Postgres(conn), committed: false, rolled_back: false }
190        }
191        DbPool::Mysql(p) => {
192            let mut conn = p.acquire().await?;
193            sqlx::query::<sqlx::MySql>("BEGIN").execute(&mut *conn).await?;
194            ActiveTx { inner: ActiveTxInner::Mysql(conn), committed: false, rolled_back: false }
195        }
196        DbPool::Sqlite(p) => {
197            let mut conn = p.acquire().await?;
198            sqlx::query::<sqlx::Sqlite>("BEGIN").execute(&mut *conn).await?;
199            ActiveTx { inner: ActiveTxInner::Sqlite(conn), committed: false, rolled_back: false }
200        }
201        DbPool::Any(_) => return Err(DbError::Other("Any pool 不支持事务".into())),
202    };
203
204    let (mut tx, result) = f(tx).await;
205    match result {
206        Ok(val) => { tx.commit().await?; Ok(val) }
207        Err(e)  => { tx.rollback().await; Err(e) }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    #[test]
214    fn test_isolation_display() {
215        use super::Isolation;
216        assert_eq!(Isolation::ReadCommitted.to_string(), "READ COMMITTED");
217        assert_eq!(Isolation::Serializable.to_string(), "SERIALIZABLE");
218    }
219}