Skip to main content

yang_db/mysql/
transaction.rs

1use crate::error::DbError;
2use crate::mysql::condition::{Condition, SqlValue};
3use crate::mysql::field::FieldType;
4use sqlx::Transaction as SqlxTransaction;
5use std::collections::HashMap;
6
7/// 数据库事务
8pub struct Transaction {
9    tx: Option<SqlxTransaction<'static, sqlx::MySql>>,
10    enable_logging: bool,
11}
12
13impl Transaction {
14    /// 创建新的事务实例
15    pub(crate) fn new(tx: SqlxTransaction<'static, sqlx::MySql>, enable_logging: bool) -> Self {
16        Self {
17            tx: Some(tx),
18            enable_logging,
19        }
20    }
21
22    /// 提交事务
23    pub async fn commit(mut self) -> Result<(), DbError> {
24        if self.enable_logging {
25            log::debug!("提交事务");
26        }
27
28        if let Some(tx) = self.tx.take() {
29            tx.commit().await?;
30        }
31
32        Ok(())
33    }
34
35    /// 回滚事务
36    pub async fn rollback(mut self) -> Result<(), DbError> {
37        if self.enable_logging {
38            log::debug!("回滚事务");
39        }
40
41        if let Some(tx) = self.tx.take() {
42            tx.rollback().await?;
43        }
44
45        Ok(())
46    }
47
48    /// 执行原生 SQL
49    pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
50        if self.enable_logging {
51            log::debug!("事务中执行: {}", sql);
52        }
53
54        if let Some(tx) = &mut self.tx {
55            let result = sqlx::query(sql).execute(&mut **tx).await?;
56            Ok(result.rows_affected())
57        } else {
58            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
59        }
60    }
61
62    /// 选择表,返回事务中的查询构建器
63    ///
64    /// # 参数
65    /// - table_name: 表名
66    ///
67    /// # 返回
68    /// - TransactionQueryBuilder: 事务查询构建器
69    ///
70    /// # 示例
71    /// ```no_run
72    /// use yang_db::Database;
73    /// use serde_json::json;
74    ///
75    /// # async fn example() -> Result<(), yang_db::DbError> {
76    /// let db = Database::connect("mysql://root:password@localhost/test").await?;
77    /// let mut tx = db.transaction().await?;
78    ///
79    /// // 在事务中插入数据
80    /// let user_data = json!({"name": "张三", "email": "zhangsan@example.com"});
81    /// let user_id = tx.table("users").insert(&user_data).await?;
82    ///
83    /// // 在事务中更新数据
84    /// let update_data = json!({"status": 1});
85    /// tx.table("users")
86    ///     .where_and("id", "=", user_id)
87    ///     .update(&update_data)
88    ///     .await?;
89    ///
90    /// // 提交事务
91    /// tx.commit().await?;
92    /// # Ok(())
93    /// # }
94    /// ```
95    pub fn table(&mut self, table_name: &str) -> TransactionQueryBuilder<'_> {
96        TransactionQueryBuilder::new(self, table_name)
97    }
98}
99
100/// 事务查询构建器
101///
102/// 用于在事务上下文中构建和执行查询
103pub struct TransactionQueryBuilder<'a> {
104    tx: &'a mut Transaction,
105    table: String,
106    conditions: Vec<Condition>,
107    field_types: HashMap<String, FieldType>,
108}
109
110impl<'a> TransactionQueryBuilder<'a> {
111    /// 创建新的事务查询构建器
112    fn new(tx: &'a mut Transaction, table_name: &str) -> Self {
113        Self {
114            tx,
115            table: table_name.to_string(),
116            conditions: Vec::new(),
117            field_types: HashMap::new(),
118        }
119    }
120
121    /// 标记字段为 JSON 类型
122    pub fn json(mut self, field: &str) -> Self {
123        self.field_types.insert(field.to_string(), FieldType::Json);
124        self
125    }
126
127    /// 标记字段为 DATETIME 类型
128    pub fn datetime(mut self, field: &str) -> Self {
129        self.field_types
130            .insert(field.to_string(), FieldType::DateTime);
131        self
132    }
133
134    /// 标记字段为 TIMESTAMP 类型
135    pub fn timestamp(mut self, field: &str) -> Self {
136        self.field_types
137            .insert(field.to_string(), FieldType::Timestamp);
138        self
139    }
140
141    /// 标记字段为 DECIMAL 类型
142    pub fn decimal(mut self, field: &str) -> Self {
143        self.field_types
144            .insert(field.to_string(), FieldType::Decimal);
145        self
146    }
147
148    /// 标记字段为 BLOB 类型
149    pub fn blob(mut self, field: &str) -> Self {
150        self.field_types.insert(field.to_string(), FieldType::Blob);
151        self
152    }
153
154    /// 标记字段为 TEXT 类型
155    pub fn text(mut self, field: &str) -> Self {
156        self.field_types.insert(field.to_string(), FieldType::Text);
157        self
158    }
159
160    /// 添加 AND 条件
161    pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Self
162    where
163        V: Into<SqlValue>,
164    {
165        let sql_value = value.into();
166        let condition = match op {
167            "=" => Condition::Eq(field.to_string(), sql_value),
168            "!=" => Condition::Ne(field.to_string(), sql_value),
169            ">" => Condition::Gt(field.to_string(), sql_value),
170            "<" => Condition::Lt(field.to_string(), sql_value),
171            ">=" => Condition::Gte(field.to_string(), sql_value),
172            "<=" => Condition::Lte(field.to_string(), sql_value),
173            "like" | "LIKE" => {
174                if let SqlValue::String(s) = sql_value {
175                    Condition::Like(field.to_string(), s)
176                } else {
177                    Condition::Like(field.to_string(), format!("{:?}", sql_value))
178                }
179            }
180            _ => panic!("不支持的操作符: {}", op),
181        };
182
183        self.conditions.push(condition);
184        self
185    }
186
187    /// 插入数据
188    ///
189    /// 在事务中执行 INSERT 操作
190    ///
191    /// # 类型参数
192    /// - T: 数据类型,必须实现 Serialize trait
193    ///
194    /// # 参数
195    /// - data: 要插入的数据
196    ///
197    /// # 返回
198    /// - Ok(u64): 插入成功,返回插入记录的 ID(自增主键)
199    /// - Err(DbError): 插入失败
200    pub async fn insert<T>(self, data: &T) -> Result<u64, DbError>
201    where
202        T: serde::Serialize,
203    {
204        // 记录日志
205        if self.tx.enable_logging {
206            log::debug!("事务中执行 insert() 操作,表: {}", self.table);
207        }
208
209        // 将数据序列化为 JSON
210        let json_data = serde_json::to_value(data)
211            .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
212
213        // 生成 INSERT 语句
214        let mut generator = crate::mysql::query_builder::SqlGenerator::new();
215        generator.build_insert(&self.table, &json_data, &self.field_types)?;
216
217        let sql = generator.get_sql();
218        let params = generator.get_params();
219
220        // 记录日志
221        if self.tx.enable_logging {
222            log::debug!("事务中执行 insert() SQL: {}", sql);
223            log::debug!("参数: {:?}", params);
224        }
225
226        // 构建查询
227        let mut query = sqlx::query(sql);
228
229        // 绑定参数
230        for param in params {
231            query = bind_execute_param(query, param);
232        }
233
234        // 执行插入
235        if let Some(tx) = &mut self.tx.tx {
236            let result = query.execute(&mut **tx).await?;
237            let last_insert_id = result.last_insert_id();
238
239            if self.tx.enable_logging {
240                log::debug!("事务中 insert() 成功,插入 ID: {}", last_insert_id);
241            }
242
243            Ok(last_insert_id)
244        } else {
245            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
246        }
247    }
248
249    /// 更新数据
250    ///
251    /// 在事务中执行 UPDATE 操作
252    /// 为了防止误操作,必须提供 WHERE 条件,否则会返回错误
253    ///
254    /// # 类型参数
255    /// - T: 数据类型,必须实现 Serialize trait
256    ///
257    /// # 参数
258    /// - data: 要更新的数据
259    ///
260    /// # 返回
261    /// - Ok(u64): 更新成功,返回受影响的行数
262    /// - Err(DbError): 更新失败
263    pub async fn update<T>(self, data: &T) -> Result<u64, DbError>
264    where
265        T: serde::Serialize,
266    {
267        // 记录日志
268        if self.tx.enable_logging {
269            log::debug!("事务中执行 update() 操作,表: {}", self.table);
270        }
271
272        // 检查是否有 WHERE 条件
273        if self.conditions.is_empty() {
274            log::warn!("事务中 update() 操作缺少 WHERE 条件,禁止全表更新");
275            return Err(DbError::MissingWhereClause);
276        }
277
278        // 将数据序列化为 JSON
279        let json_data = serde_json::to_value(data)
280            .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
281
282        // 生成 UPDATE 语句
283        let mut generator = crate::mysql::query_builder::SqlGenerator::new();
284        generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
285
286        let sql = generator.get_sql();
287        let params = generator.get_params();
288
289        // 记录日志
290        if self.tx.enable_logging {
291            log::debug!("事务中执行 update() SQL: {}", sql);
292            log::debug!("参数: {:?}", params);
293        }
294
295        // 构建查询
296        let mut query = sqlx::query(sql);
297
298        // 绑定参数
299        for param in params {
300            query = bind_execute_param(query, param);
301        }
302
303        // 执行更新
304        if let Some(tx) = &mut self.tx.tx {
305            let result = query.execute(&mut **tx).await?;
306            let rows_affected = result.rows_affected();
307
308            if self.tx.enable_logging {
309                log::debug!("事务中 update() 成功,影响 {} 行", rows_affected);
310            }
311
312            Ok(rows_affected)
313        } else {
314            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
315        }
316    }
317
318    /// 删除数据
319    ///
320    /// 在事务中执行 DELETE 操作
321    /// 为了防止误操作,必须提供 WHERE 条件,否则会返回错误
322    ///
323    /// # 返回
324    /// - Ok(u64): 删除成功,返回受影响的行数
325    /// - Err(DbError): 删除失败
326    pub async fn delete(self) -> Result<u64, DbError> {
327        // 记录日志
328        if self.tx.enable_logging {
329            log::debug!("事务中执行 delete() 操作,表: {}", self.table);
330        }
331
332        // 检查是否有 WHERE 条件
333        if self.conditions.is_empty() {
334            log::warn!("事务中 delete() 操作缺少 WHERE 条件,禁止全表删除");
335            return Err(DbError::MissingWhereClause);
336        }
337
338        // 生成 DELETE 语句
339        let mut generator = crate::mysql::query_builder::SqlGenerator::new();
340        generator.build_delete(&self.table, &self.conditions)?;
341
342        let sql = generator.get_sql();
343        let params = generator.get_params();
344
345        // 记录日志
346        if self.tx.enable_logging {
347            log::debug!("事务中执行 delete() SQL: {}", sql);
348            log::debug!("参数: {:?}", params);
349        }
350
351        // 构建查询
352        let mut query = sqlx::query(sql);
353
354        // 绑定参数
355        for param in params {
356            query = bind_execute_param(query, param);
357        }
358
359        // 执行删除
360        if let Some(tx) = &mut self.tx.tx {
361            let result = query.execute(&mut **tx).await?;
362            let rows_affected = result.rows_affected();
363
364            if self.tx.enable_logging {
365                log::debug!("事务中 delete() 成功,影响 {} 行", rows_affected);
366            }
367
368            Ok(rows_affected)
369        } else {
370            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
371        }
372    }
373}
374
375/// 绑定参数到执行查询(用于事务中的 INSERT/UPDATE/DELETE)
376///
377/// # 参数
378/// - query: sqlx 查询对象
379/// - param: SQL 参数值
380///
381/// # 返回
382/// - 绑定参数后的查询对象
383fn bind_execute_param<'q>(
384    query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
385    param: &SqlValue,
386) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
387    match param {
388        SqlValue::Null => query.bind(Option::<i32>::None),
389        SqlValue::Bool(b) => query.bind(*b),
390        SqlValue::Int(i) => query.bind(*i),
391        SqlValue::Float(f) => query.bind(*f),
392        SqlValue::String(s) => query.bind(s.clone()),
393        SqlValue::Bytes(b) => query.bind(b.clone()),
394        SqlValue::Json(j) => query.bind(j.to_string()),
395        SqlValue::DateTime(dt) => query.bind(*dt),
396        SqlValue::Timestamp(ts) => query.bind(*ts),
397    }
398}