baichun_framework_db/
executor.rs

1//! 数据库执行器模块
2//!
3//! 本模块提供了数据库查询和事务执行的核心功能。包括:
4//! - 基本的数据库查询操作(执行、查询单行、查询多行等)
5//! - 事务管理
6//! - 连接池和事务的统一接口
7
8use crate::error::{Error, Result};
9use async_trait::async_trait;
10use sqlx::{MySql, Transaction};
11use std::future::Future;
12
13/// 数据库执行器特征
14///
15/// 提供了基本的数据库操作接口,包括执行查询、获取结果等。
16///
17/// # 示例
18///
19/// ```rust
20/// use baichun_framework_db::{Executor, DbPool};
21///
22/// async fn example(pool: &mut DbPool) -> Result<()> {
23///     // 执行插入操作
24///     pool.execute("INSERT INTO users (name) VALUES ('Alice')").await?;
25///
26///     // 查询单条记录
27///     let user: User = pool.fetch_one("SELECT * FROM users WHERE id = 1").await?;
28///
29///     // 查询多条记录
30///     let users: Vec<User> = pool.fetch_all("SELECT * FROM users").await?;
31///
32///     Ok(())
33/// }
34/// ```
35#[async_trait]
36pub trait Executor {
37    /// 执行不返回行的查询
38    ///
39    /// 适用于 INSERT、UPDATE、DELETE 等操作。
40    async fn execute(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult>;
41
42    /// 执行返回多行的查询
43    ///
44    /// 将结果转换为指定类型的向量。
45    async fn fetch_all<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
46    where
47        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
48
49    /// 执行返回单行的查询
50    ///
51    /// 如果查询没有返回行,将返回错误。
52    async fn fetch_one<'q, T>(&mut self, query: &'q str) -> Result<T>
53    where
54        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
55
56    /// 执行返回可选单行的查询
57    ///
58    /// 如果查询没有返回行,将返回 None。
59    async fn fetch_optional<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
60    where
61        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
62}
63
64/// 数据库执行器内部特征
65///
66/// 为连接池和事务提供统一的执行接口。
67#[async_trait]
68pub trait DbExecutor<'c>: Send + Sync {
69    /// 执行查询
70    async fn execute_query(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult>;
71
72    /// 获取所有结果
73    async fn fetch_all_query<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
74    where
75        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
76
77    /// 获取单行结果
78    async fn fetch_one_query<'q, T>(&mut self, query: &'q str) -> Result<T>
79    where
80        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
81
82    /// 获取可选的单行结果
83    async fn fetch_optional_query<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
84    where
85        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>;
86}
87
88#[async_trait]
89impl<'c> DbExecutor<'c> for sqlx::Pool<MySql> {
90    async fn execute_query(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult> {
91        sqlx::query(query)
92            .execute(&*self)
93            .await
94            .map_err(|e| Error::Query(e.to_string()))
95    }
96
97    async fn fetch_all_query<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
98    where
99        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
100    {
101        sqlx::query_as::<_, T>(query)
102            .fetch_all(&*self)
103            .await
104            .map_err(|e| Error::Query(e.to_string()))
105    }
106
107    async fn fetch_one_query<'q, T>(&mut self, query: &'q str) -> Result<T>
108    where
109        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
110    {
111        sqlx::query_as::<_, T>(query)
112            .fetch_one(&*self)
113            .await
114            .map_err(|e| Error::Query(e.to_string()))
115    }
116
117    async fn fetch_optional_query<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
118    where
119        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
120    {
121        sqlx::query_as::<_, T>(query)
122            .fetch_optional(&*self)
123            .await
124            .map_err(|e| Error::Query(e.to_string()))
125    }
126}
127
128#[async_trait]
129impl<'c> DbExecutor<'c> for Transaction<'c, MySql> {
130    async fn execute_query(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult> {
131        sqlx::query(query)
132            .execute(&mut **self)
133            .await
134            .map_err(|e| Error::Query(e.to_string()))
135    }
136
137    async fn fetch_all_query<'q, T>(&mut self, query: &'q str) -> Result<Vec<T>>
138    where
139        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
140    {
141        sqlx::query_as::<_, T>(query)
142            .fetch_all(&mut **self)
143            .await
144            .map_err(|e| Error::Query(e.to_string()))
145    }
146
147    async fn fetch_one_query<'q, T>(&mut self, query: &'q str) -> Result<T>
148    where
149        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
150    {
151        sqlx::query_as::<_, T>(query)
152            .fetch_one(&mut **self)
153            .await
154            .map_err(|e| Error::Query(e.to_string()))
155    }
156
157    async fn fetch_optional_query<'q, T>(&mut self, query: &'q str) -> Result<Option<T>>
158    where
159        T: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
160    {
161        sqlx::query_as::<_, T>(query)
162            .fetch_optional(&mut **self)
163            .await
164            .map_err(|e| Error::Query(e.to_string()))
165    }
166}
167
168#[async_trait]
169impl<'c, T> Executor for T
170where
171    T: DbExecutor<'c> + Send + Sync,
172{
173    async fn execute(&mut self, query: &str) -> Result<sqlx::mysql::MySqlQueryResult> {
174        self.execute_query(query).await
175    }
176
177    async fn fetch_all<'q, U>(&mut self, query: &'q str) -> Result<Vec<U>>
178    where
179        U: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
180    {
181        self.fetch_all_query(query).await
182    }
183
184    async fn fetch_one<'q, U>(&mut self, query: &'q str) -> Result<U>
185    where
186        U: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
187    {
188        self.fetch_one_query(query).await
189    }
190
191    async fn fetch_optional<'q, U>(&mut self, query: &'q str) -> Result<Option<U>>
192    where
193        U: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
194    {
195        self.fetch_optional_query(query).await
196    }
197}
198
199/// 事务管理器
200///
201/// 提供了事务的自动提交和回滚功能。
202///
203/// # 示例
204///
205/// ```rust
206/// use baichun_framework_db::{DbPool, TransactionManager};
207///
208/// async fn example(pool: &DbPool) -> Result<()> {
209///     let tx = pool.begin().await?;
210///     let mut tm = TransactionManager::new(tx);
211///
212///     tm.execute(|tx| Box::pin(async move {
213///         tx.execute("INSERT INTO users (name) VALUES ('Alice')").await?;
214///         tx.execute("UPDATE users SET status = 'active' WHERE name = 'Alice'").await?;
215///         Ok(())
216///     })).await?;
217///
218///     Ok(())
219/// }
220/// ```
221pub struct TransactionManager<'c> {
222    tx: Option<Transaction<'c, MySql>>,
223}
224
225impl<'c> TransactionManager<'c> {
226    /// 创建新的事务管理器
227    pub fn new(tx: Transaction<'c, MySql>) -> Self {
228        Self { tx: Some(tx) }
229    }
230
231    /// 在事务中执行操作
232    ///
233    /// 如果操作成功,事务将被提交;如果操作失败,事务将被回滚。
234    pub async fn execute<F, T, E>(&mut self, f: F) -> Result<T>
235    where
236        F: for<'a> FnOnce(
237            &'a mut Transaction<'c, MySql>,
238        ) -> std::pin::Pin<
239            Box<dyn Future<Output = std::result::Result<T, E>> + Send + 'a>,
240        >,
241        E: Into<Error>,
242        T: Send,
243    {
244        let tx = self
245            .tx
246            .take()
247            .ok_or_else(|| Error::Transaction("Transaction already used".to_string()))?;
248        let mut tx = tx;
249
250        match f(&mut tx).await {
251            Ok(value) => {
252                tx.commit()
253                    .await
254                    .map_err(|e| Error::Transaction(e.to_string()))?;
255                Ok(value)
256            }
257            Err(e) => {
258                if let Err(e) = tx.rollback().await {
259                    return Err(Error::Transaction(e.to_string()));
260                }
261                Err(e.into())
262            }
263        }
264    }
265}