reifydb_engine/function/math/aggregate/
avg.rs1use std::collections::HashMap;
5
6use reifydb_core::value::column::ColumnData;
7use reifydb_type::Value;
8
9use crate::function::{AggregateFunction, AggregateFunctionContext};
10
11pub struct Avg {
12 pub sums: HashMap<Vec<Value>, f64>,
13 pub counts: HashMap<Vec<Value>, u64>,
14}
15
16impl Avg {
17 pub fn new() -> Self {
18 Self {
19 sums: HashMap::new(),
20 counts: HashMap::new(),
21 }
22 }
23}
24
25impl AggregateFunction for Avg {
26 fn aggregate(&mut self, ctx: AggregateFunctionContext) -> crate::Result<()> {
27 let column = ctx.column;
28 let groups = &ctx.groups;
29
30 match &column.data() {
31 ColumnData::Float8(container) => {
32 for (group, indices) in groups.iter() {
33 let mut sum = 0.0;
34 let mut count = 0;
35
36 for &i in indices {
37 if let Some(value) = container.get(i) {
38 sum += *value;
39 count += 1;
40 }
41 }
42
43 if count > 0 {
44 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
45
46 self.counts
47 .entry(group.clone())
48 .and_modify(|c| *c += count)
49 .or_insert(count);
50 }
51 }
52 Ok(())
53 }
54 ColumnData::Float4(container) => {
55 for (group, indices) in groups.iter() {
56 let mut sum = 0.0;
57 let mut count = 0;
58
59 for &i in indices {
60 if let Some(value) = container.get(i) {
61 sum += *value as f64;
62 count += 1;
63 }
64 }
65
66 if count > 0 {
67 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
68
69 self.counts
70 .entry(group.clone())
71 .and_modify(|c| *c += count)
72 .or_insert(count);
73 }
74 }
75 Ok(())
76 }
77 ColumnData::Int2(container) => {
78 for (group, indices) in groups.iter() {
79 let mut sum = 0.0;
80 let mut count = 0;
81
82 for &i in indices {
83 if let Some(value) = container.get(i) {
84 sum += *value as f64;
85 count += 1;
86 }
87 }
88
89 if count > 0 {
90 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
91
92 self.counts
93 .entry(group.clone())
94 .and_modify(|c| *c += count)
95 .or_insert(count);
96 }
97 }
98 Ok(())
99 }
100 ColumnData::Int4(container) => {
101 for (group, indices) in groups.iter() {
102 let mut sum = 0.0;
103 let mut count = 0;
104
105 for &i in indices {
106 if let Some(value) = container.get(i) {
107 sum += *value as f64;
108 count += 1;
109 }
110 }
111
112 if count > 0 {
113 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
114
115 self.counts
116 .entry(group.clone())
117 .and_modify(|c| *c += count)
118 .or_insert(count);
119 }
120 }
121 Ok(())
122 }
123 ColumnData::Int8(container) => {
124 for (group, indices) in groups.iter() {
125 let mut sum = 0.0;
126 let mut count = 0;
127
128 for &i in indices {
129 if let Some(value) = container.get(i) {
130 sum += *value as f64;
131 count += 1;
132 }
133 }
134
135 if count > 0 {
136 self.sums.entry(group.clone()).and_modify(|v| *v += sum).or_insert(sum);
137
138 self.counts
139 .entry(group.clone())
140 .and_modify(|c| *c += count)
141 .or_insert(count);
142 }
143 }
144 Ok(())
145 }
146 _ => unimplemented!(),
147 }
148 }
149
150 fn finalize(&mut self) -> crate::Result<(Vec<Vec<Value>>, ColumnData)> {
151 let mut keys = Vec::with_capacity(self.sums.len());
152 let mut data = ColumnData::float8_with_capacity(self.sums.len());
153
154 for (key, sum) in std::mem::take(&mut self.sums) {
155 let count = self.counts.remove(&key).unwrap_or(0);
156 let avg = if count > 0 {
157 sum / count as f64
158 } else {
159 f64::NAN };
161
162 keys.push(key);
163 data.push_value(Value::float8(avg));
164 }
165
166 Ok((keys, data))
167 }
168}