Skip to main content

reifydb_routine/function/math/aggregate/
count.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::{
9	Value,
10	r#type::{Type, input_types::InputTypes},
11};
12
13use crate::function::{AggregateFunction, AggregateFunctionContext, error::AggregateFunctionResult};
14
15pub struct Count {
16	pub counts: IndexMap<Vec<Value>, i64>,
17}
18
19impl Count {
20	pub fn new() -> Self {
21		Self {
22			counts: IndexMap::new(),
23		}
24	}
25}
26
27impl AggregateFunction for Count {
28	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
29		let column = ctx.column;
30		let groups = &ctx.groups;
31
32		// Check if this is count(*) by examining if we have a dummy column
33		let is_count_star = column.name.text() == "dummy" && matches!(column.data(), ColumnData::Int4(_));
34
35		if is_count_star {
36			// For count(*), count all rows including those with undefined values
37			for (group, indices) in groups.iter() {
38				let count = indices.len() as i64;
39				self.counts.insert(group.clone(), count);
40			}
41		} else {
42			// For count(column), only count defined (non-null) values
43			// is_defined handles both plain and Option-wrapped columns
44			for (group, indices) in groups.iter() {
45				let count = indices.iter().filter(|&i| column.data().is_defined(*i)).count() as i64;
46				self.counts.insert(group.clone(), count);
47			}
48		}
49		Ok(())
50	}
51
52	fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
53		let mut keys = Vec::with_capacity(self.counts.len());
54		let mut data = ColumnData::int8_with_capacity(self.counts.len());
55
56		for (key, count) in mem::take(&mut self.counts) {
57			keys.push(key);
58			data.push_value(Value::Int8(count));
59		}
60
61		Ok((keys, data))
62	}
63
64	fn return_type(&self, _input_type: &Type) -> Type {
65		Type::Int8
66	}
67
68	fn accepted_types(&self) -> InputTypes {
69		InputTypes::any()
70	}
71}