Skip to main content

reifydb_function/math/scalar/
avg.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
9
10pub struct Avg {}
11
12impl Avg {
13	pub fn new() -> Self {
14		Self {}
15	}
16}
17
18impl ScalarFunction for Avg {
19	fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
20		if let Some(result) = propagate_options(self, &ctx) {
21			return result;
22		}
23		let columns = ctx.columns;
24		let row_count = ctx.row_count;
25
26		// Validate at least 1 argument
27		if columns.is_empty() {
28			return Err(ScalarFunctionError::ArityMismatch {
29				function: ctx.fragment.clone(),
30				expected: 1,
31				actual: 0,
32			});
33		}
34
35		let mut sum = vec![0.0f64; row_count];
36		let mut count = vec![0u32; row_count];
37
38		for (col_idx, col) in columns.iter().enumerate() {
39			match &col.data() {
40				ColumnData::Int1(container) => {
41					for i in 0..row_count {
42						if let Some(value) = container.get(i) {
43							sum[i] += *value as f64;
44							count[i] += 1;
45						}
46					}
47				}
48				ColumnData::Int2(container) => {
49					for i in 0..row_count {
50						if let Some(value) = container.get(i) {
51							sum[i] += *value as f64;
52							count[i] += 1;
53						}
54					}
55				}
56				ColumnData::Int4(container) => {
57					for i in 0..row_count {
58						if let Some(value) = container.get(i) {
59							sum[i] += *value as f64;
60							count[i] += 1;
61						}
62					}
63				}
64				ColumnData::Int8(container) => {
65					for i in 0..row_count {
66						if let Some(value) = container.get(i) {
67							sum[i] += *value as f64;
68							count[i] += 1;
69						}
70					}
71				}
72				ColumnData::Int16(container) => {
73					for i in 0..row_count {
74						if let Some(value) = container.get(i) {
75							sum[i] += *value as f64;
76							count[i] += 1;
77						}
78					}
79				}
80				ColumnData::Uint1(container) => {
81					for i in 0..row_count {
82						if let Some(value) = container.get(i) {
83							sum[i] += *value as f64;
84							count[i] += 1;
85						}
86					}
87				}
88				ColumnData::Uint2(container) => {
89					for i in 0..row_count {
90						if let Some(value) = container.get(i) {
91							sum[i] += *value as f64;
92							count[i] += 1;
93						}
94					}
95				}
96				ColumnData::Uint4(container) => {
97					for i in 0..row_count {
98						if let Some(value) = container.get(i) {
99							sum[i] += *value as f64;
100							count[i] += 1;
101						}
102					}
103				}
104				ColumnData::Uint8(container) => {
105					for i in 0..row_count {
106						if let Some(value) = container.get(i) {
107							sum[i] += *value as f64;
108							count[i] += 1;
109						}
110					}
111				}
112				ColumnData::Uint16(container) => {
113					for i in 0..row_count {
114						if let Some(value) = container.get(i) {
115							sum[i] += *value as f64;
116							count[i] += 1;
117						}
118					}
119				}
120				ColumnData::Float4(container) => {
121					for i in 0..row_count {
122						if let Some(value) = container.get(i) {
123							sum[i] += *value as f64;
124							count[i] += 1;
125						}
126					}
127				}
128				ColumnData::Float8(container) => {
129					for i in 0..row_count {
130						if let Some(value) = container.get(i) {
131							sum[i] += *value;
132							count[i] += 1;
133						}
134					}
135				}
136				ColumnData::Int {
137					container,
138					..
139				} => {
140					for i in 0..row_count {
141						if let Some(value) = container.get(i) {
142							sum[i] += value.0.to_f64().unwrap_or(0.0);
143							count[i] += 1;
144						}
145					}
146				}
147				ColumnData::Uint {
148					container,
149					..
150				} => {
151					for i in 0..row_count {
152						if let Some(value) = container.get(i) {
153							sum[i] += value.0.to_f64().unwrap_or(0.0);
154							count[i] += 1;
155						}
156					}
157				}
158				ColumnData::Decimal {
159					container,
160					..
161				} => {
162					for i in 0..row_count {
163						if let Some(value) = container.get(i) {
164							sum[i] += value.0.to_f64().unwrap_or(0.0);
165							count[i] += 1;
166						}
167					}
168				}
169				other => {
170					return Err(ScalarFunctionError::InvalidArgumentType {
171						function: ctx.fragment.clone(),
172						argument_index: col_idx,
173						expected: vec![
174							Type::Int1,
175							Type::Int2,
176							Type::Int4,
177							Type::Int8,
178							Type::Int16,
179							Type::Uint1,
180							Type::Uint2,
181							Type::Uint4,
182							Type::Uint8,
183							Type::Uint16,
184							Type::Float4,
185							Type::Float8,
186							Type::Int,
187							Type::Uint,
188							Type::Decimal,
189						],
190						actual: other.get_type(),
191					});
192				}
193			}
194		}
195
196		let mut data = Vec::with_capacity(row_count);
197		let mut valids = Vec::with_capacity(row_count);
198
199		for i in 0..row_count {
200			if count[i] > 0 {
201				data.push(sum[i] / count[i] as f64);
202				valids.push(true);
203			} else {
204				data.push(0.0);
205				valids.push(false);
206			}
207		}
208
209		Ok(ColumnData::float8_with_bitvec(data, valids))
210	}
211
212	fn return_type(&self, _input_types: &[Type]) -> Type {
213		Type::Float8
214	}
215}