use super::expr::Expr;
use super::{validate::validate_value, ModelSchema, QueryError, SqlValue};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Op {
Eq,
Ne,
Lt,
Lte,
Gt,
Gte,
In,
NotIn,
Like,
NotLike,
ILike,
NotILike,
Between,
IsNull,
IsDistinctFrom,
IsNotDistinctFrom,
JsonContains,
JsonContainedBy,
JsonHasKey,
JsonHasAnyKey,
JsonHasAllKeys,
Regex,
NotRegex,
IRegex,
NotIRegex,
TrigramSimilar,
TrigramWordSimilar,
Search,
ArrayContains,
ArrayContainedBy,
ArrayOverlap,
RangeContains,
RangeContainedBy,
RangeOverlap,
RangeStrictlyLeft,
RangeStrictlyRight,
RangeAdjacent,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Filter {
pub column: &'static str,
pub op: Op,
pub value: SqlValue,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ColumnFilter {
pub column: &'static str,
pub op: Op,
pub rhs: Expr,
}
#[derive(Debug, Clone, PartialEq)]
pub enum WhereExpr {
Predicate(Filter),
ColumnCompare(ColumnFilter),
And(Vec<WhereExpr>),
Or(Vec<WhereExpr>),
Not(Box<WhereExpr>),
Xor(Vec<WhereExpr>),
Exists(Box<SelectQuery>),
NotExists(Box<SelectQuery>),
InSubquery {
column: &'static str,
negated: bool,
subquery: Box<SelectQuery>,
},
ExprCompare { lhs: Expr, op: Op, rhs: Expr },
}
impl WhereExpr {
#[must_use]
pub fn is_empty(&self) -> bool {
matches!(self, Self::And(items) if items.is_empty())
}
#[must_use]
pub fn and_predicates(filters: Vec<Filter>) -> Self {
Self::And(filters.into_iter().map(Self::Predicate).collect())
}
pub fn push_and(&mut self, child: Self) {
match self {
Self::And(items) => items.push(child),
_ => {
let prev = std::mem::replace(self, Self::And(Vec::new()));
if let Self::And(items) = self {
items.push(prev);
items.push(child);
}
}
}
}
#[must_use]
pub fn as_flat_and(&self) -> Option<Vec<&Filter>> {
match self {
Self::Predicate(f) => Some(vec![f]),
Self::And(items) => {
let mut out = Vec::with_capacity(items.len());
for item in items {
match item {
Self::Predicate(f) => out.push(f),
_ => return None,
}
}
Some(out)
}
Self::ColumnCompare(_)
| Self::Or(_)
| Self::Xor(_)
| Self::Not(_)
| Self::Exists(_)
| Self::NotExists(_)
| Self::InSubquery { .. }
| Self::ExprCompare { .. } => None,
}
}
pub fn validate(&self, model: &'static ModelSchema) -> Result<(), QueryError> {
match self {
Self::Predicate(f) => {
if model.field_by_column(f.column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: f.column.to_owned(),
});
}
Ok(())
}
Self::ColumnCompare(cf) => {
if model.field_by_column(cf.column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: cf.column.to_owned(),
});
}
validate_expr_columns(model, &cf.rhs)?;
Ok(())
}
Self::And(items) | Self::Or(items) | Self::Xor(items) => {
for child in items {
child.validate(model)?;
}
Ok(())
}
Self::Not(child) => child.validate(model),
Self::Exists(_) | Self::NotExists(_) => Ok(()),
Self::InSubquery { column, .. } => {
if model.field_by_column(column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*column).to_owned(),
});
}
Ok(())
}
Self::ExprCompare { .. } => Ok(()),
}
}
}
fn validate_expr_columns(model: &'static ModelSchema, expr: &Expr) -> Result<(), QueryError> {
match expr {
Expr::Literal(_) => Ok(()),
Expr::Column(name) => {
if model.field_by_column(name).is_none() {
Err(QueryError::UnknownField {
model: model.name,
field: (*name).to_owned(),
})
} else {
Ok(())
}
}
Expr::BinOp { left, right, .. } => {
validate_expr_columns(model, left)?;
validate_expr_columns(model, right)
}
Expr::Function { args, .. } => {
for a in args {
validate_expr_columns(model, a)?;
}
Ok(())
}
Expr::Case { branches, default } => {
for b in branches {
b.condition.validate(model)?;
validate_expr_columns(model, &b.then)?;
}
if let Some(d) = default {
validate_expr_columns(model, d)?;
}
Ok(())
}
Expr::Subquery(_) | Expr::OuterRef(_) | Expr::AliasedColumn { .. } => Ok(()),
Expr::Window(w) => {
for col in &w.partition_by {
if model.field_by_column(col).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: (*col).to_owned(),
});
}
}
for o in &w.order_by {
if model.field_by_column(o.column).is_none() {
return Err(QueryError::UnknownField {
model: model.name,
field: o.column.to_owned(),
});
}
}
for arg in &w.args {
validate_expr_columns(model, arg)?;
}
Ok(())
}
Expr::Aggregate(_) => Ok(()),
}
}
impl Default for WhereExpr {
fn default() -> Self {
Self::And(Vec::new())
}
}
impl From<Filter> for WhereExpr {
fn from(f: Filter) -> Self {
Self::Predicate(f)
}
}
#[derive(Debug, Clone)]
pub struct SelectQuery {
pub model: &'static ModelSchema,
pub where_clause: WhereExpr,
pub search: Option<SearchClause>,
pub joins: Vec<Join>,
pub order_by: Vec<OrderItem>,
pub limit: Option<i64>,
pub offset: Option<i64>,
pub lock_mode: Option<LockMode>,
pub compound: Vec<CompoundBranch>,
pub projection: Option<Vec<&'static str>>,
}
#[derive(Debug, Clone)]
pub struct CompoundBranch {
pub op: SetOp,
pub query: Box<SelectQuery>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SetOp {
Union,
UnionAll,
Intersection,
Difference,
}
impl SetOp {
#[must_use]
pub fn keyword(self) -> &'static str {
match self {
Self::Union => "UNION",
Self::UnionAll => "UNION ALL",
Self::Intersection => "INTERSECT",
Self::Difference => "EXCEPT",
}
}
}
impl PartialEq for CompoundBranch {
fn eq(&self, other: &Self) -> bool {
self.op == other.op && self.query == other.query
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub struct LockMode {
pub no_key: bool,
pub skip_locked: bool,
pub nowait: bool,
pub of: Vec<&'static str>,
}
impl PartialEq for SelectQuery {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self.model, other.model)
&& self.where_clause == other.where_clause
&& self.search == other.search
&& self.joins == other.joins
&& self.order_by == other.order_by
&& self.limit == other.limit
&& self.offset == other.offset
&& self.lock_mode == other.lock_mode
&& self.compound == other.compound
&& self.projection == other.projection
}
}
impl PartialEq for Join {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self.target, other.target)
&& self.alias == other.alias
&& self.kind == other.kind
&& self.on == other.on
&& self.project == other.project
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OrderClause {
pub column: &'static str,
pub desc: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NullsOrder {
#[default]
Default,
First,
Last,
}
#[derive(Debug, Clone, PartialEq)]
pub enum OrderItem {
Column {
column: &'static str,
desc: bool,
nulls: NullsOrder,
},
Expr {
expr: Expr,
desc: bool,
nulls: NullsOrder,
},
Random,
}
impl From<OrderClause> for OrderItem {
fn from(c: OrderClause) -> Self {
Self::Column {
column: c.column,
desc: c.desc,
nulls: NullsOrder::Default,
}
}
}
impl OrderItem {
#[must_use]
pub fn column(column: &'static str, desc: bool) -> Self {
Self::Column {
column,
desc,
nulls: NullsOrder::Default,
}
}
#[must_use]
pub fn column_with_nulls(column: &'static str, desc: bool, nulls: NullsOrder) -> Self {
Self::Column {
column,
desc,
nulls,
}
}
#[must_use]
pub fn expr(expr: Expr, desc: bool) -> Self {
Self::Expr {
expr,
desc,
nulls: NullsOrder::Default,
}
}
#[must_use]
pub fn expr_with_nulls(expr: Expr, desc: bool, nulls: NullsOrder) -> Self {
Self::Expr { expr, desc, nulls }
}
#[must_use]
pub fn random() -> Self {
Self::Random
}
#[must_use]
pub fn column_name(&self) -> Option<&'static str> {
match self {
Self::Column { column, .. } => Some(column),
Self::Expr { .. } | Self::Random => None,
}
}
#[must_use]
pub fn is_desc(&self) -> bool {
match self {
Self::Column { desc, .. } | Self::Expr { desc, .. } => *desc,
Self::Random => false,
}
}
#[must_use]
pub fn nulls_order(&self) -> NullsOrder {
match self {
Self::Column { nulls, .. } | Self::Expr { nulls, .. } => *nulls,
Self::Random => NullsOrder::Default,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum JoinKind {
Inner,
#[default]
Left,
Right,
Full,
}
#[derive(Debug, Clone)]
pub struct Join {
pub target: &'static ModelSchema,
pub alias: &'static str,
pub kind: JoinKind,
pub on: WhereExpr,
pub project: Vec<&'static str>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SearchClause {
pub columns: Vec<&'static str>,
pub query: String,
}
#[derive(Debug, Clone)]
pub enum ConflictClause {
DoNothing,
DoUpdate {
target: Vec<&'static str>,
update_columns: Vec<&'static str>,
},
}
#[derive(Debug, Clone)]
pub struct InsertQuery {
pub model: &'static ModelSchema,
pub columns: Vec<&'static str>,
pub values: Vec<SqlValue>,
pub returning: Vec<&'static str>,
pub on_conflict: Option<ConflictClause>,
}
impl InsertQuery {
pub fn validate(&self) -> Result<(), QueryError> {
for (column, value) in self.columns.iter().zip(self.values.iter()) {
let field =
self.model
.field_by_column(column)
.ok_or_else(|| QueryError::UnknownField {
model: self.model.name,
field: (*column).to_owned(),
})?;
validate_value(self.model.name, field, value)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct BulkInsertQuery {
pub model: &'static ModelSchema,
pub columns: Vec<&'static str>,
pub rows: Vec<Vec<SqlValue>>,
pub returning: Vec<&'static str>,
pub on_conflict: Option<ConflictClause>,
}
impl BulkInsertQuery {
pub fn validate(&self) -> Result<(), QueryError> {
for row in &self.rows {
for (column, value) in self.columns.iter().zip(row.iter()) {
let field =
self.model
.field_by_column(column)
.ok_or_else(|| QueryError::UnknownField {
model: self.model.name,
field: (*column).to_owned(),
})?;
validate_value(self.model.name, field, value)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Assignment {
pub column: &'static str,
pub value: Expr,
}
#[derive(Debug, Clone)]
pub struct UpdateQuery {
pub model: &'static ModelSchema,
pub set: Vec<Assignment>,
pub where_clause: WhereExpr,
}
impl UpdateQuery {
pub fn validate(&self) -> Result<(), QueryError> {
for assignment in &self.set {
let field = self
.model
.field_by_column(assignment.column)
.ok_or_else(|| QueryError::UnknownField {
model: self.model.name,
field: assignment.column.to_owned(),
})?;
if let Some(literal) = assignment.value.as_literal() {
validate_value(self.model.name, field, literal)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DeleteQuery {
pub model: &'static ModelSchema,
pub where_clause: WhereExpr,
}
#[derive(Debug, Clone)]
pub struct CountQuery {
pub model: &'static ModelSchema,
pub where_clause: WhereExpr,
pub search: Option<SearchClause>,
}
#[derive(Debug, Clone)]
pub struct BulkUpdateQuery {
pub model: &'static ModelSchema,
pub update_columns: Vec<&'static str>,
pub rows: Vec<Vec<SqlValue>>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AggregateExpr {
Count(Option<&'static str>),
CountDistinct(&'static str),
Sum(&'static str),
Avg(&'static str),
Max(&'static str),
Min(&'static str),
StdDev(&'static str),
StdDevPop(&'static str),
Variance(&'static str),
VariancePop(&'static str),
Filtered {
inner: Box<AggregateExpr>,
filter: WhereExpr,
},
Coalesced {
inner: Box<AggregateExpr>,
default: SqlValue,
},
Window(Box<super::window::WindowExpr>),
ArrayAgg {
column: &'static str,
distinct: bool,
},
StringAgg {
column: &'static str,
delimiter: String,
distinct: bool,
},
JsonbAgg { column: &'static str },
}
impl AggregateExpr {
#[must_use]
pub fn is_aggregating(&self) -> bool {
match self {
AggregateExpr::Count(_)
| AggregateExpr::CountDistinct(_)
| AggregateExpr::Sum(_)
| AggregateExpr::Avg(_)
| AggregateExpr::Max(_)
| AggregateExpr::Min(_)
| AggregateExpr::StdDev(_)
| AggregateExpr::StdDevPop(_)
| AggregateExpr::Variance(_)
| AggregateExpr::VariancePop(_)
| AggregateExpr::ArrayAgg { .. }
| AggregateExpr::StringAgg { .. }
| AggregateExpr::JsonbAgg { .. } => true,
AggregateExpr::Window(_) => false,
AggregateExpr::Filtered { inner, .. } | AggregateExpr::Coalesced { inner, .. } => {
inner.is_aggregating()
}
}
}
#[must_use]
pub const fn array_agg(column: &'static str) -> Self {
Self::ArrayAgg {
column,
distinct: false,
}
}
#[must_use]
pub const fn array_agg_distinct(column: &'static str) -> Self {
Self::ArrayAgg {
column,
distinct: true,
}
}
#[must_use]
pub fn string_agg(column: &'static str, delimiter: impl Into<String>) -> Self {
Self::StringAgg {
column,
delimiter: delimiter.into(),
distinct: false,
}
}
#[must_use]
pub fn string_agg_distinct(column: &'static str, delimiter: impl Into<String>) -> Self {
Self::StringAgg {
column,
delimiter: delimiter.into(),
distinct: true,
}
}
#[must_use]
pub const fn jsonb_agg(column: &'static str) -> Self {
Self::JsonbAgg { column }
}
}
#[derive(Debug, Clone)]
pub struct AggregateQuery {
pub model: &'static ModelSchema,
pub where_clause: WhereExpr,
pub group_by: Vec<&'static str>,
pub aggregates: Vec<(&'static str, AggregateExpr)>,
pub having: Option<WhereExpr>,
pub order_by: Vec<OrderItem>,
pub limit: Option<i64>,
pub offset: Option<i64>,
}