reifydb_engine/function/math/aggregate/
avg.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4use indexmap::IndexMap;
5use reifydb_core::value::column::ColumnData;
6use reifydb_type::Value;
7
8use crate::function::{AggregateFunction, AggregateFunctionContext};
9
10pub struct Avg {
11	pub sums: IndexMap<Vec<Value>, f64>,
12	pub counts: IndexMap<Vec<Value>, u64>,
13}
14
15impl Avg {
16	pub fn new() -> Self {
17		Self {
18			sums: IndexMap::new(),
19			counts: IndexMap::new(),
20		}
21	}
22}
23
24impl AggregateFunction for Avg {
25	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> crate::Result<()> {
26		let column = ctx.column;
27		let groups = &ctx.groups;
28
29		match &column.data() {
30			ColumnData::Float8(container) => {
31				for (group, indices) in groups.iter() {
32					let mut sum = 0.0;
33					let mut count = 0;
34
35					for &i in indices {
36						if let Some(value) = container.get(i) {
37							sum += *value;
38							count += 1;
39						}
40					}
41
42					if count > 0 {
43						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
44
45						self.counts
46							.entry(group.clone())
47							.and_modify(|c| *c += count)
48							.or_insert(count);
49					}
50				}
51				Ok(())
52			}
53			ColumnData::Float4(container) => {
54				for (group, indices) in groups.iter() {
55					let mut sum = 0.0;
56					let mut count = 0;
57
58					for &i in indices {
59						if let Some(value) = container.get(i) {
60							sum += *value as f64;
61							count += 1;
62						}
63					}
64
65					if count > 0 {
66						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
67
68						self.counts
69							.entry(group.clone())
70							.and_modify(|c| *c += count)
71							.or_insert(count);
72					}
73				}
74				Ok(())
75			}
76			ColumnData::Int2(container) => {
77				for (group, indices) in groups.iter() {
78					let mut sum = 0.0;
79					let mut count = 0;
80
81					for &i in indices {
82						if let Some(value) = container.get(i) {
83							sum += *value as f64;
84							count += 1;
85						}
86					}
87
88					if count > 0 {
89						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
90
91						self.counts
92							.entry(group.clone())
93							.and_modify(|c| *c += count)
94							.or_insert(count);
95					}
96				}
97				Ok(())
98			}
99			ColumnData::Int4(container) => {
100				for (group, indices) in groups.iter() {
101					let mut sum = 0.0;
102					let mut count = 0;
103
104					for &i in indices {
105						if let Some(value) = container.get(i) {
106							sum += *value as f64;
107							count += 1;
108						}
109					}
110
111					if count > 0 {
112						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
113
114						self.counts
115							.entry(group.clone())
116							.and_modify(|c| *c += count)
117							.or_insert(count);
118					}
119				}
120				Ok(())
121			}
122			ColumnData::Int8(container) => {
123				for (group, indices) in groups.iter() {
124					let mut sum = 0.0;
125					let mut count = 0;
126
127					for &i in indices {
128						if let Some(value) = container.get(i) {
129							sum += *value as f64;
130							count += 1;
131						}
132					}
133
134					if count > 0 {
135						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
136
137						self.counts
138							.entry(group.clone())
139							.and_modify(|c| *c += count)
140							.or_insert(count);
141					}
142				}
143				Ok(())
144			}
145			_ => unimplemented!(),
146		}
147	}
148
149	fn finalize(&mut self) -> crate::Result<(Vec<Vec<Value>>, ColumnData)> {
150		let mut keys = Vec::with_capacity(self.sums.len());
151		let mut data = ColumnData::float8_with_capacity(self.sums.len());
152
153		for (key, sum) in std::mem::take(&mut self.sums) {
154			let count = self.counts.swap_remove(&key).unwrap_or(0);
155			let avg = if count > 0 {
156				sum / count as f64
157			} else {
158				keys.push(key);
159				data.push_value(Value::undefined());
160				return Ok((keys, data));
161			};
162
163			keys.push(key);
164			data.push_value(Value::float8(avg));
165		}
166
167		Ok((keys, data))
168	}
169}