reifydb_routine/function/math/
sum.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8 Column,
9 columns::Columns,
10 data::ColumnData,
11 view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14 fragment::Fragment,
15 value::{
16 Value,
17 decimal::Decimal,
18 int::Int,
19 r#type::{Type, input_types::InputTypes},
20 uint::Uint,
21 },
22};
23
24use crate::function::{Accumulator, Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
25
26pub struct Sum {
27 info: FunctionInfo,
28}
29
30impl Default for Sum {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl Sum {
37 pub fn new() -> Self {
38 Self {
39 info: FunctionInfo::new("math::sum"),
40 }
41 }
42}
43
44impl Function for Sum {
45 fn info(&self) -> &FunctionInfo {
46 &self.info
47 }
48
49 fn capabilities(&self) -> &[FunctionCapability] {
50 &[FunctionCapability::Scalar, FunctionCapability::Aggregate]
51 }
52
53 fn return_type(&self, input_types: &[Type]) -> Type {
54 input_types.first().cloned().unwrap_or(Type::Int8)
55 }
56
57 fn accepted_types(&self) -> InputTypes {
58 InputTypes::numeric()
59 }
60
61 fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
62 if args.is_empty() {
64 return Err(FunctionError::ArityMismatch {
65 function: ctx.fragment.clone(),
66 expected: 1,
67 actual: 0,
68 });
69 }
70
71 let row_count = args.row_count();
72 let mut results = Vec::with_capacity(row_count);
73
74 for i in 0..row_count {
75 let val1 = args[0].data().get_value(i);
78 results.push(Box::new(val1));
79 }
80
81 Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), ColumnData::any(results))]))
82 }
83
84 fn accumulator(&self, _ctx: &FunctionContext) -> Option<Box<dyn Accumulator>> {
85 Some(Box::new(SumAccumulator::new()))
86 }
87}
88
89struct SumAccumulator {
90 pub sums: IndexMap<Vec<Value>, Value>,
91 input_type: Option<Type>,
92}
93
94impl SumAccumulator {
95 pub fn new() -> Self {
96 Self {
97 sums: IndexMap::new(),
98 input_type: None,
99 }
100 }
101}
102
103macro_rules! sum_arm {
104 ($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $ctor:expr) => {
105 for (group, indices) in $groups.iter() {
106 let mut sum: $t = Default::default();
107 let mut has_value = false;
108 for &i in indices {
109 if $column.data().is_defined(i) {
110 if let Some(&val) = $container.get(i) {
111 sum += val;
112 has_value = true;
113 }
114 }
115 }
116 if has_value {
117 $self.sums.insert(group.clone(), $ctor(sum));
118 } else {
119 $self.sums.entry(group.clone()).or_insert(Value::none());
120 }
121 }
122 };
123}
124
125impl Accumulator for SumAccumulator {
126 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), FunctionError> {
127 let column = &args[0];
128 let (data, _bitvec) = column.data().unwrap_option();
129
130 if self.input_type.is_none() {
131 self.input_type = Some(data.get_type());
132 }
133
134 match data {
135 ColumnData::Int1(container) => {
136 sum_arm!(self, column, groups, container, i8, Value::Int1);
137 Ok(())
138 }
139 ColumnData::Int2(container) => {
140 sum_arm!(self, column, groups, container, i16, Value::Int2);
141 Ok(())
142 }
143 ColumnData::Int4(container) => {
144 sum_arm!(self, column, groups, container, i32, Value::Int4);
145 Ok(())
146 }
147 ColumnData::Int8(container) => {
148 sum_arm!(self, column, groups, container, i64, Value::Int8);
149 Ok(())
150 }
151 ColumnData::Int16(container) => {
152 sum_arm!(self, column, groups, container, i128, Value::Int16);
153 Ok(())
154 }
155 ColumnData::Uint1(container) => {
156 sum_arm!(self, column, groups, container, u8, Value::Uint1);
157 Ok(())
158 }
159 ColumnData::Uint2(container) => {
160 sum_arm!(self, column, groups, container, u16, Value::Uint2);
161 Ok(())
162 }
163 ColumnData::Uint4(container) => {
164 sum_arm!(self, column, groups, container, u32, Value::Uint4);
165 Ok(())
166 }
167 ColumnData::Uint8(container) => {
168 sum_arm!(self, column, groups, container, u64, Value::Uint8);
169 Ok(())
170 }
171 ColumnData::Uint16(container) => {
172 sum_arm!(self, column, groups, container, u128, Value::Uint16);
173 Ok(())
174 }
175 ColumnData::Float4(container) => {
176 sum_arm!(self, column, groups, container, f32, Value::float4);
177 Ok(())
178 }
179 ColumnData::Float8(container) => {
180 sum_arm!(self, column, groups, container, f64, Value::float8);
181 Ok(())
182 }
183 ColumnData::Int {
184 container,
185 ..
186 } => {
187 for (group, indices) in groups.iter() {
188 let mut sum = Int::zero();
189 let mut has_value = false;
190 for &i in indices {
191 if column.data().is_defined(i)
192 && let Some(val) = container.get(i)
193 {
194 sum = Int(sum.0 + &val.0);
195 has_value = true;
196 }
197 }
198 if has_value {
199 self.sums.insert(group.clone(), Value::Int(sum));
200 } else {
201 self.sums.entry(group.clone()).or_insert(Value::none());
202 }
203 }
204 Ok(())
205 }
206 ColumnData::Uint {
207 container,
208 ..
209 } => {
210 for (group, indices) in groups.iter() {
211 let mut sum = Uint::zero();
212 let mut has_value = false;
213 for &i in indices {
214 if column.data().is_defined(i)
215 && let Some(val) = container.get(i)
216 {
217 sum = Uint(sum.0 + &val.0);
218 has_value = true;
219 }
220 }
221 if has_value {
222 self.sums.insert(group.clone(), Value::Uint(sum));
223 } else {
224 self.sums.entry(group.clone()).or_insert(Value::none());
225 }
226 }
227 Ok(())
228 }
229 ColumnData::Decimal {
230 container,
231 ..
232 } => {
233 for (group, indices) in groups.iter() {
234 let mut sum = Decimal::zero();
235 let mut has_value = false;
236 for &i in indices {
237 if column.data().is_defined(i)
238 && let Some(val) = container.get(i)
239 {
240 sum = Decimal(sum.0 + &val.0);
241 has_value = true;
242 }
243 }
244 if has_value {
245 self.sums.insert(group.clone(), Value::Decimal(sum));
246 } else {
247 self.sums.entry(group.clone()).or_insert(Value::none());
248 }
249 }
250 Ok(())
251 }
252 other => Err(FunctionError::InvalidArgumentType {
253 function: Fragment::internal("math::sum"),
254 argument_index: 0,
255 expected: InputTypes::numeric().expected_at(0).to_vec(),
256 actual: other.get_type(),
257 }),
258 }
259 }
260
261 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnData), FunctionError> {
262 let ty = self.input_type.take().unwrap_or(Type::Int8);
263 let mut keys = Vec::with_capacity(self.sums.len());
264 let mut data = ColumnData::with_capacity(ty, self.sums.len());
265
266 for (key, sum) in mem::take(&mut self.sums) {
267 keys.push(key);
268 data.push_value(sum);
269 }
270
271 Ok((keys, data))
272 }
273}