use std::fmt;
use crate::query::ast::{OrderTerm, SelectStatement};
use crate::value::{BindValue, Value};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggFunc {
Count,
Sum,
Avg,
Min,
Max,
StringAggregation,
ArrayAggregation,
JsonAggregation,
JsonbAggregation,
BoolAnd,
BoolOr,
}
impl AggFunc {
pub fn as_sql(self) -> &'static str {
match self {
AggFunc::Count => "COUNT",
AggFunc::Sum => "SUM",
AggFunc::Avg => "AVG",
AggFunc::Min => "MIN",
AggFunc::Max => "MAX",
AggFunc::StringAggregation => "string_agg",
AggFunc::ArrayAggregation => "array_agg",
AggFunc::JsonAggregation => "json_agg",
AggFunc::JsonbAggregation => "jsonb_agg",
AggFunc::BoolAnd => "bool_and",
AggFunc::BoolOr => "bool_or",
}
}
}
#[derive(Debug, Clone)]
pub enum WindowBound {
UnboundedPreceding,
Preceding(Box<Expr>),
CurrentRow,
Following(Box<Expr>),
UnboundedFollowing,
}
impl fmt::Display for WindowBound {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WindowBound::UnboundedPreceding => f.write_str("UNBOUNDED PRECEDING"),
WindowBound::Preceding(_) => f.write_str("PRECEDING"),
WindowBound::CurrentRow => f.write_str("CURRENT ROW"),
WindowBound::Following(_) => f.write_str("FOLLOWING"),
WindowBound::UnboundedFollowing => f.write_str("UNBOUNDED FOLLOWING"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WindowFrameUnit {
Rows,
Range,
Groups,
}
impl fmt::Display for WindowFrameUnit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WindowFrameUnit::Rows => f.write_str("ROWS"),
WindowFrameUnit::Range => f.write_str("RANGE"),
WindowFrameUnit::Groups => f.write_str("GROUPS"),
}
}
}
#[derive(Debug, Clone)]
pub struct WindowFrame {
pub unit: WindowFrameUnit,
pub start: WindowBound,
pub end: Option<WindowBound>,
}
#[derive(Debug, Clone)]
pub struct Window {
pub partition_by: Vec<Expr>,
pub order_by: Vec<OrderTerm>,
pub frame: Option<WindowFrame>,
}
impl Default for Window {
fn default() -> Self {
Self {
partition_by: Vec::new(),
order_by: Vec::new(),
frame: None,
}
}
}
pub struct ExprOver {
expr: Expr,
window: Window,
}
impl ExprOver {
pub fn partition_by(mut self, cols: impl IntoIterator<Item = Expr>) -> Self {
self.window.partition_by = cols.into_iter().collect();
self
}
pub fn order_by(mut self, terms: impl IntoIterator<Item = OrderTerm>) -> Self {
self.window.order_by = terms.into_iter().collect();
self
}
pub fn rows_between(mut self, start: WindowBound, end: WindowBound) -> Self {
self.window.frame = Some(WindowFrame {
unit: WindowFrameUnit::Rows,
start,
end: Some(end),
});
self
}
pub fn range_between(mut self, start: WindowBound, end: WindowBound) -> Self {
self.window.frame = Some(WindowFrame {
unit: WindowFrameUnit::Range,
start,
end: Some(end),
});
self
}
pub fn groups_between(mut self, start: WindowBound, end: WindowBound) -> Self {
self.window.frame = Some(WindowFrame {
unit: WindowFrameUnit::Groups,
start,
end: Some(end),
});
self
}
pub fn end(self) -> Expr {
Expr::Over {
expr: Box::new(self.expr),
window: Box::new(self.window),
}
}
}
impl From<ExprOver> for Expr {
fn from(over: ExprOver) -> Self {
over.end()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOp {
Eq,
Ne,
Gt,
Ge,
Lt,
Le,
Add,
Sub,
Mul,
Div,
Mod,
Like,
ILike,
JsonGet,
JsonGetText,
Contains,
Overlap,
IsDistinctFrom,
IsNotDistinctFrom,
JsonKeyExists,
JsonKeyExistsAny,
JsonKeyExistsAll,
JsonPath,
JsonPathText,
ArrayContainedBy,
TsMatch,
}
impl BinaryOp {
pub fn as_sql(self) -> &'static str {
match self {
BinaryOp::Eq => "=",
BinaryOp::Ne => "<>",
BinaryOp::Gt => ">",
BinaryOp::Ge => ">=",
BinaryOp::Lt => "<",
BinaryOp::Le => "<=",
BinaryOp::Add => "+",
BinaryOp::Sub => "-",
BinaryOp::Mul => "*",
BinaryOp::Div => "/",
BinaryOp::Mod => "%",
BinaryOp::Like => "LIKE",
BinaryOp::ILike => "ILIKE",
BinaryOp::JsonGet => "->",
BinaryOp::JsonGetText => "->>",
BinaryOp::Contains => "@>",
BinaryOp::Overlap => "&&",
BinaryOp::IsDistinctFrom => "IS DISTINCT FROM",
BinaryOp::IsNotDistinctFrom => "IS NOT DISTINCT FROM",
BinaryOp::JsonKeyExists => "?",
BinaryOp::JsonKeyExistsAny => "?|",
BinaryOp::JsonKeyExistsAll => "?&",
BinaryOp::JsonPath => "#>",
BinaryOp::JsonPathText => "#>>",
BinaryOp::ArrayContainedBy => "<@",
BinaryOp::TsMatch => "@@",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogicalOp {
And,
Or,
}
impl LogicalOp {
pub fn as_sql(self) -> &'static str {
match self {
LogicalOp::And => "AND",
LogicalOp::Or => "OR",
}
}
}
#[derive(Debug, Clone)]
pub enum Expr {
Column {
table: &'static str,
column: &'static str,
},
Value(Value),
Binary {
left: Box<Expr>,
op: BinaryOp,
right: Box<Expr>,
},
Logical {
op: LogicalOp,
items: Vec<Expr>,
},
Not(Box<Expr>),
InList {
expr: Box<Expr>,
values: Vec<Value>,
},
IsNull {
expr: Box<Expr>,
negated: bool,
},
Aggregate {
func: AggFunc,
args: Vec<Expr>,
filter: Option<Box<Expr>>,
},
Func {
name: String,
args: Vec<Expr>,
},
CountStar,
Alias {
expr: Box<Expr>,
alias: &'static str,
},
Between {
expr: Box<Expr>,
low: Box<Expr>,
high: Box<Expr>,
},
Case {
whens: Vec<(Expr, Expr)>,
else_expr: Option<Box<Expr>>,
},
Subquery(Box<SelectStatement>),
InSubquery {
expr: Box<Expr>,
subquery: Box<SelectStatement>,
negated: bool,
},
Raw {
sql: String,
params: Vec<Value>,
},
Exists {
subquery: Box<SelectStatement>,
negated: bool,
},
Excluded(&'static str),
Extract {
field: String,
source: Box<Expr>,
},
Over {
expr: Box<Expr>,
window: Box<Window>,
},
}
pub struct CaseWhen {
whens: Vec<(Expr, Expr)>,
else_expr: Option<Box<Expr>>,
}
impl CaseWhen {
pub fn when(mut self, cond: Expr, result: Expr) -> Self {
self.whens.push((cond, result));
self
}
pub fn else_(mut self, default: Expr) -> Self {
self.else_expr = Some(Box::new(default));
self
}
pub fn end(self) -> Expr {
Expr::Case {
whens: self.whens,
else_expr: self.else_expr,
}
}
}
impl Expr {
pub fn column(table: &'static str, column: &'static str) -> Self {
Expr::Column { table, column }
}
pub fn value(value: Value) -> Self {
Expr::Value(value)
}
pub fn excluded(column: &'static str) -> Self {
Expr::Excluded(column)
}
pub fn binary(left: Expr, op: BinaryOp, right: Expr) -> Self {
Expr::Binary {
left: Box::new(left),
op,
right: Box::new(right),
}
}
pub fn all(items: impl IntoIterator<Item = Expr>) -> Self {
Expr::Logical {
op: LogicalOp::And,
items: items.into_iter().collect(),
}
}
pub fn any(items: impl IntoIterator<Item = Expr>) -> Self {
Expr::Logical {
op: LogicalOp::Or,
items: items.into_iter().collect(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn not(expr: Expr) -> Self {
Expr::Not(Box::new(expr))
}
pub fn in_list(expr: Expr, values: Vec<Value>) -> Self {
Expr::InList {
expr: Box::new(expr),
values,
}
}
pub fn is_null(expr: Expr, negated: bool) -> Self {
Expr::IsNull {
expr: Box::new(expr),
negated,
}
}
pub fn aggregate(func: AggFunc, args: impl IntoIterator<Item = Expr>) -> Self {
Expr::Aggregate {
func,
args: args.into_iter().collect(),
filter: None,
}
}
pub fn filter(self, predicate: Expr) -> Self {
match self {
Expr::Aggregate { func, args, filter: None } => Expr::Aggregate {
func,
args,
filter: Some(Box::new(predicate)),
},
_ => panic!("filter() can only be called on Aggregate expressions"),
}
}
pub fn func(name: impl Into<String>, args: impl IntoIterator<Item = Expr>) -> Self {
Expr::Func {
name: name.into(),
args: args.into_iter().collect(),
}
}
pub fn between(expr: Expr, low: Expr, high: Expr) -> Self {
Expr::Between {
expr: Box::new(expr),
low: Box::new(low),
high: Box::new(high),
}
}
pub fn case() -> CaseWhen {
CaseWhen {
whens: Vec::new(),
else_expr: None,
}
}
pub fn subquery(stmt: SelectStatement) -> Self {
Expr::Subquery(Box::new(stmt))
}
pub fn in_subquery(expr: Expr, stmt: SelectStatement, negated: bool) -> Self {
Expr::InSubquery {
expr: Box::new(expr),
subquery: Box::new(stmt),
negated,
}
}
pub fn raw(sql: impl Into<String>) -> Self {
Expr::Raw { sql: sql.into(), params: Vec::new() }
}
pub fn over(self) -> ExprOver {
ExprOver {
expr: self,
window: Window::default(),
}
}
pub fn exists<X: crate::model::Model>(qs: crate::query::queryset::QuerySet<X>) -> Self {
Expr::Exists { subquery: Box::new(qs.into_statement()), negated: false }
}
pub fn not_exists<X: crate::model::Model>(qs: crate::query::queryset::QuerySet<X>) -> Self {
Expr::Exists { subquery: Box::new(qs.into_statement()), negated: true }
}
pub fn as_(self, alias: &'static str) -> Self {
Expr::Alias {
expr: Box::new(self),
alias,
}
}
fn compare(self, op: BinaryOp, value: impl BindValue) -> Self {
Expr::binary(self, op, Expr::Value(value.to_value()))
}
pub fn eq(self, value: impl BindValue) -> Self {
self.compare(BinaryOp::Eq, value)
}
pub fn ne(self, value: impl BindValue) -> Self {
self.compare(BinaryOp::Ne, value)
}
pub fn gt(self, value: impl BindValue) -> Self {
self.compare(BinaryOp::Gt, value)
}
pub fn ge(self, value: impl BindValue) -> Self {
self.compare(BinaryOp::Ge, value)
}
pub fn lt(self, value: impl BindValue) -> Self {
self.compare(BinaryOp::Lt, value)
}
pub fn le(self, value: impl BindValue) -> Self {
self.compare(BinaryOp::Le, value)
}
pub fn add(self, rhs: Expr) -> Self {
Expr::binary(self, BinaryOp::Add, rhs)
}
pub fn sub(self, rhs: Expr) -> Self {
Expr::binary(self, BinaryOp::Sub, rhs)
}
pub fn mul(self, rhs: Expr) -> Self {
Expr::binary(self, BinaryOp::Mul, rhs)
}
pub fn div(self, rhs: Expr) -> Self {
Expr::binary(self, BinaryOp::Div, rhs)
}
pub fn rem(self, rhs: Expr) -> Self {
Expr::binary(self, BinaryOp::Mod, rhs)
}
#[cfg(feature = "postgres")]
pub fn matches(self, query: impl Into<Expr>) -> Expr {
Expr::binary(self, BinaryOp::TsMatch, query.into())
}
#[cfg(feature = "postgres")]
pub fn to_tsvector(self, config: &str) -> Expr {
Expr::func("to_tsvector", [Expr::value(Value::Text(config.to_string())), self])
}
#[cfg(feature = "postgres")]
pub fn ts_match(self, config: &str, query_text: &str) -> Expr {
let vector = self.to_tsvector(config);
let query = Expr::func(
"to_tsquery",
[
Expr::value(Value::Text(config.to_string())),
Expr::value(Value::Text(query_text.to_string())),
],
);
Expr::binary(vector, BinaryOp::TsMatch, query)
}
pub fn asc(self) -> OrderTerm {
OrderTerm::new(self, false)
}
pub fn desc(self) -> OrderTerm {
OrderTerm::new(self, true)
}
}