reifydb_routine/function/math/
sum.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8 ColumnWithName,
9 buffer::ColumnBuffer,
10 columns::Columns,
11 view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14 fragment::Fragment,
15 value::{
16 Value,
17 decimal::Decimal,
18 int::Int,
19 r#type::{Type, input_types::InputTypes},
20 uint::Uint,
21 },
22};
23
24use crate::routine::{
25 Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
26};
27
28pub struct Sum {
29 info: RoutineInfo,
30}
31
32impl Default for Sum {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl Sum {
39 pub fn new() -> Self {
40 Self {
41 info: RoutineInfo::new("math::sum"),
42 }
43 }
44}
45
46impl<'a> Routine<FunctionContext<'a>> for Sum {
47 fn info(&self) -> &RoutineInfo {
48 &self.info
49 }
50
51 fn return_type(&self, input_types: &[Type]) -> Type {
52 input_types.first().cloned().unwrap_or(Type::Int8)
53 }
54
55 fn accepted_types(&self) -> InputTypes {
56 InputTypes::numeric()
57 }
58
59 fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
60 if args.is_empty() {
61 return Err(RoutineError::FunctionArityMismatch {
62 function: ctx.fragment.clone(),
63 expected: 1,
64 actual: 0,
65 });
66 }
67
68 let row_count = args.row_count();
69 let mut results = Vec::with_capacity(row_count);
70
71 for i in 0..row_count {
72 let val1 = args[0].get_value(i);
73 results.push(Box::new(val1));
74 }
75
76 Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), ColumnBuffer::any(results))]))
77 }
78}
79
80impl Function for Sum {
81 fn kinds(&self) -> &[FunctionKind] {
82 &[FunctionKind::Scalar, FunctionKind::Aggregate]
83 }
84
85 fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
86 Some(Box::new(SumAccumulator::new()))
87 }
88}
89
90struct SumAccumulator {
91 pub sums: IndexMap<Vec<Value>, Value>,
92 input_type: Option<Type>,
93}
94
95impl SumAccumulator {
96 pub fn new() -> Self {
97 Self {
98 sums: IndexMap::new(),
99 input_type: None,
100 }
101 }
102}
103
104macro_rules! sum_arm {
105 ($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $variant:ident) => {
106 for (group, indices) in $groups.iter() {
107 let mut delta: $t = Default::default();
108 let mut has_value = false;
109 for &i in indices {
110 if $column.is_defined(i) {
111 if let Some(&val) = $container.get(i) {
112 delta += val;
113 has_value = true;
114 }
115 }
116 }
117 if has_value {
118 let merged = match $self.sums.swap_remove(group) {
119 Some(Value::$variant(prev)) => prev + delta,
120 _ => delta,
121 };
122 $self.sums.insert(group.clone(), Value::$variant(merged));
123 } else {
124 $self.sums.entry(group.clone()).or_insert(Value::none());
125 }
126 }
127 };
128}
129
130macro_rules! sum_arm_float {
131 ($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $variant:ident, $ctor:expr) => {
132 for (group, indices) in $groups.iter() {
133 let mut delta: $t = Default::default();
134 let mut has_value = false;
135 for &i in indices {
136 if $column.is_defined(i) {
137 if let Some(&val) = $container.get(i) {
138 delta += val;
139 has_value = true;
140 }
141 }
142 }
143 if has_value {
144 let merged = match $self.sums.swap_remove(group) {
145 Some(Value::$variant(prev)) => prev.value() + delta,
146 _ => delta,
147 };
148 $self.sums.insert(group.clone(), $ctor(merged));
149 } else {
150 $self.sums.entry(group.clone()).or_insert(Value::none());
151 }
152 }
153 };
154}
155
156impl Accumulator for SumAccumulator {
157 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
158 let column = &args[0];
159 let (data, _bitvec) = column.unwrap_option();
160
161 if self.input_type.is_none() {
162 self.input_type = Some(data.get_type());
163 }
164
165 match data {
166 ColumnBuffer::Int1(container) => {
167 sum_arm!(self, column, groups, container, i8, Int1);
168 Ok(())
169 }
170 ColumnBuffer::Int2(container) => {
171 sum_arm!(self, column, groups, container, i16, Int2);
172 Ok(())
173 }
174 ColumnBuffer::Int4(container) => {
175 sum_arm!(self, column, groups, container, i32, Int4);
176 Ok(())
177 }
178 ColumnBuffer::Int8(container) => {
179 sum_arm!(self, column, groups, container, i64, Int8);
180 Ok(())
181 }
182 ColumnBuffer::Int16(container) => {
183 sum_arm!(self, column, groups, container, i128, Int16);
184 Ok(())
185 }
186 ColumnBuffer::Uint1(container) => {
187 sum_arm!(self, column, groups, container, u8, Uint1);
188 Ok(())
189 }
190 ColumnBuffer::Uint2(container) => {
191 sum_arm!(self, column, groups, container, u16, Uint2);
192 Ok(())
193 }
194 ColumnBuffer::Uint4(container) => {
195 sum_arm!(self, column, groups, container, u32, Uint4);
196 Ok(())
197 }
198 ColumnBuffer::Uint8(container) => {
199 sum_arm!(self, column, groups, container, u64, Uint8);
200 Ok(())
201 }
202 ColumnBuffer::Uint16(container) => {
203 sum_arm!(self, column, groups, container, u128, Uint16);
204 Ok(())
205 }
206 ColumnBuffer::Float4(container) => {
207 sum_arm_float!(self, column, groups, container, f32, Float4, Value::float4);
208 Ok(())
209 }
210 ColumnBuffer::Float8(container) => {
211 sum_arm_float!(self, column, groups, container, f64, Float8, Value::float8);
212 Ok(())
213 }
214 ColumnBuffer::Int {
215 container,
216 ..
217 } => {
218 for (group, indices) in groups.iter() {
219 let mut delta = Int::zero();
220 let mut has_value = false;
221 for &i in indices {
222 if column.is_defined(i)
223 && let Some(val) = container.get(i)
224 {
225 delta = Int(delta.0 + &val.0);
226 has_value = true;
227 }
228 }
229 if has_value {
230 let merged = match self.sums.swap_remove(group) {
231 Some(Value::Int(prev)) => Int(prev.0 + &delta.0),
232 _ => delta,
233 };
234 self.sums.insert(group.clone(), Value::Int(merged));
235 } else {
236 self.sums.entry(group.clone()).or_insert(Value::none());
237 }
238 }
239 Ok(())
240 }
241 ColumnBuffer::Uint {
242 container,
243 ..
244 } => {
245 for (group, indices) in groups.iter() {
246 let mut delta = Uint::zero();
247 let mut has_value = false;
248 for &i in indices {
249 if column.is_defined(i)
250 && let Some(val) = container.get(i)
251 {
252 delta = Uint(delta.0 + &val.0);
253 has_value = true;
254 }
255 }
256 if has_value {
257 let merged = match self.sums.swap_remove(group) {
258 Some(Value::Uint(prev)) => Uint(prev.0 + &delta.0),
259 _ => delta,
260 };
261 self.sums.insert(group.clone(), Value::Uint(merged));
262 } else {
263 self.sums.entry(group.clone()).or_insert(Value::none());
264 }
265 }
266 Ok(())
267 }
268 ColumnBuffer::Decimal {
269 container,
270 ..
271 } => {
272 for (group, indices) in groups.iter() {
273 let mut delta = Decimal::zero();
274 let mut has_value = false;
275 for &i in indices {
276 if column.is_defined(i)
277 && let Some(val) = container.get(i)
278 {
279 delta = Decimal(delta.0 + &val.0);
280 has_value = true;
281 }
282 }
283 if has_value {
284 let merged = match self.sums.swap_remove(group) {
285 Some(Value::Decimal(prev)) => Decimal(prev.0 + &delta.0),
286 _ => delta,
287 };
288 self.sums.insert(group.clone(), Value::Decimal(merged));
289 } else {
290 self.sums.entry(group.clone()).or_insert(Value::none());
291 }
292 }
293 Ok(())
294 }
295 other => Err(RoutineError::FunctionInvalidArgumentType {
296 function: Fragment::internal("math::sum"),
297 argument_index: 0,
298 expected: InputTypes::numeric().expected_at(0).to_vec(),
299 actual: other.get_type(),
300 }),
301 }
302 }
303
304 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
305 let ty = self.input_type.take().unwrap_or(Type::Int8);
306 let mut keys = Vec::with_capacity(self.sums.len());
307 let mut data = ColumnBuffer::with_capacity(ty, self.sums.len());
308
309 for (key, sum) in mem::take(&mut self.sums) {
310 keys.push(key);
311 data.push_value(sum);
312 }
313
314 Ok((keys, data))
315 }
316}