reifydb_routine/function/math/
count.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8 Column,
9 columns::Columns,
10 data::ColumnData,
11 view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::value::{
14 Value,
15 r#type::{Type, input_types::InputTypes},
16};
17
18use crate::function::{Accumulator, Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
19
20pub struct Count {
21 info: FunctionInfo,
22}
23
24impl Default for Count {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl Count {
31 pub fn new() -> Self {
32 Self {
33 info: FunctionInfo::new("math::count"),
34 }
35 }
36}
37
38impl Function for Count {
39 fn info(&self) -> &FunctionInfo {
40 &self.info
41 }
42
43 fn capabilities(&self) -> &[FunctionCapability] {
44 &[FunctionCapability::Scalar, FunctionCapability::Aggregate]
45 }
46
47 fn return_type(&self, _input_types: &[Type]) -> Type {
48 Type::Int8
49 }
50
51 fn accepted_types(&self) -> InputTypes {
52 InputTypes::any()
53 }
54
55 fn propagates_options(&self) -> bool {
56 false
57 }
58
59 fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
60 let row_count = args.row_count();
62 let mut counts = vec![0i64; row_count];
63
64 for col in args.iter() {
65 for (i, count) in counts.iter_mut().enumerate().take(row_count) {
66 if col.data().is_defined(i) {
67 *count += 1;
68 }
69 }
70 }
71
72 Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), ColumnData::int8(counts))]))
73 }
74
75 fn accumulator(&self, _ctx: &FunctionContext) -> Option<Box<dyn Accumulator>> {
76 Some(Box::new(CountAccumulator::new()))
77 }
78}
79
80struct CountAccumulator {
81 pub counts: IndexMap<GroupKey, i64>,
82}
83
84impl CountAccumulator {
85 pub fn new() -> Self {
86 Self {
87 counts: IndexMap::new(),
88 }
89 }
90}
91
92impl Accumulator for CountAccumulator {
93 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), FunctionError> {
94 let column = &args[0];
95
96 let is_count_star = column.name.text() == "dummy" && matches!(column.data(), ColumnData::Int4(_));
98
99 if is_count_star {
100 for (group, indices) in groups.iter() {
101 let count = indices.len() as i64;
102 *self.counts.entry(group.clone()).or_insert(0) += count;
103 }
104 } else {
105 for (group, indices) in groups.iter() {
106 let count = indices.iter().filter(|&i| column.data().is_defined(*i)).count() as i64;
107 *self.counts.entry(group.clone()).or_insert(0) += count;
108 }
109 }
110 Ok(())
111 }
112
113 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnData), FunctionError> {
114 let mut keys = Vec::with_capacity(self.counts.len());
115 let mut data = ColumnData::int8_with_capacity(self.counts.len());
116
117 for (key, count) in mem::take(&mut self.counts) {
118 keys.push(key);
119 data.push_value(Value::Int8(count));
120 }
121
122 Ok((keys, data))
123 }
124}