reifydb_routine/function/math/
avg.rs1use std::mem;
5
6use indexmap::IndexMap;
7use num_traits::ToPrimitive;
8use reifydb_core::value::column::{
9 ColumnWithName,
10 buffer::ColumnBuffer,
11 columns::Columns,
12 view::group_by::{GroupByView, GroupKey},
13};
14use reifydb_type::{
15 fragment::Fragment,
16 value::{
17 Value,
18 r#type::{Type, input_types::InputTypes},
19 },
20};
21
22use crate::routine::{
23 Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
24};
25
26pub struct Avg {
27 info: RoutineInfo,
28}
29
30impl Default for Avg {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl Avg {
37 pub fn new() -> Self {
38 Self {
39 info: RoutineInfo::new("math::avg"),
40 }
41 }
42}
43
44impl<'a> Routine<FunctionContext<'a>> for Avg {
45 fn info(&self) -> &RoutineInfo {
46 &self.info
47 }
48
49 fn return_type(&self, _input_types: &[Type]) -> Type {
50 Type::Float8
51 }
52
53 fn accepted_types(&self) -> InputTypes {
54 InputTypes::numeric()
55 }
56
57 fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
58 if args.is_empty() {
59 return Err(RoutineError::FunctionArityMismatch {
60 function: ctx.fragment.clone(),
61 expected: 1,
62 actual: 0,
63 });
64 }
65
66 let row_count = args.row_count();
67 let mut sum = vec![0.0f64; row_count];
68 let mut count = vec![0u32; row_count];
69
70 for (col_idx, col) in args.iter().enumerate() {
71 let (data, _bitvec) = col.data().unwrap_option();
72 match data {
73 ColumnBuffer::Int1(container) => {
74 for i in 0..row_count {
75 if let Some(value) = container.get(i) {
76 sum[i] += *value as f64;
77 count[i] += 1;
78 }
79 }
80 }
81 ColumnBuffer::Int2(container) => {
82 for i in 0..row_count {
83 if let Some(value) = container.get(i) {
84 sum[i] += *value as f64;
85 count[i] += 1;
86 }
87 }
88 }
89 ColumnBuffer::Int4(container) => {
90 for i in 0..row_count {
91 if let Some(value) = container.get(i) {
92 sum[i] += *value as f64;
93 count[i] += 1;
94 }
95 }
96 }
97 ColumnBuffer::Int8(container) => {
98 for i in 0..row_count {
99 if let Some(value) = container.get(i) {
100 sum[i] += *value as f64;
101 count[i] += 1;
102 }
103 }
104 }
105 ColumnBuffer::Int16(container) => {
106 for i in 0..row_count {
107 if let Some(value) = container.get(i) {
108 sum[i] += *value as f64;
109 count[i] += 1;
110 }
111 }
112 }
113 ColumnBuffer::Uint1(container) => {
114 for i in 0..row_count {
115 if let Some(value) = container.get(i) {
116 sum[i] += *value as f64;
117 count[i] += 1;
118 }
119 }
120 }
121 ColumnBuffer::Uint2(container) => {
122 for i in 0..row_count {
123 if let Some(value) = container.get(i) {
124 sum[i] += *value as f64;
125 count[i] += 1;
126 }
127 }
128 }
129 ColumnBuffer::Uint4(container) => {
130 for i in 0..row_count {
131 if let Some(value) = container.get(i) {
132 sum[i] += *value as f64;
133 count[i] += 1;
134 }
135 }
136 }
137 ColumnBuffer::Uint8(container) => {
138 for i in 0..row_count {
139 if let Some(value) = container.get(i) {
140 sum[i] += *value as f64;
141 count[i] += 1;
142 }
143 }
144 }
145 ColumnBuffer::Uint16(container) => {
146 for i in 0..row_count {
147 if let Some(value) = container.get(i) {
148 sum[i] += *value as f64;
149 count[i] += 1;
150 }
151 }
152 }
153 ColumnBuffer::Float4(container) => {
154 for i in 0..row_count {
155 if let Some(value) = container.get(i) {
156 sum[i] += *value as f64;
157 count[i] += 1;
158 }
159 }
160 }
161 ColumnBuffer::Float8(container) => {
162 for i in 0..row_count {
163 if let Some(value) = container.get(i) {
164 sum[i] += *value;
165 count[i] += 1;
166 }
167 }
168 }
169 ColumnBuffer::Int {
170 container,
171 ..
172 } => {
173 for i in 0..row_count {
174 if let Some(value) = container.get(i) {
175 sum[i] += value.0.to_f64().unwrap_or(0.0);
176 count[i] += 1;
177 }
178 }
179 }
180 ColumnBuffer::Uint {
181 container,
182 ..
183 } => {
184 for i in 0..row_count {
185 if let Some(value) = container.get(i) {
186 sum[i] += value.0.to_f64().unwrap_or(0.0);
187 count[i] += 1;
188 }
189 }
190 }
191 ColumnBuffer::Decimal {
192 container,
193 ..
194 } => {
195 for i in 0..row_count {
196 if let Some(value) = container.get(i) {
197 sum[i] += value.0.to_f64().unwrap_or(0.0);
198 count[i] += 1;
199 }
200 }
201 }
202 other => {
203 return Err(RoutineError::FunctionInvalidArgumentType {
204 function: ctx.fragment.clone(),
205 argument_index: col_idx,
206 expected: self.accepted_types().expected_at(0).to_vec(),
207 actual: other.get_type(),
208 });
209 }
210 }
211 }
212
213 let mut data = Vec::with_capacity(row_count);
214 let mut valids = Vec::with_capacity(row_count);
215
216 for i in 0..row_count {
217 if count[i] > 0 {
218 data.push(sum[i] / count[i] as f64);
219 valids.push(true);
220 } else {
221 data.push(0.0);
222 valids.push(false);
223 }
224 }
225
226 Ok(Columns::new(vec![ColumnWithName::new(
227 ctx.fragment.clone(),
228 ColumnBuffer::float8_with_bitvec(data, valids),
229 )]))
230 }
231}
232
233impl Function for Avg {
234 fn kinds(&self) -> &[FunctionKind] {
235 &[FunctionKind::Scalar, FunctionKind::Aggregate]
236 }
237
238 fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
239 Some(Box::new(AvgAccumulator::new()))
240 }
241}
242
243struct AvgAccumulator {
244 pub sums: IndexMap<GroupKey, f64>,
245 pub counts: IndexMap<GroupKey, u64>,
246}
247
248impl AvgAccumulator {
249 pub fn new() -> Self {
250 Self {
251 sums: IndexMap::new(),
252 counts: IndexMap::new(),
253 }
254 }
255}
256
257macro_rules! avg_arm {
258 ($self:expr, $column:expr, $groups:expr, $container:expr) => {
259 for (group, indices) in $groups.iter() {
260 let mut sum = 0.0f64;
261 let mut count = 0u64;
262 for &i in indices {
263 if $column.is_defined(i) {
264 if let Some(&val) = $container.get(i) {
265 sum += val as f64;
266 count += 1;
267 }
268 }
269 }
270 if count > 0 {
271 $self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
272 $self.counts.entry(group.clone()).and_modify(|c| *c += count).or_insert(count);
273 } else {
274 $self.sums.entry(group.clone()).or_insert(0.0);
275 $self.counts.entry(group.clone()).or_insert(0);
276 }
277 }
278 };
279}
280
281impl Accumulator for AvgAccumulator {
282 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
283 let column = &args[0];
284 let (data, _bitvec) = column.unwrap_option();
285
286 match data {
287 ColumnBuffer::Int1(container) => {
288 avg_arm!(self, column, groups, container);
289 Ok(())
290 }
291 ColumnBuffer::Int2(container) => {
292 avg_arm!(self, column, groups, container);
293 Ok(())
294 }
295 ColumnBuffer::Int4(container) => {
296 avg_arm!(self, column, groups, container);
297 Ok(())
298 }
299 ColumnBuffer::Int8(container) => {
300 avg_arm!(self, column, groups, container);
301 Ok(())
302 }
303 ColumnBuffer::Int16(container) => {
304 avg_arm!(self, column, groups, container);
305 Ok(())
306 }
307 ColumnBuffer::Uint1(container) => {
308 avg_arm!(self, column, groups, container);
309 Ok(())
310 }
311 ColumnBuffer::Uint2(container) => {
312 avg_arm!(self, column, groups, container);
313 Ok(())
314 }
315 ColumnBuffer::Uint4(container) => {
316 avg_arm!(self, column, groups, container);
317 Ok(())
318 }
319 ColumnBuffer::Uint8(container) => {
320 avg_arm!(self, column, groups, container);
321 Ok(())
322 }
323 ColumnBuffer::Uint16(container) => {
324 avg_arm!(self, column, groups, container);
325 Ok(())
326 }
327 ColumnBuffer::Float4(container) => {
328 avg_arm!(self, column, groups, container);
329 Ok(())
330 }
331 ColumnBuffer::Float8(container) => {
332 avg_arm!(self, column, groups, container);
333 Ok(())
334 }
335 ColumnBuffer::Int {
336 container,
337 ..
338 } => {
339 for (group, indices) in groups.iter() {
340 let mut sum = 0.0f64;
341 let mut count = 0u64;
342 for &i in indices {
343 if column.is_defined(i)
344 && let Some(val) = container.get(i)
345 {
346 sum += val.0.to_f64().unwrap_or(0.0);
347 count += 1;
348 }
349 }
350 if count > 0 {
351 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
352 self.counts
353 .entry(group.clone())
354 .and_modify(|c| *c += count)
355 .or_insert(count);
356 } else {
357 self.sums.entry(group.clone()).or_insert(0.0);
358 self.counts.entry(group.clone()).or_insert(0);
359 }
360 }
361 Ok(())
362 }
363 ColumnBuffer::Uint {
364 container,
365 ..
366 } => {
367 for (group, indices) in groups.iter() {
368 let mut sum = 0.0f64;
369 let mut count = 0u64;
370 for &i in indices {
371 if column.is_defined(i)
372 && let Some(val) = container.get(i)
373 {
374 sum += val.0.to_f64().unwrap_or(0.0);
375 count += 1;
376 }
377 }
378 if count > 0 {
379 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
380 self.counts
381 .entry(group.clone())
382 .and_modify(|c| *c += count)
383 .or_insert(count);
384 } else {
385 self.sums.entry(group.clone()).or_insert(0.0);
386 self.counts.entry(group.clone()).or_insert(0);
387 }
388 }
389 Ok(())
390 }
391 ColumnBuffer::Decimal {
392 container,
393 ..
394 } => {
395 for (group, indices) in groups.iter() {
396 let mut sum = 0.0f64;
397 let mut count = 0u64;
398 for &i in indices {
399 if column.is_defined(i)
400 && let Some(val) = container.get(i)
401 {
402 sum += val.0.to_f64().unwrap_or(0.0);
403 count += 1;
404 }
405 }
406 if count > 0 {
407 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
408 self.counts
409 .entry(group.clone())
410 .and_modify(|c| *c += count)
411 .or_insert(count);
412 } else {
413 self.sums.entry(group.clone()).or_insert(0.0);
414 self.counts.entry(group.clone()).or_insert(0);
415 }
416 }
417 Ok(())
418 }
419 other => Err(RoutineError::FunctionInvalidArgumentType {
420 function: Fragment::internal("math::avg"),
421 argument_index: 0,
422 expected: InputTypes::numeric().expected_at(0).to_vec(),
423 actual: other.get_type(),
424 }),
425 }
426 }
427
428 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
429 let mut keys = Vec::with_capacity(self.sums.len());
430 let mut data = ColumnBuffer::float8_with_capacity(self.sums.len());
431
432 for (key, sum) in mem::take(&mut self.sums) {
433 let count = self.counts.swap_remove(&key).unwrap_or(0);
434 keys.push(key);
435 if count > 0 {
436 data.push_value(Value::float8(sum / count as f64));
437 } else {
438 data.push_value(Value::none());
439 }
440 }
441
442 Ok((keys, data))
443 }
444}