Skip to main content

reifydb_function/math/aggregate/
sum.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::{Value, r#type::Type};
9
10use crate::{AggregateFunction, AggregateFunctionContext, error::AggregateFunctionResult};
11
12pub struct Sum {
13	pub sums: IndexMap<Vec<Value>, Value>,
14}
15
16impl Sum {
17	pub fn new() -> Self {
18		Self {
19			sums: IndexMap::new(),
20		}
21	}
22}
23
24impl AggregateFunction for Sum {
25	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
26		let column = ctx.column;
27		let groups = &ctx.groups;
28
29		match &column.data() {
30			ColumnData::Float8(container) => {
31				for (group, indices) in groups.iter() {
32					let sum: f64 = indices
33						.iter()
34						.filter(|&i| container.is_defined(*i))
35						.filter_map(|&i| container.get(i))
36						.sum();
37
38					self.sums.insert(group.clone(), Value::float8(sum));
39				}
40				Ok(())
41			}
42			ColumnData::Float4(container) => {
43				for (group, indices) in groups.iter() {
44					let sum: f32 = indices
45						.iter()
46						.filter(|&i| container.is_defined(*i))
47						.filter_map(|&i| container.get(i))
48						.sum();
49
50					self.sums.insert(group.clone(), Value::float4(sum));
51				}
52				Ok(())
53			}
54			ColumnData::Int2(container) => {
55				for (group, indices) in groups.iter() {
56					let sum: i16 = indices.iter().filter_map(|&i| container.get(i)).sum();
57
58					self.sums.insert(group.clone(), Value::Int2(sum));
59				}
60				Ok(())
61			}
62			ColumnData::Int4(container) => {
63				for (group, indices) in groups.iter() {
64					let sum: i32 = indices
65						.iter()
66						.filter(|&i| container.is_defined(*i))
67						.filter_map(|&i| container.get(i))
68						.sum();
69					self.sums.insert(group.clone(), Value::Int4(sum));
70				}
71				Ok(())
72			}
73			ColumnData::Int8(container) => {
74				for (group, indices) in groups.iter() {
75					let sum: i64 = indices
76						.iter()
77						.filter(|&i| container.is_defined(*i))
78						.filter_map(|&i| container.get(i))
79						.sum();
80
81					self.sums.insert(group.clone(), Value::Int8(sum));
82				}
83				Ok(())
84			}
85			_ => unimplemented!("{}", column.get_type()),
86		}
87	}
88
89	fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
90		let mut keys = Vec::with_capacity(self.sums.len());
91		let mut data = ColumnData::none_typed(Type::Boolean, 0);
92
93		for (key, sum) in mem::take(&mut self.sums) {
94			keys.push(key);
95			data.push_value(sum);
96		}
97
98		Ok((keys, data))
99	}
100}