reifydb_routine/function/math/aggregate/
avg.rs1use std::mem;
5
6use indexmap::IndexMap;
7use num_traits::ToPrimitive;
8use reifydb_core::value::column::data::ColumnData;
9use reifydb_type::value::{
10 Value,
11 r#type::{Type, input_types::InputTypes},
12};
13
14use crate::function::{
15 AggregateFunction, AggregateFunctionContext,
16 error::{AggregateFunctionError, AggregateFunctionResult},
17};
18
19pub struct Avg {
20 pub sums: IndexMap<Vec<Value>, f64>,
21 pub counts: IndexMap<Vec<Value>, u64>,
22}
23
24impl Avg {
25 pub fn new() -> Self {
26 Self {
27 sums: IndexMap::new(),
28 counts: IndexMap::new(),
29 }
30 }
31}
32
33macro_rules! avg_arm {
34 ($self:expr, $column:expr, $groups:expr, $container:expr) => {
35 for (group, indices) in $groups.iter() {
36 let mut sum = 0.0f64;
37 let mut count = 0u64;
38 for &i in indices {
39 if $column.data().is_defined(i) {
40 if let Some(&val) = $container.get(i) {
41 sum += val as f64;
42 count += 1;
43 }
44 }
45 }
46 if count > 0 {
47 $self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
48 $self.counts.entry(group.clone()).and_modify(|c| *c += count).or_insert(count);
49 } else {
50 $self.sums.entry(group.clone()).or_insert(0.0);
51 $self.counts.entry(group.clone()).or_insert(0);
52 }
53 }
54 };
55}
56
57impl AggregateFunction for Avg {
58 fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
59 let column = ctx.column;
60 let groups = &ctx.groups;
61 let (data, _bitvec) = column.data().unwrap_option();
62
63 match data {
64 ColumnData::Int1(container) => {
65 avg_arm!(self, column, groups, container);
66 Ok(())
67 }
68 ColumnData::Int2(container) => {
69 avg_arm!(self, column, groups, container);
70 Ok(())
71 }
72 ColumnData::Int4(container) => {
73 avg_arm!(self, column, groups, container);
74 Ok(())
75 }
76 ColumnData::Int8(container) => {
77 avg_arm!(self, column, groups, container);
78 Ok(())
79 }
80 ColumnData::Int16(container) => {
81 avg_arm!(self, column, groups, container);
82 Ok(())
83 }
84 ColumnData::Uint1(container) => {
85 avg_arm!(self, column, groups, container);
86 Ok(())
87 }
88 ColumnData::Uint2(container) => {
89 avg_arm!(self, column, groups, container);
90 Ok(())
91 }
92 ColumnData::Uint4(container) => {
93 avg_arm!(self, column, groups, container);
94 Ok(())
95 }
96 ColumnData::Uint8(container) => {
97 avg_arm!(self, column, groups, container);
98 Ok(())
99 }
100 ColumnData::Uint16(container) => {
101 avg_arm!(self, column, groups, container);
102 Ok(())
103 }
104 ColumnData::Float4(container) => {
105 avg_arm!(self, column, groups, container);
106 Ok(())
107 }
108 ColumnData::Float8(container) => {
109 avg_arm!(self, column, groups, container);
110 Ok(())
111 }
112 ColumnData::Int {
113 container,
114 ..
115 } => {
116 for (group, indices) in groups.iter() {
117 let mut sum = 0.0f64;
118 let mut count = 0u64;
119 for &i in indices {
120 if column.data().is_defined(i) {
121 if let Some(val) = container.get(i) {
122 sum += val.0.to_f64().unwrap_or(0.0);
123 count += 1;
124 }
125 }
126 }
127 if count > 0 {
128 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
129 self.counts
130 .entry(group.clone())
131 .and_modify(|c| *c += count)
132 .or_insert(count);
133 } else {
134 self.sums.entry(group.clone()).or_insert(0.0);
135 self.counts.entry(group.clone()).or_insert(0);
136 }
137 }
138 Ok(())
139 }
140 ColumnData::Uint {
141 container,
142 ..
143 } => {
144 for (group, indices) in groups.iter() {
145 let mut sum = 0.0f64;
146 let mut count = 0u64;
147 for &i in indices {
148 if column.data().is_defined(i) {
149 if let Some(val) = container.get(i) {
150 sum += val.0.to_f64().unwrap_or(0.0);
151 count += 1;
152 }
153 }
154 }
155 if count > 0 {
156 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
157 self.counts
158 .entry(group.clone())
159 .and_modify(|c| *c += count)
160 .or_insert(count);
161 } else {
162 self.sums.entry(group.clone()).or_insert(0.0);
163 self.counts.entry(group.clone()).or_insert(0);
164 }
165 }
166 Ok(())
167 }
168 ColumnData::Decimal {
169 container,
170 ..
171 } => {
172 for (group, indices) in groups.iter() {
173 let mut sum = 0.0f64;
174 let mut count = 0u64;
175 for &i in indices {
176 if column.data().is_defined(i) {
177 if let Some(val) = container.get(i) {
178 sum += val.0.to_f64().unwrap_or(0.0);
179 count += 1;
180 }
181 }
182 }
183 if count > 0 {
184 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
185 self.counts
186 .entry(group.clone())
187 .and_modify(|c| *c += count)
188 .or_insert(count);
189 } else {
190 self.sums.entry(group.clone()).or_insert(0.0);
191 self.counts.entry(group.clone()).or_insert(0);
192 }
193 }
194 Ok(())
195 }
196 other => Err(AggregateFunctionError::InvalidArgumentType {
197 function: ctx.fragment.clone(),
198 argument_index: 0,
199 expected: self.accepted_types().expected_at(0).to_vec(),
200 actual: other.get_type(),
201 }),
202 }
203 }
204
205 fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
206 let mut keys = Vec::with_capacity(self.sums.len());
207 let mut data = ColumnData::float8_with_capacity(self.sums.len());
208
209 for (key, sum) in mem::take(&mut self.sums) {
210 let count = self.counts.swap_remove(&key).unwrap_or(0);
211 keys.push(key);
212 if count > 0 {
213 data.push_value(Value::float8(sum / count as f64));
214 } else {
215 data.push_value(Value::none());
216 }
217 }
218
219 Ok((keys, data))
220 }
221
222 fn return_type(&self, _input_type: &Type) -> Type {
223 Type::Float8
224 }
225
226 fn accepted_types(&self) -> InputTypes {
227 InputTypes::numeric()
228 }
229}