Skip to main content

reifydb_routine/function/math/
max.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8	ColumnWithName,
9	buffer::ColumnBuffer,
10	columns::Columns,
11	view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14	fragment::Fragment,
15	value::{
16		Value,
17		decimal::Decimal,
18		int::Int,
19		r#type::{Type, input_types::InputTypes},
20		uint::Uint,
21	},
22};
23
24use crate::routine::{
25	Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
26};
27
28pub struct Max {
29	info: RoutineInfo,
30}
31
32impl Default for Max {
33	fn default() -> Self {
34		Self::new()
35	}
36}
37
38impl Max {
39	pub fn new() -> Self {
40		Self {
41			info: RoutineInfo::new("math::max"),
42		}
43	}
44}
45
46impl<'a> Routine<FunctionContext<'a>> for Max {
47	fn info(&self) -> &RoutineInfo {
48		&self.info
49	}
50
51	fn return_type(&self, input_types: &[Type]) -> Type {
52		input_types.first().cloned().unwrap_or(Type::Float8)
53	}
54
55	fn accepted_types(&self) -> InputTypes {
56		InputTypes::numeric()
57	}
58
59	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
60		if args.is_empty() {
61			return Err(RoutineError::FunctionArityMismatch {
62				function: ctx.fragment.clone(),
63				expected: 1,
64				actual: 0,
65			});
66		}
67
68		for (i, col) in args.iter().enumerate() {
69			if !col.get_type().is_number() {
70				return Err(RoutineError::FunctionInvalidArgumentType {
71					function: ctx.fragment.clone(),
72					argument_index: i,
73					expected: InputTypes::numeric().expected_at(0).to_vec(),
74					actual: col.get_type(),
75				});
76			}
77		}
78
79		let row_count = args.row_count();
80		let input_type = args[0].get_type();
81		let mut data = ColumnBuffer::with_capacity(input_type, row_count);
82
83		for i in 0..row_count {
84			let mut row_max: Option<Value> = None;
85			for col in args.iter() {
86				if col.data().is_defined(i) {
87					let val = col.data().get_value(i);
88					row_max = Some(match row_max {
89						Some(current) if val > current => val,
90						Some(current) => current,
91						None => val,
92					});
93				}
94			}
95			data.push_value(row_max.unwrap_or(Value::none()));
96		}
97
98		Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), data)]))
99	}
100}
101
102impl Function for Max {
103	fn kinds(&self) -> &[FunctionKind] {
104		&[FunctionKind::Scalar, FunctionKind::Aggregate]
105	}
106
107	fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
108		Some(Box::new(MaxAccumulator::new()))
109	}
110}
111
112struct MaxAccumulator {
113	pub maxs: IndexMap<GroupKey, Value>,
114	input_type: Option<Type>,
115}
116
117impl MaxAccumulator {
118	pub fn new() -> Self {
119		Self {
120			maxs: IndexMap::new(),
121			input_type: None,
122		}
123	}
124}
125
126macro_rules! max_arm {
127	($self:expr, $column:expr, $groups:expr, $container:expr, $ctor:expr) => {
128		for (group, indices) in $groups.iter() {
129			let mut max = None;
130			for &i in indices {
131				if $column.is_defined(i) {
132					if let Some(&val) = $container.get(i) {
133						max = Some(match max {
134							Some(current) if val > current => val,
135							Some(current) => current,
136							None => val,
137						});
138					}
139				}
140			}
141			if let Some(v) = max {
142				$self.maxs.insert(group.clone(), $ctor(v));
143			} else {
144				$self.maxs.entry(group.clone()).or_insert(Value::none());
145			}
146		}
147	};
148}
149
150impl Accumulator for MaxAccumulator {
151	fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
152		let column = &args[0];
153		let (data, _bitvec) = column.unwrap_option();
154
155		if self.input_type.is_none() {
156			self.input_type = Some(data.get_type());
157		}
158
159		match data {
160			ColumnBuffer::Int1(container) => {
161				max_arm!(self, column, groups, container, Value::Int1);
162				Ok(())
163			}
164			ColumnBuffer::Int2(container) => {
165				max_arm!(self, column, groups, container, Value::Int2);
166				Ok(())
167			}
168			ColumnBuffer::Int4(container) => {
169				max_arm!(self, column, groups, container, Value::Int4);
170				Ok(())
171			}
172			ColumnBuffer::Int8(container) => {
173				max_arm!(self, column, groups, container, Value::Int8);
174				Ok(())
175			}
176			ColumnBuffer::Int16(container) => {
177				max_arm!(self, column, groups, container, Value::Int16);
178				Ok(())
179			}
180			ColumnBuffer::Uint1(container) => {
181				max_arm!(self, column, groups, container, Value::Uint1);
182				Ok(())
183			}
184			ColumnBuffer::Uint2(container) => {
185				max_arm!(self, column, groups, container, Value::Uint2);
186				Ok(())
187			}
188			ColumnBuffer::Uint4(container) => {
189				max_arm!(self, column, groups, container, Value::Uint4);
190				Ok(())
191			}
192			ColumnBuffer::Uint8(container) => {
193				max_arm!(self, column, groups, container, Value::Uint8);
194				Ok(())
195			}
196			ColumnBuffer::Uint16(container) => {
197				max_arm!(self, column, groups, container, Value::Uint16);
198				Ok(())
199			}
200			ColumnBuffer::Float4(container) => {
201				for (group, indices) in groups.iter() {
202					let mut max: Option<f32> = None;
203					for &i in indices {
204						if column.is_defined(i)
205							&& let Some(&val) = container.get(i)
206						{
207							max = Some(match max {
208								Some(current) => f32::max(current, val),
209								None => val,
210							});
211						}
212					}
213					if let Some(v) = max {
214						self.maxs.insert(group.clone(), Value::float4(v));
215					} else {
216						self.maxs.entry(group.clone()).or_insert(Value::none());
217					}
218				}
219				Ok(())
220			}
221			ColumnBuffer::Float8(container) => {
222				for (group, indices) in groups.iter() {
223					let mut max: Option<f64> = None;
224					for &i in indices {
225						if column.is_defined(i)
226							&& let Some(&val) = container.get(i)
227						{
228							max = Some(match max {
229								Some(current) => f64::max(current, val),
230								None => val,
231							});
232						}
233					}
234					if let Some(v) = max {
235						self.maxs.insert(group.clone(), Value::float8(v));
236					} else {
237						self.maxs.entry(group.clone()).or_insert(Value::none());
238					}
239				}
240				Ok(())
241			}
242			ColumnBuffer::Int {
243				container,
244				..
245			} => {
246				for (group, indices) in groups.iter() {
247					let mut max: Option<Int> = None;
248					for &i in indices {
249						if column.is_defined(i)
250							&& let Some(val) = container.get(i)
251						{
252							max = Some(match max {
253								Some(current) if *val > current => val.clone(),
254								Some(current) => current,
255								None => val.clone(),
256							});
257						}
258					}
259					if let Some(v) = max {
260						self.maxs.insert(group.clone(), Value::Int(v));
261					} else {
262						self.maxs.entry(group.clone()).or_insert(Value::none());
263					}
264				}
265				Ok(())
266			}
267			ColumnBuffer::Uint {
268				container,
269				..
270			} => {
271				for (group, indices) in groups.iter() {
272					let mut max: Option<Uint> = None;
273					for &i in indices {
274						if column.is_defined(i)
275							&& let Some(val) = container.get(i)
276						{
277							max = Some(match max {
278								Some(current) if *val > current => val.clone(),
279								Some(current) => current,
280								None => val.clone(),
281							});
282						}
283					}
284					if let Some(v) = max {
285						self.maxs.insert(group.clone(), Value::Uint(v));
286					} else {
287						self.maxs.entry(group.clone()).or_insert(Value::none());
288					}
289				}
290				Ok(())
291			}
292			ColumnBuffer::Decimal {
293				container,
294				..
295			} => {
296				for (group, indices) in groups.iter() {
297					let mut max: Option<Decimal> = None;
298					for &i in indices {
299						if column.is_defined(i)
300							&& let Some(val) = container.get(i)
301						{
302							max = Some(match max {
303								Some(current) if *val > current => val.clone(),
304								Some(current) => current,
305								None => val.clone(),
306							});
307						}
308					}
309					if let Some(v) = max {
310						self.maxs.insert(group.clone(), Value::Decimal(v));
311					} else {
312						self.maxs.entry(group.clone()).or_insert(Value::none());
313					}
314				}
315				Ok(())
316			}
317			other => Err(RoutineError::FunctionInvalidArgumentType {
318				function: Fragment::internal("math::max"),
319				argument_index: 0,
320				expected: InputTypes::numeric().expected_at(0).to_vec(),
321				actual: other.get_type(),
322			}),
323		}
324	}
325
326	fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
327		let ty = self.input_type.take().unwrap_or(Type::Float8);
328		let mut keys = Vec::with_capacity(self.maxs.len());
329		let mut data = ColumnBuffer::with_capacity(ty, self.maxs.len());
330
331		for (key, max) in mem::take(&mut self.maxs) {
332			keys.push(key);
333			data.push_value(max);
334		}
335
336		Ok((keys, data))
337	}
338}