use crate::error::QuickDbResult;
use crate::stored_procedure::types::*;
use std::collections::HashMap;
pub struct StoredProcedureBuilder {
config: StoredProcedureConfig,
}
impl StoredProcedureBuilder {
pub fn new(name: &str, database: &str) -> Self {
Self {
config: StoredProcedureConfig {
database: database.to_string(),
dependencies: Vec::new(),
joins: Vec::new(),
fields: HashMap::new(),
procedure_name: name.to_string(),
mongo_pipeline: None,
},
}
}
pub fn with_dependency<T: crate::model::Model>(mut self) -> Self {
let model_meta = T::meta();
self.config.dependencies.push(model_meta);
self
}
pub fn with_join<T: crate::model::Model>(
mut self,
local_field: &str,
foreign_field: &str,
join_type: JoinType,
) -> Self {
let model_meta = T::meta();
self.config.joins.push(JoinRelation {
table: model_meta.collection_name.clone(),
local_field: local_field.to_string(),
foreign_field: foreign_field.to_string(),
join_type,
});
self
}
pub fn with_field(mut self, field_name: &str, expression: &str) -> Self {
self.config
.fields
.insert(field_name.to_string(), expression.to_string());
self
}
pub fn with_mongo_pipeline(
mut self,
operations: Vec<crate::stored_procedure::types::MongoAggregationOperation>,
) -> Self {
self.config.mongo_pipeline = Some(operations);
self
}
pub fn with_mongo_aggregation(self) -> MongoPipelineBuilder {
MongoPipelineBuilder::new(self)
}
pub fn build(self) -> StoredProcedureConfig {
self.config
}
}
pub struct MongoPipelineBuilder {
stored_procedure_builder: StoredProcedureBuilder,
pipeline: Vec<crate::stored_procedure::types::MongoAggregationOperation>,
}
impl MongoPipelineBuilder {
pub fn new(stored_procedure_builder: StoredProcedureBuilder) -> Self {
Self {
stored_procedure_builder,
pipeline: Vec::new(),
}
}
pub fn project(
mut self,
fields: Vec<(&str, crate::stored_procedure::types::MongoFieldExpression)>,
) -> Self {
let mut field_map = std::collections::HashMap::new();
for (name, expr) in fields {
field_map.insert(name.to_string(), expr);
}
self.pipeline.push(
crate::stored_procedure::types::MongoAggregationOperation::Project {
fields: field_map,
},
);
self
}
pub fn match_condition(
mut self,
conditions: Vec<crate::stored_procedure::types::MongoCondition>,
) -> Self {
self.pipeline
.push(crate::stored_procedure::types::MongoAggregationOperation::Match { conditions });
self
}
pub fn lookup(
mut self,
from: &str,
local_field: &str,
foreign_field: &str,
as_field: &str,
) -> Self {
self.pipeline.push(
crate::stored_procedure::types::MongoAggregationOperation::Lookup {
from: from.to_string(),
local_field: local_field.to_string(),
foreign_field: foreign_field.to_string(),
as_field: as_field.to_string(),
},
);
self
}
pub fn unwind(mut self, field: &str) -> Self {
self.pipeline.push(
crate::stored_procedure::types::MongoAggregationOperation::Unwind {
field: field.to_string(),
},
);
self
}
pub fn group(
mut self,
id: crate::stored_procedure::types::MongoGroupKey,
accumulators: Vec<(&str, crate::stored_procedure::types::MongoAccumulator)>,
) -> Self {
let mut acc_map = std::collections::HashMap::new();
for (name, acc) in accumulators {
acc_map.insert(name.to_string(), acc);
}
self.pipeline.push(
crate::stored_procedure::types::MongoAggregationOperation::Group {
id,
accumulators: acc_map,
},
);
self
}
pub fn sort(mut self, fields: Vec<(&str, crate::types::SortDirection)>) -> Self {
let sort_fields: Vec<(String, crate::types::SortDirection)> = fields
.into_iter()
.map(|(name, dir)| (name.to_string(), dir))
.collect();
self.pipeline.push(
crate::stored_procedure::types::MongoAggregationOperation::Sort {
fields: sort_fields,
},
);
self
}
pub fn limit(mut self, count: i64) -> Self {
self.pipeline
.push(crate::stored_procedure::types::MongoAggregationOperation::Limit { count });
self
}
pub fn skip(mut self, count: i64) -> Self {
self.pipeline
.push(crate::stored_procedure::types::MongoAggregationOperation::Skip { count });
self
}
pub fn add_fields(
mut self,
fields: Vec<(&str, crate::stored_procedure::types::MongoFieldExpression)>,
) -> Self {
let mut field_map = std::collections::HashMap::new();
for (name, expr) in fields {
field_map.insert(name.to_string(), expr);
}
self.pipeline.push(
crate::stored_procedure::types::MongoAggregationOperation::AddFields {
fields: field_map,
},
);
self
}
pub fn done(self) -> StoredProcedureBuilder {
self.stored_procedure_builder
.with_mongo_pipeline(self.pipeline)
}
pub fn add_placeholder(mut self, placeholder_type: &str) -> Self {
self.pipeline.push(
crate::stored_procedure::types::MongoAggregationOperation::Placeholder {
placeholder_type: placeholder_type.to_string(),
},
);
self
}
pub fn with_common_placeholders(self) -> Self {
self.add_placeholder("where")
.add_placeholder("group_by")
.add_placeholder("having")
.add_placeholder("order_by")
.add_placeholder("limit")
.add_placeholder("offset")
}
pub fn build(self) -> StoredProcedureConfig {
self.done().build()
}
}
impl crate::stored_procedure::types::MongoFieldExpression {
pub fn field(field: &str) -> Self {
Self::Field(field.to_string())
}
pub fn constant(value: crate::types::DataValue) -> Self {
Self::Constant(value)
}
pub fn size(field: &str) -> Self {
Self::Aggregate(
crate::stored_procedure::types::MongoAggregateExpression::Size {
field: field.to_string(),
},
)
}
pub fn sum(field: &str) -> Self {
Self::Aggregate(
crate::stored_procedure::types::MongoAggregateExpression::Sum {
field: field.to_string(),
},
)
}
pub fn avg(field: &str) -> Self {
Self::Aggregate(
crate::stored_procedure::types::MongoAggregateExpression::Avg {
field: field.to_string(),
},
)
}
pub fn max(field: &str) -> Self {
Self::Aggregate(
crate::stored_procedure::types::MongoAggregateExpression::Max {
field: field.to_string(),
},
)
}
pub fn min(field: &str) -> Self {
Self::Aggregate(
crate::stored_procedure::types::MongoAggregateExpression::Min {
field: field.to_string(),
},
)
}
pub fn if_null(
field: &str,
default: crate::stored_procedure::types::MongoFieldExpression,
) -> Self {
Self::Aggregate(
crate::stored_procedure::types::MongoAggregateExpression::IfNull {
field: field.to_string(),
default: Box::new(default),
},
)
}
}
impl crate::stored_procedure::types::MongoCondition {
pub fn eq(field: &str, value: crate::types::DataValue) -> Self {
Self::Eq {
field: field.to_string(),
value,
}
}
pub fn ne(field: &str, value: crate::types::DataValue) -> Self {
Self::Ne {
field: field.to_string(),
value,
}
}
pub fn gt(field: &str, value: crate::types::DataValue) -> Self {
Self::Gt {
field: field.to_string(),
value,
}
}
pub fn gte(field: &str, value: crate::types::DataValue) -> Self {
Self::Gte {
field: field.to_string(),
value,
}
}
pub fn lt(field: &str, value: crate::types::DataValue) -> Self {
Self::Lt {
field: field.to_string(),
value,
}
}
pub fn lte(field: &str, value: crate::types::DataValue) -> Self {
Self::Lte {
field: field.to_string(),
value,
}
}
pub fn and(conditions: Vec<Self>) -> Self {
Self::And { conditions }
}
pub fn or(conditions: Vec<Self>) -> Self {
Self::Or { conditions }
}
pub fn exists(field: &str, exists: bool) -> Self {
Self::Exists {
field: field.to_string(),
exists,
}
}
pub fn regex(field: &str, pattern: &str) -> Self {
Self::Regex {
field: field.to_string(),
pattern: pattern.to_string(),
}
}
}
impl StoredProcedureConfig {
pub fn builder(name: &str, database: &str) -> StoredProcedureBuilder {
StoredProcedureBuilder::new(name, database)
}
pub fn validate(&self) -> QuickDbResult<()> {
if self.procedure_name.is_empty() {
return Err(crate::error::QuickDbError::ValidationError {
field: "procedure_name".to_string(),
message: "存储过程名称不能为空".to_string(),
});
}
if self.database.is_empty() {
return Err(crate::error::QuickDbError::ValidationError {
field: "database".to_string(),
message: "数据库别名不能为空".to_string(),
});
}
self.validate_database_type_compatibility()?;
if self.mongo_pipeline.is_none() && self.fields.is_empty() {
return Err(crate::error::QuickDbError::ValidationError {
field: "fields".to_string(),
message: "至少需要一个字段或聚合管道操作".to_string(),
});
}
for join in &self.joins {
if join.local_field.is_empty() || join.foreign_field.is_empty() {
return Err(crate::error::QuickDbError::ValidationError {
field: "join_fields".to_string(),
message: "JOIN字段不能为空".to_string(),
});
}
}
Ok(())
}
fn validate_database_type_compatibility(&self) -> QuickDbResult<()> {
use crate::manager::get_global_pool_manager;
let db_type = get_global_pool_manager()
.get_database_type(&self.database)
.map_err(|_| crate::error::QuickDbError::ValidationError {
field: "database".to_string(),
message: format!("数据库别名 '{}' 不存在", self.database),
})?;
match db_type {
crate::types::DatabaseType::MongoDB => {
if self.mongo_pipeline.is_none() && self.fields.is_empty() {
return Err(crate::error::QuickDbError::ValidationError {
field: "mongo_config".to_string(),
message: "MongoDB存储过程必须使用聚合管道或字段映射".to_string(),
});
}
if self.joins.len() > 1 {
rat_logger::warn!(
"警告:MongoDB对复杂JOIN支持有限,建议使用聚合管道中的$lookup操作"
);
}
}
crate::types::DatabaseType::SQLite
| crate::types::DatabaseType::MySQL
| crate::types::DatabaseType::PostgreSQL => {
if self.mongo_pipeline.is_some() {
return Err(crate::error::QuickDbError::ValidationError {
field: "mongo_pipeline".to_string(),
message: format!(
"{} 不支持MongoDB聚合管道,请使用传统字段映射和JOIN配置",
match db_type {
crate::types::DatabaseType::SQLite => "SQLite",
crate::types::DatabaseType::MySQL => "MySQL",
crate::types::DatabaseType::PostgreSQL => "PostgreSQL",
_ => "该数据库",
}
),
});
}
if self.fields.is_empty() {
return Err(crate::error::QuickDbError::ValidationError {
field: "fields".to_string(),
message: format!(
"{} 存储过程必须定义字段映射",
match db_type {
crate::types::DatabaseType::SQLite => "SQLite",
crate::types::DatabaseType::MySQL => "MySQL",
crate::types::DatabaseType::PostgreSQL => "PostgreSQL",
_ => "该数据库",
}
),
});
}
}
}
Ok(())
}
}