reifydb_routine/function/math/
sum.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8 ColumnWithName,
9 buffer::ColumnBuffer,
10 columns::Columns,
11 view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14 fragment::Fragment,
15 value::{
16 Value,
17 decimal::Decimal,
18 int::Int,
19 r#type::{Type, input_types::InputTypes},
20 uint::Uint,
21 },
22};
23
24use crate::routine::{
25 Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
26};
27
28pub struct Sum {
29 info: RoutineInfo,
30}
31
32impl Default for Sum {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl Sum {
39 pub fn new() -> Self {
40 Self {
41 info: RoutineInfo::new("math::sum"),
42 }
43 }
44}
45
46impl<'a> Routine<FunctionContext<'a>> for Sum {
47 fn info(&self) -> &RoutineInfo {
48 &self.info
49 }
50
51 fn return_type(&self, input_types: &[Type]) -> Type {
52 input_types.first().cloned().unwrap_or(Type::Int8)
53 }
54
55 fn accepted_types(&self) -> InputTypes {
56 InputTypes::numeric()
57 }
58
59 fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
60 if args.is_empty() {
62 return Err(RoutineError::FunctionArityMismatch {
63 function: ctx.fragment.clone(),
64 expected: 1,
65 actual: 0,
66 });
67 }
68
69 let row_count = args.row_count();
70 let mut results = Vec::with_capacity(row_count);
71
72 for i in 0..row_count {
73 let val1 = args[0].get_value(i);
76 results.push(Box::new(val1));
77 }
78
79 Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), ColumnBuffer::any(results))]))
80 }
81}
82
83impl Function for Sum {
84 fn kinds(&self) -> &[FunctionKind] {
85 &[FunctionKind::Scalar, FunctionKind::Aggregate]
86 }
87
88 fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
89 Some(Box::new(SumAccumulator::new()))
90 }
91}
92
93struct SumAccumulator {
94 pub sums: IndexMap<Vec<Value>, Value>,
95 input_type: Option<Type>,
96}
97
98impl SumAccumulator {
99 pub fn new() -> Self {
100 Self {
101 sums: IndexMap::new(),
102 input_type: None,
103 }
104 }
105}
106
107macro_rules! sum_arm {
108 ($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $ctor:expr) => {
109 for (group, indices) in $groups.iter() {
110 let mut sum: $t = Default::default();
111 let mut has_value = false;
112 for &i in indices {
113 if $column.is_defined(i) {
114 if let Some(&val) = $container.get(i) {
115 sum += val;
116 has_value = true;
117 }
118 }
119 }
120 if has_value {
121 $self.sums.insert(group.clone(), $ctor(sum));
122 } else {
123 $self.sums.entry(group.clone()).or_insert(Value::none());
124 }
125 }
126 };
127}
128
129impl Accumulator for SumAccumulator {
130 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
131 let column = &args[0];
132 let (data, _bitvec) = column.unwrap_option();
133
134 if self.input_type.is_none() {
135 self.input_type = Some(data.get_type());
136 }
137
138 match data {
139 ColumnBuffer::Int1(container) => {
140 sum_arm!(self, column, groups, container, i8, Value::Int1);
141 Ok(())
142 }
143 ColumnBuffer::Int2(container) => {
144 sum_arm!(self, column, groups, container, i16, Value::Int2);
145 Ok(())
146 }
147 ColumnBuffer::Int4(container) => {
148 sum_arm!(self, column, groups, container, i32, Value::Int4);
149 Ok(())
150 }
151 ColumnBuffer::Int8(container) => {
152 sum_arm!(self, column, groups, container, i64, Value::Int8);
153 Ok(())
154 }
155 ColumnBuffer::Int16(container) => {
156 sum_arm!(self, column, groups, container, i128, Value::Int16);
157 Ok(())
158 }
159 ColumnBuffer::Uint1(container) => {
160 sum_arm!(self, column, groups, container, u8, Value::Uint1);
161 Ok(())
162 }
163 ColumnBuffer::Uint2(container) => {
164 sum_arm!(self, column, groups, container, u16, Value::Uint2);
165 Ok(())
166 }
167 ColumnBuffer::Uint4(container) => {
168 sum_arm!(self, column, groups, container, u32, Value::Uint4);
169 Ok(())
170 }
171 ColumnBuffer::Uint8(container) => {
172 sum_arm!(self, column, groups, container, u64, Value::Uint8);
173 Ok(())
174 }
175 ColumnBuffer::Uint16(container) => {
176 sum_arm!(self, column, groups, container, u128, Value::Uint16);
177 Ok(())
178 }
179 ColumnBuffer::Float4(container) => {
180 sum_arm!(self, column, groups, container, f32, Value::float4);
181 Ok(())
182 }
183 ColumnBuffer::Float8(container) => {
184 sum_arm!(self, column, groups, container, f64, Value::float8);
185 Ok(())
186 }
187 ColumnBuffer::Int {
188 container,
189 ..
190 } => {
191 for (group, indices) in groups.iter() {
192 let mut sum = Int::zero();
193 let mut has_value = false;
194 for &i in indices {
195 if column.is_defined(i)
196 && let Some(val) = container.get(i)
197 {
198 sum = Int(sum.0 + &val.0);
199 has_value = true;
200 }
201 }
202 if has_value {
203 self.sums.insert(group.clone(), Value::Int(sum));
204 } else {
205 self.sums.entry(group.clone()).or_insert(Value::none());
206 }
207 }
208 Ok(())
209 }
210 ColumnBuffer::Uint {
211 container,
212 ..
213 } => {
214 for (group, indices) in groups.iter() {
215 let mut sum = Uint::zero();
216 let mut has_value = false;
217 for &i in indices {
218 if column.is_defined(i)
219 && let Some(val) = container.get(i)
220 {
221 sum = Uint(sum.0 + &val.0);
222 has_value = true;
223 }
224 }
225 if has_value {
226 self.sums.insert(group.clone(), Value::Uint(sum));
227 } else {
228 self.sums.entry(group.clone()).or_insert(Value::none());
229 }
230 }
231 Ok(())
232 }
233 ColumnBuffer::Decimal {
234 container,
235 ..
236 } => {
237 for (group, indices) in groups.iter() {
238 let mut sum = Decimal::zero();
239 let mut has_value = false;
240 for &i in indices {
241 if column.is_defined(i)
242 && let Some(val) = container.get(i)
243 {
244 sum = Decimal(sum.0 + &val.0);
245 has_value = true;
246 }
247 }
248 if has_value {
249 self.sums.insert(group.clone(), Value::Decimal(sum));
250 } else {
251 self.sums.entry(group.clone()).or_insert(Value::none());
252 }
253 }
254 Ok(())
255 }
256 other => Err(RoutineError::FunctionInvalidArgumentType {
257 function: Fragment::internal("math::sum"),
258 argument_index: 0,
259 expected: InputTypes::numeric().expected_at(0).to_vec(),
260 actual: other.get_type(),
261 }),
262 }
263 }
264
265 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
266 let ty = self.input_type.take().unwrap_or(Type::Int8);
267 let mut keys = Vec::with_capacity(self.sums.len());
268 let mut data = ColumnBuffer::with_capacity(ty, self.sums.len());
269
270 for (key, sum) in mem::take(&mut self.sums) {
271 keys.push(key);
272 data.push_value(sum);
273 }
274
275 Ok((keys, data))
276 }
277}