1use crate::ast::{
20 CallStyle, Extension, ExtensionFunction, ExtensionOutputValue, ExtensionValue, Literal, Name,
21 RepresentableExtensionValue, Type, Value, ValueKind,
22};
23use crate::entities::SchemaType;
24use crate::evaluator;
25
26use miette::Diagnostic;
27use std::str::FromStr;
28use std::sync::Arc;
29use thiserror::Error;
30
31const NUM_DIGITS: u32 = 4;
33
34#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
37struct Decimal {
38 value: i64,
39}
40
41#[expect(
42 clippy::expect_used,
43 clippy::unwrap_used,
44 reason = "The `Name`s and `Regex` here are valid"
45)]
46mod constants {
47 use super::EXTENSION_NAME;
48 use crate::ast::Name;
49 use regex::Regex;
50 use std::sync::LazyLock;
51
52 pub static DECIMAL_FROM_STR_NAME: LazyLock<Name> = LazyLock::new(|| {
53 Name::parse_unqualified_name(EXTENSION_NAME).expect("should be a valid identifier")
54 });
55 pub static LESS_THAN: LazyLock<Name> = LazyLock::new(|| {
56 Name::parse_unqualified_name("lessThan").expect("should be a valid identifier")
57 });
58 pub static LESS_THAN_OR_EQUAL: LazyLock<Name> = LazyLock::new(|| {
59 Name::parse_unqualified_name("lessThanOrEqual").expect("should be a valid identifier")
60 });
61 pub static GREATER_THAN: LazyLock<Name> = LazyLock::new(|| {
62 Name::parse_unqualified_name("greaterThan").expect("should be a valid identifier")
63 });
64 pub static GREATER_THAN_OR_EQUAL: LazyLock<Name> = LazyLock::new(|| {
65 Name::parse_unqualified_name("greaterThanOrEqual").expect("should be a valid identifier")
66 });
67
68 pub static DECIMAL_REGEX: LazyLock<Regex> =
70 LazyLock::new(|| Regex::new(r"^(-?\d+)\.(\d+)$").unwrap());
71}
72
73const ADVICE_MSG: &str = "maybe you forgot to apply the `decimal` constructor?";
76
77#[derive(Debug, Diagnostic, Error)]
81enum Error {
82 #[error("`{0}` is not a well-formed decimal value")]
84 FailedParse(String),
85
86 #[error("too many digits after the decimal in `{0}`")]
88 #[diagnostic(help("at most {NUM_DIGITS} digits are supported"))]
89 TooManyDigits(String),
90
91 #[error("overflow when converting to decimal")]
93 Overflow,
94}
95
96fn checked_mul_pow(x: i64, y: u32) -> Result<i64, Error> {
98 if let Some(z) = i64::checked_pow(10, y) {
99 if let Some(w) = i64::checked_mul(x, z) {
100 return Ok(w);
101 }
102 };
103 Err(Error::Overflow)
104}
105
106impl Decimal {
107 fn typename() -> Name {
109 constants::DECIMAL_FROM_STR_NAME.clone()
110 }
111
112 fn from_str(str: impl AsRef<str>) -> Result<Self, Error> {
121 let caps = constants::DECIMAL_REGEX
123 .captures(str.as_ref())
124 .ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?;
125 let l_str = caps
126 .get(1)
127 .ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?
128 .as_str();
129 let r_str = caps
130 .get(2)
131 .ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?
132 .as_str();
133
134 let l = i64::from_str(l_str).map_err(|_| Error::Overflow)?;
136 let l = checked_mul_pow(l, NUM_DIGITS)?;
137
138 let len: u32 = r_str.len().try_into().map_err(|_| Error::Overflow)?;
140 if NUM_DIGITS < len {
141 return Err(Error::TooManyDigits(str.as_ref().to_string()));
142 }
143 let r = i64::from_str(r_str).map_err(|_| Error::Overflow)?;
144 let r = checked_mul_pow(r, NUM_DIGITS - len)?;
145
146 if !l_str.starts_with('-') {
148 l.checked_add(r)
149 } else {
150 l.checked_sub(r)
151 }
152 .map(|value| Self { value })
153 .ok_or(Error::Overflow)
154 }
155}
156
157impl std::fmt::Display for Decimal {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 let abs = i128::from(self.value).abs();
160 if self.value.is_negative() {
161 write!(f, "-")?;
162 }
163 let pow = i128::pow(10, NUM_DIGITS);
164 write!(f, "{}.{:04}", abs / pow, abs % pow)
165 }
166}
167
168impl ExtensionValue for Decimal {
169 fn typename(&self) -> Name {
170 Self::typename()
171 }
172 fn supports_operator_overloading(&self) -> bool {
173 false
174 }
175}
176
177const EXTENSION_NAME: &str = "decimal";
178
179fn extension_err(msg: impl Into<String>, advice: Option<String>) -> evaluator::EvaluationError {
180 evaluator::EvaluationError::failed_extension_function_application(
181 constants::DECIMAL_FROM_STR_NAME.clone(),
182 msg.into(),
183 None,
184 advice, )
186}
187
188fn decimal_from_str(arg: &Value) -> evaluator::Result<ExtensionOutputValue> {
191 let str = arg.get_as_string()?;
192 let decimal =
193 Decimal::from_str(str.as_str()).map_err(|e| extension_err(e.to_string(), None))?;
194 let arg_source_loc = arg.source_loc();
195 let e = RepresentableExtensionValue::new(
196 Arc::new(decimal),
197 constants::DECIMAL_FROM_STR_NAME.clone(),
198 vec![arg.clone().into()],
199 );
200 Ok(Value {
201 value: ValueKind::ExtensionValue(Arc::new(e)),
202 loc: arg_source_loc.cloned(), }
204 .into())
205}
206
207fn as_decimal(v: &Value) -> Result<&Decimal, evaluator::EvaluationError> {
209 match &v.value {
210 ValueKind::ExtensionValue(ev) if ev.typename() == Decimal::typename() => {
211 #[expect(clippy::expect_used, reason = "Conditional above performs a typecheck")]
212 let d = ev
213 .value()
214 .as_any()
215 .downcast_ref::<Decimal>()
216 .expect("already typechecked, so this downcast should succeed");
217 Ok(d)
218 }
219 ValueKind::Lit(Literal::String(_)) => {
220 Err(evaluator::EvaluationError::type_error_with_advice_single(
221 Type::Extension {
222 name: Decimal::typename(),
223 },
224 v,
225 ADVICE_MSG.into(),
226 ))
227 }
228 _ => Err(evaluator::EvaluationError::type_error_single(
229 Type::Extension {
230 name: Decimal::typename(),
231 },
232 v,
233 )),
234 }
235}
236
237fn decimal_lt(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
240 let left = as_decimal(left)?;
241 let right = as_decimal(right)?;
242 Ok(Value::from(left < right).into())
243}
244
245fn decimal_le(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
248 let left = as_decimal(left)?;
249 let right = as_decimal(right)?;
250 Ok(Value::from(left <= right).into())
251}
252
253fn decimal_gt(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
256 let left = as_decimal(left)?;
257 let right = as_decimal(right)?;
258 Ok(Value::from(left > right).into())
259}
260
261fn decimal_ge(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
264 let left = as_decimal(left)?;
265 let right = as_decimal(right)?;
266 Ok(Value::from(left >= right).into())
267}
268
269pub fn extension() -> Extension {
271 let decimal_type = SchemaType::Extension {
272 name: Decimal::typename(),
273 };
274 Extension::new(
275 constants::DECIMAL_FROM_STR_NAME.clone(),
276 vec![
277 ExtensionFunction::unary(
278 constants::DECIMAL_FROM_STR_NAME.clone(),
279 CallStyle::FunctionStyle,
280 Box::new(decimal_from_str),
281 decimal_type.clone(),
282 SchemaType::String,
283 ),
284 ExtensionFunction::binary(
285 constants::LESS_THAN.clone(),
286 CallStyle::MethodStyle,
287 Box::new(decimal_lt),
288 SchemaType::Bool,
289 (decimal_type.clone(), decimal_type.clone()),
290 ),
291 ExtensionFunction::binary(
292 constants::LESS_THAN_OR_EQUAL.clone(),
293 CallStyle::MethodStyle,
294 Box::new(decimal_le),
295 SchemaType::Bool,
296 (decimal_type.clone(), decimal_type.clone()),
297 ),
298 ExtensionFunction::binary(
299 constants::GREATER_THAN.clone(),
300 CallStyle::MethodStyle,
301 Box::new(decimal_gt),
302 SchemaType::Bool,
303 (decimal_type.clone(), decimal_type.clone()),
304 ),
305 ExtensionFunction::binary(
306 constants::GREATER_THAN_OR_EQUAL.clone(),
307 CallStyle::MethodStyle,
308 Box::new(decimal_ge),
309 SchemaType::Bool,
310 (decimal_type.clone(), decimal_type),
311 ),
312 ],
313 std::iter::empty(),
314 )
315}
316
317#[cfg(test)]
318#[expect(clippy::panic, reason = "Unit Test Code")]
319mod tests {
320 use super::*;
321 use crate::ast::{Expr, Type, Value};
322 use crate::evaluator::test::{basic_entities, basic_request};
323 use crate::evaluator::{evaluation_errors, EvaluationError, Evaluator};
324 use crate::extensions::Extensions;
325 use crate::parser::parse_expr;
326 use cool_asserts::assert_matches;
327 use nonempty::nonempty;
328
329 #[track_caller] fn assert_decimal_err<T: std::fmt::Debug>(res: evaluator::Result<T>) {
332 assert_matches!(res, Err(evaluator::EvaluationError::FailedExtensionFunctionExecution(evaluation_errors::ExtensionFunctionExecutionError {
333 extension_name,
334 msg,
335 ..
336 })) => {
337 println!("{msg}");
338 assert_eq!(
339 extension_name,
340 Name::parse_unqualified_name("decimal")
341 .expect("should be a valid identifier")
342 )
343 });
344 }
345
346 #[track_caller] fn assert_decimal_valid(res: evaluator::Result<Value>) {
349 assert_matches!(res, Ok(Value { value: ValueKind::ExtensionValue(ev), .. }) => {
350 assert_eq!(ev.typename(), Decimal::typename());
351 });
352 }
353
354 #[test]
356 fn constructors() {
357 let ext = extension();
358 assert!(ext
359 .get_func(
360 &Name::parse_unqualified_name("decimal").expect("should be a valid identifier")
361 )
362 .expect("function should exist")
363 .is_single_arg_constructor());
364 assert!(!ext
365 .get_func(
366 &Name::parse_unqualified_name("lessThan").expect("should be a valid identifier")
367 )
368 .expect("function should exist")
369 .is_single_arg_constructor());
370 assert!(!ext
371 .get_func(
372 &Name::parse_unqualified_name("lessThanOrEqual")
373 .expect("should be a valid identifier")
374 )
375 .expect("function should exist")
376 .is_single_arg_constructor());
377 assert!(!ext
378 .get_func(
379 &Name::parse_unqualified_name("greaterThan").expect("should be a valid identifier")
380 )
381 .expect("function should exist")
382 .is_single_arg_constructor());
383 assert!(!ext
384 .get_func(
385 &Name::parse_unqualified_name("greaterThanOrEqual")
386 .expect("should be a valid identifier")
387 )
388 .expect("function should exist")
389 .is_single_arg_constructor(),);
390 }
391
392 #[test]
393 fn decimal_creation() {
394 let ext_array = [extension()];
395 let exts = Extensions::specific_extensions(&ext_array).unwrap();
396 let request = basic_request();
397 let entities = basic_entities();
398 let eval = Evaluator::new(request, &entities, &exts);
399
400 assert_decimal_valid(
402 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.0")"#).expect("parsing error")),
403 );
404 assert_decimal_valid(
405 eval.interpret_inline_policy(&parse_expr(r#"decimal("-1.0")"#).expect("parsing error")),
406 );
407 assert_decimal_valid(
408 eval.interpret_inline_policy(
409 &parse_expr(r#"decimal("123.456")"#).expect("parsing error"),
410 ),
411 );
412 assert_decimal_valid(
413 eval.interpret_inline_policy(
414 &parse_expr(r#"decimal("0.1234")"#).expect("parsing error"),
415 ),
416 );
417 assert_decimal_valid(
418 eval.interpret_inline_policy(
419 &parse_expr(r#"decimal("-0.0123")"#).expect("parsing error"),
420 ),
421 );
422 assert_decimal_valid(
423 eval.interpret_inline_policy(&parse_expr(r#"decimal("55.1")"#).expect("parsing error")),
424 );
425 assert_decimal_valid(eval.interpret_inline_policy(
426 &parse_expr(r#"decimal("-922337203685477.5808")"#).expect("parsing error"),
427 ));
428
429 assert_decimal_valid(
431 eval.interpret_inline_policy(
432 &parse_expr(r#"decimal("00.000")"#).expect("parsing error"),
433 ),
434 );
435
436 assert_decimal_err(
438 eval.interpret_inline_policy(&parse_expr(r#"decimal("1234")"#).expect("parsing error")),
439 );
440 assert_decimal_err(
441 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.0.")"#).expect("parsing error")),
442 );
443 assert_decimal_err(
444 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.")"#).expect("parsing error")),
445 );
446 assert_decimal_err(
447 eval.interpret_inline_policy(&parse_expr(r#"decimal(".1")"#).expect("parsing error")),
448 );
449 assert_decimal_err(
450 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.a")"#).expect("parsing error")),
451 );
452 assert_decimal_err(
453 eval.interpret_inline_policy(&parse_expr(r#"decimal("-.")"#).expect("parsing error")),
454 );
455
456 assert_decimal_err(eval.interpret_inline_policy(
458 &parse_expr(r#"decimal("1000000000000000.0")"#).expect("parsing error"),
459 ));
460 assert_decimal_err(eval.interpret_inline_policy(
461 &parse_expr(r#"decimal("922337203685477.5808")"#).expect("parsing error"),
462 ));
463 assert_decimal_err(eval.interpret_inline_policy(
464 &parse_expr(r#"decimal("-922337203685477.5809")"#).expect("parsing error"),
465 ));
466 assert_decimal_err(eval.interpret_inline_policy(
467 &parse_expr(r#"decimal("-922337203685478.0")"#).expect("parsing error"),
468 ));
469
470 assert_decimal_err(
472 eval.interpret_inline_policy(
473 &parse_expr(r#"decimal("0.12345")"#).expect("parsing error"),
474 ),
475 );
476
477 assert_decimal_err(
479 eval.interpret_inline_policy(
480 &parse_expr(r#"decimal("0.00000")"#).expect("parsing error"),
481 ),
482 );
483
484 parse_expr(r#" "1.0".decimal() "#).expect_err("should fail");
486 }
487
488 #[test]
489 fn decimal_equality() {
490 let ext_array = [extension()];
491 let exts = Extensions::specific_extensions(&ext_array).unwrap();
492 let request = basic_request();
493 let entities = basic_entities();
494 let eval = Evaluator::new(request, &entities, &exts);
495
496 let a = parse_expr(r#"decimal("123.0")"#).expect("parsing error");
497 let b = parse_expr(r#"decimal("123.0000")"#).expect("parsing error");
498 let c = parse_expr(r#"decimal("0123.0")"#).expect("parsing error");
499 let d = parse_expr(r#"decimal("123.456")"#).expect("parsing error");
500 let e = parse_expr(r#"decimal("1.23")"#).expect("parsing error");
501 let f = parse_expr(r#"decimal("0.0")"#).expect("parsing error");
502 let g = parse_expr(r#"decimal("-0.0")"#).expect("parsing error");
503
504 assert_eq!(
506 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), a.clone())),
507 Ok(Value::from(true))
508 );
509 assert_eq!(
510 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), b.clone())),
511 Ok(Value::from(true))
512 );
513 assert_eq!(
514 eval.interpret_inline_policy(&Expr::is_eq(b.clone(), c.clone())),
515 Ok(Value::from(true))
516 );
517 assert_eq!(
518 eval.interpret_inline_policy(&Expr::is_eq(c, a.clone())),
519 Ok(Value::from(true))
520 );
521
522 assert_eq!(
524 eval.interpret_inline_policy(&Expr::is_eq(b, d.clone())),
525 Ok(Value::from(false))
526 );
527 assert_eq!(
528 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), e.clone())),
529 Ok(Value::from(false))
530 );
531 assert_eq!(
532 eval.interpret_inline_policy(&Expr::is_eq(d, e)),
533 Ok(Value::from(false))
534 );
535
536 assert_eq!(
538 eval.interpret_inline_policy(&Expr::is_eq(f, g)),
539 Ok(Value::from(true))
540 );
541
542 assert_eq!(
544 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), Expr::val("123.0"))),
545 Ok(Value::from(false))
546 );
547 assert_eq!(
548 eval.interpret_inline_policy(&Expr::is_eq(a, Expr::val(1))),
549 Ok(Value::from(false))
550 );
551 }
552
553 fn decimal_ops_helper(op: &str, tests: Vec<((Expr, Expr), bool)>) {
554 let ext_array = [extension()];
555 let exts = Extensions::specific_extensions(&ext_array).unwrap();
556 let request = basic_request();
557 let entities = basic_entities();
558 let eval = Evaluator::new(request, &entities, &exts);
559
560 for ((l, r), res) in tests {
561 assert_eq!(
562 eval.interpret_inline_policy(&Expr::call_extension_fn(
563 Name::parse_unqualified_name(op).expect("should be a valid identifier"),
564 vec![l, r]
565 )),
566 Ok(Value::from(res))
567 );
568 }
569 }
570
571 #[test]
572 fn decimal_ops() {
573 let a = parse_expr(r#"decimal("1.23")"#).expect("parsing error");
574 let b = parse_expr(r#"decimal("1.24")"#).expect("parsing error");
575 let c = parse_expr(r#"decimal("123.45")"#).expect("parsing error");
576 let d = parse_expr(r#"decimal("-1.23")"#).expect("parsing error");
577 let e = parse_expr(r#"decimal("-1.24")"#).expect("parsing error");
578
579 let tests = vec![
581 ((a.clone(), b.clone()), true), ((a.clone(), a.clone()), false), ((c.clone(), a.clone()), false), ((d.clone(), a.clone()), true), ((d.clone(), e.clone()), false), ];
587 decimal_ops_helper("lessThan", tests);
588
589 let tests = vec![
591 ((a.clone(), b.clone()), true), ((a.clone(), a.clone()), true), ((c.clone(), a.clone()), false), ((d.clone(), a.clone()), true), ((d.clone(), e.clone()), false), ];
597 decimal_ops_helper("lessThanOrEqual", tests);
598
599 let tests = vec![
601 ((a.clone(), b.clone()), false), ((a.clone(), a.clone()), false), ((c.clone(), a.clone()), true), ((d.clone(), a.clone()), false), ((d.clone(), e.clone()), true), ];
607 decimal_ops_helper("greaterThan", tests);
608
609 let tests = vec![
611 ((a.clone(), b), false), ((a.clone(), a.clone()), true), ((c, a.clone()), true), ((d.clone(), a), false), ((d, e), true), ];
617 decimal_ops_helper("greaterThanOrEqual", tests);
618
619 let ext_array = [extension()];
622 let exts = Extensions::specific_extensions(&ext_array).unwrap();
623 let request = basic_request();
624 let entities = basic_entities();
625 let eval = Evaluator::new(request, &entities, &exts);
626
627 assert_matches!(
628 eval.interpret_inline_policy(
629 &parse_expr(r#"decimal("1.23") < decimal("1.24")"#).expect("parsing error")
630 ),
631 Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => {
632 assert_eq!(expected, nonempty![Type::Long]);
633 assert_eq!(actual, Type::Extension {
634 name: Name::parse_unqualified_name("decimal")
635 .expect("should be a valid identifier")
636 });
637 assert_eq!(advice, Some("Only types long support comparison".into()));
638 }
639 );
640 assert_matches!(
641 eval.interpret_inline_policy(
642 &parse_expr(r#"decimal("-1.23").lessThan("1.23")"#).expect("parsing error")
643 ),
644 Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => {
645 assert_eq!(expected, nonempty![Type::Extension {
646 name: Name::parse_unqualified_name("decimal")
647 .expect("should be a valid identifier")
648 }]);
649 assert_eq!(actual, Type::String);
650 assert_matches!(advice, Some(a) => assert_eq!(a, ADVICE_MSG));
651 }
652 );
653 parse_expr(r#"lessThan(decimal("-1.23"), decimal("1.23"))"#).expect_err("should fail");
655
656 assert_eq!(
657 eval.interpret_inline_policy(
658 &parse_expr(r#"decimal("-0.23") != decimal("0.23")"#).expect("parsing error")
659 ),
660 Ok(true.into())
661 );
662
663 assert_eq!(
664 eval.interpret_inline_policy(
665 &parse_expr(r#"decimal("-0.0001").lessThan(decimal("0.0"))"#)
666 .expect("parsing error")
667 ),
668 Ok(true.into())
669 );
670
671 assert_eq!(
672 eval.interpret_inline_policy(
673 &parse_expr(r#"decimal("-0.0023").lessThan(decimal("-0.23"))"#)
674 .expect("parsing error")
675 ),
676 Ok(false.into())
677 );
678
679 assert_eq!(
680 eval.interpret_inline_policy(
681 &parse_expr(r#"decimal("-1.0000").lessThan(decimal("-0.9999"))"#)
682 .expect("parsing error")
683 ),
684 Ok(true.into())
685 );
686 }
687
688 fn check_round_trip(s: &str) {
689 let d = Decimal::from_str(s).expect("should be a valid decimal");
690 assert_eq!(d, Decimal::from_str(d.to_string()).unwrap());
691 }
692
693 #[test]
694 fn decimal_display() {
695 check_round_trip("123.0");
697 check_round_trip("1.2300");
698 check_round_trip("123.4560");
699 check_round_trip("-123.4560");
700 check_round_trip("0.0");
701 }
702}