use std::fmt::Write as _;
use crate::core::{
AggregateExpr, AggregateQuery, BulkInsertQuery, BulkUpdateQuery, CountQuery, DeleteQuery,
Filter, InsertQuery, ModelSchema, Op, SearchClause, SelectQuery, SqlValue, UpdateQuery,
WhereExpr,
};
use super::{CompiledStatement, Dialect, SqlError};
#[allow(clippy::struct_field_names)] pub(super) struct Sql<'d> {
pub d: &'d dyn Dialect,
pub sql: String,
pub params: Vec<SqlValue>,
pub scope_stack: Vec<&'static ModelSchema>,
pub current_qualify_alias: Option<&'static str>,
pub aggregate_allowed: bool,
}
impl<'d> Sql<'d> {
pub(super) fn new(d: &'d dyn Dialect) -> Self {
Self {
d,
sql: String::new(),
params: Vec::new(),
scope_stack: Vec::new(),
current_qualify_alias: None,
aggregate_allowed: false,
}
}
pub(super) fn with_capacity(d: &'d dyn Dialect, cap: usize) -> Self {
Self {
d,
sql: String::new(),
params: Vec::with_capacity(cap),
scope_stack: Vec::new(),
current_qualify_alias: None,
aggregate_allowed: false,
}
}
pub(super) fn write_ident(&mut self, name: &str) {
self.sql.push_str(&self.d.quote_ident(name));
}
pub(super) fn push_param_typed(&mut self, value: SqlValue, cast: Option<&'static str>) {
let is_null = matches!(value, SqlValue::Null);
self.params.push(value);
let p = self.d.placeholder(self.params.len());
self.sql.push_str(&p);
if is_null {
if let Some(ty) = cast {
self.sql.push_str("::");
self.sql.push_str(ty);
}
}
}
pub(super) fn push_param(&mut self, value: SqlValue) {
self.push_param_typed(value, None);
}
pub(super) fn finish(self) -> CompiledStatement {
CompiledStatement {
sql: self.sql,
params: self.params,
}
}
}
pub(super) fn null_cast_for(
d: &dyn Dialect,
model: &ModelSchema,
column: &str,
) -> Option<&'static str> {
let field = model.field_by_column(column)?;
d.null_cast(field.ty)
}
pub(super) fn write_select(b: &mut Sql<'_>, query: &SelectQuery) -> Result<(), SqlError> {
b.scope_stack.push(query.model);
let result = if query.compound.is_empty() {
write_select_inner(b, query)
} else {
write_compound_select(b, query)
};
b.scope_stack.pop();
result
}
fn write_compound_select(b: &mut Sql<'_>, query: &SelectQuery) -> Result<(), SqlError> {
let head = SelectQuery {
model: query.model,
where_clause: query.where_clause.clone(),
search: query.search.clone(),
joins: query.joins.clone(),
order_by: Vec::new(),
limit: None,
offset: None,
lock_mode: None,
compound: Vec::new(),
projection: None,
};
b.sql.push('(');
write_select_inner(b, &head)?;
b.sql.push(')');
for branch in &query.compound {
b.sql.push(' ');
b.sql.push_str(branch.op.keyword());
b.sql.push_str(" (");
b.scope_stack.push(branch.query.model);
let r = if branch.query.compound.is_empty() {
write_select_inner(b, &branch.query)
} else {
write_compound_select(b, &branch.query)
};
b.scope_stack.pop();
r?;
b.sql.push(')');
}
write_order_limit_offset(b, &query.order_by, query.limit, query.offset, None)?;
if let Some(lock) = &query.lock_mode {
write_lock_clause(b, lock);
}
Ok(())
}
fn write_select_inner(b: &mut Sql<'_>, query: &SelectQuery) -> Result<(), SqlError> {
let qualify = !query.joins.is_empty();
b.sql.push_str("SELECT ");
let mut first_col = true;
if let Some(cols) = query.projection.as_ref() {
for col in cols {
if !first_col {
b.sql.push_str(", ");
}
first_col = false;
if qualify {
b.write_ident(query.model.table);
b.sql.push('.');
}
b.write_ident(col);
}
} else {
for field in query.model.scalar_fields() {
if !first_col {
b.sql.push_str(", ");
}
first_col = false;
if qualify {
b.write_ident(query.model.table);
b.sql.push('.');
}
b.write_ident(field.column);
}
}
for join in &query.joins {
for col in &join.project {
b.sql.push_str(", ");
b.write_ident(join.alias);
b.sql.push('.');
b.write_ident(col);
b.sql.push_str(" AS ");
b.write_ident(&format!("{}__{}", join.alias, col));
}
}
b.sql.push_str(" FROM ");
b.write_ident(query.model.table);
for join in &query.joins {
use crate::core::JoinKind;
let kind_kw = match (join.kind, b.d.name()) {
(JoinKind::Inner, _) => "INNER JOIN",
(JoinKind::Left, _) => "LEFT JOIN",
(JoinKind::Right, "sqlite") => {
return Err(SqlError::JoinKindNotSupported {
kind: "RIGHT",
dialect: b.d.name(),
});
}
(JoinKind::Right, _) => "RIGHT JOIN",
(JoinKind::Full, "postgres") => "FULL OUTER JOIN",
(JoinKind::Full, _) => {
return Err(SqlError::JoinKindNotSupported {
kind: "FULL",
dialect: b.d.name(),
});
}
};
if join.on.is_empty() {
return Err(SqlError::EmptyJoinOnCondition);
}
b.sql.push(' ');
b.sql.push_str(kind_kw);
b.sql.push(' ');
b.write_ident(join.target.table);
b.sql.push_str(" AS ");
b.write_ident(join.alias);
b.sql.push_str(" ON ");
let prior_qualify = b.current_qualify_alias.replace(join.alias);
let on_result = write_where_expr(b, &join.on, Some(join.alias), Some(join.target));
b.current_qualify_alias = prior_qualify;
on_result?;
}
write_where_with_search(
b,
&query.where_clause,
query.search.as_ref(),
qualify.then_some(query.model.table),
Some(query.model),
)?;
write_order_limit_offset(
b,
&query.order_by,
query.limit,
query.offset,
qualify.then_some(query.model.table),
)?;
if let Some(lock) = &query.lock_mode {
write_lock_clause(b, lock);
}
Ok(())
}
fn write_lock_clause(b: &mut Sql<'_>, lock: &crate::core::LockMode) {
if b.d.name() == "sqlite" {
return;
}
b.sql.push_str(" FOR ");
if lock.no_key && b.d.name() == "postgres" {
b.sql.push_str("NO KEY UPDATE");
} else {
b.sql.push_str("UPDATE");
}
if !lock.of.is_empty() {
b.sql.push_str(" OF ");
let mut first = true;
for t in &lock.of {
if !first {
b.sql.push_str(", ");
}
first = false;
b.sql.push_str(&b.d.quote_ident(t));
}
}
if lock.skip_locked {
b.sql.push_str(" SKIP LOCKED");
} else if lock.nowait {
b.sql.push_str(" NOWAIT");
}
}
pub(super) fn write_count(b: &mut Sql<'_>, query: &CountQuery) -> Result<(), SqlError> {
b.scope_stack.push(query.model);
let r = (|| {
b.sql.push_str("SELECT COUNT(*) FROM ");
b.write_ident(query.model.table);
write_where_with_search(
b,
&query.where_clause,
query.search.as_ref(),
None,
Some(query.model),
)?;
Ok(())
})();
b.scope_stack.pop();
r
}
pub(super) fn write_aggregate(b: &mut Sql<'_>, query: &AggregateQuery) -> Result<(), SqlError> {
b.scope_stack.push(query.model);
let r = write_aggregate_inner(b, query);
b.scope_stack.pop();
r
}
fn write_aggregate_inner(b: &mut Sql<'_>, query: &AggregateQuery) -> Result<(), SqlError> {
b.sql.push_str("SELECT ");
for (i, col) in query.group_by.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.write_ident(col);
}
for (i, (alias, expr)) in query.aggregates.iter().enumerate() {
if !query.group_by.is_empty() || i > 0 {
b.sql.push_str(", ");
}
write_aggregate_expr(b, expr, query.model)?;
b.sql.push_str(" AS ");
b.write_ident(alias);
}
b.sql.push_str(" FROM ");
b.write_ident(query.model.table);
write_where(b, &query.where_clause, Some(query.model))?;
if !query.group_by.is_empty() {
b.sql.push_str(" GROUP BY ");
for (i, col) in query.group_by.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.write_ident(col);
}
}
if let Some(having) = &query.having {
b.sql.push_str(" HAVING ");
let prev = b.aggregate_allowed;
b.aggregate_allowed = true;
let r = write_where_expr(b, having, None, Some(query.model));
b.aggregate_allowed = prev;
r?;
}
let prev = b.aggregate_allowed;
b.aggregate_allowed = true;
let r = write_order_limit_offset(b, &query.order_by, query.limit, query.offset, None);
b.aggregate_allowed = prev;
r?;
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum AggCast {
Int,
Float,
}
fn aggregate_cast_kind(expr: &AggregateExpr) -> Option<AggCast> {
match expr {
AggregateExpr::Sum(_) => Some(AggCast::Int),
AggregateExpr::Avg(_)
| AggregateExpr::StdDev(_)
| AggregateExpr::StdDevPop(_)
| AggregateExpr::Variance(_)
| AggregateExpr::VariancePop(_) => Some(AggCast::Float),
_ => None,
}
}
fn apply_agg_cast(d: &dyn Dialect, kind: AggCast, inner: &str) -> String {
match kind {
AggCast::Int => d.cast_aggregate_to_int(inner),
AggCast::Float => d.cast_aggregate_to_float(inner),
}
}
fn format_bare_aggregate(b: &Sql<'_>, expr: &AggregateExpr) -> Result<String, SqlError> {
Ok(match expr {
AggregateExpr::Count(None) => "COUNT(*)".into(),
AggregateExpr::Count(Some(col)) => format!("COUNT({})", b.d.quote_ident(col)),
AggregateExpr::CountDistinct(col) => {
format!("COUNT(DISTINCT {})", b.d.quote_ident(col))
}
AggregateExpr::Sum(col) => format!("SUM({})", b.d.quote_ident(col)),
AggregateExpr::Avg(col) => format!("AVG({})", b.d.quote_ident(col)),
AggregateExpr::Max(col) => format!("MAX({})", b.d.quote_ident(col)),
AggregateExpr::Min(col) => format!("MIN({})", b.d.quote_ident(col)),
AggregateExpr::StdDev(col)
| AggregateExpr::StdDevPop(col)
| AggregateExpr::Variance(col)
| AggregateExpr::VariancePop(col) => {
if b.d.name() == "sqlite" {
return Err(SqlError::AggregateNotSupported {
aggregate: stddev_variance_name(expr),
dialect: b.d.name(),
});
}
format!("{}({})", stddev_variance_name(expr), b.d.quote_ident(col))
}
AggregateExpr::Filtered { .. } | AggregateExpr::Coalesced { .. } => {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "wrapper at format_bare_aggregate site",
});
}
AggregateExpr::Window(_) => {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "Window at format_bare_aggregate site",
});
}
AggregateExpr::ArrayAgg { .. }
| AggregateExpr::StringAgg { .. }
| AggregateExpr::JsonbAgg { .. } => {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "PG-aggregate at format_bare_aggregate site",
});
}
})
}
fn write_aggregate_expr(
b: &mut Sql<'_>,
expr: &AggregateExpr,
model: &'static ModelSchema,
) -> Result<(), SqlError> {
match expr {
AggregateExpr::Coalesced { inner, default } => {
if matches!(inner.as_ref(), AggregateExpr::Coalesced { .. }) {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "Coalesced",
});
}
b.sql.push_str("COALESCE(");
write_aggregate_expr(b, inner, model)?;
b.sql.push_str(", ");
let cast = aggregate_column(inner).and_then(|c| null_cast_for(b.d, model, c));
b.push_param_typed(default.clone(), cast);
b.sql.push(')');
Ok(())
}
AggregateExpr::Filtered { inner, filter } => {
if matches!(inner.as_ref(), AggregateExpr::Filtered { .. }) {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "Filtered",
});
}
if matches!(inner.as_ref(), AggregateExpr::Coalesced { .. }) {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "Filtered(Coalesced)",
});
}
if matches!(inner.as_ref(), AggregateExpr::Window(_)) {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "Filtered(Window)",
});
}
if b.d.name() == "mysql" {
return write_aggregate_as_case_when(b, inner, filter);
}
let bare = format_bare_aggregate(b, inner)?;
let prior = b.sql.len();
b.sql.push_str(&bare);
b.sql.push_str(" FILTER (WHERE ");
write_where_expr(b, filter, None, Some(model))?;
b.sql.push(')');
if let Some(kind) = aggregate_cast_kind(inner) {
let emitted = b.sql[prior..].to_string();
b.sql.truncate(prior);
let wrapped = apply_agg_cast(b.d, kind, &format!("({emitted})"));
b.sql.push_str(&wrapped);
}
Ok(())
}
AggregateExpr::Window(w) => write_window_expr(b, w),
AggregateExpr::ArrayAgg { column, distinct } => {
if b.d.name() != "postgres" {
return Err(SqlError::AggregateNotSupportedInDialect {
aggregate: "array_agg",
dialect: b.d.name(),
});
}
b.sql.push_str("array_agg(");
if *distinct {
b.sql.push_str("DISTINCT ");
}
b.write_ident(column);
b.sql.push(')');
Ok(())
}
AggregateExpr::StringAgg {
column,
delimiter,
distinct,
} => {
if b.d.name() != "postgres" {
return Err(SqlError::AggregateNotSupportedInDialect {
aggregate: "string_agg",
dialect: b.d.name(),
});
}
b.sql.push_str("string_agg(");
if *distinct {
b.sql.push_str("DISTINCT ");
}
b.write_ident(column);
b.sql.push_str(", ");
b.push_param(crate::core::SqlValue::String(delimiter.clone()));
b.sql.push(')');
Ok(())
}
AggregateExpr::JsonbAgg { column } => {
if b.d.name() != "postgres" {
return Err(SqlError::AggregateNotSupportedInDialect {
aggregate: "jsonb_agg",
dialect: b.d.name(),
});
}
b.sql.push_str("jsonb_agg(");
b.write_ident(column);
b.sql.push(')');
Ok(())
}
_ => write_aggregate_kind(b, expr),
}
}
fn write_aggregate_kind(b: &mut Sql<'_>, expr: &AggregateExpr) -> Result<(), SqlError> {
let bare = format_bare_aggregate(b, expr)?;
let out = match aggregate_cast_kind(expr) {
Some(kind) => apply_agg_cast(b.d, kind, &bare),
None => bare,
};
b.sql.push_str(&out);
Ok(())
}
fn write_aggregate_as_case_when(
b: &mut Sql<'_>,
inner: &AggregateExpr,
filter: &WhereExpr,
) -> Result<(), SqlError> {
let (agg_kw, case_then, distinct_prefix) = match inner {
AggregateExpr::Count(None) => ("COUNT", None, ""),
AggregateExpr::Count(Some(col)) => ("COUNT", Some(*col), ""),
AggregateExpr::CountDistinct(col) => ("COUNT", Some(*col), "DISTINCT "),
AggregateExpr::Sum(col) => ("SUM", Some(*col), ""),
AggregateExpr::Avg(col) => ("AVG", Some(*col), ""),
AggregateExpr::Max(col) => ("MAX", Some(*col), ""),
AggregateExpr::Min(col) => ("MIN", Some(*col), ""),
AggregateExpr::StdDev(col)
| AggregateExpr::StdDevPop(col)
| AggregateExpr::Variance(col)
| AggregateExpr::VariancePop(col) => (stddev_variance_name(inner), Some(*col), ""),
AggregateExpr::Filtered { .. } | AggregateExpr::Coalesced { .. } => {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "wrapper inside Filtered fallback",
});
}
AggregateExpr::Window(_) => {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "Filtered(Window)",
});
}
AggregateExpr::ArrayAgg { .. }
| AggregateExpr::StringAgg { .. }
| AggregateExpr::JsonbAgg { .. } => {
return Err(SqlError::NestedAggregateWrapper {
wrapper: "Filtered(PG-aggregate)",
});
}
};
let prior = b.sql.len();
b.sql.push_str(agg_kw);
b.sql.push('(');
b.sql.push_str(distinct_prefix);
b.sql.push_str("CASE WHEN ");
write_where_expr(b, filter, None, None)?;
b.sql.push_str(" THEN ");
match case_then {
Some(col) => b.write_ident(col),
None => b.sql.push('1'),
}
b.sql.push_str(" END)");
if let Some(kind) = aggregate_cast_kind(inner) {
let emitted = b.sql[prior..].to_string();
b.sql.truncate(prior);
let wrapped = apply_agg_cast(b.d, kind, &emitted);
b.sql.push_str(&wrapped);
}
Ok(())
}
fn aggregate_column(expr: &AggregateExpr) -> Option<&'static str> {
match expr {
AggregateExpr::Count(c) => *c,
AggregateExpr::CountDistinct(c)
| AggregateExpr::Sum(c)
| AggregateExpr::Avg(c)
| AggregateExpr::Max(c)
| AggregateExpr::Min(c)
| AggregateExpr::StdDev(c)
| AggregateExpr::StdDevPop(c)
| AggregateExpr::Variance(c)
| AggregateExpr::VariancePop(c) => Some(c),
AggregateExpr::ArrayAgg { column, .. }
| AggregateExpr::StringAgg { column, .. }
| AggregateExpr::JsonbAgg { column } => Some(column),
AggregateExpr::Filtered { inner, .. } | AggregateExpr::Coalesced { inner, .. } => {
aggregate_column(inner)
}
AggregateExpr::Window(w) => w.args.iter().find_map(|a| match a {
crate::core::Expr::Column(c) => Some(*c),
_ => None,
}),
}
}
fn stddev_variance_name(expr: &AggregateExpr) -> &'static str {
match expr {
AggregateExpr::StdDev(_) => "STDDEV_SAMP",
AggregateExpr::StdDevPop(_) => "STDDEV_POP",
AggregateExpr::Variance(_) => "VAR_SAMP",
AggregateExpr::VariancePop(_) => "VAR_POP",
_ => "(unknown)", }
}
pub(super) fn write_insert(b: &mut Sql<'_>, query: &InsertQuery) -> Result<(), SqlError> {
if query.columns.is_empty() && query.returning.is_empty() {
return Err(SqlError::EmptyInsert);
}
if query.columns.len() != query.values.len() {
return Err(SqlError::InsertShapeMismatch {
columns: query.columns.len(),
values: query.values.len(),
});
}
b.sql.push_str("INSERT INTO ");
b.write_ident(query.model.table);
if query.columns.is_empty() {
b.sql.push_str(" DEFAULT VALUES");
} else {
b.sql.push_str(" (");
let mut first = true;
for col in &query.columns {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
}
b.sql.push_str(") VALUES (");
let mut first = true;
for (col, value) in query.columns.iter().zip(&query.values) {
if !first {
b.sql.push_str(", ");
}
first = false;
let cast = null_cast_for(b.d, query.model, col);
b.push_param_typed(value.clone(), cast);
}
b.sql.push(')');
}
if let Some(conflict) = &query.on_conflict {
b.d.write_conflict_clause(&mut b.sql, conflict)?;
}
write_returning(b, &query.returning)?;
Ok(())
}
pub(super) fn write_bulk_insert(b: &mut Sql<'_>, query: &BulkInsertQuery) -> Result<(), SqlError> {
if query.rows.is_empty() {
return Err(SqlError::EmptyBulkInsert);
}
if query.columns.is_empty() && query.returning.is_empty() {
return Err(SqlError::EmptyInsert);
}
for row in &query.rows {
if row.len() != query.columns.len() {
return Err(SqlError::InsertShapeMismatch {
columns: query.columns.len(),
values: row.len(),
});
}
}
b.sql.push_str("INSERT INTO ");
b.write_ident(query.model.table);
if query.columns.is_empty() {
let pk = query
.returning
.first()
.copied()
.ok_or(SqlError::EmptyInsert)?;
b.sql.push_str(" (");
b.write_ident(pk);
b.sql.push_str(") VALUES ");
let mut first_row = true;
for _ in &query.rows {
if !first_row {
b.sql.push_str(", ");
}
first_row = false;
b.sql.push_str("(DEFAULT)");
}
} else {
b.sql.push_str(" (");
let mut first = true;
for col in &query.columns {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
}
b.sql.push_str(") VALUES ");
let mut first_row = true;
for row in &query.rows {
if !first_row {
b.sql.push_str(", ");
}
first_row = false;
b.sql.push('(');
let mut first_v = true;
for (col, value) in query.columns.iter().zip(row) {
if !first_v {
b.sql.push_str(", ");
}
first_v = false;
let cast = null_cast_for(b.d, query.model, col);
b.push_param_typed(value.clone(), cast);
}
b.sql.push(')');
}
}
if let Some(conflict) = &query.on_conflict {
b.d.write_conflict_clause(&mut b.sql, conflict)?;
}
write_returning(b, &query.returning)?;
Ok(())
}
pub(super) fn write_update(b: &mut Sql<'_>, query: &UpdateQuery) -> Result<(), SqlError> {
if query.set.is_empty() {
return Err(SqlError::EmptyUpdateSet);
}
b.scope_stack.push(query.model);
let r = (|| {
b.sql.push_str("UPDATE ");
b.write_ident(query.model.table);
b.sql.push_str(" SET ");
let mut first = true;
for assignment in &query.set {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(assignment.column);
b.sql.push_str(" = ");
let cast = null_cast_for(b.d, query.model, assignment.column);
write_expr(b, &assignment.value, cast)?;
}
write_where(b, &query.where_clause, Some(query.model))?;
Ok(())
})();
b.scope_stack.pop();
r
}
fn write_expr(
b: &mut Sql<'_>,
expr: &crate::core::Expr,
cast: Option<&'static str>,
) -> Result<(), SqlError> {
use crate::core::{BinOp as BO, Expr};
match expr {
Expr::Literal(v) => {
b.push_param_typed(v.clone(), cast);
Ok(())
}
Expr::Column(name) => {
if let Some(alias) = b.current_qualify_alias {
let qualified = format!("{}.{}", b.d.quote_ident(alias), b.d.quote_ident(name),);
b.sql.push_str(&qualified);
} else {
b.write_ident(name);
}
Ok(())
}
Expr::BinOp { left, op, right } => {
if matches!(op, BO::BitXor) && b.d.name() == "sqlite" {
return Err(SqlError::OpNotSupportedInDialect {
op: "BitXor",
dialect: b.d.name(),
});
}
b.sql.push('(');
write_expr(b, left, None)?;
b.sql.push(' ');
b.sql.push_str(match op {
BO::Add => "+",
BO::Sub => "-",
BO::Mul => "*",
BO::Div => "/",
BO::Mod => "%",
BO::BitAnd => "&",
BO::BitOr => "|",
BO::BitXor => {
if b.d.name() == "postgres" {
"#"
} else {
"^"
}
}
BO::BitShl => "<<",
BO::BitShr => ">>",
});
b.sql.push(' ');
write_expr(b, right, None)?;
b.sql.push(')');
Ok(())
}
Expr::Function { kind, args } => write_function(b, *kind, args),
Expr::Case { branches, default } => write_case(b, branches, default.as_deref()),
Expr::Subquery(inner) => {
b.sql.push('(');
let prev = b.aggregate_allowed;
b.aggregate_allowed = false;
let r = write_select(b, inner);
b.aggregate_allowed = prev;
r?;
b.sql.push(')');
Ok(())
}
Expr::OuterRef(col) => {
let len = b.scope_stack.len();
if len < 2 {
return Err(SqlError::OuterRefOutsideSubquery { column: col });
}
let outer = b.scope_stack[len - 2];
let qualified = format!("{}.{}", b.d.quote_ident(outer.table), b.d.quote_ident(col),);
b.sql.push_str(&qualified);
Ok(())
}
Expr::AliasedColumn { alias, column } => {
let qualified = format!("{}.{}", b.d.quote_ident(alias), b.d.quote_ident(column),);
b.sql.push_str(&qualified);
Ok(())
}
Expr::Window(w) => write_window_expr(b, w),
Expr::Aggregate(agg) => {
if !b.aggregate_allowed {
return Err(SqlError::AggregateOutsideAggregateContext);
}
let model = b
.scope_stack
.last()
.copied()
.expect("Expr::Aggregate emitted outside any scope frame");
write_aggregate_expr(b, agg, model)
}
}
}
fn write_window_expr(b: &mut Sql<'_>, w: &crate::core::WindowExpr) -> Result<(), SqlError> {
use crate::core::{Expr, WindowFn};
let fn_name = match w.kind {
WindowFn::RowNumber => "ROW_NUMBER",
WindowFn::Rank => "RANK",
WindowFn::DenseRank => "DENSE_RANK",
WindowFn::Ntile => "NTILE",
WindowFn::Lag => "LAG",
WindowFn::Lead => "LEAD",
WindowFn::FirstValue => "FIRST_VALUE",
WindowFn::LastValue => "LAST_VALUE",
};
b.sql.push_str(fn_name);
b.sql.push('(');
let integer_arg_index: Option<usize> = match w.kind {
WindowFn::Lag | WindowFn::Lead => Some(1),
WindowFn::Ntile => Some(0),
_ => None,
};
for (i, arg) in w.args.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
if integer_arg_index == Some(i) {
if let Expr::Literal(SqlValue::I64(n)) = arg {
use std::fmt::Write as _;
let _ = write!(b.sql, "{n}");
continue;
}
}
write_expr(b, arg, None)?;
}
b.sql.push_str(") OVER (");
let mut first_clause = true;
if !w.partition_by.is_empty() {
b.sql.push_str("PARTITION BY ");
for (i, col) in w.partition_by.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.write_ident(col);
}
first_clause = false;
}
if !w.order_by.is_empty() {
if !first_clause {
b.sql.push(' ');
}
b.sql.push_str("ORDER BY ");
for (i, o) in w.order_by.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.write_ident(o.column);
if o.desc {
b.sql.push_str(" DESC");
}
}
first_clause = false;
}
if let Some(frame) = &w.frame {
if !first_clause {
b.sql.push(' ');
}
write_window_frame(b, frame);
}
b.sql.push(')');
Ok(())
}
fn write_window_frame(b: &mut Sql<'_>, frame: &crate::core::WindowFrame) {
use crate::core::{FrameBoundary, FrameKind};
b.sql.push_str(match frame.kind {
FrameKind::Rows => "ROWS",
FrameKind::Range => "RANGE",
});
b.sql.push(' ');
if frame.end.is_some() {
b.sql.push_str("BETWEEN ");
}
write_frame_boundary(b, frame.start);
if let Some(end) = frame.end {
b.sql.push_str(" AND ");
write_frame_boundary(b, end);
}
fn write_frame_boundary(b: &mut Sql<'_>, bound: FrameBoundary) {
match bound {
FrameBoundary::UnboundedPreceding => b.sql.push_str("UNBOUNDED PRECEDING"),
FrameBoundary::Preceding(n) => {
use std::fmt::Write as _;
let _ = write!(b.sql, "{n} PRECEDING");
}
FrameBoundary::CurrentRow => b.sql.push_str("CURRENT ROW"),
FrameBoundary::Following(n) => {
use std::fmt::Write as _;
let _ = write!(b.sql, "{n} FOLLOWING");
}
FrameBoundary::UnboundedFollowing => b.sql.push_str("UNBOUNDED FOLLOWING"),
}
}
}
fn write_case(
b: &mut Sql<'_>,
branches: &[crate::core::CaseBranch],
default: Option<&crate::core::Expr>,
) -> Result<(), SqlError> {
if branches.is_empty() {
return Err(SqlError::EmptyCaseBranches);
}
b.sql.push_str("CASE");
for branch in branches {
if branch.condition.is_empty() {
return Err(SqlError::EmptyCaseWhenCondition);
}
b.sql.push_str(" WHEN ");
write_where_expr(b, &branch.condition, None, None)?;
b.sql.push_str(" THEN ");
write_expr(b, &branch.then, None)?;
}
if let Some(d) = default {
b.sql.push_str(" ELSE ");
write_expr(b, d, None)?;
}
b.sql.push_str(" END");
Ok(())
}
#[allow(clippy::too_many_lines)] fn write_function(
b: &mut Sql<'_>,
kind: crate::core::ScalarFn,
args: &[crate::core::Expr],
) -> Result<(), SqlError> {
use crate::core::ScalarFn as F;
match kind {
F::Lower => write_call_unary(b, "LOWER", args),
F::Upper => write_call_unary(b, "UPPER", args),
F::Length => write_call_unary(b, "LENGTH", args),
F::Trim => write_call_unary(b, "TRIM", args),
F::LTrim => write_call_unary(b, "LTRIM", args),
F::RTrim => write_call_unary(b, "RTRIM", args),
F::Replace => {
if args.len() != 3 {
return Err(SqlError::FunctionArityMismatch {
func: "REPLACE",
expected: "3",
got: args.len(),
});
}
write_call(b, "REPLACE", args)
}
F::Concat => {
if args.is_empty() {
return Err(SqlError::FunctionArityMismatch {
func: "CONCAT",
expected: ">= 1",
got: 0,
});
}
if b.d.name() == "sqlite" {
b.sql.push('(');
let mut first = true;
for a in args {
if !first {
b.sql.push_str(" || ");
}
first = false;
write_expr(b, a, None)?;
}
b.sql.push(')');
Ok(())
} else {
write_call(b, "CONCAT", args)
}
}
F::Substr => {
if args.len() != 3 {
return Err(SqlError::FunctionArityMismatch {
func: "SUBSTRING",
expected: "3",
got: args.len(),
});
}
if b.d.name() == "postgres" {
b.sql.push_str("SUBSTRING(");
write_expr(b, &args[0], None)?;
b.sql.push_str(" FROM ");
write_expr(b, &args[1], None)?;
b.sql.push_str(" FOR ");
write_expr(b, &args[2], None)?;
b.sql.push(')');
Ok(())
} else {
let name = if b.d.name() == "mysql" {
"SUBSTRING"
} else {
"SUBSTR"
};
write_call(b, name, args)
}
}
F::Abs => write_call_unary(b, "ABS", args),
F::Floor => write_call_unary(b, "FLOOR", args),
F::Ceil => {
write_call_unary(b, "CEIL", args)
}
F::Round => {
if args.is_empty() || args.len() > 2 {
return Err(SqlError::FunctionArityMismatch {
func: "ROUND",
expected: "1 or 2",
got: args.len(),
});
}
write_call(b, "ROUND", args)
}
F::Coalesce => {
if args.is_empty() {
return Err(SqlError::FunctionArityMismatch {
func: "COALESCE",
expected: ">= 1",
got: 0,
});
}
write_call(b, "COALESCE", args)
}
F::Greatest => {
if args.is_empty() {
return Err(SqlError::FunctionArityMismatch {
func: "GREATEST",
expected: ">= 1",
got: 0,
});
}
if b.d.name() == "sqlite" && args.len() == 1 {
return Err(SqlError::OpNotSupportedInDialect {
op: "GREATEST with 1 argument (SQLite collides with the aggregate MAX)",
dialect: "sqlite",
});
}
let name = if b.d.name() == "sqlite" {
"MAX"
} else {
"GREATEST"
};
write_call(b, name, args)
}
F::Least => {
if args.is_empty() {
return Err(SqlError::FunctionArityMismatch {
func: "LEAST",
expected: ">= 1",
got: 0,
});
}
if b.d.name() == "sqlite" && args.len() == 1 {
return Err(SqlError::OpNotSupportedInDialect {
op: "LEAST with 1 argument (SQLite collides with the aggregate MIN)",
dialect: "sqlite",
});
}
let name = if b.d.name() == "sqlite" {
"MIN"
} else {
"LEAST"
};
write_call(b, name, args)
}
F::NullIf => {
if args.len() != 2 {
return Err(SqlError::FunctionArityMismatch {
func: "NULLIF",
expected: "2",
got: args.len(),
});
}
write_call(b, "NULLIF", args)
}
F::Now => {
if !args.is_empty() {
return Err(SqlError::FunctionArityMismatch {
func: "NOW",
expected: "0",
got: args.len(),
});
}
b.sql.push_str(if b.d.name() == "sqlite" {
"CURRENT_TIMESTAMP"
} else {
"NOW()"
});
Ok(())
}
F::ExtractYear
| F::ExtractMonth
| F::ExtractDay
| F::ExtractHour
| F::ExtractMinute
| F::ExtractSecond
| F::ExtractWeek => write_extract_int(b, kind, args),
F::ExtractWeekDay => write_extract_weekday(b, args),
F::ExtractQuarter => {
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: "EXTRACT(QUARTER)",
expected: "1",
got: args.len(),
});
}
if b.d.name() == "sqlite" {
return Err(SqlError::OpNotSupportedInDialect {
op: "EXTRACT(QUARTER) (SQLite has no native quarter token)",
dialect: "sqlite",
});
}
write_extract_int(b, kind, args)
}
F::TruncDate => {
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: "DATE",
expected: "1",
got: args.len(),
});
}
b.sql.push_str("DATE(");
write_expr(b, &args[0], None)?;
b.sql.push(')');
Ok(())
}
F::TruncYear | F::TruncMonth | F::TruncDay => write_trunc(b, kind, args),
F::TrigramSimilarity | F::TrigramWordSimilarity => {
let fname = if matches!(kind, F::TrigramWordSimilarity) {
"WORD_SIMILARITY"
} else {
"SIMILARITY"
};
if args.len() != 2 {
return Err(SqlError::FunctionArityMismatch {
func: if matches!(kind, F::TrigramWordSimilarity) {
"WORD_SIMILARITY"
} else {
"SIMILARITY"
},
expected: "2",
got: args.len(),
});
}
if b.d.name() != "postgres" {
return Err(SqlError::OpNotSupportedInDialect {
op: if matches!(kind, F::TrigramWordSimilarity) {
"WORD_SIMILARITY (pg_trgm) is Postgres-only"
} else {
"SIMILARITY (pg_trgm) is Postgres-only"
},
dialect: b.d.name(),
});
}
b.sql.push_str(fname);
b.sql.push('(');
write_expr(b, &args[0], None)?;
b.sql.push_str(", ");
write_expr(b, &args[1], None)?;
b.sql.push(')');
Ok(())
}
F::ToTsVector | F::PlainToTsQuery => {
let fname = if matches!(kind, F::ToTsVector) {
"to_tsvector"
} else {
"plainto_tsquery"
};
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: if matches!(kind, F::ToTsVector) {
"to_tsvector"
} else {
"plainto_tsquery"
},
expected: "1",
got: args.len(),
});
}
if b.d.name() != "postgres" {
return Err(SqlError::OpNotSupportedInDialect {
op: if matches!(kind, F::ToTsVector) {
"to_tsvector (FTS) is Postgres-only"
} else {
"plainto_tsquery (FTS) is Postgres-only"
},
dialect: b.d.name(),
});
}
b.sql.push_str(fname);
b.sql.push('(');
write_expr(b, &args[0], None)?;
b.sql.push(')');
Ok(())
}
F::TsRank => {
if args.len() != 2 {
return Err(SqlError::FunctionArityMismatch {
func: "ts_rank",
expected: "2",
got: args.len(),
});
}
if b.d.name() != "postgres" {
return Err(SqlError::OpNotSupportedInDialect {
op: "ts_rank (FTS) is Postgres-only",
dialect: b.d.name(),
});
}
b.sql.push_str("ts_rank(");
write_expr(b, &args[0], None)?;
b.sql.push_str(", ");
write_expr(b, &args[1], None)?;
b.sql.push(')');
Ok(())
}
F::TsHeadline => {
if args.len() != 2 && args.len() != 3 {
return Err(SqlError::FunctionArityMismatch {
func: "ts_headline",
expected: "2 or 3",
got: args.len(),
});
}
if b.d.name() != "postgres" {
return Err(SqlError::OpNotSupportedInDialect {
op: "ts_headline (FTS) is Postgres-only",
dialect: b.d.name(),
});
}
b.sql.push_str("ts_headline(");
for (i, a) in args.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
write_expr(b, a, None)?;
}
b.sql.push(')');
Ok(())
}
F::PhraseToTsQuery | F::WebsearchToTsQuery | F::ToTsQuery => {
let fname = match kind {
F::PhraseToTsQuery => "phraseto_tsquery",
F::WebsearchToTsQuery => "websearch_to_tsquery",
F::ToTsQuery => "to_tsquery",
_ => unreachable!(),
};
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: fname,
expected: "1",
got: args.len(),
});
}
if b.d.name() != "postgres" {
return Err(SqlError::OpNotSupportedInDialect {
op: match kind {
F::PhraseToTsQuery => "phraseto_tsquery (FTS) is Postgres-only",
F::WebsearchToTsQuery => "websearch_to_tsquery (FTS) is Postgres-only",
F::ToTsQuery => "to_tsquery (FTS) is Postgres-only",
_ => unreachable!(),
},
dialect: b.d.name(),
});
}
b.sql.push_str(fname);
b.sql.push('(');
write_expr(b, &args[0], None)?;
b.sql.push(')');
Ok(())
}
F::TsRankCd => {
if args.len() != 2 {
return Err(SqlError::FunctionArityMismatch {
func: "ts_rank_cd",
expected: "2",
got: args.len(),
});
}
if b.d.name() != "postgres" {
return Err(SqlError::OpNotSupportedInDialect {
op: "ts_rank_cd (FTS) is Postgres-only",
dialect: b.d.name(),
});
}
b.sql.push_str("ts_rank_cd(");
write_expr(b, &args[0], None)?;
b.sql.push_str(", ");
write_expr(b, &args[1], None)?;
b.sql.push(')');
Ok(())
}
}
}
fn write_extract_int(
b: &mut Sql<'_>,
kind: crate::core::ScalarFn,
args: &[crate::core::Expr],
) -> Result<(), SqlError> {
use crate::core::ScalarFn as F;
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: "EXTRACT",
expected: "1",
got: args.len(),
});
}
let field = match kind {
F::ExtractYear => "YEAR",
F::ExtractMonth => "MONTH",
F::ExtractDay => "DAY",
F::ExtractHour => "HOUR",
F::ExtractMinute => "MINUTE",
F::ExtractSecond => "SECOND",
F::ExtractWeek => "WEEK",
F::ExtractQuarter => "QUARTER",
_ => unreachable!("write_extract_int called with non-extract kind: {kind:?}"),
};
let dialect = b.d.name();
if dialect == "postgres" {
b.sql.push_str("CAST(EXTRACT(");
b.sql.push_str(field);
b.sql.push_str(" FROM ");
write_expr(b, &args[0], None)?;
b.sql.push_str(") AS INTEGER)");
} else if dialect == "mysql" {
b.sql.push_str(field);
b.sql.push('(');
write_expr(b, &args[0], None)?;
b.sql.push(')');
} else {
let token = match field {
"YEAR" => "%Y",
"MONTH" => "%m",
"DAY" => "%d",
"HOUR" => "%H",
"MINUTE" => "%M",
"SECOND" => "%S",
"WEEK" => "%W",
_ => unreachable!("sqlite extract: {field}"),
};
b.sql.push_str("CAST(strftime('");
b.sql.push_str(token);
b.sql.push_str("', ");
write_expr(b, &args[0], None)?;
b.sql.push_str(") AS INTEGER)");
}
Ok(())
}
fn write_extract_weekday(b: &mut Sql<'_>, args: &[crate::core::Expr]) -> Result<(), SqlError> {
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: "EXTRACT(WEEKDAY)",
expected: "1",
got: args.len(),
});
}
let dialect = b.d.name();
if dialect == "postgres" {
b.sql.push_str("CAST(EXTRACT(DOW FROM ");
write_expr(b, &args[0], None)?;
b.sql.push_str(") AS INTEGER)");
} else if dialect == "mysql" {
b.sql.push_str("(DAYOFWEEK(");
write_expr(b, &args[0], None)?;
b.sql.push_str(") - 1)");
} else {
b.sql.push_str("CAST(strftime('%w', ");
write_expr(b, &args[0], None)?;
b.sql.push_str(") AS INTEGER)");
}
Ok(())
}
fn write_trunc(
b: &mut Sql<'_>,
kind: crate::core::ScalarFn,
args: &[crate::core::Expr],
) -> Result<(), SqlError> {
use crate::core::ScalarFn as F;
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: "DATE_TRUNC",
expected: "1",
got: args.len(),
});
}
let dialect = b.d.name();
let pg_unit = match kind {
F::TruncYear => "year",
F::TruncMonth => "month",
F::TruncDay => "day",
_ => unreachable!("write_trunc: non-trunc kind: {kind:?}"),
};
let format_str = match kind {
F::TruncYear => "%Y-01-01",
F::TruncMonth => "%Y-%m-01",
F::TruncDay => "%Y-%m-%d",
_ => unreachable!(),
};
if dialect == "postgres" {
b.sql.push_str("DATE_TRUNC('");
b.sql.push_str(pg_unit);
b.sql.push_str("', ");
write_expr(b, &args[0], None)?;
b.sql.push(')');
} else if dialect == "mysql" {
if matches!(kind, F::TruncDay) {
b.sql.push_str("DATE(");
write_expr(b, &args[0], None)?;
b.sql.push(')');
} else {
b.sql.push_str("DATE_FORMAT(");
write_expr(b, &args[0], None)?;
b.sql.push_str(", '");
b.sql.push_str(format_str);
b.sql.push_str("')");
}
} else {
if matches!(kind, F::TruncDay) {
b.sql.push_str("date(");
write_expr(b, &args[0], None)?;
b.sql.push(')');
} else {
b.sql.push_str("strftime('");
b.sql.push_str(format_str);
b.sql.push_str("', ");
write_expr(b, &args[0], None)?;
b.sql.push(')');
}
}
Ok(())
}
fn write_call(b: &mut Sql<'_>, name: &str, args: &[crate::core::Expr]) -> Result<(), SqlError> {
b.sql.push_str(name);
b.sql.push('(');
let mut first = true;
for a in args {
if !first {
b.sql.push_str(", ");
}
first = false;
write_expr(b, a, None)?;
}
b.sql.push(')');
Ok(())
}
fn write_call_unary(
b: &mut Sql<'_>,
name: &'static str,
args: &[crate::core::Expr],
) -> Result<(), SqlError> {
if args.len() != 1 {
return Err(SqlError::FunctionArityMismatch {
func: name,
expected: "1",
got: args.len(),
});
}
write_call(b, name, args)
}
pub(super) fn write_delete(b: &mut Sql<'_>, query: &DeleteQuery) -> Result<(), SqlError> {
b.scope_stack.push(query.model);
let r = (|| {
b.sql.push_str("DELETE FROM ");
b.write_ident(query.model.table);
write_where(b, &query.where_clause, Some(query.model))?;
Ok(())
})();
b.scope_stack.pop();
r
}
pub(super) fn write_bulk_update_pg(
b: &mut Sql<'_>,
query: &BulkUpdateQuery,
) -> Result<(), SqlError> {
if query.rows.is_empty() {
return Err(SqlError::EmptyBulkInsert);
}
if query.update_columns.is_empty() {
return Err(SqlError::EmptyUpdateSet);
}
let pk_field = query
.model
.primary_key()
.ok_or(SqlError::MissingPrimaryKey)?;
b.sql.push_str("UPDATE ");
b.write_ident(query.model.table);
b.sql.push_str(" SET ");
let mut first = true;
for col in &query.update_columns {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
b.sql.push_str(" = __data.");
b.write_ident(col);
}
b.sql.push_str(" FROM (VALUES ");
let mut first_row = true;
for row in &query.rows {
if !first_row {
b.sql.push_str(", ");
}
first_row = false;
b.sql.push('(');
for (i, val) in row.iter().enumerate() {
if i > 0 {
b.sql.push_str(", ");
}
b.push_param(val.clone());
}
b.sql.push(')');
}
b.sql.push_str(") AS __data(");
b.write_ident(pk_field.column);
for col in &query.update_columns {
b.sql.push_str(", ");
b.write_ident(col);
}
b.sql.push_str(") WHERE ");
b.write_ident(query.model.table);
b.sql.push('.');
b.write_ident(pk_field.column);
b.sql.push_str(" = __data.");
b.write_ident(pk_field.column);
Ok(())
}
pub(super) fn write_where(
b: &mut Sql<'_>,
where_clause: &WhereExpr,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
if where_clause.is_empty() {
return Ok(());
}
b.sql.push_str(" WHERE ");
write_where_expr(b, where_clause, None, model)
}
pub(super) fn write_where_with_search(
b: &mut Sql<'_>,
where_clause: &WhereExpr,
search: Option<&SearchClause>,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
let has_search = search.is_some_and(|s| !s.columns.is_empty() && !s.query.is_empty());
let has_where = !where_clause.is_empty();
if !has_where && !has_search {
return Ok(());
}
b.sql.push_str(" WHERE ");
if has_where {
write_where_expr(b, where_clause, qualify_with, model)?;
}
if has_search {
let s = search.expect("checked above");
if has_where {
b.sql.push_str(" AND ");
}
b.params.push(SqlValue::String(format!("%{}%", s.query)));
let placeholder = b.d.placeholder(b.params.len());
b.sql.push('(');
for (i, col) in s.columns.iter().enumerate() {
if i > 0 {
b.sql.push_str(" OR ");
}
let mut qualified = String::new();
if let Some(table) = qualify_with {
qualified.push_str(&b.d.quote_ident(table));
qualified.push('.');
}
qualified.push_str(&b.d.quote_ident(col));
b.d.write_ilike(&mut b.sql, &qualified, &placeholder, false);
}
b.sql.push(')');
}
Ok(())
}
pub(super) fn write_where_expr(
b: &mut Sql<'_>,
expr: &WhereExpr,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
match expr {
WhereExpr::Predicate(filter) => write_filter(b, filter, qualify_with, model),
WhereExpr::ColumnCompare(cf) => write_column_compare(b, cf, qualify_with, model),
WhereExpr::And(items) => write_joined(b, items, " AND ", qualify_with, model),
WhereExpr::Or(items) => {
if items.is_empty() {
return Err(SqlError::EmptyOrBranch);
}
write_joined(b, items, " OR ", qualify_with, model)
}
WhereExpr::Xor(items) => write_xor(b, items, qualify_with, model),
WhereExpr::Not(child) => {
b.sql.push_str("NOT (");
write_where_expr(b, child, qualify_with, model)?;
b.sql.push(')');
Ok(())
}
WhereExpr::Exists(subq) => {
b.sql.push_str("EXISTS (");
write_select(b, subq)?;
b.sql.push(')');
Ok(())
}
WhereExpr::NotExists(subq) => {
b.sql.push_str("NOT EXISTS (");
write_select(b, subq)?;
b.sql.push(')');
Ok(())
}
WhereExpr::InSubquery {
column,
negated,
subquery,
} => {
let qualified = render_qualified_col(b.d, qualify_with, column);
b.sql.push_str(&qualified);
b.sql.push_str(if *negated { " NOT IN (" } else { " IN (" });
write_select(b, subquery)?;
b.sql.push(')');
Ok(())
}
WhereExpr::ExprCompare { lhs, op, rhs } => write_expr_compare(b, lhs, *op, rhs),
}
}
fn write_expr_compare(
b: &mut Sql<'_>,
lhs: &crate::core::Expr,
op: crate::core::Op,
rhs: &crate::core::Expr,
) -> Result<(), SqlError> {
use crate::core::{Expr, Op};
let binary_op_str = match op {
Op::Eq => Some(" = "),
Op::Ne => Some(" <> "),
Op::Lt => Some(" < "),
Op::Lte => Some(" <= "),
Op::Gt => Some(" > "),
Op::Gte => Some(" >= "),
Op::Like => Some(" LIKE "),
Op::NotLike => Some(" NOT LIKE "),
_ => None,
};
if let Some(kw) = binary_op_str {
write_expr(b, lhs, None)?;
b.sql.push_str(kw);
write_expr(b, rhs, None)?;
return Ok(());
}
match op {
Op::ILike | Op::NotILike => {
require_op(b.d, op)?;
let lhs_start = b.sql.len();
write_expr(b, lhs, None)?;
let lhs_str = b.sql.split_off(lhs_start);
let Expr::Literal(v) = rhs else {
return Err(SqlError::OpNotSupportedInDialect {
op: "ILIKE with non-literal RHS in ExprCompare",
dialect: b.d.name(),
});
};
b.params.push(v.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_ilike(&mut b.sql, &lhs_str, &p, matches!(op, Op::NotILike));
Ok(())
}
Op::In | Op::NotIn => {
let Expr::Literal(SqlValue::List(elements)) = rhs else {
return Err(SqlError::InRequiresList);
};
if elements.is_empty() {
return Err(SqlError::EmptyInList);
}
write_expr(b, lhs, None)?;
b.sql.push_str(if matches!(op, Op::In) {
" IN ("
} else {
" NOT IN ("
});
let mut first = true;
for elem in elements {
if !first {
b.sql.push_str(", ");
}
first = false;
b.push_param_typed(elem.clone(), None);
}
b.sql.push(')');
Ok(())
}
Op::Between => {
let Expr::Literal(SqlValue::List(bounds)) = rhs else {
return Err(SqlError::BetweenRequiresTwoElementList);
};
if bounds.len() != 2 {
return Err(SqlError::BetweenRequiresTwoElementList);
}
write_expr(b, lhs, None)?;
b.sql.push_str(" BETWEEN ");
b.push_param_typed(bounds[0].clone(), None);
b.sql.push_str(" AND ");
b.push_param_typed(bounds[1].clone(), None);
Ok(())
}
Op::IsNull => {
let Expr::Literal(SqlValue::Bool(is_null)) = rhs else {
return Err(SqlError::IsNullRequiresBool);
};
write_expr(b, lhs, None)?;
b.sql
.push_str(if *is_null { " IS NULL" } else { " IS NOT NULL" });
Ok(())
}
_ => Err(SqlError::OpNotSupportedInDialect {
op: "non-binary comparison in ExprCompare",
dialect: b.d.name(),
}),
}
}
fn write_column_compare(
b: &mut Sql<'_>,
cf: &crate::core::ColumnFilter,
qualify_with: Option<&str>,
_model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
let qualified = render_qualified_col(b.d, qualify_with, cf.column);
b.sql.push_str(&qualified);
let op_str = match cf.op {
crate::core::Op::Eq => " = ",
crate::core::Op::Ne => " <> ",
crate::core::Op::Lt => " < ",
crate::core::Op::Lte => " <= ",
crate::core::Op::Gt => " > ",
crate::core::Op::Gte => " >= ",
_ => {
return Err(SqlError::OpNotSupportedInDialect {
op: "non-binary comparison in ColumnCompare",
dialect: b.d.name(),
});
}
};
b.sql.push_str(op_str);
write_expr(b, &cf.rhs, None)?;
Ok(())
}
fn write_joined(
b: &mut Sql<'_>,
items: &[WhereExpr],
sep: &str,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
let mut first = true;
for child in items {
if !first {
b.sql.push_str(sep);
}
first = false;
write_child(b, child, qualify_with, model)?;
}
Ok(())
}
fn write_child(
b: &mut Sql<'_>,
expr: &WhereExpr,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
match expr {
WhereExpr::Predicate(filter) => write_filter(b, filter, qualify_with, model),
WhereExpr::ColumnCompare(cf) => write_column_compare(b, cf, qualify_with, model),
WhereExpr::Exists(_)
| WhereExpr::NotExists(_)
| WhereExpr::InSubquery { .. }
| WhereExpr::ExprCompare { .. } => write_where_expr(b, expr, qualify_with, model),
WhereExpr::And(_) | WhereExpr::Or(_) | WhereExpr::Xor(_) | WhereExpr::Not(_) => {
b.sql.push('(');
write_where_expr(b, expr, qualify_with, model)?;
b.sql.push(')');
Ok(())
}
}
}
fn write_xor(
b: &mut Sql<'_>,
items: &[WhereExpr],
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
match items.len() {
0 => Err(SqlError::EmptyXorBranch),
1 => write_where_expr(b, &items[0], qualify_with, model),
2 => {
b.sql.push('(');
write_child(b, &items[0], qualify_with, model)?;
b.sql.push_str(" AND NOT (");
write_where_expr(b, &items[1], qualify_with, model)?;
b.sql.push_str(")) OR (NOT (");
write_where_expr(b, &items[0], qualify_with, model)?;
b.sql.push_str(") AND ");
write_child(b, &items[1], qualify_with, model)?;
b.sql.push(')');
Ok(())
}
_ => {
b.sql.push('(');
let mut first = true;
for child in items {
if !first {
b.sql.push_str(" + ");
}
first = false;
b.sql.push_str("(CASE WHEN ");
write_child(b, child, qualify_with, model)?;
b.sql.push_str(" THEN 1 ELSE 0 END)");
}
b.sql.push_str(") % 2 = 1");
Ok(())
}
}
}
#[allow(clippy::too_many_lines)] fn write_filter(
b: &mut Sql<'_>,
filter: &Filter,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<(), SqlError> {
let qualified_col = render_qualified_col(b.d, qualify_with, filter.column);
let cast = model.and_then(|m| null_cast_for(b.d, m, filter.column));
match filter.op {
Op::Eq => simple_op(b, &qualified_col, " = ", filter.value.clone(), cast),
Op::Ne => simple_op(b, &qualified_col, " <> ", filter.value.clone(), cast),
Op::Lt => simple_op(b, &qualified_col, " < ", filter.value.clone(), cast),
Op::Lte => simple_op(b, &qualified_col, " <= ", filter.value.clone(), cast),
Op::Gt => simple_op(b, &qualified_col, " > ", filter.value.clone(), cast),
Op::Gte => simple_op(b, &qualified_col, " >= ", filter.value.clone(), cast),
Op::Like => simple_op(b, &qualified_col, " LIKE ", filter.value.clone(), cast),
Op::NotLike => simple_op(b, &qualified_col, " NOT LIKE ", filter.value.clone(), cast),
Op::ILike | Op::NotILike => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_ilike(
&mut b.sql,
&qualified_col,
&p,
matches!(filter.op, Op::NotILike),
);
}
Op::Regex | Op::NotRegex | Op::IRegex | Op::NotIRegex => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_regex(
&mut b.sql,
&qualified_col,
&p,
matches!(filter.op, Op::Regex | Op::NotRegex),
matches!(filter.op, Op::NotRegex | Op::NotIRegex),
);
}
Op::TrigramSimilar | Op::TrigramWordSimilar => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_trigram_similar(
&mut b.sql,
&qualified_col,
&p,
matches!(filter.op, Op::TrigramWordSimilar),
)?;
}
Op::Search => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_search(&mut b.sql, &qualified_col, &p)?;
}
Op::ArrayContains | Op::ArrayContainedBy | Op::ArrayOverlap => {
if !matches!(filter.value, SqlValue::Array(_)) {
return Err(SqlError::ArrayOpRequiresArray);
}
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
let op_str: &'static str = match filter.op {
Op::ArrayContains => "@>",
Op::ArrayContainedBy => "<@",
Op::ArrayOverlap => "&&",
_ => unreachable!(),
};
b.d.write_array_op(&mut b.sql, &qualified_col, &p, op_str)?;
}
Op::RangeContains
| Op::RangeContainedBy
| Op::RangeOverlap
| Op::RangeStrictlyLeft
| Op::RangeStrictlyRight
| Op::RangeAdjacent => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
let op_str: &'static str = match filter.op {
Op::RangeContains => "@>",
Op::RangeContainedBy => "<@",
Op::RangeOverlap => "&&",
Op::RangeStrictlyLeft => "<<",
Op::RangeStrictlyRight => ">>",
Op::RangeAdjacent => "-|-",
_ => unreachable!(),
};
b.d.write_range_op(&mut b.sql, &qualified_col, &p, op_str)?;
}
Op::In | Op::NotIn => {
let SqlValue::List(elements) = &filter.value else {
return Err(SqlError::InRequiresList);
};
if elements.is_empty() {
return Err(SqlError::EmptyInList);
}
b.sql.push_str(&qualified_col);
b.sql.push_str(if matches!(filter.op, Op::In) {
" IN ("
} else {
" NOT IN ("
});
let mut first = true;
for elem in elements {
if !first {
b.sql.push_str(", ");
}
first = false;
b.push_param_typed(elem.clone(), cast);
}
b.sql.push(')');
}
Op::Between => {
let SqlValue::List(bounds) = &filter.value else {
return Err(SqlError::BetweenRequiresTwoElementList);
};
if bounds.len() != 2 {
return Err(SqlError::BetweenRequiresTwoElementList);
}
b.sql.push_str(&qualified_col);
b.sql.push_str(" BETWEEN ");
b.push_param_typed(bounds[0].clone(), cast);
b.sql.push_str(" AND ");
b.push_param_typed(bounds[1].clone(), cast);
}
Op::IsNull => {
let SqlValue::Bool(is_null) = filter.value else {
return Err(SqlError::IsNullRequiresBool);
};
b.sql.push_str(&qualified_col);
b.sql
.push_str(if is_null { " IS NULL" } else { " IS NOT NULL" });
}
Op::IsDistinctFrom | Op::IsNotDistinctFrom => {
require_op(b.d, filter.op)?;
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_null_safe_eq(
&mut b.sql,
&qualified_col,
&p,
matches!(filter.op, Op::IsDistinctFrom),
);
}
Op::JsonContains => {
require_op(b.d, filter.op)?;
let SqlValue::Json(_) = &filter.value else {
return Err(SqlError::JsonOpRequiresJson);
};
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_json_contains(&mut b.sql, &qualified_col, &p);
}
Op::JsonContainedBy => {
require_op(b.d, filter.op)?;
let SqlValue::Json(_) = &filter.value else {
return Err(SqlError::JsonOpRequiresJson);
};
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_json_contained_by(&mut b.sql, &qualified_col, &p);
}
Op::JsonHasKey => {
require_op(b.d, filter.op)?;
let SqlValue::String(_) = &filter.value else {
return Err(SqlError::JsonKeyRequiresString);
};
b.params.push(filter.value.clone());
let p = b.d.placeholder(b.params.len());
b.d.write_json_has_key(&mut b.sql, &qualified_col, &p);
}
Op::JsonHasAnyKey | Op::JsonHasAllKeys => {
require_op(b.d, filter.op)?;
let SqlValue::List(keys) = &filter.value else {
return Err(SqlError::JsonKeysRequiresList);
};
let placeholders = bind_param_list(b, keys);
if matches!(filter.op, Op::JsonHasAnyKey) {
b.d.write_json_has_any_keys(&mut b.sql, &qualified_col, &placeholders);
} else {
b.d.write_json_has_all_keys(&mut b.sql, &qualified_col, &placeholders);
}
}
}
Ok(())
}
fn simple_op(
b: &mut Sql<'_>,
qualified_col: &str,
kw: &str,
value: SqlValue,
cast: Option<&'static str>,
) {
b.sql.push_str(qualified_col);
b.sql.push_str(kw);
b.push_param_typed(value, cast);
}
fn render_qualified_col(d: &dyn Dialect, qualify_with: Option<&str>, column: &str) -> String {
let mut s = String::new();
if let Some(table) = qualify_with {
s.push_str(&d.quote_ident(table));
s.push('.');
}
s.push_str(&d.quote_ident(column));
s
}
fn bind_param_list(b: &mut Sql<'_>, values: &[SqlValue]) -> Vec<String> {
let mut out = Vec::with_capacity(values.len());
for v in values {
b.params.push(v.clone());
out.push(b.d.placeholder(b.params.len()));
}
out
}
fn require_op(d: &dyn Dialect, op: Op) -> Result<(), SqlError> {
if d.supports_op(op) {
Ok(())
} else {
Err(SqlError::OperatorNotSupportedInDialect {
op: op_label(op),
dialect: d.name(),
})
}
}
fn op_label(op: Op) -> &'static str {
match op {
Op::Eq => "=",
Op::Ne => "<>",
Op::Lt => "<",
Op::Lte => "<=",
Op::Gt => ">",
Op::Gte => ">=",
Op::In => "IN",
Op::NotIn => "NOT IN",
Op::Like => "LIKE",
Op::NotLike => "NOT LIKE",
Op::ILike => "ILIKE",
Op::NotILike => "NOT ILIKE",
Op::Between => "BETWEEN",
Op::IsNull => "IS NULL",
Op::IsDistinctFrom => "IS DISTINCT FROM",
Op::IsNotDistinctFrom => "IS NOT DISTINCT FROM",
Op::JsonContains => "@>",
Op::JsonContainedBy => "<@",
Op::JsonHasKey => "? (json)",
Op::JsonHasAnyKey => "?| (json)",
Op::JsonHasAllKeys => "?& (json)",
Op::Regex => "~ (regex)",
Op::NotRegex => "!~ (regex)",
Op::TrigramSimilar => "% (trigram_similar)",
Op::TrigramWordSimilar => "%> (trigram_word_similar)",
Op::Search => "@@ (search)",
Op::ArrayContains => "@> (array_contains)",
Op::ArrayContainedBy => "<@ (array_contained_by)",
Op::ArrayOverlap => "&& (array_overlap)",
Op::RangeContains => "@> (range_contains)",
Op::RangeContainedBy => "<@ (range_contained_by)",
Op::RangeOverlap => "&& (range_overlap)",
Op::RangeStrictlyLeft => "<< (range_strictly_left)",
Op::RangeStrictlyRight => ">> (range_strictly_right)",
Op::RangeAdjacent => "-|- (range_adjacent)",
Op::IRegex => "~* (iregex)",
Op::NotIRegex => "!~* (iregex)",
}
}
fn write_order_limit_offset(
b: &mut Sql<'_>,
order_by: &[crate::core::OrderItem],
limit: Option<i64>,
offset: Option<i64>,
qualify_with: Option<&str>,
) -> Result<(), SqlError> {
use crate::core::{NullsOrder, OrderItem};
if !order_by.is_empty() {
b.sql.push_str(" ORDER BY ");
let supports_nulls = b.d.supports_nulls_order();
let mut first = true;
for item in order_by {
let (desc, nulls) = match item {
OrderItem::Column { desc, nulls, .. } => (*desc, *nulls),
OrderItem::Expr { desc, nulls, .. } => (*desc, *nulls),
OrderItem::Random => (false, NullsOrder::Default),
};
if !supports_nulls && !matches!(nulls, NullsOrder::Default) {
if !first {
b.sql.push_str(", ");
}
first = false;
write_order_target(b, item, qualify_with)?;
b.sql.push_str(" IS NULL");
match nulls {
NullsOrder::First => b.sql.push_str(" DESC"),
NullsOrder::Last => b.sql.push_str(" ASC"),
NullsOrder::Default => unreachable!(),
}
}
if !first {
b.sql.push_str(", ");
}
first = false;
write_order_target(b, item, qualify_with)?;
if desc {
b.sql.push_str(" DESC");
}
if supports_nulls {
match nulls {
NullsOrder::First => b.sql.push_str(" NULLS FIRST"),
NullsOrder::Last => b.sql.push_str(" NULLS LAST"),
NullsOrder::Default => {}
}
}
}
}
if let Some(n) = limit {
let _ = write!(b.sql, " LIMIT {n}");
}
if let Some(n) = offset {
let _ = write!(b.sql, " OFFSET {n}");
}
Ok(())
}
fn write_order_target(
b: &mut Sql<'_>,
item: &crate::core::OrderItem,
qualify_with: Option<&str>,
) -> Result<(), SqlError> {
use crate::core::OrderItem;
match item {
OrderItem::Column { column, .. } => {
if let Some(table) = qualify_with {
b.write_ident(table);
b.sql.push('.');
}
b.write_ident(column);
}
OrderItem::Expr { expr, .. } => {
write_expr(b, expr, None)?;
}
OrderItem::Random => {
b.sql.push_str(b.d.random_fn());
b.sql.push_str("()");
}
}
Ok(())
}
fn write_returning(b: &mut Sql<'_>, returning: &[&'static str]) -> Result<(), SqlError> {
if returning.is_empty() {
return Ok(());
}
if !b.d.supports_returning() {
return Err(SqlError::OperatorNotSupportedInDialect {
op: "RETURNING",
dialect: b.d.name(),
});
}
b.sql.push_str(" RETURNING ");
let mut first = true;
for col in returning {
if !first {
b.sql.push_str(", ");
}
first = false;
b.write_ident(col);
}
Ok(())
}
#[allow(unused)]
#[allow(clippy::too_many_arguments)] pub(crate) fn compile_where_order_tail(
d: &dyn Dialect,
where_clause: &WhereExpr,
search: Option<&SearchClause>,
order_by: &[crate::core::OrderItem],
limit: Option<i64>,
offset: Option<i64>,
qualify_with: Option<&str>,
model: Option<&'static ModelSchema>,
) -> Result<CompiledStatement, SqlError> {
let mut b = Sql::new(d);
write_where_with_search(&mut b, where_clause, search, qualify_with, model)?;
write_order_limit_offset(&mut b, order_by, limit, offset, qualify_with)?;
Ok(b.finish())
}