use sqlx::{PgPool, MySqlPool, SqlitePool, AnyPool, Column, Row as SqlxRow};
use crate::{Row, DbResult, DbError, IdKind};
use alun_core::PageQuery;
use serde_json::{Value, Number};
#[derive(Clone)]
pub enum DbPool {
Postgres(PgPool),
Mysql(MySqlPool),
Sqlite(SqlitePool),
Any(AnyPool),
}
#[derive(Clone)]
pub struct Db {
pool: DbPool,
}
macro_rules! impl_db_ops {
($pool_ty:ty, $db_mod:ident) => {
paste::paste! {
fn [<typed_row_to_row_ $db_mod:snake>](
row: &<sqlx::$db_mod as sqlx::Database>::Row
) -> Row {
let mut r = Row::default();
for col in row.columns() {
let name = col.name().to_string();
let idx: usize = col.ordinal();
if let Ok(v) = row.try_get::<i64, usize>(idx) {
r.data.insert(name, Value::Number(v.into()));
} else if let Ok(v) = row.try_get::<i32, usize>(idx) {
r.data.insert(name, Value::Number((v as i64).into()));
} else if let Ok(v) = row.try_get::<i16, usize>(idx) {
r.data.insert(name, Value::Number((v as i64).into()));
} else if let Ok(v) = row.try_get::<String, usize>(idx) {
r.data.insert(name, Value::String(v));
} else if let Ok(v) = row.try_get::<sqlx::types::Uuid, usize>(idx) {
r.data.insert(name, Value::String(v.to_string()));
} else if let Ok(v) = row.try_get::<f64, usize>(idx) {
if let Some(n) = Number::from_f64(v) {
r.data.insert(name, Value::Number(n));
}
} else if let Ok(v) = row.try_get::<bool, usize>(idx) {
r.data.insert(name, Value::Bool(v));
}
}
r.mark_all_changed();
r
}
async fn [<query_one_ $pool_ty:snake>](
pool: &$pool_ty, sql: &str, params: &[&str],
) -> DbResult<Option<Row>> {
let mut q = sqlx::query::<sqlx::$db_mod>(sql);
for p in params { q = q.bind(*p); }
Ok(q.fetch_optional(pool).await?.as_ref()
.map([<typed_row_to_row_ $db_mod:snake>]))
}
async fn [<query_all_ $pool_ty:snake>](
pool: &$pool_ty, sql: &str, params: &[&str],
) -> DbResult<Vec<Row>> {
let mut q = sqlx::query::<sqlx::$db_mod>(sql);
for p in params { q = q.bind(*p); }
let rows = q.fetch_all(pool).await?;
Ok(rows.iter().map([<typed_row_to_row_ $db_mod:snake>]).collect())
}
async fn [<count_ $pool_ty:snake>](
pool: &$pool_ty, sql: &str, params: &[&str],
) -> DbResult<u64> {
let mut q = sqlx::query_scalar::<sqlx::$db_mod, i64>(sql);
for p in params { q = q.bind(*p); }
Ok(q.fetch_optional(pool).await?.unwrap_or(0) as u64)
}
async fn [<execute_ $pool_ty:snake>](
pool: &$pool_ty, sql: &str, params: &[&str],
) -> DbResult<u64> {
let mut q = sqlx::query::<sqlx::$db_mod>(sql);
for p in params { q = q.bind(*p); }
q.execute(pool).await.map_err(DbError::from).map(|r| r.rows_affected())
}
}
};
}
impl_db_ops!(PgPool, Postgres);
impl_db_ops!(MySqlPool, MySql);
impl_db_ops!(SqlitePool, Sqlite);
async fn query_one_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<Option<Row>> {
let mut q = sqlx::query(sql);
for p in params { q = q.bind(*p); }
Ok(q.fetch_optional(pool).await?.as_ref().map(typed_row_to_row_any))
}
async fn query_all_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<Vec<Row>> {
let mut q = sqlx::query(sql);
for p in params { q = q.bind(*p); }
let rows = q.fetch_all(pool).await?;
Ok(rows.iter().map(typed_row_to_row_any).collect())
}
fn typed_row_to_row_any(row: &sqlx::any::AnyRow) -> Row {
let mut r = Row::default();
for col in row.columns() {
let name = col.name().to_string();
let idx: usize = col.ordinal();
if let Ok(v) = row.try_get::<i64, usize>(idx) {
r.data.insert(name, Value::Number(v.into()));
} else if let Ok(v) = row.try_get::<i32, usize>(idx) {
r.data.insert(name, Value::Number((v as i64).into()));
} else if let Ok(v) = row.try_get::<String, usize>(idx) {
r.data.insert(name, Value::String(v));
} else if let Ok(v) = row.try_get::<f64, usize>(idx) {
if let Some(n) = Number::from_f64(v) {
r.data.insert(name, Value::Number(n));
}
} else if let Ok(v) = row.try_get::<bool, usize>(idx) {
r.data.insert(name, Value::Bool(v));
}
}
r.mark_all_changed();
r
}
async fn count_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<u64> {
let mut q = sqlx::query_scalar::<sqlx::Any, i64>(sql);
for p in params { q = q.bind(*p); }
Ok(q.fetch_optional(pool).await?.unwrap_or(0) as u64)
}
async fn execute_any(pool: &AnyPool, sql: &str, params: &[&str]) -> DbResult<u64> {
let mut q = sqlx::query(sql);
for p in params { q = q.bind(*p); }
Ok(q.execute(pool).await.map_err(DbError::from)?.rows_affected())
}
impl Db {
pub fn postgres(pool: PgPool) -> Self { Self { pool: DbPool::Postgres(pool) } }
pub fn mysql(pool: MySqlPool) -> Self { Self { pool: DbPool::Mysql(pool) } }
pub fn sqlite(pool: SqlitePool) -> Self { Self { pool: DbPool::Sqlite(pool) } }
pub fn pg_pool(&self) -> &PgPool { match &self.pool { DbPool::Postgres(p) => p, _ => panic!("不是 PG"), } }
pub fn mysql_pool(&self) -> &MySqlPool { match &self.pool { DbPool::Mysql(p) => p, _ => panic!("不是 MySQL"), } }
pub fn sqlite_pool(&self) -> &SqlitePool { match &self.pool { DbPool::Sqlite(p) => p, _ => panic!("不是 SQLite"), } }
pub async fn find_by_id(&self, table: &str, id: impl Into<serde_json::Value>) -> DbResult<Option<Row>> {
let value: serde_json::Value = id.into();
let pk = "id";
let id_str = value_to_string(&value);
let sql = format!("SELECT * FROM {} WHERE {}=$1{}", table, pk, id_cast(&value));
let params = vec![id_str.as_str()];
self.query_one(&sql, ¶ms).await
}
pub async fn query_one(&self, sql: &str, params: &[&str]) -> DbResult<Option<Row>> {
match &self.pool {
DbPool::Postgres(pool) => query_one_pg_pool(pool, sql, params).await,
DbPool::Mysql(pool) => query_one_my_sql_pool(pool, sql, params).await,
DbPool::Sqlite(pool) => query_one_sqlite_pool(pool, sql, params).await,
DbPool::Any(pool) => query_one_any(pool, sql, params).await,
}
}
pub async fn query(&self, sql: &str, params: &[&str]) -> DbResult<Vec<Row>> {
match &self.pool {
DbPool::Postgres(pool) => query_all_pg_pool(pool, sql, params).await,
DbPool::Mysql(pool) => query_all_my_sql_pool(pool, sql, params).await,
DbPool::Sqlite(pool) => query_all_sqlite_pool(pool, sql, params).await,
DbPool::Any(pool) => query_all_any(pool, sql, params).await,
}
}
pub async fn query_page(&self, sql: &str, params: &[&str], page: &PageQuery) -> DbResult<(Vec<Row>, u64)> {
let count_sql = format!("SELECT COUNT(*) as cnt FROM ({}) AS _count_sub", sql);
let total = self.count(&count_sql, params).await?;
let page_sql = format!("{} LIMIT {} OFFSET {}", sql, page.limit(), page.offset());
let rows = self.query(&page_sql, params).await?;
Ok((rows, total))
}
pub async fn count(&self, sql: &str, params: &[&str]) -> DbResult<u64> {
match &self.pool {
DbPool::Postgres(pool) => count_pg_pool(pool, sql, params).await,
DbPool::Mysql(pool) => count_my_sql_pool(pool, sql, params).await,
DbPool::Sqlite(pool) => count_sqlite_pool(pool, sql, params).await,
DbPool::Any(pool) => count_any(pool, sql, params).await,
}
}
pub async fn insert(&self, row: &Row) -> DbResult<Row> {
let table = row.table.as_deref().ok_or(DbError::Argument("Row 缺少表名".into()))?;
let columns: Vec<&String> = row.changes.iter().collect();
if columns.is_empty() { return Err(DbError::Argument("没有变更的字段".into())); }
let placeholders: Vec<String> = columns.iter().enumerate().map(|(i, c)| {
let cast = row.data.get(*c).map(|v| value_cast(v)).unwrap_or("");
format!("${}{}", i + 1, cast)
}).collect();
let col_str = columns.iter().map(|c| c.as_str()).collect::<Vec<_>>().join(", ");
let values: Vec<String> = columns.iter()
.filter_map(|c| row.data.get(*c)).map(value_to_string).collect();
let val_refs: Vec<&str> = values.iter().map(|s| s.as_str()).collect();
if matches!(&self.pool, DbPool::Postgres(_)) {
let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", table, col_str, placeholders.join(", "));
self.query_one(&sql, &val_refs).await?.ok_or_else(|| DbError::Other("INSERT 返回空".into()))
} else {
let sql = format!("INSERT INTO {} ({}) VALUES ({})", table, col_str, placeholders.join(", "));
self.execute(&sql, &val_refs).await?;
let pk_val = row.data.get("id");
match pk_val {
Some(v) => self.find_by_id(table, v.clone()).await?.ok_or(DbError::Other("INSERT 后查不到".into())),
None => Err(DbError::Argument("非 PG 数据库需 Row 含主键".into())),
}
}
}
pub async fn batch_insert(&self, rows: &[Row]) -> DbResult<u64> {
if rows.is_empty() { return Ok(0); }
let table = rows[0].table.as_deref().ok_or(DbError::Argument("Row 缺少表名".into()))?;
let columns: Vec<&String> = rows[0].changes.iter().collect();
if columns.is_empty() { return Err(DbError::Argument("没有变更的字段".into())); }
let col_names = columns.iter().map(|c| c.as_str()).collect::<Vec<_>>().join(", ");
let mut all_params: Vec<String> = Vec::new();
let mut groups: Vec<String> = Vec::new();
for (ri, row) in rows.iter().enumerate() {
let offset = ri * columns.len();
let ph: Vec<String> = columns.iter().enumerate().map(|(ci, c)| {
let cast = row.data.get(*c).map(|v| value_cast(v)).unwrap_or("");
format!("${}{}", offset + ci + 1, cast)
}).collect();
groups.push(format!("({})", ph.join(", ")));
for c in &columns {
all_params.push(row.data.get(*c).map(value_to_string).unwrap_or_default());
}
}
let sql = format!("INSERT INTO {} ({}) VALUES {}", table, col_names, groups.join(", "));
let val_refs: Vec<&str> = all_params.iter().map(|s| s.as_str()).collect();
self.execute(&sql, &val_refs).await
}
pub async fn update(&self, row: &Row) -> DbResult<Option<Row>> {
let table = row.table.as_deref().ok_or(DbError::Argument("Row 缺少表名".into()))?;
let sets: Vec<String> = row.changes.iter().enumerate()
.map(|(i, col)| {
let cast = row.data.get(col).map(|v| value_cast(v)).unwrap_or("");
format!("{} = ${}{}", col, i + 1, cast)
}).collect();
let pk = row.primary_keys.first().map(|s| s.as_str()).unwrap_or("id");
let id_value = row.data.get(pk).ok_or(DbError::Argument("Row 缺少主键".into()))?;
let mut params: Vec<String> = row.changes.iter()
.filter_map(|c| row.data.get(c)).map(value_to_string).collect();
params.push(value_to_string(id_value));
let val_refs: Vec<&str> = params.iter().map(|s| s.as_str()).collect();
let id_cast_sql = id_cast(id_value);
if matches!(&self.pool, DbPool::Postgres(_)) {
let sql = format!("UPDATE {} SET {} WHERE {}=${}{} RETURNING *",
table, sets.join(", "), pk, row.changes.len() + 1, id_cast_sql);
self.query_one(&sql, &val_refs).await
} else {
let sql = format!("UPDATE {} SET {} WHERE {}=${}{}",
table, sets.join(", "), pk, row.changes.len() + 1, id_cast_sql);
let n = self.execute(&sql, &val_refs).await?;
if n > 0 { self.find_by_id(table, id_value.clone()).await } else { Ok(None) }
}
}
pub async fn batch_update(&self, table: &str, sets: &Row, where_sql: &str, where_params: &[&str]) -> DbResult<u64> {
if sets.changes.is_empty() { return Err(DbError::Argument("没有要更新的字段".into())); }
let set_clauses: Vec<String> = sets.changes.iter().enumerate()
.map(|(i, col)| {
let cast = sets.data.get(col).map(|v| value_cast(v)).unwrap_or("");
format!("{} = ${}{}", col, i + 1, cast)
}).collect();
let set_values: Vec<String> = sets.changes.iter()
.filter_map(|c| sets.data.get(c)).map(value_to_string).collect();
let offset = sets.changes.len();
let adjusted_where = adjust_param_indices_with_casts(where_sql, offset, where_params);
let sql = format!("UPDATE {} SET {} WHERE {}", table, set_clauses.join(", "), adjusted_where);
let mut all: Vec<String> = set_values;
all.extend(where_params.iter().map(|s| s.to_string()));
let val_refs: Vec<&str> = all.iter().map(|s| s.as_str()).collect();
self.execute(&sql, &val_refs).await
}
pub async fn delete_by_id(&self, table: &str, id: impl Into<serde_json::Value>) -> DbResult<bool> {
let value: serde_json::Value = id.into();
let pk = "id";
let id_str = value_to_string(&value);
let sql = format!("DELETE FROM {} WHERE {}=$1{}",
table, pk, id_cast(&value));
let n = self.execute(&sql, &[&id_str]).await?;
Ok(n > 0)
}
pub async fn batch_delete_by_ids(&self, table: &str, ids: &[impl AsRef<str>]) -> DbResult<u64> {
if ids.is_empty() { return Ok(0); }
let is_uuid = ids.first().map(|id| {
let s = id.as_ref();
s.len() == 36 && s.chars().filter(|&c| c == '-').count() == 4
}).unwrap_or(false);
let cast = if is_uuid { "::uuid" } else { "" };
let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}{}", i, cast)).collect();
let sql = format!("DELETE FROM {} WHERE id IN ({})", table, placeholders.join(", "));
let params: Vec<&str> = ids.iter().map(|id| id.as_ref()).collect();
self.execute(&sql, ¶ms).await
}
pub async fn execute(&self, sql: &str, params: &[&str]) -> DbResult<u64> {
match &self.pool {
DbPool::Postgres(pool) => execute_pg_pool(pool, sql, params).await,
DbPool::Mysql(pool) => execute_my_sql_pool(pool, sql, params).await,
DbPool::Sqlite(pool) => execute_sqlite_pool(pool, sql, params).await,
DbPool::Any(pool) => execute_any(pool, sql, params).await,
}
}
pub async fn transaction<F, Fut, T>(&self, f: F) -> DbResult<T>
where
F: FnOnce(crate::tx::ActiveTx) -> Fut + Send,
Fut: std::future::Future<Output = (crate::tx::ActiveTx, DbResult<T>)> + Send,
T: Send,
{
let mut rollback_only = false;
crate::tx::execute_transaction(&self.pool, crate::tx::Isolation::ReadCommitted, &mut rollback_only, f).await
}
}
pub(crate) fn value_to_string(v: &serde_json::Value) -> String {
match v {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
serde_json::Value::Null => String::new(),
other => other.to_string(),
}
}
fn adjust_param_indices_with_casts(sql: &str, offset: usize, params: &[&str]) -> String {
let re = regex::Regex::new(r"\$(\d+)").unwrap();
if offset == 0 {
return re.replace_all(sql, |caps: ®ex::Captures| {
let n: usize = caps[1].parse().unwrap_or(0);
let cast = params.get(n.wrapping_sub(1)).map(|v| {
let s: &str = v;
if s.len() == 36 && s.chars().filter(|&c| c == '-').count() == 4 { "::uuid" }
else if s.parse::<i64>().is_ok() { "::bigint" }
else if s.parse::<f64>().is_ok() { "::double precision" }
else { "" }
}).unwrap_or("");
format!("${}{}", n, cast)
}).to_string();
}
re.replace_all(sql, |caps: ®ex::Captures| {
let n: usize = caps[1].parse().unwrap_or(0);
let cast = params.get(n.wrapping_sub(1)).map(|v| {
let s: &str = v;
if s.len() == 36 && s.chars().filter(|&c| c == '-').count() == 4 { "::uuid" }
else if s.parse::<i64>().is_ok() { "::bigint" }
else if s.parse::<f64>().is_ok() { "::double precision" }
else { "" }
}).unwrap_or("");
format!("${}{}", n + offset, cast)
}).to_string()
}
fn id_cast(value: &Value) -> &'static str {
match IdKind::detect(value) {
IdKind::Uuid => "::uuid",
IdKind::I64 => "::bigint",
_ => "",
}
}
fn value_cast(value: &Value) -> &'static str {
match value {
Value::Object(_) | Value::Array(_) => "::jsonb",
Value::String(s) => {
if is_inet_format(s) {
"::inet"
} else {
match IdKind::detect(value) {
IdKind::Uuid => "::uuid",
IdKind::I64 => "::bigint",
IdKind::F64 => "::double precision",
IdKind::Bool => "::boolean",
_ => "",
}
}
}
_ => match IdKind::detect(value) {
IdKind::Uuid => "::uuid",
IdKind::I64 => "::bigint",
IdKind::F64 => "::double precision",
IdKind::Bool => "::boolean",
_ => "",
},
}
}
fn is_inet_format(s: &str) -> bool {
if s.is_empty() {
return false;
}
let parts: Vec<&str> = s.split('.').collect();
if parts.len() == 4 && parts.iter().all(|p| p.parse::<u8>().is_ok()) {
return true;
}
if s.contains("::") {
return true;
}
if s.contains(':') {
let parts: Vec<&str> = s.split(':').collect();
if parts.len() >= 2 && parts.len() <= 8 {
return parts.iter().all(|p| p.is_empty() || u16::from_str_radix(p, 16).is_ok());
}
}
false
}