reifydb_routine/function/math/
count.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8 ColumnWithName,
9 buffer::ColumnBuffer,
10 columns::Columns,
11 view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::value::{
14 Value,
15 r#type::{Type, input_types::InputTypes},
16};
17
18use crate::routine::{
19 Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
20};
21
22pub struct Count {
23 info: RoutineInfo,
24}
25
26impl Default for Count {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl Count {
33 pub fn new() -> Self {
34 Self {
35 info: RoutineInfo::new("math::count"),
36 }
37 }
38}
39
40impl<'a> Routine<FunctionContext<'a>> for Count {
41 fn info(&self) -> &RoutineInfo {
42 &self.info
43 }
44
45 fn return_type(&self, _input_types: &[Type]) -> Type {
46 Type::Int8
47 }
48
49 fn accepted_types(&self) -> InputTypes {
50 InputTypes::any()
51 }
52
53 fn propagates_options(&self) -> bool {
54 false
55 }
56
57 fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
58 let row_count = args.row_count();
59 let mut counts = vec![0i64; row_count];
60
61 for col in args.iter() {
62 for (i, count) in counts.iter_mut().enumerate().take(row_count) {
63 if col.data().is_defined(i) {
64 *count += 1;
65 }
66 }
67 }
68
69 Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), ColumnBuffer::int8(counts))]))
70 }
71}
72
73impl Function for Count {
74 fn kinds(&self) -> &[FunctionKind] {
75 &[FunctionKind::Scalar, FunctionKind::Aggregate]
76 }
77
78 fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
79 Some(Box::new(CountAccumulator::new()))
80 }
81}
82
83struct CountAccumulator {
84 pub counts: IndexMap<GroupKey, i64>,
85}
86
87impl CountAccumulator {
88 pub fn new() -> Self {
89 Self {
90 counts: IndexMap::new(),
91 }
92 }
93}
94
95impl Accumulator for CountAccumulator {
96 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
97 let column = &args[0];
98 let column_name = args.name_at(0);
99
100 let is_count_star = column_name.text() == "dummy" && matches!(column, ColumnBuffer::Int4(_));
101
102 if is_count_star {
103 for (group, indices) in groups.iter() {
104 let count = indices.len() as i64;
105 *self.counts.entry(group.clone()).or_insert(0) += count;
106 }
107 } else {
108 for (group, indices) in groups.iter() {
109 let count = indices.iter().filter(|&i| column.is_defined(*i)).count() as i64;
110 *self.counts.entry(group.clone()).or_insert(0) += count;
111 }
112 }
113 Ok(())
114 }
115
116 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
117 let mut keys = Vec::with_capacity(self.counts.len());
118 let mut data = ColumnBuffer::int8_with_capacity(self.counts.len());
119
120 for (key, count) in mem::take(&mut self.counts) {
121 keys.push(key);
122 data.push_value(Value::Int8(count));
123 }
124
125 Ok((keys, data))
126 }
127}