Skip to main content

reifydb_function/math/scalar/
avg.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{
9	ScalarFunction, ScalarFunctionContext,
10	error::{ScalarFunctionError, ScalarFunctionResult},
11	propagate_options,
12};
13
14pub struct Avg {}
15
16impl Avg {
17	pub fn new() -> Self {
18		Self {}
19	}
20}
21
22impl ScalarFunction for Avg {
23	fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
24		if let Some(result) = propagate_options(self, &ctx) {
25			return result;
26		}
27		let columns = ctx.columns;
28		let row_count = ctx.row_count;
29
30		// Validate at least 1 argument
31		if columns.is_empty() {
32			return Err(ScalarFunctionError::ArityMismatch {
33				function: ctx.fragment.clone(),
34				expected: 1,
35				actual: 0,
36			});
37		}
38
39		let mut sum = vec![0.0f64; row_count];
40		let mut count = vec![0u32; row_count];
41
42		for (col_idx, col) in columns.iter().enumerate() {
43			match &col.data() {
44				ColumnData::Int1(container) => {
45					for i in 0..row_count {
46						if let Some(value) = container.get(i) {
47							sum[i] += *value as f64;
48							count[i] += 1;
49						}
50					}
51				}
52				ColumnData::Int2(container) => {
53					for i in 0..row_count {
54						if let Some(value) = container.get(i) {
55							sum[i] += *value as f64;
56							count[i] += 1;
57						}
58					}
59				}
60				ColumnData::Int4(container) => {
61					for i in 0..row_count {
62						if let Some(value) = container.get(i) {
63							sum[i] += *value as f64;
64							count[i] += 1;
65						}
66					}
67				}
68				ColumnData::Int8(container) => {
69					for i in 0..row_count {
70						if let Some(value) = container.get(i) {
71							sum[i] += *value as f64;
72							count[i] += 1;
73						}
74					}
75				}
76				ColumnData::Int16(container) => {
77					for i in 0..row_count {
78						if let Some(value) = container.get(i) {
79							sum[i] += *value as f64;
80							count[i] += 1;
81						}
82					}
83				}
84				ColumnData::Uint1(container) => {
85					for i in 0..row_count {
86						if let Some(value) = container.get(i) {
87							sum[i] += *value as f64;
88							count[i] += 1;
89						}
90					}
91				}
92				ColumnData::Uint2(container) => {
93					for i in 0..row_count {
94						if let Some(value) = container.get(i) {
95							sum[i] += *value as f64;
96							count[i] += 1;
97						}
98					}
99				}
100				ColumnData::Uint4(container) => {
101					for i in 0..row_count {
102						if let Some(value) = container.get(i) {
103							sum[i] += *value as f64;
104							count[i] += 1;
105						}
106					}
107				}
108				ColumnData::Uint8(container) => {
109					for i in 0..row_count {
110						if let Some(value) = container.get(i) {
111							sum[i] += *value as f64;
112							count[i] += 1;
113						}
114					}
115				}
116				ColumnData::Uint16(container) => {
117					for i in 0..row_count {
118						if let Some(value) = container.get(i) {
119							sum[i] += *value as f64;
120							count[i] += 1;
121						}
122					}
123				}
124				ColumnData::Float4(container) => {
125					for i in 0..row_count {
126						if let Some(value) = container.get(i) {
127							sum[i] += *value as f64;
128							count[i] += 1;
129						}
130					}
131				}
132				ColumnData::Float8(container) => {
133					for i in 0..row_count {
134						if let Some(value) = container.get(i) {
135							sum[i] += *value;
136							count[i] += 1;
137						}
138					}
139				}
140				ColumnData::Int {
141					container,
142					..
143				} => {
144					for i in 0..row_count {
145						if let Some(value) = container.get(i) {
146							sum[i] += value.0.to_f64().unwrap_or(0.0);
147							count[i] += 1;
148						}
149					}
150				}
151				ColumnData::Uint {
152					container,
153					..
154				} => {
155					for i in 0..row_count {
156						if let Some(value) = container.get(i) {
157							sum[i] += value.0.to_f64().unwrap_or(0.0);
158							count[i] += 1;
159						}
160					}
161				}
162				ColumnData::Decimal {
163					container,
164					..
165				} => {
166					for i in 0..row_count {
167						if let Some(value) = container.get(i) {
168							sum[i] += value.0.to_f64().unwrap_or(0.0);
169							count[i] += 1;
170						}
171					}
172				}
173				other => {
174					return Err(ScalarFunctionError::InvalidArgumentType {
175						function: ctx.fragment.clone(),
176						argument_index: col_idx,
177						expected: vec![
178							Type::Int1,
179							Type::Int2,
180							Type::Int4,
181							Type::Int8,
182							Type::Int16,
183							Type::Uint1,
184							Type::Uint2,
185							Type::Uint4,
186							Type::Uint8,
187							Type::Uint16,
188							Type::Float4,
189							Type::Float8,
190							Type::Int,
191							Type::Uint,
192							Type::Decimal,
193						],
194						actual: other.get_type(),
195					});
196				}
197			}
198		}
199
200		let mut data = Vec::with_capacity(row_count);
201		let mut valids = Vec::with_capacity(row_count);
202
203		for i in 0..row_count {
204			if count[i] > 0 {
205				data.push(sum[i] / count[i] as f64);
206				valids.push(true);
207			} else {
208				data.push(0.0);
209				valids.push(false);
210			}
211		}
212
213		Ok(ColumnData::float8_with_bitvec(data, valids))
214	}
215
216	fn return_type(&self, _input_types: &[Type]) -> Type {
217		Type::Float8
218	}
219}