Skip to main content

reifydb_routine/function/math/aggregate/
avg.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use num_traits::ToPrimitive;
8use reifydb_core::value::column::data::ColumnData;
9use reifydb_type::value::{
10	Value,
11	r#type::{Type, input_types::InputTypes},
12};
13
14use crate::function::{
15	AggregateFunction, AggregateFunctionContext,
16	error::{AggregateFunctionError, AggregateFunctionResult},
17};
18
19pub struct Avg {
20	pub sums: IndexMap<Vec<Value>, f64>,
21	pub counts: IndexMap<Vec<Value>, u64>,
22}
23
24impl Avg {
25	pub fn new() -> Self {
26		Self {
27			sums: IndexMap::new(),
28			counts: IndexMap::new(),
29		}
30	}
31}
32
33macro_rules! avg_arm {
34	($self:expr, $column:expr, $groups:expr, $container:expr) => {
35		for (group, indices) in $groups.iter() {
36			let mut sum = 0.0f64;
37			let mut count = 0u64;
38			for &i in indices {
39				if $column.data().is_defined(i) {
40					if let Some(&val) = $container.get(i) {
41						sum += val as f64;
42						count += 1;
43					}
44				}
45			}
46			if count > 0 {
47				$self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
48				$self.counts.entry(group.clone()).and_modify(|c| *c += count).or_insert(count);
49			} else {
50				$self.sums.entry(group.clone()).or_insert(0.0);
51				$self.counts.entry(group.clone()).or_insert(0);
52			}
53		}
54	};
55}
56
57impl AggregateFunction for Avg {
58	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
59		let column = ctx.column;
60		let groups = &ctx.groups;
61		let (data, _bitvec) = column.data().unwrap_option();
62
63		match data {
64			ColumnData::Int1(container) => {
65				avg_arm!(self, column, groups, container);
66				Ok(())
67			}
68			ColumnData::Int2(container) => {
69				avg_arm!(self, column, groups, container);
70				Ok(())
71			}
72			ColumnData::Int4(container) => {
73				avg_arm!(self, column, groups, container);
74				Ok(())
75			}
76			ColumnData::Int8(container) => {
77				avg_arm!(self, column, groups, container);
78				Ok(())
79			}
80			ColumnData::Int16(container) => {
81				avg_arm!(self, column, groups, container);
82				Ok(())
83			}
84			ColumnData::Uint1(container) => {
85				avg_arm!(self, column, groups, container);
86				Ok(())
87			}
88			ColumnData::Uint2(container) => {
89				avg_arm!(self, column, groups, container);
90				Ok(())
91			}
92			ColumnData::Uint4(container) => {
93				avg_arm!(self, column, groups, container);
94				Ok(())
95			}
96			ColumnData::Uint8(container) => {
97				avg_arm!(self, column, groups, container);
98				Ok(())
99			}
100			ColumnData::Uint16(container) => {
101				avg_arm!(self, column, groups, container);
102				Ok(())
103			}
104			ColumnData::Float4(container) => {
105				avg_arm!(self, column, groups, container);
106				Ok(())
107			}
108			ColumnData::Float8(container) => {
109				avg_arm!(self, column, groups, container);
110				Ok(())
111			}
112			ColumnData::Int {
113				container,
114				..
115			} => {
116				for (group, indices) in groups.iter() {
117					let mut sum = 0.0f64;
118					let mut count = 0u64;
119					for &i in indices {
120						if column.data().is_defined(i) {
121							if let Some(val) = container.get(i) {
122								sum += val.0.to_f64().unwrap_or(0.0);
123								count += 1;
124							}
125						}
126					}
127					if count > 0 {
128						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
129						self.counts
130							.entry(group.clone())
131							.and_modify(|c| *c += count)
132							.or_insert(count);
133					} else {
134						self.sums.entry(group.clone()).or_insert(0.0);
135						self.counts.entry(group.clone()).or_insert(0);
136					}
137				}
138				Ok(())
139			}
140			ColumnData::Uint {
141				container,
142				..
143			} => {
144				for (group, indices) in groups.iter() {
145					let mut sum = 0.0f64;
146					let mut count = 0u64;
147					for &i in indices {
148						if column.data().is_defined(i) {
149							if let Some(val) = container.get(i) {
150								sum += val.0.to_f64().unwrap_or(0.0);
151								count += 1;
152							}
153						}
154					}
155					if count > 0 {
156						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
157						self.counts
158							.entry(group.clone())
159							.and_modify(|c| *c += count)
160							.or_insert(count);
161					} else {
162						self.sums.entry(group.clone()).or_insert(0.0);
163						self.counts.entry(group.clone()).or_insert(0);
164					}
165				}
166				Ok(())
167			}
168			ColumnData::Decimal {
169				container,
170				..
171			} => {
172				for (group, indices) in groups.iter() {
173					let mut sum = 0.0f64;
174					let mut count = 0u64;
175					for &i in indices {
176						if column.data().is_defined(i) {
177							if let Some(val) = container.get(i) {
178								sum += val.0.to_f64().unwrap_or(0.0);
179								count += 1;
180							}
181						}
182					}
183					if count > 0 {
184						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
185						self.counts
186							.entry(group.clone())
187							.and_modify(|c| *c += count)
188							.or_insert(count);
189					} else {
190						self.sums.entry(group.clone()).or_insert(0.0);
191						self.counts.entry(group.clone()).or_insert(0);
192					}
193				}
194				Ok(())
195			}
196			other => Err(AggregateFunctionError::InvalidArgumentType {
197				function: ctx.fragment.clone(),
198				argument_index: 0,
199				expected: self.accepted_types().expected_at(0).to_vec(),
200				actual: other.get_type(),
201			}),
202		}
203	}
204
205	fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
206		let mut keys = Vec::with_capacity(self.sums.len());
207		let mut data = ColumnData::float8_with_capacity(self.sums.len());
208
209		for (key, sum) in mem::take(&mut self.sums) {
210			let count = self.counts.swap_remove(&key).unwrap_or(0);
211			keys.push(key);
212			if count > 0 {
213				data.push_value(Value::float8(sum / count as f64));
214			} else {
215				data.push_value(Value::none());
216			}
217		}
218
219		Ok((keys, data))
220	}
221
222	fn return_type(&self, _input_type: &Type) -> Type {
223		Type::Float8
224	}
225
226	fn accepted_types(&self) -> InputTypes {
227		InputTypes::numeric()
228	}
229}