Skip to main content

reifydb_routine/function/math/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8	ColumnWithName,
9	buffer::ColumnBuffer,
10	columns::Columns,
11	view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14	fragment::Fragment,
15	value::{
16		Value,
17		decimal::Decimal,
18		int::Int,
19		r#type::{Type, input_types::InputTypes},
20		uint::Uint,
21	},
22};
23
24use crate::routine::{
25	Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
26};
27
28pub struct Sum {
29	info: RoutineInfo,
30}
31
32impl Default for Sum {
33	fn default() -> Self {
34		Self::new()
35	}
36}
37
38impl Sum {
39	pub fn new() -> Self {
40		Self {
41			info: RoutineInfo::new("math::sum"),
42		}
43	}
44}
45
46impl<'a> Routine<FunctionContext<'a>> for Sum {
47	fn info(&self) -> &RoutineInfo {
48		&self.info
49	}
50
51	fn return_type(&self, input_types: &[Type]) -> Type {
52		input_types.first().cloned().unwrap_or(Type::Int8)
53	}
54
55	fn accepted_types(&self) -> InputTypes {
56		InputTypes::numeric()
57	}
58
59	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
60		if args.is_empty() {
61			return Err(RoutineError::FunctionArityMismatch {
62				function: ctx.fragment.clone(),
63				expected: 1,
64				actual: 0,
65			});
66		}
67
68		let row_count = args.row_count();
69		let mut results = Vec::with_capacity(row_count);
70
71		for i in 0..row_count {
72			let val1 = args[0].get_value(i);
73			results.push(Box::new(val1));
74		}
75
76		Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), ColumnBuffer::any(results))]))
77	}
78}
79
80impl Function for Sum {
81	fn kinds(&self) -> &[FunctionKind] {
82		&[FunctionKind::Scalar, FunctionKind::Aggregate]
83	}
84
85	fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
86		Some(Box::new(SumAccumulator::new()))
87	}
88}
89
90struct SumAccumulator {
91	pub sums: IndexMap<Vec<Value>, Value>,
92	input_type: Option<Type>,
93}
94
95impl SumAccumulator {
96	pub fn new() -> Self {
97		Self {
98			sums: IndexMap::new(),
99			input_type: None,
100		}
101	}
102}
103
104macro_rules! sum_arm {
105	($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $variant:ident) => {
106		for (group, indices) in $groups.iter() {
107			let mut delta: $t = Default::default();
108			let mut has_value = false;
109			for &i in indices {
110				if $column.is_defined(i) {
111					if let Some(&val) = $container.get(i) {
112						delta += val;
113						has_value = true;
114					}
115				}
116			}
117			if has_value {
118				let merged = match $self.sums.swap_remove(group) {
119					Some(Value::$variant(prev)) => prev + delta,
120					_ => delta,
121				};
122				$self.sums.insert(group.clone(), Value::$variant(merged));
123			} else {
124				$self.sums.entry(group.clone()).or_insert(Value::none());
125			}
126		}
127	};
128}
129
130macro_rules! sum_arm_float {
131	($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $variant:ident, $ctor:expr) => {
132		for (group, indices) in $groups.iter() {
133			let mut delta: $t = Default::default();
134			let mut has_value = false;
135			for &i in indices {
136				if $column.is_defined(i) {
137					if let Some(&val) = $container.get(i) {
138						delta += val;
139						has_value = true;
140					}
141				}
142			}
143			if has_value {
144				let merged = match $self.sums.swap_remove(group) {
145					Some(Value::$variant(prev)) => prev.value() + delta,
146					_ => delta,
147				};
148				$self.sums.insert(group.clone(), $ctor(merged));
149			} else {
150				$self.sums.entry(group.clone()).or_insert(Value::none());
151			}
152		}
153	};
154}
155
156impl Accumulator for SumAccumulator {
157	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
158		let column = &args[0];
159		let (data, _bitvec) = column.unwrap_option();
160
161		if self.input_type.is_none() {
162			self.input_type = Some(data.get_type());
163		}
164
165		match data {
166			ColumnBuffer::Int1(container) => {
167				sum_arm!(self, column, groups, container, i8, Int1);
168				Ok(())
169			}
170			ColumnBuffer::Int2(container) => {
171				sum_arm!(self, column, groups, container, i16, Int2);
172				Ok(())
173			}
174			ColumnBuffer::Int4(container) => {
175				sum_arm!(self, column, groups, container, i32, Int4);
176				Ok(())
177			}
178			ColumnBuffer::Int8(container) => {
179				sum_arm!(self, column, groups, container, i64, Int8);
180				Ok(())
181			}
182			ColumnBuffer::Int16(container) => {
183				sum_arm!(self, column, groups, container, i128, Int16);
184				Ok(())
185			}
186			ColumnBuffer::Uint1(container) => {
187				sum_arm!(self, column, groups, container, u8, Uint1);
188				Ok(())
189			}
190			ColumnBuffer::Uint2(container) => {
191				sum_arm!(self, column, groups, container, u16, Uint2);
192				Ok(())
193			}
194			ColumnBuffer::Uint4(container) => {
195				sum_arm!(self, column, groups, container, u32, Uint4);
196				Ok(())
197			}
198			ColumnBuffer::Uint8(container) => {
199				sum_arm!(self, column, groups, container, u64, Uint8);
200				Ok(())
201			}
202			ColumnBuffer::Uint16(container) => {
203				sum_arm!(self, column, groups, container, u128, Uint16);
204				Ok(())
205			}
206			ColumnBuffer::Float4(container) => {
207				sum_arm_float!(self, column, groups, container, f32, Float4, Value::float4);
208				Ok(())
209			}
210			ColumnBuffer::Float8(container) => {
211				sum_arm_float!(self, column, groups, container, f64, Float8, Value::float8);
212				Ok(())
213			}
214			ColumnBuffer::Int {
215				container,
216				..
217			} => {
218				for (group, indices) in groups.iter() {
219					let mut delta = Int::zero();
220					let mut has_value = false;
221					for &i in indices {
222						if column.is_defined(i)
223							&& let Some(val) = container.get(i)
224						{
225							delta = Int(delta.0 + &val.0);
226							has_value = true;
227						}
228					}
229					if has_value {
230						let merged = match self.sums.swap_remove(group) {
231							Some(Value::Int(prev)) => Int(prev.0 + &delta.0),
232							_ => delta,
233						};
234						self.sums.insert(group.clone(), Value::Int(merged));
235					} else {
236						self.sums.entry(group.clone()).or_insert(Value::none());
237					}
238				}
239				Ok(())
240			}
241			ColumnBuffer::Uint {
242				container,
243				..
244			} => {
245				for (group, indices) in groups.iter() {
246					let mut delta = Uint::zero();
247					let mut has_value = false;
248					for &i in indices {
249						if column.is_defined(i)
250							&& let Some(val) = container.get(i)
251						{
252							delta = Uint(delta.0 + &val.0);
253							has_value = true;
254						}
255					}
256					if has_value {
257						let merged = match self.sums.swap_remove(group) {
258							Some(Value::Uint(prev)) => Uint(prev.0 + &delta.0),
259							_ => delta,
260						};
261						self.sums.insert(group.clone(), Value::Uint(merged));
262					} else {
263						self.sums.entry(group.clone()).or_insert(Value::none());
264					}
265				}
266				Ok(())
267			}
268			ColumnBuffer::Decimal {
269				container,
270				..
271			} => {
272				for (group, indices) in groups.iter() {
273					let mut delta = Decimal::zero();
274					let mut has_value = false;
275					for &i in indices {
276						if column.is_defined(i)
277							&& let Some(val) = container.get(i)
278						{
279							delta = Decimal(delta.0 + &val.0);
280							has_value = true;
281						}
282					}
283					if has_value {
284						let merged = match self.sums.swap_remove(group) {
285							Some(Value::Decimal(prev)) => Decimal(prev.0 + &delta.0),
286							_ => delta,
287						};
288						self.sums.insert(group.clone(), Value::Decimal(merged));
289					} else {
290						self.sums.entry(group.clone()).or_insert(Value::none());
291					}
292				}
293				Ok(())
294			}
295			other => Err(RoutineError::FunctionInvalidArgumentType {
296				function: Fragment::internal("math::sum"),
297				argument_index: 0,
298				expected: InputTypes::numeric().expected_at(0).to_vec(),
299				actual: other.get_type(),
300			}),
301		}
302	}
303
304	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
305		let ty = self.input_type.take().unwrap_or(Type::Int8);
306		let mut keys = Vec::with_capacity(self.sums.len());
307		let mut data = ColumnBuffer::with_capacity(ty, self.sums.len());
308
309		for (key, sum) in mem::take(&mut self.sums) {
310			keys.push(key);
311			data.push_value(sum);
312		}
313
314		Ok((keys, data))
315	}
316}