use std::marker::PhantomData;
use crate::core::{
AggregateExpr, AggregateQuery, Assignment, DeleteQuery, Filter, Model, ModelSchema, Op,
QueryError, SelectQuery, SqlValue, TypedAssignment, TypedExpr, UpdateQuery, WhereExpr,
};
pub struct QuerySet<T: Model> {
pending: Vec<PendingFilter>,
limit: Option<i64>,
offset: Option<i64>,
select_related: Vec<String>,
ad_hoc_joins: Vec<crate::core::Join>,
order_by: Vec<PendingOrderItem>,
lock_mode: Option<crate::core::LockMode>,
compound: Vec<crate::core::CompoundBranch>,
_model: PhantomData<fn() -> T>,
}
#[derive(Debug, Clone)]
enum PendingOrderItem {
Field {
name: String,
desc: bool,
nulls: crate::core::NullsOrder,
},
Expr {
expr: crate::core::Expr,
desc: bool,
nulls: crate::core::NullsOrder,
},
Random,
}
enum PendingFilter {
Raw(RawFilter),
Resolved(Filter),
Expr(WhereExpr),
Error(QueryError),
}
#[derive(Debug, Clone)]
struct RawFilter {
field: String,
op: Op,
value: SqlValue,
}
#[derive(Debug, Clone)]
struct RawAssignment {
field: String,
value: SqlValue,
}
#[derive(Debug, Clone)]
struct RawExprAssignment {
field: String,
value: crate::core::Expr,
}
impl<T: Model> Default for QuerySet<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Model> QuerySet<T> {
#[must_use]
pub fn new() -> Self {
Self {
pending: Vec::new(),
limit: None,
offset: None,
select_related: Vec::new(),
ad_hoc_joins: Vec::new(),
order_by: Vec::new(),
lock_mode: None,
compound: Vec::new(),
_model: PhantomData,
}
}
#[must_use]
pub fn select_for_update(mut self) -> Self {
self.lock_mode = Some(crate::core::LockMode::default());
self
}
#[must_use]
pub fn skip_locked(mut self) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.skip_locked = true;
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn nowait(mut self) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.nowait = true;
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn no_key(mut self) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.no_key = true;
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn of(mut self, tables: &[&'static str]) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.of.extend_from_slice(tables);
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn union(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::Union, other)
}
#[must_use]
pub fn union_all(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::UnionAll, other)
}
#[must_use]
pub fn intersection(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::Intersection, other)
}
#[must_use]
pub fn difference(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::Difference, other)
}
#[must_use]
pub fn with_compound(self, op: crate::core::SetOp, branch: crate::core::SelectQuery) -> Self {
self.add_compound_compiled(op, branch)
}
fn add_compound(self, op: crate::core::SetOp, other: QuerySet<T>) -> Self {
match other.compile() {
Ok(branch) => self.add_compound_compiled(op, branch),
Err(e) => panic!(
"rustango: set-algebra branch failed to compile: {e}. \
Pre-compile the branch and pass via .with_compound(op, \
branch) to surface this error as a Result."
),
}
}
fn add_compound_compiled(
mut self,
op: crate::core::SetOp,
branch: crate::core::SelectQuery,
) -> Self {
self.compound.push(crate::core::CompoundBranch {
op,
query: Box::new(branch),
});
self
}
#[must_use]
pub fn order_by(mut self, items: &[(&str, bool)]) -> Self {
for (field, desc) in items {
self.order_by.push(PendingOrderItem::Field {
name: (*field).to_owned(),
desc: *desc,
nulls: crate::core::NullsOrder::Default,
});
}
self
}
#[must_use]
pub fn order_by_with_nulls(
mut self,
items: &[(&'static str, bool, crate::core::NullsOrder)],
) -> Self {
for (field, desc, nulls) in items {
self.order_by.push(PendingOrderItem::Field {
name: (*field).to_owned(),
desc: *desc,
nulls: *nulls,
});
}
self
}
#[must_use]
pub fn order_by_expr(mut self, expr: impl Into<crate::core::Expr>, desc: bool) -> Self {
self.order_by.push(PendingOrderItem::Expr {
expr: expr.into(),
desc,
nulls: crate::core::NullsOrder::Default,
});
self
}
#[must_use]
pub fn order_by_expr_with_nulls(
mut self,
expr: impl Into<crate::core::Expr>,
desc: bool,
nulls: crate::core::NullsOrder,
) -> Self {
self.order_by.push(PendingOrderItem::Expr {
expr: expr.into(),
desc,
nulls,
});
self
}
#[must_use]
pub fn order_random(mut self) -> Self {
self.order_by.push(PendingOrderItem::Random);
self
}
#[must_use]
pub fn replace_order_by(mut self, items: &[(&str, bool)]) -> Self {
self.order_by.clear();
self.order_by(items)
}
#[must_use]
pub fn flip_order_by(mut self) -> Self {
for entry in &mut self.order_by {
match entry {
PendingOrderItem::Field { desc, nulls, .. }
| PendingOrderItem::Expr { desc, nulls, .. } => {
*desc = !*desc;
*nulls = match *nulls {
crate::core::NullsOrder::First => crate::core::NullsOrder::Last,
crate::core::NullsOrder::Last => crate::core::NullsOrder::First,
crate::core::NullsOrder::Default => crate::core::NullsOrder::Default,
};
}
PendingOrderItem::Random => {}
}
}
self
}
#[must_use]
pub fn has_order_by(&self) -> bool {
!self.order_by.is_empty()
}
#[must_use]
pub fn select_related(mut self, field: impl Into<String>) -> Self {
self.select_related.push(field.into());
self
}
#[must_use]
pub fn join(mut self, join: crate::core::Join) -> Self {
self.ad_hoc_joins.push(join);
self
}
#[must_use]
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
#[must_use]
pub fn filter_op(
mut self,
field: impl Into<String>,
op: Op,
value: impl Into<SqlValue>,
) -> Self {
self.pending.push(PendingFilter::Raw(RawFilter {
field: field.into(),
op,
value: value.into(),
}));
self
}
#[must_use]
pub fn filter(self, key: &str, value: impl Into<SqlValue>) -> Self {
let raw = value.into();
match parse_lookup(key, raw) {
Ok((field, op, parsed_value)) => self.filter_op(field, op, parsed_value),
Err(e) => self.with_pending_error(e),
}
}
fn with_pending_error(mut self, e: QueryError) -> Self {
self.pending.push(PendingFilter::Error(e));
self
}
#[must_use]
pub fn eq(self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.filter_op(field, Op::Eq, value)
}
#[must_use]
pub fn where_raw(mut self, expr: WhereExpr) -> Self {
self.pending.push(PendingFilter::Expr(expr));
self
}
#[must_use]
pub fn where_<E: Into<TypedExpr<T>>>(mut self, predicate: E) -> Self {
let expr = predicate.into().into_expr();
match expr {
WhereExpr::Predicate(filter) => {
self.pending.push(PendingFilter::Resolved(filter));
}
other => {
self.pending.push(PendingFilter::Expr(other));
}
}
self
}
pub fn compile(self) -> Result<SelectQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let where_clause = resolve_pending(model, self.pending)?;
let mut joins = lower_select_related(model, &self.select_related)?;
joins.extend(self.ad_hoc_joins);
let order_by = lower_order_items(model, self.order_by)?;
Ok(SelectQuery {
model,
where_clause,
search: None,
joins,
order_by,
limit: self.limit,
offset: self.offset,
lock_mode: self.lock_mode,
compound: self.compound,
projection: None,
})
}
#[must_use]
pub fn values_dict(self, cols: &[&'static str]) -> ValuesQuerySet<T> {
ValuesQuerySet {
qs: self,
cols: cols.to_vec(),
}
}
#[must_use]
pub fn values_list(self, cols: &[&'static str]) -> ValuesListQuerySet<T> {
ValuesListQuerySet {
qs: self,
cols: cols.to_vec(),
}
}
#[must_use]
pub fn values_list_flat(self, col: &'static str) -> ValuesFlatQuerySet<T> {
ValuesFlatQuerySet { qs: self, col }
}
#[must_use]
pub fn only(self, cols: &[&'static str]) -> ValuesQuerySet<T> {
self.values_dict(cols)
}
#[must_use]
pub fn defer(self, cols: &[&'static str]) -> ValuesQuerySet<T> {
let model = T::SCHEMA;
let exclude: std::collections::HashSet<&'static str> = cols.iter().copied().collect();
let mut projection: Vec<&'static str> = model
.scalar_fields()
.filter(|f| !exclude.contains(f.column))
.map(|f| f.column)
.collect();
for &col in cols {
if model.field_by_column(col).is_none() {
projection.push(col);
}
}
self.values_dict(&projection)
}
pub fn compile_delete(self) -> Result<DeleteQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let where_clause = resolve_pending(model, self.pending)?;
Ok(DeleteQuery {
model,
where_clause,
})
}
#[must_use]
pub fn update(self) -> UpdateBuilder<T> {
UpdateBuilder {
qs: self,
set: Vec::new(),
}
}
#[must_use]
pub fn aggregate(self) -> AggregateBuilder<T> {
AggregateBuilder {
qs: self,
group_by: Vec::new(),
aggregates: Vec::new(),
having: None,
order_by: Vec::new(),
limit: None,
offset: None,
deferred_error: None,
values: None,
}
}
#[must_use]
pub fn values(self, columns: &[&'static str]) -> AggregateBuilder<T> {
AggregateBuilder {
qs: self,
group_by: Vec::new(),
aggregates: Vec::new(),
having: None,
order_by: Vec::new(),
limit: None,
offset: None,
deferred_error: None,
values: Some(columns.to_vec()),
}
}
#[must_use]
pub fn annotate(self, alias: &'static str, expr: AggregateExpr) -> AggregateBuilder<T> {
self.aggregate().annotate(alias, expr)
}
}
pub struct UpdateBuilder<T: Model> {
qs: QuerySet<T>,
set: Vec<PendingAssignment>,
}
enum PendingAssignment {
Raw(RawAssignment),
RawExpr(RawExprAssignment),
Resolved(Assignment),
}
impl<T: Model> UpdateBuilder<T> {
#[must_use]
pub fn set(mut self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.set.push(PendingAssignment::Raw(RawAssignment {
field: field.into(),
value: value.into(),
}));
self
}
#[must_use]
pub fn set_typed(mut self, assignment: TypedAssignment<T>) -> Self {
self.set
.push(PendingAssignment::Resolved(assignment.into_assignment()));
self
}
#[must_use]
pub fn set_expr(
mut self,
field: impl Into<String>,
expr: impl Into<crate::core::Expr>,
) -> Self {
self.set.push(PendingAssignment::RawExpr(RawExprAssignment {
field: field.into(),
value: expr.into(),
}));
self
}
pub fn compile(self) -> Result<UpdateQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let assignments = self
.set
.into_iter()
.map(|p| match p {
PendingAssignment::Raw(raw) => resolve_assignment(model, raw),
PendingAssignment::RawExpr(raw) => resolve_assignment_expr(model, raw),
PendingAssignment::Resolved(assignment) => Ok(assignment),
})
.collect::<Result<Vec<_>, _>>()?;
let where_clause = resolve_pending(model, self.qs.pending)?;
Ok(UpdateQuery {
model,
set: assignments,
where_clause,
})
}
}
fn lower_order_items(
model: &'static ModelSchema,
items: Vec<PendingOrderItem>,
) -> Result<Vec<crate::core::OrderItem>, QueryError> {
let mut out = Vec::with_capacity(items.len());
for item in items {
match item {
PendingOrderItem::Field { name, desc, nulls } => {
let field = model.field(&name).ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: name.clone(),
})?;
out.push(crate::core::OrderItem::column_with_nulls(
field.column,
desc,
nulls,
));
}
PendingOrderItem::Expr { expr, desc, nulls } => {
out.push(crate::core::OrderItem::expr_with_nulls(expr, desc, nulls));
}
PendingOrderItem::Random => {
out.push(crate::core::OrderItem::random());
}
}
}
Ok(out)
}
fn lower_select_related(
model: &'static ModelSchema,
names: &[String],
) -> Result<Vec<crate::core::Join>, QueryError> {
use crate::core::{inventory, Expr, Join, JoinKind, ModelEntry, Op, Relation, WhereExpr};
let mut out: Vec<Join> = Vec::with_capacity(names.len());
for name in names {
let field = model
.field(name)
.ok_or_else(|| QueryError::SelectRelatedInvalid {
model: model.name,
field: name.clone(),
reason: format!("no field `{name}` on this model"),
})?;
let (to, on) = match field.relation {
Some(Relation::Fk { to, on }) | Some(Relation::O2O { to, on }) => (to, on),
_ => {
return Err(QueryError::SelectRelatedInvalid {
model: model.name,
field: name.clone(),
reason: "not a `ForeignKey<T>` field".into(),
});
}
};
let target = inventory::iter::<ModelEntry>
.into_iter()
.find(|e| e.schema.table == to)
.map(|e| e.schema)
.ok_or_else(|| QueryError::SelectRelatedInvalid {
model: model.name,
field: name.clone(),
reason: format!(
"target table `{to}` is not registered (is the parent's `#[derive(Model)]` linked into the binary?)"
),
})?;
let project: Vec<&'static str> = target.scalar_fields().map(|f| f.column).collect();
let alias = field.name;
out.push(Join {
target,
alias,
kind: JoinKind::Left,
on: WhereExpr::ExprCompare {
lhs: Expr::AliasedColumn {
alias: model.table,
column: field.column,
},
op: Op::Eq,
rhs: Expr::AliasedColumn { alias, column: on },
},
project,
});
}
Ok(out)
}
fn parse_lookup(key: &str, value: SqlValue) -> Result<(String, Op, SqlValue), QueryError> {
let Some(split_at) = key.find("__") else {
return Ok((key.to_owned(), Op::Eq, value));
};
let field = key[..split_at].to_owned();
let suffix = &key[split_at + 2..];
match suffix {
"exact" => Ok((field, Op::Eq, value)),
"ne" => Ok((field, Op::Ne, value)),
"gt" => Ok((field, Op::Gt, value)),
"gte" => Ok((field, Op::Gte, value)),
"lt" => Ok((field, Op::Lt, value)),
"lte" => Ok((field, Op::Lte, value)),
"iexact" => Ok((field, Op::ILike, value)),
"contains" => {
let v = wrap_like(&value, "%", "%", &field, suffix)?;
Ok((field, Op::Like, v))
}
"icontains" => {
let v = wrap_like(&value, "%", "%", &field, suffix)?;
Ok((field, Op::ILike, v))
}
"startswith" => {
let v = wrap_like(&value, "", "%", &field, suffix)?;
Ok((field, Op::Like, v))
}
"istartswith" => {
let v = wrap_like(&value, "", "%", &field, suffix)?;
Ok((field, Op::ILike, v))
}
"endswith" => {
let v = wrap_like(&value, "%", "", &field, suffix)?;
Ok((field, Op::Like, v))
}
"iendswith" => {
let v = wrap_like(&value, "%", "", &field, suffix)?;
Ok((field, Op::ILike, v))
}
"in" => {
if !matches!(value, SqlValue::List(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List(...)",
actual: sql_value_shape_name(&value),
});
}
Ok((field, Op::In, value))
}
"isnull" => {
if !matches!(value, SqlValue::Bool(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::Bool(true|false)",
actual: sql_value_shape_name(&value),
});
}
Ok((field, Op::IsNull, value))
}
"between" | "range" => {
match &value {
SqlValue::List(items) if items.len() == 2 => {}
SqlValue::List(_) => {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List with exactly 2 elements [lo, hi]",
actual: "SqlValue::List with wrong arity",
});
}
other => {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List([lo, hi])",
actual: sql_value_shape_name(other),
});
}
}
Ok((field, Op::Between, value))
}
"regex" | "iregex" => {
if !matches!(value, SqlValue::String(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<regex pattern>)",
actual: sql_value_shape_name(&value),
});
}
let op = if suffix == "regex" {
Op::Regex
} else {
Op::IRegex
};
Ok((field, op, value))
}
"trigram_similar" | "trigram_word_similar" => {
if !matches!(value, SqlValue::String(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<trigram pattern>)",
actual: sql_value_shape_name(&value),
});
}
let op = if suffix == "trigram_similar" {
Op::TrigramSimilar
} else {
Op::TrigramWordSimilar
};
Ok((field, op, value))
}
"search" => {
if !matches!(value, SqlValue::String(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<search query>)",
actual: sql_value_shape_name(&value),
});
}
Ok((field, Op::Search, value))
}
"range_contains"
| "range_contained_by"
| "range_overlap"
| "range_strictly_left"
| "range_strictly_right"
| "range_adjacent" => {
let literal = match value {
SqlValue::String(s) => s,
SqlValue::RangeLiteral(s) => s,
other => {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<PG range literal>) — e.g. \"[1, 10)\"",
actual: sql_value_shape_name(&other),
});
}
};
let op = match suffix {
"range_contains" => Op::RangeContains,
"range_contained_by" => Op::RangeContainedBy,
"range_overlap" => Op::RangeOverlap,
"range_strictly_left" => Op::RangeStrictlyLeft,
"range_strictly_right" => Op::RangeStrictlyRight,
"range_adjacent" => Op::RangeAdjacent,
_ => unreachable!(),
};
Ok((field, op, SqlValue::RangeLiteral(literal)))
}
"array_contains" | "array_contained_by" | "array_overlap" => {
let SqlValue::List(elems) = value else {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List(<elements>) — typed homogeneous array",
actual: sql_value_shape_name(&value),
});
};
let op = match suffix {
"array_contains" => Op::ArrayContains,
"array_contained_by" => Op::ArrayContainedBy,
"array_overlap" => Op::ArrayOverlap,
_ => unreachable!(),
};
Ok((field, op, SqlValue::Array(elems)))
}
unknown => Err(QueryError::UnknownLookup {
field,
suffix: unknown.to_owned(),
}),
}
}
fn wrap_like(
value: &SqlValue,
prefix: &str,
suffix_char: &str,
field: &str,
suffix: &str,
) -> Result<SqlValue, QueryError> {
let s = match value {
SqlValue::String(s) => s,
other => {
return Err(QueryError::InvalidLookupValue {
field: field.to_owned(),
suffix: suffix.to_owned(),
expected: "SqlValue::String(...)",
actual: sql_value_shape_name(other),
});
}
};
Ok(SqlValue::String(format!("{prefix}{s}{suffix_char}")))
}
fn sql_value_shape_name(v: &SqlValue) -> &'static str {
match v {
SqlValue::Null => "SqlValue::Null",
SqlValue::I16(_) => "SqlValue::I16",
SqlValue::I32(_) => "SqlValue::I32",
SqlValue::I64(_) => "SqlValue::I64",
SqlValue::F32(_) => "SqlValue::F32",
SqlValue::F64(_) => "SqlValue::F64",
SqlValue::Bool(_) => "SqlValue::Bool",
SqlValue::String(_) => "SqlValue::String",
SqlValue::List(_) => "SqlValue::List",
_ => "SqlValue::<other>",
}
}
fn resolve_pending(
model: &'static ModelSchema,
pending: Vec<PendingFilter>,
) -> Result<WhereExpr, QueryError> {
let mut nodes: Vec<WhereExpr> = Vec::with_capacity(pending.len());
for entry in pending {
match entry {
PendingFilter::Raw(raw) => {
nodes.push(WhereExpr::Predicate(resolve_filter(model, raw)?));
}
PendingFilter::Resolved(filter) => {
nodes.push(WhereExpr::Predicate(filter));
}
PendingFilter::Expr(expr) => {
nodes.push(expr);
}
PendingFilter::Error(e) => {
return Err(e);
}
}
}
Ok(WhereExpr::And(nodes))
}
fn resolve_filter(model: &'static ModelSchema, raw: RawFilter) -> Result<Filter, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
let skip_type_check = matches!(raw.op, Op::IsNull | Op::In);
if !skip_type_check {
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
}
Ok(Filter {
column: field.column,
op: raw.op,
value: raw.value,
})
}
fn resolve_assignment(
model: &'static ModelSchema,
raw: RawAssignment,
) -> Result<Assignment, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
Ok(Assignment {
column: field.column,
value: raw.value.into(),
})
}
fn resolve_assignment_expr(
model: &'static ModelSchema,
raw: RawExprAssignment,
) -> Result<Assignment, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
validate_expr_columns_in_model(model, &raw.value)?;
Ok(Assignment {
column: field.column,
value: raw.value,
})
}
fn validate_expr_columns_in_model(
model: &'static ModelSchema,
expr: &crate::core::Expr,
) -> Result<(), QueryError> {
use crate::core::Expr;
match expr {
Expr::Literal(_) => Ok(()),
Expr::Column(name) => {
if model.field_by_column(name).is_none() {
Err(QueryError::UnknownField {
model: model.name,
field: (*name).to_owned(),
})
} else {
Ok(())
}
}
Expr::BinOp { left, right, .. } => {
validate_expr_columns_in_model(model, left)?;
validate_expr_columns_in_model(model, right)
}
Expr::Function { args, .. } => {
for a in args {
validate_expr_columns_in_model(model, a)?;
}
Ok(())
}
Expr::Case { branches, default } => {
for b in branches {
b.condition.validate(model)?;
validate_expr_columns_in_model(model, &b.then)?;
}
if let Some(d) = default {
validate_expr_columns_in_model(model, d)?;
}
Ok(())
}
Expr::Subquery(_) | Expr::OuterRef(_) | Expr::AliasedColumn { .. } => Ok(()),
Expr::Window(w) => {
for col in &w.partition_by {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
for o in &w.order_by {
if model.field_by_column(o.column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: o.column.to_owned(),
});
}
}
for arg in &w.args {
validate_expr_columns_in_model(model, arg)?;
}
Ok(())
}
Expr::Aggregate(_) => Ok(()),
}
}
pub struct AggregateBuilder<T: Model> {
qs: QuerySet<T>,
group_by: Vec<&'static str>,
aggregates: Vec<(&'static str, AggregateExpr)>,
having: Option<WhereExpr>,
order_by: Vec<(&'static str, bool)>,
limit: Option<i64>,
offset: Option<i64>,
deferred_error: Option<crate::core::QueryError>,
values: Option<Vec<&'static str>>,
}
impl<T: Model> AggregateBuilder<T> {
#[must_use]
pub fn group_by(mut self, column: &'static str) -> Self {
self.group_by.push(column);
self
}
#[must_use]
pub fn values(mut self, columns: &[&'static str]) -> Self {
self.values = Some(columns.to_vec());
self
}
#[must_use]
pub fn annotate(mut self, alias: &'static str, expr: AggregateExpr) -> Self {
self.aggregates.push((alias, expr));
self
}
#[must_use]
pub fn having<E: Into<crate::core::TypedExpr<T>>>(mut self, predicate: E) -> Self {
let expr = predicate.into().into_expr();
match self.having {
None => self.having = Some(expr),
Some(ref mut existing) => existing.push_and(expr),
}
self
}
#[must_use]
pub fn filter(
mut self,
field: &'static str,
op: crate::core::Op,
value: impl Into<crate::core::SqlValue>,
) -> Self {
if self.deferred_error.is_some() {
return self;
}
let agg = self
.aggregates
.iter()
.find(|(alias, _)| *alias == field)
.map(|(_, expr)| expr.clone());
if let Some(agg) = agg {
if matches!(
op,
crate::core::Op::JsonContains
| crate::core::Op::JsonContainedBy
| crate::core::Op::JsonHasKey
| crate::core::Op::JsonHasAnyKey
| crate::core::Op::JsonHasAllKeys
| crate::core::Op::IsDistinctFrom
| crate::core::Op::IsNotDistinctFrom
) {
self.deferred_error = Some(crate::core::QueryError::HavingOpNotSupported {
alias: field.to_owned(),
op,
});
return self;
}
let pred = WhereExpr::ExprCompare {
lhs: crate::core::Expr::Aggregate(Box::new(agg)),
op,
rhs: crate::core::Expr::Literal(value.into()),
};
match self.having {
None => self.having = Some(pred),
Some(ref mut existing) => existing.push_and(pred),
}
} else {
self.qs = self.qs.filter_op(field, op, value);
}
self
}
#[must_use]
pub fn order_by(mut self, items: &[(&'static str, bool)]) -> Self {
self.order_by.extend_from_slice(items);
self
}
#[must_use]
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
pub fn compile(self) -> Result<AggregateQuery, QueryError> {
if let Some(e) = self.deferred_error {
return Err(e);
}
let model = T::SCHEMA;
let where_clause = resolve_pending(model, self.qs.pending)?;
for (_alias, expr) in &self.aggregates {
validate_aggregate_expr_columns(model, expr)?;
}
let order_by = self
.order_by
.into_iter()
.map(|(col, desc)| crate::core::OrderItem::column(col, desc))
.collect();
let has_aggregating = self.aggregates.iter().any(|(_, e)| e.is_aggregating());
let group_by = if !self.group_by.is_empty() {
for col in &self.group_by {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
self.group_by
} else if let Some(cols) = self.values.as_ref() {
if !has_aggregating {
return Err(QueryError::ValuesRequiresAggregate { cols: cols.clone() });
}
for col in cols {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
cols.clone()
} else if has_aggregating {
model.scalar_fields().map(|f| f.column).collect()
} else {
Vec::new()
};
Ok(AggregateQuery {
model,
where_clause,
group_by,
aggregates: self.aggregates,
having: self.having,
order_by,
limit: self.limit,
offset: self.offset,
})
}
}
fn validate_aggregate_expr_columns(
model: &'static ModelSchema,
expr: &crate::core::AggregateExpr,
) -> Result<(), QueryError> {
use crate::core::AggregateExpr;
match expr {
AggregateExpr::Filtered { inner, filter } => {
filter.validate(model)?;
validate_aggregate_expr_columns(model, inner)
}
AggregateExpr::Coalesced { inner, .. } => validate_aggregate_expr_columns(model, inner),
AggregateExpr::Window(w) => {
for col in &w.partition_by {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
for o in &w.order_by {
if model.field_by_column(o.column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: o.column.to_owned(),
});
}
}
for arg in &w.args {
validate_expr_columns_in_model(model, arg)?;
}
Ok(())
}
_ => Ok(()),
}
}
pub struct ValuesQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) cols: Vec<&'static str>,
}
pub struct ValuesListQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) cols: Vec<&'static str>,
}
pub struct ValuesFlatQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) col: &'static str,
}
impl<T: Model> ValuesQuerySet<T> {
pub fn compile(self) -> Result<SelectQuery, QueryError> {
compile_values_select(self.qs, self.cols)
}
#[must_use]
pub fn columns(&self) -> &[&'static str] {
&self.cols
}
}
impl<T: Model> ValuesListQuerySet<T> {
pub fn compile(self) -> Result<SelectQuery, QueryError> {
compile_values_select(self.qs, self.cols)
}
#[must_use]
pub fn columns(&self) -> &[&'static str] {
&self.cols
}
}
impl<T: Model> ValuesFlatQuerySet<T> {
pub fn compile(self) -> Result<SelectQuery, QueryError> {
compile_values_select(self.qs, vec![self.col])
}
}
fn compile_values_select<T: Model>(
qs: QuerySet<T>,
cols: Vec<&'static str>,
) -> Result<SelectQuery, QueryError> {
if cols.is_empty() {
return Err(QueryError::EmptyValuesProjection);
}
let model: &'static ModelSchema = T::SCHEMA;
for col in &cols {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
let mut q = qs.compile()?;
q.projection = Some(cols);
Ok(q)
}