1use crate::error::DbError;
2use crate::mysql::condition::{Condition, SqlValue};
3use crate::mysql::field::FieldType;
4use sqlx::Transaction as SqlxTransaction;
5use std::collections::HashMap;
6
7pub struct Transaction {
9 tx: Option<SqlxTransaction<'static, sqlx::MySql>>,
10 enable_logging: bool,
11}
12
13impl Transaction {
14 pub(crate) fn new(tx: SqlxTransaction<'static, sqlx::MySql>, enable_logging: bool) -> Self {
16 Self {
17 tx: Some(tx),
18 enable_logging,
19 }
20 }
21
22 pub async fn commit(mut self) -> Result<(), DbError> {
24 if self.enable_logging {
25 log::debug!("提交事务");
26 }
27
28 if let Some(tx) = self.tx.take() {
29 tx.commit().await?;
30 }
31
32 Ok(())
33 }
34
35 pub async fn rollback(mut self) -> Result<(), DbError> {
37 if self.enable_logging {
38 log::debug!("回滚事务");
39 }
40
41 if let Some(tx) = self.tx.take() {
42 tx.rollback().await?;
43 }
44
45 Ok(())
46 }
47
48 pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
50 if self.enable_logging {
51 log::debug!("事务中执行: {}", sql);
52 }
53
54 if let Some(tx) = &mut self.tx {
55 let result = sqlx::query(sql).execute(&mut **tx).await?;
56 Ok(result.rows_affected())
57 } else {
58 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
59 }
60 }
61
62 pub fn table(&mut self, table_name: &str) -> TransactionQueryBuilder<'_> {
96 TransactionQueryBuilder::new(self, table_name)
97 }
98}
99
100pub struct TransactionQueryBuilder<'a> {
104 tx: &'a mut Transaction,
105 table: String,
106 conditions: Vec<Condition>,
107 field_types: HashMap<String, FieldType>,
108}
109
110impl<'a> TransactionQueryBuilder<'a> {
111 fn new(tx: &'a mut Transaction, table_name: &str) -> Self {
113 Self {
114 tx,
115 table: table_name.to_string(),
116 conditions: Vec::new(),
117 field_types: HashMap::new(),
118 }
119 }
120
121 pub fn json(mut self, field: &str) -> Self {
123 self.field_types.insert(field.to_string(), FieldType::Json);
124 self
125 }
126
127 pub fn datetime(mut self, field: &str) -> Self {
129 self.field_types
130 .insert(field.to_string(), FieldType::DateTime);
131 self
132 }
133
134 pub fn timestamp(mut self, field: &str) -> Self {
136 self.field_types
137 .insert(field.to_string(), FieldType::Timestamp);
138 self
139 }
140
141 pub fn decimal(mut self, field: &str) -> Self {
143 self.field_types
144 .insert(field.to_string(), FieldType::Decimal);
145 self
146 }
147
148 pub fn blob(mut self, field: &str) -> Self {
150 self.field_types.insert(field.to_string(), FieldType::Blob);
151 self
152 }
153
154 pub fn text(mut self, field: &str) -> Self {
156 self.field_types.insert(field.to_string(), FieldType::Text);
157 self
158 }
159
160 pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Self
162 where
163 V: Into<SqlValue>,
164 {
165 let sql_value = value.into();
166 let condition = match op {
167 "=" => Condition::Eq(field.to_string(), sql_value),
168 "!=" => Condition::Ne(field.to_string(), sql_value),
169 ">" => Condition::Gt(field.to_string(), sql_value),
170 "<" => Condition::Lt(field.to_string(), sql_value),
171 ">=" => Condition::Gte(field.to_string(), sql_value),
172 "<=" => Condition::Lte(field.to_string(), sql_value),
173 "like" | "LIKE" => {
174 if let SqlValue::String(s) = sql_value {
175 Condition::Like(field.to_string(), s)
176 } else {
177 Condition::Like(field.to_string(), format!("{:?}", sql_value))
178 }
179 }
180 _ => panic!("不支持的操作符: {}", op),
181 };
182
183 self.conditions.push(condition);
184 self
185 }
186
187 pub async fn insert<T>(self, data: &T) -> Result<u64, DbError>
201 where
202 T: serde::Serialize,
203 {
204 if self.tx.enable_logging {
206 log::debug!("事务中执行 insert() 操作,表: {}", self.table);
207 }
208
209 let json_data = serde_json::to_value(data)
211 .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
212
213 let mut generator = crate::mysql::query_builder::SqlGenerator::new();
215 generator.build_insert(&self.table, &json_data, &self.field_types)?;
216
217 let sql = generator.get_sql();
218 let params = generator.get_params();
219
220 if self.tx.enable_logging {
222 log::debug!("事务中执行 insert() SQL: {}", sql);
223 log::debug!("参数: {:?}", params);
224 }
225
226 let mut query = sqlx::query(sql);
228
229 for param in params {
231 query = bind_execute_param(query, param);
232 }
233
234 if let Some(tx) = &mut self.tx.tx {
236 let result = query.execute(&mut **tx).await?;
237 let last_insert_id = result.last_insert_id();
238
239 if self.tx.enable_logging {
240 log::debug!("事务中 insert() 成功,插入 ID: {}", last_insert_id);
241 }
242
243 Ok(last_insert_id)
244 } else {
245 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
246 }
247 }
248
249 pub async fn update<T>(self, data: &T) -> Result<u64, DbError>
264 where
265 T: serde::Serialize,
266 {
267 if self.tx.enable_logging {
269 log::debug!("事务中执行 update() 操作,表: {}", self.table);
270 }
271
272 if self.conditions.is_empty() {
274 log::warn!("事务中 update() 操作缺少 WHERE 条件,禁止全表更新");
275 return Err(DbError::MissingWhereClause);
276 }
277
278 let json_data = serde_json::to_value(data)
280 .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
281
282 let mut generator = crate::mysql::query_builder::SqlGenerator::new();
284 generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
285
286 let sql = generator.get_sql();
287 let params = generator.get_params();
288
289 if self.tx.enable_logging {
291 log::debug!("事务中执行 update() SQL: {}", sql);
292 log::debug!("参数: {:?}", params);
293 }
294
295 let mut query = sqlx::query(sql);
297
298 for param in params {
300 query = bind_execute_param(query, param);
301 }
302
303 if let Some(tx) = &mut self.tx.tx {
305 let result = query.execute(&mut **tx).await?;
306 let rows_affected = result.rows_affected();
307
308 if self.tx.enable_logging {
309 log::debug!("事务中 update() 成功,影响 {} 行", rows_affected);
310 }
311
312 Ok(rows_affected)
313 } else {
314 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
315 }
316 }
317
318 pub async fn delete(self) -> Result<u64, DbError> {
327 if self.tx.enable_logging {
329 log::debug!("事务中执行 delete() 操作,表: {}", self.table);
330 }
331
332 if self.conditions.is_empty() {
334 log::warn!("事务中 delete() 操作缺少 WHERE 条件,禁止全表删除");
335 return Err(DbError::MissingWhereClause);
336 }
337
338 let mut generator = crate::mysql::query_builder::SqlGenerator::new();
340 generator.build_delete(&self.table, &self.conditions)?;
341
342 let sql = generator.get_sql();
343 let params = generator.get_params();
344
345 if self.tx.enable_logging {
347 log::debug!("事务中执行 delete() SQL: {}", sql);
348 log::debug!("参数: {:?}", params);
349 }
350
351 let mut query = sqlx::query(sql);
353
354 for param in params {
356 query = bind_execute_param(query, param);
357 }
358
359 if let Some(tx) = &mut self.tx.tx {
361 let result = query.execute(&mut **tx).await?;
362 let rows_affected = result.rows_affected();
363
364 if self.tx.enable_logging {
365 log::debug!("事务中 delete() 成功,影响 {} 行", rows_affected);
366 }
367
368 Ok(rows_affected)
369 } else {
370 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
371 }
372 }
373}
374
375fn bind_execute_param<'q>(
384 query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
385 param: &SqlValue,
386) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
387 match param {
388 SqlValue::Null => query.bind(Option::<i32>::None),
389 SqlValue::Bool(b) => query.bind(*b),
390 SqlValue::Int(i) => query.bind(*i),
391 SqlValue::Float(f) => query.bind(*f),
392 SqlValue::String(s) => query.bind(s.clone()),
393 SqlValue::Bytes(b) => query.bind(b.clone()),
394 SqlValue::Json(j) => query.bind(j.to_string()),
395 SqlValue::DateTime(dt) => query.bind(*dt),
396 SqlValue::Timestamp(ts) => query.bind(*ts),
397 }
398}