use std::marker::PhantomData;
use crate::core::{
Assignment, DeleteQuery, Filter, Model, ModelSchema, Op, QueryError, SelectQuery, SqlValue,
TypedAssignment, TypedExpr, UpdateQuery, WhereExpr,
};
pub struct QuerySet<T: Model> {
pending: Vec<PendingFilter>,
limit: Option<i64>,
offset: Option<i64>,
_model: PhantomData<fn() -> T>,
}
enum PendingFilter {
Raw(RawFilter),
Resolved(Filter),
Expr(WhereExpr),
}
#[derive(Debug, Clone)]
struct RawFilter {
field: String,
op: Op,
value: SqlValue,
}
#[derive(Debug, Clone)]
struct RawAssignment {
field: String,
value: SqlValue,
}
impl<T: Model> Default for QuerySet<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Model> QuerySet<T> {
#[must_use]
pub fn new() -> Self {
Self {
pending: Vec::new(),
limit: None,
offset: None,
_model: PhantomData,
}
}
#[must_use]
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
#[must_use]
pub fn filter(mut self, field: impl Into<String>, op: Op, value: impl Into<SqlValue>) -> Self {
self.pending.push(PendingFilter::Raw(RawFilter {
field: field.into(),
op,
value: value.into(),
}));
self
}
#[must_use]
pub fn eq(self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.filter(field, Op::Eq, value)
}
#[must_use]
pub fn where_<E: Into<TypedExpr<T>>>(mut self, predicate: E) -> Self {
let expr = predicate.into().into_expr();
match expr {
WhereExpr::Predicate(filter) => {
self.pending.push(PendingFilter::Resolved(filter));
}
other => {
self.pending.push(PendingFilter::Expr(other));
}
}
self
}
pub fn compile(self) -> Result<SelectQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let where_clause = resolve_pending(model, self.pending)?;
Ok(SelectQuery {
model,
where_clause,
search: None,
joins: vec![],
limit: self.limit,
offset: self.offset,
})
}
pub fn compile_delete(self) -> Result<DeleteQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let where_clause = resolve_pending(model, self.pending)?;
Ok(DeleteQuery {
model,
where_clause,
})
}
#[must_use]
pub fn update(self) -> UpdateBuilder<T> {
UpdateBuilder {
qs: self,
set: Vec::new(),
}
}
}
pub struct UpdateBuilder<T: Model> {
qs: QuerySet<T>,
set: Vec<PendingAssignment>,
}
enum PendingAssignment {
Raw(RawAssignment),
Resolved(Assignment),
}
impl<T: Model> UpdateBuilder<T> {
#[must_use]
pub fn set(mut self, field: impl Into<String>, value: impl Into<SqlValue>) -> Self {
self.set.push(PendingAssignment::Raw(RawAssignment {
field: field.into(),
value: value.into(),
}));
self
}
#[must_use]
pub fn set_typed(mut self, assignment: TypedAssignment<T>) -> Self {
self.set
.push(PendingAssignment::Resolved(assignment.into_assignment()));
self
}
pub fn compile(self) -> Result<UpdateQuery, QueryError> {
let model: &'static ModelSchema = T::SCHEMA;
let assignments = self
.set
.into_iter()
.map(|p| match p {
PendingAssignment::Raw(raw) => resolve_assignment(model, raw),
PendingAssignment::Resolved(assignment) => Ok(assignment),
})
.collect::<Result<Vec<_>, _>>()?;
let where_clause = resolve_pending(model, self.qs.pending)?;
Ok(UpdateQuery {
model,
set: assignments,
where_clause,
})
}
}
fn resolve_pending(
model: &'static ModelSchema,
pending: Vec<PendingFilter>,
) -> Result<WhereExpr, QueryError> {
let mut nodes: Vec<WhereExpr> = Vec::with_capacity(pending.len());
for entry in pending {
match entry {
PendingFilter::Raw(raw) => {
nodes.push(WhereExpr::Predicate(resolve_filter(model, raw)?));
}
PendingFilter::Resolved(filter) => {
nodes.push(WhereExpr::Predicate(filter));
}
PendingFilter::Expr(expr) => {
nodes.push(expr);
}
}
}
Ok(WhereExpr::And(nodes))
}
fn resolve_filter(model: &'static ModelSchema, raw: RawFilter) -> Result<Filter, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
let skip_type_check = matches!(raw.op, Op::IsNull | Op::In);
if !skip_type_check {
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
}
Ok(Filter {
column: field.column,
op: raw.op,
value: raw.value,
})
}
fn resolve_assignment(
model: &'static ModelSchema,
raw: RawAssignment,
) -> Result<Assignment, QueryError> {
let field = model
.field(&raw.field)
.ok_or_else(|| QueryError::UnknownField {
model: model.name,
field: raw.field.clone(),
})?;
if let Some(value_ty) = raw.value.field_type() {
if value_ty != field.ty {
return Err(QueryError::TypeMismatch {
model: model.name,
field: raw.field,
expected: field.ty,
actual: value_ty,
});
}
}
Ok(Assignment {
column: field.column,
value: raw.value,
})
}