drasi_core/evaluation/functions/numeric/numeric_round.rs
1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use drasi_query_ast::ast;
17
18use crate::evaluation::functions::ScalarFunction;
19use crate::evaluation::variable_value::float::Float;
20use crate::evaluation::variable_value::integer::Integer;
21use crate::evaluation::variable_value::VariableValue;
22use crate::evaluation::{ExpressionEvaluationContext, FunctionError, FunctionEvaluationError};
23use std::collections::HashSet;
24
25use round::{round_down, round_up};
26
27#[derive(Debug)]
28pub struct Round {}
29
30#[async_trait]
31impl ScalarFunction for Round {
32 async fn call(
33 &self,
34 _context: &ExpressionEvaluationContext,
35 expression: &ast::FunctionExpression,
36 args: Vec<VariableValue>,
37 ) -> Result<VariableValue, FunctionError> {
38 if args.is_empty() || args.len() > 3 {
39 return Err(FunctionError {
40 function_name: expression.name.to_string(),
41 error: FunctionEvaluationError::InvalidArgumentCount,
42 });
43 }
44 if args.contains(&VariableValue::Null) {
45 return Ok(VariableValue::Null);
46 }
47 match args.len() {
48 1 => {
49 match &args[0] {
50 VariableValue::Null => Ok(VariableValue::Null),
51 VariableValue::Integer(n) => Ok(VariableValue::Float(
52 match Float::from_f64(match n.as_i64() {
53 Some(i) => i as f64,
54 None => {
55 return Err(FunctionError {
56 function_name: expression.name.to_string(),
57 error: FunctionEvaluationError::OverflowError,
58 })
59 }
60 }) {
61 Some(f) => f,
62 None => {
63 return Err(FunctionError {
64 function_name: expression.name.to_string(),
65 error: FunctionEvaluationError::OverflowError,
66 })
67 }
68 },
69 )),
70 VariableValue::Float(n) => {
71 let input_as_f64 = match n.as_f64() {
72 Some(f) => f,
73 None => {
74 return Err(FunctionError {
75 function_name: expression.name.to_string(),
76 error: FunctionEvaluationError::OverflowError,
77 })
78 }
79 };
80 if input_as_f64.fract() == -0.5 {
81 //Cypher edge case
82 return Ok(VariableValue::Float(
83 match Float::from_f64(input_as_f64.trunc()) {
84 Some(f) => f,
85 None => {
86 return Err(FunctionError {
87 function_name: expression.name.to_string(),
88 error: FunctionEvaluationError::OverflowError,
89 })
90 }
91 },
92 ));
93 }
94 Ok(VariableValue::Float(
95 match Float::from_f64(match n.as_f64() {
96 Some(f) => f.round(),
97 None => {
98 return Err(FunctionError {
99 function_name: expression.name.to_string(),
100 error: FunctionEvaluationError::OverflowError,
101 })
102 }
103 }) {
104 Some(f) => f,
105 None => {
106 return Err(FunctionError {
107 function_name: expression.name.to_string(),
108 error: FunctionEvaluationError::OverflowError,
109 })
110 }
111 },
112 ))
113 }
114 _ => Err(FunctionError {
115 function_name: expression.name.to_string(),
116 error: FunctionEvaluationError::InvalidArgument(0),
117 }),
118 }
119 }
120 2 => {
121 match (&args[0], &args[1]) {
122 (VariableValue::Float(n), VariableValue::Integer(p)) => {
123 let multiplier = 10.0_f64.powi(match p.as_i64() {
124 Some(i) => i as i32,
125 None => {
126 return Err(FunctionError {
127 function_name: expression.name.to_string(),
128 error: FunctionEvaluationError::OverflowError,
129 })
130 }
131 });
132 //edge case
133
134 if match n.as_f64() {
135 Some(f) => f,
136 None => {
137 return Err(FunctionError {
138 function_name: expression.name.to_string(),
139 error: FunctionEvaluationError::OverflowError,
140 })
141 }
142 } > f64::MAX / multiplier
143 {
144 return Ok(VariableValue::Float(n.clone()));
145 }
146 let intermediate_value = match n.as_f64() {
147 Some(f) => f,
148 None => {
149 return Err(FunctionError {
150 function_name: expression.name.to_string(),
151 error: FunctionEvaluationError::OverflowError,
152 })
153 }
154 } * multiplier;
155 if intermediate_value.is_sign_negative()
156 && intermediate_value.fract() == -0.5
157 {
158 //Cypher edge case
159 let rounded_value = intermediate_value.trunc() / multiplier;
160 return Ok(VariableValue::Float(
161 match Float::from_f64(rounded_value) {
162 Some(f) => f,
163 None => {
164 return Err(FunctionError {
165 function_name: expression.name.to_string(),
166 error: FunctionEvaluationError::OverflowError,
167 })
168 }
169 },
170 ));
171 }
172 let rounded_value = (match n.as_f64() {
173 Some(f) => f,
174 None => {
175 return Err(FunctionError {
176 function_name: expression.name.to_string(),
177 error: FunctionEvaluationError::OverflowError,
178 })
179 }
180 } * multiplier)
181 .round()
182 / multiplier;
183 Ok(VariableValue::Float(match Float::from_f64(rounded_value) {
184 Some(f) => f,
185 None => {
186 return Err(FunctionError {
187 function_name: expression.name.to_string(),
188 error: FunctionEvaluationError::OverflowError,
189 })
190 }
191 }))
192 }
193 (VariableValue::Integer(n), VariableValue::Integer(_p)) => {
194 Ok(VariableValue::Integer(Integer::from(match n.as_i64() {
195 Some(i) => i,
196 None => {
197 return Err(FunctionError {
198 function_name: expression.name.to_string(),
199 error: FunctionEvaluationError::OverflowError,
200 })
201 }
202 })))
203 }
204 (VariableValue::Float(_n), _) => {
205 return Err(FunctionError {
206 function_name: expression.name.to_string(),
207 error: FunctionEvaluationError::InvalidArgument(1),
208 });
209 }
210 (VariableValue::Integer(_n), _) => {
211 return Err(FunctionError {
212 function_name: expression.name.to_string(),
213 error: FunctionEvaluationError::InvalidArgument(1),
214 });
215 }
216 _ => {
217 return Err(FunctionError {
218 function_name: expression.name.to_string(),
219 error: FunctionEvaluationError::InvalidArgument(0),
220 });
221 }
222 }
223 }
224 3 => {
225 match (&args[0], &args[1], &args[2]) {
226 (
227 VariableValue::Float(n),
228 VariableValue::Integer(p),
229 VariableValue::String(m),
230 ) => {
231 let valid_modes: HashSet<String> = [
232 "UP",
233 "DOWN",
234 "CEILING",
235 "FLOOR",
236 "HALF_UP",
237 "HALF_DOWN",
238 "HALF_EVEN",
239 ]
240 .iter()
241 .map(|s| s.to_string())
242 .collect();
243 // let valid_keys: HashSet<String> = vec!["year", "month", "week", "day", "ordinalDay", "quarter", "dayOfWeek", "dayOfQuarter"].iter().map(|s| s.to_string()).collect();
244 let mode = m.to_uppercase();
245 if !valid_modes.contains(&mode) {
246 return Err(FunctionError {
247 function_name: expression.name.to_string(),
248 error: FunctionEvaluationError::InvalidArgument(2),
249 });
250 }
251 let is_positive = match n.as_f64() {
252 Some(f) => f,
253 None => {
254 return Err(FunctionError {
255 function_name: expression.name.to_string(),
256 error: FunctionEvaluationError::OverflowError,
257 })
258 }
259 }
260 .is_sign_positive();
261 match mode.as_str() {
262 "UP" => {
263 if is_positive {
264 let result = round_up(
265 match n.as_f64() {
266 Some(f) => f,
267 None => {
268 return Err(FunctionError {
269 function_name: expression.name.to_string(),
270 error: FunctionEvaluationError::OverflowError,
271 })
272 }
273 },
274 match p.as_i64() {
275 Some(i) => i as i32,
276 None => {
277 return Err(FunctionError {
278 function_name: expression.name.to_string(),
279 error: FunctionEvaluationError::OverflowError,
280 })
281 }
282 },
283 );
284 return Ok(VariableValue::Float(
285 match Float::from_f64(result) {
286 Some(f) => f,
287 None => {
288 return Err(FunctionError {
289 function_name: expression.name.to_string(),
290 error: FunctionEvaluationError::OverflowError,
291 })
292 }
293 },
294 ));
295 } else {
296 //Cypher being weird :)
297 let result = round_down(
298 match n.as_f64() {
299 Some(f) => f,
300 None => {
301 return Err(FunctionError {
302 function_name: expression.name.to_string(),
303 error: FunctionEvaluationError::OverflowError,
304 })
305 }
306 },
307 match p.as_i64() {
308 Some(i) => i as i32,
309 None => {
310 return Err(FunctionError {
311 function_name: expression.name.to_string(),
312 error: FunctionEvaluationError::OverflowError,
313 })
314 }
315 },
316 );
317 return Ok(VariableValue::Float(
318 match Float::from_f64(result) {
319 Some(f) => f,
320 None => {
321 return Err(FunctionError {
322 function_name: expression.name.to_string(),
323 error: FunctionEvaluationError::OverflowError,
324 })
325 }
326 },
327 ));
328 }
329 }
330 "DOWN" => {
331 if is_positive {
332 let result = round_down(
333 match n.as_f64() {
334 Some(f) => f,
335 None => {
336 return Err(FunctionError {
337 function_name: expression.name.to_string(),
338 error: FunctionEvaluationError::OverflowError,
339 })
340 }
341 },
342 match p.as_i64() {
343 Some(i) => i as i32,
344 None => {
345 return Err(FunctionError {
346 function_name: expression.name.to_string(),
347 error: FunctionEvaluationError::OverflowError,
348 })
349 }
350 },
351 );
352 return Ok(VariableValue::Float(
353 match Float::from_f64(result) {
354 Some(f) => f,
355 None => {
356 return Err(FunctionError {
357 function_name: expression.name.to_string(),
358 error: FunctionEvaluationError::OverflowError,
359 })
360 }
361 },
362 ));
363 } else {
364 let result = round_up(
365 match n.as_f64() {
366 Some(f) => f,
367 None => {
368 return Err(FunctionError {
369 function_name: expression.name.to_string(),
370 error: FunctionEvaluationError::OverflowError,
371 })
372 }
373 },
374 match p.as_i64() {
375 Some(i) => i as i32,
376 None => {
377 return Err(FunctionError {
378 function_name: expression.name.to_string(),
379 error: FunctionEvaluationError::OverflowError,
380 })
381 }
382 },
383 );
384 return Ok(VariableValue::Float(
385 match Float::from_f64(result) {
386 Some(f) => f,
387 None => {
388 return Err(FunctionError {
389 function_name: expression.name.to_string(),
390 error: FunctionEvaluationError::OverflowError,
391 })
392 }
393 },
394 ));
395 }
396 }
397 "CEILING" => {
398 let result = round_up(
399 match n.as_f64() {
400 Some(f) => f,
401 None => {
402 return Err(FunctionError {
403 function_name: expression.name.to_string(),
404 error: FunctionEvaluationError::OverflowError,
405 })
406 }
407 },
408 match p.as_i64() {
409 Some(i) => i as i32,
410 None => {
411 return Err(FunctionError {
412 function_name: expression.name.to_string(),
413 error: FunctionEvaluationError::OverflowError,
414 })
415 }
416 },
417 );
418 return Ok(VariableValue::Float(match Float::from_f64(result) {
419 Some(f) => f,
420 None => {
421 return Err(FunctionError {
422 function_name: expression.name.to_string(),
423 error: FunctionEvaluationError::OverflowError,
424 })
425 }
426 }));
427 }
428 "FLOOR" => {
429 let result = round_down(
430 match n.as_f64() {
431 Some(f) => f,
432 None => {
433 return Err(FunctionError {
434 function_name: expression.name.to_string(),
435 error: FunctionEvaluationError::OverflowError,
436 })
437 }
438 },
439 match p.as_i64() {
440 Some(i) => i as i32,
441 None => {
442 return Err(FunctionError {
443 function_name: expression.name.to_string(),
444 error: FunctionEvaluationError::OverflowError,
445 })
446 }
447 },
448 );
449 return Ok(VariableValue::Float(match Float::from_f64(result) {
450 Some(f) => f,
451 None => {
452 return Err(FunctionError {
453 function_name: expression.name.to_string(),
454 error: FunctionEvaluationError::OverflowError,
455 })
456 }
457 }));
458 }
459 "HALF_UP" => {
460 let multiplier = 10.0_f64.powi(match p.as_i64() {
461 Some(i) => i as i32,
462 None => {
463 return Err(FunctionError {
464 function_name: expression.name.to_string(),
465 error: FunctionEvaluationError::OverflowError,
466 })
467 }
468 });
469 if match n.as_f64() {
470 Some(f) => f,
471 None => {
472 return Err(FunctionError {
473 function_name: expression.name.to_string(),
474 error: FunctionEvaluationError::OverflowError,
475 })
476 }
477 } > f64::MAX / multiplier
478 {
479 return Ok(VariableValue::Float(n.clone()));
480 }
481 let intermediate_value = match n.as_f64() {
482 Some(f) => f,
483 None => {
484 return Err(FunctionError {
485 function_name: expression.name.to_string(),
486 error: FunctionEvaluationError::OverflowError,
487 })
488 }
489 } * multiplier;
490 if intermediate_value.fract() == 0.5 {
491 let rounded_value =
492 (intermediate_value.trunc() + 1.0) / multiplier;
493 return Ok(VariableValue::Float(
494 match Float::from_f64(rounded_value) {
495 Some(f) => f,
496 None => {
497 return Err(FunctionError {
498 function_name: expression.name.to_string(),
499 error: FunctionEvaluationError::OverflowError,
500 })
501 }
502 },
503 ));
504 } else if intermediate_value.fract() == -0.5 {
505 let rounded_value =
506 (intermediate_value.trunc() - 1.0) / multiplier;
507 return Ok(VariableValue::Float(
508 match Float::from_f64(rounded_value) {
509 Some(f) => f,
510 None => {
511 return Err(FunctionError {
512 function_name: expression.name.to_string(),
513 error: FunctionEvaluationError::OverflowError,
514 })
515 }
516 },
517 ));
518 }
519 let rounded_value = (match n.as_f64() {
520 Some(f) => f,
521 None => {
522 return Err(FunctionError {
523 function_name: expression.name.to_string(),
524 error: FunctionEvaluationError::OverflowError,
525 })
526 }
527 } * multiplier)
528 .round()
529 / multiplier;
530 return Ok(VariableValue::Float(
531 match Float::from_f64(rounded_value) {
532 Some(f) => f,
533 None => {
534 return Err(FunctionError {
535 function_name: expression.name.to_string(),
536 error: FunctionEvaluationError::OverflowError,
537 })
538 }
539 },
540 ));
541 }
542 "HALF_DOWN" => {
543 let multiplier = 10.0_f64.powi(match p.as_i64() {
544 Some(i) => i as i32,
545 None => {
546 return Err(FunctionError {
547 function_name: expression.name.to_string(),
548 error: FunctionEvaluationError::OverflowError,
549 })
550 }
551 });
552 if match n.as_f64() {
553 Some(f) => f,
554 None => {
555 return Err(FunctionError {
556 function_name: expression.name.to_string(),
557 error: FunctionEvaluationError::OverflowError,
558 })
559 }
560 } > f64::MAX / multiplier
561 {
562 return Ok(VariableValue::Float(n.clone()));
563 }
564 let intermediate_value = match n.as_f64() {
565 Some(f) => f,
566 None => {
567 return Err(FunctionError {
568 function_name: expression.name.to_string(),
569 error: FunctionEvaluationError::OverflowError,
570 })
571 }
572 } * multiplier;
573 if intermediate_value.fract() == 0.5
574 || intermediate_value.fract() == -0.5
575 {
576 let rounded_value = (intermediate_value.trunc()) / multiplier;
577 return Ok(VariableValue::Float(
578 match Float::from_f64(rounded_value) {
579 Some(f) => f,
580 None => {
581 return Err(FunctionError {
582 function_name: expression.name.to_string(),
583 error: FunctionEvaluationError::OverflowError,
584 })
585 }
586 },
587 ));
588 }
589 let rounded_value = (match n.as_f64() {
590 Some(f) => f,
591 None => {
592 return Err(FunctionError {
593 function_name: expression.name.to_string(),
594 error: FunctionEvaluationError::OverflowError,
595 })
596 }
597 } * multiplier)
598 .round()
599 / multiplier;
600 return Ok(VariableValue::Float(
601 match Float::from_f64(rounded_value) {
602 Some(f) => f,
603 None => {
604 return Err(FunctionError {
605 function_name: expression.name.to_string(),
606 error: FunctionEvaluationError::OverflowError,
607 })
608 }
609 },
610 ));
611 }
612 "HALF_EVEN" => {
613 let multiplier = 10.0_f64.powi(match p.as_i64() {
614 Some(i) => i as i32,
615 None => {
616 return Err(FunctionError {
617 function_name: expression.name.to_string(),
618 error: FunctionEvaluationError::OverflowError,
619 })
620 }
621 });
622 if match n.as_f64() {
623 Some(f) => f,
624 None => {
625 return Err(FunctionError {
626 function_name: expression.name.to_string(),
627 error: FunctionEvaluationError::OverflowError,
628 })
629 }
630 } > f64::MAX / multiplier
631 {
632 return Ok(VariableValue::Float(n.clone()));
633 }
634 let intermediate_value = match n.as_f64() {
635 Some(f) => f,
636 None => {
637 return Err(FunctionError {
638 function_name: expression.name.to_string(),
639 error: FunctionEvaluationError::OverflowError,
640 })
641 }
642 } * multiplier;
643 if intermediate_value.fract() == 0.5 {
644 if intermediate_value.trunc() % 2.0 == 0.0 {
645 let rounded_value =
646 (intermediate_value.trunc()) / multiplier;
647 return Ok(VariableValue::Float(
648 match Float::from_f64(rounded_value) {
649 Some(f) => f,
650 None => {
651 return Err(FunctionError {
652 function_name: expression.name.to_string(),
653 error:
654 FunctionEvaluationError::OverflowError,
655 })
656 }
657 },
658 ));
659 } else {
660 let rounded_value =
661 (intermediate_value.trunc() + 1.0) / multiplier;
662 return Ok(VariableValue::Float(
663 match Float::from_f64(rounded_value) {
664 Some(f) => f,
665 None => {
666 return Err(FunctionError {
667 function_name: expression.name.to_string(),
668 error:
669 FunctionEvaluationError::OverflowError,
670 })
671 }
672 },
673 ));
674 }
675 } else if intermediate_value.fract() == -0.5 {
676 if intermediate_value.trunc() % 2.0 == 0.0 {
677 let rounded_value =
678 (intermediate_value.trunc()) / multiplier;
679 return Ok(VariableValue::Float(
680 match Float::from_f64(rounded_value) {
681 Some(f) => f,
682 None => {
683 return Err(FunctionError {
684 function_name: expression.name.to_string(),
685 error:
686 FunctionEvaluationError::OverflowError,
687 })
688 }
689 },
690 ));
691 } else {
692 let rounded_value =
693 (intermediate_value.trunc() - 1.0) / multiplier;
694 return Ok(VariableValue::Float(
695 match Float::from_f64(rounded_value) {
696 Some(f) => f,
697 None => {
698 return Err(FunctionError {
699 function_name: expression.name.to_string(),
700 error:
701 FunctionEvaluationError::OverflowError,
702 })
703 }
704 },
705 ));
706 }
707 }
708 let rounded_value = (match n.as_f64() {
709 Some(f) => f,
710 None => {
711 return Err(FunctionError {
712 function_name: expression.name.to_string(),
713 error: FunctionEvaluationError::OverflowError,
714 })
715 }
716 } * multiplier)
717 .round()
718 / multiplier;
719 return Ok(VariableValue::Float(
720 match Float::from_f64(rounded_value) {
721 Some(f) => f,
722 None => {
723 return Err(FunctionError {
724 function_name: expression.name.to_string(),
725 error: FunctionEvaluationError::OverflowError,
726 })
727 }
728 },
729 ));
730 }
731 _ => {
732 return Err(FunctionError {
733 function_name: expression.name.to_string(),
734 error: FunctionEvaluationError::InvalidArgument(2),
735 })
736 }
737 }
738 }
739 (
740 VariableValue::Integer(n),
741 VariableValue::Integer(_p),
742 VariableValue::Integer(_m),
743 ) => Ok(VariableValue::Integer(Integer::from(match n.as_i64() {
744 Some(i) => i,
745 None => {
746 return Err(FunctionError {
747 function_name: expression.name.to_string(),
748 error: FunctionEvaluationError::OverflowError,
749 })
750 }
751 }))),
752 _ => Err(FunctionError {
753 function_name: expression.name.to_string(),
754 error: FunctionEvaluationError::InvalidArgument(2),
755 }),
756 }
757 }
758 _ => unreachable!(),
759 }
760 }
761}