reifydb_routine/function/math/aggregate/
sum.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::{
9 Value,
10 decimal::Decimal,
11 int::Int,
12 r#type::{Type, input_types::InputTypes},
13 uint::Uint,
14};
15
16use crate::function::{
17 AggregateFunction, AggregateFunctionContext,
18 error::{AggregateFunctionError, AggregateFunctionResult},
19};
20
21pub struct Sum {
22 pub sums: IndexMap<Vec<Value>, Value>,
23 input_type: Option<Type>,
24}
25
26impl Sum {
27 pub fn new() -> Self {
28 Self {
29 sums: IndexMap::new(),
30 input_type: None,
31 }
32 }
33}
34
35macro_rules! sum_arm {
36 ($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $ctor:expr) => {
37 for (group, indices) in $groups.iter() {
38 let mut sum: $t = Default::default();
39 let mut has_value = false;
40 for &i in indices {
41 if $column.data().is_defined(i) {
42 if let Some(&val) = $container.get(i) {
43 sum += val;
44 has_value = true;
45 }
46 }
47 }
48 if has_value {
49 $self.sums.insert(group.clone(), $ctor(sum));
50 } else {
51 $self.sums.entry(group.clone()).or_insert(Value::none());
52 }
53 }
54 };
55}
56
57impl AggregateFunction for Sum {
58 fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
59 let column = ctx.column;
60 let groups = &ctx.groups;
61 let (data, _bitvec) = column.data().unwrap_option();
62
63 if self.input_type.is_none() {
64 self.input_type = Some(data.get_type());
65 }
66
67 match data {
68 ColumnData::Int1(container) => {
69 sum_arm!(self, column, groups, container, i8, Value::Int1);
70 Ok(())
71 }
72 ColumnData::Int2(container) => {
73 sum_arm!(self, column, groups, container, i16, Value::Int2);
74 Ok(())
75 }
76 ColumnData::Int4(container) => {
77 sum_arm!(self, column, groups, container, i32, Value::Int4);
78 Ok(())
79 }
80 ColumnData::Int8(container) => {
81 sum_arm!(self, column, groups, container, i64, Value::Int8);
82 Ok(())
83 }
84 ColumnData::Int16(container) => {
85 sum_arm!(self, column, groups, container, i128, Value::Int16);
86 Ok(())
87 }
88 ColumnData::Uint1(container) => {
89 sum_arm!(self, column, groups, container, u8, Value::Uint1);
90 Ok(())
91 }
92 ColumnData::Uint2(container) => {
93 sum_arm!(self, column, groups, container, u16, Value::Uint2);
94 Ok(())
95 }
96 ColumnData::Uint4(container) => {
97 sum_arm!(self, column, groups, container, u32, Value::Uint4);
98 Ok(())
99 }
100 ColumnData::Uint8(container) => {
101 sum_arm!(self, column, groups, container, u64, Value::Uint8);
102 Ok(())
103 }
104 ColumnData::Uint16(container) => {
105 sum_arm!(self, column, groups, container, u128, Value::Uint16);
106 Ok(())
107 }
108 ColumnData::Float4(container) => {
109 sum_arm!(self, column, groups, container, f32, Value::float4);
110 Ok(())
111 }
112 ColumnData::Float8(container) => {
113 sum_arm!(self, column, groups, container, f64, Value::float8);
114 Ok(())
115 }
116 ColumnData::Int {
117 container,
118 ..
119 } => {
120 for (group, indices) in groups.iter() {
121 let mut sum = Int::zero();
122 let mut has_value = false;
123 for &i in indices {
124 if column.data().is_defined(i) {
125 if let Some(val) = container.get(i) {
126 sum = Int(sum.0 + &val.0);
127 has_value = true;
128 }
129 }
130 }
131 if has_value {
132 self.sums.insert(group.clone(), Value::Int(sum));
133 } else {
134 self.sums.entry(group.clone()).or_insert(Value::none());
135 }
136 }
137 Ok(())
138 }
139 ColumnData::Uint {
140 container,
141 ..
142 } => {
143 for (group, indices) in groups.iter() {
144 let mut sum = Uint::zero();
145 let mut has_value = false;
146 for &i in indices {
147 if column.data().is_defined(i) {
148 if let Some(val) = container.get(i) {
149 sum = Uint(sum.0 + &val.0);
150 has_value = true;
151 }
152 }
153 }
154 if has_value {
155 self.sums.insert(group.clone(), Value::Uint(sum));
156 } else {
157 self.sums.entry(group.clone()).or_insert(Value::none());
158 }
159 }
160 Ok(())
161 }
162 ColumnData::Decimal {
163 container,
164 ..
165 } => {
166 for (group, indices) in groups.iter() {
167 let mut sum = Decimal::zero();
168 let mut has_value = false;
169 for &i in indices {
170 if column.data().is_defined(i) {
171 if let Some(val) = container.get(i) {
172 sum = Decimal(sum.0 + &val.0);
173 has_value = true;
174 }
175 }
176 }
177 if has_value {
178 self.sums.insert(group.clone(), Value::Decimal(sum));
179 } else {
180 self.sums.entry(group.clone()).or_insert(Value::none());
181 }
182 }
183 Ok(())
184 }
185 other => Err(AggregateFunctionError::InvalidArgumentType {
186 function: ctx.fragment.clone(),
187 argument_index: 0,
188 expected: self.accepted_types().expected_at(0).to_vec(),
189 actual: other.get_type(),
190 }),
191 }
192 }
193
194 fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
195 let ty = self.input_type.take().unwrap_or(Type::Int8);
196 let mut keys = Vec::with_capacity(self.sums.len());
197 let mut data = ColumnData::with_capacity(ty, self.sums.len());
198
199 for (key, sum) in mem::take(&mut self.sums) {
200 keys.push(key);
201 data.push_value(sum);
202 }
203
204 Ok((keys, data))
205 }
206
207 fn return_type(&self, input_type: &Type) -> Type {
208 input_type.clone()
209 }
210
211 fn accepted_types(&self) -> InputTypes {
212 InputTypes::numeric()
213 }
214}