Skip to main content

yang_db/mysql/
database.rs

1use crate::error::DbError;
2use crate::mysql::query_builder::QueryBuilder;
3use crate::mysql::transaction::Transaction;
4use sqlx::mysql::MySqlPool;
5
6/// 数据库配置
7///
8/// 用于配置数据库连接池的参数
9///
10/// # 示例
11///
12/// ```rust
13/// use yang_db::DatabaseConfig;
14///
15/// // 使用默认配置
16/// let config = DatabaseConfig::default();
17///
18/// // 自定义配置
19/// let config = DatabaseConfig {
20///     max_connections: 20,
21///     connect_timeout: 10,
22///     idle_timeout: 300,
23///     enable_logging: true,
24/// };
25/// ```
26#[derive(Debug, Clone)]
27pub struct DatabaseConfig {
28    /// 最大连接数
29    pub max_connections: u32,
30    /// 连接超时时间(秒)
31    pub connect_timeout: u64,
32    /// 空闲连接超时时间(秒)
33    pub idle_timeout: u64,
34    /// 是否启用日志
35    pub enable_logging: bool,
36}
37
38impl Default for DatabaseConfig {
39    fn default() -> Self {
40        Self {
41            max_connections: 10,
42            connect_timeout: 30,
43            idle_timeout: 600,
44            enable_logging: false,
45        }
46    }
47}
48
49/// 数据库连接管理器
50///
51/// 管理 MySQL 数据库连接池,提供查询构建和执行的入口点
52///
53/// # 示例
54///
55/// ```rust,no_run
56/// use yang_db::Database;
57///
58/// #[tokio::main]
59/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
60///     // 连接数据库
61///     let db = Database::connect("mysql://user:password@localhost:3306/test").await?;
62///     
63///     // 使用查询构建器
64///     let builder = db.table("users")
65///         .field("id")
66///         .field("name");
67///     
68///     // 执行查询(需要实现 select 方法)
69///     // let users = builder.select::<User>().await?;
70///     
71///     Ok(())
72/// }
73/// ```
74pub struct Database {
75    pool: MySqlPool,
76    config: DatabaseConfig,
77}
78
79impl Database {
80    /// 创建数据库连接
81    ///
82    /// # 参数
83    /// - url: 数据库连接字符串,格式:mysql://user:password@host:port/database
84    ///
85    /// # 返回
86    /// - Ok(Database): 成功创建的数据库实例
87    /// - Err(DbError): 连接失败错误
88    pub async fn connect(url: &str) -> Result<Self, DbError> {
89        Self::connect_with_config(url, DatabaseConfig::default()).await
90    }
91
92    /// 使用自定义配置创建数据库连接
93    pub async fn connect_with_config(url: &str, config: DatabaseConfig) -> Result<Self, DbError> {
94        use sqlx::mysql::MySqlPoolOptions;
95        use std::time::Duration;
96
97        // 使用配置参数创建连接池
98        let pool = MySqlPoolOptions::new()
99            .max_connections(config.max_connections)
100            .acquire_timeout(Duration::from_secs(config.connect_timeout))
101            .idle_timeout(Duration::from_secs(config.idle_timeout))
102            .connect(url)
103            .await?;
104
105        Ok(Self { pool, config })
106    }
107
108    /// 选择表,返回查询构建器
109    pub fn table(&self, table_name: &str) -> QueryBuilder<'_> {
110        QueryBuilder::new(&self.pool, table_name, self.config.enable_logging)
111    }
112
113    /// 执行原生 SELECT 查询
114    pub async fn query<T>(&self, sql: &str) -> Result<Vec<T>, DbError>
115    where
116        T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
117    {
118        if self.config.enable_logging {
119            log::debug!("执行原生查询: {}", sql);
120        }
121
122        let rows = sqlx::query_as::<_, T>(sql).fetch_all(&self.pool).await?;
123
124        Ok(rows)
125    }
126
127    /// 执行原生 INSERT/UPDATE/DELETE 查询
128    pub async fn execute(&self, sql: &str) -> Result<u64, DbError> {
129        if self.config.enable_logging {
130            log::debug!("执行原生语句: {}", sql);
131        }
132
133        let result = sqlx::query(sql).execute(&self.pool).await?;
134
135        Ok(result.rows_affected())
136    }
137
138    /// 开始事务
139    pub async fn transaction(&self) -> Result<Transaction, DbError> {
140        let tx = self.pool.begin().await?;
141        Ok(Transaction::new(tx, self.config.enable_logging))
142    }
143
144    /// 初始化数据库(执行 SQL 脚本)
145    pub async fn init(&self, sql_script: &str) -> Result<(), DbError> {
146        // 按分号分割多个 SQL 语句
147        let statements: Vec<&str> = sql_script
148            .split(';')
149            .map(|s| s.trim())
150            .filter(|s| !s.is_empty())
151            .collect();
152
153        for statement in statements {
154            self.execute(statement).await?;
155        }
156
157        Ok(())
158    }
159
160    /// 创建表
161    pub async fn create_table(&self, create_sql: &str) -> Result<(), DbError> {
162        self.execute(create_sql).await?;
163        Ok(())
164    }
165
166    /// 删除表
167    pub async fn drop_table(&self, table_name: &str) -> Result<(), DbError> {
168        let sql = format!("DROP TABLE IF EXISTS `{}`", table_name);
169        self.execute(&sql).await?;
170        Ok(())
171    }
172
173    /// 检查表是否存在
174    pub async fn table_exists(&self, table_name: &str) -> Result<bool, DbError> {
175        let sql = format!(
176            "SELECT COUNT(*) as count FROM information_schema.tables \
177             WHERE table_schema = DATABASE() AND table_name = '{}'",
178            table_name
179        );
180
181        let row: (i64,) = sqlx::query_as(&sql).fetch_one(&self.pool).await?;
182
183        Ok(row.0 > 0)
184    }
185}