use std::marker::PhantomData;
use crate::core::{
AggregateExpr, AggregateQuery, Assignment, DeleteQuery, Expr, Filter, Model, ModelSchema, Op,
QueryError, ScalarFn, SelectQuery, SqlValue, TypedAssignment, TypedExpr, UpdateQuery,
WhereExpr,
};
mod q;
pub use q::Q;
pub struct QuerySet<T: Model> {
pending: Vec<PendingFilter>,
limit: Option<i64>,
offset: Option<i64>,
distinct: Option<crate::core::DistinctMode>,
select_related: Vec<String>,
ad_hoc_joins: Vec<crate::core::Join>,
subquery_joins: Vec<crate::core::SubqueryJoin>,
order_by: Vec<PendingOrderItem>,
lock_mode: Option<crate::core::LockMode>,
compound: Vec<crate::core::CompoundBranch>,
head_order_by: Vec<PendingOrderItem>,
head_limit: Option<i64>,
head_offset: Option<i64>,
is_none: bool,
disabled_global_scopes: Vec<&'static str>,
disable_all_global_scopes: bool,
_model: PhantomData<fn() -> T>,
}
#[derive(Debug, Clone)]
enum PendingOrderItem {
Field {
name: String,
desc: bool,
nulls: crate::core::NullsOrder,
},
Expr {
expr: crate::core::Expr,
desc: bool,
nulls: crate::core::NullsOrder,
},
Random,
}
#[derive(Clone)]
enum PendingFilter {
Raw(RawFilter),
DateTransform(DateTransformFilter),
Resolved(Filter),
Expr(WhereExpr),
Negated(Box<PendingFilter>),
RelationSpan { raw_key: String, value: SqlValue },
Error(QueryError),
}
#[derive(Debug, Clone)]
struct RawFilter {
field: String,
op: Op,
value: SqlValue,
}
#[derive(Debug, Clone)]
struct DateTransformFilter {
field: String,
transform: ScalarFn,
op: Op,
value: SqlValue,
}
#[derive(Debug, Clone)]
struct RawAssignment {
field: String,
value: SqlValue,
}
#[derive(Debug, Clone)]
struct RawExprAssignment {
field: String,
value: crate::core::Expr,
}
impl<T: Model> Default for QuerySet<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Model> Clone for QuerySet<T> {
fn clone(&self) -> Self {
Self {
pending: self.pending.clone(),
limit: self.limit,
offset: self.offset,
distinct: self.distinct.clone(),
select_related: self.select_related.clone(),
ad_hoc_joins: self.ad_hoc_joins.clone(),
subquery_joins: self.subquery_joins.clone(),
order_by: self.order_by.clone(),
lock_mode: self.lock_mode.clone(),
compound: self.compound.clone(),
head_order_by: self.head_order_by.clone(),
head_limit: self.head_limit,
head_offset: self.head_offset,
is_none: self.is_none,
disabled_global_scopes: self.disabled_global_scopes.clone(),
disable_all_global_scopes: self.disable_all_global_scopes,
_model: std::marker::PhantomData,
}
}
}
impl<T: Model> QuerySet<T> {
#[must_use]
pub fn new() -> Self {
Self {
pending: Vec::new(),
limit: None,
offset: None,
distinct: None,
select_related: Vec::new(),
ad_hoc_joins: Vec::new(),
subquery_joins: Vec::new(),
order_by: Vec::new(),
lock_mode: None,
compound: Vec::new(),
head_order_by: Vec::new(),
head_limit: None,
head_offset: None,
is_none: false,
disabled_global_scopes: Vec::new(),
disable_all_global_scopes: false,
_model: PhantomData,
}
}
#[must_use]
pub fn without_global_scope(mut self, name: &'static str) -> Self {
if !self.disabled_global_scopes.contains(&name) {
self.disabled_global_scopes.push(name);
}
self
}
#[must_use]
pub fn without_global_scopes(mut self) -> Self {
self.disable_all_global_scopes = true;
self
}
fn apply_global_scopes(&mut self) {
if self.disable_all_global_scopes {
return;
}
let schema = T::SCHEMA;
if schema.global_scopes.is_empty() {
return;
}
let mut prefixed: Vec<PendingFilter> = Vec::new();
for scope in schema.global_scopes {
if self.disabled_global_scopes.contains(&scope.name) {
continue;
}
let expr = (scope.apply)();
prefixed.push(PendingFilter::Expr(expr));
}
if !prefixed.is_empty() {
prefixed.append(&mut self.pending);
self.pending = prefixed;
}
}
#[must_use]
pub fn none(mut self) -> Self {
self.is_none = true;
self
}
#[must_use]
pub fn distinct(mut self) -> Self {
self.distinct = Some(crate::core::DistinctMode::All);
self
}
#[must_use]
pub fn distinct_on(mut self, fields: &[&'static str]) -> Self {
self.distinct = Some(crate::core::DistinctMode::On(fields.to_vec()));
self
}
#[must_use]
pub fn select_for_update(mut self) -> Self {
self.lock_mode = Some(crate::core::LockMode::default());
self
}
#[must_use]
pub fn lock_for_update(self) -> Self {
self.select_for_update()
}
#[must_use]
pub fn skip_locked(mut self) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.skip_locked = true;
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn nowait(mut self) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.nowait = true;
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn no_key(mut self) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.no_key = true;
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn of(mut self, tables: &[&'static str]) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.of.extend_from_slice(tables);
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn silent_on_sqlite(mut self) -> Self {
let mut lock = self.lock_mode.take().unwrap_or_default();
lock.silent_on_sqlite = true;
self.lock_mode = Some(lock);
self
}
#[must_use]
pub fn union(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::Union, other)
}
#[must_use]
pub fn union_all(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::UnionAll, other)
}
#[must_use]
pub fn intersection(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::Intersection, other)
}
#[must_use]
pub fn difference(self, other: QuerySet<T>) -> Self {
self.add_compound(crate::core::SetOp::Difference, other)
}
#[must_use]
pub fn with_compound(self, op: crate::core::SetOp, branch: crate::core::SelectQuery) -> Self {
self.add_compound_compiled(op, branch)
}
fn add_compound(self, op: crate::core::SetOp, other: QuerySet<T>) -> Self {
match other.compile() {
Ok(branch) => self.add_compound_compiled(op, branch),
Err(e) => panic!(
"rustango: set-algebra branch failed to compile: {e}. \
Pre-compile the branch and pass via .with_compound(op, \
branch) to surface this error as a Result."
),
}
}
fn add_compound_compiled(
mut self,
op: crate::core::SetOp,
branch: crate::core::SelectQuery,
) -> Self {
if self.compound.is_empty() {
self.head_order_by = std::mem::take(&mut self.order_by);
self.head_limit = self.limit.take();
self.head_offset = self.offset.take();
}
self.compound.push(crate::core::CompoundBranch {
op,
query: Box::new(branch),
});
self
}
#[must_use]
pub fn order_by(mut self, items: &[(&str, bool)]) -> Self {
for (field, desc) in items {
self.order_by.push(PendingOrderItem::Field {
name: (*field).to_owned(),
desc: *desc,
nulls: crate::core::NullsOrder::Default,
});
}
self
}
#[must_use]
pub fn reorder(mut self, items: &[(&str, bool)]) -> Self {
self.order_by.clear();
for (field, desc) in items {
self.order_by.push(PendingOrderItem::Field {
name: (*field).to_owned(),
desc: *desc,
nulls: crate::core::NullsOrder::Default,
});
}
self
}
#[must_use]
pub fn order_by_desc(self, column: &str) -> Self {
self.order_by(&[(column, true)])
}
#[must_use]
pub fn order_by_asc(self, column: &str) -> Self {
self.order_by(&[(column, false)])
}
#[must_use]
pub fn with_default_order(mut self) -> Self {
let model = T::SCHEMA;
let new_entries: Vec<PendingOrderItem> = model
.default_order
.iter()
.map(|(name, desc)| PendingOrderItem::Field {
name: (*name).to_owned(),
desc: *desc,
nulls: crate::core::NullsOrder::Default,
})
.collect();
if new_entries.is_empty() {
return self;
}
let already_applied = self.order_by.len() >= new_entries.len()
&& self
.order_by
.iter()
.zip(new_entries.iter())
.all(|(a, b)| match (a, b) {
(
PendingOrderItem::Field {
name: an, desc: ad, ..
},
PendingOrderItem::Field {
name: bn, desc: bd, ..
},
) => an == bn && ad == bd,
_ => false,
});
if already_applied {
return self;
}
let mut combined = new_entries;
combined.extend(self.order_by.drain(..));
self.order_by = combined;
self
}
#[must_use]
pub fn unordered(mut self) -> Self {
self.order_by.clear();
self
}
#[must_use]
pub fn order_by_with_nulls(
mut self,
items: &[(&'static str, bool, crate::core::NullsOrder)],
) -> Self {
for (field, desc, nulls) in items {
self.order_by.push(PendingOrderItem::Field {
name: (*field).to_owned(),
desc: *desc,
nulls: *nulls,
});
}
self
}
#[must_use]
pub fn order_by_expr(mut self, expr: impl Into<crate::core::Expr>, desc: bool) -> Self {
self.order_by.push(PendingOrderItem::Expr {
expr: expr.into(),
desc,
nulls: crate::core::NullsOrder::Default,
});
self
}
#[must_use]
pub fn order_by_distance(
self,
column: &'static str,
query: Vec<f32>,
metric: crate::core::VectorMetric,
) -> Self {
let expr = crate::core::Expr::BinOp {
left: Box::new(crate::core::Expr::Column(column)),
op: metric.to_binop(),
right: Box::new(crate::core::Expr::Literal(SqlValue::Vector(query))),
};
self.order_by_expr(expr, false)
}
#[must_use]
pub fn k_nearest(
self,
column: &'static str,
query: Vec<f32>,
k: i64,
metric: crate::core::VectorMetric,
) -> Self {
self.order_by_distance(column, query, metric).limit(k)
}
#[must_use]
pub fn order_by_distance_to(self, column: &'static str, point: crate::sql::Point) -> Self {
let expr = crate::core::funcs::st_distance(crate::core::Expr::Column(column), point);
self.order_by_expr(expr, false)
}
#[must_use]
pub fn filter_dwithin(
self,
column: &'static str,
point: crate::sql::Point,
distance: f64,
) -> Self {
let pred = WhereExpr::ExprCompare {
lhs: crate::core::funcs::st_dwithin(
crate::core::Expr::Column(column),
point,
crate::core::Expr::Literal(SqlValue::F64(distance)),
),
op: Op::Eq,
rhs: crate::core::Expr::Literal(SqlValue::Bool(true)),
};
self.where_raw(pred)
}
#[must_use]
pub fn order_by_expr_with_nulls(
mut self,
expr: impl Into<crate::core::Expr>,
desc: bool,
nulls: crate::core::NullsOrder,
) -> Self {
self.order_by.push(PendingOrderItem::Expr {
expr: expr.into(),
desc,
nulls,
});
self
}
#[must_use]
pub fn order_random(mut self) -> Self {
self.order_by.push(PendingOrderItem::Random);
self
}
#[must_use]
pub fn in_random_order(self) -> Self {
self.order_random()
}
#[must_use]
pub fn replace_order_by(mut self, items: &[(&str, bool)]) -> Self {
self.order_by.clear();
self.order_by(items)
}
#[must_use]
pub fn reverse(mut self) -> Self {
for item in &mut self.order_by {
match item {
PendingOrderItem::Field { desc, .. } | PendingOrderItem::Expr { desc, .. } => {
*desc = !*desc
}
PendingOrderItem::Random => {}
}
}
self
}
#[must_use]
pub fn flip_order_by(mut self) -> Self {
for entry in &mut self.order_by {
match entry {
PendingOrderItem::Field { desc, nulls, .. }
| PendingOrderItem::Expr { desc, nulls, .. } => {
*desc = !*desc;
*nulls = match *nulls {
crate::core::NullsOrder::First => crate::core::NullsOrder::Last,
crate::core::NullsOrder::Last => crate::core::NullsOrder::First,
crate::core::NullsOrder::Default => crate::core::NullsOrder::Default,
};
}
PendingOrderItem::Random => {}
}
}
self
}
#[must_use]
pub fn has_order_by(&self) -> bool {
!self.order_by.is_empty()
}
#[must_use]
pub fn select_related(mut self, field: impl Into<String>) -> Self {
self.select_related.push(field.into());
self
}
#[must_use]
pub fn join(mut self, join: crate::core::Join) -> Self {
self.ad_hoc_joins.push(join);
self
}
#[must_use]
pub fn join_sub(self, sub: SelectQuery, alias: &'static str, on: WhereExpr) -> Self {
self.push_subquery_join(
sub,
alias,
crate::core::JoinKind::Inner,
on,
false,
)
}
#[must_use]
pub fn left_join_sub(self, sub: SelectQuery, alias: &'static str, on: WhereExpr) -> Self {
self.push_subquery_join(
sub,
alias,
crate::core::JoinKind::Left,
on,
false,
)
}
#[must_use]
pub fn join_lateral(self, sub: SelectQuery, alias: &'static str, on: WhereExpr) -> Self {
self.push_subquery_join(
sub,
alias,
crate::core::JoinKind::Inner,
on,
true,
)
}
#[must_use]
pub fn left_join_lateral(self, sub: SelectQuery, alias: &'static str, on: WhereExpr) -> Self {
self.push_subquery_join(
sub,
alias,
crate::core::JoinKind::Left,
on,
true,
)
}
#[must_use]
fn push_subquery_join(
mut self,
sub: SelectQuery,
alias: &'static str,
kind: crate::core::JoinKind,
on: WhereExpr,
lateral: bool,
) -> Self {
self.subquery_joins.push(crate::core::SubqueryJoin {
subquery: Box::new(sub),
alias,
kind,
on,
lateral,
});
self
}
#[must_use]
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
#[must_use]
pub fn take(self, n: i64) -> Self {
self.limit(n)
}
#[must_use]
pub fn skip(self, n: i64) -> Self {
self.offset(n)
}
#[must_use]
pub fn filter_op(
mut self,
field: impl Into<String>,
op: Op,
value: impl Into<SqlValue>,
) -> Self {
self.pending.push(PendingFilter::Raw(RawFilter {
field: field.into(),
op,
value: value.into(),
}));
self
}
#[must_use]
pub fn filter(mut self, key: &str, value: impl Into<SqlValue>) -> Self {
self.pending.push(parse_to_pending(key, value.into()));
self
}
#[must_use]
pub fn exclude(mut self, key: &str, value: impl Into<SqlValue>) -> Self {
self.pending
.push(PendingFilter::Negated(Box::new(parse_to_pending(
key,
value.into(),
))));
self
}
fn with_pending_error(mut self, e: QueryError) -> Self {
self.pending.push(PendingFilter::Error(e));
self
}
#[must_use]
pub fn eq(self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.filter_op(field, Op::Eq, value)
}
#[must_use]
pub fn where_key(self, pk: impl Into<SqlValue>) -> Self {
match T::SCHEMA.primary_key() {
Some(f) => self.filter_op(f.name, Op::Eq, pk),
None => self.with_pending_error(QueryError::UnknownField {
model: T::SCHEMA.name,
field: "<pk>".to_string(),
}),
}
}
#[must_use]
pub fn where_key_not(self, pk: impl Into<SqlValue>) -> Self {
match T::SCHEMA.primary_key() {
Some(f) => self.filter_op(f.name, Op::Ne, pk),
None => self.with_pending_error(QueryError::UnknownField {
model: T::SCHEMA.name,
field: "<pk>".to_string(),
}),
}
}
#[must_use]
pub fn where_exists(self, subquery: SelectQuery) -> Self {
self.where_raw(crate::core::subquery::exists(subquery))
}
#[must_use]
pub fn where_not_exists(self, subquery: SelectQuery) -> Self {
self.where_raw(crate::core::subquery::not_exists(subquery))
}
#[must_use]
pub fn where_in_subquery(self, column: &'static str, subquery: SelectQuery) -> Self {
self.where_raw(crate::core::subquery::in_subquery(column, subquery))
}
#[must_use]
pub fn where_not_in_subquery(self, column: &'static str, subquery: SelectQuery) -> Self {
self.where_raw(crate::core::subquery::not_in_subquery(column, subquery))
}
#[must_use]
pub fn where_has(self, name: &str) -> Self {
match Self::resolve_rel_exists(name, false) {
Some(expr) => self.where_raw(expr),
None => self.with_pending_error(QueryError::UnknownField {
model: T::SCHEMA.name,
field: name.to_string(),
}),
}
}
fn self_pk_column() -> &'static str {
T::SCHEMA.primary_key().map_or("id", |f| f.column)
}
fn resolve_rel_exists(name: &str, negated: bool) -> Option<WhereExpr> {
use crate::core::subquery;
if let Some(rel) = T::reverse_relations().iter().find(|r| r.name == name) {
return Some(if negated {
subquery::reverse_has_not_exists(rel)
} else {
subquery::reverse_has_exists(rel)
});
}
if let Some(m2m) = T::SCHEMA.m2m.iter().find(|m| m.name == name) {
return Some(subquery::m2m_has_exists(
m2m,
Self::self_pk_column(),
negated,
));
}
if let Some(g) = T::generic_reverse_relations()
.iter()
.find(|g| g.name == name)
{
return Some(subquery::generic_has_exists(g, T::SCHEMA.table, negated));
}
None
}
fn resolve_rel_count(name: &str) -> Option<Expr> {
use crate::core::{subquery, RelAggKind};
if let Some(rel) = T::reverse_relations().iter().find(|r| r.name == name) {
return Some(subquery::reverse_has_count(rel));
}
if let Some(m2m) = T::SCHEMA.m2m.iter().find(|m| m.name == name) {
return Some(subquery::m2m_has_aggregate(
m2m,
Self::self_pk_column(),
RelAggKind::Count,
None,
));
}
if let Some(g) = T::generic_reverse_relations()
.iter()
.find(|g| g.name == name)
{
return Some(subquery::generic_has_aggregate(
g,
T::SCHEMA.table,
RelAggKind::Count,
None,
));
}
None
}
#[must_use]
pub fn where_doesnt_have(self, name: &str) -> Self {
match Self::resolve_rel_exists(name, true) {
Some(expr) => self.where_raw(expr),
None => self.with_pending_error(QueryError::UnknownField {
model: T::SCHEMA.name,
field: name.to_string(),
}),
}
}
#[must_use]
pub fn where_has_filter(self, name: &str, inner: SelectQuery) -> Self {
self.where_has_filter_impl(name, inner, false)
}
#[must_use]
pub fn where_doesnt_have_filter(self, name: &str, inner: SelectQuery) -> Self {
self.where_has_filter_impl(name, inner, true)
}
fn where_has_filter_impl(self, name: &str, mut inner: SelectQuery, negated: bool) -> Self {
let rel = match T::reverse_relations().iter().find(|r| r.name == name) {
Some(r) => r,
None => {
return self.with_pending_error(QueryError::UnknownField {
model: T::SCHEMA.name,
field: name.to_string(),
});
}
};
if inner.model.name != rel.child_schema.name {
return self.with_pending_error(QueryError::UnknownField {
model: T::SCHEMA.name,
field: format!(
"where_has_filter({name}): inner queryset model mismatch — \
expected `{}`, got `{}`",
rel.child_schema.name, inner.model.name
),
});
}
let correlated = WhereExpr::ExprCompare {
lhs: crate::core::Expr::Column(rel.child_fk_column),
op: Op::Eq,
rhs: crate::core::Expr::OuterRef(rel.self_pk_column),
};
inner.where_clause =
match std::mem::replace(&mut inner.where_clause, WhereExpr::And(Vec::new())) {
WhereExpr::And(mut v) => {
v.push(correlated);
WhereExpr::And(v)
}
other => WhereExpr::And(vec![other, correlated]),
};
let wrapped = if negated {
WhereExpr::NotExists(Box::new(inner))
} else {
WhereExpr::Exists(Box::new(inner))
};
self.where_raw(wrapped)
}
#[must_use]
pub fn where_has_count(self, name: &str, op: Op, n: i64) -> Self {
match Self::resolve_rel_count(name) {
Some(lhs) => self.where_raw(WhereExpr::ExprCompare {
lhs,
op,
rhs: Expr::Literal(SqlValue::I64(n)),
}),
None => self.with_pending_error(QueryError::UnknownField {
model: T::SCHEMA.name,
field: name.to_string(),
}),
}
}
#[must_use]
pub fn annotate_count(self, name: &str) -> AggregateBuilder<T> {
self.annotate_relation_aggregate(name, None, AggregateExpr::Count(None), "count")
}
#[must_use]
pub fn annotate_sum(self, name: &str, column: &'static str) -> AggregateBuilder<T> {
self.annotate_relation_aggregate(name, Some(column), AggregateExpr::Sum(column), "sum")
}
#[must_use]
pub fn annotate_avg(self, name: &str, column: &'static str) -> AggregateBuilder<T> {
self.annotate_relation_aggregate(name, Some(column), AggregateExpr::Avg(column), "avg")
}
#[must_use]
pub fn annotate_max(self, name: &str, column: &'static str) -> AggregateBuilder<T> {
self.annotate_relation_aggregate(name, Some(column), AggregateExpr::Max(column), "max")
}
#[must_use]
pub fn annotate_min(self, name: &str, column: &'static str) -> AggregateBuilder<T> {
self.annotate_relation_aggregate(name, Some(column), AggregateExpr::Min(column), "min")
}
#[must_use]
pub fn annotate_exists(self, name: &str) -> AggregateBuilder<T> {
let mut builder = self.aggregate();
match Self::resolve_rel_exists(name, false) {
Some(exists) => {
let expr = AggregateExpr::RelatedAggregate(Box::new(
crate::core::subquery::exists_as_int(exists),
));
builder
.aggregates
.push((std::borrow::Cow::Owned(format!("{name}_exists")), expr));
}
None => {
builder
.deferred_error
.get_or_insert(QueryError::UnknownField {
model: T::SCHEMA.name,
field: name.to_owned(),
});
}
}
builder
}
fn annotate_relation_aggregate(
self,
name: &str,
column: Option<&'static str>,
agg: AggregateExpr,
suffix: &str,
) -> AggregateBuilder<T> {
use crate::core::{subquery, RelAggKind};
let mut builder = self.aggregate();
let alias = || match column {
Some(col) => std::borrow::Cow::Owned(format!("{name}_{suffix}_{col}")),
None => std::borrow::Cow::Owned(format!("{name}_{suffix}")),
};
let rel_kind = |a: &AggregateExpr| match a {
AggregateExpr::Sum(_) => RelAggKind::Sum,
AggregateExpr::Avg(_) => RelAggKind::Avg,
AggregateExpr::Max(_) => RelAggKind::Max,
AggregateExpr::Min(_) => RelAggKind::Min,
_ => RelAggKind::Count,
};
if let Some(rel) = T::reverse_relations().iter().find(|r| r.name == name) {
if let Some(col) = column {
if rel.child_schema.field_by_column(col).is_none() {
builder
.deferred_error
.get_or_insert(QueryError::UnknownField {
model: rel.child_schema.name,
field: col.to_owned(),
});
return builder;
}
}
let expr = AggregateExpr::RelatedAggregate(Box::new(subquery::reverse_has_aggregate(
rel, agg,
)));
builder.aggregates.push((alias(), expr));
return builder;
}
if let Some(m2m) = T::SCHEMA.m2m.iter().find(|m| m.name == name) {
let expr = AggregateExpr::RelatedAggregate(Box::new(subquery::m2m_has_aggregate(
m2m,
Self::self_pk_column(),
rel_kind(&agg),
column,
)));
builder.aggregates.push((alias(), expr));
return builder;
}
if let Some(g) = T::generic_reverse_relations()
.iter()
.find(|g| g.name == name)
{
if let Some(col) = column {
if g.child_schema.field_by_column(col).is_none() {
builder
.deferred_error
.get_or_insert(QueryError::UnknownField {
model: g.child_schema.name,
field: col.to_owned(),
});
return builder;
}
}
let expr = AggregateExpr::RelatedAggregate(Box::new(subquery::generic_has_aggregate(
g,
T::SCHEMA.table,
rel_kind(&agg),
column,
)));
builder.aggregates.push((alias(), expr));
return builder;
}
builder
.deferred_error
.get_or_insert(QueryError::UnknownField {
model: T::SCHEMA.name,
field: name.to_owned(),
});
builder
}
#[must_use]
pub fn where_column(self, col1: &'static str, col2: &'static str) -> Self {
self.where_column_op(col1, Op::Eq, col2)
}
#[must_use]
pub fn where_column_op(self, col1: &'static str, op: Op, col2: &'static str) -> Self {
use crate::core::F;
self.where_raw(WhereExpr::ExprCompare {
lhs: F(col1).into(),
op,
rhs: F(col2).into(),
})
}
#[must_use]
pub fn where_raw(mut self, expr: WhereExpr) -> Self {
self.pending.push(PendingFilter::Expr(expr));
self
}
#[must_use]
pub fn when<F>(self, condition: bool, f: F) -> Self
where
F: FnOnce(Self) -> Self,
{
if condition {
f(self)
} else {
self
}
}
#[must_use]
pub fn unless<F>(self, condition: bool, f: F) -> Self
where
F: FnOnce(Self) -> Self,
{
if condition {
self
} else {
f(self)
}
}
#[must_use]
pub fn tap<F>(self, f: F) -> Self
where
F: FnOnce(&Self),
{
f(&self);
self
}
#[must_use]
pub fn active(self) -> Self {
match crate::soft_delete::active_filter(T::SCHEMA) {
Some(expr) => self.where_raw(expr),
None => self,
}
}
#[must_use]
pub fn only_trashed(self) -> Self {
match crate::soft_delete::trashed_filter(T::SCHEMA) {
Some(expr) => self.where_raw(expr),
None => self,
}
}
#[must_use]
pub fn with_trashed(self) -> Self {
self
}
#[must_use]
pub fn where_<E: Into<TypedExpr<T>>>(mut self, predicate: E) -> Self {
let expr = predicate.into().into_expr();
match expr {
WhereExpr::Predicate(filter) => {
self.pending.push(PendingFilter::Resolved(filter));
}
other => {
self.pending.push(PendingFilter::Expr(other));
}
}
self
}
pub fn compile(mut self) -> Result<SelectQuery, QueryError> {
self.apply_global_scopes();
let model: &'static ModelSchema = T::SCHEMA;
let (pending, span_joins) = lower_relation_spans(model, self.pending)?;
let where_clause = resolve_pending(model, pending)?;
let mut joins = lower_select_related(model, &self.select_related)?;
joins.extend(self.ad_hoc_joins);
for j in span_joins {
if !joins.iter().any(|e| e.alias == j.alias) {
joins.push(j);
}
}
let has_compound = !self.compound.is_empty();
let (head_pending, head_limit, head_offset, comb_pending, comb_limit, comb_offset) =
if has_compound {
(
std::mem::take(&mut self.head_order_by),
self.head_limit,
self.head_offset,
std::mem::take(&mut self.order_by),
self.limit,
self.offset,
)
} else {
(
std::mem::take(&mut self.order_by),
self.limit,
self.offset,
Vec::new(),
None,
None,
)
};
let order_by = lower_order_items(model, head_pending)?;
let compound_order_by = lower_order_items(model, comb_pending)?;
if let Some(crate::core::DistinctMode::On(cols)) = &self.distinct {
if cols.is_empty() {
return Err(QueryError::DistinctOnEmpty);
}
for col in cols {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
if order_by.len() < cols.len() {
return Err(QueryError::DistinctOnOrderByMismatch {
distinct_on: cols.iter().map(|s| (*s).to_owned()).collect(),
order_by: order_by_column_names(&order_by),
});
}
for (i, col) in cols.iter().enumerate() {
let head = match &order_by[i] {
crate::core::OrderItem::Column { column, .. } => *column,
_ => {
return Err(QueryError::DistinctOnOrderByMismatch {
distinct_on: cols.iter().map(|s| (*s).to_owned()).collect(),
order_by: order_by_column_names(&order_by),
});
}
};
if head != *col {
return Err(QueryError::DistinctOnOrderByMismatch {
distinct_on: cols.iter().map(|s| (*s).to_owned()).collect(),
order_by: order_by_column_names(&order_by),
});
}
}
}
let (limit, compound_limit) = if self.is_none {
if has_compound {
(head_limit, Some(0))
} else {
(Some(0), comb_limit)
}
} else {
(head_limit, comb_limit)
};
Ok(SelectQuery {
model,
where_clause,
search: None,
joins,
subquery_joins: self.subquery_joins,
order_by,
limit,
offset: head_offset,
lock_mode: self.lock_mode,
compound: self.compound,
projection: None,
distinct: self.distinct,
compound_order_by,
compound_limit,
compound_offset: comb_offset,
})
}
#[must_use]
pub fn values_dict(self, cols: &[&'static str]) -> ValuesQuerySet<T> {
ValuesQuerySet {
qs: self,
cols: cols.to_vec(),
}
}
#[must_use]
pub fn values_list(self, cols: &[&'static str]) -> ValuesListQuerySet<T> {
ValuesListQuerySet {
qs: self,
cols: cols.to_vec(),
}
}
#[must_use]
pub fn values_list_flat(self, col: &'static str) -> ValuesFlatQuerySet<T> {
ValuesFlatQuerySet { qs: self, col }
}
#[must_use]
pub fn only(self, cols: &[&'static str]) -> ValuesQuerySet<T> {
self.values_dict(cols)
}
#[must_use]
pub fn defer(self, cols: &[&'static str]) -> ValuesQuerySet<T> {
let model = T::SCHEMA;
let exclude: std::collections::HashSet<&'static str> = cols.iter().copied().collect();
let mut projection: Vec<&'static str> = model
.scalar_fields()
.filter(|f| !exclude.contains(f.column))
.map(|f| f.column)
.collect();
for &col in cols {
if model.field_by_column(col).is_none() {
projection.push(col);
}
}
self.values_dict(&projection)
}
pub fn compile_delete(mut self) -> Result<DeleteQuery, QueryError> {
self.apply_global_scopes();
let model: &'static ModelSchema = T::SCHEMA;
let where_clause = resolve_pending(model, self.pending)?;
let where_clause = if self.is_none {
never_match_clause(model, where_clause)?
} else {
where_clause
};
Ok(DeleteQuery {
model,
where_clause,
})
}
#[must_use]
pub fn update(self) -> UpdateBuilder<T> {
UpdateBuilder {
qs: self,
set: Vec::new(),
}
}
#[must_use]
pub fn aggregate(self) -> AggregateBuilder<T> {
AggregateBuilder {
qs: self,
group_by: Vec::new(),
aggregates: Vec::new(),
aliases: Vec::new(),
having: None,
order_by: Vec::new(),
limit: None,
offset: None,
deferred_error: None,
values: None,
}
}
#[must_use]
pub fn values(self, columns: &[&'static str]) -> AggregateBuilder<T> {
AggregateBuilder {
qs: self,
group_by: Vec::new(),
aggregates: Vec::new(),
aliases: Vec::new(),
having: None,
order_by: Vec::new(),
limit: None,
offset: None,
deferred_error: None,
values: Some(columns.to_vec()),
}
}
#[must_use]
pub fn annotate(self, alias: &'static str, expr: AggregateExpr) -> AggregateBuilder<T> {
self.aggregate().annotate(alias, expr)
}
#[must_use]
pub fn annotate_subquery(
self,
alias: &'static str,
inner: crate::core::SelectQuery,
) -> AggregateBuilder<T> {
self.aggregate()
.annotate(alias, crate::core::subquery::scalar_subquery(inner))
}
}
pub struct UpdateBuilder<T: Model> {
qs: QuerySet<T>,
set: Vec<PendingAssignment>,
}
enum PendingAssignment {
Raw(RawAssignment),
RawExpr(RawExprAssignment),
Resolved(Assignment),
}
impl<T: Model> UpdateBuilder<T> {
#[must_use]
pub fn set(mut self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.set.push(PendingAssignment::Raw(RawAssignment {
field: field.into(),
value: value.into(),
}));
self
}
#[must_use]
pub fn set_typed(mut self, assignment: TypedAssignment<T>) -> Self {
self.set
.push(PendingAssignment::Resolved(assignment.into_assignment()));
self
}
#[must_use]
pub fn set_expr(
mut self,
field: impl Into<String>,
expr: impl Into<crate::core::Expr>,
) -> Self {
self.set.push(PendingAssignment::RawExpr(RawExprAssignment {
field: field.into(),
value: expr.into(),
}));
self
}
pub fn compile(mut self) -> Result<UpdateQuery, QueryError> {
self.qs.apply_global_scopes();
let model: &'static ModelSchema = T::SCHEMA;
let assignments = self
.set
.into_iter()
.map(|p| match p {
PendingAssignment::Raw(raw) => resolve_assignment(model, raw),
PendingAssignment::RawExpr(raw) => resolve_assignment_expr(model, raw),
PendingAssignment::Resolved(assignment) => Ok(assignment),
})
.collect::<Result<Vec<_>, _>>()?;
let where_clause = resolve_pending(model, self.qs.pending)?;
let where_clause = if self.qs.is_none {
never_match_clause(model, where_clause)?
} else {
where_clause
};
Ok(UpdateQuery {
model,
set: assignments,
where_clause,
})
}
}
fn order_by_column_names(order_by: &[crate::core::OrderItem]) -> Vec<String> {
order_by
.iter()
.map(|item| match item {
crate::core::OrderItem::Column { column, .. } => (*column).to_owned(),
crate::core::OrderItem::Expr { .. } => "<expr>".to_owned(),
crate::core::OrderItem::Random => "RANDOM".to_owned(),
})
.collect()
}
fn lower_order_items(
model: &'static ModelSchema,
items: Vec<PendingOrderItem>,
) -> Result<Vec<crate::core::OrderItem>, QueryError> {
let mut out = Vec::with_capacity(items.len());
for item in items {
match item {
PendingOrderItem::Field { name, desc, nulls } => {
let field = model.field(&name).ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: name.clone(),
})?;
out.push(crate::core::OrderItem::column_with_nulls(
field.column,
desc,
nulls,
));
}
PendingOrderItem::Expr { expr, desc, nulls } => {
out.push(crate::core::OrderItem::expr_with_nulls(expr, desc, nulls));
}
PendingOrderItem::Random => {
out.push(crate::core::OrderItem::random());
}
}
}
Ok(out)
}
fn lower_select_related(
model: &'static ModelSchema,
names: &[String],
) -> Result<Vec<crate::core::Join>, QueryError> {
use crate::core::{inventory, Expr, Join, JoinKind, ModelEntry, Op, Relation, WhereExpr};
let mut out: Vec<Join> = Vec::with_capacity(names.len());
for name in names {
let hops: Vec<&str> = name.split("__").collect();
if hops.iter().any(|h| h.is_empty()) {
return Err(QueryError::SelectRelatedInvalid {
model: model.name,
field: name.clone(),
reason: format!(
"empty hop in chain `{name}` (use `field` or `field__nested`, not `field____nested`)"
),
});
}
let mut current: &'static ModelSchema = model;
let mut prev_alias: &'static str = model.table;
let mut prev_alias_owned: String = String::new();
for (depth, hop) in hops.iter().enumerate() {
let field = current
.field(hop)
.ok_or_else(|| QueryError::SelectRelatedInvalid {
model: current.name,
field: name.clone(),
reason: format!(
"no field `{hop}` on `{}` (hop {} of chain `{name}`)",
current.name,
depth + 1
),
})?;
let (to, on) = match field.relation {
Some(Relation::Fk { to, on }) | Some(Relation::O2O { to, on }) => (to, on),
_ => {
return Err(QueryError::SelectRelatedInvalid {
model: current.name,
field: name.clone(),
reason: format!(
"`{hop}` on `{}` is not a `ForeignKey<T>` field (hop {} of chain `{name}`)",
current.name,
depth + 1
),
});
}
};
let target = inventory::iter::<ModelEntry>
.into_iter()
.find(|e| e.schema.table == to)
.map(|e| e.schema)
.ok_or_else(|| QueryError::SelectRelatedInvalid {
model: current.name,
field: name.clone(),
reason: format!(
"target table `{to}` is not registered (hop {} of chain `{name}` — is `{}`'s `#[derive(Model)]` linked into the binary?)",
depth + 1,
current.name
),
})?;
let alias: &'static str = if hops.len() == 1 {
field.name
} else {
let owned = if depth == 0 {
(*hop).to_owned()
} else {
format!("{prev_alias_owned}__{hop}")
};
prev_alias_owned = owned.clone();
Box::leak(owned.into_boxed_str())
};
let project: Vec<&'static str> = target.scalar_fields().map(|f| f.column).collect();
out.push(Join {
target,
alias,
kind: JoinKind::Left,
on: WhereExpr::ExprCompare {
lhs: Expr::AliasedColumn {
alias: prev_alias,
column: field.column,
},
op: Op::Eq,
rhs: Expr::AliasedColumn { alias, column: on },
},
project,
});
current = target;
prev_alias = alias;
}
}
Ok(out)
}
enum ParsedLookup {
Raw {
field: String,
op: Op,
value: SqlValue,
},
DateTransform {
field: String,
transform: ScalarFn,
op: Op,
value: SqlValue,
},
RelationSpan { raw_key: String, value: SqlValue },
}
fn date_transform_fn(token: &str) -> Option<ScalarFn> {
match token {
"year" => Some(ScalarFn::ExtractYear),
"month" => Some(ScalarFn::ExtractMonth),
"day" => Some(ScalarFn::ExtractDay),
"hour" => Some(ScalarFn::ExtractHour),
"minute" => Some(ScalarFn::ExtractMinute),
"second" => Some(ScalarFn::ExtractSecond),
"quarter" => Some(ScalarFn::ExtractQuarter),
"week" => Some(ScalarFn::ExtractWeek),
"week_day" => Some(ScalarFn::ExtractWeekDay),
"date" => Some(ScalarFn::TruncDate),
_ => None,
}
}
fn date_compare_op(suffix: &str) -> Option<Op> {
match suffix {
"exact" => Some(Op::Eq),
"ne" => Some(Op::Ne),
"gt" => Some(Op::Gt),
"gte" => Some(Op::Gte),
"lt" => Some(Op::Lt),
"lte" => Some(Op::Lte),
_ => None,
}
}
fn parse_lookup(key: &str, value: SqlValue) -> Result<ParsedLookup, QueryError> {
let Some(split_at) = key.find("__") else {
return Ok(ParsedLookup::Raw {
field: key.to_owned(),
op: Op::Eq,
value,
});
};
let field = key[..split_at].to_owned();
let suffix = &key[split_at + 2..];
let (transform_token, trailing) = match suffix.find("__") {
Some(at) => (&suffix[..at], Some(&suffix[at + 2..])),
None => (suffix, None),
};
if let Some(transform) = date_transform_fn(transform_token) {
let op = match trailing {
None => Op::Eq,
Some(t) => date_compare_op(t).ok_or_else(|| QueryError::UnknownLookup {
field: field.clone(),
suffix: suffix.to_owned(),
})?,
};
return Ok(ParsedLookup::DateTransform {
field,
transform,
op,
value,
});
}
let pair = |field: String, op: Op, value: SqlValue| ParsedLookup::Raw { field, op, value };
match suffix {
"exact" => Ok(pair(field, Op::Eq, value)),
"ne" => Ok(pair(field, Op::Ne, value)),
"gt" => Ok(pair(field, Op::Gt, value)),
"gte" => Ok(pair(field, Op::Gte, value)),
"lt" => Ok(pair(field, Op::Lt, value)),
"lte" => Ok(pair(field, Op::Lte, value)),
"iexact" => Ok(pair(field, Op::ILike, value)),
"contains" => {
let v = wrap_like(&value, "%", "%", &field, suffix)?;
Ok(pair(field, Op::Like, v))
}
"icontains" => {
let v = wrap_like(&value, "%", "%", &field, suffix)?;
Ok(pair(field, Op::ILike, v))
}
"startswith" => {
let v = wrap_like(&value, "", "%", &field, suffix)?;
Ok(pair(field, Op::Like, v))
}
"istartswith" => {
let v = wrap_like(&value, "", "%", &field, suffix)?;
Ok(pair(field, Op::ILike, v))
}
"endswith" => {
let v = wrap_like(&value, "%", "", &field, suffix)?;
Ok(pair(field, Op::Like, v))
}
"iendswith" => {
let v = wrap_like(&value, "%", "", &field, suffix)?;
Ok(pair(field, Op::ILike, v))
}
"like" | "ilike" | "not_like" | "not_ilike" => {
if !matches!(value, SqlValue::String(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<LIKE pattern>)",
actual: sql_value_shape_name(&value),
});
}
let op = match suffix {
"like" => Op::Like,
"ilike" => Op::ILike,
"not_like" => Op::NotLike,
"not_ilike" => Op::NotILike,
_ => unreachable!(),
};
Ok(pair(field, op, value))
}
"in" | "not_in" => {
if !matches!(value, SqlValue::List(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List(...)",
actual: sql_value_shape_name(&value),
});
}
let op = if suffix == "not_in" {
Op::NotIn
} else {
Op::In
};
Ok(pair(field, op, value))
}
"isnull" => {
if !matches!(value, SqlValue::Bool(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::Bool(true|false)",
actual: sql_value_shape_name(&value),
});
}
Ok(pair(field, Op::IsNull, value))
}
"between" | "range" | "not_between" | "not_range" => {
match &value {
SqlValue::List(items) if items.len() == 2 => {}
SqlValue::List(_) => {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List with exactly 2 elements [lo, hi]",
actual: "SqlValue::List with wrong arity",
});
}
other => {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List([lo, hi])",
actual: sql_value_shape_name(other),
});
}
}
let op = if matches!(suffix, "not_between" | "not_range") {
Op::NotBetween
} else {
Op::Between
};
Ok(pair(field, op, value))
}
"regex" | "iregex" => {
if !matches!(value, SqlValue::String(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<regex pattern>)",
actual: sql_value_shape_name(&value),
});
}
let op = if suffix == "regex" {
Op::Regex
} else {
Op::IRegex
};
Ok(pair(field, op, value))
}
"trigram_similar" | "trigram_word_similar" => {
if !matches!(value, SqlValue::String(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<trigram pattern>)",
actual: sql_value_shape_name(&value),
});
}
let op = if suffix == "trigram_similar" {
Op::TrigramSimilar
} else {
Op::TrigramWordSimilar
};
Ok(pair(field, op, value))
}
"search" => {
if !matches!(value, SqlValue::String(_)) {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<search query>)",
actual: sql_value_shape_name(&value),
});
}
Ok(pair(field, Op::Search, value))
}
"range_contains"
| "range_contained_by"
| "range_overlap"
| "range_strictly_left"
| "range_strictly_right"
| "range_adjacent" => {
let literal = match value {
SqlValue::String(s) => s,
SqlValue::RangeLiteral(s) => s,
other => {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::String(<PG range literal>) — e.g. \"[1, 10)\"",
actual: sql_value_shape_name(&other),
});
}
};
let op = match suffix {
"range_contains" => Op::RangeContains,
"range_contained_by" => Op::RangeContainedBy,
"range_overlap" => Op::RangeOverlap,
"range_strictly_left" => Op::RangeStrictlyLeft,
"range_strictly_right" => Op::RangeStrictlyRight,
"range_adjacent" => Op::RangeAdjacent,
_ => unreachable!(),
};
Ok(pair(field, op, SqlValue::RangeLiteral(literal)))
}
"array_contains" | "array_contained_by" | "array_overlap" => {
let SqlValue::List(elems) = value else {
return Err(QueryError::InvalidLookupValue {
field,
suffix: suffix.to_owned(),
expected: "SqlValue::List(<elements>) — typed homogeneous array",
actual: sql_value_shape_name(&value),
});
};
let op = match suffix {
"array_contains" => Op::ArrayContains,
"array_contained_by" => Op::ArrayContainedBy,
"array_overlap" => Op::ArrayOverlap,
_ => unreachable!(),
};
Ok(pair(field, op, SqlValue::Array(elems)))
}
_ => Ok(ParsedLookup::RelationSpan {
raw_key: key.to_owned(),
value,
}),
}
}
fn wrap_like(
value: &SqlValue,
prefix: &str,
suffix_char: &str,
field: &str,
suffix: &str,
) -> Result<SqlValue, QueryError> {
let s = match value {
SqlValue::String(s) => s,
other => {
return Err(QueryError::InvalidLookupValue {
field: field.to_owned(),
suffix: suffix.to_owned(),
expected: "SqlValue::String(...)",
actual: sql_value_shape_name(other),
});
}
};
Ok(SqlValue::String(format!("{prefix}{s}{suffix_char}")))
}
fn sql_value_shape_name(v: &SqlValue) -> &'static str {
match v {
SqlValue::Null => "SqlValue::Null",
SqlValue::I16(_) => "SqlValue::I16",
SqlValue::I32(_) => "SqlValue::I32",
SqlValue::I64(_) => "SqlValue::I64",
SqlValue::F32(_) => "SqlValue::F32",
SqlValue::F64(_) => "SqlValue::F64",
SqlValue::Bool(_) => "SqlValue::Bool",
SqlValue::String(_) => "SqlValue::String",
SqlValue::List(_) => "SqlValue::List",
_ => "SqlValue::<other>",
}
}
fn never_match_clause(
model: &'static ModelSchema,
base: WhereExpr,
) -> Result<WhereExpr, QueryError> {
let pk = model
.primary_key()
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: "<primary_key>".to_owned(),
})?;
let never = WhereExpr::Predicate(crate::core::Filter {
column: pk.column,
op: Op::IsNull,
value: SqlValue::Bool(true),
});
Ok(match base {
WhereExpr::And(nodes) if nodes.is_empty() => never,
WhereExpr::And(mut nodes) => {
nodes.push(never);
WhereExpr::And(nodes)
}
other => WhereExpr::And(vec![other, never]),
})
}
fn parse_to_pending(key: &str, value: SqlValue) -> PendingFilter {
match parse_lookup(key, value) {
Ok(ParsedLookup::Raw { field, op, value }) => {
PendingFilter::Raw(RawFilter { field, op, value })
}
Ok(ParsedLookup::DateTransform {
field,
transform,
op,
value,
}) => PendingFilter::DateTransform(DateTransformFilter {
field,
transform,
op,
value,
}),
Ok(ParsedLookup::RelationSpan { raw_key, value }) => {
PendingFilter::RelationSpan { raw_key, value }
}
Err(e) => PendingFilter::Error(e),
}
}
fn resolve_pending(
model: &'static ModelSchema,
pending: Vec<PendingFilter>,
) -> Result<WhereExpr, QueryError> {
let mut nodes: Vec<WhereExpr> = Vec::with_capacity(pending.len());
for entry in pending {
nodes.push(resolve_one_pending(model, entry)?);
}
Ok(WhereExpr::And(nodes))
}
fn resolve_one_pending(
model: &'static ModelSchema,
entry: PendingFilter,
) -> Result<WhereExpr, QueryError> {
Ok(match entry {
PendingFilter::Raw(raw) => WhereExpr::Predicate(resolve_filter(model, raw)?),
PendingFilter::DateTransform(dt) => resolve_date_transform(model, dt)?,
PendingFilter::Resolved(filter) => WhereExpr::Predicate(filter),
PendingFilter::Expr(expr) => expr,
PendingFilter::Negated(inner) => {
WhereExpr::Not(Box::new(resolve_one_pending(model, *inner)?))
}
PendingFilter::RelationSpan { raw_key, value } => {
return match resolve_span(model, &raw_key, value) {
Ok(_) => Err(QueryError::RelationSpanUnsupportedHere { key: raw_key }),
Err(e) => Err(e),
}
}
PendingFilter::Error(e) => return Err(e),
})
}
fn resolve_filter(model: &'static ModelSchema, raw: RawFilter) -> Result<Filter, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
let skip_type_check = matches!(raw.op, Op::IsNull | Op::In);
if !skip_type_check {
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
}
Ok(Filter {
column: field.column,
op: raw.op,
value: raw.value,
})
}
fn resolve_span(
model: &'static ModelSchema,
raw_key: &str,
value: SqlValue,
) -> Result<(Vec<crate::core::Join>, WhereExpr), QueryError> {
use crate::core::{inventory, Expr, Join, JoinKind, ModelEntry, Relation};
let segs: Vec<&str> = raw_key.split("__").collect();
let mut joins: Vec<Join> = Vec::new();
let mut current: &'static ModelSchema = model;
let mut prev_alias: &'static str = model.table;
let mut alias_path = String::new();
let mut term_i = segs.len() - 1;
for (i, seg) in segs.iter().enumerate() {
let is_last = i + 1 == segs.len();
let field = current.field(seg);
let fk = field.and_then(|f| match f.relation {
Some(Relation::Fk { to, on }) | Some(Relation::O2O { to, on }) => Some((f, to, on)),
_ => None,
});
match fk {
Some((field, to, on)) if !is_last => {
let target = inventory::iter::<ModelEntry>
.into_iter()
.find(|e| e.schema.table == to)
.map(|e| e.schema)
.ok_or_else(|| QueryError::UnknownField {
model: current.name,
field: format!("{seg} (target table `{to}` not registered)"),
})?;
alias_path = if alias_path.is_empty() {
(*seg).to_owned()
} else {
format!("{alias_path}__{seg}")
};
let alias: &'static str = Box::leak(alias_path.clone().into_boxed_str());
let project: Vec<&'static str> = target.scalar_fields().map(|f| f.column).collect();
joins.push(Join {
target,
alias,
kind: JoinKind::Left,
on: WhereExpr::ExprCompare {
lhs: Expr::AliasedColumn {
alias: prev_alias,
column: field.column,
},
op: Op::Eq,
rhs: Expr::AliasedColumn { alias, column: on },
},
project,
});
current = target;
prev_alias = alias;
}
_ => {
term_i = i;
break;
}
}
}
if joins.is_empty() {
let suffix = raw_key.splitn(2, "__").nth(1).unwrap_or("").to_owned();
return Err(QueryError::UnknownLookup {
field: segs.first().map_or_else(String::new, |s| (*s).to_owned()),
suffix,
});
}
let term_field = current
.field(segs[term_i])
.ok_or_else(|| QueryError::UnknownField {
model: current.name,
field: segs[term_i].to_owned(),
})?;
let suffix_segs = &segs[term_i + 1..];
let (op, value, transform) = if suffix_segs.is_empty() {
(Op::Eq, value, None)
} else {
let synth = format!("{}__{}", term_field.name, suffix_segs.join("__"));
match parse_lookup(&synth, value) {
Ok(ParsedLookup::Raw { op, value, .. }) => (op, value, None),
Ok(ParsedLookup::DateTransform {
transform,
op,
value,
..
}) => (op, value, Some(transform)),
Ok(ParsedLookup::RelationSpan { .. }) => {
return Err(QueryError::UnknownLookup {
field: term_field.name.to_owned(),
suffix: suffix_segs.join("__"),
})
}
Err(e) => return Err(e),
}
};
let column = Expr::AliasedColumn {
alias: prev_alias,
column: term_field.column,
};
let lhs = match transform {
None => column,
Some(kind) => Expr::Function {
kind,
args: vec![column],
},
};
let predicate = WhereExpr::ExprCompare {
lhs,
op,
rhs: Expr::Literal(value),
};
Ok((joins, predicate))
}
fn lower_relation_spans(
model: &'static ModelSchema,
pending: Vec<PendingFilter>,
) -> Result<(Vec<PendingFilter>, Vec<crate::core::Join>), QueryError> {
let mut joins: Vec<crate::core::Join> = Vec::new();
let mut out = Vec::with_capacity(pending.len());
for pf in pending {
out.push(convert_relation_span(model, pf, &mut joins)?);
}
Ok((out, joins))
}
fn convert_relation_span(
model: &'static ModelSchema,
pf: PendingFilter,
joins: &mut Vec<crate::core::Join>,
) -> Result<PendingFilter, QueryError> {
Ok(match pf {
PendingFilter::RelationSpan { raw_key, value } => {
let (jns, predicate) = resolve_span(model, &raw_key, value)?;
for j in jns {
if !joins.iter().any(|e| e.alias == j.alias) {
joins.push(j);
}
}
PendingFilter::Expr(predicate)
}
PendingFilter::Negated(inner) => {
PendingFilter::Negated(Box::new(convert_relation_span(model, *inner, joins)?))
}
other => other,
})
}
fn resolve_date_transform(
model: &'static ModelSchema,
dt: DateTransformFilter,
) -> Result<WhereExpr, QueryError> {
let field = model
.field(&dt.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: dt.field.clone(),
})?;
Ok(WhereExpr::ExprCompare {
lhs: Expr::Function {
kind: dt.transform,
args: vec![Expr::Column(field.column)],
},
op: dt.op,
rhs: Expr::Literal(dt.value),
})
}
fn resolve_assignment(
model: &'static ModelSchema,
raw: RawAssignment,
) -> Result<Assignment, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
Ok(Assignment {
column: field.column,
value: raw.value.into(),
})
}
fn resolve_assignment_expr(
model: &'static ModelSchema,
raw: RawExprAssignment,
) -> Result<Assignment, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
validate_expr_columns_in_model(model, &raw.value)?;
Ok(Assignment {
column: field.column,
value: raw.value,
})
}
fn validate_expr_columns_in_model(
model: &'static ModelSchema,
expr: &crate::core::Expr,
) -> Result<(), QueryError> {
use crate::core::Expr;
match expr {
Expr::Literal(_) => Ok(()),
Expr::Column(name) => {
if model.field_by_column(name).is_none() {
Err(QueryError::UnknownField {
model: model.name,
field: (*name).to_owned(),
})
} else {
Ok(())
}
}
Expr::BinOp { left, right, .. } => {
validate_expr_columns_in_model(model, left)?;
validate_expr_columns_in_model(model, right)
}
Expr::Function { args, .. } => {
for a in args {
validate_expr_columns_in_model(model, a)?;
}
Ok(())
}
Expr::Cast { expr: inner, .. } => validate_expr_columns_in_model(model, inner),
Expr::Case { branches, default } => {
for b in branches {
b.condition.validate(model)?;
validate_expr_columns_in_model(model, &b.then)?;
}
if let Some(d) = default {
validate_expr_columns_in_model(model, d)?;
}
Ok(())
}
Expr::Subquery(_)
| Expr::AggregateSubquery(_)
| Expr::OuterRef(_)
| Expr::RelAggregate { .. }
| Expr::AliasedColumn { .. } => Ok(()),
Expr::Window(w) => {
for col in &w.partition_by {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
for o in &w.order_by {
if model.field_by_column(o.column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: o.column.to_owned(),
});
}
}
for arg in &w.args {
validate_expr_columns_in_model(model, arg)?;
}
Ok(())
}
Expr::Aggregate(_) => Ok(()),
Expr::JsonPath { source, .. } => validate_expr_columns_in_model(model, source),
}
}
pub struct AggregateBuilder<T: Model> {
qs: QuerySet<T>,
group_by: Vec<&'static str>,
aggregates: Vec<(std::borrow::Cow<'static, str>, AggregateExpr)>,
aliases: Vec<(std::borrow::Cow<'static, str>, AggregateExpr)>,
having: Option<WhereExpr>,
order_by: Vec<(&'static str, bool)>,
limit: Option<i64>,
offset: Option<i64>,
deferred_error: Option<crate::core::QueryError>,
values: Option<Vec<&'static str>>,
}
impl<T: Model> AggregateBuilder<T> {
#[must_use]
pub fn group_by(mut self, column: &'static str) -> Self {
self.group_by.push(column);
self
}
#[must_use]
pub fn values(mut self, columns: &[&'static str]) -> Self {
self.values = Some(columns.to_vec());
self
}
#[must_use]
pub fn annotate(mut self, alias: &'static str, expr: AggregateExpr) -> Self {
self.aggregates
.push((std::borrow::Cow::Borrowed(alias), expr));
self
}
#[must_use]
pub fn annotate_subquery(self, alias: &'static str, inner: crate::core::SelectQuery) -> Self {
self.annotate(alias, crate::core::subquery::scalar_subquery(inner))
}
#[must_use]
pub fn alias(mut self, name: &'static str, expr: AggregateExpr) -> Self {
self.aliases.push((std::borrow::Cow::Borrowed(name), expr));
self
}
#[must_use]
pub fn having<E: Into<crate::core::TypedExpr<T>>>(mut self, predicate: E) -> Self {
let expr = predicate.into().into_expr();
match self.having {
None => self.having = Some(expr),
Some(ref mut existing) => existing.push_and(expr),
}
self
}
#[must_use]
pub fn filter(
mut self,
field: &'static str,
op: crate::core::Op,
value: impl Into<crate::core::SqlValue>,
) -> Self {
if self.deferred_error.is_some() {
return self;
}
let agg = self
.aggregates
.iter()
.chain(self.aliases.iter())
.find(|(alias, _)| *alias == field)
.map(|(_, expr)| expr.clone());
if let Some(agg) = agg {
if matches!(
op,
crate::core::Op::JsonContains
| crate::core::Op::JsonContainedBy
| crate::core::Op::JsonHasKey
| crate::core::Op::JsonHasAnyKey
| crate::core::Op::JsonHasAllKeys
| crate::core::Op::IsDistinctFrom
| crate::core::Op::IsNotDistinctFrom
) {
self.deferred_error = Some(crate::core::QueryError::HavingOpNotSupported {
alias: field.to_owned(),
op,
});
return self;
}
let pred = WhereExpr::ExprCompare {
lhs: crate::core::Expr::Aggregate(Box::new(agg)),
op,
rhs: crate::core::Expr::Literal(value.into()),
};
match self.having {
None => self.having = Some(pred),
Some(ref mut existing) => existing.push_and(pred),
}
} else {
self.qs = self.qs.filter_op(field, op, value);
}
self
}
#[must_use]
pub fn order_by(mut self, items: &[(&'static str, bool)]) -> Self {
self.order_by.extend_from_slice(items);
self
}
#[must_use]
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
pub fn compile(mut self) -> Result<AggregateQuery, QueryError> {
if let Some(e) = self.deferred_error {
return Err(e);
}
self.qs.apply_global_scopes();
let model = T::SCHEMA;
let where_clause = resolve_pending(model, self.qs.pending)?;
for (_alias, expr) in self.aggregates.iter().chain(self.aliases.iter()) {
validate_aggregate_expr_columns(model, expr)?;
}
let alias_for = |name: &str| -> Option<crate::core::AggregateExpr> {
self.aggregates
.iter()
.chain(self.aliases.iter())
.find(|(a, _)| a.as_ref() == name)
.map(|(_, e)| e.clone())
};
let order_by = self
.order_by
.into_iter()
.map(|(col, desc)| match alias_for(col) {
Some(agg) => {
crate::core::OrderItem::expr(crate::core::Expr::Aggregate(Box::new(agg)), desc)
}
None => crate::core::OrderItem::column(col, desc),
})
.collect();
let has_aggregating = self
.aggregates
.iter()
.chain(self.aliases.iter())
.any(|(_, e)| e.is_aggregating());
let group_by = if !self.group_by.is_empty() {
for col in &self.group_by {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
self.group_by
} else if let Some(cols) = self.values.as_ref() {
if !has_aggregating {
return Err(QueryError::ValuesRequiresAggregate { cols: cols.clone() });
}
for col in cols {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
cols.clone()
} else if has_aggregating {
model.scalar_fields().map(|f| f.column).collect()
} else {
Vec::new()
};
let (where_clause, limit) = if self.qs.is_none {
(never_match_clause(model, where_clause)?, Some(0))
} else {
(where_clause, self.limit)
};
Ok(AggregateQuery {
model,
where_clause,
group_by,
aggregates: self.aggregates,
aliases: self.aliases,
having: self.having,
order_by,
limit,
offset: self.offset,
})
}
}
fn validate_aggregate_expr_columns(
model: &'static ModelSchema,
expr: &crate::core::AggregateExpr,
) -> Result<(), QueryError> {
use crate::core::AggregateExpr;
match expr {
AggregateExpr::Filtered { inner, filter } => {
filter.validate(model)?;
validate_aggregate_expr_columns(model, inner)
}
AggregateExpr::Coalesced { inner, .. } => validate_aggregate_expr_columns(model, inner),
AggregateExpr::Window(w) => {
for col in &w.partition_by {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
for o in &w.order_by {
if model.field_by_column(o.column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: o.column.to_owned(),
});
}
}
for arg in &w.args {
validate_expr_columns_in_model(model, arg)?;
}
Ok(())
}
_ => Ok(()),
}
}
pub struct ValuesQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) cols: Vec<&'static str>,
}
pub struct ValuesListQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) cols: Vec<&'static str>,
}
pub struct ValuesFlatQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) col: &'static str,
}
impl<T: Model> ValuesQuerySet<T> {
pub fn compile(self) -> Result<SelectQuery, QueryError> {
compile_values_select(self.qs, self.cols)
}
#[must_use]
pub fn columns(&self) -> &[&'static str] {
&self.cols
}
}
impl<T: Model> ValuesListQuerySet<T> {
pub fn compile(self) -> Result<SelectQuery, QueryError> {
compile_values_select(self.qs, self.cols)
}
#[must_use]
pub fn columns(&self) -> &[&'static str] {
&self.cols
}
}
impl<T: Model> ValuesFlatQuerySet<T> {
pub fn compile(self) -> Result<SelectQuery, QueryError> {
compile_values_select(self.qs, vec![self.col])
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DateKind {
Year,
Month,
Day,
}
impl DateKind {
pub(crate) fn trunc_sql(self, dialect_name: &str, col_quoted: &str) -> String {
match (dialect_name, self) {
("postgres", DateKind::Year) => format!("DATE_TRUNC('year', {col_quoted})::date"),
("postgres", DateKind::Month) => format!("DATE_TRUNC('month', {col_quoted})::date"),
("postgres", DateKind::Day) => format!("DATE({col_quoted})"),
("mysql", DateKind::Year) => {
format!("DATE(DATE_FORMAT({col_quoted}, '%Y-01-01'))")
}
("mysql", DateKind::Month) => {
format!("DATE(DATE_FORMAT({col_quoted}, '%Y-%m-01'))")
}
("mysql", DateKind::Day) => format!("DATE({col_quoted})"),
("sqlite", DateKind::Year) => {
format!("date(strftime('%Y-01-01', {col_quoted}))")
}
("sqlite", DateKind::Month) => {
format!("date(strftime('%Y-%m-01', {col_quoted}))")
}
("sqlite", DateKind::Day) => format!("date({col_quoted})"),
(_, DateKind::Year) => format!("DATE_TRUNC('year', {col_quoted})"),
(_, DateKind::Month) => format!("DATE_TRUNC('month', {col_quoted})"),
(_, DateKind::Day) => format!("DATE({col_quoted})"),
}
}
}
pub struct DatesQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) field: &'static str,
pub(crate) kind: DateKind,
pub(crate) descending: bool,
}
impl<T: Model> DatesQuerySet<T> {
#[must_use]
pub fn order_desc(mut self, desc: bool) -> Self {
self.descending = desc;
self
}
pub fn resolve_column(&self) -> Result<&'static str, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let field = model
.field(self.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: self.field.to_owned(),
})?;
if !matches!(
field.ty,
crate::core::FieldType::Date | crate::core::FieldType::DateTime
) {
return Err(QueryError::TypeMismatch {
model: model.name,
field: self.field.to_owned(),
expected: crate::core::FieldType::DateTime,
actual: field.ty,
});
}
Ok(field.column)
}
}
impl<T: Model> QuerySet<T> {
#[must_use]
pub fn dates(self, field: &'static str, kind: DateKind) -> DatesQuerySet<T> {
DatesQuerySet {
qs: self,
field,
kind,
descending: false,
}
}
#[must_use]
pub fn datetimes(self, field: &'static str, kind: DateTimeKind) -> DateTimesQuerySet<T> {
DateTimesQuerySet {
qs: self,
field,
kind,
descending: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DateTimeKind {
Year,
Month,
Day,
Hour,
Minute,
Second,
}
impl DateTimeKind {
pub(crate) fn trunc_sql(self, dialect_name: &str, col_quoted: &str) -> String {
match dialect_name {
"postgres" => {
let unit = match self {
DateTimeKind::Year => "year",
DateTimeKind::Month => "month",
DateTimeKind::Day => "day",
DateTimeKind::Hour => "hour",
DateTimeKind::Minute => "minute",
DateTimeKind::Second => "second",
};
format!("DATE_TRUNC('{unit}', {col_quoted})")
}
"mysql" => {
let fmt = match self {
DateTimeKind::Year => "%Y-01-01 00:00:00",
DateTimeKind::Month => "%Y-%m-01 00:00:00",
DateTimeKind::Day => "%Y-%m-%d 00:00:00",
DateTimeKind::Hour => "%Y-%m-%d %H:00:00",
DateTimeKind::Minute => "%Y-%m-%d %H:%i:00",
DateTimeKind::Second => "%Y-%m-%d %H:%i:%s",
};
format!("CAST(DATE_FORMAT({col_quoted}, '{fmt}') AS DATETIME)")
}
_ => {
let fmt = match self {
DateTimeKind::Year => "%Y-01-01 00:00:00",
DateTimeKind::Month => "%Y-%m-01 00:00:00",
DateTimeKind::Day => "%Y-%m-%d 00:00:00",
DateTimeKind::Hour => "%Y-%m-%d %H:00:00",
DateTimeKind::Minute => "%Y-%m-%d %H:%M:00",
DateTimeKind::Second => "%Y-%m-%d %H:%M:%S",
};
format!("strftime('{fmt}', {col_quoted})")
}
}
}
}
pub struct DateTimesQuerySet<T: Model> {
pub(crate) qs: QuerySet<T>,
pub(crate) field: &'static str,
pub(crate) kind: DateTimeKind,
pub(crate) descending: bool,
}
impl<T: Model> DateTimesQuerySet<T> {
#[must_use]
pub fn order_desc(mut self, desc: bool) -> Self {
self.descending = desc;
self
}
pub fn resolve_column(&self) -> Result<&'static str, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let field = model
.field(self.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: self.field.to_owned(),
})?;
if !matches!(
field.ty,
crate::core::FieldType::Date | crate::core::FieldType::DateTime
) {
return Err(QueryError::TypeMismatch {
model: model.name,
field: self.field.to_owned(),
expected: crate::core::FieldType::DateTime,
actual: field.ty,
});
}
Ok(field.column)
}
}
fn compile_values_select<T: Model>(
qs: QuerySet<T>,
cols: Vec<&'static str>,
) -> Result<SelectQuery, QueryError> {
if cols.is_empty() {
return Err(QueryError::EmptyValuesProjection);
}
let model: &'static ModelSchema = T::SCHEMA;
for col in &cols {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
let mut q = qs.compile()?;
q.projection = Some(cols);
Ok(q)
}