Skip to main content

yang_db/mysql/
database.rs

1use crate::error::DbError;
2use crate::mysql::query_builder::QueryBuilder;
3use crate::mysql::transaction::Transaction;
4use sqlx::mysql::MySqlPool;
5
6/// 数据库配置
7///
8/// 用于配置数据库连接池的参数
9///
10/// # 示例
11///
12/// ```rust
13/// use yang_db::DatabaseConfig;
14///
15/// // 使用默认配置
16/// let config = DatabaseConfig::default();
17///
18/// // 自定义配置
19/// let config = DatabaseConfig {
20///     max_connections: 20,
21///     connect_timeout: 10,
22///     idle_timeout: 300,
23///     enable_logging: true,
24/// };
25/// ```
26#[derive(Debug, Clone)]
27pub struct DatabaseConfig {
28    /// 最大连接数
29    pub max_connections: u32,
30    /// 连接超时时间(秒)
31    pub connect_timeout: u64,
32    /// 空闲连接超时时间(秒)
33    pub idle_timeout: u64,
34    /// 是否启用日志
35    pub enable_logging: bool,
36}
37
38impl Default for DatabaseConfig {
39    fn default() -> Self {
40        Self {
41            max_connections: 10,
42            connect_timeout: 30,
43            idle_timeout: 600,
44            enable_logging: false,
45        }
46    }
47}
48
49/// 数据库连接管理器
50///
51/// 管理 MySQL 数据库连接池,提供查询构建和执行的入口点
52///
53/// # 示例
54///
55/// ```rust,no_run
56/// use yang_db::Database;
57///
58/// #[tokio::main]
59/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
60///     // 连接数据库
61///     let db = Database::connect("mysql://user:password@localhost:3306/test").await?;
62///     
63///     // 使用查询构建器
64///     let builder = db.table("users")
65///         .field("id")
66///         .field("name");
67///     
68///     // 执行查询(需要实现 select 方法)
69///     // let users = builder.select::<User>().await?;
70///     
71///     Ok(())
72/// }
73/// ```
74pub struct Database {
75    pool: MySqlPool,
76    config: DatabaseConfig,
77}
78
79impl Database {
80    /// 创建数据库连接
81    ///
82    /// # 参数
83    /// - url: 数据库连接字符串,格式:mysql://user:password@host:port/database
84    ///
85    /// # 返回
86    /// - Ok(Database): 成功创建的数据库实例
87    /// - Err(DbError): 连接失败错误
88    pub async fn connect(url: &str) -> Result<Self, DbError> {
89        Self::connect_with_config(url, DatabaseConfig::default()).await
90    }
91
92    /// 使用自定义配置创建数据库连接
93    pub async fn connect_with_config(url: &str, config: DatabaseConfig) -> Result<Self, DbError> {
94        use sqlx::mysql::MySqlPoolOptions;
95        use std::time::Duration;
96
97        // 使用配置参数创建连接池
98        let pool = MySqlPoolOptions::new()
99            .max_connections(config.max_connections)
100            .acquire_timeout(Duration::from_secs(config.connect_timeout))
101            .idle_timeout(Duration::from_secs(config.idle_timeout))
102            .connect(url)
103            .await?;
104
105        Ok(Self { pool, config })
106    }
107
108    /// 选择表,返回查询构建器
109    pub fn table(&self, table_name: &str) -> QueryBuilder<'_> {
110        QueryBuilder::new(&self.pool, table_name, self.config.enable_logging)
111    }
112
113    /// 执行原生 SELECT 查询
114    pub async fn query<T>(&self, sql: &str) -> Result<Vec<T>, DbError>
115    where
116        T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
117    {
118        if self.config.enable_logging {
119            log::debug!("执行原生查询: {}", sql);
120        }
121
122        let rows = sqlx::query_as::<_, T>(sql).fetch_all(&self.pool).await?;
123
124        Ok(rows)
125    }
126
127    /// 执行原生 INSERT/UPDATE/DELETE 查询
128    pub async fn execute(&self, sql: &str) -> Result<u64, DbError> {
129        if self.config.enable_logging {
130            log::debug!("执行原生语句: {}", sql);
131        }
132
133        let result = sqlx::query(sql).execute(&self.pool).await?;
134
135        Ok(result.rows_affected())
136    }
137
138    /// 开始事务
139    pub async fn transaction(&self) -> Result<Transaction, DbError> {
140        let tx = self.pool.begin().await?;
141        Ok(Transaction::new(tx, self.config.enable_logging))
142    }
143
144    /// 初始化数据库(执行 SQL 脚本)
145    pub async fn init(&self, sql_script: &str) -> Result<(), DbError> {
146        // 按分号分割多个 SQL 语句
147        let statements: Vec<&str> = sql_script
148            .split(';')
149            .map(|s| s.trim())
150            .filter(|s| !s.is_empty())
151            .collect();
152
153        for statement in statements {
154            self.execute(statement).await?;
155        }
156
157        Ok(())
158    }
159
160    /// 创建表
161    pub async fn create_table(&self, create_sql: &str) -> Result<(), DbError> {
162        self.execute(create_sql).await?;
163        Ok(())
164    }
165
166    /// 删除表
167    pub async fn drop_table(&self, table_name: &str) -> Result<(), DbError> {
168        let sql = format!("DROP TABLE IF EXISTS `{}`", table_name);
169        self.execute(&sql).await?;
170        Ok(())
171    }
172
173    /// 检查表是否存在
174    pub async fn table_exists(&self, table_name: &str) -> Result<bool, DbError> {
175        let sql = format!(
176            "SELECT COUNT(*) as count FROM information_schema.tables \
177             WHERE table_schema = DATABASE() AND table_name = '{}'",
178            table_name
179        );
180
181        let row: (i64,) = sqlx::query_as(&sql).fetch_one(&self.pool).await?;
182
183        Ok(row.0 > 0)
184    }
185
186    /// 执行带参数的原生 SELECT 查询(参数化查询,防止 SQL 注入)
187    ///
188    /// # 参数
189    /// - sql: SQL 查询语句,使用 `?` 作为参数占位符
190    /// - params: 参数列表,使用 `serde_json::Value` 类型
191    ///
192    /// # 返回
193    /// - Ok(Vec<T>): 查询结果列表
194    /// - Err(DbError): 查询失败错误
195    ///
196    /// # 示例
197    ///
198    /// ```rust,no_run
199    /// use yang_db::Database;
200    /// use serde_json::json;
201    ///
202    /// # async fn example() -> Result<(), yang_db::DbError> {
203    /// let db = Database::connect("mysql://root:password@localhost/test").await?;
204    /// let params = vec![json!("admin"), json!(1)];
205    /// // let users: Vec<User> = db.query_with_params("SELECT * FROM users WHERE role = ? AND status = ?", params).await?;
206    /// # Ok(())
207    /// # }
208    /// ```
209    pub async fn query_with_params<T>(
210        &self,
211        sql: &str,
212        params: Vec<serde_json::Value>,
213    ) -> Result<Vec<T>, DbError>
214    where
215        T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
216    {
217        if self.config.enable_logging {
218            log::debug!("执行参数化查询: {}, 参数数量: {}", sql, params.len());
219        }
220
221        // 构建查询并逐一绑定参数
222        let mut query = sqlx::query_as::<_, T>(sql);
223        for param in &params {
224            query = bind_json_param_as(query, param);
225        }
226
227        let rows = query.fetch_all(&self.pool).await?;
228        Ok(rows)
229    }
230
231    /// 执行带参数的原生 INSERT/UPDATE/DELETE 语句(参数化查询,防止 SQL 注入)
232    ///
233    /// # 参数
234    /// - sql: SQL 语句,使用 `?` 作为参数占位符
235    /// - params: 参数列表,使用 `serde_json::Value` 类型
236    ///
237    /// # 返回
238    /// - Ok(u64): 受影响的行数
239    /// - Err(DbError): 执行失败错误
240    ///
241    /// # 示例
242    ///
243    /// ```rust,no_run
244    /// use yang_db::Database;
245    /// use serde_json::json;
246    ///
247    /// # async fn example() -> Result<(), yang_db::DbError> {
248    /// let db = Database::connect("mysql://root:password@localhost/test").await?;
249    /// let params = vec![json!("张三"), json!("zhangsan@example.com")];
250    /// let rows = db.execute_with_params("INSERT INTO users (name, email) VALUES (?, ?)", params).await?;
251    /// # Ok(())
252    /// # }
253    /// ```
254    pub async fn execute_with_params(
255        &self,
256        sql: &str,
257        params: Vec<serde_json::Value>,
258    ) -> Result<u64, DbError> {
259        if self.config.enable_logging {
260            log::debug!("执行参数化语句: {}, 参数数量: {}", sql, params.len());
261        }
262
263        // 构建查询并逐一绑定参数
264        let mut query = sqlx::query(sql);
265        for param in &params {
266            query = bind_json_param(query, param);
267        }
268
269        let result = query.execute(&self.pool).await?;
270        Ok(result.rows_affected())
271    }
272}
273
274/// 将 `serde_json::Value` 参数绑定到 `query_as` 查询
275///
276/// # 参数
277/// - query: sqlx query_as 查询对象
278/// - param: JSON 参数值
279///
280/// # 返回
281/// - 绑定参数后的查询对象
282fn bind_json_param_as<'q, T>(
283    query: sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
284    param: &serde_json::Value,
285) -> sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
286where
287    T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
288{
289    match param {
290        // 字符串类型直接绑定
291        serde_json::Value::String(s) => query.bind(s.clone()),
292        // 数字类型转为 i64 绑定
293        serde_json::Value::Number(n) => {
294            if let Some(i) = n.as_i64() {
295                query.bind(i)
296            } else if let Some(f) = n.as_f64() {
297                // 浮点数转为字符串绑定,避免精度丢失
298                query.bind(f.to_string())
299            } else {
300                query.bind(Option::<String>::None)
301            }
302        }
303        // 布尔类型绑定
304        serde_json::Value::Bool(b) => query.bind(*b),
305        // NULL 类型绑定为 None
306        serde_json::Value::Null => query.bind(Option::<String>::None),
307        // 数组和对象类型序列化为 JSON 字符串绑定
308        other => query.bind(other.to_string()),
309    }
310}
311
312/// 将 `serde_json::Value` 参数绑定到执行查询
313///
314/// # 参数
315/// - query: sqlx 执行查询对象
316/// - param: JSON 参数值
317///
318/// # 返回
319/// - 绑定参数后的查询对象
320fn bind_json_param<'q>(
321    query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
322    param: &serde_json::Value,
323) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
324    match param {
325        // 字符串类型直接绑定
326        serde_json::Value::String(s) => query.bind(s.clone()),
327        // 数字类型转为 i64 绑定
328        serde_json::Value::Number(n) => {
329            if let Some(i) = n.as_i64() {
330                query.bind(i)
331            } else if let Some(f) = n.as_f64() {
332                // 浮点数转为字符串绑定,避免精度丢失
333                query.bind(f.to_string())
334            } else {
335                query.bind(Option::<String>::None)
336            }
337        }
338        // 布尔类型绑定
339        serde_json::Value::Bool(b) => query.bind(*b),
340        // NULL 类型绑定为 None
341        serde_json::Value::Null => query.bind(Option::<String>::None),
342        // 数组和对象类型序列化为 JSON 字符串绑定
343        other => query.bind(other.to_string()),
344    }
345}