1use crate::error::{ArithmeticError, ParseError};
4use crate::rounding::RoundingMode;
5use core::cmp::Ordering;
6use core::fmt;
7use core::ops::{Add, Div, Mul, Neg, Sub};
8use core::str::FromStr;
9use num_traits::Signed;
10use rust_decimal::prelude::MathematicalOps;
11use rust_decimal::Decimal as RustDecimal;
12use serde::{Deserialize, Serialize};
13
14pub const MAX_SCALE: u32 = 28;
16
17#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
23#[serde(transparent)]
24pub struct Decimal(RustDecimal);
25
26impl Decimal {
27 pub const ZERO: Self = Self(RustDecimal::ZERO);
29
30 pub const ONE: Self = Self(RustDecimal::ONE);
32
33 pub const NEGATIVE_ONE: Self = Self(RustDecimal::NEGATIVE_ONE);
35
36 pub const TEN: Self = Self(RustDecimal::TEN);
38
39 pub const ONE_HUNDRED: Self = Self(RustDecimal::ONE_HUNDRED);
41
42 pub const ONE_THOUSAND: Self = Self(RustDecimal::ONE_THOUSAND);
44
45 pub const MAX: Self = Self(RustDecimal::MAX);
47
48 pub const MIN: Self = Self(RustDecimal::MIN);
50
51 #[must_use]
59 pub fn new(mantissa: i64, scale: u32) -> Self {
60 Self(RustDecimal::new(mantissa, scale))
61 }
62
63 #[must_use]
68 pub const fn from_parts(lo: u32, mid: u32, hi: u32, negative: bool, scale: u32) -> Self {
69 Self(RustDecimal::from_parts(lo, mid, hi, negative, scale))
70 }
71
72 pub fn try_from_i128(value: i128) -> Result<Self, ArithmeticError> {
76 RustDecimal::try_from_i128_with_scale(value, 0)
77 .map(Self)
78 .map_err(|_| ArithmeticError::Overflow)
79 }
80
81 #[must_use]
83 pub fn to_parts(self) -> (i128, u32) {
84 let unpacked = self.0.unpack();
85 let mantissa = i128::from(unpacked.lo)
86 | (i128::from(unpacked.mid) << 32)
87 | (i128::from(unpacked.hi) << 64);
88 let signed = if unpacked.negative {
89 -mantissa
90 } else {
91 mantissa
92 };
93 (signed, unpacked.scale)
94 }
95
96 #[must_use]
98 pub fn scale(self) -> u32 {
99 self.0.scale()
100 }
101
102 #[must_use]
104 pub fn is_zero(self) -> bool {
105 self.0.is_zero()
106 }
107
108 #[must_use]
110 pub fn is_negative(self) -> bool {
111 self.0.is_sign_negative()
112 }
113
114 #[must_use]
116 pub fn is_positive(self) -> bool {
117 self.0.is_sign_positive() && !self.0.is_zero()
118 }
119
120 #[must_use]
122 pub fn abs(self) -> Self {
123 Self(self.0.abs())
124 }
125
126 #[must_use]
128 pub fn signum(self) -> Self {
129 Self(self.0.signum())
130 }
131
132 #[must_use]
134 pub fn checked_add(self, other: Self) -> Option<Self> {
135 self.0.checked_add(other.0).map(Self)
136 }
137
138 #[must_use]
140 pub fn checked_sub(self, other: Self) -> Option<Self> {
141 self.0.checked_sub(other.0).map(Self)
142 }
143
144 #[must_use]
146 pub fn checked_mul(self, other: Self) -> Option<Self> {
147 self.0.checked_mul(other.0).map(Self)
148 }
149
150 #[must_use]
152 pub fn checked_div(self, other: Self) -> Option<Self> {
153 self.0.checked_div(other.0).map(Self)
154 }
155
156 #[must_use]
158 pub fn checked_rem(self, other: Self) -> Option<Self> {
159 self.0.checked_rem(other.0).map(Self)
160 }
161
162 #[must_use]
164 pub fn saturating_add(self, other: Self) -> Self {
165 Self(self.0.saturating_add(other.0))
166 }
167
168 #[must_use]
170 pub fn saturating_sub(self, other: Self) -> Self {
171 Self(self.0.saturating_sub(other.0))
172 }
173
174 #[must_use]
176 pub fn saturating_mul(self, other: Self) -> Self {
177 Self(self.0.saturating_mul(other.0))
178 }
179
180 pub fn try_add(self, other: Self) -> Result<Self, ArithmeticError> {
182 self.checked_add(other).ok_or(ArithmeticError::Overflow)
183 }
184
185 pub fn try_sub(self, other: Self) -> Result<Self, ArithmeticError> {
187 self.checked_sub(other).ok_or(ArithmeticError::Overflow)
188 }
189
190 pub fn try_mul(self, other: Self) -> Result<Self, ArithmeticError> {
192 self.checked_mul(other).ok_or(ArithmeticError::Overflow)
193 }
194
195 pub fn try_div(self, other: Self) -> Result<Self, ArithmeticError> {
197 if other.is_zero() {
198 return Err(ArithmeticError::DivisionByZero);
199 }
200 self.checked_div(other).ok_or(ArithmeticError::Overflow)
201 }
202
203 #[must_use]
205 pub fn round(self, dp: u32, mode: RoundingMode) -> Self {
206 Self(self.0.round_dp_with_strategy(dp, mode.to_rust_decimal()))
207 }
208
209 #[must_use]
211 pub fn round_dp(self, dp: u32) -> Self {
212 self.round(dp, RoundingMode::HalfEven)
213 }
214
215 #[must_use]
217 pub fn trunc(self, dp: u32) -> Self {
218 self.round(dp, RoundingMode::TowardZero)
219 }
220
221 #[must_use]
223 pub fn floor(self) -> Self {
224 Self(self.0.floor())
225 }
226
227 #[must_use]
229 pub fn ceil(self) -> Self {
230 Self(self.0.ceil())
231 }
232
233 #[must_use]
235 pub fn normalize(self) -> Self {
236 Self(self.0.normalize())
237 }
238
239 pub fn rescale(&mut self, scale: u32) -> Result<(), ArithmeticError> {
243 if scale > MAX_SCALE {
244 return Err(ArithmeticError::ScaleExceeded);
245 }
246 self.0.rescale(scale);
247 Ok(())
248 }
249
250 #[must_use]
252 pub fn min(self, other: Self) -> Self {
253 if self <= other {
254 self
255 } else {
256 other
257 }
258 }
259
260 #[must_use]
262 pub fn max(self, other: Self) -> Self {
263 if self >= other {
264 self
265 } else {
266 other
267 }
268 }
269
270 #[must_use]
272 pub fn clamp(self, min: Self, max: Self) -> Self {
273 self.max(min).min(max)
274 }
275
276 #[must_use]
278 pub fn into_inner(self) -> RustDecimal {
279 self.0
280 }
281
282 #[must_use]
284 pub fn from_inner(inner: RustDecimal) -> Self {
285 Self(inner)
286 }
287
288 #[must_use]
308 pub fn sqrt(self) -> Option<Self> {
309 if self.is_negative() {
310 return None;
311 }
312 self.0.sqrt().map(Self)
313 }
314
315 pub fn try_sqrt(self) -> Result<Self, ArithmeticError> {
317 if self.is_negative() {
318 return Err(ArithmeticError::NegativeSqrt);
319 }
320 self.sqrt().ok_or(ArithmeticError::Overflow)
321 }
322
323 #[must_use]
336 pub fn exp(self) -> Option<Self> {
337 if self > Self::from(100i64) {
344 return None; }
346 if self < Self::from(-100i64) {
347 return Some(Self::ZERO); }
349
350 Some(Self(self.0.exp()))
351 }
352
353 pub fn try_exp(self) -> Result<Self, ArithmeticError> {
355 self.exp().ok_or(ArithmeticError::Overflow)
356 }
357
358 #[must_use]
374 pub fn ln(self) -> Option<Self> {
375 if !self.is_positive() {
376 return None;
377 }
378 Some(Self(self.0.ln()))
379 }
380
381 pub fn try_ln(self) -> Result<Self, ArithmeticError> {
383 if self.is_zero() {
384 return Err(ArithmeticError::LogOfZero);
385 }
386 if self.is_negative() {
387 return Err(ArithmeticError::LogOfNegative);
388 }
389 self.ln().ok_or(ArithmeticError::Overflow)
390 }
391
392 #[must_use]
396 pub fn log10(self) -> Option<Self> {
397 if !self.is_positive() {
398 return None;
399 }
400 Some(Self(self.0.log10()))
401 }
402
403 #[must_use]
424 pub fn pow(self, exponent: Self) -> Option<Self> {
425 if exponent.is_zero() {
427 return Some(Self::ONE);
428 }
429 if self.is_zero() {
430 return if exponent.is_positive() {
431 Some(Self::ZERO)
432 } else {
433 None };
435 }
436 if self == Self::ONE {
437 return Some(Self::ONE);
438 }
439
440 if self.is_negative() {
442 if exponent.floor() != exponent {
444 return None; }
446 let abs_base = self.abs();
447 let result = abs_base.ln()?.checked_mul(exponent)?;
448 let exp_result = result.exp()?;
449
450 let exp_int = exponent.floor();
452 let is_odd = (exp_int / Self::from(2i64)).floor() * Self::from(2i64) != exp_int;
453
454 return Some(if is_odd { -exp_result } else { exp_result });
455 }
456
457 let ln_x = self.ln()?;
459 let product = ln_x.checked_mul(exponent)?;
460 product.exp()
461 }
462
463 pub fn try_pow(self, exponent: Self) -> Result<Self, ArithmeticError> {
465 self.pow(exponent).ok_or(ArithmeticError::Overflow)
466 }
467
468 #[must_use]
482 pub fn powi(self, n: i32) -> Option<Self> {
483 if n == 0 {
484 return Some(Self::ONE);
485 }
486
487 let (mut base, mut exp) = if n < 0 {
488 (Self::ONE.checked_div(self)?, (-n) as u32)
489 } else {
490 (self, n as u32)
491 };
492
493 let mut result = Self::ONE;
494
495 while exp > 0 {
496 if exp & 1 == 1 {
497 result = result.checked_mul(base)?;
498 }
499 base = base.checked_mul(base)?;
500 exp >>= 1;
501 }
502
503 Some(result)
504 }
505
506 pub fn try_powi(self, n: i32) -> Result<Self, ArithmeticError> {
508 if n < 0 && self.is_zero() {
509 return Err(ArithmeticError::DivisionByZero);
510 }
511 self.powi(n).ok_or(ArithmeticError::Overflow)
512 }
513
514 pub fn e() -> Self {
516 Self::from_str("2.7182818284590452353602874713527").expect("E constant is valid")
517 }
518
519 pub fn pi() -> Self {
521 Self::from_str("3.1415926535897932384626433832795").expect("PI constant is valid")
522 }
523}
524
525impl Default for Decimal {
526 fn default() -> Self {
527 Self::ZERO
528 }
529}
530
531impl fmt::Debug for Decimal {
532 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
533 write!(f, "Decimal({})", self.0)
534 }
535}
536
537impl fmt::Display for Decimal {
538 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539 fmt::Display::fmt(&self.0, f)
540 }
541}
542
543impl FromStr for Decimal {
544 type Err = ParseError;
545
546 fn from_str(s: &str) -> Result<Self, Self::Err> {
547 if s.is_empty() {
548 return Err(ParseError::Empty);
549 }
550 RustDecimal::from_str(s)
551 .map(Self)
552 .map_err(|_| ParseError::InvalidCharacter)
553 }
554}
555
556impl PartialOrd for Decimal {
557 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
558 Some(self.cmp(other))
559 }
560}
561
562impl Ord for Decimal {
563 fn cmp(&self, other: &Self) -> Ordering {
564 self.0.cmp(&other.0)
565 }
566}
567
568impl Neg for Decimal {
569 type Output = Self;
570
571 fn neg(self) -> Self::Output {
572 Self(-self.0)
573 }
574}
575
576impl Add for Decimal {
577 type Output = Self;
578
579 fn add(self, other: Self) -> Self::Output {
580 self.checked_add(other).expect("decimal overflow")
581 }
582}
583
584impl Sub for Decimal {
585 type Output = Self;
586
587 fn sub(self, other: Self) -> Self::Output {
588 self.checked_sub(other).expect("decimal overflow")
589 }
590}
591
592impl Mul for Decimal {
593 type Output = Self;
594
595 fn mul(self, other: Self) -> Self::Output {
596 self.checked_mul(other).expect("decimal overflow")
597 }
598}
599
600impl Div for Decimal {
601 type Output = Self;
602
603 fn div(self, other: Self) -> Self::Output {
604 self.checked_div(other).expect("decimal division error")
605 }
606}
607
608macro_rules! impl_from_int {
609 ($($t:ty),*) => {
610 $(
611 impl From<$t> for Decimal {
612 fn from(n: $t) -> Self {
613 Self(RustDecimal::from(n))
614 }
615 }
616 )*
617 };
618}
619
620impl_from_int!(i8, i16, i32, i64, u8, u16, u32, u64);
621
622impl From<i128> for Decimal {
623 fn from(n: i128) -> Self {
624 Self(RustDecimal::from(n))
625 }
626}
627
628impl From<u128> for Decimal {
629 fn from(n: u128) -> Self {
630 Self(RustDecimal::from(n))
631 }
632}
633
634#[cfg(test)]
635mod tests {
636 extern crate alloc;
637 use super::*;
638 use alloc::string::ToString;
639
640 #[test]
641 fn zero_identity() {
642 let a = Decimal::from(42i64);
643 assert_eq!(a + Decimal::ZERO, a);
644 assert_eq!(a - Decimal::ZERO, a);
645 assert_eq!(a * Decimal::ZERO, Decimal::ZERO);
646 }
647
648 #[test]
649 fn one_identity() {
650 let a = Decimal::from(42i64);
651 assert_eq!(a * Decimal::ONE, a);
652 assert_eq!(a / Decimal::ONE, a);
653 }
654
655 #[test]
656 fn negation() {
657 let a = Decimal::from(42i64);
658 assert_eq!(-(-a), a);
659 assert_eq!(a + (-a), Decimal::ZERO);
660 }
661
662 #[test]
663 fn basic_arithmetic() {
664 let a = Decimal::new(100, 2);
665 let b = Decimal::new(200, 2);
666 assert_eq!(a + b, Decimal::new(300, 2));
667 assert_eq!(b - a, Decimal::new(100, 2));
668 assert_eq!(a * Decimal::from(2i64), b);
669 assert_eq!(b / Decimal::from(2i64), a);
670 }
671
672 #[test]
673 fn division_precision() {
674 let a = Decimal::from(1i64);
675 let b = Decimal::from(3i64);
676 let result = a / b;
677 assert_eq!(result.round_dp(6), Decimal::from_str("0.333333").unwrap());
678 }
679
680 #[test]
681 fn rounding_modes() {
682 let a = Decimal::from_str("2.5").unwrap();
683 assert_eq!(a.round(0, RoundingMode::HalfEven), Decimal::from(2i64));
684 assert_eq!(a.round(0, RoundingMode::HalfUp), Decimal::from(3i64));
685 assert_eq!(a.round(0, RoundingMode::Down), Decimal::from(2i64));
686 assert_eq!(a.round(0, RoundingMode::Up), Decimal::from(3i64));
687
688 let b = Decimal::from_str("3.5").unwrap();
689 assert_eq!(b.round(0, RoundingMode::HalfEven), Decimal::from(4i64));
690 }
691
692 #[test]
693 fn checked_operations() {
694 assert!(Decimal::MAX.checked_add(Decimal::ONE).is_none());
695 assert!(Decimal::MIN.checked_sub(Decimal::ONE).is_none());
696 assert!(Decimal::ZERO.checked_div(Decimal::ZERO).is_none());
697 }
698
699 #[test]
700 fn try_operations() {
701 assert!(matches!(
702 Decimal::MAX.try_add(Decimal::ONE),
703 Err(ArithmeticError::Overflow)
704 ));
705 assert!(matches!(
706 Decimal::ONE.try_div(Decimal::ZERO),
707 Err(ArithmeticError::DivisionByZero)
708 ));
709 }
710
711 #[test]
712 fn parse_and_display() {
713 let a: Decimal = "123.456".parse().unwrap();
714 assert_eq!(a.to_string(), "123.456");
715
716 let b: Decimal = "-0.001".parse().unwrap();
717 assert_eq!(b.to_string(), "-0.001");
718 }
719
720 #[test]
721 fn ordering() {
722 let a = Decimal::from(1i64);
723 let b = Decimal::from(2i64);
724 assert!(a < b);
725 assert!(b > a);
726 assert_eq!(a.min(b), a);
727 assert_eq!(a.max(b), b);
728 }
729
730 #[test]
731 fn abs_and_signum() {
732 let pos = Decimal::from(5i64);
733 let neg = Decimal::from(-5i64);
734
735 assert_eq!(pos.abs(), pos);
736 assert_eq!(neg.abs(), pos);
737 assert_eq!(pos.signum(), Decimal::ONE);
738 assert_eq!(neg.signum(), Decimal::NEGATIVE_ONE);
739 assert_eq!(Decimal::ZERO.signum(), Decimal::ZERO);
740 }
741
742 #[test]
743 fn clamp() {
744 let min = Decimal::from(0i64);
745 let max = Decimal::from(100i64);
746
747 assert_eq!(Decimal::from(50i64).clamp(min, max), Decimal::from(50i64));
748 assert_eq!(Decimal::from(-10i64).clamp(min, max), min);
749 assert_eq!(Decimal::from(150i64).clamp(min, max), max);
750 }
751
752 #[test]
757 fn sqrt_perfect_squares() {
758 assert_eq!(Decimal::from(4i64).sqrt(), Some(Decimal::from(2i64)));
759 assert_eq!(Decimal::from(9i64).sqrt(), Some(Decimal::from(3i64)));
760 assert_eq!(Decimal::from(16i64).sqrt(), Some(Decimal::from(4i64)));
761 assert_eq!(Decimal::from(100i64).sqrt(), Some(Decimal::from(10i64)));
762 assert_eq!(Decimal::ZERO.sqrt(), Some(Decimal::ZERO));
763 assert_eq!(Decimal::ONE.sqrt(), Some(Decimal::ONE));
764 }
765
766 #[test]
767 fn sqrt_negative_returns_none() {
768 assert_eq!(Decimal::from(-1i64).sqrt(), None);
769 assert_eq!(Decimal::from(-100i64).sqrt(), None);
770 }
771
772 #[test]
773 fn sqrt_non_perfect() {
774 let sqrt2 = Decimal::from(2i64).sqrt().unwrap();
775 let expected = Decimal::from_str("1.4142135623730951").unwrap();
777 let diff = (sqrt2 - expected).abs();
778 assert!(diff < Decimal::from_str("0.0001").unwrap());
779 }
780
781 #[test]
782 fn exp_basic() {
783 assert_eq!(Decimal::ZERO.exp(), Some(Decimal::ONE));
785
786 let e = Decimal::ONE.exp().unwrap();
788 let expected = Decimal::e();
789 let diff = (e - expected).abs();
790 assert!(diff < Decimal::from_str("0.0001").unwrap());
791 }
792
793 #[test]
794 fn exp_overflow_protection() {
795 assert_eq!(Decimal::from(200i64).exp(), None);
797
798 let result = Decimal::from(-200i64).exp();
800 assert_eq!(result, Some(Decimal::ZERO));
801 }
802
803 #[test]
804 fn ln_basic() {
805 assert_eq!(Decimal::ONE.ln(), Some(Decimal::ZERO));
807
808 let e = Decimal::e();
810 let ln_e = e.ln().unwrap();
811 let diff = (ln_e - Decimal::ONE).abs();
812 assert!(diff < Decimal::from_str("0.0001").unwrap());
813 }
814
815 #[test]
816 fn ln_invalid_inputs() {
817 assert_eq!(Decimal::ZERO.ln(), None);
819
820 assert_eq!(Decimal::from(-1i64).ln(), None);
822 }
823
824 #[test]
825 fn exp_ln_inverse() {
826 let x = Decimal::from(5i64);
828 let result = x.ln().unwrap().exp().unwrap();
829 let diff = (result - x).abs();
830 assert!(diff < Decimal::from_str("0.0001").unwrap());
831
832 let y = Decimal::from(2i64);
834 let result2 = y.exp().unwrap().ln().unwrap();
835 let diff2 = (result2 - y).abs();
836 assert!(diff2 < Decimal::from_str("0.0001").unwrap());
837 }
838
839 #[test]
840 fn pow_basic() {
841 let result = Decimal::from(2i64).pow(Decimal::from(3i64)).unwrap();
843 let diff = (result - Decimal::from(8i64)).abs();
844 assert!(diff < Decimal::from_str("0.0001").unwrap());
845
846 assert_eq!(Decimal::from(100i64).pow(Decimal::ZERO), Some(Decimal::ONE));
848
849 let result2 = Decimal::from(42i64).pow(Decimal::ONE).unwrap();
851 let diff2 = (result2 - Decimal::from(42i64)).abs();
852 assert!(diff2 < Decimal::from_str("0.0001").unwrap());
853 }
854
855 #[test]
856 fn pow_fractional_exponent() {
857 let result = Decimal::from(4i64)
859 .pow(Decimal::from_str("0.5").unwrap())
860 .unwrap();
861 let diff = (result - Decimal::from(2i64)).abs();
862 assert!(diff < Decimal::from_str("0.0001").unwrap());
863 }
864
865 #[test]
866 fn constants() {
867 let e = Decimal::e();
869 assert!(e > Decimal::from(2i64));
870 assert!(e < Decimal::from(3i64));
871
872 let pi = Decimal::pi();
873 assert!(pi > Decimal::from(3i64));
874 assert!(pi < Decimal::from(4i64));
875 }
876
877 #[test]
878 fn powi_exact() {
879 assert_eq!(Decimal::from(2i64).powi(0), Some(Decimal::ONE));
881 assert_eq!(Decimal::from(2i64).powi(1), Some(Decimal::from(2i64)));
882 assert_eq!(Decimal::from(2i64).powi(2), Some(Decimal::from(4i64)));
883 assert_eq!(Decimal::from(2i64).powi(3), Some(Decimal::from(8i64)));
884 assert_eq!(Decimal::from(2i64).powi(10), Some(Decimal::from(1024i64)));
885
886 let half = Decimal::from(2i64).powi(-1).unwrap();
888 assert_eq!(half, Decimal::from_str("0.5").unwrap());
889
890 let quarter = Decimal::from(2i64).powi(-2).unwrap();
891 assert_eq!(quarter, Decimal::from_str("0.25").unwrap());
892 }
893}