reifydb_engine/function/math/aggregate/
avg.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4use std::collections::HashMap;
5
6use reifydb_core::value::column::ColumnData;
7use reifydb_type::Value;
8
9use crate::function::{AggregateFunction, AggregateFunctionContext};
10
11pub struct Avg {
12	pub sums: HashMap<Vec<Value>, f64>,
13	pub counts: HashMap<Vec<Value>, u64>,
14}
15
16impl Avg {
17	pub fn new() -> Self {
18		Self {
19			sums: HashMap::new(),
20			counts: HashMap::new(),
21		}
22	}
23}
24
25impl AggregateFunction for Avg {
26	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> crate::Result<()> {
27		let column = ctx.column;
28		let groups = &ctx.groups;
29
30		match &column.data() {
31			ColumnData::Float8(container) => {
32				for (group, indices) in groups.iter() {
33					let mut sum = 0.0;
34					let mut count = 0;
35
36					for &i in indices {
37						if let Some(value) = container.get(i) {
38							sum += *value;
39							count += 1;
40						}
41					}
42
43					if count > 0 {
44						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
45
46						self.counts
47							.entry(group.clone())
48							.and_modify(|c| *c += count)
49							.or_insert(count);
50					}
51				}
52				Ok(())
53			}
54			ColumnData::Float4(container) => {
55				for (group, indices) in groups.iter() {
56					let mut sum = 0.0;
57					let mut count = 0;
58
59					for &i in indices {
60						if let Some(value) = container.get(i) {
61							sum += *value as f64;
62							count += 1;
63						}
64					}
65
66					if count > 0 {
67						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
68
69						self.counts
70							.entry(group.clone())
71							.and_modify(|c| *c += count)
72							.or_insert(count);
73					}
74				}
75				Ok(())
76			}
77			ColumnData::Int2(container) => {
78				for (group, indices) in groups.iter() {
79					let mut sum = 0.0;
80					let mut count = 0;
81
82					for &i in indices {
83						if let Some(value) = container.get(i) {
84							sum += *value as f64;
85							count += 1;
86						}
87					}
88
89					if count > 0 {
90						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
91
92						self.counts
93							.entry(group.clone())
94							.and_modify(|c| *c += count)
95							.or_insert(count);
96					}
97				}
98				Ok(())
99			}
100			ColumnData::Int4(container) => {
101				for (group, indices) in groups.iter() {
102					let mut sum = 0.0;
103					let mut count = 0;
104
105					for &i in indices {
106						if let Some(value) = container.get(i) {
107							sum += *value as f64;
108							count += 1;
109						}
110					}
111
112					if count > 0 {
113						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
114
115						self.counts
116							.entry(group.clone())
117							.and_modify(|c| *c += count)
118							.or_insert(count);
119					}
120				}
121				Ok(())
122			}
123			ColumnData::Int8(container) => {
124				for (group, indices) in groups.iter() {
125					let mut sum = 0.0;
126					let mut count = 0;
127
128					for &i in indices {
129						if let Some(value) = container.get(i) {
130							sum += *value as f64;
131							count += 1;
132						}
133					}
134
135					if count > 0 {
136						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
137
138						self.counts
139							.entry(group.clone())
140							.and_modify(|c| *c += count)
141							.or_insert(count);
142					}
143				}
144				Ok(())
145			}
146			_ => unimplemented!(),
147		}
148	}
149
150	fn finalize(&mut self) -> crate::Result<(Vec<Vec<Value>>, ColumnData)> {
151		let mut keys = Vec::with_capacity(self.sums.len());
152		let mut data = ColumnData::float8_with_capacity(self.sums.len());
153
154		for (key, sum) in std::mem::take(&mut self.sums) {
155			let count = self.counts.remove(&key).unwrap_or(0);
156			let avg = if count > 0 {
157				sum / count as f64
158			} else {
159				f64::NAN // or return Value::Undefined if preferred
160			};
161
162			keys.push(key);
163			data.push_value(Value::float8(avg));
164		}
165
166		Ok((keys, data))
167	}
168}