use std::marker::PhantomData;
use crate::core::{
AggregateExpr, AggregateQuery, Assignment, DeleteQuery, Filter, Model, ModelSchema, Op,
OrderClause, QueryError, SelectQuery, SqlValue, TypedAssignment, TypedExpr, UpdateQuery,
WhereExpr,
};
pub struct QuerySet<T: Model> {
pending: Vec<PendingFilter>,
limit: Option<i64>,
offset: Option<i64>,
select_related: Vec<String>,
order_by: Vec<(String, bool)>,
_model: PhantomData<fn() -> T>,
}
enum PendingFilter {
Raw(RawFilter),
Resolved(Filter),
Expr(WhereExpr),
}
#[derive(Debug, Clone)]
struct RawFilter {
field: String,
op: Op,
value: SqlValue,
}
#[derive(Debug, Clone)]
struct RawAssignment {
field: String,
value: SqlValue,
}
impl<T: Model> Default for QuerySet<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Model> QuerySet<T> {
#[must_use]
pub fn new() -> Self {
Self {
pending: Vec::new(),
limit: None,
offset: None,
select_related: Vec::new(),
order_by: Vec::new(),
_model: PhantomData,
}
}
#[must_use]
pub fn order_by(mut self, items: &[(&str, bool)]) -> Self {
for (field, desc) in items {
self.order_by.push(((*field).to_owned(), *desc));
}
self
}
#[must_use]
pub fn select_related(mut self, field: impl Into<String>) -> Self {
self.select_related.push(field.into());
self
}
#[must_use]
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
#[must_use]
pub fn filter(mut self, field: impl Into<String>, op: Op, value: impl Into<SqlValue>) -> Self {
self.pending.push(PendingFilter::Raw(RawFilter {
field: field.into(),
op,
value: value.into(),
}));
self
}
#[must_use]
pub fn eq(self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.filter(field, Op::Eq, value)
}
#[must_use]
pub fn where_<E: Into<TypedExpr<T>>>(mut self, predicate: E) -> Self {
let expr = predicate.into().into_expr();
match expr {
WhereExpr::Predicate(filter) => {
self.pending.push(PendingFilter::Resolved(filter));
}
other => {
self.pending.push(PendingFilter::Expr(other));
}
}
self
}
pub fn compile(self) -> Result<SelectQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let where_clause = resolve_pending(model, self.pending)?;
let joins = lower_select_related(model, &self.select_related)?;
let order_by = lower_order_by(model, &self.order_by)?;
Ok(SelectQuery {
model,
where_clause,
search: None,
joins,
order_by,
limit: self.limit,
offset: self.offset,
})
}
pub fn compile_delete(self) -> Result<DeleteQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let where_clause = resolve_pending(model, self.pending)?;
Ok(DeleteQuery {
model,
where_clause,
})
}
#[must_use]
pub fn update(self) -> UpdateBuilder<T> {
UpdateBuilder {
qs: self,
set: Vec::new(),
}
}
#[must_use]
pub fn aggregate(self) -> AggregateBuilder<T> {
AggregateBuilder {
qs: self,
group_by: Vec::new(),
aggregates: Vec::new(),
having: None,
order_by: Vec::new(),
limit: None,
offset: None,
}
}
}
pub struct UpdateBuilder<T: Model> {
qs: QuerySet<T>,
set: Vec<PendingAssignment>,
}
enum PendingAssignment {
Raw(RawAssignment),
Resolved(Assignment),
}
impl<T: Model> UpdateBuilder<T> {
#[must_use]
pub fn set(mut self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.set.push(PendingAssignment::Raw(RawAssignment {
field: field.into(),
value: value.into(),
}));
self
}
#[must_use]
pub fn set_typed(mut self, assignment: TypedAssignment<T>) -> Self {
self.set
.push(PendingAssignment::Resolved(assignment.into_assignment()));
self
}
pub fn compile(self) -> Result<UpdateQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let assignments = self
.set
.into_iter()
.map(|p| match p {
PendingAssignment::Raw(raw) => resolve_assignment(model, raw),
PendingAssignment::Resolved(assignment) => Ok(assignment),
})
.collect::<Result<Vec<_>, _>>()?;
let where_clause = resolve_pending(model, self.qs.pending)?;
Ok(UpdateQuery {
model,
set: assignments,
where_clause,
})
}
}
fn lower_order_by(
model: &'static ModelSchema,
items: &[(String, bool)],
) -> Result<Vec<crate::core::OrderClause>, QueryError> {
let mut out = Vec::with_capacity(items.len());
for (field_name, desc) in items {
let field = model
.field(field_name)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: field_name.clone(),
})?;
out.push(crate::core::OrderClause {
column: field.column,
desc: *desc,
});
}
Ok(out)
}
fn lower_select_related(
model: &'static ModelSchema,
names: &[String],
) -> Result<Vec<crate::core::Join>, QueryError> {
use crate::core::{inventory, Join, ModelEntry, Relation};
let mut out: Vec<Join> = Vec::with_capacity(names.len());
for name in names {
let field = model
.field(name)
.ok_or_else(|| QueryError::SelectRelatedInvalid {
model: model.name,
field: name.clone(),
reason: format!("no field `{name}` on this model"),
})?;
let (to, on) = match field.relation {
Some(Relation::Fk { to, on }) | Some(Relation::O2O { to, on }) => (to, on),
_ => {
return Err(QueryError::SelectRelatedInvalid {
model: model.name,
field: name.clone(),
reason: "not a `ForeignKey<T>` field".into(),
});
}
};
let target = inventory::iter::<ModelEntry>
.into_iter()
.find(|e| e.schema.table == to)
.map(|e| e.schema)
.ok_or_else(|| QueryError::SelectRelatedInvalid {
model: model.name,
field: name.clone(),
reason: format!(
"target table `{to}` is not registered (is the parent's `#[derive(Model)]` linked into the binary?)"
),
})?;
let project: Vec<&'static str> =
target.scalar_fields().map(|f| f.column).collect();
out.push(Join {
target,
on_local: field.column,
on_remote: on,
alias: field.name,
project,
});
}
Ok(out)
}
fn resolve_pending(
model: &'static ModelSchema,
pending: Vec<PendingFilter>,
) -> Result<WhereExpr, QueryError> {
let mut nodes: Vec<WhereExpr> = Vec::with_capacity(pending.len());
for entry in pending {
match entry {
PendingFilter::Raw(raw) => {
nodes.push(WhereExpr::Predicate(resolve_filter(model, raw)?));
}
PendingFilter::Resolved(filter) => {
nodes.push(WhereExpr::Predicate(filter));
}
PendingFilter::Expr(expr) => {
nodes.push(expr);
}
}
}
Ok(WhereExpr::And(nodes))
}
fn resolve_filter(model: &'static ModelSchema, raw: RawFilter) -> Result<Filter, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
let skip_type_check = matches!(raw.op, Op::IsNull | Op::In);
if !skip_type_check {
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
}
Ok(Filter {
column: field.column,
op: raw.op,
value: raw.value,
})
}
fn resolve_assignment(
model: &'static ModelSchema,
raw: RawAssignment,
) -> Result<Assignment, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
Ok(Assignment {
column: field.column,
value: raw.value,
})
}
pub struct AggregateBuilder<T: Model> {
qs: QuerySet<T>,
group_by: Vec<&'static str>,
aggregates: Vec<(&'static str, AggregateExpr)>,
having: Option<WhereExpr>,
order_by: Vec<(&'static str, bool)>,
limit: Option<i64>,
offset: Option<i64>,
}
impl<T: Model> AggregateBuilder<T> {
#[must_use]
pub fn group_by(mut self, column: &'static str) -> Self {
self.group_by.push(column);
self
}
#[must_use]
pub fn annotate(mut self, alias: &'static str, expr: AggregateExpr) -> Self {
self.aggregates.push((alias, expr));
self
}
#[must_use]
pub fn having<E: Into<crate::core::TypedExpr<T>>>(mut self, predicate: E) -> Self {
let expr = predicate.into().into_expr();
match self.having {
None => self.having = Some(expr),
Some(ref mut existing) => existing.push_and(expr),
}
self
}
#[must_use]
pub fn order_by(mut self, items: &[(&'static str, bool)]) -> Self {
self.order_by.extend_from_slice(items);
self
}
#[must_use]
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
pub fn compile(self) -> Result<AggregateQuery, QueryError> {
let model = T::SCHEMA;
let where_clause = resolve_pending(model, self.qs.pending)?;
let order_by = self.order_by.into_iter().map(|(col, desc)| OrderClause { column: col, desc }).collect();
Ok(AggregateQuery {
model,
where_clause,
group_by: self.group_by,
aggregates: self.aggregates,
having: self.having,
order_by,
limit: self.limit,
offset: self.offset,
})
}
}