drasi_core/evaluation/functions/aggregation/
max.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, sync::Arc};
16
17use crate::{
18    evaluation::{
19        temporal_constants,
20        variable_value::{
21            duration::Duration, zoned_datetime::ZonedDateTime, zoned_time::ZonedTime,
22        },
23        {FunctionError, FunctionEvaluationError},
24    },
25    interface::ResultIndex,
26};
27
28use async_trait::async_trait;
29
30use drasi_query_ast::ast;
31
32use crate::evaluation::{
33    variable_value::float::Float, variable_value::VariableValue, ExpressionEvaluationContext,
34};
35
36use chrono::{DateTime, Duration as ChronoDuration, LocalResult};
37
38use super::{super::AggregatingFunction, lazy_sorted_set::LazySortedSet, Accumulator};
39
40#[derive(Clone)]
41pub struct Max {}
42
43#[async_trait]
44impl AggregatingFunction for Max {
45    fn initialize_accumulator(
46        &self,
47        _context: &ExpressionEvaluationContext,
48        expression: &ast::FunctionExpression,
49        grouping_keys: &Vec<VariableValue>,
50        index: Arc<dyn ResultIndex>,
51    ) -> Accumulator {
52        Accumulator::LazySortedSet(LazySortedSet::new(
53            expression.position_in_query,
54            grouping_keys,
55            index,
56        ))
57    }
58
59    fn accumulator_is_lazy(&self) -> bool {
60        true
61    }
62
63    async fn apply(
64        &self,
65        _context: &ExpressionEvaluationContext,
66        args: Vec<VariableValue>,
67        accumulator: &mut Accumulator,
68    ) -> Result<VariableValue, FunctionError> {
69        if args.len() != 1 {
70            return Err(FunctionError {
71                function_name: "Max".to_string(),
72                error: FunctionEvaluationError::InvalidArgumentCount,
73            });
74        }
75
76        log::info!("Applying Max with args: {:?}", args);
77        let accumulator = match accumulator {
78            Accumulator::LazySortedSet(accumulator) => accumulator,
79            _ => {
80                return Err(FunctionError {
81                    function_name: "Max".to_string(),
82                    error: FunctionEvaluationError::CorruptData,
83                })
84            }
85        };
86
87        match &args[0] {
88            VariableValue::Float(n) => {
89                let value = match n.as_f64() {
90                    Some(n) => n,
91                    None => {
92                        return Err(FunctionError {
93                            function_name: "Max".to_string(),
94                            error: FunctionEvaluationError::OverflowError,
95                        })
96                    }
97                };
98                accumulator.insert(value * -1.0).await;
99                match accumulator.get_head().await {
100                    Ok(Some(head)) => {
101                        Ok(VariableValue::Float(match Float::from_f64(head * -1.0) {
102                            Some(f) => f,
103                            None => {
104                                return Err(FunctionError {
105                                    function_name: "Max".to_string(),
106                                    error: FunctionEvaluationError::CorruptData,
107                                })
108                            }
109                        }))
110                    }
111                    Ok(None) => Ok(VariableValue::Null),
112                    Err(e) => Err(FunctionError {
113                        function_name: "Max".to_string(),
114                        error: FunctionEvaluationError::IndexError(e),
115                    }),
116                }
117            }
118            VariableValue::Integer(n) => {
119                let value = match n.as_i64() {
120                    Some(n) => n,
121                    None => {
122                        return Err(FunctionError {
123                            function_name: "Max".to_string(),
124                            error: FunctionEvaluationError::OverflowError,
125                        })
126                    }
127                };
128                accumulator.insert((value as f64) * -1.0).await;
129                match accumulator.get_head().await {
130                    Ok(Some(head)) => {
131                        Ok(VariableValue::Float(match Float::from_f64(head * -1.0) {
132                            Some(f) => f,
133                            None => {
134                                return Err(FunctionError {
135                                    function_name: "Max".to_string(),
136                                    error: FunctionEvaluationError::CorruptData,
137                                })
138                            }
139                        }))
140                    }
141                    Ok(None) => Ok(VariableValue::Null),
142                    Err(e) => Err(FunctionError {
143                        function_name: "Max".to_string(),
144                        error: FunctionEvaluationError::IndexError(e),
145                    }),
146                }
147            }
148            VariableValue::ZonedDateTime(zdt) => {
149                let value = zdt.datetime().timestamp_millis() as f64;
150                accumulator.insert(value * -1.0).await;
151                match accumulator.get_head().await {
152                    Ok(Some(head)) => Ok(VariableValue::ZonedDateTime(
153                        ZonedDateTime::from_epoch_millis((head * -1.0) as u64),
154                    )),
155                    Ok(None) => Ok(VariableValue::Null),
156                    Err(e) => Err(FunctionError {
157                        function_name: "Max".to_string(),
158                        error: FunctionEvaluationError::IndexError(e),
159                    }),
160                }
161            }
162            VariableValue::Duration(d) => {
163                let value = d.duration().num_milliseconds() as f64;
164                accumulator.insert(value * -1.0).await;
165                match accumulator.get_head().await {
166                    Ok(Some(head)) => Ok(VariableValue::Duration(Duration::new(
167                        ChronoDuration::milliseconds((head * -1.0) as i64),
168                        0,
169                        0,
170                    ))),
171                    Ok(None) => Ok(VariableValue::Null),
172                    Err(e) => Err(FunctionError {
173                        function_name: "Max".to_string(),
174                        error: FunctionEvaluationError::IndexError(e),
175                    }),
176                }
177            }
178            VariableValue::Date(d) => {
179                // For date (Chrono::NaiveDate), we can store the number of days since the epoch
180                let reference_date = *temporal_constants::EPOCH_NAIVE_DATE;
181                let days_since_epoch = d.signed_duration_since(reference_date).num_days() as f64;
182                accumulator.insert(days_since_epoch * -1.0).await;
183                match accumulator.get_head().await {
184                    Ok(Some(head)) => Ok(VariableValue::Date(
185                        reference_date + ChronoDuration::days((head * -1.0) as i64),
186                    )),
187                    Ok(None) => Ok(VariableValue::Null),
188                    Err(e) => Err(FunctionError {
189                        function_name: "Max".to_string(),
190                        error: FunctionEvaluationError::IndexError(e),
191                    }),
192                }
193            }
194            VariableValue::LocalTime(t) => {
195                let reference_time = *temporal_constants::MIDNIGHT_NAIVE_TIME;
196                let duration_since_midnight =
197                    t.signed_duration_since(reference_time).num_milliseconds() as f64;
198                accumulator.insert(duration_since_midnight * -1.0).await;
199
200                match accumulator.get_head().await {
201                    Ok(Some(head)) => Ok(VariableValue::LocalTime(
202                        reference_time + ChronoDuration::milliseconds((head * -1.0) as i64),
203                    )),
204                    Ok(None) => Ok(VariableValue::Null),
205                    Err(e) => Err(FunctionError {
206                        function_name: "Max".to_string(),
207                        error: FunctionEvaluationError::IndexError(e),
208                    }),
209                }
210            }
211            VariableValue::LocalDateTime(dt) => {
212                let duration_since_epoch = dt.and_utc().timestamp_millis() as f64;
213                accumulator.insert(duration_since_epoch * -1.0).await;
214                match accumulator.get_head().await {
215                    Ok(Some(head)) => Ok(VariableValue::LocalDateTime(
216                        DateTime::from_timestamp_millis(head as i64 * -1.0 as i64)
217                            .unwrap_or_default()
218                            .naive_local(),
219                    )),
220                    Ok(None) => Ok(VariableValue::Null),
221                    Err(e) => Err(FunctionError {
222                        function_name: "Max".to_string(),
223                        error: FunctionEvaluationError::IndexError(e),
224                    }),
225                }
226            }
227            VariableValue::ZonedTime(t) => {
228                let epoch_date = *temporal_constants::EPOCH_NAIVE_DATE;
229                let epoch_datetime = match epoch_date
230                    .and_time(*t.time())
231                    .and_local_timezone(*t.offset())
232                {
233                    LocalResult::Single(dt) => dt,
234                    _ => {
235                        return Err(FunctionError {
236                            function_name: "Max".to_string(),
237                            error: FunctionEvaluationError::InvalidFormat {
238                                expected: temporal_constants::INVALID_ZONED_TIME_FORMAT_ERROR
239                                    .to_string(),
240                            },
241                        })
242                    }
243                };
244                let duration_since_epoch = epoch_datetime.timestamp_millis() as f64;
245                accumulator.insert(duration_since_epoch * -1.0).await;
246                match accumulator.get_head().await {
247                    Ok(Some(head)) => Ok(VariableValue::ZonedTime(ZonedTime::new(
248                        (epoch_datetime + ChronoDuration::milliseconds((head * -1.0) as i64))
249                            .time(),
250                        *temporal_constants::UTC_FIXED_OFFSET,
251                    ))),
252                    Ok(None) => Ok(VariableValue::Null),
253                    Err(e) => Err(FunctionError {
254                        function_name: "Max".to_string(),
255                        error: FunctionEvaluationError::IndexError(e),
256                    }),
257                }
258            }
259            VariableValue::Null => Ok(VariableValue::Null),
260            _ => Err(FunctionError {
261                function_name: "Max".to_string(),
262                error: FunctionEvaluationError::InvalidArgument(0),
263            }),
264        }
265    }
266
267    async fn revert(
268        &self,
269        _context: &ExpressionEvaluationContext,
270        args: Vec<VariableValue>,
271        accumulator: &mut Accumulator,
272    ) -> Result<VariableValue, FunctionError> {
273        if args.len() != 1 {
274            return Err(FunctionError {
275                function_name: "Max".to_string(),
276                error: FunctionEvaluationError::InvalidArgumentCount,
277            });
278        }
279        let accumulator = match accumulator {
280            Accumulator::LazySortedSet(accumulator) => accumulator,
281            _ => {
282                return Err(FunctionError {
283                    function_name: "Max".to_string(),
284                    error: FunctionEvaluationError::CorruptData,
285                })
286            }
287        };
288
289        match &args[0] {
290            VariableValue::Float(n) => {
291                let value = match n.as_f64() {
292                    Some(n) => n,
293                    None => {
294                        return Err(FunctionError {
295                            function_name: "Max".to_string(),
296                            error: FunctionEvaluationError::OverflowError,
297                        })
298                    }
299                };
300                accumulator.remove(value * -1.0).await;
301                match accumulator.get_head().await {
302                    Ok(Some(head)) => {
303                        Ok(VariableValue::Float(match Float::from_f64(head * -1.0) {
304                            Some(f) => f,
305                            None => {
306                                return Err(FunctionError {
307                                    function_name: "Max".to_string(),
308                                    error: FunctionEvaluationError::CorruptData,
309                                })
310                            }
311                        }))
312                    }
313                    Ok(None) => Ok(VariableValue::Null),
314                    Err(e) => Err(FunctionError {
315                        function_name: "Max".to_string(),
316                        error: FunctionEvaluationError::IndexError(e),
317                    }),
318                }
319            }
320            VariableValue::Integer(n) => {
321                let value = match n.as_i64() {
322                    Some(n) => n,
323                    None => {
324                        return Err(FunctionError {
325                            function_name: "Max".to_string(),
326                            error: FunctionEvaluationError::OverflowError,
327                        })
328                    }
329                };
330                accumulator.remove((value as f64) * -1.0).await;
331                match accumulator.get_head().await {
332                    Ok(Some(head)) => {
333                        Ok(VariableValue::Float(match Float::from_f64(head * -1.0) {
334                            Some(f) => f,
335                            None => {
336                                return Err(FunctionError {
337                                    function_name: "Max".to_string(),
338                                    error: FunctionEvaluationError::CorruptData,
339                                })
340                            }
341                        }))
342                    }
343                    Ok(None) => Ok(VariableValue::Null),
344                    Err(e) => Err(FunctionError {
345                        function_name: "Max".to_string(),
346                        error: FunctionEvaluationError::IndexError(e),
347                    }),
348                }
349            }
350            VariableValue::ZonedDateTime(zdt) => {
351                let value = zdt.datetime().timestamp_millis() as f64;
352                accumulator.remove(value * -1.0).await;
353                match accumulator.get_head().await {
354                    Ok(Some(head)) => Ok(VariableValue::ZonedDateTime(
355                        ZonedDateTime::from_epoch_millis((head * -1.0) as u64),
356                    )),
357                    Ok(None) => Ok(VariableValue::Null),
358                    Err(e) => Err(FunctionError {
359                        function_name: "Max".to_string(),
360                        error: FunctionEvaluationError::IndexError(e),
361                    }),
362                }
363            }
364            VariableValue::Duration(d) => {
365                let value = d.duration().num_milliseconds() as f64;
366                accumulator.remove(value * -1.0).await;
367                match accumulator.get_head().await {
368                    Ok(Some(head)) => Ok(VariableValue::Duration(Duration::new(
369                        ChronoDuration::milliseconds((head * -1.0) as i64),
370                        0,
371                        0,
372                    ))),
373                    Ok(None) => Ok(VariableValue::Null),
374                    Err(e) => Err(FunctionError {
375                        function_name: "Max".to_string(),
376                        error: FunctionEvaluationError::IndexError(e),
377                    }),
378                }
379            }
380            VariableValue::Date(d) => {
381                // For date (Chrono::NaiveDate), we can store the number of days since the epoch
382                let reference_date = *temporal_constants::EPOCH_NAIVE_DATE;
383                let days_since_epoch = d.signed_duration_since(reference_date).num_days() as f64;
384                accumulator.remove(days_since_epoch * -1.0).await;
385                match accumulator.get_head().await {
386                    Ok(Some(head)) => Ok(VariableValue::Date(
387                        reference_date + ChronoDuration::days((head * -1.0) as i64),
388                    )),
389                    Ok(None) => Ok(VariableValue::Null),
390                    Err(e) => Err(FunctionError {
391                        function_name: "Max".to_string(),
392                        error: FunctionEvaluationError::IndexError(e),
393                    }),
394                }
395            }
396            VariableValue::LocalTime(t) => {
397                let reference_time = *temporal_constants::MIDNIGHT_NAIVE_TIME;
398                let duration_since_midnight =
399                    t.signed_duration_since(reference_time).num_milliseconds() as f64;
400                accumulator.remove(duration_since_midnight * -1.0).await;
401
402                match accumulator.get_head().await {
403                    Ok(Some(head)) => Ok(VariableValue::LocalTime(
404                        reference_time + ChronoDuration::milliseconds((head * -1.0) as i64),
405                    )),
406                    Ok(None) => Ok(VariableValue::Null),
407                    Err(e) => Err(FunctionError {
408                        function_name: "Max".to_string(),
409                        error: FunctionEvaluationError::IndexError(e),
410                    }),
411                }
412            }
413            VariableValue::LocalDateTime(dt) => {
414                let duration_since_epoch = dt.and_utc().timestamp_millis() as f64;
415                accumulator.remove(duration_since_epoch * -1.0).await;
416                match accumulator.get_head().await {
417                    Ok(Some(head)) => Ok(VariableValue::LocalDateTime(
418                        DateTime::from_timestamp_millis(head as i64 * -1.0 as i64)
419                            .unwrap_or_default()
420                            .naive_local(),
421                    )),
422                    Ok(None) => Ok(VariableValue::Null),
423                    Err(e) => Err(FunctionError {
424                        function_name: "Max".to_string(),
425                        error: FunctionEvaluationError::IndexError(e),
426                    }),
427                }
428            }
429            VariableValue::ZonedTime(t) => {
430                let epoch_date = *temporal_constants::EPOCH_NAIVE_DATE;
431                let epoch_datetime = match epoch_date
432                    .and_time(*t.time())
433                    .and_local_timezone(*t.offset())
434                {
435                    LocalResult::Single(dt) => dt,
436                    _ => {
437                        return Err(FunctionError {
438                            function_name: "Max".to_string(),
439                            error: FunctionEvaluationError::InvalidFormat {
440                                expected: temporal_constants::INVALID_ZONED_TIME_FORMAT_ERROR
441                                    .to_string(),
442                            },
443                        })
444                    }
445                };
446                let duration_since_epoch = epoch_datetime.timestamp_millis() as f64;
447                accumulator.remove(duration_since_epoch * -1.0).await;
448                match accumulator.get_head().await {
449                    Ok(Some(head)) => Ok(VariableValue::ZonedTime(ZonedTime::new(
450                        (epoch_datetime + ChronoDuration::milliseconds((head * -1.0) as i64))
451                            .time(),
452                        *temporal_constants::UTC_FIXED_OFFSET,
453                    ))),
454                    Ok(None) => Ok(VariableValue::Null),
455                    Err(e) => Err(FunctionError {
456                        function_name: "Max".to_string(),
457                        error: FunctionEvaluationError::IndexError(e),
458                    }),
459                }
460            }
461            VariableValue::Null => Ok(VariableValue::Null),
462            _ => Err(FunctionError {
463                function_name: "Max".to_string(),
464                error: FunctionEvaluationError::InvalidArgument(0),
465            }),
466        }
467    }
468
469    async fn snapshot(
470        &self,
471        _context: &ExpressionEvaluationContext,
472        args: Vec<VariableValue>,
473        accumulator: &Accumulator,
474    ) -> Result<VariableValue, FunctionError> {
475        if args.len() != 1 {
476            return Err(FunctionError {
477                function_name: "Max".to_string(),
478                error: FunctionEvaluationError::InvalidArgumentCount,
479            });
480        }
481
482        let accumulator = match accumulator {
483            Accumulator::LazySortedSet(accumulator) => accumulator,
484            _ => {
485                return Err(FunctionError {
486                    function_name: "Max".to_string(),
487                    error: FunctionEvaluationError::CorruptData,
488                })
489            }
490        };
491
492        let value = match accumulator.get_head().await {
493            Ok(Some(head)) => head * -1.0,
494            Ok(None) => return Ok(VariableValue::Null),
495            Err(e) => {
496                return Err(FunctionError {
497                    function_name: "Max".to_string(),
498                    error: FunctionEvaluationError::IndexError(e),
499                })
500            }
501        };
502
503        return match &args[0] {
504            VariableValue::Float(_) => Ok(VariableValue::Float(match Float::from_f64(value) {
505                Some(f) => f,
506                None => {
507                    return Err(FunctionError {
508                        function_name: "Max".to_string(),
509                        error: FunctionEvaluationError::OverflowError,
510                    })
511                }
512            })),
513            VariableValue::Integer(_) => Ok(VariableValue::Integer((value as i64).into())),
514            VariableValue::ZonedDateTime(_) => Ok(VariableValue::ZonedDateTime(
515                ZonedDateTime::from_epoch_millis(value as u64),
516            )),
517            VariableValue::Duration(_) => Ok(VariableValue::Duration(Duration::new(
518                ChronoDuration::milliseconds(value as i64),
519                0,
520                0,
521            ))),
522            VariableValue::Date(_) => {
523                let reference_date = *temporal_constants::EPOCH_NAIVE_DATE;
524                Ok(VariableValue::Date(
525                    reference_date + ChronoDuration::days(value as i64),
526                ))
527            }
528            VariableValue::LocalTime(_) => {
529                let reference_time = *temporal_constants::MIDNIGHT_NAIVE_TIME;
530                Ok(VariableValue::LocalTime(
531                    reference_time + ChronoDuration::milliseconds(value as i64),
532                ))
533            }
534            VariableValue::LocalDateTime(_) => Ok(VariableValue::LocalDateTime(
535                DateTime::from_timestamp_millis(value as i64)
536                    .unwrap_or_default()
537                    .naive_local(),
538            )),
539            VariableValue::ZonedTime(_) => {
540                let epoch_date = *temporal_constants::EPOCH_NAIVE_DATE;
541                let epoch_datetime = match epoch_date
542                    .and_time(*temporal_constants::MIDNIGHT_NAIVE_TIME)
543                    .and_local_timezone(*temporal_constants::UTC_FIXED_OFFSET)
544                {
545                    LocalResult::Single(dt) => dt,
546                    _ => {
547                        return Err(FunctionError {
548                            function_name: "Max".to_string(),
549                            error: FunctionEvaluationError::InvalidFormat {
550                                expected: temporal_constants::INVALID_ZONED_TIME_FORMAT_ERROR
551                                    .to_string(),
552                            },
553                        })
554                    }
555                };
556                Ok(VariableValue::ZonedTime(ZonedTime::new(
557                    (epoch_datetime + ChronoDuration::milliseconds(value as i64)).time(),
558                    *temporal_constants::UTC_FIXED_OFFSET,
559                )))
560            }
561            VariableValue::Null => Ok(VariableValue::Null),
562            _ => Err(FunctionError {
563                function_name: "Max".to_string(),
564                error: FunctionEvaluationError::InvalidArgument(0),
565            }),
566        };
567    }
568}
569
570impl Debug for Max {
571    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
572        write!(f, "Max")
573    }
574}