use std::fmt::Write as _;
use crate::core::{
AggregateExpr, AggregateQuery, BulkInsertQuery, BulkUpdateQuery, CountQuery, DeleteQuery,
Filter, InsertQuery, ModelSchema, Op, OrderClause, SearchClause, SelectQuery, SqlValue,
UpdateQuery, WhereExpr,
};
use super::{CompiledStatement, Dialect, SqlError};
#[allow(clippy::struct_field_names)] pub(super) struct Sql<'d> {
pub d: &'d dyn Dialect,
pub sql: String,
pub params: Vec<SqlValue>,
}
impl<'d> Sql<'d> {
pub(super) fn new(d: &'d dyn Dialect) -> Self {
Self {
d,
sql: String::new(),
params: Vec::new(),
}
}
pub(super) fn with_capacity(d: &'d dyn Dialect, cap: usize) -> Self {
Self {
d,
sql: String::new(),
params: Vec::with_capacity(cap),
}
}
pub(super) fn write_ident(&mut self, name: &str) {
self.sql.push_str(&self.d.quote_ident(name));
}
pub(super) fn push_param_typed(&mut self, value: SqlValue, cast: Option<&'static str>) {
let is_null = matches!(value, SqlValue::Null);
self.params.push(value);
let p = self.d.placeholder(self.params.len());
self.sql.push_str(&p);
if is_null {
if let Some(ty) = cast {
self.sql.push_str("::");
self.sql.push_str(ty);
}
}
}
pub(super) fn push_param(&mut self, value: SqlValue) {
self.push_param_typed(value, None);
}
pub(super) fn finish(self) -> CompiledStatement {
CompiledStatement {
sql: self.sql,
params: self.params,
}
}
}
pub(super) fn null_cast_for(
d: &dyn Dialect,
model: &ModelSchema,
column: &str,
) -> Option<&'static str> {
let field = model.field_by_column(column)?;
d.null_cast(field.ty)
}
pub(super) fn write_select(b: &mut Sql<'_>, query: &SelectQuery) -> Result<(), SqlError> {
let qualify = !query.joins.is_empty();
b.sql.push_str("SELECT ");
let mut first_col = true;
for field in query.model.scalar_fields() {
if !first_col {
b.sql.push_str(", ");
}
first_col = false;
if qualify {
b.write_ident(query.model.table);
b.sql.push('.');
}
b.write_ident(field.column);
}
for join in &query.joins {
for col in &join.project {
b.sql.push_str(", ");
b.write_ident(join.alias);
b.sql.push('.');
b.write_ident(col);
b.sql.push_str(" AS ");
b.write_ident(&format!("{}__{}", join.alias, col));
}
}
b.sql.push_str(" FROM ");
b.write_ident(query.model.table);
for join in &query.joins {
b.sql.push_str(" LEFT JOIN ");
b.write_ident(join.target.table);
b.sql.push_str(" AS ");
b.write_ident(join.alias);
b.sql.push_str(" ON ");
b.write_ident(query.model.table);
b.sql.push('.');
b.write_ident(join.on_local);
b.sql.push_str(" = ");
b.write_ident(join.alias);
b.sql.push('.');
b.write_ident(join.on_remote);
}
write_where_with_search(
b,
&query.where_clause,
query.search.as_ref(),
qualify.then_some(query.model.table),
Some(query.model),
)?;
write_order_limit_offset(
b,
&query.order_by,
query.limit,
query.offset,
qualify.then_some(query.model.table),
);
Ok(())
}
pub(super) fn write_count(b: &mut Sql<'_>, query: &CountQuery) -> Result<(), SqlError> {
b.sql.push_str("SELECT COUNT(*) FROM ");
b.write_ident(query.model.table);
write_where_with_search(
b,
&query.where_clause,
query.search.as_ref(),
None,
Some(query.model),
)?;
Ok(())
}
pub(super) fn write_aggregate(b: &mut Sql<'_>, query: &AggregateQuery) -> Result<(), SqlError> {
b.sql.push_str("SELECT ");
for (i, col) in query.group_by.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.write_ident(col);
}
for (i, (alias, expr)) in query.aggregates.iter().enumerate() {
if !query.group_by.is_empty() || i > 0 {
b.sql.push_str(", ");
}
match expr {
AggregateExpr::Count(None) => b.sql.push_str("COUNT(*)"),
AggregateExpr::Count(Some(col)) => {
b.sql.push_str("COUNT(");
b.write_ident(col);
b.sql.push(')');
}
AggregateExpr::Sum(col) => {
let inner = format!("SUM({})", b.d.quote_ident(col));
let wrapped = b.d.cast_aggregate_to_int(&inner);
b.sql.push_str(&wrapped);
}
AggregateExpr::Avg(col) => {
let inner = format!("AVG({})", b.d.quote_ident(col));
let wrapped = b.d.cast_aggregate_to_float(&inner);
b.sql.push_str(&wrapped);
}
AggregateExpr::Max(col) => {
b.sql.push_str("MAX(");
b.write_ident(col);
b.sql.push(')');
}
AggregateExpr::Min(col) => {
b.sql.push_str("MIN(");
b.write_ident(col);
b.sql.push(')');
}
}
b.sql.push_str(" AS ");
b.write_ident(alias);
}
b.sql.push_str(" FROM ");
b.write_ident(query.model.table);
write_where(b, &query.where_clause, Some(query.model))?;
if !query.group_by.is_empty() {
b.sql.push_str(" GROUP BY ");
for (i, col) in query.group_by.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.write_ident(col);
}
}
if let Some(having) = &query.having {
b.sql.push_str(" HAVING ");
write_where_expr(b, having, None, Some(query.model))?;
}
write_order_limit_offset(b, &query.order_by, query.limit, query.offset, None);
Ok(())
}
pub(super) fn write_insert(b: &mut Sql<'_>, query: &InsertQuery) -> Result<(), SqlError> {
if query.columns.is_empty() && query.returning.is_empty() {
return Err(SqlError::EmptyInsert);
}
if query.columns.len() != query.values.len() {
return Err(SqlError::InsertShapeMismatch {
columns: query.columns.len(),
values: query.values.len(),
});
}
b.sql.push_str("INSERT INTO ");
b.write_ident(query.model.table);
if query.columns.is_empty() {
b.sql.push_str(" DEFAULT VALUES");
} else {
b.sql.push_str(" (");
let mut first = true;
for col in &query.columns {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
}
b.sql.push_str(") VALUES (");
let mut first = true;
for (col, value) in query.columns.iter().zip(&query.values) {
if !first {
b.sql.push_str(", ");
}
first = false;
let cast = null_cast_for(b.d, query.model, col);
b.push_param_typed(value.clone(), cast);
}
b.sql.push(')');
}
if let Some(conflict) = &query.on_conflict {
b.d.write_conflict_clause(&mut b.sql, conflict)?;
}
write_returning(b, &query.returning)?;
Ok(())
}
pub(super) fn write_bulk_insert(b: &mut Sql<'_>, query: &BulkInsertQuery) -> Result<(), SqlError> {
if query.rows.is_empty() {
return Err(SqlError::EmptyBulkInsert);
}
if query.columns.is_empty() && query.returning.is_empty() {
return Err(SqlError::EmptyInsert);
}
for row in &query.rows {
if row.len() != query.columns.len() {
return Err(SqlError::InsertShapeMismatch {
columns: query.columns.len(),
values: row.len(),
});
}
}
b.sql.push_str("INSERT INTO ");
b.write_ident(query.model.table);
if query.columns.is_empty() {
let pk = query
.returning
.first()
.copied()
.ok_or(SqlError::EmptyInsert)?;
b.sql.push_str(" (");
b.write_ident(pk);
b.sql.push_str(") VALUES ");
let mut first_row = true;
for _ in &query.rows {
if !first_row {
b.sql.push_str(", ");
}
first_row = false;
b.sql.push_str("(DEFAULT)");
}
} else {
b.sql.push_str(" (");
let mut first = true;
for col in &query.columns {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
}
b.sql.push_str(") VALUES ");
let mut first_row = true;
for row in &query.rows {
if !first_row {
b.sql.push_str(", ");
}
first_row = false;
b.sql.push('(');
let mut first_v = true;
for (col, value) in query.columns.iter().zip(row) {
if !first_v {
b.sql.push_str(", ");
}
first_v = false;
let cast = null_cast_for(b.d, query.model, col);
b.push_param_typed(value.clone(), cast);
}
b.sql.push(')');
}
}
if let Some(conflict) = &query.on_conflict {
b.d.write_conflict_clause(&mut b.sql, conflict)?;
}
write_returning(b, &query.returning)?;
Ok(())
}
pub(super) fn write_update(b: &mut Sql<'_>, query: &UpdateQuery) -> Result<(), SqlError> {
if query.set.is_empty() {
return Err(SqlError::EmptyUpdateSet);
}
b.sql.push_str("UPDATE ");
b.write_ident(query.model.table);
b.sql.push_str(" SET ");
let mut first = true;
for assignment in &query.set {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(assignment.column);
b.sql.push_str(" = ");
let cast = null_cast_for(b.d, query.model, assignment.column);
b.push_param_typed(assignment.value.clone(), cast);
}
write_where(b, &query.where_clause, Some(query.model))?;
Ok(())
}
pub(super) fn write_delete(b: &mut Sql<'_>, query: &DeleteQuery) -> Result<(), SqlError> {
b.sql.push_str("DELETE FROM ");
b.write_ident(query.model.table);
write_where(b, &query.where_clause, Some(query.model))?;
Ok(())
}
pub(super) fn write_bulk_update_pg(
b: &mut Sql<'_>,
query: &BulkUpdateQuery,
) -> Result<(), SqlError> {
if query.rows.is_empty() {
return Err(SqlError::EmptyBulkInsert);
}
if query.update_columns.is_empty() {
return Err(SqlError::EmptyUpdateSet);
}
let pk_field = query
.model
.primary_key()
.ok_or(SqlError::MissingPrimaryKey)?;
b.sql.push_str("UPDATE ");
b.write_ident(query.model.table);
b.sql.push_str(" SET ");
let mut first = true;
for col in &query.update_columns {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
b.sql.push_str(" = __data.");
b.write_ident(col);
}
b.sql.push_str(" FROM (VALUES ");
let mut first_row = true;
for row in &query.rows {
if !first_row {
b.sql.push_str(", ");
}
first_row = false;
b.sql.push('(');
for (i, val) in row.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.push_param(val.clone());
}
b.sql.push(')');
}
b.sql.push_str(") AS __data(");
b.write_ident(pk_field.column);
for col in &query.update_columns {
b.sql.push_str(", ");
b.write_ident(col);
}
b.sql.push_str(") WHERE ");
b.write_ident(query.model.table);
b.sql.push('.');
b.write_ident(pk_field.column);
b.sql.push_str(" = __data.");
b.write_ident(pk_field.column);
Ok(())
}
pub(super) fn write_where(
b: &mut Sql<'_>,
where_clause: &WhereExpr,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
if where_clause.is_empty() {
return Ok(());
}
b.sql.push_str(" WHERE ");
write_where_expr(b, where_clause, None, model)
}
pub(super) fn write_where_with_search(
b: &mut Sql<'_>,
where_clause: &WhereExpr,
search: Option<&SearchClause>,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
let has_search = search.is_some_and(|s| !s.columns.is_empty() && !s.query.is_empty());
let has_where = !where_clause.is_empty();
if !has_where && !has_search {
return Ok(());
}
b.sql.push_str(" WHERE ");
if has_where {
write_where_expr(b, where_clause, qualify_with, model)?;
}
if has_search {
let s = search.expect("checked above");
if has_where {
b.sql.push_str(" AND ");
}
b.params.push(SqlValue::String(format!("%{}%", s.query)));
let placeholder = b.d.placeholder(b.params.len());
b.sql.push('(');
for (i, col) in s.columns.iter().enumerate() {
if i > 0 {
b.sql.push_str(" OR ");
}
let mut qualified = String::new();
if let Some(table) = qualify_with {
qualified.push_str(&b.d.quote_ident(table));
qualified.push('.');
}
qualified.push_str(&b.d.quote_ident(col));
b.d.write_ilike(&mut b.sql, &qualified, &placeholder, false);
}
b.sql.push(')');
}
Ok(())
}
pub(super) fn write_where_expr(
b: &mut Sql<'_>,
expr: &WhereExpr,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
match expr {
WhereExpr::Predicate(filter) => write_filter(b, filter, qualify_with, model),
WhereExpr::And(items) => write_joined(b, items, " AND ", qualify_with, model),
WhereExpr::Or(items) => {
if items.is_empty() {
return Err(SqlError::EmptyOrBranch);
}
write_joined(b, items, " OR ", qualify_with, model)
}
WhereExpr::Not(child) => {
b.sql.push_str("NOT (");
write_where_expr(b, child, qualify_with, model)?;
b.sql.push(')');
Ok(())
}
}
}
fn write_joined(
b: &mut Sql<'_>,
items: &[WhereExpr],
sep: &str,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
let mut first = true;
for child in items {
if !first {
b.sql.push_str(sep);
}
first = false;
write_child(b, child, qualify_with, model)?;
}
Ok(())
}
fn write_child(
b: &mut Sql<'_>,
expr: &WhereExpr,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
match expr {
WhereExpr::Predicate(filter) => write_filter(b, filter, qualify_with, model),
WhereExpr::And(_) | WhereExpr::Or(_) | WhereExpr::Not(_) => {
b.sql.push('(');
write_where_expr(b, expr, qualify_with, model)?;
b.sql.push(')');
Ok(())
}
}
}
#[allow(clippy::too_many_lines)] fn write_filter(
b: &mut Sql<'_>,
filter: &Filter,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
let qualified_col = render_qualified_col(b.d, qualify_with, filter.column);
let cast = model.and_then(|m| null_cast_for(b.d, m, filter.column));
match filter.op {
Op::Eq => simple_op(b, &qualified_col, " = ", filter.value.clone(), cast),
Op::Ne => simple_op(b, &qualified_col, " <> ", filter.value.clone(), cast),
Op::Lt => simple_op(b, &qualified_col, " < ", filter.value.clone(), cast),
Op::Lte => simple_op(b, &qualified_col, " <= ", filter.value.clone(), cast),
Op::Gt => simple_op(b, &qualified_col, " > ", filter.value.clone(), cast),
Op::Gte => simple_op(b, &qualified_col, " >= ", filter.value.clone(), cast),
Op::Like => simple_op(b, &qualified_col, " LIKE ", filter.value.clone(), cast),
Op::NotLike => simple_op(b, &qualified_col, " NOT LIKE ", filter.value.clone(), cast),
Op::ILike | Op::NotILike => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_ilike(
&mut b.sql,
&qualified_col,
&p,
matches!(filter.op, Op::NotILike),
);
}
Op::In | Op::NotIn => {
let SqlValue::List(elements) = &filter.value else {
return Err(SqlError::InRequiresList);
};
if elements.is_empty() {
return Err(SqlError::EmptyInList);
}
b.sql.push_str(&qualified_col);
b.sql.push_str(if matches!(filter.op, Op::In) {
" IN ("
} else {
" NOT IN ("
});
let mut first = true;
for elem in elements {
if !first {
b.sql.push_str(", ");
}
first = false;
b.push_param_typed(elem.clone(), cast);
}
b.sql.push(')');
}
Op::Between => {
let SqlValue::List(bounds) = &filter.value else {
return Err(SqlError::BetweenRequiresTwoElementList);
};
if bounds.len() != 2 {
return Err(SqlError::BetweenRequiresTwoElementList);
}
b.sql.push_str(&qualified_col);
b.sql.push_str(" BETWEEN ");
b.push_param_typed(bounds[0].clone(), cast);
b.sql.push_str(" AND ");
b.push_param_typed(bounds[1].clone(), cast);
}
Op::IsNull => {
let SqlValue::Bool(is_null) = filter.value else {
return Err(SqlError::IsNullRequiresBool);
};
b.sql.push_str(&qualified_col);
b.sql
.push_str(if is_null { " IS NULL" } else { " IS NOT NULL" });
}
Op::IsDistinctFrom | Op::IsNotDistinctFrom => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_null_safe_eq(
&mut b.sql,
&qualified_col,
&p,
matches!(filter.op, Op::IsDistinctFrom),
);
}
Op::JsonContains => {
require_op(b.d, filter.op)?;
let SqlValue::Json(_) = &filter.value else {
return Err(SqlError::JsonOpRequiresJson);
};
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_json_contains(&mut b.sql, &qualified_col, &p);
}
Op::JsonContainedBy => {
require_op(b.d, filter.op)?;
let SqlValue::Json(_) = &filter.value else {
return Err(SqlError::JsonOpRequiresJson);
};
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_json_contained_by(&mut b.sql, &qualified_col, &p);
}
Op::JsonHasKey => {
require_op(b.d, filter.op)?;
let SqlValue::String(_) = &filter.value else {
return Err(SqlError::JsonKeyRequiresString);
};
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_json_has_key(&mut b.sql, &qualified_col, &p);
}
Op::JsonHasAnyKey | Op::JsonHasAllKeys => {
require_op(b.d, filter.op)?;
let SqlValue::List(keys) = &filter.value else {
return Err(SqlError::JsonKeysRequiresList);
};
let placeholders = bind_param_list(b, keys);
if matches!(filter.op, Op::JsonHasAnyKey) {
b.d.write_json_has_any_keys(&mut b.sql, &qualified_col, &placeholders);
} else {
b.d.write_json_has_all_keys(&mut b.sql, &qualified_col, &placeholders);
}
}
}
Ok(())
}
fn simple_op(
b: &mut Sql<'_>,
qualified_col: &str,
kw: &str,
value: SqlValue,
cast: Option<&'static str>,
) {
b.sql.push_str(qualified_col);
b.sql.push_str(kw);
b.push_param_typed(value, cast);
}
fn render_qualified_col(d: &dyn Dialect, qualify_with: Option<&str>, column: &str) -> String {
let mut s = String::new();
if let Some(table) = qualify_with {
s.push_str(&d.quote_ident(table));
s.push('.');
}
s.push_str(&d.quote_ident(column));
s
}
fn bind_param_list(b: &mut Sql<'_>, values: &[SqlValue]) -> Vec<String> {
let mut out = Vec::with_capacity(values.len());
for v in values {
b.params.push(v.clone());
out.push(b.d.placeholder(b.params.len()));
}
out
}
fn require_op(d: &dyn Dialect, op: Op) -> Result<(), SqlError> {
if d.supports_op(op) {
Ok(())
} else {
Err(SqlError::OperatorNotSupportedInDialect {
op: op_label(op),
dialect: d.name(),
})
}
}
fn op_label(op: Op) -> &'static str {
match op {
Op::Eq => "=",
Op::Ne => "<>",
Op::Lt => "<",
Op::Lte => "<=",
Op::Gt => ">",
Op::Gte => ">=",
Op::In => "IN",
Op::NotIn => "NOT IN",
Op::Like => "LIKE",
Op::NotLike => "NOT LIKE",
Op::ILike => "ILIKE",
Op::NotILike => "NOT ILIKE",
Op::Between => "BETWEEN",
Op::IsNull => "IS NULL",
Op::IsDistinctFrom => "IS DISTINCT FROM",
Op::IsNotDistinctFrom => "IS NOT DISTINCT FROM",
Op::JsonContains => "@>",
Op::JsonContainedBy => "<@",
Op::JsonHasKey => "? (json)",
Op::JsonHasAnyKey => "?| (json)",
Op::JsonHasAllKeys => "?& (json)",
}
}
fn write_order_limit_offset(
b: &mut Sql<'_>,
order_by: &[OrderClause],
limit: Option<i64>,
offset: Option<i64>,
qualify_with: Option<&str>,
) {
if !order_by.is_empty() {
b.sql.push_str(" ORDER BY ");
for (i, clause) in order_by.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
if let Some(table) = qualify_with {
b.write_ident(table);
b.sql.push('.');
}
b.write_ident(clause.column);
if clause.desc {
b.sql.push_str(" DESC");
}
}
}
if let Some(n) = limit {
let _ = write!(b.sql, " LIMIT {n}");
}
if let Some(n) = offset {
let _ = write!(b.sql, " OFFSET {n}");
}
}
fn write_returning(b: &mut Sql<'_>, returning: &[&'static str]) -> Result<(), SqlError> {
if returning.is_empty() {
return Ok(());
}
if !b.d.supports_returning() {
return Err(SqlError::OperatorNotSupportedInDialect {
op: "RETURNING",
dialect: b.d.name(),
});
}
b.sql.push_str(" RETURNING ");
let mut first = true;
for col in returning {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
}
Ok(())
}
#[allow(clippy::too_many_arguments)] pub(crate) fn compile_where_order_tail(
d: &dyn Dialect,
where_clause: &WhereExpr,
search: Option<&SearchClause>,
order_by: &[OrderClause],
limit: Option<i64>,
offset: Option<i64>,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<CompiledStatement, SqlError> {
let mut b = Sql::new(d);
write_where_with_search(&mut b, where_clause, search, qualify_with, model)?;
write_order_limit_offset(&mut b, order_by, limit, offset, qualify_with);
Ok(b.finish())
}