Skip to main content

reifydb_routine/function/math/
avg.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use num_traits::ToPrimitive;
8use reifydb_core::value::column::{
9	ColumnWithName,
10	buffer::ColumnBuffer,
11	columns::Columns,
12	view::group_by::{GroupByView, GroupKey},
13};
14use reifydb_type::{
15	fragment::Fragment,
16	value::{
17		Value,
18		r#type::{Type, input_types::InputTypes},
19	},
20};
21
22use crate::routine::{
23	Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
24};
25
26pub struct Avg {
27	info: RoutineInfo,
28}
29
30impl Default for Avg {
31	fn default() -> Self {
32		Self::new()
33	}
34}
35
36impl Avg {
37	pub fn new() -> Self {
38		Self {
39			info: RoutineInfo::new("math::avg"),
40		}
41	}
42}
43
44impl<'a> Routine<FunctionContext<'a>> for Avg {
45	fn info(&self) -> &RoutineInfo {
46		&self.info
47	}
48
49	fn return_type(&self, _input_types: &[Type]) -> Type {
50		Type::Float8
51	}
52
53	fn accepted_types(&self) -> InputTypes {
54		InputTypes::numeric()
55	}
56
57	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
58		if args.is_empty() {
59			return Err(RoutineError::FunctionArityMismatch {
60				function: ctx.fragment.clone(),
61				expected: 1,
62				actual: 0,
63			});
64		}
65
66		let row_count = args.row_count();
67		let mut sum = vec![0.0f64; row_count];
68		let mut count = vec![0u32; row_count];
69
70		for (col_idx, col) in args.iter().enumerate() {
71			let (data, _bitvec) = col.data().unwrap_option();
72			match data {
73				ColumnBuffer::Int1(container) => {
74					for i in 0..row_count {
75						if let Some(value) = container.get(i) {
76							sum[i] += *value as f64;
77							count[i] += 1;
78						}
79					}
80				}
81				ColumnBuffer::Int2(container) => {
82					for i in 0..row_count {
83						if let Some(value) = container.get(i) {
84							sum[i] += *value as f64;
85							count[i] += 1;
86						}
87					}
88				}
89				ColumnBuffer::Int4(container) => {
90					for i in 0..row_count {
91						if let Some(value) = container.get(i) {
92							sum[i] += *value as f64;
93							count[i] += 1;
94						}
95					}
96				}
97				ColumnBuffer::Int8(container) => {
98					for i in 0..row_count {
99						if let Some(value) = container.get(i) {
100							sum[i] += *value as f64;
101							count[i] += 1;
102						}
103					}
104				}
105				ColumnBuffer::Int16(container) => {
106					for i in 0..row_count {
107						if let Some(value) = container.get(i) {
108							sum[i] += *value as f64;
109							count[i] += 1;
110						}
111					}
112				}
113				ColumnBuffer::Uint1(container) => {
114					for i in 0..row_count {
115						if let Some(value) = container.get(i) {
116							sum[i] += *value as f64;
117							count[i] += 1;
118						}
119					}
120				}
121				ColumnBuffer::Uint2(container) => {
122					for i in 0..row_count {
123						if let Some(value) = container.get(i) {
124							sum[i] += *value as f64;
125							count[i] += 1;
126						}
127					}
128				}
129				ColumnBuffer::Uint4(container) => {
130					for i in 0..row_count {
131						if let Some(value) = container.get(i) {
132							sum[i] += *value as f64;
133							count[i] += 1;
134						}
135					}
136				}
137				ColumnBuffer::Uint8(container) => {
138					for i in 0..row_count {
139						if let Some(value) = container.get(i) {
140							sum[i] += *value as f64;
141							count[i] += 1;
142						}
143					}
144				}
145				ColumnBuffer::Uint16(container) => {
146					for i in 0..row_count {
147						if let Some(value) = container.get(i) {
148							sum[i] += *value as f64;
149							count[i] += 1;
150						}
151					}
152				}
153				ColumnBuffer::Float4(container) => {
154					for i in 0..row_count {
155						if let Some(value) = container.get(i) {
156							sum[i] += *value as f64;
157							count[i] += 1;
158						}
159					}
160				}
161				ColumnBuffer::Float8(container) => {
162					for i in 0..row_count {
163						if let Some(value) = container.get(i) {
164							sum[i] += *value;
165							count[i] += 1;
166						}
167					}
168				}
169				ColumnBuffer::Int {
170					container,
171					..
172				} => {
173					for i in 0..row_count {
174						if let Some(value) = container.get(i) {
175							sum[i] += value.0.to_f64().unwrap_or(0.0);
176							count[i] += 1;
177						}
178					}
179				}
180				ColumnBuffer::Uint {
181					container,
182					..
183				} => {
184					for i in 0..row_count {
185						if let Some(value) = container.get(i) {
186							sum[i] += value.0.to_f64().unwrap_or(0.0);
187							count[i] += 1;
188						}
189					}
190				}
191				ColumnBuffer::Decimal {
192					container,
193					..
194				} => {
195					for i in 0..row_count {
196						if let Some(value) = container.get(i) {
197							sum[i] += value.0.to_f64().unwrap_or(0.0);
198							count[i] += 1;
199						}
200					}
201				}
202				other => {
203					return Err(RoutineError::FunctionInvalidArgumentType {
204						function: ctx.fragment.clone(),
205						argument_index: col_idx,
206						expected: self.accepted_types().expected_at(0).to_vec(),
207						actual: other.get_type(),
208					});
209				}
210			}
211		}
212
213		let mut data = Vec::with_capacity(row_count);
214		let mut valids = Vec::with_capacity(row_count);
215
216		for i in 0..row_count {
217			if count[i] > 0 {
218				data.push(sum[i] / count[i] as f64);
219				valids.push(true);
220			} else {
221				data.push(0.0);
222				valids.push(false);
223			}
224		}
225
226		Ok(Columns::new(vec![ColumnWithName::new(
227			ctx.fragment.clone(),
228			ColumnBuffer::float8_with_bitvec(data, valids),
229		)]))
230	}
231}
232
233impl Function for Avg {
234	fn kinds(&self) -> &[FunctionKind] {
235		&[FunctionKind::Scalar, FunctionKind::Aggregate]
236	}
237
238	fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
239		Some(Box::new(AvgAccumulator::new()))
240	}
241}
242
243struct AvgAccumulator {
244	pub sums: IndexMap<GroupKey, f64>,
245	pub counts: IndexMap<GroupKey, u64>,
246}
247
248impl AvgAccumulator {
249	pub fn new() -> Self {
250		Self {
251			sums: IndexMap::new(),
252			counts: IndexMap::new(),
253		}
254	}
255}
256
257macro_rules! avg_arm {
258	($self:expr, $column:expr, $groups:expr, $container:expr) => {
259		for (group, indices) in $groups.iter() {
260			let mut sum = 0.0f64;
261			let mut count = 0u64;
262			for &i in indices {
263				if $column.is_defined(i) {
264					if let Some(&val) = $container.get(i) {
265						sum += val as f64;
266						count += 1;
267					}
268				}
269			}
270			if count > 0 {
271				$self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
272				$self.counts.entry(group.clone()).and_modify(|c| *c += count).or_insert(count);
273			} else {
274				$self.sums.entry(group.clone()).or_insert(0.0);
275				$self.counts.entry(group.clone()).or_insert(0);
276			}
277		}
278	};
279}
280
281impl Accumulator for AvgAccumulator {
282	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
283		let column = &args[0];
284		let (data, _bitvec) = column.unwrap_option();
285
286		match data {
287			ColumnBuffer::Int1(container) => {
288				avg_arm!(self, column, groups, container);
289				Ok(())
290			}
291			ColumnBuffer::Int2(container) => {
292				avg_arm!(self, column, groups, container);
293				Ok(())
294			}
295			ColumnBuffer::Int4(container) => {
296				avg_arm!(self, column, groups, container);
297				Ok(())
298			}
299			ColumnBuffer::Int8(container) => {
300				avg_arm!(self, column, groups, container);
301				Ok(())
302			}
303			ColumnBuffer::Int16(container) => {
304				avg_arm!(self, column, groups, container);
305				Ok(())
306			}
307			ColumnBuffer::Uint1(container) => {
308				avg_arm!(self, column, groups, container);
309				Ok(())
310			}
311			ColumnBuffer::Uint2(container) => {
312				avg_arm!(self, column, groups, container);
313				Ok(())
314			}
315			ColumnBuffer::Uint4(container) => {
316				avg_arm!(self, column, groups, container);
317				Ok(())
318			}
319			ColumnBuffer::Uint8(container) => {
320				avg_arm!(self, column, groups, container);
321				Ok(())
322			}
323			ColumnBuffer::Uint16(container) => {
324				avg_arm!(self, column, groups, container);
325				Ok(())
326			}
327			ColumnBuffer::Float4(container) => {
328				avg_arm!(self, column, groups, container);
329				Ok(())
330			}
331			ColumnBuffer::Float8(container) => {
332				avg_arm!(self, column, groups, container);
333				Ok(())
334			}
335			ColumnBuffer::Int {
336				container,
337				..
338			} => {
339				for (group, indices) in groups.iter() {
340					let mut sum = 0.0f64;
341					let mut count = 0u64;
342					for &i in indices {
343						if column.is_defined(i)
344							&& let Some(val) = container.get(i)
345						{
346							sum += val.0.to_f64().unwrap_or(0.0);
347							count += 1;
348						}
349					}
350					if count > 0 {
351						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
352						self.counts
353							.entry(group.clone())
354							.and_modify(|c| *c += count)
355							.or_insert(count);
356					} else {
357						self.sums.entry(group.clone()).or_insert(0.0);
358						self.counts.entry(group.clone()).or_insert(0);
359					}
360				}
361				Ok(())
362			}
363			ColumnBuffer::Uint {
364				container,
365				..
366			} => {
367				for (group, indices) in groups.iter() {
368					let mut sum = 0.0f64;
369					let mut count = 0u64;
370					for &i in indices {
371						if column.is_defined(i)
372							&& let Some(val) = container.get(i)
373						{
374							sum += val.0.to_f64().unwrap_or(0.0);
375							count += 1;
376						}
377					}
378					if count > 0 {
379						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
380						self.counts
381							.entry(group.clone())
382							.and_modify(|c| *c += count)
383							.or_insert(count);
384					} else {
385						self.sums.entry(group.clone()).or_insert(0.0);
386						self.counts.entry(group.clone()).or_insert(0);
387					}
388				}
389				Ok(())
390			}
391			ColumnBuffer::Decimal {
392				container,
393				..
394			} => {
395				for (group, indices) in groups.iter() {
396					let mut sum = 0.0f64;
397					let mut count = 0u64;
398					for &i in indices {
399						if column.is_defined(i)
400							&& let Some(val) = container.get(i)
401						{
402							sum += val.0.to_f64().unwrap_or(0.0);
403							count += 1;
404						}
405					}
406					if count > 0 {
407						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
408						self.counts
409							.entry(group.clone())
410							.and_modify(|c| *c += count)
411							.or_insert(count);
412					} else {
413						self.sums.entry(group.clone()).or_insert(0.0);
414						self.counts.entry(group.clone()).or_insert(0);
415					}
416				}
417				Ok(())
418			}
419			other => Err(RoutineError::FunctionInvalidArgumentType {
420				function: Fragment::internal("math::avg"),
421				argument_index: 0,
422				expected: InputTypes::numeric().expected_at(0).to_vec(),
423				actual: other.get_type(),
424			}),
425		}
426	}
427
428	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
429		let mut keys = Vec::with_capacity(self.sums.len());
430		let mut data = ColumnBuffer::float8_with_capacity(self.sums.len());
431
432		for (key, sum) in mem::take(&mut self.sums) {
433			let count = self.counts.swap_remove(&key).unwrap_or(0);
434			keys.push(key);
435			if count > 0 {
436				data.push_value(Value::float8(sum / count as f64));
437			} else {
438				data.push_value(Value::none());
439			}
440		}
441
442		Ok((keys, data))
443	}
444}