Skip to main content

reifydb_routine/function/math/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8	ColumnWithName,
9	buffer::ColumnBuffer,
10	columns::Columns,
11	view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14	fragment::Fragment,
15	value::{
16		Value,
17		decimal::Decimal,
18		int::Int,
19		r#type::{Type, input_types::InputTypes},
20		uint::Uint,
21	},
22};
23
24use crate::routine::{
25	Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
26};
27
28pub struct Sum {
29	info: RoutineInfo,
30}
31
32impl Default for Sum {
33	fn default() -> Self {
34		Self::new()
35	}
36}
37
38impl Sum {
39	pub fn new() -> Self {
40		Self {
41			info: RoutineInfo::new("math::sum"),
42		}
43	}
44}
45
46impl<'a> Routine<FunctionContext<'a>> for Sum {
47	fn info(&self) -> &RoutineInfo {
48		&self.info
49	}
50
51	fn return_type(&self, input_types: &[Type]) -> Type {
52		input_types.first().cloned().unwrap_or(Type::Int8)
53	}
54
55	fn accepted_types(&self) -> InputTypes {
56		InputTypes::numeric()
57	}
58
59	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
60		// SCALAR: Horizontal Sum (summing columns in each row)
61		if args.is_empty() {
62			return Err(RoutineError::FunctionArityMismatch {
63				function: ctx.fragment.clone(),
64				expected: 1,
65				actual: 0,
66			});
67		}
68
69		let row_count = args.row_count();
70		let mut results = Vec::with_capacity(row_count);
71
72		for i in 0..row_count {
73			// Basic implementation: just use first arg for now or add them if possible
74			// In a full implementation we would use a unified adder
75			let val1 = args[0].get_value(i);
76			results.push(Box::new(val1));
77		}
78
79		Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), ColumnBuffer::any(results))]))
80	}
81}
82
83impl Function for Sum {
84	fn kinds(&self) -> &[FunctionKind] {
85		&[FunctionKind::Scalar, FunctionKind::Aggregate]
86	}
87
88	fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
89		Some(Box::new(SumAccumulator::new()))
90	}
91}
92
93struct SumAccumulator {
94	pub sums: IndexMap<Vec<Value>, Value>,
95	input_type: Option<Type>,
96}
97
98impl SumAccumulator {
99	pub fn new() -> Self {
100		Self {
101			sums: IndexMap::new(),
102			input_type: None,
103		}
104	}
105}
106
107macro_rules! sum_arm {
108	($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $ctor:expr) => {
109		for (group, indices) in $groups.iter() {
110			let mut sum: $t = Default::default();
111			let mut has_value = false;
112			for &i in indices {
113				if $column.is_defined(i) {
114					if let Some(&val) = $container.get(i) {
115						sum += val;
116						has_value = true;
117					}
118				}
119			}
120			if has_value {
121				$self.sums.insert(group.clone(), $ctor(sum));
122			} else {
123				$self.sums.entry(group.clone()).or_insert(Value::none());
124			}
125		}
126	};
127}
128
129impl Accumulator for SumAccumulator {
130	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
131		let column = &args[0];
132		let (data, _bitvec) = column.unwrap_option();
133
134		if self.input_type.is_none() {
135			self.input_type = Some(data.get_type());
136		}
137
138		match data {
139			ColumnBuffer::Int1(container) => {
140				sum_arm!(self, column, groups, container, i8, Value::Int1);
141				Ok(())
142			}
143			ColumnBuffer::Int2(container) => {
144				sum_arm!(self, column, groups, container, i16, Value::Int2);
145				Ok(())
146			}
147			ColumnBuffer::Int4(container) => {
148				sum_arm!(self, column, groups, container, i32, Value::Int4);
149				Ok(())
150			}
151			ColumnBuffer::Int8(container) => {
152				sum_arm!(self, column, groups, container, i64, Value::Int8);
153				Ok(())
154			}
155			ColumnBuffer::Int16(container) => {
156				sum_arm!(self, column, groups, container, i128, Value::Int16);
157				Ok(())
158			}
159			ColumnBuffer::Uint1(container) => {
160				sum_arm!(self, column, groups, container, u8, Value::Uint1);
161				Ok(())
162			}
163			ColumnBuffer::Uint2(container) => {
164				sum_arm!(self, column, groups, container, u16, Value::Uint2);
165				Ok(())
166			}
167			ColumnBuffer::Uint4(container) => {
168				sum_arm!(self, column, groups, container, u32, Value::Uint4);
169				Ok(())
170			}
171			ColumnBuffer::Uint8(container) => {
172				sum_arm!(self, column, groups, container, u64, Value::Uint8);
173				Ok(())
174			}
175			ColumnBuffer::Uint16(container) => {
176				sum_arm!(self, column, groups, container, u128, Value::Uint16);
177				Ok(())
178			}
179			ColumnBuffer::Float4(container) => {
180				sum_arm!(self, column, groups, container, f32, Value::float4);
181				Ok(())
182			}
183			ColumnBuffer::Float8(container) => {
184				sum_arm!(self, column, groups, container, f64, Value::float8);
185				Ok(())
186			}
187			ColumnBuffer::Int {
188				container,
189				..
190			} => {
191				for (group, indices) in groups.iter() {
192					let mut sum = Int::zero();
193					let mut has_value = false;
194					for &i in indices {
195						if column.is_defined(i)
196							&& let Some(val) = container.get(i)
197						{
198							sum = Int(sum.0 + &val.0);
199							has_value = true;
200						}
201					}
202					if has_value {
203						self.sums.insert(group.clone(), Value::Int(sum));
204					} else {
205						self.sums.entry(group.clone()).or_insert(Value::none());
206					}
207				}
208				Ok(())
209			}
210			ColumnBuffer::Uint {
211				container,
212				..
213			} => {
214				for (group, indices) in groups.iter() {
215					let mut sum = Uint::zero();
216					let mut has_value = false;
217					for &i in indices {
218						if column.is_defined(i)
219							&& let Some(val) = container.get(i)
220						{
221							sum = Uint(sum.0 + &val.0);
222							has_value = true;
223						}
224					}
225					if has_value {
226						self.sums.insert(group.clone(), Value::Uint(sum));
227					} else {
228						self.sums.entry(group.clone()).or_insert(Value::none());
229					}
230				}
231				Ok(())
232			}
233			ColumnBuffer::Decimal {
234				container,
235				..
236			} => {
237				for (group, indices) in groups.iter() {
238					let mut sum = Decimal::zero();
239					let mut has_value = false;
240					for &i in indices {
241						if column.is_defined(i)
242							&& let Some(val) = container.get(i)
243						{
244							sum = Decimal(sum.0 + &val.0);
245							has_value = true;
246						}
247					}
248					if has_value {
249						self.sums.insert(group.clone(), Value::Decimal(sum));
250					} else {
251						self.sums.entry(group.clone()).or_insert(Value::none());
252					}
253				}
254				Ok(())
255			}
256			other => Err(RoutineError::FunctionInvalidArgumentType {
257				function: Fragment::internal("math::sum"),
258				argument_index: 0,
259				expected: InputTypes::numeric().expected_at(0).to_vec(),
260				actual: other.get_type(),
261			}),
262		}
263	}
264
265	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
266		let ty = self.input_type.take().unwrap_or(Type::Int8);
267		let mut keys = Vec::with_capacity(self.sums.len());
268		let mut data = ColumnBuffer::with_capacity(ty, self.sums.len());
269
270		for (key, sum) in mem::take(&mut self.sums) {
271			keys.push(key);
272			data.push_value(sum);
273		}
274
275		Ok((keys, data))
276	}
277}