use super::expr::{CaseBranch, Expr};
use super::query::WhereExpr;
#[must_use]
pub struct CaseBuilder {
branches: Vec<CaseBranch>,
default: Option<Box<Expr>>,
}
#[must_use]
pub fn case() -> CaseBuilder {
CaseBuilder {
branches: Vec::new(),
default: None,
}
}
#[must_use]
pub fn value(v: impl Into<super::SqlValue>) -> Expr {
Expr::Literal(v.into())
}
impl CaseBuilder {
#[must_use]
pub fn when(mut self, condition: impl Into<WhereExpr>, then: impl Into<Expr>) -> Self {
self.branches.push(CaseBranch {
condition: condition.into(),
then: then.into(),
});
self
}
#[must_use]
pub fn default(mut self, value: impl Into<Expr>) -> Self {
self.default = Some(Box::new(value.into()));
self
}
#[must_use]
pub fn build(self) -> Expr {
Expr::Case {
branches: self.branches,
default: self.default,
}
}
}
impl From<CaseBuilder> for Expr {
fn from(b: CaseBuilder) -> Self {
b.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Filter, Op, SqlValue};
fn predicate(col: &'static str, value: i64) -> WhereExpr {
WhereExpr::Predicate(Filter {
column: col,
op: Op::Eq,
value: SqlValue::I64(value),
})
}
#[test]
fn empty_builder_yields_case_with_no_branches() {
let e: Expr = case().build();
let Expr::Case { branches, default } = e else {
panic!("expected Case variant")
};
assert!(branches.is_empty());
assert!(default.is_none());
}
#[test]
fn single_when_no_default_produces_one_branch() {
let e: Expr = case().when(predicate("status", 1), 100_i64).build();
let Expr::Case { branches, default } = e else {
panic!()
};
assert_eq!(branches.len(), 1);
assert_eq!(branches[0].condition, predicate("status", 1));
assert_eq!(branches[0].then, Expr::Literal(SqlValue::I64(100)));
assert!(default.is_none());
}
#[test]
fn multiple_branches_preserve_source_order() {
let e: Expr = case()
.when(predicate("a", 1), 10_i64)
.when(predicate("a", 2), 20_i64)
.when(predicate("a", 3), 30_i64)
.build();
let Expr::Case { branches, .. } = e else {
panic!()
};
assert_eq!(branches.len(), 3);
assert_eq!(branches[0].then, Expr::Literal(SqlValue::I64(10)));
assert_eq!(branches[1].then, Expr::Literal(SqlValue::I64(20)));
assert_eq!(branches[2].then, Expr::Literal(SqlValue::I64(30)));
}
#[test]
fn default_is_stored_as_boxed_else() {
let e: Expr = case()
.when(predicate("status", 1), 100_i64)
.default(999_i64)
.build();
let Expr::Case { default, .. } = e else {
panic!()
};
assert_eq!(default.as_deref(), Some(&Expr::Literal(SqlValue::I64(999))));
}
#[test]
fn last_default_call_wins() {
let e: Expr = case().default(1_i64).default(2_i64).build();
let Expr::Case { default, .. } = e else {
panic!()
};
assert_eq!(default.as_deref(), Some(&Expr::Literal(SqlValue::I64(2))));
}
#[test]
fn case_implements_into_expr() {
let _: Expr = case().when(predicate("x", 1), 1_i64).into();
}
#[test]
fn value_sugars_a_literal_into_expr() {
let e: Expr = value("hello");
assert_eq!(e, Expr::Literal(SqlValue::String("hello".into())));
let e: Expr = value(42_i64);
assert_eq!(e, Expr::Literal(SqlValue::I64(42)));
}
}