use std::ops;
use super::value::SqlValue;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinOp {
Add,
Sub,
Mul,
Div,
Mod,
BitAnd,
BitOr,
BitXor,
BitShl,
BitShr,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Literal(SqlValue),
Column(&'static str),
BinOp {
left: Box<Expr>,
op: BinOp,
right: Box<Expr>,
},
Function { kind: ScalarFn, args: Vec<Expr> },
Case {
branches: Vec<CaseBranch>,
default: Option<Box<Expr>>,
},
Subquery(Box<super::query::SelectQuery>),
OuterRef(&'static str),
AliasedColumn {
alias: &'static str,
column: &'static str,
},
Window(Box<super::window::WindowExpr>),
Aggregate(Box<super::query::AggregateExpr>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct CaseBranch {
pub condition: super::query::WhereExpr,
pub then: Expr,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScalarFn {
Lower,
Upper,
Length,
Concat,
Substr,
Trim,
LTrim,
RTrim,
Replace,
Abs,
Ceil,
Floor,
Round,
Coalesce,
Greatest,
Least,
NullIf,
Now,
ExtractYear,
ExtractMonth,
ExtractDay,
ExtractHour,
ExtractMinute,
ExtractSecond,
ExtractWeek,
ExtractWeekDay,
ExtractQuarter,
TruncDate,
TruncYear,
TruncMonth,
TruncDay,
TrigramSimilarity,
TrigramWordSimilarity,
ToTsVector,
PlainToTsQuery,
TsRank,
TsHeadline,
PhraseToTsQuery,
WebsearchToTsQuery,
ToTsQuery,
TsRankCd,
}
impl Expr {
#[must_use]
pub fn col(name: &'static str) -> Self {
Self::Column(name)
}
#[must_use]
pub fn binop(self, op: BinOp, rhs: impl Into<Expr>) -> Self {
Self::BinOp {
left: Box::new(self),
op,
right: Box::new(rhs.into()),
}
}
#[must_use]
pub fn is_literal(&self) -> bool {
matches!(self, Self::Literal(_))
}
#[must_use]
pub fn as_literal(&self) -> Option<&SqlValue> {
match self {
Self::Literal(v) => Some(v),
_ => None,
}
}
}
impl From<SqlValue> for Expr {
fn from(v: SqlValue) -> Self {
Self::Literal(v)
}
}
macro_rules! expr_from_primitive {
($($t:ty),+ $(,)?) => {
$(
impl From<$t> for Expr {
fn from(v: $t) -> Self { Self::Literal(SqlValue::from(v)) }
}
)+
};
}
expr_from_primitive! {
i16, i32, i64, f32, f64, bool, String, &'static str,
chrono::DateTime<chrono::Utc>, chrono::NaiveDate, uuid::Uuid,
serde_json::Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(non_camel_case_types)] pub struct F(pub &'static str);
impl F {
#[must_use]
pub fn new(column: &'static str) -> Self {
Self(column)
}
}
impl From<F> for Expr {
fn from(f: F) -> Self {
Self::Column(f.0)
}
}
macro_rules! impl_binop {
($($Trait:ident :: $method:ident => $op:ident),+ $(,)?) => {
$(
impl<R: Into<Expr>> ops::$Trait<R> for F {
type Output = Expr;
fn $method(self, rhs: R) -> Expr {
Expr::Column(self.0).binop(BinOp::$op, rhs)
}
}
impl<R: Into<Expr>> ops::$Trait<R> for Expr {
type Output = Expr;
fn $method(self, rhs: R) -> Expr {
self.binop(BinOp::$op, rhs)
}
}
)+
};
}
impl_binop! {
Add::add => Add,
Sub::sub => Sub,
Mul::mul => Mul,
Div::div => Div,
Rem::rem => Mod,
BitAnd::bitand => BitAnd,
BitOr::bitor => BitOr,
BitXor::bitxor => BitXor,
Shl::shl => BitShl,
Shr::shr => BitShr,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn f_lifts_to_column_expr() {
let e: Expr = F("views").into();
assert_eq!(e, Expr::Column("views"));
}
#[test]
fn f_add_int_builds_binop() {
let e: Expr = F("views") + 1;
assert_eq!(
e,
Expr::BinOp {
left: Box::new(Expr::Column("views")),
op: BinOp::Add,
right: Box::new(Expr::Literal(SqlValue::I32(1))),
}
);
}
#[test]
fn f_add_f_builds_column_column_binop() {
let e: Expr = F("a") + F("b");
assert_eq!(
e,
Expr::BinOp {
left: Box::new(Expr::Column("a")),
op: BinOp::Add,
right: Box::new(Expr::Column("b")),
}
);
}
#[test]
fn arithmetic_chains_left_assoc() {
let e: Expr = F("a") + 1 - 2;
let Expr::BinOp { left, op, right } = e else {
panic!("expected outer BinOp");
};
assert_eq!(op, BinOp::Sub);
assert_eq!(*right, Expr::Literal(SqlValue::I32(2)));
let Expr::BinOp { op: inner_op, .. } = *left else {
panic!("expected nested BinOp")
};
assert_eq!(inner_op, BinOp::Add);
}
#[test]
fn sqlvalue_lifts_into_expr_literal() {
let e: Expr = SqlValue::I64(42).into();
assert_eq!(e, Expr::Literal(SqlValue::I64(42)));
}
#[test]
fn primitives_lift_into_expr_literal() {
let e: Expr = 7i64.into();
assert_eq!(e, Expr::Literal(SqlValue::I64(7)));
let e: Expr = "hi".into();
assert_eq!(e, Expr::Literal(SqlValue::String("hi".to_owned())));
}
#[test]
fn is_literal_distinguishes() {
assert!(Expr::Literal(SqlValue::I32(1)).is_literal());
assert!(!Expr::Column("x").is_literal());
assert!(!(F("a") + 1).is_literal());
}
#[test]
fn bitwise_operators_compile_and_compose() {
let e: Expr = F("mask") & 0xff_i32;
assert!(matches!(
e,
Expr::BinOp {
op: BinOp::BitAnd,
..
}
));
let e: Expr = F("a") | F("b");
assert!(matches!(
e,
Expr::BinOp {
op: BinOp::BitOr,
..
}
));
let e: Expr = F("a") << 4_i32;
assert!(matches!(
e,
Expr::BinOp {
op: BinOp::BitShl,
..
}
));
}
}