Skip to main content

reifydb_function/math/aggregate/
count.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::Value;
9
10use crate::{AggregateFunction, AggregateFunctionContext, error::AggregateFunctionResult};
11
12pub struct Count {
13	pub counts: IndexMap<Vec<Value>, i64>,
14}
15
16impl Count {
17	pub fn new() -> Self {
18		Self {
19			counts: IndexMap::new(),
20		}
21	}
22}
23
24impl AggregateFunction for Count {
25	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
26		let column = ctx.column;
27		let groups = &ctx.groups;
28
29		// Check if this is count(*) by examining if we have a dummy column
30		let is_count_star = column.name.text() == "dummy" && matches!(column.data(), ColumnData::Int4(_));
31
32		if is_count_star {
33			// For count(*), count all rows including those with undefined values
34			for (group, indices) in groups.iter() {
35				let count = indices.len() as i64;
36				self.counts.insert(group.clone(), count);
37			}
38		} else {
39			// For count(column), only count defined (non-null) values
40			match &column.data() {
41				ColumnData::Bool(container) => {
42					for (group, indices) in groups.iter() {
43						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
44							as i64;
45						self.counts.insert(group.clone(), count);
46					}
47				}
48				ColumnData::Float8(container) => {
49					for (group, indices) in groups.iter() {
50						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
51							as i64;
52						self.counts.insert(group.clone(), count);
53					}
54				}
55				ColumnData::Float4(container) => {
56					for (group, indices) in groups.iter() {
57						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
58							as i64;
59						self.counts.insert(group.clone(), count);
60					}
61				}
62				ColumnData::Int4(container) => {
63					for (group, indices) in groups.iter() {
64						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
65							as i64;
66						self.counts.insert(group.clone(), count);
67					}
68				}
69				ColumnData::Int8(container) => {
70					for (group, indices) in groups.iter() {
71						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
72							as i64;
73						self.counts.insert(group.clone(), count);
74					}
75				}
76				ColumnData::Int2(container) => {
77					for (group, indices) in groups.iter() {
78						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
79							as i64;
80						self.counts.insert(group.clone(), count);
81					}
82				}
83				ColumnData::Int1(container) => {
84					for (group, indices) in groups.iter() {
85						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
86							as i64;
87						self.counts.insert(group.clone(), count);
88					}
89				}
90				ColumnData::Int16(container) => {
91					for (group, indices) in groups.iter() {
92						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
93							as i64;
94						self.counts.insert(group.clone(), count);
95					}
96				}
97				ColumnData::Utf8 {
98					container,
99					..
100				} => {
101					for (group, indices) in groups.iter() {
102						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
103							as i64;
104						self.counts.insert(group.clone(), count);
105					}
106				}
107				ColumnData::Date(container) => {
108					for (group, indices) in groups.iter() {
109						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
110							as i64;
111						self.counts.insert(group.clone(), count);
112					}
113				}
114				ColumnData::DateTime(container) => {
115					for (group, indices) in groups.iter() {
116						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
117							as i64;
118						self.counts.insert(group.clone(), count);
119					}
120				}
121				_ => {
122					// For other column types, use generic is_defined check
123					for (group, indices) in groups.iter() {
124						let count = indices
125							.iter()
126							.filter(|&i| column.data().is_defined(*i))
127							.count() as i64;
128						self.counts.insert(group.clone(), count);
129					}
130				}
131			}
132		}
133		Ok(())
134	}
135
136	fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
137		let mut keys = Vec::with_capacity(self.counts.len());
138		let mut data = ColumnData::int8_with_capacity(self.counts.len());
139
140		for (key, count) in mem::take(&mut self.counts) {
141			keys.push(key);
142			data.push_value(Value::Int8(count));
143		}
144
145		Ok((keys, data))
146	}
147}