Skip to main content

reifydb_function/math/aggregate/
avg.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::Value;
9
10use crate::{AggregateFunction, AggregateFunctionContext, error::AggregateFunctionResult};
11
12pub struct Avg {
13	pub sums: IndexMap<Vec<Value>, f64>,
14	pub counts: IndexMap<Vec<Value>, u64>,
15}
16
17impl Avg {
18	pub fn new() -> Self {
19		Self {
20			sums: IndexMap::new(),
21			counts: IndexMap::new(),
22		}
23	}
24}
25
26impl AggregateFunction for Avg {
27	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
28		let column = ctx.column;
29		let groups = &ctx.groups;
30
31		match &column.data() {
32			ColumnData::Float8(container) => {
33				for (group, indices) in groups.iter() {
34					let mut sum = 0.0;
35					let mut count = 0;
36
37					for &i in indices {
38						if let Some(value) = container.get(i) {
39							sum += *value;
40							count += 1;
41						}
42					}
43
44					if count > 0 {
45						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
46
47						self.counts
48							.entry(group.clone())
49							.and_modify(|c| *c += count)
50							.or_insert(count);
51					}
52				}
53				Ok(())
54			}
55			ColumnData::Float4(container) => {
56				for (group, indices) in groups.iter() {
57					let mut sum = 0.0;
58					let mut count = 0;
59
60					for &i in indices {
61						if let Some(value) = container.get(i) {
62							sum += *value as f64;
63							count += 1;
64						}
65					}
66
67					if count > 0 {
68						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
69
70						self.counts
71							.entry(group.clone())
72							.and_modify(|c| *c += count)
73							.or_insert(count);
74					}
75				}
76				Ok(())
77			}
78			ColumnData::Int2(container) => {
79				for (group, indices) in groups.iter() {
80					let mut sum = 0.0;
81					let mut count = 0;
82
83					for &i in indices {
84						if let Some(value) = container.get(i) {
85							sum += *value as f64;
86							count += 1;
87						}
88					}
89
90					if count > 0 {
91						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
92
93						self.counts
94							.entry(group.clone())
95							.and_modify(|c| *c += count)
96							.or_insert(count);
97					}
98				}
99				Ok(())
100			}
101			ColumnData::Int4(container) => {
102				for (group, indices) in groups.iter() {
103					let mut sum = 0.0;
104					let mut count = 0;
105
106					for &i in indices {
107						if let Some(value) = container.get(i) {
108							sum += *value as f64;
109							count += 1;
110						}
111					}
112
113					if count > 0 {
114						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
115
116						self.counts
117							.entry(group.clone())
118							.and_modify(|c| *c += count)
119							.or_insert(count);
120					}
121				}
122				Ok(())
123			}
124			ColumnData::Int8(container) => {
125				for (group, indices) in groups.iter() {
126					let mut sum = 0.0;
127					let mut count = 0;
128
129					for &i in indices {
130						if let Some(value) = container.get(i) {
131							sum += *value as f64;
132							count += 1;
133						}
134					}
135
136					if count > 0 {
137						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
138
139						self.counts
140							.entry(group.clone())
141							.and_modify(|c| *c += count)
142							.or_insert(count);
143					}
144				}
145				Ok(())
146			}
147			_ => unimplemented!(),
148		}
149	}
150
151	fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
152		let mut keys = Vec::with_capacity(self.sums.len());
153		let mut data = ColumnData::float8_with_capacity(self.sums.len());
154
155		for (key, sum) in mem::take(&mut self.sums) {
156			let count = self.counts.swap_remove(&key).unwrap_or(0);
157			let avg = if count > 0 {
158				sum / count as f64
159			} else {
160				keys.push(key);
161				data.push_value(Value::none());
162				return Ok((keys, data));
163			};
164
165			keys.push(key);
166			data.push_value(Value::float8(avg));
167		}
168
169		Ok((keys, data))
170	}
171}