use crate::condition::{Condition, SqlValue};
use crate::error::DbError;
use crate::field::FieldType;
use sqlx::Transaction as SqlxTransaction;
use std::collections::HashMap;
pub struct Transaction {
tx: Option<SqlxTransaction<'static, sqlx::MySql>>,
enable_logging: bool,
}
impl Transaction {
pub(crate) fn new(tx: SqlxTransaction<'static, sqlx::MySql>, enable_logging: bool) -> Self {
Self {
tx: Some(tx),
enable_logging,
}
}
pub async fn commit(mut self) -> Result<(), DbError> {
if self.enable_logging {
log::debug!("提交事务");
}
if let Some(tx) = self.tx.take() {
tx.commit().await?;
}
Ok(())
}
pub async fn rollback(mut self) -> Result<(), DbError> {
if self.enable_logging {
log::debug!("回滚事务");
}
if let Some(tx) = self.tx.take() {
tx.rollback().await?;
}
Ok(())
}
pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
if self.enable_logging {
log::debug!("事务中执行: {}", sql);
}
if let Some(tx) = &mut self.tx {
let result = sqlx::query(sql).execute(&mut **tx).await?;
Ok(result.rows_affected())
} else {
Err(DbError::TransactionError("事务已提交或回滚".to_string()))
}
}
pub fn table(&mut self, table_name: &str) -> TransactionQueryBuilder<'_> {
TransactionQueryBuilder::new(self, table_name)
}
}
pub struct TransactionQueryBuilder<'a> {
tx: &'a mut Transaction,
table: String,
conditions: Vec<Condition>,
field_types: HashMap<String, FieldType>,
}
impl<'a> TransactionQueryBuilder<'a> {
fn new(tx: &'a mut Transaction, table_name: &str) -> Self {
Self {
tx,
table: table_name.to_string(),
conditions: Vec::new(),
field_types: HashMap::new(),
}
}
pub fn json(mut self, field: &str) -> Self {
self.field_types.insert(field.to_string(), FieldType::Json);
self
}
pub fn datetime(mut self, field: &str) -> Self {
self.field_types
.insert(field.to_string(), FieldType::DateTime);
self
}
pub fn timestamp(mut self, field: &str) -> Self {
self.field_types
.insert(field.to_string(), FieldType::Timestamp);
self
}
pub fn decimal(mut self, field: &str) -> Self {
self.field_types
.insert(field.to_string(), FieldType::Decimal);
self
}
pub fn blob(mut self, field: &str) -> Self {
self.field_types.insert(field.to_string(), FieldType::Blob);
self
}
pub fn text(mut self, field: &str) -> Self {
self.field_types.insert(field.to_string(), FieldType::Text);
self
}
pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Self
where
V: Into<SqlValue>,
{
let sql_value = value.into();
let condition = match op {
"=" => Condition::Eq(field.to_string(), sql_value),
"!=" => Condition::Ne(field.to_string(), sql_value),
">" => Condition::Gt(field.to_string(), sql_value),
"<" => Condition::Lt(field.to_string(), sql_value),
">=" => Condition::Gte(field.to_string(), sql_value),
"<=" => Condition::Lte(field.to_string(), sql_value),
"like" | "LIKE" => {
if let SqlValue::String(s) = sql_value {
Condition::Like(field.to_string(), s)
} else {
Condition::Like(field.to_string(), format!("{:?}", sql_value))
}
}
_ => panic!("不支持的操作符: {}", op),
};
self.conditions.push(condition);
self
}
pub async fn insert<T>(self, data: &T) -> Result<u64, DbError>
where
T: serde::Serialize,
{
if self.tx.enable_logging {
log::debug!("事务中执行 insert() 操作,表: {}", self.table);
}
let json_data = serde_json::to_value(data)
.map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
let mut generator = crate::query_builder::SqlGenerator::new();
generator.build_insert(&self.table, &json_data, &self.field_types)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.tx.enable_logging {
log::debug!("事务中执行 insert() SQL: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query(sql);
for param in params {
query = bind_execute_param(query, param);
}
if let Some(tx) = &mut self.tx.tx {
let result = query.execute(&mut **tx).await?;
let last_insert_id = result.last_insert_id();
if self.tx.enable_logging {
log::debug!("事务中 insert() 成功,插入 ID: {}", last_insert_id);
}
Ok(last_insert_id)
} else {
Err(DbError::TransactionError("事务已提交或回滚".to_string()))
}
}
pub async fn update<T>(self, data: &T) -> Result<u64, DbError>
where
T: serde::Serialize,
{
if self.tx.enable_logging {
log::debug!("事务中执行 update() 操作,表: {}", self.table);
}
if self.conditions.is_empty() {
log::warn!("事务中 update() 操作缺少 WHERE 条件,禁止全表更新");
return Err(DbError::MissingWhereClause);
}
let json_data = serde_json::to_value(data)
.map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
let mut generator = crate::query_builder::SqlGenerator::new();
generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.tx.enable_logging {
log::debug!("事务中执行 update() SQL: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query(sql);
for param in params {
query = bind_execute_param(query, param);
}
if let Some(tx) = &mut self.tx.tx {
let result = query.execute(&mut **tx).await?;
let rows_affected = result.rows_affected();
if self.tx.enable_logging {
log::debug!("事务中 update() 成功,影响 {} 行", rows_affected);
}
Ok(rows_affected)
} else {
Err(DbError::TransactionError("事务已提交或回滚".to_string()))
}
}
pub async fn delete(self) -> Result<u64, DbError> {
if self.tx.enable_logging {
log::debug!("事务中执行 delete() 操作,表: {}", self.table);
}
if self.conditions.is_empty() {
log::warn!("事务中 delete() 操作缺少 WHERE 条件,禁止全表删除");
return Err(DbError::MissingWhereClause);
}
let mut generator = crate::query_builder::SqlGenerator::new();
generator.build_delete(&self.table, &self.conditions)?;
let sql = generator.get_sql();
let params = generator.get_params();
if self.tx.enable_logging {
log::debug!("事务中执行 delete() SQL: {}", sql);
log::debug!("参数: {:?}", params);
}
let mut query = sqlx::query(sql);
for param in params {
query = bind_execute_param(query, param);
}
if let Some(tx) = &mut self.tx.tx {
let result = query.execute(&mut **tx).await?;
let rows_affected = result.rows_affected();
if self.tx.enable_logging {
log::debug!("事务中 delete() 成功,影响 {} 行", rows_affected);
}
Ok(rows_affected)
} else {
Err(DbError::TransactionError("事务已提交或回滚".to_string()))
}
}
}
fn bind_execute_param<'q>(
query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
param: &SqlValue,
) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
match param {
SqlValue::Null => query.bind(Option::<i32>::None),
SqlValue::Bool(b) => query.bind(*b),
SqlValue::Int(i) => query.bind(*i),
SqlValue::Float(f) => query.bind(*f),
SqlValue::String(s) => query.bind(s.clone()),
SqlValue::Bytes(b) => query.bind(b.clone()),
SqlValue::Json(j) => query.bind(j.to_string()),
SqlValue::DateTime(dt) => query.bind(*dt),
SqlValue::Timestamp(ts) => query.bind(*ts),
}
}