Skip to main content

reifydb_routine/function/math/aggregate/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::{
9	Value,
10	decimal::Decimal,
11	int::Int,
12	r#type::{Type, input_types::InputTypes},
13	uint::Uint,
14};
15
16use crate::function::{
17	AggregateFunction, AggregateFunctionContext,
18	error::{AggregateFunctionError, AggregateFunctionResult},
19};
20
21pub struct Sum {
22	pub sums: IndexMap<Vec<Value>, Value>,
23	input_type: Option<Type>,
24}
25
26impl Sum {
27	pub fn new() -> Self {
28		Self {
29			sums: IndexMap::new(),
30			input_type: None,
31		}
32	}
33}
34
35macro_rules! sum_arm {
36	($self:expr, $column:expr, $groups:expr, $container:expr, $t:ty, $ctor:expr) => {
37		for (group, indices) in $groups.iter() {
38			let mut sum: $t = Default::default();
39			let mut has_value = false;
40			for &i in indices {
41				if $column.data().is_defined(i) {
42					if let Some(&val) = $container.get(i) {
43						sum += val;
44						has_value = true;
45					}
46				}
47			}
48			if has_value {
49				$self.sums.insert(group.clone(), $ctor(sum));
50			} else {
51				$self.sums.entry(group.clone()).or_insert(Value::none());
52			}
53		}
54	};
55}
56
57impl AggregateFunction for Sum {
58	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
59		let column = ctx.column;
60		let groups = &ctx.groups;
61		let (data, _bitvec) = column.data().unwrap_option();
62
63		if self.input_type.is_none() {
64			self.input_type = Some(data.get_type());
65		}
66
67		match data {
68			ColumnData::Int1(container) => {
69				sum_arm!(self, column, groups, container, i8, Value::Int1);
70				Ok(())
71			}
72			ColumnData::Int2(container) => {
73				sum_arm!(self, column, groups, container, i16, Value::Int2);
74				Ok(())
75			}
76			ColumnData::Int4(container) => {
77				sum_arm!(self, column, groups, container, i32, Value::Int4);
78				Ok(())
79			}
80			ColumnData::Int8(container) => {
81				sum_arm!(self, column, groups, container, i64, Value::Int8);
82				Ok(())
83			}
84			ColumnData::Int16(container) => {
85				sum_arm!(self, column, groups, container, i128, Value::Int16);
86				Ok(())
87			}
88			ColumnData::Uint1(container) => {
89				sum_arm!(self, column, groups, container, u8, Value::Uint1);
90				Ok(())
91			}
92			ColumnData::Uint2(container) => {
93				sum_arm!(self, column, groups, container, u16, Value::Uint2);
94				Ok(())
95			}
96			ColumnData::Uint4(container) => {
97				sum_arm!(self, column, groups, container, u32, Value::Uint4);
98				Ok(())
99			}
100			ColumnData::Uint8(container) => {
101				sum_arm!(self, column, groups, container, u64, Value::Uint8);
102				Ok(())
103			}
104			ColumnData::Uint16(container) => {
105				sum_arm!(self, column, groups, container, u128, Value::Uint16);
106				Ok(())
107			}
108			ColumnData::Float4(container) => {
109				sum_arm!(self, column, groups, container, f32, Value::float4);
110				Ok(())
111			}
112			ColumnData::Float8(container) => {
113				sum_arm!(self, column, groups, container, f64, Value::float8);
114				Ok(())
115			}
116			ColumnData::Int {
117				container,
118				..
119			} => {
120				for (group, indices) in groups.iter() {
121					let mut sum = Int::zero();
122					let mut has_value = false;
123					for &i in indices {
124						if column.data().is_defined(i) {
125							if let Some(val) = container.get(i) {
126								sum = Int(sum.0 + &val.0);
127								has_value = true;
128							}
129						}
130					}
131					if has_value {
132						self.sums.insert(group.clone(), Value::Int(sum));
133					} else {
134						self.sums.entry(group.clone()).or_insert(Value::none());
135					}
136				}
137				Ok(())
138			}
139			ColumnData::Uint {
140				container,
141				..
142			} => {
143				for (group, indices) in groups.iter() {
144					let mut sum = Uint::zero();
145					let mut has_value = false;
146					for &i in indices {
147						if column.data().is_defined(i) {
148							if let Some(val) = container.get(i) {
149								sum = Uint(sum.0 + &val.0);
150								has_value = true;
151							}
152						}
153					}
154					if has_value {
155						self.sums.insert(group.clone(), Value::Uint(sum));
156					} else {
157						self.sums.entry(group.clone()).or_insert(Value::none());
158					}
159				}
160				Ok(())
161			}
162			ColumnData::Decimal {
163				container,
164				..
165			} => {
166				for (group, indices) in groups.iter() {
167					let mut sum = Decimal::zero();
168					let mut has_value = false;
169					for &i in indices {
170						if column.data().is_defined(i) {
171							if let Some(val) = container.get(i) {
172								sum = Decimal(sum.0 + &val.0);
173								has_value = true;
174							}
175						}
176					}
177					if has_value {
178						self.sums.insert(group.clone(), Value::Decimal(sum));
179					} else {
180						self.sums.entry(group.clone()).or_insert(Value::none());
181					}
182				}
183				Ok(())
184			}
185			other => Err(AggregateFunctionError::InvalidArgumentType {
186				function: ctx.fragment.clone(),
187				argument_index: 0,
188				expected: self.accepted_types().expected_at(0).to_vec(),
189				actual: other.get_type(),
190			}),
191		}
192	}
193
194	fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
195		let ty = self.input_type.take().unwrap_or(Type::Int8);
196		let mut keys = Vec::with_capacity(self.sums.len());
197		let mut data = ColumnData::with_capacity(ty, self.sums.len());
198
199		for (key, sum) in mem::take(&mut self.sums) {
200			keys.push(key);
201			data.push_value(sum);
202		}
203
204		Ok((keys, data))
205	}
206
207	fn return_type(&self, input_type: &Type) -> Type {
208		input_type.clone()
209	}
210
211	fn accepted_types(&self) -> InputTypes {
212		InputTypes::numeric()
213	}
214}