Skip to main content

reifydb_routine/function/math/
max.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8	ColumnWithName,
9	buffer::ColumnBuffer,
10	columns::Columns,
11	view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14	fragment::Fragment,
15	value::{
16		Value,
17		decimal::Decimal,
18		int::Int,
19		r#type::{Type, input_types::InputTypes},
20		uint::Uint,
21	},
22};
23
24use crate::routine::{
25	Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
26};
27
28pub struct Max {
29	info: RoutineInfo,
30}
31
32impl Default for Max {
33	fn default() -> Self {
34		Self::new()
35	}
36}
37
38impl Max {
39	pub fn new() -> Self {
40		Self {
41			info: RoutineInfo::new("math::max"),
42		}
43	}
44}
45
46impl<'a> Routine<FunctionContext<'a>> for Max {
47	fn info(&self) -> &RoutineInfo {
48		&self.info
49	}
50
51	fn return_type(&self, input_types: &[Type]) -> Type {
52		input_types.first().cloned().unwrap_or(Type::Float8)
53	}
54
55	fn accepted_types(&self) -> InputTypes {
56		InputTypes::numeric()
57	}
58
59	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
60		if args.is_empty() {
61			return Err(RoutineError::FunctionArityMismatch {
62				function: ctx.fragment.clone(),
63				expected: 1,
64				actual: 0,
65			});
66		}
67
68		for (i, col) in args.iter().enumerate() {
69			if !col.get_type().is_number() {
70				return Err(RoutineError::FunctionInvalidArgumentType {
71					function: ctx.fragment.clone(),
72					argument_index: i,
73					expected: InputTypes::numeric().expected_at(0).to_vec(),
74					actual: col.get_type(),
75				});
76			}
77		}
78
79		let row_count = args.row_count();
80		let input_type = args[0].get_type();
81		let mut data = ColumnBuffer::with_capacity(input_type, row_count);
82
83		for i in 0..row_count {
84			let mut row_max: Option<Value> = None;
85			for col in args.iter() {
86				if col.data().is_defined(i) {
87					let val = col.data().get_value(i);
88					row_max = Some(match row_max {
89						Some(current) if val > current => val,
90						Some(current) => current,
91						None => val,
92					});
93				}
94			}
95			data.push_value(row_max.unwrap_or(Value::none()));
96		}
97
98		Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), data)]))
99	}
100}
101
102impl Function for Max {
103	fn kinds(&self) -> &[FunctionKind] {
104		&[FunctionKind::Scalar, FunctionKind::Aggregate]
105	}
106
107	fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
108		Some(Box::new(MaxAccumulator::new()))
109	}
110}
111
112struct MaxAccumulator {
113	pub maxs: IndexMap<GroupKey, Value>,
114	input_type: Option<Type>,
115}
116
117impl MaxAccumulator {
118	pub fn new() -> Self {
119		Self {
120			maxs: IndexMap::new(),
121			input_type: None,
122		}
123	}
124}
125
126macro_rules! max_arm {
127	($self:expr, $column:expr, $groups:expr, $container:expr, $variant:ident) => {
128		for (group, indices) in $groups.iter() {
129			let mut max = None;
130			for &i in indices {
131				if $column.is_defined(i) {
132					if let Some(&val) = $container.get(i) {
133						max = Some(match max {
134							Some(current) if val > current => val,
135							Some(current) => current,
136							None => val,
137						});
138					}
139				}
140			}
141			if let Some(v) = max {
142				let merged = match $self.maxs.swap_remove(group) {
143					Some(Value::$variant(prev)) if prev > v => prev,
144					_ => v,
145				};
146				$self.maxs.insert(group.clone(), Value::$variant(merged));
147			} else {
148				$self.maxs.entry(group.clone()).or_insert(Value::none());
149			}
150		}
151	};
152}
153
154impl Accumulator for MaxAccumulator {
155	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
156		let column = &args[0];
157		let (data, _bitvec) = column.unwrap_option();
158
159		if self.input_type.is_none() {
160			self.input_type = Some(data.get_type());
161		}
162
163		match data {
164			ColumnBuffer::Int1(container) => {
165				max_arm!(self, column, groups, container, Int1);
166				Ok(())
167			}
168			ColumnBuffer::Int2(container) => {
169				max_arm!(self, column, groups, container, Int2);
170				Ok(())
171			}
172			ColumnBuffer::Int4(container) => {
173				max_arm!(self, column, groups, container, Int4);
174				Ok(())
175			}
176			ColumnBuffer::Int8(container) => {
177				max_arm!(self, column, groups, container, Int8);
178				Ok(())
179			}
180			ColumnBuffer::Int16(container) => {
181				max_arm!(self, column, groups, container, Int16);
182				Ok(())
183			}
184			ColumnBuffer::Uint1(container) => {
185				max_arm!(self, column, groups, container, Uint1);
186				Ok(())
187			}
188			ColumnBuffer::Uint2(container) => {
189				max_arm!(self, column, groups, container, Uint2);
190				Ok(())
191			}
192			ColumnBuffer::Uint4(container) => {
193				max_arm!(self, column, groups, container, Uint4);
194				Ok(())
195			}
196			ColumnBuffer::Uint8(container) => {
197				max_arm!(self, column, groups, container, Uint8);
198				Ok(())
199			}
200			ColumnBuffer::Uint16(container) => {
201				max_arm!(self, column, groups, container, Uint16);
202				Ok(())
203			}
204			ColumnBuffer::Float4(container) => {
205				for (group, indices) in groups.iter() {
206					let mut max: Option<f32> = None;
207					for &i in indices {
208						if column.is_defined(i)
209							&& let Some(&val) = container.get(i)
210						{
211							max = Some(match max {
212								Some(current) => f32::max(current, val),
213								None => val,
214							});
215						}
216					}
217					if let Some(v) = max {
218						let merged = match self.maxs.swap_remove(group) {
219							Some(Value::Float4(prev)) => f32::max(prev.value(), v),
220							_ => v,
221						};
222						self.maxs.insert(group.clone(), Value::float4(merged));
223					} else {
224						self.maxs.entry(group.clone()).or_insert(Value::none());
225					}
226				}
227				Ok(())
228			}
229			ColumnBuffer::Float8(container) => {
230				for (group, indices) in groups.iter() {
231					let mut max: Option<f64> = None;
232					for &i in indices {
233						if column.is_defined(i)
234							&& let Some(&val) = container.get(i)
235						{
236							max = Some(match max {
237								Some(current) => f64::max(current, val),
238								None => val,
239							});
240						}
241					}
242					if let Some(v) = max {
243						let merged = match self.maxs.swap_remove(group) {
244							Some(Value::Float8(prev)) => f64::max(prev.value(), v),
245							_ => v,
246						};
247						self.maxs.insert(group.clone(), Value::float8(merged));
248					} else {
249						self.maxs.entry(group.clone()).or_insert(Value::none());
250					}
251				}
252				Ok(())
253			}
254			ColumnBuffer::Int {
255				container,
256				..
257			} => {
258				for (group, indices) in groups.iter() {
259					let mut max: Option<Int> = None;
260					for &i in indices {
261						if column.is_defined(i)
262							&& let Some(val) = container.get(i)
263						{
264							max = Some(match max {
265								Some(current) if *val > current => val.clone(),
266								Some(current) => current,
267								None => val.clone(),
268							});
269						}
270					}
271					if let Some(v) = max {
272						let merged = match self.maxs.swap_remove(group) {
273							Some(Value::Int(prev)) if prev > v => prev,
274							_ => v,
275						};
276						self.maxs.insert(group.clone(), Value::Int(merged));
277					} else {
278						self.maxs.entry(group.clone()).or_insert(Value::none());
279					}
280				}
281				Ok(())
282			}
283			ColumnBuffer::Uint {
284				container,
285				..
286			} => {
287				for (group, indices) in groups.iter() {
288					let mut max: Option<Uint> = None;
289					for &i in indices {
290						if column.is_defined(i)
291							&& let Some(val) = container.get(i)
292						{
293							max = Some(match max {
294								Some(current) if *val > current => val.clone(),
295								Some(current) => current,
296								None => val.clone(),
297							});
298						}
299					}
300					if let Some(v) = max {
301						let merged = match self.maxs.swap_remove(group) {
302							Some(Value::Uint(prev)) if prev > v => prev,
303							_ => v,
304						};
305						self.maxs.insert(group.clone(), Value::Uint(merged));
306					} else {
307						self.maxs.entry(group.clone()).or_insert(Value::none());
308					}
309				}
310				Ok(())
311			}
312			ColumnBuffer::Decimal {
313				container,
314				..
315			} => {
316				for (group, indices) in groups.iter() {
317					let mut max: Option<Decimal> = None;
318					for &i in indices {
319						if column.is_defined(i)
320							&& let Some(val) = container.get(i)
321						{
322							max = Some(match max {
323								Some(current) if *val > current => val.clone(),
324								Some(current) => current,
325								None => val.clone(),
326							});
327						}
328					}
329					if let Some(v) = max {
330						let merged = match self.maxs.swap_remove(group) {
331							Some(Value::Decimal(prev)) if prev > v => prev,
332							_ => v,
333						};
334						self.maxs.insert(group.clone(), Value::Decimal(merged));
335					} else {
336						self.maxs.entry(group.clone()).or_insert(Value::none());
337					}
338				}
339				Ok(())
340			}
341			other => Err(RoutineError::FunctionInvalidArgumentType {
342				function: Fragment::internal("math::max"),
343				argument_index: 0,
344				expected: InputTypes::numeric().expected_at(0).to_vec(),
345				actual: other.get_type(),
346			}),
347		}
348	}
349
350	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
351		let ty = self.input_type.take().unwrap_or(Type::Float8);
352		let mut keys = Vec::with_capacity(self.maxs.len());
353		let mut data = ColumnBuffer::with_capacity(ty, self.maxs.len());
354
355		for (key, max) in mem::take(&mut self.maxs) {
356			keys.push(key);
357			data.push_value(max);
358		}
359
360		Ok((keys, data))
361	}
362}