1use sqlx::{PgPool, MySqlPool, SqlitePool, AnyPool, Column, Row as SqlxRow};
6use crate::{Row, DbResult, DbError, IdKind};
7use alun_core::PageQuery;
8use serde_json::{Value, Number};
9
10#[derive(Clone)]
14pub enum DbPool {
15 Postgres(PgPool),
17 Mysql(MySqlPool),
19 Sqlite(SqlitePool),
21 Any(AnyPool),
23}
24
25#[derive(Clone)]
37pub struct Db {
38 pool: DbPool,
40}
41
42macro_rules! impl_db_ops {
46 ($pool_ty:ty, $db_mod:ident) => {
47 paste::paste! {
48 fn [<typed_row_to_row_ $db_mod:snake>](
49 row: &<sqlx::$db_mod as sqlx::Database>::Row
50 ) -> Row {
51 let mut r = Row::default();
52 for col in row.columns() {
53 let name = col.name().to_string();
54 let idx: usize = col.ordinal();
55 if let Ok(v) = row.try_get::<i64, usize>(idx) {
56 r.data.insert(name, Value::Number(v.into()));
57 } else if let Ok(v) = row.try_get::<i32, usize>(idx) {
58 r.data.insert(name, Value::Number((v as i64).into()));
59 } else if let Ok(v) = row.try_get::<i16, usize>(idx) {
60 r.data.insert(name, Value::Number((v as i64).into()));
61 } else if let Ok(v) = row.try_get::<String, usize>(idx) {
62 r.data.insert(name, Value::String(v));
63 } else if let Ok(v) = row.try_get::<sqlx::types::Uuid, usize>(idx) {
64 r.data.insert(name, Value::String(v.to_string()));
65 } else if let Ok(v) = row.try_get::<f64, usize>(idx) {
66 if let Some(n) = Number::from_f64(v) {
67 r.data.insert(name, Value::Number(n));
68 }
69 } else if let Ok(v) = row.try_get::<bool, usize>(idx) {
70 r.data.insert(name, Value::Bool(v));
71 }
72 }
73 r.mark_all_changed();
74 r
75 }
76
77 async fn [<query_one_ $pool_ty:snake>](
78 pool: &$pool_ty, sql: &str, params: &[&str],
79 ) -> DbResult<Option<Row>> {
80 let mut q = sqlx::query::<sqlx::$db_mod>(sql);
81 for p in params { q = q.bind(*p); }
82 Ok(q.fetch_optional(pool).await?.as_ref()
83 .map([<typed_row_to_row_ $db_mod:snake>]))
84 }
85
86 async fn [<query_all_ $pool_ty:snake>](
87 pool: &$pool_ty, sql: &str, params: &[&str],
88 ) -> DbResult<Vec<Row>> {
89 let mut q = sqlx::query::<sqlx::$db_mod>(sql);
90 for p in params { q = q.bind(*p); }
91 let rows = q.fetch_all(pool).await?;
92 Ok(rows.iter().map([<typed_row_to_row_ $db_mod:snake>]).collect())
93 }
94
95 async fn [<count_ $pool_ty:snake>](
96 pool: &$pool_ty, sql: &str, params: &[&str],
97 ) -> DbResult<u64> {
98 let mut q = sqlx::query_scalar::<sqlx::$db_mod, i64>(sql);
99 for p in params { q = q.bind(*p); }
100 Ok(q.fetch_optional(pool).await?.unwrap_or(0) as u64)
101 }
102
103 async fn [<execute_ $pool_ty:snake>](
104 pool: &$pool_ty, sql: &str, params: &[&str],
105 ) -> DbResult<u64> {
106 let mut q = sqlx::query::<sqlx::$db_mod>(sql);
107 for p in params { q = q.bind(*p); }
108 q.execute(pool).await.map_err(DbError::from).map(|r| r.rows_affected())
109 }
110 }
111 };
112}
113
114impl_db_ops!(PgPool, Postgres);
115impl_db_ops!(MySqlPool, MySql);
116impl_db_ops!(SqlitePool, Sqlite);
117
118async fn query_one_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<Option<Row>> {
119 let mut q = sqlx::query(sql);
120 for p in params { q = q.bind(*p); }
121 Ok(q.fetch_optional(pool).await?.as_ref().map(typed_row_to_row_any))
122}
123
124async fn query_all_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<Vec<Row>> {
125 let mut q = sqlx::query(sql);
126 for p in params { q = q.bind(*p); }
127 let rows = q.fetch_all(pool).await?;
128 Ok(rows.iter().map(typed_row_to_row_any).collect())
129}
130
131fn typed_row_to_row_any(row: &sqlx::any::AnyRow) -> Row {
132 let mut r = Row::default();
133 for col in row.columns() {
134 let name = col.name().to_string();
135 let idx: usize = col.ordinal();
136 if let Ok(v) = row.try_get::<i64, usize>(idx) {
137 r.data.insert(name, Value::Number(v.into()));
138 } else if let Ok(v) = row.try_get::<i32, usize>(idx) {
139 r.data.insert(name, Value::Number((v as i64).into()));
140 } else if let Ok(v) = row.try_get::<String, usize>(idx) {
141 r.data.insert(name, Value::String(v));
142 } else if let Ok(v) = row.try_get::<f64, usize>(idx) {
143 if let Some(n) = Number::from_f64(v) {
144 r.data.insert(name, Value::Number(n));
145 }
146 } else if let Ok(v) = row.try_get::<bool, usize>(idx) {
147 r.data.insert(name, Value::Bool(v));
148 }
149 }
150 r.mark_all_changed();
151 r
152}
153
154async fn count_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<u64> {
155 let mut q = sqlx::query_scalar::<sqlx::Any, i64>(sql);
156 for p in params { q = q.bind(*p); }
157 Ok(q.fetch_optional(pool).await?.unwrap_or(0) as u64)
158}
159
160async fn execute_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<u64> {
161 let mut q = sqlx::query(sql);
162 for p in params { q = q.bind(*p); }
163 Ok(q.execute(pool).await.map_err(DbError::from)?.rows_affected())
164}
165
166impl Db {
167 pub fn postgres(pool: PgPool) -> Self { Self { pool: DbPool::Postgres(pool) } }
169 pub fn mysql(pool: MySqlPool) -> Self { Self { pool: DbPool::Mysql(pool) } }
171 pub fn sqlite(pool: SqlitePool) -> Self { Self { pool: DbPool::Sqlite(pool) } }
173
174 pub fn pg_pool(&self) -> &PgPool { match &self.pool { DbPool::Postgres(p) => p, _ => panic!("不是 PG"), } }
176 pub fn mysql_pool(&self) -> &MySqlPool { match &self.pool { DbPool::Mysql(p) => p, _ => panic!("不是 MySQL"), } }
178 pub fn sqlite_pool(&self) -> &SqlitePool { match &self.pool { DbPool::Sqlite(p) => p, _ => panic!("不是 SQLite"), } }
180
181 pub async fn find_by_id(&self, table: &str, id: impl Into<serde_json::Value>) -> DbResult<Option<Row>> {
198 let value: serde_json::Value = id.into();
199 let pk = "id";
200 let id_str = value_to_string(&value);
201 let sql = format!("SELECT * FROM {} WHERE {}=$1{}", table, pk, id_cast(&value));
202 let params = vec![id_str.as_str()];
203 self.query_one(&sql, ¶ms).await
204 }
205
206 pub async fn query_one(&self, sql: &str, params: &[&str]) -> DbResult<Option<Row>> {
210 match &self.pool {
211 DbPool::Postgres(pool) => query_one_pg_pool(pool, sql, params).await,
212 DbPool::Mysql(pool) => query_one_my_sql_pool(pool, sql, params).await,
213 DbPool::Sqlite(pool) => query_one_sqlite_pool(pool, sql, params).await,
214 DbPool::Any(pool) => query_one_any(pool, sql, params).await,
215 }
216 }
217
218 pub async fn query(&self, sql: &str, params: &[&str]) -> DbResult<Vec<Row>> {
220 match &self.pool {
221 DbPool::Postgres(pool) => query_all_pg_pool(pool, sql, params).await,
222 DbPool::Mysql(pool) => query_all_my_sql_pool(pool, sql, params).await,
223 DbPool::Sqlite(pool) => query_all_sqlite_pool(pool, sql, params).await,
224 DbPool::Any(pool) => query_all_any(pool, sql, params).await,
225 }
226 }
227
228 pub async fn query_page(&self, sql: &str, params: &[&str], page: &PageQuery) -> DbResult<(Vec<Row>, u64)> {
232 let count_sql = format!("SELECT COUNT(*) as cnt FROM ({}) AS _count_sub", sql);
233 let total = self.count(&count_sql, params).await?;
234 let page_sql = format!("{} LIMIT {} OFFSET {}", sql, page.limit(), page.offset());
235 let rows = self.query(&page_sql, params).await?;
236 Ok((rows, total))
237 }
238
239 pub async fn count(&self, sql: &str, params: &[&str]) -> DbResult<u64> {
241 match &self.pool {
242 DbPool::Postgres(pool) => count_pg_pool(pool, sql, params).await,
243 DbPool::Mysql(pool) => count_my_sql_pool(pool, sql, params).await,
244 DbPool::Sqlite(pool) => count_sqlite_pool(pool, sql, params).await,
245 DbPool::Any(pool) => count_any(pool, sql, params).await,
246 }
247 }
248
249 pub async fn insert(&self, row: &Row) -> DbResult<Row> {
261 let table = row.table.as_deref().ok_or(DbError::Argument("Row 缺少表名".into()))?;
262 let columns: Vec<&String> = row.changes.iter().collect();
263 if columns.is_empty() { return Err(DbError::Argument("没有变更的字段".into())); }
264
265 let placeholders: Vec<String> = columns.iter().enumerate().map(|(i, c)| {
266 let cast = row.data.get(*c).map(|v| value_cast(v)).unwrap_or("");
267 format!("${}{}", i + 1, cast)
268 }).collect();
269 let col_str = columns.iter().map(|c| c.as_str()).collect::<Vec<_>>().join(", ");
270 let values: Vec<String> = columns.iter()
271 .filter_map(|c| row.data.get(*c)).map(value_to_string).collect();
272 let val_refs: Vec<&str> = values.iter().map(|s| s.as_str()).collect();
273
274 if matches!(&self.pool, DbPool::Postgres(_)) {
275 let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", table, col_str, placeholders.join(", "));
276 self.query_one(&sql, &val_refs).await?.ok_or_else(|| DbError::Other("INSERT 返回空".into()))
277 } else {
278 let sql = format!("INSERT INTO {} ({}) VALUES ({})", table, col_str, placeholders.join(", "));
279 self.execute(&sql, &val_refs).await?;
280 let pk_val = row.data.get("id");
281 match pk_val {
282 Some(v) => self.find_by_id(table, v.clone()).await?.ok_or(DbError::Other("INSERT 后查不到".into())),
283 None => Err(DbError::Argument("非 PG 数据库需 Row 含主键".into())),
284 }
285 }
286 }
287
288 pub async fn batch_insert(&self, rows: &[Row]) -> DbResult<u64> {
292 if rows.is_empty() { return Ok(0); }
293 let table = rows[0].table.as_deref().ok_or(DbError::Argument("Row 缺少表名".into()))?;
294 let columns: Vec<&String> = rows[0].changes.iter().collect();
295 if columns.is_empty() { return Err(DbError::Argument("没有变更的字段".into())); }
296
297 let col_names = columns.iter().map(|c| c.as_str()).collect::<Vec<_>>().join(", ");
298 let mut all_params: Vec<String> = Vec::new();
299 let mut groups: Vec<String> = Vec::new();
300 for (ri, row) in rows.iter().enumerate() {
301 let offset = ri * columns.len();
302 let ph: Vec<String> = columns.iter().enumerate().map(|(ci, c)| {
303 let cast = row.data.get(*c).map(|v| value_cast(v)).unwrap_or("");
304 format!("${}{}", offset + ci + 1, cast)
305 }).collect();
306 groups.push(format!("({})", ph.join(", ")));
307 for c in &columns {
308 all_params.push(row.data.get(*c).map(value_to_string).unwrap_or_default());
309 }
310 }
311 let sql = format!("INSERT INTO {} ({}) VALUES {}", table, col_names, groups.join(", "));
312 let val_refs: Vec<&str> = all_params.iter().map(|s| s.as_str()).collect();
313 self.execute(&sql, &val_refs).await
314 }
315
316 pub async fn update(&self, row: &Row) -> DbResult<Option<Row>> {
321 let table = row.table.as_deref().ok_or(DbError::Argument("Row 缺少表名".into()))?;
322 let sets: Vec<String> = row.changes.iter().enumerate()
323 .map(|(i, col)| {
324 let cast = row.data.get(col).map(|v| value_cast(v)).unwrap_or("");
325 format!("{} = ${}{}", col, i + 1, cast)
326 }).collect();
327 let pk = row.primary_keys.first().map(|s| s.as_str()).unwrap_or("id");
328 let id_value = row.data.get(pk).ok_or(DbError::Argument("Row 缺少主键".into()))?;
329
330 let mut params: Vec<String> = row.changes.iter()
331 .filter_map(|c| row.data.get(c)).map(value_to_string).collect();
332 params.push(value_to_string(id_value));
333 let val_refs: Vec<&str> = params.iter().map(|s| s.as_str()).collect();
334
335 let id_cast_sql = id_cast(id_value);
336 if matches!(&self.pool, DbPool::Postgres(_)) {
337 let sql = format!("UPDATE {} SET {} WHERE {}=${}{} RETURNING *",
338 table, sets.join(", "), pk, row.changes.len() + 1, id_cast_sql);
339 self.query_one(&sql, &val_refs).await
340 } else {
341 let sql = format!("UPDATE {} SET {} WHERE {}=${}{}",
342 table, sets.join(", "), pk, row.changes.len() + 1, id_cast_sql);
343 let n = self.execute(&sql, &val_refs).await?;
344 if n > 0 { self.find_by_id(table, id_value.clone()).await } else { Ok(None) }
345 }
346 }
347
348 pub async fn batch_update(&self, table: &str, sets: &Row, where_sql: &str, where_params: &[&str]) -> DbResult<u64> {
354 if sets.changes.is_empty() { return Err(DbError::Argument("没有要更新的字段".into())); }
355 let set_clauses: Vec<String> = sets.changes.iter().enumerate()
356 .map(|(i, col)| {
357 let cast = sets.data.get(col).map(|v| value_cast(v)).unwrap_or("");
358 format!("{} = ${}{}", col, i + 1, cast)
359 }).collect();
360 let set_values: Vec<String> = sets.changes.iter()
361 .filter_map(|c| sets.data.get(c)).map(value_to_string).collect();
362
363 let offset = sets.changes.len();
364 let adjusted_where = adjust_param_indices_with_casts(where_sql, offset, where_params);
365 let sql = format!("UPDATE {} SET {} WHERE {}", table, set_clauses.join(", "), adjusted_where);
366 let mut all: Vec<String> = set_values;
367 all.extend(where_params.iter().map(|s| s.to_string()));
368 let val_refs: Vec<&str> = all.iter().map(|s| s.as_str()).collect();
369 self.execute(&sql, &val_refs).await
370 }
371
372 pub async fn delete_by_id(&self, table: &str, id: impl Into<serde_json::Value>) -> DbResult<bool> {
374 let value: serde_json::Value = id.into();
375 let pk = "id";
376 let id_str = value_to_string(&value);
377 let sql = format!("DELETE FROM {} WHERE {}=$1{}",
378 table, pk, id_cast(&value));
379 let n = self.execute(&sql, &[&id_str]).await?;
380 Ok(n > 0)
381 }
382
383 pub async fn batch_delete_by_ids(&self, table: &str, ids: &[impl AsRef<str>]) -> DbResult<u64> {
385 if ids.is_empty() { return Ok(0); }
386 let is_uuid = ids.first().map(|id| {
387 let s = id.as_ref();
388 s.len() == 36 && s.chars().filter(|&c| c == '-').count() == 4
389 }).unwrap_or(false);
390 let cast = if is_uuid { "::uuid" } else { "" };
391 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}{}", i, cast)).collect();
392 let sql = format!("DELETE FROM {} WHERE id IN ({})", table, placeholders.join(", "));
393 let params: Vec<&str> = ids.iter().map(|id| id.as_ref()).collect();
394 self.execute(&sql, ¶ms).await
395 }
396
397 pub async fn execute(&self, sql: &str, params: &[&str]) -> DbResult<u64> {
399 match &self.pool {
400 DbPool::Postgres(pool) => execute_pg_pool(pool, sql, params).await,
401 DbPool::Mysql(pool) => execute_my_sql_pool(pool, sql, params).await,
402 DbPool::Sqlite(pool) => execute_sqlite_pool(pool, sql, params).await,
403 DbPool::Any(pool) => execute_any(pool, sql, params).await,
404 }
405 }
406
407 pub async fn transaction<F, Fut, T>(&self, f: F) -> DbResult<T>
424 where
425 F: FnOnce(crate::tx::ActiveTx) -> Fut + Send,
426 Fut: std::future::Future<Output = (crate::tx::ActiveTx, DbResult<T>)> + Send,
427 T: Send,
428 {
429 let mut rollback_only = false;
430 crate::tx::execute_transaction(&self.pool, crate::tx::Isolation::ReadCommitted, &mut rollback_only, f).await
431 }
432}
433
434pub(crate) fn value_to_string(v: &serde_json::Value) -> String {
437 match v {
438 serde_json::Value::String(s) => s.clone(),
439 serde_json::Value::Number(n) => n.to_string(),
440 serde_json::Value::Bool(b) => b.to_string(),
441 serde_json::Value::Null => String::new(),
442 other => other.to_string(),
443 }
444}
445
446fn adjust_param_indices_with_casts(sql: &str, offset: usize, params: &[&str]) -> String {
447 let re = regex::Regex::new(r"\$(\d+)").unwrap();
448 if offset == 0 {
449 return re.replace_all(sql, |caps: ®ex::Captures| {
450 let n: usize = caps[1].parse().unwrap_or(0);
451 let cast = params.get(n.wrapping_sub(1)).map(|v| {
452 let s: &str = v;
453 if s.len() == 36 && s.chars().filter(|&c| c == '-').count() == 4 { "::uuid" }
454 else if s.parse::<i64>().is_ok() { "::bigint" }
455 else if s.parse::<f64>().is_ok() { "::double precision" }
456 else { "" }
457 }).unwrap_or("");
458 format!("${}{}", n, cast)
459 }).to_string();
460 }
461 re.replace_all(sql, |caps: ®ex::Captures| {
462 let n: usize = caps[1].parse().unwrap_or(0);
463 let cast = params.get(n.wrapping_sub(1)).map(|v| {
464 let s: &str = v;
465 if s.len() == 36 && s.chars().filter(|&c| c == '-').count() == 4 { "::uuid" }
466 else if s.parse::<i64>().is_ok() { "::bigint" }
467 else if s.parse::<f64>().is_ok() { "::double precision" }
468 else { "" }
469 }).unwrap_or("");
470 format!("${}{}", n + offset, cast)
471 }).to_string()
472}
473
474fn id_cast(value: &Value) -> &'static str {
475 match IdKind::detect(value) {
476 IdKind::Uuid => "::uuid",
477 IdKind::I64 => "::bigint",
478 _ => "",
479 }
480}
481
482fn value_cast(value: &Value) -> &'static str {
483 match value {
484 Value::Object(_) | Value::Array(_) => "::jsonb",
485 Value::String(s) => {
486 if is_inet_format(s) {
487 "::inet"
488 } else {
489 match IdKind::detect(value) {
490 IdKind::Uuid => "::uuid",
491 IdKind::I64 => "::bigint",
492 IdKind::F64 => "::double precision",
493 IdKind::Bool => "::boolean",
494 _ => "",
495 }
496 }
497 }
498 _ => match IdKind::detect(value) {
499 IdKind::Uuid => "::uuid",
500 IdKind::I64 => "::bigint",
501 IdKind::F64 => "::double precision",
502 IdKind::Bool => "::boolean",
503 _ => "",
504 },
505 }
506}
507
508fn is_inet_format(s: &str) -> bool {
509 if s.is_empty() {
510 return false;
511 }
512 let parts: Vec<&str> = s.split('.').collect();
513 if parts.len() == 4 && parts.iter().all(|p| p.parse::<u8>().is_ok()) {
514 return true;
515 }
516 if s.contains("::") {
517 return true;
518 }
519 if s.contains(':') {
520 let parts: Vec<&str> = s.split(':').collect();
521 if parts.len() >= 2 && parts.len() <= 8 {
522 return parts.iter().all(|p| p.is_empty() || u16::from_str_radix(p, 16).is_ok());
523 }
524 }
525 false
526}
527
528