1use crate::{DbResult, DbError, db::DbPool};
4use sqlx::{Row, Column};
5use serde_json::{Value, Number};
6use tracing::debug;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub enum Isolation {
13 ReadUncommitted,
15 ReadCommitted,
17 RepeatableRead,
19 Serializable,
21}
22
23impl std::fmt::Display for Isolation {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Isolation::ReadUncommitted => write!(f, "READ UNCOMMITTED"),
27 Isolation::ReadCommitted => write!(f, "READ COMMITTED"),
28 Isolation::RepeatableRead => write!(f, "REPEATABLE READ"),
29 Isolation::Serializable => write!(f, "SERIALIZABLE"),
30 }
31 }
32}
33
34pub struct ActiveTx {
40 inner: ActiveTxInner,
41 committed: bool,
42 rolled_back: bool,
43}
44
45enum ActiveTxInner {
46 Postgres(sqlx::pool::PoolConnection<sqlx::Postgres>),
48 Mysql(sqlx::pool::PoolConnection<sqlx::MySql>),
50 Sqlite(sqlx::pool::PoolConnection<sqlx::Sqlite>),
52}
53
54impl ActiveTx {
55 pub async fn execute(&mut self, sql: &str, params: &[&str]) -> DbResult<u64> {
59 match &mut self.inner {
60 ActiveTxInner::Postgres(c) => {
61 let mut q = sqlx::query::<sqlx::Postgres>(sql);
62 for p in params { q = q.bind(*p); }
63 q.execute(&mut **c).await.map_err(DbError::from).map(|r| r.rows_affected())
64 }
65 ActiveTxInner::Mysql(c) => {
66 let mut q = sqlx::query::<sqlx::MySql>(sql);
67 for p in params { q = q.bind(*p); }
68 q.execute(&mut **c).await.map_err(DbError::from).map(|r| r.rows_affected())
69 }
70 ActiveTxInner::Sqlite(c) => {
71 let mut q = sqlx::query::<sqlx::Sqlite>(sql);
72 for p in params { q = q.bind(*p); }
73 q.execute(&mut **c).await.map_err(DbError::from).map(|r| r.rows_affected())
74 }
75 }
76 }
77
78 pub async fn query_one(&mut self, sql: &str, params: &[&str]) -> DbResult<Option<crate::Row>> {
83 match &mut self.inner {
84 ActiveTxInner::Postgres(c) => {
85 let mut q = sqlx::query::<sqlx::Postgres>(sql);
86 for p in params { q = q.bind(*p); }
87 Ok(q.fetch_optional(&mut **c).await?.as_ref().map(tx_row_to_row_pg))
88 }
89 ActiveTxInner::Mysql(c) => {
90 let mut q = sqlx::query::<sqlx::MySql>(sql);
91 for p in params { q = q.bind(*p); }
92 Ok(q.fetch_optional(&mut **c).await?.as_ref().map(tx_row_to_row_my))
93 }
94 ActiveTxInner::Sqlite(c) => {
95 let mut q = sqlx::query::<sqlx::Sqlite>(sql);
96 for p in params { q = q.bind(*p); }
97 Ok(q.fetch_optional(&mut **c).await?.as_ref().map(tx_row_to_row_sqlite))
98 }
99 }
100 }
101
102 pub fn set_rollback_only(&mut self) { self.committed = false; self.rolled_back = true; }
106
107 async fn commit(mut self) -> DbResult<()> {
108 if self.rolled_back { self.rollback().await; return Ok(()); }
109 debug!("事务提交");
110 match &mut self.inner {
111 ActiveTxInner::Postgres(c) => { sqlx::query::<sqlx::Postgres>("COMMIT").execute(&mut **c).await.map_err(DbError::from)?; }
112 ActiveTxInner::Mysql(c) => { sqlx::query::<sqlx::MySql>("COMMIT").execute(&mut **c).await.map_err(DbError::from)?; }
113 ActiveTxInner::Sqlite(c) => { sqlx::query::<sqlx::Sqlite>("COMMIT").execute(&mut **c).await.map_err(DbError::from)?; }
114 };
115 self.committed = true;
116 Ok(())
117 }
118
119 async fn rollback(&mut self) {
120 debug!("事务回滚");
121 match &mut self.inner {
122 ActiveTxInner::Postgres(c) => { let _ = sqlx::query::<sqlx::Postgres>("ROLLBACK").execute(&mut **c).await; }
123 ActiveTxInner::Mysql(c) => { let _ = sqlx::query::<sqlx::MySql>("ROLLBACK").execute(&mut **c).await; }
124 ActiveTxInner::Sqlite(c) => { let _ = sqlx::query::<sqlx::Sqlite>("ROLLBACK").execute(&mut **c).await; }
125 };
126 self.rolled_back = true;
127 }
128}
129
130impl Drop for ActiveTx {
131 fn drop(&mut self) {
132 if !self.committed && !self.rolled_back {
133 tracing::warn!("事务未提交也未回滚,连接返回池时将自动回滚(依赖数据库特性)");
134 }
135 }
136}
137
138macro_rules! tx_row_convert {
139 ($func_name:ident, $db_ty:ty) => {
140 fn $func_name(row: &<$db_ty as sqlx::Database>::Row) -> crate::Row {
141 let mut r = crate::Row::default();
142 for col in <$db_ty as sqlx::Database>::Row::columns(row) {
143 let name = col.name().to_string();
144 let idx: usize = col.ordinal();
145 if let Ok(v) = row.try_get::<i64, usize>(idx) {
146 r.data.insert(name, Value::Number(v.into()));
147 } else if let Ok(v) = row.try_get::<i32, usize>(idx) {
148 r.data.insert(name, Value::Number((v as i64).into()));
149 } else if let Ok(v) = row.try_get::<i16, usize>(idx) {
150 r.data.insert(name, Value::Number((v as i64).into()));
151 } else if let Ok(v) = row.try_get::<String, usize>(idx) {
152 r.data.insert(name, Value::String(v));
153 } else if let Ok(v) = row.try_get::<sqlx::types::Uuid, usize>(idx) {
154 r.data.insert(name, Value::String(v.to_string()));
155 } else if let Ok(v) = row.try_get::<f64, usize>(idx) {
156 if let Some(n) = Number::from_f64(v) {
157 r.data.insert(name, Value::Number(n));
158 }
159 } else if let Ok(v) = row.try_get::<bool, usize>(idx) {
160 r.data.insert(name, Value::Bool(v));
161 }
162 }
163 r.mark_all_changed();
164 r
165 }
166 };
167}
168
169tx_row_convert!(tx_row_to_row_pg, sqlx::Postgres);
170tx_row_convert!(tx_row_to_row_my, sqlx::MySql);
171tx_row_convert!(tx_row_to_row_sqlite, sqlx::Sqlite);
172
173pub(crate) async fn execute_transaction<F, Fut, T>(
178 pool: &DbPool, _isolation: Isolation, _rollback_only: &mut bool, f: F,
179) -> DbResult<T>
180where
181 F: FnOnce(ActiveTx) -> Fut + Send,
182 Fut: std::future::Future<Output = (ActiveTx, DbResult<T>)> + Send,
183 T: Send,
184{
185 let tx = match pool {
186 DbPool::Postgres(p) => {
187 let mut conn = p.acquire().await?;
188 sqlx::query::<sqlx::Postgres>("BEGIN").execute(&mut *conn).await?;
189 ActiveTx { inner: ActiveTxInner::Postgres(conn), committed: false, rolled_back: false }
190 }
191 DbPool::Mysql(p) => {
192 let mut conn = p.acquire().await?;
193 sqlx::query::<sqlx::MySql>("BEGIN").execute(&mut *conn).await?;
194 ActiveTx { inner: ActiveTxInner::Mysql(conn), committed: false, rolled_back: false }
195 }
196 DbPool::Sqlite(p) => {
197 let mut conn = p.acquire().await?;
198 sqlx::query::<sqlx::Sqlite>("BEGIN").execute(&mut *conn).await?;
199 ActiveTx { inner: ActiveTxInner::Sqlite(conn), committed: false, rolled_back: false }
200 }
201 DbPool::Any(_) => return Err(DbError::Other("Any pool 不支持事务".into())),
202 };
203
204 let (mut tx, result) = f(tx).await;
205 match result {
206 Ok(val) => { tx.commit().await?; Ok(val) }
207 Err(e) => { tx.rollback().await; Err(e) }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 #[test]
214 fn test_isolation_display() {
215 use super::Isolation;
216 assert_eq!(Isolation::ReadCommitted.to_string(), "READ COMMITTED");
217 assert_eq!(Isolation::Serializable.to_string(), "SERIALIZABLE");
218 }
219}