use std::fmt;
use std::marker::PhantomData;
use serde_json::Value;
use sqlx::Row;
use crate::error::ModelResult;
use crate::model::Model;
#[derive(Debug, Clone, PartialEq)]
pub enum QueryOperator {
Equal,
NotEqual,
GreaterThan,
GreaterThanOrEqual,
LessThan,
LessThanOrEqual,
Like,
NotLike,
In,
NotIn,
IsNull,
IsNotNull,
Between,
}
impl fmt::Display for QueryOperator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
QueryOperator::Equal => write!(f, "="),
QueryOperator::NotEqual => write!(f, "!="),
QueryOperator::GreaterThan => write!(f, ">"),
QueryOperator::GreaterThanOrEqual => write!(f, ">="),
QueryOperator::LessThan => write!(f, "<"),
QueryOperator::LessThanOrEqual => write!(f, "<="),
QueryOperator::Like => write!(f, "LIKE"),
QueryOperator::NotLike => write!(f, "NOT LIKE"),
QueryOperator::In => write!(f, "IN"),
QueryOperator::NotIn => write!(f, "NOT IN"),
QueryOperator::IsNull => write!(f, "IS NULL"),
QueryOperator::IsNotNull => write!(f, "IS NOT NULL"),
QueryOperator::Between => write!(f, "BETWEEN"),
}
}
}
#[derive(Debug, Clone)]
pub struct WhereCondition {
pub column: String,
pub operator: QueryOperator,
pub value: Option<Value>,
pub values: Vec<Value>, }
#[derive(Debug, Clone, PartialEq)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
}
impl fmt::Display for JoinType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
JoinType::Inner => write!(f, "INNER JOIN"),
JoinType::Left => write!(f, "LEFT JOIN"),
JoinType::Right => write!(f, "RIGHT JOIN"),
JoinType::Full => write!(f, "FULL JOIN"),
}
}
}
#[derive(Debug, Clone)]
pub struct JoinClause {
pub join_type: JoinType,
pub table: String,
pub on_conditions: Vec<(String, String)>, }
#[derive(Debug, Clone, PartialEq)]
pub enum OrderDirection {
Asc,
Desc,
}
impl fmt::Display for OrderDirection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OrderDirection::Asc => write!(f, "ASC"),
OrderDirection::Desc => write!(f, "DESC"),
}
}
}
#[derive(Debug, Clone)]
pub struct OrderByClause {
pub column: String,
pub direction: OrderDirection,
}
#[derive(Debug)]
pub struct QueryBuilder<M = ()> {
select_fields: Vec<String>,
from_table: Option<String>,
where_conditions: Vec<WhereCondition>,
joins: Vec<JoinClause>,
order_by: Vec<OrderByClause>,
group_by: Vec<String>,
having_conditions: Vec<WhereCondition>,
limit_value: Option<i64>,
offset_value: Option<i64>,
distinct: bool,
_phantom: PhantomData<M>,
}
impl<M> Clone for QueryBuilder<M> {
fn clone(&self) -> Self {
Self {
select_fields: self.select_fields.clone(),
from_table: self.from_table.clone(),
where_conditions: self.where_conditions.clone(),
joins: self.joins.clone(),
order_by: self.order_by.clone(),
group_by: self.group_by.clone(),
having_conditions: self.having_conditions.clone(),
limit_value: self.limit_value,
offset_value: self.offset_value,
distinct: self.distinct,
_phantom: PhantomData,
}
}
}
impl<M> Default for QueryBuilder<M> {
fn default() -> Self {
Self::new()
}
}
impl<M> QueryBuilder<M> {
pub fn new() -> Self {
Self {
select_fields: Vec::new(),
from_table: None,
where_conditions: Vec::new(),
joins: Vec::new(),
order_by: Vec::new(),
group_by: Vec::new(),
having_conditions: Vec::new(),
limit_value: None,
offset_value: None,
distinct: false,
_phantom: PhantomData,
}
}
pub fn select(mut self, fields: &str) -> Self {
if fields == "*" {
self.select_fields.push("*".to_string());
} else {
self.select_fields.extend(
fields
.split(',')
.map(|f| f.trim().to_string())
.collect::<Vec<String>>()
);
}
self
}
pub fn select_distinct(mut self, fields: &str) -> Self {
self.distinct = true;
self.select(fields)
}
pub fn from(mut self, table: &str) -> Self {
self.from_table = Some(table.to_string());
self
}
pub fn where_eq<T: Into<Value>>(mut self, column: &str, value: T) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::Equal,
value: Some(value.into()),
values: Vec::new(),
});
self
}
pub fn where_ne<T: Into<Value>>(mut self, column: &str, value: T) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::NotEqual,
value: Some(value.into()),
values: Vec::new(),
});
self
}
pub fn where_gt<T: Into<Value>>(mut self, column: &str, value: T) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::GreaterThan,
value: Some(value.into()),
values: Vec::new(),
});
self
}
pub fn where_gte<T: Into<Value>>(mut self, column: &str, value: T) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::GreaterThanOrEqual,
value: Some(value.into()),
values: Vec::new(),
});
self
}
pub fn where_lt<T: Into<Value>>(mut self, column: &str, value: T) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::LessThan,
value: Some(value.into()),
values: Vec::new(),
});
self
}
pub fn where_lte<T: Into<Value>>(mut self, column: &str, value: T) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::LessThanOrEqual,
value: Some(value.into()),
values: Vec::new(),
});
self
}
pub fn where_like(mut self, column: &str, pattern: &str) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::Like,
value: Some(Value::String(pattern.to_string())),
values: Vec::new(),
});
self
}
pub fn where_not_like(mut self, column: &str, pattern: &str) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::NotLike,
value: Some(Value::String(pattern.to_string())),
values: Vec::new(),
});
self
}
pub fn where_in<T: Into<Value>>(mut self, column: &str, values: Vec<T>) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::In,
value: None,
values: values.into_iter().map(|v| v.into()).collect(),
});
self
}
pub fn where_not_in<T: Into<Value>>(mut self, column: &str, values: Vec<T>) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::NotIn,
value: None,
values: values.into_iter().map(|v| v.into()).collect(),
});
self
}
pub fn where_null(mut self, column: &str) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::IsNull,
value: None,
values: Vec::new(),
});
self
}
pub fn where_not_null(mut self, column: &str) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::IsNotNull,
value: None,
values: Vec::new(),
});
self
}
pub fn where_between<T: Into<Value>>(mut self, column: &str, start: T, end: T) -> Self {
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::Between,
value: None,
values: vec![start.into(), end.into()],
});
self
}
pub fn join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
self.joins.push(JoinClause {
join_type: JoinType::Inner,
table: table.to_string(),
on_conditions: vec![(left_col.to_string(), right_col.to_string())],
});
self
}
pub fn left_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
self.joins.push(JoinClause {
join_type: JoinType::Left,
table: table.to_string(),
on_conditions: vec![(left_col.to_string(), right_col.to_string())],
});
self
}
pub fn right_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
self.joins.push(JoinClause {
join_type: JoinType::Right,
table: table.to_string(),
on_conditions: vec![(left_col.to_string(), right_col.to_string())],
});
self
}
pub fn order_by(mut self, column: &str) -> Self {
self.order_by.push(OrderByClause {
column: column.to_string(),
direction: OrderDirection::Asc,
});
self
}
pub fn order_by_desc(mut self, column: &str) -> Self {
self.order_by.push(OrderByClause {
column: column.to_string(),
direction: OrderDirection::Desc,
});
self
}
pub fn group_by(mut self, column: &str) -> Self {
self.group_by.push(column.to_string());
self
}
pub fn having_eq<T: Into<Value>>(mut self, column: &str, value: T) -> Self {
self.having_conditions.push(WhereCondition {
column: column.to_string(),
operator: QueryOperator::Equal,
value: Some(value.into()),
values: Vec::new(),
});
self
}
pub fn select_count(mut self, column: &str, alias: Option<&str>) -> Self {
let select_expr = if let Some(alias) = alias {
format!("COUNT({}) AS {}", column, alias)
} else {
format!("COUNT({})", column)
};
self.select_fields.push(select_expr);
self
}
pub fn select_sum(mut self, column: &str, alias: Option<&str>) -> Self {
let select_expr = if let Some(alias) = alias {
format!("SUM({}) AS {}", column, alias)
} else {
format!("SUM({})", column)
};
self.select_fields.push(select_expr);
self
}
pub fn select_avg(mut self, column: &str, alias: Option<&str>) -> Self {
let select_expr = if let Some(alias) = alias {
format!("AVG({}) AS {}", column, alias)
} else {
format!("AVG({})", column)
};
self.select_fields.push(select_expr);
self
}
pub fn select_min(mut self, column: &str, alias: Option<&str>) -> Self {
let select_expr = if let Some(alias) = alias {
format!("MIN({}) AS {}", column, alias)
} else {
format!("MIN({})", column)
};
self.select_fields.push(select_expr);
self
}
pub fn select_max(mut self, column: &str, alias: Option<&str>) -> Self {
let select_expr = if let Some(alias) = alias {
format!("MAX({}) AS {}", column, alias)
} else {
format!("MAX({})", column)
};
self.select_fields.push(select_expr);
self
}
pub fn select_raw(mut self, expression: &str) -> Self {
self.select_fields.push(expression.to_string());
self
}
pub fn limit(mut self, count: i64) -> Self {
self.limit_value = Some(count);
self
}
pub fn offset(mut self, count: i64) -> Self {
self.offset_value = Some(count);
self
}
pub fn paginate(mut self, per_page: i64, page: i64) -> Self {
self.limit_value = Some(per_page);
self.offset_value = Some((page - 1) * per_page);
self
}
pub fn paginate_cursor<T: Into<Value>>(mut self, cursor_column: &str, cursor_value: Option<T>, per_page: i64, direction: OrderDirection) -> Self {
self.limit_value = Some(per_page);
if let Some(cursor_val) = cursor_value {
match direction {
OrderDirection::Asc => {
self = self.where_gt(cursor_column, cursor_val);
}
OrderDirection::Desc => {
self = self.where_lt(cursor_column, cursor_val);
}
}
}
self.order_by.push(OrderByClause {
column: cursor_column.to_string(),
direction,
});
self
}
pub fn union(self, _other_query: QueryBuilder<M>) -> Self {
self
}
pub fn union_all(self, _other_query: QueryBuilder<M>) -> Self {
self
}
pub fn where_subquery<T: Into<Value>>(mut self, column: &str, operator: QueryOperator, subquery: QueryBuilder<M>) -> Self {
let subquery_sql = subquery.to_sql();
let formatted_value = format!("({})", subquery_sql);
self.where_conditions.push(WhereCondition {
column: column.to_string(),
operator,
value: Some(Value::String(formatted_value)),
values: Vec::new(),
});
self
}
pub fn where_exists(mut self, subquery: QueryBuilder<M>) -> Self {
self.where_conditions.push(WhereCondition {
column: "EXISTS".to_string(),
operator: QueryOperator::Equal,
value: Some(Value::String(format!("({})", subquery.to_sql()))),
values: Vec::new(),
});
self
}
pub fn where_not_exists(mut self, subquery: QueryBuilder<M>) -> Self {
self.where_conditions.push(WhereCondition {
column: "NOT EXISTS".to_string(),
operator: QueryOperator::Equal,
value: Some(Value::String(format!("({})", subquery.to_sql()))),
values: Vec::new(),
});
self
}
pub fn where_raw(mut self, raw_condition: &str) -> Self {
self.where_conditions.push(WhereCondition {
column: "RAW".to_string(),
operator: QueryOperator::Equal,
value: Some(Value::String(raw_condition.to_string())),
values: Vec::new(),
});
self
}
pub fn or_where<F>(mut self, closure: F) -> Self
where
F: FnOnce(QueryBuilder<M>) -> QueryBuilder<M>,
{
let inner_query = closure(QueryBuilder::new());
self.where_conditions.extend(inner_query.where_conditions);
self
}
pub fn to_sql(&self) -> String {
let mut sql = String::new();
if self.distinct {
sql.push_str("SELECT DISTINCT ");
} else {
sql.push_str("SELECT ");
}
if self.select_fields.is_empty() {
sql.push('*');
} else {
sql.push_str(&self.select_fields.join(", "));
}
if let Some(table) = &self.from_table {
sql.push_str(&format!(" FROM {}", table));
}
for join in &self.joins {
sql.push_str(&format!(" {} {}", join.join_type, join.table));
if !join.on_conditions.is_empty() {
sql.push_str(" ON ");
let conditions: Vec<String> = join
.on_conditions
.iter()
.map(|(left, right)| format!("{} = {}", left, right))
.collect();
sql.push_str(&conditions.join(" AND "));
}
}
if !self.where_conditions.is_empty() {
sql.push_str(" WHERE ");
let conditions = self.build_where_conditions(&self.where_conditions);
sql.push_str(&conditions.join(" AND "));
}
if !self.group_by.is_empty() {
sql.push_str(&format!(" GROUP BY {}", self.group_by.join(", ")));
}
if !self.having_conditions.is_empty() {
sql.push_str(" HAVING ");
let conditions = self.build_where_conditions(&self.having_conditions);
sql.push_str(&conditions.join(" AND "));
}
if !self.order_by.is_empty() {
sql.push_str(" ORDER BY ");
let order_clauses: Vec<String> = self
.order_by
.iter()
.map(|clause| format!("{} {}", clause.column, clause.direction))
.collect();
sql.push_str(&order_clauses.join(", "));
}
if let Some(limit) = self.limit_value {
sql.push_str(&format!(" LIMIT {}", limit));
}
if let Some(offset) = self.offset_value {
sql.push_str(&format!(" OFFSET {}", offset));
}
sql
}
fn build_where_conditions(&self, conditions: &[WhereCondition]) -> Vec<String> {
conditions
.iter()
.map(|condition| {
if condition.column == "RAW" {
if let Some(Value::String(raw_sql)) = &condition.value {
return raw_sql.clone();
}
}
if condition.column == "EXISTS" || condition.column == "NOT EXISTS" {
if let Some(Value::String(subquery)) = &condition.value {
return format!("{} {}", condition.column, subquery);
}
}
match &condition.operator {
QueryOperator::IsNull | QueryOperator::IsNotNull => {
format!("{} {}", condition.column, condition.operator)
}
QueryOperator::In | QueryOperator::NotIn => {
if let Some(Value::String(subquery)) = &condition.value {
if subquery.starts_with('(') && subquery.ends_with(')') {
format!("{} {} {}", condition.column, condition.operator, subquery)
} else {
format!("{} {} ({})", condition.column, condition.operator, self.format_value(&condition.value.as_ref().unwrap()))
}
} else {
let values: Vec<String> = condition
.values
.iter()
.map(|v| self.format_value(v))
.collect();
format!("{} {} ({})", condition.column, condition.operator, values.join(", "))
}
}
QueryOperator::Between => {
if condition.values.len() == 2 {
format!(
"{} BETWEEN {} AND {}",
condition.column,
self.format_value(&condition.values[0]),
self.format_value(&condition.values[1])
)
} else {
format!("{} = NULL", condition.column) }
}
_ => {
if let Some(value) = &condition.value {
if let Value::String(val_str) = value {
if val_str.starts_with('(') && val_str.ends_with(')') {
format!("{} {} {}", condition.column, condition.operator, val_str)
} else {
format!("{} {} {}", condition.column, condition.operator, self.format_value(value))
}
} else {
format!("{} {} {}", condition.column, condition.operator, self.format_value(value))
}
} else {
format!("{} = NULL", condition.column) }
}
}
})
.collect()
}
fn format_value(&self, value: &Value) -> String {
match value {
Value::String(s) => format!("'{}'", s.replace('\'', "''")), Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => "NULL".to_string(),
_ => "NULL".to_string(), }
}
pub fn bindings(&self) -> Vec<Value> {
let mut bindings = Vec::new();
for condition in &self.where_conditions {
if matches!(condition.column.as_str(), "RAW" | "EXISTS" | "NOT EXISTS") {
continue;
}
if let Some(value) = &condition.value {
if let Value::String(val_str) = value {
if !val_str.starts_with('(') || !val_str.ends_with(')') {
bindings.push(value.clone());
}
} else {
bindings.push(value.clone());
}
}
bindings.extend(condition.values.clone());
}
for condition in &self.having_conditions {
if let Some(value) = &condition.value {
bindings.push(value.clone());
}
bindings.extend(condition.values.clone());
}
bindings
}
pub fn clone_for_subquery(&self) -> Self {
self.clone()
}
pub fn optimize(self) -> Self {
self
}
pub fn complexity_score(&self) -> u32 {
let mut score = 0;
score += self.where_conditions.len() as u32;
score += self.joins.len() as u32 * 2; score += self.group_by.len() as u32;
score += self.having_conditions.len() as u32;
if self.distinct {
score += 1;
}
score
}
}
impl<M: Model> QueryBuilder<M> {
pub async fn get(self, pool: &sqlx::Pool<sqlx::Postgres>) -> ModelResult<Vec<M>> {
let sql = self.to_sql();
let rows = sqlx::query(&sql)
.fetch_all(pool)
.await?;
let mut models = Vec::new();
for row in rows {
models.push(M::from_row(&row)?);
}
Ok(models)
}
pub async fn chunk<F>(
mut self,
pool: &sqlx::Pool<sqlx::Postgres>,
chunk_size: i64,
mut callback: F
) -> ModelResult<()>
where
F: FnMut(Vec<M>) -> Result<(), crate::error::ModelError>,
{
let mut offset = 0;
loop {
let chunk_query = self.clone()
.limit(chunk_size)
.offset(offset);
let chunk = chunk_query.get(pool).await?;
if chunk.is_empty() {
break;
}
callback(chunk)?;
offset += chunk_size;
}
Ok(())
}
pub async fn get_raw(self, pool: &sqlx::Pool<sqlx::Postgres>) -> ModelResult<Vec<serde_json::Value>> {
let sql = self.to_sql();
let rows = sqlx::query(&sql)
.fetch_all(pool)
.await?;
let mut results = Vec::new();
for row in rows {
let mut json_row = serde_json::Map::new();
for i in 0..row.len() {
if let Ok(column) = row.try_get::<Option<String>, _>(i) {
let column_name = format!("column_{}", i); json_row.insert(column_name, serde_json::Value::String(column.unwrap_or_default()));
}
}
results.push(serde_json::Value::Object(json_row));
}
Ok(results)
}
pub async fn first(self, pool: &sqlx::Pool<sqlx::Postgres>) -> ModelResult<Option<M>> {
let query = self.limit(1);
let mut results = query.get(pool).await?;
Ok(results.pop())
}
pub async fn first_or_fail(self, pool: &sqlx::Pool<sqlx::Postgres>) -> ModelResult<M> {
self.first(pool)
.await?
.ok_or_else(|| crate::error::ModelError::NotFound(M::table_name().to_string()))
}
pub async fn count(mut self, pool: &sqlx::Pool<sqlx::Postgres>) -> ModelResult<i64> {
self.select_fields = vec!["COUNT(*)".to_string()];
let sql = self.to_sql();
let row = sqlx::query(&sql)
.fetch_one(pool)
.await?;
let count: i64 = row.try_get(0)?;
Ok(count)
}
pub async fn aggregate(self, pool: &sqlx::Pool<sqlx::Postgres>) -> ModelResult<Option<serde_json::Value>> {
let sql = self.to_sql();
let row_opt = sqlx::query(&sql)
.fetch_optional(pool)
.await?;
if let Some(row) = row_opt {
if let Ok(result) = row.try_get::<Option<i64>, _>(0) {
return Ok(Some(serde_json::Value::Number(serde_json::Number::from(result.unwrap_or(0)))));
} else if let Ok(result) = row.try_get::<Option<f64>, _>(0) {
return Ok(Some(serde_json::Number::from_f64(result.unwrap_or(0.0)).map(serde_json::Value::Number).unwrap_or(serde_json::Value::Null)));
} else if let Ok(result) = row.try_get::<Option<String>, _>(0) {
return Ok(Some(serde_json::Value::String(result.unwrap_or_default())));
}
}
Ok(None)
}
}