Skip to main content

reifydb_function/math/aggregate/
count.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use indexmap::IndexMap;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::Value;
7
8use crate::{AggregateFunction, AggregateFunctionContext};
9
10pub struct Count {
11	pub counts: IndexMap<Vec<Value>, i64>,
12}
13
14impl Count {
15	pub fn new() -> Self {
16		Self {
17			counts: IndexMap::new(),
18		}
19	}
20}
21
22impl AggregateFunction for Count {
23	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> crate::error::AggregateFunctionResult<()> {
24		let column = ctx.column;
25		let groups = &ctx.groups;
26
27		// Check if this is count(*) by examining if we have a dummy column
28		let is_count_star = column.name.text() == "dummy" && matches!(column.data(), ColumnData::Int4(_));
29
30		if is_count_star {
31			// For count(*), count all rows including those with undefined values
32			for (group, indices) in groups.iter() {
33				let count = indices.len() as i64;
34				self.counts.insert(group.clone(), count);
35			}
36		} else {
37			// For count(column), only count defined (non-null) values
38			match &column.data() {
39				ColumnData::Bool(container) => {
40					for (group, indices) in groups.iter() {
41						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
42							as i64;
43						self.counts.insert(group.clone(), count);
44					}
45				}
46				ColumnData::Float8(container) => {
47					for (group, indices) in groups.iter() {
48						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
49							as i64;
50						self.counts.insert(group.clone(), count);
51					}
52				}
53				ColumnData::Float4(container) => {
54					for (group, indices) in groups.iter() {
55						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
56							as i64;
57						self.counts.insert(group.clone(), count);
58					}
59				}
60				ColumnData::Int4(container) => {
61					for (group, indices) in groups.iter() {
62						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
63							as i64;
64						self.counts.insert(group.clone(), count);
65					}
66				}
67				ColumnData::Int8(container) => {
68					for (group, indices) in groups.iter() {
69						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
70							as i64;
71						self.counts.insert(group.clone(), count);
72					}
73				}
74				ColumnData::Int2(container) => {
75					for (group, indices) in groups.iter() {
76						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
77							as i64;
78						self.counts.insert(group.clone(), count);
79					}
80				}
81				ColumnData::Int1(container) => {
82					for (group, indices) in groups.iter() {
83						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
84							as i64;
85						self.counts.insert(group.clone(), count);
86					}
87				}
88				ColumnData::Int16(container) => {
89					for (group, indices) in groups.iter() {
90						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
91							as i64;
92						self.counts.insert(group.clone(), count);
93					}
94				}
95				ColumnData::Utf8 {
96					container,
97					..
98				} => {
99					for (group, indices) in groups.iter() {
100						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
101							as i64;
102						self.counts.insert(group.clone(), count);
103					}
104				}
105				ColumnData::Date(container) => {
106					for (group, indices) in groups.iter() {
107						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
108							as i64;
109						self.counts.insert(group.clone(), count);
110					}
111				}
112				ColumnData::DateTime(container) => {
113					for (group, indices) in groups.iter() {
114						let count = indices.iter().filter(|&i| container.is_defined(*i)).count()
115							as i64;
116						self.counts.insert(group.clone(), count);
117					}
118				}
119				_ => {
120					// For other column types, use generic is_defined check
121					for (group, indices) in groups.iter() {
122						let count = indices
123							.iter()
124							.filter(|&i| column.data().is_defined(*i))
125							.count() as i64;
126						self.counts.insert(group.clone(), count);
127					}
128				}
129			}
130		}
131		Ok(())
132	}
133
134	fn finalize(&mut self) -> crate::error::AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
135		let mut keys = Vec::with_capacity(self.counts.len());
136		let mut data = ColumnData::int8_with_capacity(self.counts.len());
137
138		for (key, count) in std::mem::take(&mut self.counts) {
139			keys.push(key);
140			data.push_value(Value::Int8(count));
141		}
142
143		Ok((keys, data))
144	}
145}