use crate::core::Value;
use crate::functions::AggregateFunction;
use super::DistinctTracker;
#[derive(Debug, Clone, Default)]
pub enum SumState {
#[default]
Empty,
Integer(i64),
Float(f64),
}
pub enum CompiledAggregate {
CountStar { count: i64 },
Count { count: i64 },
CountDistinct { distinct_tracker: DistinctTracker },
Sum { state: SumState },
SumDistinct {
state: SumState,
distinct_tracker: DistinctTracker,
},
Avg { sum: f64, count: i64 },
AvgDistinct {
sum: f64,
count: i64,
distinct_tracker: DistinctTracker,
},
Min { min_value: Option<Value> },
Max { max_value: Option<Value> },
MinInteger { min_value: Option<i64> },
MaxInteger { max_value: Option<i64> },
MinFloat { min_value: Option<f64> },
MaxFloat { max_value: Option<f64> },
Dynamic(Box<dyn AggregateFunction>),
}
impl std::fmt::Debug for CompiledAggregate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CompiledAggregate::CountStar { count } => {
f.debug_struct("CountStar").field("count", count).finish()
}
CompiledAggregate::Count { count } => {
f.debug_struct("Count").field("count", count).finish()
}
CompiledAggregate::CountDistinct { distinct_tracker } => f
.debug_struct("CountDistinct")
.field("distinct_tracker", distinct_tracker)
.finish(),
CompiledAggregate::Sum { state } => {
f.debug_struct("Sum").field("state", state).finish()
}
CompiledAggregate::SumDistinct {
state,
distinct_tracker,
} => f
.debug_struct("SumDistinct")
.field("state", state)
.field("distinct_tracker", distinct_tracker)
.finish(),
CompiledAggregate::Avg { sum, count } => f
.debug_struct("Avg")
.field("sum", sum)
.field("count", count)
.finish(),
CompiledAggregate::AvgDistinct {
sum,
count,
distinct_tracker,
} => f
.debug_struct("AvgDistinct")
.field("sum", sum)
.field("count", count)
.field("distinct_tracker", distinct_tracker)
.finish(),
CompiledAggregate::Min { min_value } => {
f.debug_struct("Min").field("min_value", min_value).finish()
}
CompiledAggregate::Max { max_value } => {
f.debug_struct("Max").field("max_value", max_value).finish()
}
CompiledAggregate::MinInteger { min_value } => f
.debug_struct("MinInteger")
.field("min_value", min_value)
.finish(),
CompiledAggregate::MaxInteger { max_value } => f
.debug_struct("MaxInteger")
.field("max_value", max_value)
.finish(),
CompiledAggregate::MinFloat { min_value } => f
.debug_struct("MinFloat")
.field("min_value", min_value)
.finish(),
CompiledAggregate::MaxFloat { max_value } => f
.debug_struct("MaxFloat")
.field("max_value", max_value)
.finish(),
CompiledAggregate::Dynamic(func) => {
f.debug_tuple("Dynamic").field(&func.name()).finish()
}
}
}
}
impl CompiledAggregate {
pub fn count_star() -> Self {
CompiledAggregate::CountStar { count: 0 }
}
pub fn count(distinct: bool) -> Self {
if distinct {
CompiledAggregate::CountDistinct {
distinct_tracker: DistinctTracker::default(),
}
} else {
CompiledAggregate::Count { count: 0 }
}
}
pub fn sum(distinct: bool) -> Self {
if distinct {
CompiledAggregate::SumDistinct {
state: SumState::Empty,
distinct_tracker: DistinctTracker::default(),
}
} else {
CompiledAggregate::Sum {
state: SumState::Empty,
}
}
}
pub fn avg(distinct: bool) -> Self {
if distinct {
CompiledAggregate::AvgDistinct {
sum: 0.0,
count: 0,
distinct_tracker: DistinctTracker::default(),
}
} else {
CompiledAggregate::Avg { sum: 0.0, count: 0 }
}
}
pub fn min() -> Self {
CompiledAggregate::Min { min_value: None }
}
pub fn min_integer() -> Self {
CompiledAggregate::MinInteger { min_value: None }
}
pub fn min_float() -> Self {
CompiledAggregate::MinFloat { min_value: None }
}
pub fn max() -> Self {
CompiledAggregate::Max { max_value: None }
}
pub fn max_integer() -> Self {
CompiledAggregate::MaxInteger { max_value: None }
}
pub fn max_float() -> Self {
CompiledAggregate::MaxFloat { max_value: None }
}
pub fn dynamic(func: Box<dyn AggregateFunction>) -> Self {
CompiledAggregate::Dynamic(func)
}
pub fn compile(
name: &str,
is_count_star: bool,
distinct: bool,
dynamic_fallback: Option<Box<dyn AggregateFunction>>,
) -> Option<Self> {
if name.eq_ignore_ascii_case("COUNT") {
if is_count_star {
Some(CompiledAggregate::count_star())
} else {
Some(CompiledAggregate::count(distinct))
}
} else if name.eq_ignore_ascii_case("SUM") {
Some(CompiledAggregate::sum(distinct))
} else if name.eq_ignore_ascii_case("AVG") {
Some(CompiledAggregate::avg(distinct))
} else if name.eq_ignore_ascii_case("MIN") {
Some(CompiledAggregate::min())
} else if name.eq_ignore_ascii_case("MAX") {
Some(CompiledAggregate::max())
} else {
dynamic_fallback.map(CompiledAggregate::Dynamic)
}
}
#[inline(always)]
pub fn accumulate(&mut self, value: &Value) {
match self {
CompiledAggregate::CountStar { count } => {
*count += 1;
}
CompiledAggregate::Count { count } => {
if !value.is_null() {
*count += 1;
}
}
CompiledAggregate::CountDistinct { distinct_tracker } => {
if !value.is_null() {
distinct_tracker.check_and_add(value);
}
}
CompiledAggregate::Sum { state } => {
if !value.is_null() {
Self::accumulate_sum(state, value);
}
}
CompiledAggregate::SumDistinct {
state,
distinct_tracker,
} => {
if !value.is_null() && distinct_tracker.check_and_add(value) {
Self::accumulate_sum(state, value);
}
}
CompiledAggregate::Avg { sum, count } => {
if let Some(n) = Self::as_f64(value) {
*sum += n;
*count += 1;
}
}
CompiledAggregate::AvgDistinct {
sum,
count,
distinct_tracker,
} => {
if !value.is_null() && distinct_tracker.check_and_add(value) {
if let Some(n) = Self::as_f64(value) {
*sum += n;
*count += 1;
}
}
}
CompiledAggregate::Min { min_value } => {
if !value.is_null() {
match min_value {
None => *min_value = Some(value.clone()),
Some(current) => {
if Self::is_less_than(value, current) {
*min_value = Some(value.clone());
}
}
}
}
}
CompiledAggregate::Max { max_value } => {
if !value.is_null() {
match max_value {
None => *max_value = Some(value.clone()),
Some(current) => {
if Self::is_greater_than(value, current) {
*max_value = Some(value.clone());
}
}
}
}
}
CompiledAggregate::MinInteger { min_value } => {
if let Value::Integer(v) = value {
match min_value {
None => *min_value = Some(*v),
Some(current) if *v < *current => *min_value = Some(*v),
_ => {}
}
}
}
CompiledAggregate::MaxInteger { max_value } => {
if let Value::Integer(v) = value {
match max_value {
None => *max_value = Some(*v),
Some(current) if *v > *current => *max_value = Some(*v),
_ => {}
}
}
}
CompiledAggregate::MinFloat { min_value } => {
if let Value::Float(v) = value {
match min_value {
None => *min_value = Some(*v),
Some(current) if *v < *current => *min_value = Some(*v),
_ => {}
}
}
}
CompiledAggregate::MaxFloat { max_value } => {
if let Value::Float(v) = value {
match max_value {
None => *max_value = Some(*v),
Some(current) if *v > *current => *max_value = Some(*v),
_ => {}
}
}
}
CompiledAggregate::Dynamic(func) => {
func.accumulate(value, false);
}
}
}
#[inline(always)]
pub fn accumulate_with_distinct(&mut self, value: &Value, distinct: bool) {
if let CompiledAggregate::Dynamic(func) = self {
func.accumulate(value, distinct);
} else {
self.accumulate(value);
}
}
#[inline]
pub fn result(&self) -> Value {
match self {
CompiledAggregate::CountStar { count } => Value::Integer(*count),
CompiledAggregate::Count { count } => Value::Integer(*count),
CompiledAggregate::CountDistinct { distinct_tracker } => {
Value::Integer(distinct_tracker.count() as i64)
}
CompiledAggregate::Sum { state } | CompiledAggregate::SumDistinct { state, .. } => {
match state {
SumState::Empty => Value::null_unknown(),
SumState::Integer(sum) => Value::Integer(*sum),
SumState::Float(sum) => Value::Float(*sum),
}
}
CompiledAggregate::Avg { sum, count }
| CompiledAggregate::AvgDistinct { sum, count, .. } => {
if *count == 0 {
Value::null_unknown()
} else {
Value::Float(*sum / *count as f64)
}
}
CompiledAggregate::Min { min_value } => {
min_value.clone().unwrap_or_else(Value::null_unknown)
}
CompiledAggregate::Max { max_value } => {
max_value.clone().unwrap_or_else(Value::null_unknown)
}
CompiledAggregate::MinInteger { min_value } => min_value
.map(Value::Integer)
.unwrap_or_else(Value::null_unknown),
CompiledAggregate::MaxInteger { max_value } => max_value
.map(Value::Integer)
.unwrap_or_else(Value::null_unknown),
CompiledAggregate::MinFloat { min_value } => min_value
.map(Value::Float)
.unwrap_or_else(Value::null_unknown),
CompiledAggregate::MaxFloat { max_value } => max_value
.map(Value::Float)
.unwrap_or_else(Value::null_unknown),
CompiledAggregate::Dynamic(func) => func.result(),
}
}
pub fn reset(&mut self) {
match self {
CompiledAggregate::CountStar { count } => *count = 0,
CompiledAggregate::Count { count } => *count = 0,
CompiledAggregate::CountDistinct { distinct_tracker } => distinct_tracker.reset(),
CompiledAggregate::Sum { state } => *state = SumState::Empty,
CompiledAggregate::SumDistinct {
state,
distinct_tracker,
} => {
*state = SumState::Empty;
distinct_tracker.reset();
}
CompiledAggregate::Avg { sum, count } => {
*sum = 0.0;
*count = 0;
}
CompiledAggregate::AvgDistinct {
sum,
count,
distinct_tracker,
} => {
*sum = 0.0;
*count = 0;
distinct_tracker.reset();
}
CompiledAggregate::Min { min_value } => *min_value = None,
CompiledAggregate::Max { max_value } => *max_value = None,
CompiledAggregate::MinInteger { min_value } => *min_value = None,
CompiledAggregate::MaxInteger { max_value } => *max_value = None,
CompiledAggregate::MinFloat { min_value } => *min_value = None,
CompiledAggregate::MaxFloat { max_value } => *max_value = None,
CompiledAggregate::Dynamic(func) => func.reset(),
}
}
#[inline(always)]
fn accumulate_sum(state: &mut SumState, value: &Value) {
match value {
Value::Integer(i) => match state {
SumState::Empty => *state = SumState::Integer(*i),
SumState::Integer(sum) => *sum += i,
SumState::Float(sum) => *sum += *i as f64,
},
Value::Float(f) => match state {
SumState::Empty => *state = SumState::Float(*f),
SumState::Integer(sum) => {
*state = SumState::Float(*sum as f64 + f);
}
SumState::Float(sum) => *sum += f,
},
_ => {} }
}
#[inline(always)]
fn as_f64(value: &Value) -> Option<f64> {
match value {
Value::Integer(i) => Some(*i as f64),
Value::Float(f) => Some(*f),
_ => None,
}
}
#[inline(always)]
fn is_less_than(a: &Value, b: &Value) -> bool {
match (a, b) {
(Value::Null(_), _) | (_, Value::Null(_)) => false,
(Value::Integer(a), Value::Integer(b)) => a < b,
(Value::Float(a), Value::Float(b)) => a < b,
(Value::Integer(a), Value::Float(b)) => (*a as f64) < *b,
(Value::Float(a), Value::Integer(b)) => *a < (*b as f64),
(Value::Text(a), Value::Text(b)) => a < b,
(Value::Boolean(a), Value::Boolean(b)) => !a && *b,
(Value::Timestamp(a), Value::Timestamp(b)) => a < b,
_ => false,
}
}
#[inline(always)]
fn is_greater_than(a: &Value, b: &Value) -> bool {
match (a, b) {
(Value::Null(_), _) | (_, Value::Null(_)) => false,
(Value::Integer(a), Value::Integer(b)) => a > b,
(Value::Float(a), Value::Float(b)) => a > b,
(Value::Integer(a), Value::Float(b)) => (*a as f64) > *b,
(Value::Float(a), Value::Integer(b)) => *a > (*b as f64),
(Value::Text(a), Value::Text(b)) => a > b,
(Value::Boolean(a), Value::Boolean(b)) => *a && !b,
(Value::Timestamp(a), Value::Timestamp(b)) => a > b,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_star() {
let mut agg = CompiledAggregate::count_star();
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::null_unknown());
agg.accumulate(&Value::Integer(3));
assert_eq!(agg.result(), Value::Integer(3)); }
#[test]
fn test_count() {
let mut agg = CompiledAggregate::count(false);
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::null_unknown());
agg.accumulate(&Value::Integer(3));
assert_eq!(agg.result(), Value::Integer(2)); }
#[test]
fn test_count_distinct() {
let mut agg = CompiledAggregate::count(true);
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::Integer(1)); agg.accumulate(&Value::Integer(2));
agg.accumulate(&Value::null_unknown()); assert_eq!(agg.result(), Value::Integer(2));
}
#[test]
fn test_sum_integers() {
let mut agg = CompiledAggregate::sum(false);
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::Integer(2));
agg.accumulate(&Value::Integer(3));
assert_eq!(agg.result(), Value::Integer(6));
}
#[test]
fn test_sum_floats() {
let mut agg = CompiledAggregate::sum(false);
agg.accumulate(&Value::Float(1.5));
agg.accumulate(&Value::Float(2.5));
assert_eq!(agg.result(), Value::Float(4.0));
}
#[test]
fn test_sum_mixed() {
let mut agg = CompiledAggregate::sum(false);
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::Float(2.5));
assert_eq!(agg.result(), Value::Float(3.5));
}
#[test]
fn test_sum_distinct() {
let mut agg = CompiledAggregate::sum(true);
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::Integer(1)); agg.accumulate(&Value::Integer(2));
assert_eq!(agg.result(), Value::Integer(3)); }
#[test]
fn test_avg() {
let mut agg = CompiledAggregate::avg(false);
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::Integer(2));
agg.accumulate(&Value::Integer(3));
assert_eq!(agg.result(), Value::Float(2.0));
}
#[test]
fn test_avg_distinct() {
let mut agg = CompiledAggregate::avg(true);
agg.accumulate(&Value::Integer(1));
agg.accumulate(&Value::Integer(1)); agg.accumulate(&Value::Integer(3));
assert_eq!(agg.result(), Value::Float(2.0)); }
#[test]
fn test_min_integers() {
let mut agg = CompiledAggregate::min();
agg.accumulate(&Value::Integer(5));
agg.accumulate(&Value::Integer(2));
agg.accumulate(&Value::Integer(8));
assert_eq!(agg.result(), Value::Integer(2));
}
#[test]
fn test_max_integers() {
let mut agg = CompiledAggregate::max();
agg.accumulate(&Value::Integer(5));
agg.accumulate(&Value::Integer(2));
agg.accumulate(&Value::Integer(8));
assert_eq!(agg.result(), Value::Integer(8));
}
#[test]
fn test_min_integer_fast() {
let mut agg = CompiledAggregate::min_integer();
agg.accumulate(&Value::Integer(5));
agg.accumulate(&Value::Integer(2));
agg.accumulate(&Value::Integer(8));
assert_eq!(agg.result(), Value::Integer(2));
}
#[test]
fn test_max_integer_fast() {
let mut agg = CompiledAggregate::max_integer();
agg.accumulate(&Value::Integer(5));
agg.accumulate(&Value::Integer(2));
agg.accumulate(&Value::Integer(8));
assert_eq!(agg.result(), Value::Integer(8));
}
#[test]
fn test_min_strings() {
let mut agg = CompiledAggregate::min();
agg.accumulate(&Value::text("banana"));
agg.accumulate(&Value::text("apple"));
agg.accumulate(&Value::text("cherry"));
assert_eq!(agg.result(), Value::text("apple"));
}
#[test]
fn test_max_strings() {
let mut agg = CompiledAggregate::max();
agg.accumulate(&Value::text("banana"));
agg.accumulate(&Value::text("apple"));
agg.accumulate(&Value::text("cherry"));
assert_eq!(agg.result(), Value::text("cherry"));
}
#[test]
fn test_empty_aggregates() {
assert!(CompiledAggregate::sum(false).result().is_null());
assert!(CompiledAggregate::avg(false).result().is_null());
assert!(CompiledAggregate::min().result().is_null());
assert!(CompiledAggregate::max().result().is_null());
assert_eq!(CompiledAggregate::count(false).result(), Value::Integer(0));
assert_eq!(CompiledAggregate::count_star().result(), Value::Integer(0));
}
#[test]
fn test_reset() {
let mut agg = CompiledAggregate::sum(false);
agg.accumulate(&Value::Integer(10));
agg.reset();
assert!(agg.result().is_null());
let mut agg = CompiledAggregate::count(false);
agg.accumulate(&Value::Integer(10));
agg.reset();
assert_eq!(agg.result(), Value::Integer(0));
}
#[test]
fn test_compile() {
let agg = CompiledAggregate::compile("count", true, false, None);
assert!(matches!(agg, Some(CompiledAggregate::CountStar { .. })));
let agg = CompiledAggregate::compile("COUNT", false, false, None);
assert!(matches!(agg, Some(CompiledAggregate::Count { .. })));
let agg = CompiledAggregate::compile("sum", false, true, None);
assert!(matches!(agg, Some(CompiledAggregate::SumDistinct { .. })));
let agg = CompiledAggregate::compile("avg", false, false, None);
assert!(matches!(agg, Some(CompiledAggregate::Avg { .. })));
let agg = CompiledAggregate::compile("min", false, false, None);
assert!(matches!(agg, Some(CompiledAggregate::Min { .. })));
let agg = CompiledAggregate::compile("max", false, false, None);
assert!(matches!(agg, Some(CompiledAggregate::Max { .. })));
let agg = CompiledAggregate::compile("unknown", false, false, None);
assert!(agg.is_none());
}
}