Skip to main content

db_cores/
lib.rs

1// #![allow(unused_variables)]
2// #![allow(unused_imports)]
3
4mod common;
5pub use common::*;
6
7pub mod ast;
8pub mod plugin;
9pub mod utlis;
10pub mod to_json;
11pub mod verify;
12pub mod query;
13pub mod builder;
14
15
16
17
18
19
20use regex::Regex;
21use serde::Deserialize;
22use serde::Serialize;
23use serde_json::Value as JsonValue;
24pub use sql_builder::SqlBuilder;
25
26#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
27use std::{collections::HashMap, path::Path, sync::{Arc, LazyLock}};
28
29#[cfg(feature = "postgres")]
30use sqlx::{PgPool, Pool, Postgres};
31#[cfg(feature = "mysql")]
32use sqlx::{MySql, MySqlPool};
33#[cfg(feature = "sqlite")]
34use sqlx::{sqlite::SqliteConnectOptions, Sqlite, SqlitePool};
35
36pub use serde;
37pub use serde_json;
38#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
39pub use sqlx;
40pub use db_proc_macro;
41
42// 多种数据库池的枚举封装
43#[derive(Debug, Clone)]
44#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
45pub enum DbPool {
46    #[cfg(feature = "mysql")]
47    MySql(MySqlPool),
48    #[cfg(feature = "postgres")]
49    PgSql(PgPool),
50    #[cfg(feature = "sqlite")]
51    Sqlite(SqlitePool),
52}
53
54#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
55pub struct ToJsonResult {
56    pub data: Vec<JsonValue>,
57    pub columns: Vec<JsonValue>,
58    pub count: i64,
59}
60
61
62
63// 枚举:不同数据库类型创建时需要的参数
64#[derive(Debug, Serialize, Deserialize, Clone)]
65pub enum NewConnectParams {
66    MySql {
67        id: String,
68        connect_name: String,
69        host: String,
70        port: i64,
71        username: String,
72        password: String,
73        default_db_name: Option<String>,
74        charset: Option<String>,
75    },
76
77    Postgres {
78        id: String,
79        connect_name: String,
80        host: String,
81        port: i64,
82        username: String,
83        password: String,
84        default_db_name: Option<String>,
85        schema: Option<String>,
86    },
87
88    SqlServer {
89        id: String,
90        connect_name: String,
91        host: String,
92        port: i64,
93        username: String,
94        password: String,
95        default_db_name: Option<String>,
96        instance: Option<String>,
97    },
98
99    FileDB {
100        id: String,
101        kind: DatabaseKind, // 文件类数据库类型(Sqlite/DuckDB 等)
102        connect_name: String,
103        file_dir: String,
104        is_memory: bool,
105    },
106}
107
108
109// 定义全局连接池缓存
110#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
111type PoolCache = Arc<tokio::sync::RwLock<HashMap<String, DbPool>>>;
112
113pub const COUNT_COLUMN_NAME: &str = "r8Bz1ae9BxYqe";
114
115// pub static RE_SELECT: LazyLock<Regex> =
116//     LazyLock::new(|| Regex::new(r"(?i)(SELECT\s+.*?\s+FROM)").unwrap());
117
118// pub static RE_SELECT_FIELDS: LazyLock<Regex> =
119//     LazyLock::new(|| Regex::new(r"(?i)^SELECT\s+(.*?)\s+FROM").unwrap());
120
121// pub static RE_LIMIT_OFFSET: LazyLock<Regex> =
122//     LazyLock::new(|| Regex::new(r"(?i)(LIMIT\s+\d+(\s+OFFSET\s+\d+)?)").unwrap());
123
124// pub static RE_ORDER_BY: LazyLock<Regex> =
125//     LazyLock::new(|| Regex::new(r"(?i)\s*ORDER BY.*").unwrap());
126
127// // static RE_TABLE_NAME: LazyLock<Regex> =
128//     LazyLock::new(|| Regex::new(r"(?i)\bFROM\s+(\w+)").unwrap());
129
130// static RE_WHERE_CLAUSE: LazyLock<Regex> =
131//     LazyLock::new(|| Regex::new(r"(?i)\s+WHERE\s+(.*?)(?:\s+LIMIT|\s+OFFSET|$)").unwrap());
132
133#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
134pub static DB_POOL_CACHE: LazyLock<PoolCache> =
135    LazyLock::new(|| Arc::new(tokio::sync::RwLock::new(HashMap::new())));
136
137/// 通用方法:从缓存中获取连接池或者动态生成连接池并添加到缓存
138#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
139async fn get_or_create_pool<F, Fut>(connect_key: &str, create_pool: F) -> anyhow::Result<DbPool>
140where
141    F: FnOnce() -> Fut, // 接收一个函数,返回一个 Future
142    Fut: std::future::Future<Output = anyhow::Result<DbPool>>, // Future 的输出是 `Result<DbPool>`
143{
144    // 读取缓存数据
145    if let Some(pool) = DB_POOL_CACHE.read().await.get(connect_key) {
146        return Ok(pool.clone());
147    }
148    // 如果缓存未命中,则调用 `create_pool` 创建新的连接池
149    let pool = create_pool().await?;
150    // 将新创建的连接池插入缓存
151    DB_POOL_CACHE
152        .write()
153        .await
154        .insert(connect_key.to_string(), pool.clone());
155    Ok(pool)
156}
157
158#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
159pub async fn remove_pool_from_cache(database_path: impl AsRef<Path>) -> bool {
160    use tokio::time::{sleep, Duration};
161    let path: &Path = database_path.as_ref();
162    let connect_key = format!("{}", path.to_string_lossy());
163
164    let mut cache = DB_POOL_CACHE.write().await;
165    if let Some(pool) = cache.remove(&connect_key) {
166        drop(pool); // 显式释放 pool
167        sleep(Duration::from_millis(100)).await; // 等待句柄真正释放
168        true
169    } else {
170        false
171    }
172}
173
174#[cfg(feature = "sqlite")]
175async fn get_sqlite_pool(database_path: impl AsRef<Path>) -> anyhow::Result<DbPool> {
176    let path: &Path = database_path.as_ref();
177    let connect_key = format!("{}", path.to_string_lossy());
178    get_or_create_pool(&connect_key, || async {
179        let options = SqliteConnectOptions::new()
180            .filename(path)
181            .create_if_missing(true);
182        let pool = sqlx::sqlite::SqlitePool::connect_with(options.clone()).await?;
183        Ok(DbPool::Sqlite(pool))
184    })
185    .await
186}
187
188#[cfg(any(feature = "mysql", feature = "postgres"))]
189async fn get_server_pool(op: &DbConnect) -> anyhow::Result<DbPool> {
190    let connect_key = format!(
191        "{}_{}_{}_{}_{}",
192        op.host, op.port, op.db_name, op.username, op.password
193    );
194    get_or_create_pool(&connect_key, || async {
195        let db_pool = match op.kind {
196            #[cfg(feature = "mysql")]
197            DatabaseKind::MySql => {
198                let connection_string = format!(
199                    "mysql://{}:{}@{}:{}/{}",
200                    op.username, op.password, op.host, op.port, op.db_name
201                );
202                let pool = MySqlPool::connect(&connection_string).await?;
203                DbPool::MySql(pool)
204            }
205            #[cfg(feature = "postgres")]
206            DatabaseKind::Postgres => {
207                let connection_string = format!(
208                    "postgres://{}:{}@{}:{}/{}",
209                    op.username, op.password, op.host, op.port, op.db_name
210                );
211                let pool: Pool<Postgres> = PgPool::connect(&connection_string).await?;
212                DbPool::PgSql(pool)
213            }
214            _ => return Err(anyhow::anyhow!("Unsupported database type")),
215        };
216        Ok(db_pool)
217    })
218    .await
219}
220
221#[cfg(feature = "sqlite")]
222pub async fn sqlite_pool(database_path: impl AsRef<Path>) -> anyhow::Result<sqlx::Pool<Sqlite>> {
223    let pool = get_sqlite_pool(database_path).await?;
224    match pool {
225        DbPool::Sqlite(pool) => Ok(pool),
226        _ => Err(anyhow::anyhow!("Unsupported sqlite_pool database type")),
227    }
228}
229
230#[cfg(feature = "mysql")]
231pub async fn mysql_pool(op: &DbConnect) -> anyhow::Result<sqlx::Pool<MySql>> {
232    let pool = get_server_pool(op).await?;
233    match pool {
234        DbPool::MySql(pool) => Ok(pool),
235        _ => Err(anyhow::anyhow!("Invalid database type for MySql")),
236    }
237}
238
239#[cfg(feature = "postgres")]
240pub async fn postgres_pool(op: &DbConnect) -> anyhow::Result<Pool<Postgres>> {
241    let pool = get_server_pool(op).await?;
242    match pool {
243        DbPool::PgSql(pool) => Ok(pool),
244        _ => Err(anyhow::Error::msg("Invalid database type for PostgreSQL")),
245    }
246}
247
248
249
250#[cfg(test)]
251#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
252mod tests {
253    use super::*;
254    use crate::define_model; // 引入 define_model 宏
255    define_model!(Code2 , "TABLE_CODE",
256        id: String,
257        name:String,
258        kind: String, // api请求 ,code  代码,commend  命令
259        content: String,
260        params: Option<serde_json::Value>,
261        req: Option<String>,
262        res: Option<String>,
263        created_at: Option<i64>,
264        updated_at : Option<i64>,
265        is_active: bool // bool,
266    );
267
268
269    #[test]
270    pub fn run() {
271        let sql = "SELECT id FROM sample WHERE a=0 AND B>1 ORDER BY id LIMIT 1000 OFFSET 0";
272        // let r = select_add_count(sql);
273        // println!("{}", r);
274    }
275}
276