reifydb_routine/function/math/
avg.rs1use std::mem;
5
6use indexmap::IndexMap;
7use num_traits::ToPrimitive;
8use reifydb_core::value::column::{
9 Column,
10 columns::Columns,
11 data::ColumnData,
12 view::group_by::{GroupByView, GroupKey},
13};
14use reifydb_type::{
15 fragment::Fragment,
16 value::{
17 Value,
18 r#type::{Type, input_types::InputTypes},
19 },
20};
21
22use crate::function::{Accumulator, Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
23
24pub struct Avg {
25 info: FunctionInfo,
26}
27
28impl Default for Avg {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl Avg {
35 pub fn new() -> Self {
36 Self {
37 info: FunctionInfo::new("math::avg"),
38 }
39 }
40}
41
42impl Function for Avg {
43 fn info(&self) -> &FunctionInfo {
44 &self.info
45 }
46
47 fn capabilities(&self) -> &[FunctionCapability] {
48 &[FunctionCapability::Scalar, FunctionCapability::Aggregate]
49 }
50
51 fn return_type(&self, _input_types: &[Type]) -> Type {
52 Type::Float8
53 }
54
55 fn accepted_types(&self) -> InputTypes {
56 InputTypes::numeric()
57 }
58
59 fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
60 if args.is_empty() {
61 return Err(FunctionError::ArityMismatch {
62 function: ctx.fragment.clone(),
63 expected: 1,
64 actual: 0,
65 });
66 }
67
68 let row_count = args.row_count();
69 let mut sum = vec![0.0f64; row_count];
70 let mut count = vec![0u32; row_count];
71
72 for (col_idx, col) in args.iter().enumerate() {
73 let (data, _bitvec) = col.data().unwrap_option();
74 match data {
75 ColumnData::Int1(container) => {
76 for i in 0..row_count {
77 if let Some(value) = container.get(i) {
78 sum[i] += *value as f64;
79 count[i] += 1;
80 }
81 }
82 }
83 ColumnData::Int2(container) => {
84 for i in 0..row_count {
85 if let Some(value) = container.get(i) {
86 sum[i] += *value as f64;
87 count[i] += 1;
88 }
89 }
90 }
91 ColumnData::Int4(container) => {
92 for i in 0..row_count {
93 if let Some(value) = container.get(i) {
94 sum[i] += *value as f64;
95 count[i] += 1;
96 }
97 }
98 }
99 ColumnData::Int8(container) => {
100 for i in 0..row_count {
101 if let Some(value) = container.get(i) {
102 sum[i] += *value as f64;
103 count[i] += 1;
104 }
105 }
106 }
107 ColumnData::Int16(container) => {
108 for i in 0..row_count {
109 if let Some(value) = container.get(i) {
110 sum[i] += *value as f64;
111 count[i] += 1;
112 }
113 }
114 }
115 ColumnData::Uint1(container) => {
116 for i in 0..row_count {
117 if let Some(value) = container.get(i) {
118 sum[i] += *value as f64;
119 count[i] += 1;
120 }
121 }
122 }
123 ColumnData::Uint2(container) => {
124 for i in 0..row_count {
125 if let Some(value) = container.get(i) {
126 sum[i] += *value as f64;
127 count[i] += 1;
128 }
129 }
130 }
131 ColumnData::Uint4(container) => {
132 for i in 0..row_count {
133 if let Some(value) = container.get(i) {
134 sum[i] += *value as f64;
135 count[i] += 1;
136 }
137 }
138 }
139 ColumnData::Uint8(container) => {
140 for i in 0..row_count {
141 if let Some(value) = container.get(i) {
142 sum[i] += *value as f64;
143 count[i] += 1;
144 }
145 }
146 }
147 ColumnData::Uint16(container) => {
148 for i in 0..row_count {
149 if let Some(value) = container.get(i) {
150 sum[i] += *value as f64;
151 count[i] += 1;
152 }
153 }
154 }
155 ColumnData::Float4(container) => {
156 for i in 0..row_count {
157 if let Some(value) = container.get(i) {
158 sum[i] += *value as f64;
159 count[i] += 1;
160 }
161 }
162 }
163 ColumnData::Float8(container) => {
164 for i in 0..row_count {
165 if let Some(value) = container.get(i) {
166 sum[i] += *value;
167 count[i] += 1;
168 }
169 }
170 }
171 ColumnData::Int {
172 container,
173 ..
174 } => {
175 for i in 0..row_count {
176 if let Some(value) = container.get(i) {
177 sum[i] += value.0.to_f64().unwrap_or(0.0);
178 count[i] += 1;
179 }
180 }
181 }
182 ColumnData::Uint {
183 container,
184 ..
185 } => {
186 for i in 0..row_count {
187 if let Some(value) = container.get(i) {
188 sum[i] += value.0.to_f64().unwrap_or(0.0);
189 count[i] += 1;
190 }
191 }
192 }
193 ColumnData::Decimal {
194 container,
195 ..
196 } => {
197 for i in 0..row_count {
198 if let Some(value) = container.get(i) {
199 sum[i] += value.0.to_f64().unwrap_or(0.0);
200 count[i] += 1;
201 }
202 }
203 }
204 other => {
205 return Err(FunctionError::InvalidArgumentType {
206 function: ctx.fragment.clone(),
207 argument_index: col_idx,
208 expected: self.accepted_types().expected_at(0).to_vec(),
209 actual: other.get_type(),
210 });
211 }
212 }
213 }
214
215 let mut data = Vec::with_capacity(row_count);
216 let mut valids = Vec::with_capacity(row_count);
217
218 for i in 0..row_count {
219 if count[i] > 0 {
220 data.push(sum[i] / count[i] as f64);
221 valids.push(true);
222 } else {
223 data.push(0.0);
224 valids.push(false);
225 }
226 }
227
228 Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), ColumnData::float8_with_bitvec(data, valids))]))
229 }
230
231 fn accumulator(&self, _ctx: &FunctionContext) -> Option<Box<dyn Accumulator>> {
232 Some(Box::new(AvgAccumulator::new()))
233 }
234}
235
236struct AvgAccumulator {
237 pub sums: IndexMap<GroupKey, f64>,
238 pub counts: IndexMap<GroupKey, u64>,
239}
240
241impl AvgAccumulator {
242 pub fn new() -> Self {
243 Self {
244 sums: IndexMap::new(),
245 counts: IndexMap::new(),
246 }
247 }
248}
249
250macro_rules! avg_arm {
251 ($self:expr, $column:expr, $groups:expr, $container:expr) => {
252 for (group, indices) in $groups.iter() {
253 let mut sum = 0.0f64;
254 let mut count = 0u64;
255 for &i in indices {
256 if $column.data().is_defined(i) {
257 if let Some(&val) = $container.get(i) {
258 sum += val as f64;
259 count += 1;
260 }
261 }
262 }
263 if count > 0 {
264 $self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
265 $self.counts.entry(group.clone()).and_modify(|c| *c += count).or_insert(count);
266 } else {
267 $self.sums.entry(group.clone()).or_insert(0.0);
268 $self.counts.entry(group.clone()).or_insert(0);
269 }
270 }
271 };
272}
273
274impl Accumulator for AvgAccumulator {
275 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), FunctionError> {
276 let column = &args[0];
277 let (data, _bitvec) = column.data().unwrap_option();
278
279 match data {
280 ColumnData::Int1(container) => {
281 avg_arm!(self, column, groups, container);
282 Ok(())
283 }
284 ColumnData::Int2(container) => {
285 avg_arm!(self, column, groups, container);
286 Ok(())
287 }
288 ColumnData::Int4(container) => {
289 avg_arm!(self, column, groups, container);
290 Ok(())
291 }
292 ColumnData::Int8(container) => {
293 avg_arm!(self, column, groups, container);
294 Ok(())
295 }
296 ColumnData::Int16(container) => {
297 avg_arm!(self, column, groups, container);
298 Ok(())
299 }
300 ColumnData::Uint1(container) => {
301 avg_arm!(self, column, groups, container);
302 Ok(())
303 }
304 ColumnData::Uint2(container) => {
305 avg_arm!(self, column, groups, container);
306 Ok(())
307 }
308 ColumnData::Uint4(container) => {
309 avg_arm!(self, column, groups, container);
310 Ok(())
311 }
312 ColumnData::Uint8(container) => {
313 avg_arm!(self, column, groups, container);
314 Ok(())
315 }
316 ColumnData::Uint16(container) => {
317 avg_arm!(self, column, groups, container);
318 Ok(())
319 }
320 ColumnData::Float4(container) => {
321 avg_arm!(self, column, groups, container);
322 Ok(())
323 }
324 ColumnData::Float8(container) => {
325 avg_arm!(self, column, groups, container);
326 Ok(())
327 }
328 ColumnData::Int {
329 container,
330 ..
331 } => {
332 for (group, indices) in groups.iter() {
333 let mut sum = 0.0f64;
334 let mut count = 0u64;
335 for &i in indices {
336 if column.data().is_defined(i)
337 && let Some(val) = container.get(i)
338 {
339 sum += val.0.to_f64().unwrap_or(0.0);
340 count += 1;
341 }
342 }
343 if count > 0 {
344 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
345 self.counts
346 .entry(group.clone())
347 .and_modify(|c| *c += count)
348 .or_insert(count);
349 } else {
350 self.sums.entry(group.clone()).or_insert(0.0);
351 self.counts.entry(group.clone()).or_insert(0);
352 }
353 }
354 Ok(())
355 }
356 ColumnData::Uint {
357 container,
358 ..
359 } => {
360 for (group, indices) in groups.iter() {
361 let mut sum = 0.0f64;
362 let mut count = 0u64;
363 for &i in indices {
364 if column.data().is_defined(i)
365 && let Some(val) = container.get(i)
366 {
367 sum += val.0.to_f64().unwrap_or(0.0);
368 count += 1;
369 }
370 }
371 if count > 0 {
372 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
373 self.counts
374 .entry(group.clone())
375 .and_modify(|c| *c += count)
376 .or_insert(count);
377 } else {
378 self.sums.entry(group.clone()).or_insert(0.0);
379 self.counts.entry(group.clone()).or_insert(0);
380 }
381 }
382 Ok(())
383 }
384 ColumnData::Decimal {
385 container,
386 ..
387 } => {
388 for (group, indices) in groups.iter() {
389 let mut sum = 0.0f64;
390 let mut count = 0u64;
391 for &i in indices {
392 if column.data().is_defined(i)
393 && let Some(val) = container.get(i)
394 {
395 sum += val.0.to_f64().unwrap_or(0.0);
396 count += 1;
397 }
398 }
399 if count > 0 {
400 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
401 self.counts
402 .entry(group.clone())
403 .and_modify(|c| *c += count)
404 .or_insert(count);
405 } else {
406 self.sums.entry(group.clone()).or_insert(0.0);
407 self.counts.entry(group.clone()).or_insert(0);
408 }
409 }
410 Ok(())
411 }
412 other => Err(FunctionError::InvalidArgumentType {
413 function: Fragment::internal("math::avg"),
414 argument_index: 0,
415 expected: InputTypes::numeric().expected_at(0).to_vec(),
416 actual: other.get_type(),
417 }),
418 }
419 }
420
421 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnData), FunctionError> {
422 let mut keys = Vec::with_capacity(self.sums.len());
423 let mut data = ColumnData::float8_with_capacity(self.sums.len());
424
425 for (key, sum) in mem::take(&mut self.sums) {
426 let count = self.counts.swap_remove(&key).unwrap_or(0);
427 keys.push(key);
428 if count > 0 {
429 data.push_value(Value::float8(sum / count as f64));
430 } else {
431 data.push_value(Value::none());
432 }
433 }
434
435 Ok((keys, data))
436 }
437}