use rat_logger::{debug, info};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct MysqlAdapter {
creation_locks: Arc<Mutex<HashMap<String, ()>>>,
pub(crate) stored_procedures:
Arc<Mutex<HashMap<String, crate::stored_procedure::StoredProcedureInfo>>>,
}
impl MysqlAdapter {
pub fn new() -> Self {
Self {
creation_locks: Arc::new(Mutex::new(HashMap::new())),
stored_procedures: Arc::new(Mutex::new(HashMap::new())),
}
}
pub(crate) async fn acquire_table_lock(
&self,
table: &str,
) -> tokio::sync::MutexGuard<'_, HashMap<String, ()>> {
let mut locks = self.creation_locks.lock().await;
if !locks.contains_key(table) {
locks.insert(table.to_string(), ());
debug!("🔒 获取表 {} 的创建锁", table);
}
locks
}
pub(crate) async fn release_table_lock(
&self,
table: &str,
mut locks: tokio::sync::MutexGuard<'_, HashMap<String, ()>>,
) {
locks.remove(table);
debug!("🔓 释放表 {} 的创建锁", table);
}
pub async fn generate_stored_procedure_sql(
&self,
config: &crate::stored_procedure::StoredProcedureConfig,
) -> crate::error::QuickDbResult<String> {
use crate::stored_procedure::JoinType;
let fields: Vec<String> = config
.fields
.iter()
.map(|(alias, expr)| {
if alias == expr {
expr.clone()
} else {
format!("{} AS {}", expr, alias)
}
})
.collect();
let base_table = config
.dependencies
.first()
.map(|model_meta| &model_meta.collection_name)
.ok_or_else(|| crate::error::QuickDbError::ValidationError {
field: "dependencies".to_string(),
message: "至少需要一个依赖表作为主表".to_string(),
})?;
let mut joins = Vec::new();
for join in config.joins.iter() {
let join_str = match join.join_type {
JoinType::Inner => "INNER JOIN",
JoinType::Left => "LEFT JOIN",
JoinType::Right => "RIGHT JOIN",
JoinType::Full => "FULL OUTER JOIN",
};
joins.push(format!(
" {} {} ON {} = {}",
join_str, join.table, join.local_field, join.foreign_field
));
}
let mut group_by_fields = Vec::new();
let mut has_aggregate_function = false;
for (alias, expr) in &config.fields {
let expr_upper = expr.to_uppercase();
if expr_upper.contains("COUNT(")
|| expr_upper.contains("SUM(")
|| expr_upper.contains("AVG(")
|| expr_upper.contains("MAX(")
|| expr_upper.contains("MIN(")
{
has_aggregate_function = true;
} else {
if let Some(dot_pos) = expr.rfind('.') {
let field_name = &expr[dot_pos + 1..];
group_by_fields.push(expr.clone());
} else {
group_by_fields.push(expr.clone());
}
}
}
let group_by_clause = if has_aggregate_function && !group_by_fields.is_empty() {
format!(" GROUP BY {}", group_by_fields.join(", "))
} else {
"".to_string()
};
let sql_template = format!(
"SELECT {SELECT_FIELDS} FROM {BASE_TABLE}{JOINS}{WHERE}{GROUP_BY}{HAVING}{ORDER_BY}{LIMIT}{OFFSET}",
SELECT_FIELDS = fields.join(", "),
BASE_TABLE = base_table,
JOINS = if joins.is_empty() {
"".to_string()
} else {
format!(" {}", joins.join(" "))
},
WHERE = "{WHERE}", GROUP_BY = group_by_clause, HAVING = "{HAVING}", ORDER_BY = "{ORDER_BY}", LIMIT = "{LIMIT}", OFFSET = "{OFFSET}" );
info!("生成的MySQL存储过程SQL模板: {}", sql_template);
Ok(sql_template)
}
}