Skip to main content

reifydb_routine/function/math/
avg.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use num_traits::ToPrimitive;
8use reifydb_core::value::column::{
9	Column,
10	columns::Columns,
11	data::ColumnData,
12	view::group_by::{GroupByView, GroupKey},
13};
14use reifydb_type::{
15	fragment::Fragment,
16	value::{
17		Value,
18		r#type::{Type, input_types::InputTypes},
19	},
20};
21
22use crate::function::{Accumulator, Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
23
24pub struct Avg {
25	info: FunctionInfo,
26}
27
28impl Default for Avg {
29	fn default() -> Self {
30		Self::new()
31	}
32}
33
34impl Avg {
35	pub fn new() -> Self {
36		Self {
37			info: FunctionInfo::new("math::avg"),
38		}
39	}
40}
41
42impl Function for Avg {
43	fn info(&self) -> &FunctionInfo {
44		&self.info
45	}
46
47	fn capabilities(&self) -> &[FunctionCapability] {
48		&[FunctionCapability::Scalar, FunctionCapability::Aggregate]
49	}
50
51	fn return_type(&self, _input_types: &[Type]) -> Type {
52		Type::Float8
53	}
54
55	fn accepted_types(&self) -> InputTypes {
56		InputTypes::numeric()
57	}
58
59	fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
60		if args.is_empty() {
61			return Err(FunctionError::ArityMismatch {
62				function: ctx.fragment.clone(),
63				expected: 1,
64				actual: 0,
65			});
66		}
67
68		let row_count = args.row_count();
69		let mut sum = vec![0.0f64; row_count];
70		let mut count = vec![0u32; row_count];
71
72		for (col_idx, col) in args.iter().enumerate() {
73			let (data, _bitvec) = col.data().unwrap_option();
74			match data {
75				ColumnData::Int1(container) => {
76					for i in 0..row_count {
77						if let Some(value) = container.get(i) {
78							sum[i] += *value as f64;
79							count[i] += 1;
80						}
81					}
82				}
83				ColumnData::Int2(container) => {
84					for i in 0..row_count {
85						if let Some(value) = container.get(i) {
86							sum[i] += *value as f64;
87							count[i] += 1;
88						}
89					}
90				}
91				ColumnData::Int4(container) => {
92					for i in 0..row_count {
93						if let Some(value) = container.get(i) {
94							sum[i] += *value as f64;
95							count[i] += 1;
96						}
97					}
98				}
99				ColumnData::Int8(container) => {
100					for i in 0..row_count {
101						if let Some(value) = container.get(i) {
102							sum[i] += *value as f64;
103							count[i] += 1;
104						}
105					}
106				}
107				ColumnData::Int16(container) => {
108					for i in 0..row_count {
109						if let Some(value) = container.get(i) {
110							sum[i] += *value as f64;
111							count[i] += 1;
112						}
113					}
114				}
115				ColumnData::Uint1(container) => {
116					for i in 0..row_count {
117						if let Some(value) = container.get(i) {
118							sum[i] += *value as f64;
119							count[i] += 1;
120						}
121					}
122				}
123				ColumnData::Uint2(container) => {
124					for i in 0..row_count {
125						if let Some(value) = container.get(i) {
126							sum[i] += *value as f64;
127							count[i] += 1;
128						}
129					}
130				}
131				ColumnData::Uint4(container) => {
132					for i in 0..row_count {
133						if let Some(value) = container.get(i) {
134							sum[i] += *value as f64;
135							count[i] += 1;
136						}
137					}
138				}
139				ColumnData::Uint8(container) => {
140					for i in 0..row_count {
141						if let Some(value) = container.get(i) {
142							sum[i] += *value as f64;
143							count[i] += 1;
144						}
145					}
146				}
147				ColumnData::Uint16(container) => {
148					for i in 0..row_count {
149						if let Some(value) = container.get(i) {
150							sum[i] += *value as f64;
151							count[i] += 1;
152						}
153					}
154				}
155				ColumnData::Float4(container) => {
156					for i in 0..row_count {
157						if let Some(value) = container.get(i) {
158							sum[i] += *value as f64;
159							count[i] += 1;
160						}
161					}
162				}
163				ColumnData::Float8(container) => {
164					for i in 0..row_count {
165						if let Some(value) = container.get(i) {
166							sum[i] += *value;
167							count[i] += 1;
168						}
169					}
170				}
171				ColumnData::Int {
172					container,
173					..
174				} => {
175					for i in 0..row_count {
176						if let Some(value) = container.get(i) {
177							sum[i] += value.0.to_f64().unwrap_or(0.0);
178							count[i] += 1;
179						}
180					}
181				}
182				ColumnData::Uint {
183					container,
184					..
185				} => {
186					for i in 0..row_count {
187						if let Some(value) = container.get(i) {
188							sum[i] += value.0.to_f64().unwrap_or(0.0);
189							count[i] += 1;
190						}
191					}
192				}
193				ColumnData::Decimal {
194					container,
195					..
196				} => {
197					for i in 0..row_count {
198						if let Some(value) = container.get(i) {
199							sum[i] += value.0.to_f64().unwrap_or(0.0);
200							count[i] += 1;
201						}
202					}
203				}
204				other => {
205					return Err(FunctionError::InvalidArgumentType {
206						function: ctx.fragment.clone(),
207						argument_index: col_idx,
208						expected: self.accepted_types().expected_at(0).to_vec(),
209						actual: other.get_type(),
210					});
211				}
212			}
213		}
214
215		let mut data = Vec::with_capacity(row_count);
216		let mut valids = Vec::with_capacity(row_count);
217
218		for i in 0..row_count {
219			if count[i] > 0 {
220				data.push(sum[i] / count[i] as f64);
221				valids.push(true);
222			} else {
223				data.push(0.0);
224				valids.push(false);
225			}
226		}
227
228		Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), ColumnData::float8_with_bitvec(data, valids))]))
229	}
230
231	fn accumulator(&self, _ctx: &FunctionContext) -> Option<Box<dyn Accumulator>> {
232		Some(Box::new(AvgAccumulator::new()))
233	}
234}
235
236struct AvgAccumulator {
237	pub sums: IndexMap<GroupKey, f64>,
238	pub counts: IndexMap<GroupKey, u64>,
239}
240
241impl AvgAccumulator {
242	pub fn new() -> Self {
243		Self {
244			sums: IndexMap::new(),
245			counts: IndexMap::new(),
246		}
247	}
248}
249
250macro_rules! avg_arm {
251	($self:expr, $column:expr, $groups:expr, $container:expr) => {
252		for (group, indices) in $groups.iter() {
253			let mut sum = 0.0f64;
254			let mut count = 0u64;
255			for &i in indices {
256				if $column.data().is_defined(i) {
257					if let Some(&val) = $container.get(i) {
258						sum += val as f64;
259						count += 1;
260					}
261				}
262			}
263			if count > 0 {
264				$self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
265				$self.counts.entry(group.clone()).and_modify(|c| *c += count).or_insert(count);
266			} else {
267				$self.sums.entry(group.clone()).or_insert(0.0);
268				$self.counts.entry(group.clone()).or_insert(0);
269			}
270		}
271	};
272}
273
274impl Accumulator for AvgAccumulator {
275	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), FunctionError> {
276		let column = &args[0];
277		let (data, _bitvec) = column.data().unwrap_option();
278
279		match data {
280			ColumnData::Int1(container) => {
281				avg_arm!(self, column, groups, container);
282				Ok(())
283			}
284			ColumnData::Int2(container) => {
285				avg_arm!(self, column, groups, container);
286				Ok(())
287			}
288			ColumnData::Int4(container) => {
289				avg_arm!(self, column, groups, container);
290				Ok(())
291			}
292			ColumnData::Int8(container) => {
293				avg_arm!(self, column, groups, container);
294				Ok(())
295			}
296			ColumnData::Int16(container) => {
297				avg_arm!(self, column, groups, container);
298				Ok(())
299			}
300			ColumnData::Uint1(container) => {
301				avg_arm!(self, column, groups, container);
302				Ok(())
303			}
304			ColumnData::Uint2(container) => {
305				avg_arm!(self, column, groups, container);
306				Ok(())
307			}
308			ColumnData::Uint4(container) => {
309				avg_arm!(self, column, groups, container);
310				Ok(())
311			}
312			ColumnData::Uint8(container) => {
313				avg_arm!(self, column, groups, container);
314				Ok(())
315			}
316			ColumnData::Uint16(container) => {
317				avg_arm!(self, column, groups, container);
318				Ok(())
319			}
320			ColumnData::Float4(container) => {
321				avg_arm!(self, column, groups, container);
322				Ok(())
323			}
324			ColumnData::Float8(container) => {
325				avg_arm!(self, column, groups, container);
326				Ok(())
327			}
328			ColumnData::Int {
329				container,
330				..
331			} => {
332				for (group, indices) in groups.iter() {
333					let mut sum = 0.0f64;
334					let mut count = 0u64;
335					for &i in indices {
336						if column.data().is_defined(i)
337							&& let Some(val) = container.get(i)
338						{
339							sum += val.0.to_f64().unwrap_or(0.0);
340							count += 1;
341						}
342					}
343					if count > 0 {
344						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
345						self.counts
346							.entry(group.clone())
347							.and_modify(|c| *c += count)
348							.or_insert(count);
349					} else {
350						self.sums.entry(group.clone()).or_insert(0.0);
351						self.counts.entry(group.clone()).or_insert(0);
352					}
353				}
354				Ok(())
355			}
356			ColumnData::Uint {
357				container,
358				..
359			} => {
360				for (group, indices) in groups.iter() {
361					let mut sum = 0.0f64;
362					let mut count = 0u64;
363					for &i in indices {
364						if column.data().is_defined(i)
365							&& let Some(val) = container.get(i)
366						{
367							sum += val.0.to_f64().unwrap_or(0.0);
368							count += 1;
369						}
370					}
371					if count > 0 {
372						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
373						self.counts
374							.entry(group.clone())
375							.and_modify(|c| *c += count)
376							.or_insert(count);
377					} else {
378						self.sums.entry(group.clone()).or_insert(0.0);
379						self.counts.entry(group.clone()).or_insert(0);
380					}
381				}
382				Ok(())
383			}
384			ColumnData::Decimal {
385				container,
386				..
387			} => {
388				for (group, indices) in groups.iter() {
389					let mut sum = 0.0f64;
390					let mut count = 0u64;
391					for &i in indices {
392						if column.data().is_defined(i)
393							&& let Some(val) = container.get(i)
394						{
395							sum += val.0.to_f64().unwrap_or(0.0);
396							count += 1;
397						}
398					}
399					if count > 0 {
400						self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
401						self.counts
402							.entry(group.clone())
403							.and_modify(|c| *c += count)
404							.or_insert(count);
405					} else {
406						self.sums.entry(group.clone()).or_insert(0.0);
407						self.counts.entry(group.clone()).or_insert(0);
408					}
409				}
410				Ok(())
411			}
412			other => Err(FunctionError::InvalidArgumentType {
413				function: Fragment::internal("math::avg"),
414				argument_index: 0,
415				expected: InputTypes::numeric().expected_at(0).to_vec(),
416				actual: other.get_type(),
417			}),
418		}
419	}
420
421	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnData), FunctionError> {
422		let mut keys = Vec::with_capacity(self.sums.len());
423		let mut data = ColumnData::float8_with_capacity(self.sums.len());
424
425		for (key, sum) in mem::take(&mut self.sums) {
426			let count = self.counts.swap_remove(&key).unwrap_or(0);
427			keys.push(key);
428			if count > 0 {
429				data.push_value(Value::float8(sum / count as f64));
430			} else {
431				data.push_value(Value::none());
432			}
433		}
434
435		Ok((keys, data))
436	}
437}