reifydb_function/math/scalar/
avg.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
9
10pub struct Avg {}
11
12impl Avg {
13 pub fn new() -> Self {
14 Self {}
15 }
16}
17
18impl ScalarFunction for Avg {
19 fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
20 if let Some(result) = propagate_options(self, &ctx) {
21 return result;
22 }
23 let columns = ctx.columns;
24 let row_count = ctx.row_count;
25
26 if columns.is_empty() {
28 return Err(ScalarFunctionError::ArityMismatch {
29 function: ctx.fragment.clone(),
30 expected: 1,
31 actual: 0,
32 });
33 }
34
35 let mut sum = vec![0.0f64; row_count];
36 let mut count = vec![0u32; row_count];
37
38 for (col_idx, col) in columns.iter().enumerate() {
39 match &col.data() {
40 ColumnData::Int1(container) => {
41 for i in 0..row_count {
42 if let Some(value) = container.get(i) {
43 sum[i] += *value as f64;
44 count[i] += 1;
45 }
46 }
47 }
48 ColumnData::Int2(container) => {
49 for i in 0..row_count {
50 if let Some(value) = container.get(i) {
51 sum[i] += *value as f64;
52 count[i] += 1;
53 }
54 }
55 }
56 ColumnData::Int4(container) => {
57 for i in 0..row_count {
58 if let Some(value) = container.get(i) {
59 sum[i] += *value as f64;
60 count[i] += 1;
61 }
62 }
63 }
64 ColumnData::Int8(container) => {
65 for i in 0..row_count {
66 if let Some(value) = container.get(i) {
67 sum[i] += *value as f64;
68 count[i] += 1;
69 }
70 }
71 }
72 ColumnData::Int16(container) => {
73 for i in 0..row_count {
74 if let Some(value) = container.get(i) {
75 sum[i] += *value as f64;
76 count[i] += 1;
77 }
78 }
79 }
80 ColumnData::Uint1(container) => {
81 for i in 0..row_count {
82 if let Some(value) = container.get(i) {
83 sum[i] += *value as f64;
84 count[i] += 1;
85 }
86 }
87 }
88 ColumnData::Uint2(container) => {
89 for i in 0..row_count {
90 if let Some(value) = container.get(i) {
91 sum[i] += *value as f64;
92 count[i] += 1;
93 }
94 }
95 }
96 ColumnData::Uint4(container) => {
97 for i in 0..row_count {
98 if let Some(value) = container.get(i) {
99 sum[i] += *value as f64;
100 count[i] += 1;
101 }
102 }
103 }
104 ColumnData::Uint8(container) => {
105 for i in 0..row_count {
106 if let Some(value) = container.get(i) {
107 sum[i] += *value as f64;
108 count[i] += 1;
109 }
110 }
111 }
112 ColumnData::Uint16(container) => {
113 for i in 0..row_count {
114 if let Some(value) = container.get(i) {
115 sum[i] += *value as f64;
116 count[i] += 1;
117 }
118 }
119 }
120 ColumnData::Float4(container) => {
121 for i in 0..row_count {
122 if let Some(value) = container.get(i) {
123 sum[i] += *value as f64;
124 count[i] += 1;
125 }
126 }
127 }
128 ColumnData::Float8(container) => {
129 for i in 0..row_count {
130 if let Some(value) = container.get(i) {
131 sum[i] += *value;
132 count[i] += 1;
133 }
134 }
135 }
136 ColumnData::Int {
137 container,
138 ..
139 } => {
140 for i in 0..row_count {
141 if let Some(value) = container.get(i) {
142 sum[i] += value.0.to_f64().unwrap_or(0.0);
143 count[i] += 1;
144 }
145 }
146 }
147 ColumnData::Uint {
148 container,
149 ..
150 } => {
151 for i in 0..row_count {
152 if let Some(value) = container.get(i) {
153 sum[i] += value.0.to_f64().unwrap_or(0.0);
154 count[i] += 1;
155 }
156 }
157 }
158 ColumnData::Decimal {
159 container,
160 ..
161 } => {
162 for i in 0..row_count {
163 if let Some(value) = container.get(i) {
164 sum[i] += value.0.to_f64().unwrap_or(0.0);
165 count[i] += 1;
166 }
167 }
168 }
169 other => {
170 return Err(ScalarFunctionError::InvalidArgumentType {
171 function: ctx.fragment.clone(),
172 argument_index: col_idx,
173 expected: vec![
174 Type::Int1,
175 Type::Int2,
176 Type::Int4,
177 Type::Int8,
178 Type::Int16,
179 Type::Uint1,
180 Type::Uint2,
181 Type::Uint4,
182 Type::Uint8,
183 Type::Uint16,
184 Type::Float4,
185 Type::Float8,
186 Type::Int,
187 Type::Uint,
188 Type::Decimal,
189 ],
190 actual: other.get_type(),
191 });
192 }
193 }
194 }
195
196 let mut data = Vec::with_capacity(row_count);
197 let mut valids = Vec::with_capacity(row_count);
198
199 for i in 0..row_count {
200 if count[i] > 0 {
201 data.push(sum[i] / count[i] as f64);
202 valids.push(true);
203 } else {
204 data.push(0.0);
205 valids.push(false);
206 }
207 }
208
209 Ok(ColumnData::float8_with_bitvec(data, valids))
210 }
211
212 fn return_type(&self, _input_types: &[Type]) -> Type {
213 Type::Float8
214 }
215}