use crate::core::Value;
use crate::functions::{
AggregateFunction, FunctionDataType, FunctionInfo, FunctionSignature, FunctionType,
};
use super::DistinctTracker;
#[derive(Default)]
pub struct CountFunction {
count: i64,
distinct_tracker: Option<DistinctTracker>,
}
impl AggregateFunction for CountFunction {
fn name(&self) -> &str {
"COUNT"
}
fn info(&self) -> FunctionInfo {
FunctionInfo::new(
"COUNT",
FunctionType::Aggregate,
"Returns the number of rows matching the query criteria",
FunctionSignature::new(
FunctionDataType::Integer,
vec![FunctionDataType::Any],
0, 1, ),
)
}
fn accumulate(&mut self, value: &Value, distinct: bool) {
if value.is_null() {
return;
}
if let Value::Text(s) = value {
if &**s == "*" {
self.count += 1;
return;
}
}
if distinct {
if self.distinct_tracker.is_none() {
self.distinct_tracker = Some(DistinctTracker::default());
}
self.distinct_tracker.as_mut().unwrap().check_and_add(value);
} else {
self.count += 1;
}
}
fn result(&self) -> Value {
if let Some(ref tracker) = self.distinct_tracker {
Value::Integer(tracker.count() as i64)
} else {
Value::Integer(self.count)
}
}
fn reset(&mut self) {
self.count = 0;
self.distinct_tracker = None;
}
fn clone_box(&self) -> Box<dyn AggregateFunction> {
Box::new(CountFunction::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_basic() {
let mut count = CountFunction::default();
count.accumulate(&Value::Integer(1), false);
count.accumulate(&Value::Integer(2), false);
count.accumulate(&Value::Integer(3), false);
assert_eq!(count.result(), Value::Integer(3));
}
#[test]
fn test_count_star() {
let mut count = CountFunction::default();
count.accumulate(&Value::text("*"), false);
count.accumulate(&Value::text("*"), false);
count.accumulate(&Value::text("*"), false);
assert_eq!(count.result(), Value::Integer(3));
}
#[test]
fn test_count_ignores_null() {
let mut count = CountFunction::default();
count.accumulate(&Value::Integer(1), false);
count.accumulate(&Value::null_unknown(), false);
count.accumulate(&Value::Integer(3), false);
assert_eq!(count.result(), Value::Integer(2));
}
#[test]
fn test_count_distinct() {
let mut count = CountFunction::default();
count.accumulate(&Value::Integer(1), true);
count.accumulate(&Value::Integer(1), true); count.accumulate(&Value::Integer(2), true);
count.accumulate(&Value::Integer(2), true); count.accumulate(&Value::Integer(3), true);
assert_eq!(count.result(), Value::Integer(3));
}
#[test]
fn test_count_reset() {
let mut count = CountFunction::default();
count.accumulate(&Value::Integer(1), false);
count.accumulate(&Value::Integer(2), false);
count.reset();
assert_eq!(count.result(), Value::Integer(0));
}
#[test]
fn test_count_empty() {
let count = CountFunction::default();
assert_eq!(count.result(), Value::Integer(0));
}
}