Skip to main content

drasi_core/evaluation/functions/numeric/
numeric_round.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use drasi_query_ast::ast;
17
18use crate::evaluation::functions::ScalarFunction;
19use crate::evaluation::variable_value::float::Float;
20use crate::evaluation::variable_value::integer::Integer;
21use crate::evaluation::variable_value::VariableValue;
22use crate::evaluation::{ExpressionEvaluationContext, FunctionError, FunctionEvaluationError};
23use std::collections::HashSet;
24
25use round::{round_down, round_up};
26
27#[derive(Debug)]
28pub struct Round {}
29
30#[async_trait]
31impl ScalarFunction for Round {
32    async fn call(
33        &self,
34        _context: &ExpressionEvaluationContext,
35        expression: &ast::FunctionExpression,
36        args: Vec<VariableValue>,
37    ) -> Result<VariableValue, FunctionError> {
38        if args.is_empty() || args.len() > 3 {
39            return Err(FunctionError {
40                function_name: expression.name.to_string(),
41                error: FunctionEvaluationError::InvalidArgumentCount,
42            });
43        }
44        if args.contains(&VariableValue::Null) {
45            return Ok(VariableValue::Null);
46        }
47        match args.len() {
48            1 => {
49                match &args[0] {
50                    VariableValue::Null => Ok(VariableValue::Null),
51                    VariableValue::Integer(n) => Ok(VariableValue::Float(
52                        match Float::from_f64(match n.as_i64() {
53                            Some(i) => i as f64,
54                            None => {
55                                return Err(FunctionError {
56                                    function_name: expression.name.to_string(),
57                                    error: FunctionEvaluationError::OverflowError,
58                                })
59                            }
60                        }) {
61                            Some(f) => f,
62                            None => {
63                                return Err(FunctionError {
64                                    function_name: expression.name.to_string(),
65                                    error: FunctionEvaluationError::OverflowError,
66                                })
67                            }
68                        },
69                    )),
70                    VariableValue::Float(n) => {
71                        let input_as_f64 = match n.as_f64() {
72                            Some(f) => f,
73                            None => {
74                                return Err(FunctionError {
75                                    function_name: expression.name.to_string(),
76                                    error: FunctionEvaluationError::OverflowError,
77                                })
78                            }
79                        };
80                        if input_as_f64.fract() == -0.5 {
81                            //Cypher edge case
82                            return Ok(VariableValue::Float(
83                                match Float::from_f64(input_as_f64.trunc()) {
84                                    Some(f) => f,
85                                    None => {
86                                        return Err(FunctionError {
87                                            function_name: expression.name.to_string(),
88                                            error: FunctionEvaluationError::OverflowError,
89                                        })
90                                    }
91                                },
92                            ));
93                        }
94                        Ok(VariableValue::Float(
95                            match Float::from_f64(match n.as_f64() {
96                                Some(f) => f.round(),
97                                None => {
98                                    return Err(FunctionError {
99                                        function_name: expression.name.to_string(),
100                                        error: FunctionEvaluationError::OverflowError,
101                                    })
102                                }
103                            }) {
104                                Some(f) => f,
105                                None => {
106                                    return Err(FunctionError {
107                                        function_name: expression.name.to_string(),
108                                        error: FunctionEvaluationError::OverflowError,
109                                    })
110                                }
111                            },
112                        ))
113                    }
114                    _ => Err(FunctionError {
115                        function_name: expression.name.to_string(),
116                        error: FunctionEvaluationError::InvalidArgument(0),
117                    }),
118                }
119            }
120            2 => {
121                match (&args[0], &args[1]) {
122                    (VariableValue::Float(n), VariableValue::Integer(p)) => {
123                        let multiplier = 10.0_f64.powi(match p.as_i64() {
124                            Some(i) => i as i32,
125                            None => {
126                                return Err(FunctionError {
127                                    function_name: expression.name.to_string(),
128                                    error: FunctionEvaluationError::OverflowError,
129                                })
130                            }
131                        });
132                        //edge case
133
134                        if match n.as_f64() {
135                            Some(f) => f,
136                            None => {
137                                return Err(FunctionError {
138                                    function_name: expression.name.to_string(),
139                                    error: FunctionEvaluationError::OverflowError,
140                                })
141                            }
142                        } > f64::MAX / multiplier
143                        {
144                            return Ok(VariableValue::Float(n.clone()));
145                        }
146                        let intermediate_value = match n.as_f64() {
147                            Some(f) => f,
148                            None => {
149                                return Err(FunctionError {
150                                    function_name: expression.name.to_string(),
151                                    error: FunctionEvaluationError::OverflowError,
152                                })
153                            }
154                        } * multiplier;
155                        if intermediate_value.is_sign_negative()
156                            && intermediate_value.fract() == -0.5
157                        {
158                            //Cypher edge case
159                            let rounded_value = intermediate_value.trunc() / multiplier;
160                            return Ok(VariableValue::Float(
161                                match Float::from_f64(rounded_value) {
162                                    Some(f) => f,
163                                    None => {
164                                        return Err(FunctionError {
165                                            function_name: expression.name.to_string(),
166                                            error: FunctionEvaluationError::OverflowError,
167                                        })
168                                    }
169                                },
170                            ));
171                        }
172                        let rounded_value = (match n.as_f64() {
173                            Some(f) => f,
174                            None => {
175                                return Err(FunctionError {
176                                    function_name: expression.name.to_string(),
177                                    error: FunctionEvaluationError::OverflowError,
178                                })
179                            }
180                        } * multiplier)
181                            .round()
182                            / multiplier;
183                        Ok(VariableValue::Float(match Float::from_f64(rounded_value) {
184                            Some(f) => f,
185                            None => {
186                                return Err(FunctionError {
187                                    function_name: expression.name.to_string(),
188                                    error: FunctionEvaluationError::OverflowError,
189                                })
190                            }
191                        }))
192                    }
193                    (VariableValue::Integer(n), VariableValue::Integer(_p)) => {
194                        Ok(VariableValue::Integer(Integer::from(match n.as_i64() {
195                            Some(i) => i,
196                            None => {
197                                return Err(FunctionError {
198                                    function_name: expression.name.to_string(),
199                                    error: FunctionEvaluationError::OverflowError,
200                                })
201                            }
202                        })))
203                    }
204                    (VariableValue::Float(_n), _) => {
205                        return Err(FunctionError {
206                            function_name: expression.name.to_string(),
207                            error: FunctionEvaluationError::InvalidArgument(1),
208                        });
209                    }
210                    (VariableValue::Integer(_n), _) => {
211                        return Err(FunctionError {
212                            function_name: expression.name.to_string(),
213                            error: FunctionEvaluationError::InvalidArgument(1),
214                        });
215                    }
216                    _ => {
217                        return Err(FunctionError {
218                            function_name: expression.name.to_string(),
219                            error: FunctionEvaluationError::InvalidArgument(0),
220                        });
221                    }
222                }
223            }
224            3 => {
225                match (&args[0], &args[1], &args[2]) {
226                    (
227                        VariableValue::Float(n),
228                        VariableValue::Integer(p),
229                        VariableValue::String(m),
230                    ) => {
231                        let valid_modes: HashSet<String> = [
232                            "UP",
233                            "DOWN",
234                            "CEILING",
235                            "FLOOR",
236                            "HALF_UP",
237                            "HALF_DOWN",
238                            "HALF_EVEN",
239                        ]
240                        .iter()
241                        .map(|s| s.to_string())
242                        .collect();
243                        // let valid_keys: HashSet<String> = vec!["year", "month", "week", "day", "ordinalDay", "quarter", "dayOfWeek", "dayOfQuarter"].iter().map(|s| s.to_string()).collect();
244                        let mode = m.to_uppercase();
245                        if !valid_modes.contains(&mode) {
246                            return Err(FunctionError {
247                                function_name: expression.name.to_string(),
248                                error: FunctionEvaluationError::InvalidArgument(2),
249                            });
250                        }
251                        let is_positive = match n.as_f64() {
252                            Some(f) => f,
253                            None => {
254                                return Err(FunctionError {
255                                    function_name: expression.name.to_string(),
256                                    error: FunctionEvaluationError::OverflowError,
257                                })
258                            }
259                        }
260                        .is_sign_positive();
261                        match mode.as_str() {
262                            "UP" => {
263                                if is_positive {
264                                    let result = round_up(
265                                        match n.as_f64() {
266                                            Some(f) => f,
267                                            None => {
268                                                return Err(FunctionError {
269                                                    function_name: expression.name.to_string(),
270                                                    error: FunctionEvaluationError::OverflowError,
271                                                })
272                                            }
273                                        },
274                                        match p.as_i64() {
275                                            Some(i) => i as i32,
276                                            None => {
277                                                return Err(FunctionError {
278                                                    function_name: expression.name.to_string(),
279                                                    error: FunctionEvaluationError::OverflowError,
280                                                })
281                                            }
282                                        },
283                                    );
284                                    return Ok(VariableValue::Float(
285                                        match Float::from_f64(result) {
286                                            Some(f) => f,
287                                            None => {
288                                                return Err(FunctionError {
289                                                    function_name: expression.name.to_string(),
290                                                    error: FunctionEvaluationError::OverflowError,
291                                                })
292                                            }
293                                        },
294                                    ));
295                                } else {
296                                    //Cypher being weird :)
297                                    let result = round_down(
298                                        match n.as_f64() {
299                                            Some(f) => f,
300                                            None => {
301                                                return Err(FunctionError {
302                                                    function_name: expression.name.to_string(),
303                                                    error: FunctionEvaluationError::OverflowError,
304                                                })
305                                            }
306                                        },
307                                        match p.as_i64() {
308                                            Some(i) => i as i32,
309                                            None => {
310                                                return Err(FunctionError {
311                                                    function_name: expression.name.to_string(),
312                                                    error: FunctionEvaluationError::OverflowError,
313                                                })
314                                            }
315                                        },
316                                    );
317                                    return Ok(VariableValue::Float(
318                                        match Float::from_f64(result) {
319                                            Some(f) => f,
320                                            None => {
321                                                return Err(FunctionError {
322                                                    function_name: expression.name.to_string(),
323                                                    error: FunctionEvaluationError::OverflowError,
324                                                })
325                                            }
326                                        },
327                                    ));
328                                }
329                            }
330                            "DOWN" => {
331                                if is_positive {
332                                    let result = round_down(
333                                        match n.as_f64() {
334                                            Some(f) => f,
335                                            None => {
336                                                return Err(FunctionError {
337                                                    function_name: expression.name.to_string(),
338                                                    error: FunctionEvaluationError::OverflowError,
339                                                })
340                                            }
341                                        },
342                                        match p.as_i64() {
343                                            Some(i) => i as i32,
344                                            None => {
345                                                return Err(FunctionError {
346                                                    function_name: expression.name.to_string(),
347                                                    error: FunctionEvaluationError::OverflowError,
348                                                })
349                                            }
350                                        },
351                                    );
352                                    return Ok(VariableValue::Float(
353                                        match Float::from_f64(result) {
354                                            Some(f) => f,
355                                            None => {
356                                                return Err(FunctionError {
357                                                    function_name: expression.name.to_string(),
358                                                    error: FunctionEvaluationError::OverflowError,
359                                                })
360                                            }
361                                        },
362                                    ));
363                                } else {
364                                    let result = round_up(
365                                        match n.as_f64() {
366                                            Some(f) => f,
367                                            None => {
368                                                return Err(FunctionError {
369                                                    function_name: expression.name.to_string(),
370                                                    error: FunctionEvaluationError::OverflowError,
371                                                })
372                                            }
373                                        },
374                                        match p.as_i64() {
375                                            Some(i) => i as i32,
376                                            None => {
377                                                return Err(FunctionError {
378                                                    function_name: expression.name.to_string(),
379                                                    error: FunctionEvaluationError::OverflowError,
380                                                })
381                                            }
382                                        },
383                                    );
384                                    return Ok(VariableValue::Float(
385                                        match Float::from_f64(result) {
386                                            Some(f) => f,
387                                            None => {
388                                                return Err(FunctionError {
389                                                    function_name: expression.name.to_string(),
390                                                    error: FunctionEvaluationError::OverflowError,
391                                                })
392                                            }
393                                        },
394                                    ));
395                                }
396                            }
397                            "CEILING" => {
398                                let result = round_up(
399                                    match n.as_f64() {
400                                        Some(f) => f,
401                                        None => {
402                                            return Err(FunctionError {
403                                                function_name: expression.name.to_string(),
404                                                error: FunctionEvaluationError::OverflowError,
405                                            })
406                                        }
407                                    },
408                                    match p.as_i64() {
409                                        Some(i) => i as i32,
410                                        None => {
411                                            return Err(FunctionError {
412                                                function_name: expression.name.to_string(),
413                                                error: FunctionEvaluationError::OverflowError,
414                                            })
415                                        }
416                                    },
417                                );
418                                return Ok(VariableValue::Float(match Float::from_f64(result) {
419                                    Some(f) => f,
420                                    None => {
421                                        return Err(FunctionError {
422                                            function_name: expression.name.to_string(),
423                                            error: FunctionEvaluationError::OverflowError,
424                                        })
425                                    }
426                                }));
427                            }
428                            "FLOOR" => {
429                                let result = round_down(
430                                    match n.as_f64() {
431                                        Some(f) => f,
432                                        None => {
433                                            return Err(FunctionError {
434                                                function_name: expression.name.to_string(),
435                                                error: FunctionEvaluationError::OverflowError,
436                                            })
437                                        }
438                                    },
439                                    match p.as_i64() {
440                                        Some(i) => i as i32,
441                                        None => {
442                                            return Err(FunctionError {
443                                                function_name: expression.name.to_string(),
444                                                error: FunctionEvaluationError::OverflowError,
445                                            })
446                                        }
447                                    },
448                                );
449                                return Ok(VariableValue::Float(match Float::from_f64(result) {
450                                    Some(f) => f,
451                                    None => {
452                                        return Err(FunctionError {
453                                            function_name: expression.name.to_string(),
454                                            error: FunctionEvaluationError::OverflowError,
455                                        })
456                                    }
457                                }));
458                            }
459                            "HALF_UP" => {
460                                let multiplier = 10.0_f64.powi(match p.as_i64() {
461                                    Some(i) => i as i32,
462                                    None => {
463                                        return Err(FunctionError {
464                                            function_name: expression.name.to_string(),
465                                            error: FunctionEvaluationError::OverflowError,
466                                        })
467                                    }
468                                });
469                                if match n.as_f64() {
470                                    Some(f) => f,
471                                    None => {
472                                        return Err(FunctionError {
473                                            function_name: expression.name.to_string(),
474                                            error: FunctionEvaluationError::OverflowError,
475                                        })
476                                    }
477                                } > f64::MAX / multiplier
478                                {
479                                    return Ok(VariableValue::Float(n.clone()));
480                                }
481                                let intermediate_value = match n.as_f64() {
482                                    Some(f) => f,
483                                    None => {
484                                        return Err(FunctionError {
485                                            function_name: expression.name.to_string(),
486                                            error: FunctionEvaluationError::OverflowError,
487                                        })
488                                    }
489                                } * multiplier;
490                                if intermediate_value.fract() == 0.5 {
491                                    let rounded_value =
492                                        (intermediate_value.trunc() + 1.0) / multiplier;
493                                    return Ok(VariableValue::Float(
494                                        match Float::from_f64(rounded_value) {
495                                            Some(f) => f,
496                                            None => {
497                                                return Err(FunctionError {
498                                                    function_name: expression.name.to_string(),
499                                                    error: FunctionEvaluationError::OverflowError,
500                                                })
501                                            }
502                                        },
503                                    ));
504                                } else if intermediate_value.fract() == -0.5 {
505                                    let rounded_value =
506                                        (intermediate_value.trunc() - 1.0) / multiplier;
507                                    return Ok(VariableValue::Float(
508                                        match Float::from_f64(rounded_value) {
509                                            Some(f) => f,
510                                            None => {
511                                                return Err(FunctionError {
512                                                    function_name: expression.name.to_string(),
513                                                    error: FunctionEvaluationError::OverflowError,
514                                                })
515                                            }
516                                        },
517                                    ));
518                                }
519                                let rounded_value = (match n.as_f64() {
520                                    Some(f) => f,
521                                    None => {
522                                        return Err(FunctionError {
523                                            function_name: expression.name.to_string(),
524                                            error: FunctionEvaluationError::OverflowError,
525                                        })
526                                    }
527                                } * multiplier)
528                                    .round()
529                                    / multiplier;
530                                return Ok(VariableValue::Float(
531                                    match Float::from_f64(rounded_value) {
532                                        Some(f) => f,
533                                        None => {
534                                            return Err(FunctionError {
535                                                function_name: expression.name.to_string(),
536                                                error: FunctionEvaluationError::OverflowError,
537                                            })
538                                        }
539                                    },
540                                ));
541                            }
542                            "HALF_DOWN" => {
543                                let multiplier = 10.0_f64.powi(match p.as_i64() {
544                                    Some(i) => i as i32,
545                                    None => {
546                                        return Err(FunctionError {
547                                            function_name: expression.name.to_string(),
548                                            error: FunctionEvaluationError::OverflowError,
549                                        })
550                                    }
551                                });
552                                if match n.as_f64() {
553                                    Some(f) => f,
554                                    None => {
555                                        return Err(FunctionError {
556                                            function_name: expression.name.to_string(),
557                                            error: FunctionEvaluationError::OverflowError,
558                                        })
559                                    }
560                                } > f64::MAX / multiplier
561                                {
562                                    return Ok(VariableValue::Float(n.clone()));
563                                }
564                                let intermediate_value = match n.as_f64() {
565                                    Some(f) => f,
566                                    None => {
567                                        return Err(FunctionError {
568                                            function_name: expression.name.to_string(),
569                                            error: FunctionEvaluationError::OverflowError,
570                                        })
571                                    }
572                                } * multiplier;
573                                if intermediate_value.fract() == 0.5
574                                    || intermediate_value.fract() == -0.5
575                                {
576                                    let rounded_value = (intermediate_value.trunc()) / multiplier;
577                                    return Ok(VariableValue::Float(
578                                        match Float::from_f64(rounded_value) {
579                                            Some(f) => f,
580                                            None => {
581                                                return Err(FunctionError {
582                                                    function_name: expression.name.to_string(),
583                                                    error: FunctionEvaluationError::OverflowError,
584                                                })
585                                            }
586                                        },
587                                    ));
588                                }
589                                let rounded_value = (match n.as_f64() {
590                                    Some(f) => f,
591                                    None => {
592                                        return Err(FunctionError {
593                                            function_name: expression.name.to_string(),
594                                            error: FunctionEvaluationError::OverflowError,
595                                        })
596                                    }
597                                } * multiplier)
598                                    .round()
599                                    / multiplier;
600                                return Ok(VariableValue::Float(
601                                    match Float::from_f64(rounded_value) {
602                                        Some(f) => f,
603                                        None => {
604                                            return Err(FunctionError {
605                                                function_name: expression.name.to_string(),
606                                                error: FunctionEvaluationError::OverflowError,
607                                            })
608                                        }
609                                    },
610                                ));
611                            }
612                            "HALF_EVEN" => {
613                                let multiplier = 10.0_f64.powi(match p.as_i64() {
614                                    Some(i) => i as i32,
615                                    None => {
616                                        return Err(FunctionError {
617                                            function_name: expression.name.to_string(),
618                                            error: FunctionEvaluationError::OverflowError,
619                                        })
620                                    }
621                                });
622                                if match n.as_f64() {
623                                    Some(f) => f,
624                                    None => {
625                                        return Err(FunctionError {
626                                            function_name: expression.name.to_string(),
627                                            error: FunctionEvaluationError::OverflowError,
628                                        })
629                                    }
630                                } > f64::MAX / multiplier
631                                {
632                                    return Ok(VariableValue::Float(n.clone()));
633                                }
634                                let intermediate_value = match n.as_f64() {
635                                    Some(f) => f,
636                                    None => {
637                                        return Err(FunctionError {
638                                            function_name: expression.name.to_string(),
639                                            error: FunctionEvaluationError::OverflowError,
640                                        })
641                                    }
642                                } * multiplier;
643                                if intermediate_value.fract() == 0.5 {
644                                    if intermediate_value.trunc() % 2.0 == 0.0 {
645                                        let rounded_value =
646                                            (intermediate_value.trunc()) / multiplier;
647                                        return Ok(VariableValue::Float(
648                                            match Float::from_f64(rounded_value) {
649                                                Some(f) => f,
650                                                None => {
651                                                    return Err(FunctionError {
652                                                        function_name: expression.name.to_string(),
653                                                        error:
654                                                            FunctionEvaluationError::OverflowError,
655                                                    })
656                                                }
657                                            },
658                                        ));
659                                    } else {
660                                        let rounded_value =
661                                            (intermediate_value.trunc() + 1.0) / multiplier;
662                                        return Ok(VariableValue::Float(
663                                            match Float::from_f64(rounded_value) {
664                                                Some(f) => f,
665                                                None => {
666                                                    return Err(FunctionError {
667                                                        function_name: expression.name.to_string(),
668                                                        error:
669                                                            FunctionEvaluationError::OverflowError,
670                                                    })
671                                                }
672                                            },
673                                        ));
674                                    }
675                                } else if intermediate_value.fract() == -0.5 {
676                                    if intermediate_value.trunc() % 2.0 == 0.0 {
677                                        let rounded_value =
678                                            (intermediate_value.trunc()) / multiplier;
679                                        return Ok(VariableValue::Float(
680                                            match Float::from_f64(rounded_value) {
681                                                Some(f) => f,
682                                                None => {
683                                                    return Err(FunctionError {
684                                                        function_name: expression.name.to_string(),
685                                                        error:
686                                                            FunctionEvaluationError::OverflowError,
687                                                    })
688                                                }
689                                            },
690                                        ));
691                                    } else {
692                                        let rounded_value =
693                                            (intermediate_value.trunc() - 1.0) / multiplier;
694                                        return Ok(VariableValue::Float(
695                                            match Float::from_f64(rounded_value) {
696                                                Some(f) => f,
697                                                None => {
698                                                    return Err(FunctionError {
699                                                        function_name: expression.name.to_string(),
700                                                        error:
701                                                            FunctionEvaluationError::OverflowError,
702                                                    })
703                                                }
704                                            },
705                                        ));
706                                    }
707                                }
708                                let rounded_value = (match n.as_f64() {
709                                    Some(f) => f,
710                                    None => {
711                                        return Err(FunctionError {
712                                            function_name: expression.name.to_string(),
713                                            error: FunctionEvaluationError::OverflowError,
714                                        })
715                                    }
716                                } * multiplier)
717                                    .round()
718                                    / multiplier;
719                                return Ok(VariableValue::Float(
720                                    match Float::from_f64(rounded_value) {
721                                        Some(f) => f,
722                                        None => {
723                                            return Err(FunctionError {
724                                                function_name: expression.name.to_string(),
725                                                error: FunctionEvaluationError::OverflowError,
726                                            })
727                                        }
728                                    },
729                                ));
730                            }
731                            _ => {
732                                return Err(FunctionError {
733                                    function_name: expression.name.to_string(),
734                                    error: FunctionEvaluationError::InvalidArgument(2),
735                                })
736                            }
737                        }
738                    }
739                    (
740                        VariableValue::Integer(n),
741                        VariableValue::Integer(_p),
742                        VariableValue::Integer(_m),
743                    ) => Ok(VariableValue::Integer(Integer::from(match n.as_i64() {
744                        Some(i) => i,
745                        None => {
746                            return Err(FunctionError {
747                                function_name: expression.name.to_string(),
748                                error: FunctionEvaluationError::OverflowError,
749                            })
750                        }
751                    }))),
752                    _ => Err(FunctionError {
753                        function_name: expression.name.to_string(),
754                        error: FunctionEvaluationError::InvalidArgument(2),
755                    }),
756                }
757            }
758            _ => unreachable!(),
759        }
760    }
761}