reifydb_function/math/scalar/
avg.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{
9 ScalarFunction, ScalarFunctionContext,
10 error::{ScalarFunctionError, ScalarFunctionResult},
11 propagate_options,
12};
13
14pub struct Avg {}
15
16impl Avg {
17 pub fn new() -> Self {
18 Self {}
19 }
20}
21
22impl ScalarFunction for Avg {
23 fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
24 if let Some(result) = propagate_options(self, &ctx) {
25 return result;
26 }
27 let columns = ctx.columns;
28 let row_count = ctx.row_count;
29
30 if columns.is_empty() {
32 return Err(ScalarFunctionError::ArityMismatch {
33 function: ctx.fragment.clone(),
34 expected: 1,
35 actual: 0,
36 });
37 }
38
39 let mut sum = vec![0.0f64; row_count];
40 let mut count = vec![0u32; row_count];
41
42 for (col_idx, col) in columns.iter().enumerate() {
43 match &col.data() {
44 ColumnData::Int1(container) => {
45 for i in 0..row_count {
46 if let Some(value) = container.get(i) {
47 sum[i] += *value as f64;
48 count[i] += 1;
49 }
50 }
51 }
52 ColumnData::Int2(container) => {
53 for i in 0..row_count {
54 if let Some(value) = container.get(i) {
55 sum[i] += *value as f64;
56 count[i] += 1;
57 }
58 }
59 }
60 ColumnData::Int4(container) => {
61 for i in 0..row_count {
62 if let Some(value) = container.get(i) {
63 sum[i] += *value as f64;
64 count[i] += 1;
65 }
66 }
67 }
68 ColumnData::Int8(container) => {
69 for i in 0..row_count {
70 if let Some(value) = container.get(i) {
71 sum[i] += *value as f64;
72 count[i] += 1;
73 }
74 }
75 }
76 ColumnData::Int16(container) => {
77 for i in 0..row_count {
78 if let Some(value) = container.get(i) {
79 sum[i] += *value as f64;
80 count[i] += 1;
81 }
82 }
83 }
84 ColumnData::Uint1(container) => {
85 for i in 0..row_count {
86 if let Some(value) = container.get(i) {
87 sum[i] += *value as f64;
88 count[i] += 1;
89 }
90 }
91 }
92 ColumnData::Uint2(container) => {
93 for i in 0..row_count {
94 if let Some(value) = container.get(i) {
95 sum[i] += *value as f64;
96 count[i] += 1;
97 }
98 }
99 }
100 ColumnData::Uint4(container) => {
101 for i in 0..row_count {
102 if let Some(value) = container.get(i) {
103 sum[i] += *value as f64;
104 count[i] += 1;
105 }
106 }
107 }
108 ColumnData::Uint8(container) => {
109 for i in 0..row_count {
110 if let Some(value) = container.get(i) {
111 sum[i] += *value as f64;
112 count[i] += 1;
113 }
114 }
115 }
116 ColumnData::Uint16(container) => {
117 for i in 0..row_count {
118 if let Some(value) = container.get(i) {
119 sum[i] += *value as f64;
120 count[i] += 1;
121 }
122 }
123 }
124 ColumnData::Float4(container) => {
125 for i in 0..row_count {
126 if let Some(value) = container.get(i) {
127 sum[i] += *value as f64;
128 count[i] += 1;
129 }
130 }
131 }
132 ColumnData::Float8(container) => {
133 for i in 0..row_count {
134 if let Some(value) = container.get(i) {
135 sum[i] += *value;
136 count[i] += 1;
137 }
138 }
139 }
140 ColumnData::Int {
141 container,
142 ..
143 } => {
144 for i in 0..row_count {
145 if let Some(value) = container.get(i) {
146 sum[i] += value.0.to_f64().unwrap_or(0.0);
147 count[i] += 1;
148 }
149 }
150 }
151 ColumnData::Uint {
152 container,
153 ..
154 } => {
155 for i in 0..row_count {
156 if let Some(value) = container.get(i) {
157 sum[i] += value.0.to_f64().unwrap_or(0.0);
158 count[i] += 1;
159 }
160 }
161 }
162 ColumnData::Decimal {
163 container,
164 ..
165 } => {
166 for i in 0..row_count {
167 if let Some(value) = container.get(i) {
168 sum[i] += value.0.to_f64().unwrap_or(0.0);
169 count[i] += 1;
170 }
171 }
172 }
173 other => {
174 return Err(ScalarFunctionError::InvalidArgumentType {
175 function: ctx.fragment.clone(),
176 argument_index: col_idx,
177 expected: vec![
178 Type::Int1,
179 Type::Int2,
180 Type::Int4,
181 Type::Int8,
182 Type::Int16,
183 Type::Uint1,
184 Type::Uint2,
185 Type::Uint4,
186 Type::Uint8,
187 Type::Uint16,
188 Type::Float4,
189 Type::Float8,
190 Type::Int,
191 Type::Uint,
192 Type::Decimal,
193 ],
194 actual: other.get_type(),
195 });
196 }
197 }
198 }
199
200 let mut data = Vec::with_capacity(row_count);
201 let mut valids = Vec::with_capacity(row_count);
202
203 for i in 0..row_count {
204 if count[i] > 0 {
205 data.push(sum[i] / count[i] as f64);
206 valids.push(true);
207 } else {
208 data.push(0.0);
209 valids.push(false);
210 }
211 }
212
213 Ok(ColumnData::float8_with_bitvec(data, valids))
214 }
215
216 fn return_type(&self, _input_types: &[Type]) -> Type {
217 Type::Float8
218 }
219}