Skip to main content

reifydb_routine/function/math/aggregate/
max.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::data::ColumnData;
8use reifydb_type::value::{
9	Value,
10	decimal::Decimal,
11	int::Int,
12	r#type::{Type, input_types::InputTypes},
13	uint::Uint,
14};
15
16use crate::function::{
17	AggregateFunction, AggregateFunctionContext,
18	error::{AggregateFunctionError, AggregateFunctionResult},
19};
20
21pub struct Max {
22	pub maxs: IndexMap<Vec<Value>, Value>,
23	input_type: Option<Type>,
24}
25
26impl Max {
27	pub fn new() -> Self {
28		Self {
29			maxs: IndexMap::new(),
30			input_type: None,
31		}
32	}
33}
34
35macro_rules! max_arm {
36	($self:expr, $column:expr, $groups:expr, $container:expr, $ctor:expr) => {
37		for (group, indices) in $groups.iter() {
38			let mut max = None;
39			for &i in indices {
40				if $column.data().is_defined(i) {
41					if let Some(&val) = $container.get(i) {
42						max = Some(match max {
43							Some(current) if val > current => val,
44							Some(current) => current,
45							None => val,
46						});
47					}
48				}
49			}
50			match max {
51				Some(v) => {
52					$self.maxs.insert(group.clone(), $ctor(v));
53				}
54				None => {
55					$self.maxs.entry(group.clone()).or_insert(Value::none());
56				}
57			}
58		}
59	};
60}
61
62impl AggregateFunction for Max {
63	fn aggregate(&mut self, ctx: AggregateFunctionContext) -> AggregateFunctionResult<()> {
64		let column = ctx.column;
65		let groups = &ctx.groups;
66		let (data, _bitvec) = column.data().unwrap_option();
67
68		if self.input_type.is_none() {
69			self.input_type = Some(data.get_type());
70		}
71
72		match data {
73			ColumnData::Int1(container) => {
74				max_arm!(self, column, groups, container, Value::Int1);
75				Ok(())
76			}
77			ColumnData::Int2(container) => {
78				max_arm!(self, column, groups, container, Value::Int2);
79				Ok(())
80			}
81			ColumnData::Int4(container) => {
82				max_arm!(self, column, groups, container, Value::Int4);
83				Ok(())
84			}
85			ColumnData::Int8(container) => {
86				max_arm!(self, column, groups, container, Value::Int8);
87				Ok(())
88			}
89			ColumnData::Int16(container) => {
90				max_arm!(self, column, groups, container, Value::Int16);
91				Ok(())
92			}
93			ColumnData::Uint1(container) => {
94				max_arm!(self, column, groups, container, Value::Uint1);
95				Ok(())
96			}
97			ColumnData::Uint2(container) => {
98				max_arm!(self, column, groups, container, Value::Uint2);
99				Ok(())
100			}
101			ColumnData::Uint4(container) => {
102				max_arm!(self, column, groups, container, Value::Uint4);
103				Ok(())
104			}
105			ColumnData::Uint8(container) => {
106				max_arm!(self, column, groups, container, Value::Uint8);
107				Ok(())
108			}
109			ColumnData::Uint16(container) => {
110				max_arm!(self, column, groups, container, Value::Uint16);
111				Ok(())
112			}
113			ColumnData::Float4(container) => {
114				for (group, indices) in groups.iter() {
115					let mut max: Option<f32> = None;
116					for &i in indices {
117						if column.data().is_defined(i) {
118							if let Some(&val) = container.get(i) {
119								max = Some(match max {
120									Some(current) => f32::max(current, val),
121									None => val,
122								});
123							}
124						}
125					}
126					match max {
127						Some(v) => {
128							self.maxs.insert(group.clone(), Value::float4(v));
129						}
130						None => {
131							self.maxs.entry(group.clone()).or_insert(Value::none());
132						}
133					}
134				}
135				Ok(())
136			}
137			ColumnData::Float8(container) => {
138				for (group, indices) in groups.iter() {
139					let mut max: Option<f64> = None;
140					for &i in indices {
141						if column.data().is_defined(i) {
142							if let Some(&val) = container.get(i) {
143								max = Some(match max {
144									Some(current) => f64::max(current, val),
145									None => val,
146								});
147							}
148						}
149					}
150					match max {
151						Some(v) => {
152							self.maxs.insert(group.clone(), Value::float8(v));
153						}
154						None => {
155							self.maxs.entry(group.clone()).or_insert(Value::none());
156						}
157					}
158				}
159				Ok(())
160			}
161			ColumnData::Int {
162				container,
163				..
164			} => {
165				for (group, indices) in groups.iter() {
166					let mut max: Option<Int> = None;
167					for &i in indices {
168						if column.data().is_defined(i) {
169							if let Some(val) = container.get(i) {
170								max = Some(match max {
171									Some(current) if *val > current => val.clone(),
172									Some(current) => current,
173									None => val.clone(),
174								});
175							}
176						}
177					}
178					match max {
179						Some(v) => {
180							self.maxs.insert(group.clone(), Value::Int(v));
181						}
182						None => {
183							self.maxs.entry(group.clone()).or_insert(Value::none());
184						}
185					}
186				}
187				Ok(())
188			}
189			ColumnData::Uint {
190				container,
191				..
192			} => {
193				for (group, indices) in groups.iter() {
194					let mut max: Option<Uint> = None;
195					for &i in indices {
196						if column.data().is_defined(i) {
197							if let Some(val) = container.get(i) {
198								max = Some(match max {
199									Some(current) if *val > current => val.clone(),
200									Some(current) => current,
201									None => val.clone(),
202								});
203							}
204						}
205					}
206					match max {
207						Some(v) => {
208							self.maxs.insert(group.clone(), Value::Uint(v));
209						}
210						None => {
211							self.maxs.entry(group.clone()).or_insert(Value::none());
212						}
213					}
214				}
215				Ok(())
216			}
217			ColumnData::Decimal {
218				container,
219				..
220			} => {
221				for (group, indices) in groups.iter() {
222					let mut max: Option<Decimal> = None;
223					for &i in indices {
224						if column.data().is_defined(i) {
225							if let Some(val) = container.get(i) {
226								max = Some(match max {
227									Some(current) if *val > current => val.clone(),
228									Some(current) => current,
229									None => val.clone(),
230								});
231							}
232						}
233					}
234					match max {
235						Some(v) => {
236							self.maxs.insert(group.clone(), Value::Decimal(v));
237						}
238						None => {
239							self.maxs.entry(group.clone()).or_insert(Value::none());
240						}
241					}
242				}
243				Ok(())
244			}
245			other => Err(AggregateFunctionError::InvalidArgumentType {
246				function: ctx.fragment.clone(),
247				argument_index: 0,
248				expected: self.accepted_types().expected_at(0).to_vec(),
249				actual: other.get_type(),
250			}),
251		}
252	}
253
254	fn finalize(&mut self) -> AggregateFunctionResult<(Vec<Vec<Value>>, ColumnData)> {
255		let ty = self.input_type.take().unwrap_or(Type::Float8);
256		let mut keys = Vec::with_capacity(self.maxs.len());
257		let mut data = ColumnData::with_capacity(ty, self.maxs.len());
258
259		for (key, max) in mem::take(&mut self.maxs) {
260			keys.push(key);
261			data.push_value(max);
262		}
263
264		Ok((keys, data))
265	}
266
267	fn return_type(&self, input_type: &Type) -> Type {
268		input_type.clone()
269	}
270
271	fn accepted_types(&self) -> InputTypes {
272		InputTypes::numeric()
273	}
274}