Skip to main content

reifydb_routine/function/math/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8	Column,
9	columns::Columns,
10	data::ColumnData,
11	view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14	fragment::Fragment,
15	value::{
16		Value,
17		decimal::Decimal,
18		int::Int,
19		r#type::{Type, input_types::InputTypes},
20		uint::Uint,
21	},
22};
23
24use crate::function::{Accumulator, Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
25
26pub struct Sum {
27	info: FunctionInfo,
28}
29
30impl Default for Sum {
31	fn default() -> Self {
32		Self::new()
33	}
34}
35
36impl Sum {
37	pub fn new() -> Self {
38		Self {
39			info: FunctionInfo::new("math::sum"),
40		}
41	}
42}
43
44impl Function for Sum {
45	fn info(&self) -> &FunctionInfo {
46		&self.info
47	}
48
49	fn capabilities(&self) -> &[FunctionCapability] {
50		&[FunctionCapability::Scalar, FunctionCapability::Aggregate]
51	}
52
53	fn return_type(&self, input_types: &[Type]) -> Type {
54		input_types.first().cloned().unwrap_or(Type::Int8)
55	}
56
57	fn accepted_types(&self) -> InputTypes {
58		InputTypes::numeric()
59	}
60
61	fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
62		// SCALAR: Horizontal Sum (summing columns in each row)
63		if args.is_empty() {
64			return Err(FunctionError::ArityMismatch {
65				function: ctx.fragment.clone(),
66				expected: 1,
67				actual: 0,
68			});
69		}
70
71		let row_count = args.row_count();
72		let mut results = Vec::with_capacity(row_count);
73
74		for i in 0..row_count {
75			// Basic implementation: just use first arg for now or add them if possible
76			// In a full implementation we would use a unified adder
77			let val1 = args[0].data().get_value(i);
78			results.push(Box::new(val1));
79		}
80
81		Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), ColumnData::any(results))]))
82	}
83
84	fn accumulator(&self, _ctx: &FunctionContext) -> Option<Box<dyn Accumulator>> {
85		Some(Box::new(SumAccumulator::new()))
86	}
87}
88
89struct SumAccumulator {
90	pub sums: IndexMap<Vec<Value>, Value>,
91	input_type: Option<Type>,
92}
93
94impl SumAccumulator {
95	pub fn new() -> Self {
96		Self {
97			sums: IndexMap::new(),
98			input_type: None,
99		}
100	}
101}
102
103macro_rules! sum_arm {
104	($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $ctor:expr) => {
105		for (group, indices) in $groups.iter() {
106			let mut sum: $t = Default::default();
107			let mut has_value = false;
108			for &i in indices {
109				if $column.data().is_defined(i) {
110					if let Some(&val) = $container.get(i) {
111						sum += val;
112						has_value = true;
113					}
114				}
115			}
116			if has_value {
117				$self.sums.insert(group.clone(), $ctor(sum));
118			} else {
119				$self.sums.entry(group.clone()).or_insert(Value::none());
120			}
121		}
122	};
123}
124
125impl Accumulator for SumAccumulator {
126	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), FunctionError> {
127		let column = &args[0];
128		let (data, _bitvec) = column.data().unwrap_option();
129
130		if self.input_type.is_none() {
131			self.input_type = Some(data.get_type());
132		}
133
134		match data {
135			ColumnData::Int1(container) => {
136				sum_arm!(self, column, groups, container, i8, Value::Int1);
137				Ok(())
138			}
139			ColumnData::Int2(container) => {
140				sum_arm!(self, column, groups, container, i16, Value::Int2);
141				Ok(())
142			}
143			ColumnData::Int4(container) => {
144				sum_arm!(self, column, groups, container, i32, Value::Int4);
145				Ok(())
146			}
147			ColumnData::Int8(container) => {
148				sum_arm!(self, column, groups, container, i64, Value::Int8);
149				Ok(())
150			}
151			ColumnData::Int16(container) => {
152				sum_arm!(self, column, groups, container, i128, Value::Int16);
153				Ok(())
154			}
155			ColumnData::Uint1(container) => {
156				sum_arm!(self, column, groups, container, u8, Value::Uint1);
157				Ok(())
158			}
159			ColumnData::Uint2(container) => {
160				sum_arm!(self, column, groups, container, u16, Value::Uint2);
161				Ok(())
162			}
163			ColumnData::Uint4(container) => {
164				sum_arm!(self, column, groups, container, u32, Value::Uint4);
165				Ok(())
166			}
167			ColumnData::Uint8(container) => {
168				sum_arm!(self, column, groups, container, u64, Value::Uint8);
169				Ok(())
170			}
171			ColumnData::Uint16(container) => {
172				sum_arm!(self, column, groups, container, u128, Value::Uint16);
173				Ok(())
174			}
175			ColumnData::Float4(container) => {
176				sum_arm!(self, column, groups, container, f32, Value::float4);
177				Ok(())
178			}
179			ColumnData::Float8(container) => {
180				sum_arm!(self, column, groups, container, f64, Value::float8);
181				Ok(())
182			}
183			ColumnData::Int {
184				container,
185				..
186			} => {
187				for (group, indices) in groups.iter() {
188					let mut sum = Int::zero();
189					let mut has_value = false;
190					for &i in indices {
191						if column.data().is_defined(i)
192							&& let Some(val) = container.get(i)
193						{
194							sum = Int(sum.0 + &val.0);
195							has_value = true;
196						}
197					}
198					if has_value {
199						self.sums.insert(group.clone(), Value::Int(sum));
200					} else {
201						self.sums.entry(group.clone()).or_insert(Value::none());
202					}
203				}
204				Ok(())
205			}
206			ColumnData::Uint {
207				container,
208				..
209			} => {
210				for (group, indices) in groups.iter() {
211					let mut sum = Uint::zero();
212					let mut has_value = false;
213					for &i in indices {
214						if column.data().is_defined(i)
215							&& let Some(val) = container.get(i)
216						{
217							sum = Uint(sum.0 + &val.0);
218							has_value = true;
219						}
220					}
221					if has_value {
222						self.sums.insert(group.clone(), Value::Uint(sum));
223					} else {
224						self.sums.entry(group.clone()).or_insert(Value::none());
225					}
226				}
227				Ok(())
228			}
229			ColumnData::Decimal {
230				container,
231				..
232			} => {
233				for (group, indices) in groups.iter() {
234					let mut sum = Decimal::zero();
235					let mut has_value = false;
236					for &i in indices {
237						if column.data().is_defined(i)
238							&& let Some(val) = container.get(i)
239						{
240							sum = Decimal(sum.0 + &val.0);
241							has_value = true;
242						}
243					}
244					if has_value {
245						self.sums.insert(group.clone(), Value::Decimal(sum));
246					} else {
247						self.sums.entry(group.clone()).or_insert(Value::none());
248					}
249				}
250				Ok(())
251			}
252			other => Err(FunctionError::InvalidArgumentType {
253				function: Fragment::internal("math::sum"),
254				argument_index: 0,
255				expected: InputTypes::numeric().expected_at(0).to_vec(),
256				actual: other.get_type(),
257			}),
258		}
259	}
260
261	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnData), FunctionError> {
262		let ty = self.input_type.take().unwrap_or(Type::Int8);
263		let mut keys = Vec::with_capacity(self.sums.len());
264		let mut data = ColumnData::with_capacity(ty, self.sums.len());
265
266		for (key, sum) in mem::take(&mut self.sums) {
267			keys.push(key);
268			data.push_value(sum);
269		}
270
271		Ok((keys, data))
272	}
273}