reifydb_routine/function/math/
max.rs1use std::mem;
5
6use indexmap::IndexMap;
7use reifydb_core::value::column::{
8 ColumnWithName,
9 buffer::ColumnBuffer,
10 columns::Columns,
11 view::group_by::{GroupByView, GroupKey},
12};
13use reifydb_type::{
14 fragment::Fragment,
15 value::{
16 Value,
17 decimal::Decimal,
18 int::Int,
19 r#type::{Type, input_types::InputTypes},
20 uint::Uint,
21 },
22};
23
24use crate::routine::{
25 Accumulator, Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError,
26};
27
28pub struct Max {
29 info: RoutineInfo,
30}
31
32impl Default for Max {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl Max {
39 pub fn new() -> Self {
40 Self {
41 info: RoutineInfo::new("math::max"),
42 }
43 }
44}
45
46impl<'a> Routine<FunctionContext<'a>> for Max {
47 fn info(&self) -> &RoutineInfo {
48 &self.info
49 }
50
51 fn return_type(&self, input_types: &[Type]) -> Type {
52 input_types.first().cloned().unwrap_or(Type::Float8)
53 }
54
55 fn accepted_types(&self) -> InputTypes {
56 InputTypes::numeric()
57 }
58
59 fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
60 if args.is_empty() {
61 return Err(RoutineError::FunctionArityMismatch {
62 function: ctx.fragment.clone(),
63 expected: 1,
64 actual: 0,
65 });
66 }
67
68 for (i, col) in args.iter().enumerate() {
69 if !col.get_type().is_number() {
70 return Err(RoutineError::FunctionInvalidArgumentType {
71 function: ctx.fragment.clone(),
72 argument_index: i,
73 expected: InputTypes::numeric().expected_at(0).to_vec(),
74 actual: col.get_type(),
75 });
76 }
77 }
78
79 let row_count = args.row_count();
80 let input_type = args[0].get_type();
81 let mut data = ColumnBuffer::with_capacity(input_type, row_count);
82
83 for i in 0..row_count {
84 let mut row_max: Option<Value> = None;
85 for col in args.iter() {
86 if col.data().is_defined(i) {
87 let val = col.data().get_value(i);
88 row_max = Some(match row_max {
89 Some(current) if val > current => val,
90 Some(current) => current,
91 None => val,
92 });
93 }
94 }
95 data.push_value(row_max.unwrap_or(Value::none()));
96 }
97
98 Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), data)]))
99 }
100}
101
102impl Function for Max {
103 fn kinds(&self) -> &[FunctionKind] {
104 &[FunctionKind::Scalar, FunctionKind::Aggregate]
105 }
106
107 fn accumulator(&self, _ctx: &mut FunctionContext<'_>) -> Option<Box<dyn Accumulator>> {
108 Some(Box::new(MaxAccumulator::new()))
109 }
110}
111
112struct MaxAccumulator {
113 pub maxs: IndexMap<GroupKey, Value>,
114 input_type: Option<Type>,
115}
116
117impl MaxAccumulator {
118 pub fn new() -> Self {
119 Self {
120 maxs: IndexMap::new(),
121 input_type: None,
122 }
123 }
124}
125
126macro_rules! max_arm {
127 ($self:expr, $column:expr, $groups:expr, $container:expr, $ctor:expr) => {
128 for (group, indices) in $groups.iter() {
129 let mut max = None;
130 for &i in indices {
131 if $column.is_defined(i) {
132 if let Some(&val) = $container.get(i) {
133 max = Some(match max {
134 Some(current) if val > current => val,
135 Some(current) => current,
136 None => val,
137 });
138 }
139 }
140 }
141 if let Some(v) = max {
142 $self.maxs.insert(group.clone(), $ctor(v));
143 } else {
144 $self.maxs.entry(group.clone()).or_insert(Value::none());
145 }
146 }
147 };
148}
149
150impl Accumulator for MaxAccumulator {
151 fn update(&mut self, args: &Columns, groups: &GroupByView) -> Result<(), RoutineError> {
152 let column = &args[0];
153 let (data, _bitvec) = column.unwrap_option();
154
155 if self.input_type.is_none() {
156 self.input_type = Some(data.get_type());
157 }
158
159 match data {
160 ColumnBuffer::Int1(container) => {
161 max_arm!(self, column, groups, container, Value::Int1);
162 Ok(())
163 }
164 ColumnBuffer::Int2(container) => {
165 max_arm!(self, column, groups, container, Value::Int2);
166 Ok(())
167 }
168 ColumnBuffer::Int4(container) => {
169 max_arm!(self, column, groups, container, Value::Int4);
170 Ok(())
171 }
172 ColumnBuffer::Int8(container) => {
173 max_arm!(self, column, groups, container, Value::Int8);
174 Ok(())
175 }
176 ColumnBuffer::Int16(container) => {
177 max_arm!(self, column, groups, container, Value::Int16);
178 Ok(())
179 }
180 ColumnBuffer::Uint1(container) => {
181 max_arm!(self, column, groups, container, Value::Uint1);
182 Ok(())
183 }
184 ColumnBuffer::Uint2(container) => {
185 max_arm!(self, column, groups, container, Value::Uint2);
186 Ok(())
187 }
188 ColumnBuffer::Uint4(container) => {
189 max_arm!(self, column, groups, container, Value::Uint4);
190 Ok(())
191 }
192 ColumnBuffer::Uint8(container) => {
193 max_arm!(self, column, groups, container, Value::Uint8);
194 Ok(())
195 }
196 ColumnBuffer::Uint16(container) => {
197 max_arm!(self, column, groups, container, Value::Uint16);
198 Ok(())
199 }
200 ColumnBuffer::Float4(container) => {
201 for (group, indices) in groups.iter() {
202 let mut max: Option<f32> = None;
203 for &i in indices {
204 if column.is_defined(i)
205 && let Some(&val) = container.get(i)
206 {
207 max = Some(match max {
208 Some(current) => f32::max(current, val),
209 None => val,
210 });
211 }
212 }
213 if let Some(v) = max {
214 self.maxs.insert(group.clone(), Value::float4(v));
215 } else {
216 self.maxs.entry(group.clone()).or_insert(Value::none());
217 }
218 }
219 Ok(())
220 }
221 ColumnBuffer::Float8(container) => {
222 for (group, indices) in groups.iter() {
223 let mut max: Option<f64> = None;
224 for &i in indices {
225 if column.is_defined(i)
226 && let Some(&val) = container.get(i)
227 {
228 max = Some(match max {
229 Some(current) => f64::max(current, val),
230 None => val,
231 });
232 }
233 }
234 if let Some(v) = max {
235 self.maxs.insert(group.clone(), Value::float8(v));
236 } else {
237 self.maxs.entry(group.clone()).or_insert(Value::none());
238 }
239 }
240 Ok(())
241 }
242 ColumnBuffer::Int {
243 container,
244 ..
245 } => {
246 for (group, indices) in groups.iter() {
247 let mut max: Option<Int> = None;
248 for &i in indices {
249 if column.is_defined(i)
250 && let Some(val) = container.get(i)
251 {
252 max = Some(match max {
253 Some(current) if *val > current => val.clone(),
254 Some(current) => current,
255 None => val.clone(),
256 });
257 }
258 }
259 if let Some(v) = max {
260 self.maxs.insert(group.clone(), Value::Int(v));
261 } else {
262 self.maxs.entry(group.clone()).or_insert(Value::none());
263 }
264 }
265 Ok(())
266 }
267 ColumnBuffer::Uint {
268 container,
269 ..
270 } => {
271 for (group, indices) in groups.iter() {
272 let mut max: Option<Uint> = None;
273 for &i in indices {
274 if column.is_defined(i)
275 && let Some(val) = container.get(i)
276 {
277 max = Some(match max {
278 Some(current) if *val > current => val.clone(),
279 Some(current) => current,
280 None => val.clone(),
281 });
282 }
283 }
284 if let Some(v) = max {
285 self.maxs.insert(group.clone(), Value::Uint(v));
286 } else {
287 self.maxs.entry(group.clone()).or_insert(Value::none());
288 }
289 }
290 Ok(())
291 }
292 ColumnBuffer::Decimal {
293 container,
294 ..
295 } => {
296 for (group, indices) in groups.iter() {
297 let mut max: Option<Decimal> = None;
298 for &i in indices {
299 if column.is_defined(i)
300 && let Some(val) = container.get(i)
301 {
302 max = Some(match max {
303 Some(current) if *val > current => val.clone(),
304 Some(current) => current,
305 None => val.clone(),
306 });
307 }
308 }
309 if let Some(v) = max {
310 self.maxs.insert(group.clone(), Value::Decimal(v));
311 } else {
312 self.maxs.entry(group.clone()).or_insert(Value::none());
313 }
314 }
315 Ok(())
316 }
317 other => Err(RoutineError::FunctionInvalidArgumentType {
318 function: Fragment::internal("math::max"),
319 argument_index: 0,
320 expected: InputTypes::numeric().expected_at(0).to_vec(),
321 actual: other.get_type(),
322 }),
323 }
324 }
325
326 fn finalize(&mut self) -> Result<(Vec<GroupKey>, ColumnBuffer), RoutineError> {
327 let ty = self.input_type.take().unwrap_or(Type::Float8);
328 let mut keys = Vec::with_capacity(self.maxs.len());
329 let mut data = ColumnBuffer::with_capacity(ty, self.maxs.len());
330
331 for (key, max) in mem::take(&mut self.maxs) {
332 keys.push(key);
333 data.push_value(max);
334 }
335
336 Ok((keys, data))
337 }
338}