Skip to main content

reifydb_routine/function/math/
count.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8	Column,
9	columns::Columns,
10	data::ColumnData,
11	view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::value::{
14	Value,
15	r#type::{Type, input_types::InputTypes},
16};
17
18use crate::function::{Accumulator, Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
19
20pub struct Count {
21	info: FunctionInfo,
22}
23
24impl Default for Count {
25	fn default() -> Self {
26		Self::new()
27	}
28}
29
30impl Count {
31	pub fn new() -> Self {
32		Self {
33			info: FunctionInfo::new("math::count"),
34		}
35	}
36}
37
38impl Function for Count {
39	fn info(&self) -> &FunctionInfo {
40		&self.info
41	}
42
43	fn capabilities(&self) -> &[FunctionCapability] {
44		&[FunctionCapability::Scalar, FunctionCapability::Aggregate]
45	}
46
47	fn return_type(&self, _input_types: &[Type]) -> Type {
48		Type::Int8
49	}
50
51	fn accepted_types(&self) -> InputTypes {
52		InputTypes::any()
53	}
54
55	fn propagates_options(&self) -> bool {
56		false
57	}
58
59	fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
60		// SCALAR: Horizontal Count (count of non-null arguments in each row)
61		let row_count = args.row_count();
62		let mut counts = vec![0i64; row_count];
63
64		for col in args.iter() {
65			for (i, count) in counts.iter_mut().enumerate().take(row_count) {
66				if col.data().is_defined(i) {
67					*count += 1;
68				}
69			}
70		}
71
72		Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), ColumnData::int8(counts))]))
73	}
74
75	fn accumulator(&self, _ctx: &FunctionContext) -> Option<Box<dyn Accumulator>> {
76		Some(Box::new(CountAccumulator::new()))
77	}
78}
79
80struct CountAccumulator {
81	pub counts: IndexMap<GroupKey, i64>,
82}
83
84impl CountAccumulator {
85	pub fn new() -> Self {
86		Self {
87			counts: IndexMap::new(),
88		}
89	}
90}
91
92impl Accumulator for CountAccumulator {
93	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), FunctionError> {
94		let column = &args[0];
95
96		// Check if this is count(*) by examining if we have a dummy column
97		let is_count_star = column.name.text() == "dummy" && matches!(column.data(), ColumnData::Int4(_));
98
99		if is_count_star {
100			for (group, indices) in groups.iter() {
101				let count = indices.len() as i64;
102				*self.counts.entry(group.clone()).or_insert(0) += count;
103			}
104		} else {
105			for (group, indices) in groups.iter() {
106				let count = indices.iter().filter(|&i| column.data().is_defined(*i)).count() as i64;
107				*self.counts.entry(group.clone()).or_insert(0) += count;
108			}
109		}
110		Ok(())
111	}
112
113	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnData), FunctionError> {
114		let mut keys = Vec::with_capacity(self.counts.len());
115		let mut data = ColumnData::int8_with_capacity(self.counts.len());
116
117		for (key, count) in mem::take(&mut self.counts) {
118			keys.push(key);
119			data.push_value(Value::Int8(count));
120		}
121
122		Ok((keys, data))
123	}
124}