use super::query::{AggregateExpr, WhereExpr};
use super::SqlValue;
#[must_use]
pub struct AggregateBuilder {
kind: AggregateExpr,
filter: Option<WhereExpr>,
default: Option<SqlValue>,
}
impl AggregateBuilder {
fn new(kind: AggregateExpr) -> Self {
Self {
kind,
filter: None,
default: None,
}
}
pub fn filter(mut self, predicate: impl Into<WhereExpr>) -> Self {
self.filter = Some(predicate.into());
self
}
pub fn default(mut self, value: impl Into<SqlValue>) -> Self {
self.default = Some(value.into());
self
}
#[must_use]
pub fn build(self) -> AggregateExpr {
let mut out = self.kind;
if let Some(f) = self.filter {
out = AggregateExpr::Filtered {
inner: Box::new(out),
filter: f,
};
}
if let Some(d) = self.default {
out = AggregateExpr::Coalesced {
inner: Box::new(out),
default: d,
};
}
out
}
}
impl From<AggregateBuilder> for AggregateExpr {
fn from(b: AggregateBuilder) -> Self {
b.build()
}
}
#[must_use]
pub fn count(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::Count(Some(column)))
}
#[must_use]
pub fn count_all() -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::Count(None))
}
#[must_use]
pub fn count_distinct(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::CountDistinct(column))
}
#[must_use]
pub fn sum(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::Sum(column))
}
#[must_use]
pub fn avg(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::Avg(column))
}
#[must_use]
pub fn max(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::Max(column))
}
#[must_use]
pub fn min(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::Min(column))
}
#[must_use]
pub fn stddev(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::StdDev(column))
}
#[must_use]
pub fn stddev_pop(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::StdDevPop(column))
}
#[must_use]
pub fn variance(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::Variance(column))
}
#[must_use]
pub fn variance_pop(column: &'static str) -> AggregateBuilder {
AggregateBuilder::new(AggregateExpr::VariancePop(column))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Filter, Op};
fn predicate(col: &'static str) -> WhereExpr {
WhereExpr::Predicate(Filter {
column: col,
op: Op::Eq,
value: SqlValue::Bool(true),
})
}
#[test]
fn bare_count_lowers_to_count_variant() {
let e: AggregateExpr = count("id").into();
assert!(matches!(e, AggregateExpr::Count(Some("id"))));
}
#[test]
fn count_all_emits_count_none() {
let e: AggregateExpr = count_all().into();
assert!(matches!(e, AggregateExpr::Count(None)));
}
#[test]
fn filter_wraps_in_filtered_variant() {
let e: AggregateExpr = count("id").filter(predicate("active")).into();
assert!(matches!(e, AggregateExpr::Filtered { .. }));
}
#[test]
fn default_wraps_in_coalesced_variant() {
let e: AggregateExpr = sum("price").default(0_i64).into();
assert!(matches!(e, AggregateExpr::Coalesced { .. }));
}
#[test]
fn filter_then_default_wraps_coalesced_outside_filtered() {
let e: AggregateExpr = sum("price")
.filter(predicate("active"))
.default(0_i64)
.into();
match e {
AggregateExpr::Coalesced { inner, .. } => match *inner {
AggregateExpr::Filtered { inner, .. } => {
assert!(matches!(*inner, AggregateExpr::Sum("price")));
}
_ => panic!("expected Filtered inside Coalesced"),
},
_ => panic!("expected Coalesced at the top"),
}
}
#[test]
fn default_then_filter_still_wraps_coalesced_outside_filtered() {
let e: AggregateExpr = sum("price")
.default(0_i64)
.filter(predicate("active"))
.into();
match e {
AggregateExpr::Coalesced { inner, .. } => match *inner {
AggregateExpr::Filtered { .. } => {}
_ => panic!("expected Filtered inside Coalesced"),
},
_ => panic!("expected Coalesced at the top"),
}
}
#[test]
fn stddev_family_lowers_to_dedicated_variants() {
assert!(matches!(stddev("x").build(), AggregateExpr::StdDev("x")));
assert!(matches!(
stddev_pop("x").build(),
AggregateExpr::StdDevPop("x")
));
assert!(matches!(
variance("x").build(),
AggregateExpr::Variance("x")
));
assert!(matches!(
variance_pop("x").build(),
AggregateExpr::VariancePop("x")
));
}
}