Skip to main content

reifydb_routine/function/math/
count.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8	ColumnWithName,
9	buffer::ColumnBuffer,
10	columns::Columns,
11	view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::value::{
14	Value,
15	r#type::{Type, input_types::InputTypes},
16};
17
18use crate::routine::{
19	Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
20};
21
22pub struct Count {
23	info: RoutineInfo,
24}
25
26impl Default for Count {
27	fn default() -> Self {
28		Self::new()
29	}
30}
31
32impl Count {
33	pub fn new() -> Self {
34		Self {
35			info: RoutineInfo::new("math::count"),
36		}
37	}
38}
39
40impl<'a> Routine<FunctionContext<'a>> for Count {
41	fn info(&self) -> &RoutineInfo {
42		&self.info
43	}
44
45	fn return_type(&self, _input_types: &[Type]) -> Type {
46		Type::Int8
47	}
48
49	fn accepted_types(&self) -> InputTypes {
50		InputTypes::any()
51	}
52
53	fn propagates_options(&self) -> bool {
54		false
55	}
56
57	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
58		// SCALAR: Horizontal Count (count of non-null arguments in each row)
59		let row_count = args.row_count();
60		let mut counts = vec![0i64; row_count];
61
62		for col in args.iter() {
63			for (i, count) in counts.iter_mut().enumerate().take(row_count) {
64				if col.data().is_defined(i) {
65					*count += 1;
66				}
67			}
68		}
69
70		Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), ColumnBuffer::int8(counts))]))
71	}
72}
73
74impl Function for Count {
75	fn kinds(&self) -> &[FunctionKind] {
76		&[FunctionKind::Scalar, FunctionKind::Aggregate]
77	}
78
79	fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
80		Some(Box::new(CountAccumulator::new()))
81	}
82}
83
84struct CountAccumulator {
85	pub counts: IndexMap<GroupKey, i64>,
86}
87
88impl CountAccumulator {
89	pub fn new() -> Self {
90		Self {
91			counts: IndexMap::new(),
92		}
93	}
94}
95
96impl Accumulator for CountAccumulator {
97	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
98		let column = &args[0];
99		let column_name = args.name_at(0);
100
101		// Check if this is count(*) by examining if we have a dummy column
102		let is_count_star = column_name.text() == "dummy" && matches!(column, ColumnBuffer::Int4(_));
103
104		if is_count_star {
105			for (group, indices) in groups.iter() {
106				let count = indices.len() as i64;
107				*self.counts.entry(group.clone()).or_insert(0) += count;
108			}
109		} else {
110			for (group, indices) in groups.iter() {
111				let count = indices.iter().filter(|&i| column.is_defined(*i)).count() as i64;
112				*self.counts.entry(group.clone()).or_insert(0) += count;
113			}
114		}
115		Ok(())
116	}
117
118	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
119		let mut keys = Vec::with_capacity(self.counts.len());
120		let mut data = ColumnBuffer::int8_with_capacity(self.counts.len());
121
122		for (key, count) in mem::take(&mut self.counts) {
123			keys.push(key);
124			data.push_value(Value::Int8(count));
125		}
126
127		Ok((keys, data))
128	}
129}