reifydb_function/math/aggregate/
avg.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::Value;
9
10use crate::{AggregateFunction, AggregateFunctionContext, error::AggregateFunctionResult};
11
12pub struct Avg {
13 pub sums: IndexMap<Vec<Value>, f64>,
14 pub counts: IndexMap<Vec<Value>, u64>,
15}
16
17impl Avg {
18 pub fn new() -> Self {
19 Self {
20 sums: IndexMap::new(),
21 counts: IndexMap::new(),
22 }
23 }
24}
25
26impl AggregateFunction for Avg {
27 fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
28 let column = ctx.column;
29 let groups = &ctx.groups;
30
31 match &column.data() {
32 ColumnData::Float8(container) => {
33 for (group, indices) in groups.iter() {
34 let mut sum = 0.0;
35 let mut count = 0;
36
37 for &i in indices {
38 if let Some(value) = container.get(i) {
39 sum += *value;
40 count += 1;
41 }
42 }
43
44 if count > 0 {
45 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
46
47 self.counts
48 .entry(group.clone())
49 .and_modify(|c| *c += count)
50 .or_insert(count);
51 }
52 }
53 Ok(())
54 }
55 ColumnData::Float4(container) => {
56 for (group, indices) in groups.iter() {
57 let mut sum = 0.0;
58 let mut count = 0;
59
60 for &i in indices {
61 if let Some(value) = container.get(i) {
62 sum += *value as f64;
63 count += 1;
64 }
65 }
66
67 if count > 0 {
68 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
69
70 self.counts
71 .entry(group.clone())
72 .and_modify(|c| *c += count)
73 .or_insert(count);
74 }
75 }
76 Ok(())
77 }
78 ColumnData::Int2(container) => {
79 for (group, indices) in groups.iter() {
80 let mut sum = 0.0;
81 let mut count = 0;
82
83 for &i in indices {
84 if let Some(value) = container.get(i) {
85 sum += *value as f64;
86 count += 1;
87 }
88 }
89
90 if count > 0 {
91 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
92
93 self.counts
94 .entry(group.clone())
95 .and_modify(|c| *c += count)
96 .or_insert(count);
97 }
98 }
99 Ok(())
100 }
101 ColumnData::Int4(container) => {
102 for (group, indices) in groups.iter() {
103 let mut sum = 0.0;
104 let mut count = 0;
105
106 for &i in indices {
107 if let Some(value) = container.get(i) {
108 sum += *value as f64;
109 count += 1;
110 }
111 }
112
113 if count > 0 {
114 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
115
116 self.counts
117 .entry(group.clone())
118 .and_modify(|c| *c += count)
119 .or_insert(count);
120 }
121 }
122 Ok(())
123 }
124 ColumnData::Int8(container) => {
125 for (group, indices) in groups.iter() {
126 let mut sum = 0.0;
127 let mut count = 0;
128
129 for &i in indices {
130 if let Some(value) = container.get(i) {
131 sum += *value as f64;
132 count += 1;
133 }
134 }
135
136 if count > 0 {
137 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
138
139 self.counts
140 .entry(group.clone())
141 .and_modify(|c| *c += count)
142 .or_insert(count);
143 }
144 }
145 Ok(())
146 }
147 _ => unimplemented!(),
148 }
149 }
150
151 fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
152 let mut keys = Vec::with_capacity(self.sums.len());
153 let mut data = ColumnData::float8_with_capacity(self.sums.len());
154
155 for (key, sum) in mem::take(&mut self.sums) {
156 let count = self.counts.swap_remove(&key).unwrap_or(0);
157 let avg = if count > 0 {
158 sum / count as f64
159 } else {
160 keys.push(key);
161 data.push_value(Value::none());
162 return Ok((keys, data));
163 };
164
165 keys.push(key);
166 data.push_value(Value::float8(avg));
167 }
168
169 Ok((keys, data))
170 }
171}