1use crate::error::{ArithmeticError, ParseError};
4use crate::rounding::RoundingMode;
5use core::cmp::Ordering;
6use core::fmt;
7use core::ops::{Add, Div, Mul, Neg, Sub};
8use core::str::FromStr;
9use num_traits::Signed;
10use rust_decimal::prelude::MathematicalOps;
11use rust_decimal::Decimal as RustDecimal;
12use serde::{Deserialize, Serialize};
13
14pub const MAX_SCALE: u32 = 28;
16
17#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
23#[serde(transparent)]
24pub struct Decimal(RustDecimal);
25
26impl Decimal {
27 pub const ZERO: Self = Self(RustDecimal::ZERO);
29
30 pub const ONE: Self = Self(RustDecimal::ONE);
32
33 pub const NEGATIVE_ONE: Self = Self(RustDecimal::NEGATIVE_ONE);
35
36 pub const TEN: Self = Self(RustDecimal::TEN);
38
39 pub const ONE_HUNDRED: Self = Self(RustDecimal::ONE_HUNDRED);
41
42 pub const ONE_THOUSAND: Self = Self(RustDecimal::ONE_THOUSAND);
44
45 pub const MAX: Self = Self(RustDecimal::MAX);
47
48 pub const MIN: Self = Self(RustDecimal::MIN);
50
51 #[must_use]
59 pub fn new(mantissa: i64, scale: u32) -> Self {
60 Self(RustDecimal::new(mantissa, scale))
61 }
62
63 #[must_use]
68 pub const fn from_parts(lo: u32, mid: u32, hi: u32, negative: bool, scale: u32) -> Self {
69 Self(RustDecimal::from_parts(lo, mid, hi, negative, scale))
70 }
71
72 pub fn try_from_i128(value: i128) -> Result<Self, ArithmeticError> {
76 RustDecimal::try_from_i128_with_scale(value, 0)
77 .map(Self)
78 .map_err(|_| ArithmeticError::Overflow)
79 }
80
81 #[must_use]
83 pub fn to_parts(self) -> (i128, u32) {
84 let unpacked = self.0.unpack();
85 let mantissa = i128::from(unpacked.lo)
86 | (i128::from(unpacked.mid) << 32)
87 | (i128::from(unpacked.hi) << 64);
88 let signed = if unpacked.negative {
89 -mantissa
90 } else {
91 mantissa
92 };
93 (signed, unpacked.scale)
94 }
95
96 #[must_use]
98 pub fn scale(self) -> u32 {
99 self.0.scale()
100 }
101
102 #[must_use]
104 pub fn is_zero(self) -> bool {
105 self.0.is_zero()
106 }
107
108 #[must_use]
110 pub fn is_negative(self) -> bool {
111 self.0.is_sign_negative()
112 }
113
114 #[must_use]
116 pub fn is_positive(self) -> bool {
117 self.0.is_sign_positive() && !self.0.is_zero()
118 }
119
120 #[must_use]
122 pub fn abs(self) -> Self {
123 Self(self.0.abs())
124 }
125
126 #[must_use]
128 pub fn signum(self) -> Self {
129 Self(self.0.signum())
130 }
131
132 #[must_use]
134 pub fn checked_add(self, other: Self) -> Option<Self> {
135 self.0.checked_add(other.0).map(Self)
136 }
137
138 #[must_use]
140 pub fn checked_sub(self, other: Self) -> Option<Self> {
141 self.0.checked_sub(other.0).map(Self)
142 }
143
144 #[must_use]
146 pub fn checked_mul(self, other: Self) -> Option<Self> {
147 self.0.checked_mul(other.0).map(Self)
148 }
149
150 #[must_use]
152 pub fn checked_div(self, other: Self) -> Option<Self> {
153 self.0.checked_div(other.0).map(Self)
154 }
155
156 #[must_use]
158 pub fn checked_rem(self, other: Self) -> Option<Self> {
159 self.0.checked_rem(other.0).map(Self)
160 }
161
162 #[must_use]
164 pub fn saturating_add(self, other: Self) -> Self {
165 Self(self.0.saturating_add(other.0))
166 }
167
168 #[must_use]
170 pub fn saturating_sub(self, other: Self) -> Self {
171 Self(self.0.saturating_sub(other.0))
172 }
173
174 #[must_use]
176 pub fn saturating_mul(self, other: Self) -> Self {
177 Self(self.0.saturating_mul(other.0))
178 }
179
180 pub fn try_add(self, other: Self) -> Result<Self, ArithmeticError> {
182 self.checked_add(other).ok_or(ArithmeticError::Overflow)
183 }
184
185 pub fn try_sub(self, other: Self) -> Result<Self, ArithmeticError> {
187 self.checked_sub(other).ok_or(ArithmeticError::Overflow)
188 }
189
190 pub fn try_mul(self, other: Self) -> Result<Self, ArithmeticError> {
192 self.checked_mul(other).ok_or(ArithmeticError::Overflow)
193 }
194
195 pub fn try_div(self, other: Self) -> Result<Self, ArithmeticError> {
197 if other.is_zero() {
198 return Err(ArithmeticError::DivisionByZero);
199 }
200 self.checked_div(other).ok_or(ArithmeticError::Overflow)
201 }
202
203 #[must_use]
205 pub fn round(self, dp: u32, mode: RoundingMode) -> Self {
206 Self(self.0.round_dp_with_strategy(dp, mode.to_rust_decimal()))
207 }
208
209 #[must_use]
211 pub fn round_dp(self, dp: u32) -> Self {
212 self.round(dp, RoundingMode::HalfEven)
213 }
214
215 #[must_use]
217 pub fn trunc(self, dp: u32) -> Self {
218 self.round(dp, RoundingMode::TowardZero)
219 }
220
221 #[must_use]
223 pub fn floor(self) -> Self {
224 Self(self.0.floor())
225 }
226
227 #[must_use]
229 pub fn ceil(self) -> Self {
230 Self(self.0.ceil())
231 }
232
233 #[must_use]
235 pub fn normalize(self) -> Self {
236 Self(self.0.normalize())
237 }
238
239 pub fn rescale(&mut self, scale: u32) -> Result<(), ArithmeticError> {
243 if scale > MAX_SCALE {
244 return Err(ArithmeticError::ScaleExceeded);
245 }
246 self.0.rescale(scale);
247 Ok(())
248 }
249
250 #[must_use]
252 pub fn min(self, other: Self) -> Self {
253 if self <= other {
254 self
255 } else {
256 other
257 }
258 }
259
260 #[must_use]
262 pub fn max(self, other: Self) -> Self {
263 if self >= other {
264 self
265 } else {
266 other
267 }
268 }
269
270 #[must_use]
272 pub fn clamp(self, min: Self, max: Self) -> Self {
273 self.max(min).min(max)
274 }
275
276 #[must_use]
278 pub fn into_inner(self) -> RustDecimal {
279 self.0
280 }
281
282 #[must_use]
284 pub fn from_inner(inner: RustDecimal) -> Self {
285 Self(inner)
286 }
287
288 #[must_use]
308 pub fn sqrt(self) -> Option<Self> {
309 if self.is_negative() {
310 return None;
311 }
312 self.0.sqrt().map(Self)
313 }
314
315 pub fn try_sqrt(self) -> Result<Self, ArithmeticError> {
317 if self.is_negative() {
318 return Err(ArithmeticError::NegativeSqrt);
319 }
320 self.sqrt().ok_or(ArithmeticError::Overflow)
321 }
322
323 #[must_use]
336 pub fn exp(self) -> Option<Self> {
337 if self > Self::from(100i64) {
344 return None; }
346 if self < Self::from(-100i64) {
347 return Some(Self::ZERO); }
349
350 Some(Self(self.0.exp()))
351 }
352
353 pub fn try_exp(self) -> Result<Self, ArithmeticError> {
355 self.exp().ok_or(ArithmeticError::Overflow)
356 }
357
358 #[must_use]
374 pub fn ln(self) -> Option<Self> {
375 if !self.is_positive() {
376 return None;
377 }
378 Some(Self(self.0.ln()))
379 }
380
381 pub fn try_ln(self) -> Result<Self, ArithmeticError> {
383 if self.is_zero() {
384 return Err(ArithmeticError::LogOfZero);
385 }
386 if self.is_negative() {
387 return Err(ArithmeticError::LogOfNegative);
388 }
389 self.ln().ok_or(ArithmeticError::Overflow)
390 }
391
392 #[must_use]
396 pub fn log10(self) -> Option<Self> {
397 if !self.is_positive() {
398 return None;
399 }
400 Some(Self(self.0.log10()))
401 }
402
403 #[must_use]
424 pub fn pow(self, exponent: Self) -> Option<Self> {
425 if exponent.is_zero() {
427 return Some(Self::ONE);
428 }
429 if self.is_zero() {
430 return if exponent.is_positive() {
431 Some(Self::ZERO)
432 } else {
433 None };
435 }
436 if self == Self::ONE {
437 return Some(Self::ONE);
438 }
439
440 if self.is_negative() {
442 if exponent.floor() != exponent {
444 return None; }
446 let abs_base = self.abs();
447 let result = abs_base.ln()?.checked_mul(exponent)?;
448 let exp_result = result.exp()?;
449
450 let exp_int = exponent.floor();
452 let is_odd = (exp_int / Self::from(2i64)).floor() * Self::from(2i64) != exp_int;
453
454 return Some(if is_odd { -exp_result } else { exp_result });
455 }
456
457 let ln_x = self.ln()?;
459 let product = ln_x.checked_mul(exponent)?;
460 product.exp()
461 }
462
463 pub fn try_pow(self, exponent: Self) -> Result<Self, ArithmeticError> {
465 self.pow(exponent).ok_or(ArithmeticError::Overflow)
466 }
467
468 #[must_use]
482 pub fn powi(self, n: i32) -> Option<Self> {
483 if n == 0 {
484 return Some(Self::ONE);
485 }
486
487 let (mut base, mut exp) = if n < 0 {
488 (Self::ONE.checked_div(self)?, (-n) as u32)
489 } else {
490 (self, n as u32)
491 };
492
493 let mut result = Self::ONE;
494
495 while exp > 0 {
496 if exp & 1 == 1 {
497 result = result.checked_mul(base)?;
498 }
499 base = base.checked_mul(base)?;
500 exp >>= 1;
501 }
502
503 Some(result)
504 }
505
506 pub fn try_powi(self, n: i32) -> Result<Self, ArithmeticError> {
508 if n < 0 && self.is_zero() {
509 return Err(ArithmeticError::DivisionByZero);
510 }
511 self.powi(n).ok_or(ArithmeticError::Overflow)
512 }
513
514 pub fn e() -> Self {
516 Self::from_str("2.7182818284590452353602874713527")
517 .expect("E constant is valid")
518 }
519
520 pub fn pi() -> Self {
522 Self::from_str("3.1415926535897932384626433832795")
523 .expect("PI constant is valid")
524 }
525}
526
527impl Default for Decimal {
528 fn default() -> Self {
529 Self::ZERO
530 }
531}
532
533impl fmt::Debug for Decimal {
534 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535 write!(f, "Decimal({})", self.0)
536 }
537}
538
539impl fmt::Display for Decimal {
540 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
541 fmt::Display::fmt(&self.0, f)
542 }
543}
544
545impl FromStr for Decimal {
546 type Err = ParseError;
547
548 fn from_str(s: &str) -> Result<Self, Self::Err> {
549 if s.is_empty() {
550 return Err(ParseError::Empty);
551 }
552 RustDecimal::from_str(s)
553 .map(Self)
554 .map_err(|_| ParseError::InvalidCharacter)
555 }
556}
557
558impl PartialOrd for Decimal {
559 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
560 Some(self.cmp(other))
561 }
562}
563
564impl Ord for Decimal {
565 fn cmp(&self, other: &Self) -> Ordering {
566 self.0.cmp(&other.0)
567 }
568}
569
570impl Neg for Decimal {
571 type Output = Self;
572
573 fn neg(self) -> Self::Output {
574 Self(-self.0)
575 }
576}
577
578impl Add for Decimal {
579 type Output = Self;
580
581 fn add(self, other: Self) -> Self::Output {
582 self.checked_add(other).expect("decimal overflow")
583 }
584}
585
586impl Sub for Decimal {
587 type Output = Self;
588
589 fn sub(self, other: Self) -> Self::Output {
590 self.checked_sub(other).expect("decimal overflow")
591 }
592}
593
594impl Mul for Decimal {
595 type Output = Self;
596
597 fn mul(self, other: Self) -> Self::Output {
598 self.checked_mul(other).expect("decimal overflow")
599 }
600}
601
602impl Div for Decimal {
603 type Output = Self;
604
605 fn div(self, other: Self) -> Self::Output {
606 self.checked_div(other).expect("decimal division error")
607 }
608}
609
610macro_rules! impl_from_int {
611 ($($t:ty),*) => {
612 $(
613 impl From<$t> for Decimal {
614 fn from(n: $t) -> Self {
615 Self(RustDecimal::from(n))
616 }
617 }
618 )*
619 };
620}
621
622impl_from_int!(i8, i16, i32, i64, u8, u16, u32, u64);
623
624impl From<i128> for Decimal {
625 fn from(n: i128) -> Self {
626 Self(RustDecimal::from(n))
627 }
628}
629
630impl From<u128> for Decimal {
631 fn from(n: u128) -> Self {
632 Self(RustDecimal::from(n))
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 extern crate alloc;
639 use super::*;
640 use alloc::string::ToString;
641
642 #[test]
643 fn zero_identity() {
644 let a = Decimal::from(42i64);
645 assert_eq!(a + Decimal::ZERO, a);
646 assert_eq!(a - Decimal::ZERO, a);
647 assert_eq!(a * Decimal::ZERO, Decimal::ZERO);
648 }
649
650 #[test]
651 fn one_identity() {
652 let a = Decimal::from(42i64);
653 assert_eq!(a * Decimal::ONE, a);
654 assert_eq!(a / Decimal::ONE, a);
655 }
656
657 #[test]
658 fn negation() {
659 let a = Decimal::from(42i64);
660 assert_eq!(-(-a), a);
661 assert_eq!(a + (-a), Decimal::ZERO);
662 }
663
664 #[test]
665 fn basic_arithmetic() {
666 let a = Decimal::new(100, 2);
667 let b = Decimal::new(200, 2);
668 assert_eq!(a + b, Decimal::new(300, 2));
669 assert_eq!(b - a, Decimal::new(100, 2));
670 assert_eq!(a * Decimal::from(2i64), b);
671 assert_eq!(b / Decimal::from(2i64), a);
672 }
673
674 #[test]
675 fn division_precision() {
676 let a = Decimal::from(1i64);
677 let b = Decimal::from(3i64);
678 let result = a / b;
679 assert_eq!(result.round_dp(6), Decimal::from_str("0.333333").unwrap());
680 }
681
682 #[test]
683 fn rounding_modes() {
684 let a = Decimal::from_str("2.5").unwrap();
685 assert_eq!(a.round(0, RoundingMode::HalfEven), Decimal::from(2i64));
686 assert_eq!(a.round(0, RoundingMode::HalfUp), Decimal::from(3i64));
687 assert_eq!(a.round(0, RoundingMode::Down), Decimal::from(2i64));
688 assert_eq!(a.round(0, RoundingMode::Up), Decimal::from(3i64));
689
690 let b = Decimal::from_str("3.5").unwrap();
691 assert_eq!(b.round(0, RoundingMode::HalfEven), Decimal::from(4i64));
692 }
693
694 #[test]
695 fn checked_operations() {
696 assert!(Decimal::MAX.checked_add(Decimal::ONE).is_none());
697 assert!(Decimal::MIN.checked_sub(Decimal::ONE).is_none());
698 assert!(Decimal::ZERO.checked_div(Decimal::ZERO).is_none());
699 }
700
701 #[test]
702 fn try_operations() {
703 assert!(matches!(
704 Decimal::MAX.try_add(Decimal::ONE),
705 Err(ArithmeticError::Overflow)
706 ));
707 assert!(matches!(
708 Decimal::ONE.try_div(Decimal::ZERO),
709 Err(ArithmeticError::DivisionByZero)
710 ));
711 }
712
713 #[test]
714 fn parse_and_display() {
715 let a: Decimal = "123.456".parse().unwrap();
716 assert_eq!(a.to_string(), "123.456");
717
718 let b: Decimal = "-0.001".parse().unwrap();
719 assert_eq!(b.to_string(), "-0.001");
720 }
721
722 #[test]
723 fn ordering() {
724 let a = Decimal::from(1i64);
725 let b = Decimal::from(2i64);
726 assert!(a < b);
727 assert!(b > a);
728 assert_eq!(a.min(b), a);
729 assert_eq!(a.max(b), b);
730 }
731
732 #[test]
733 fn abs_and_signum() {
734 let pos = Decimal::from(5i64);
735 let neg = Decimal::from(-5i64);
736
737 assert_eq!(pos.abs(), pos);
738 assert_eq!(neg.abs(), pos);
739 assert_eq!(pos.signum(), Decimal::ONE);
740 assert_eq!(neg.signum(), Decimal::NEGATIVE_ONE);
741 assert_eq!(Decimal::ZERO.signum(), Decimal::ZERO);
742 }
743
744 #[test]
745 fn clamp() {
746 let min = Decimal::from(0i64);
747 let max = Decimal::from(100i64);
748
749 assert_eq!(Decimal::from(50i64).clamp(min, max), Decimal::from(50i64));
750 assert_eq!(Decimal::from(-10i64).clamp(min, max), min);
751 assert_eq!(Decimal::from(150i64).clamp(min, max), max);
752 }
753
754 #[test]
759 fn sqrt_perfect_squares() {
760 assert_eq!(Decimal::from(4i64).sqrt(), Some(Decimal::from(2i64)));
761 assert_eq!(Decimal::from(9i64).sqrt(), Some(Decimal::from(3i64)));
762 assert_eq!(Decimal::from(16i64).sqrt(), Some(Decimal::from(4i64)));
763 assert_eq!(Decimal::from(100i64).sqrt(), Some(Decimal::from(10i64)));
764 assert_eq!(Decimal::ZERO.sqrt(), Some(Decimal::ZERO));
765 assert_eq!(Decimal::ONE.sqrt(), Some(Decimal::ONE));
766 }
767
768 #[test]
769 fn sqrt_negative_returns_none() {
770 assert_eq!(Decimal::from(-1i64).sqrt(), None);
771 assert_eq!(Decimal::from(-100i64).sqrt(), None);
772 }
773
774 #[test]
775 fn sqrt_non_perfect() {
776 let sqrt2 = Decimal::from(2i64).sqrt().unwrap();
777 let expected = Decimal::from_str("1.4142135623730951").unwrap();
779 let diff = (sqrt2 - expected).abs();
780 assert!(diff < Decimal::from_str("0.0001").unwrap());
781 }
782
783 #[test]
784 fn exp_basic() {
785 assert_eq!(Decimal::ZERO.exp(), Some(Decimal::ONE));
787
788 let e = Decimal::ONE.exp().unwrap();
790 let expected = Decimal::e();
791 let diff = (e - expected).abs();
792 assert!(diff < Decimal::from_str("0.0001").unwrap());
793 }
794
795 #[test]
796 fn exp_overflow_protection() {
797 assert_eq!(Decimal::from(200i64).exp(), None);
799
800 let result = Decimal::from(-200i64).exp();
802 assert_eq!(result, Some(Decimal::ZERO));
803 }
804
805 #[test]
806 fn ln_basic() {
807 assert_eq!(Decimal::ONE.ln(), Some(Decimal::ZERO));
809
810 let e = Decimal::e();
812 let ln_e = e.ln().unwrap();
813 let diff = (ln_e - Decimal::ONE).abs();
814 assert!(diff < Decimal::from_str("0.0001").unwrap());
815 }
816
817 #[test]
818 fn ln_invalid_inputs() {
819 assert_eq!(Decimal::ZERO.ln(), None);
821
822 assert_eq!(Decimal::from(-1i64).ln(), None);
824 }
825
826 #[test]
827 fn exp_ln_inverse() {
828 let x = Decimal::from(5i64);
830 let result = x.ln().unwrap().exp().unwrap();
831 let diff = (result - x).abs();
832 assert!(diff < Decimal::from_str("0.0001").unwrap());
833
834 let y = Decimal::from(2i64);
836 let result2 = y.exp().unwrap().ln().unwrap();
837 let diff2 = (result2 - y).abs();
838 assert!(diff2 < Decimal::from_str("0.0001").unwrap());
839 }
840
841 #[test]
842 fn pow_basic() {
843 let result = Decimal::from(2i64).pow(Decimal::from(3i64)).unwrap();
845 let diff = (result - Decimal::from(8i64)).abs();
846 assert!(diff < Decimal::from_str("0.0001").unwrap());
847
848 assert_eq!(
850 Decimal::from(100i64).pow(Decimal::ZERO),
851 Some(Decimal::ONE)
852 );
853
854 let result2 = Decimal::from(42i64).pow(Decimal::ONE).unwrap();
856 let diff2 = (result2 - Decimal::from(42i64)).abs();
857 assert!(diff2 < Decimal::from_str("0.0001").unwrap());
858 }
859
860 #[test]
861 fn pow_fractional_exponent() {
862 let result = Decimal::from(4i64)
864 .pow(Decimal::from_str("0.5").unwrap())
865 .unwrap();
866 let diff = (result - Decimal::from(2i64)).abs();
867 assert!(diff < Decimal::from_str("0.0001").unwrap());
868 }
869
870 #[test]
871 fn constants() {
872 let e = Decimal::e();
874 assert!(e > Decimal::from(2i64));
875 assert!(e < Decimal::from(3i64));
876
877 let pi = Decimal::pi();
878 assert!(pi > Decimal::from(3i64));
879 assert!(pi < Decimal::from(4i64));
880 }
881
882 #[test]
883 fn powi_exact() {
884 assert_eq!(Decimal::from(2i64).powi(0), Some(Decimal::ONE));
886 assert_eq!(Decimal::from(2i64).powi(1), Some(Decimal::from(2i64)));
887 assert_eq!(Decimal::from(2i64).powi(2), Some(Decimal::from(4i64)));
888 assert_eq!(Decimal::from(2i64).powi(3), Some(Decimal::from(8i64)));
889 assert_eq!(Decimal::from(2i64).powi(10), Some(Decimal::from(1024i64)));
890
891 let half = Decimal::from(2i64).powi(-1).unwrap();
893 assert_eq!(half, Decimal::from_str("0.5").unwrap());
894
895 let quarter = Decimal::from(2i64).powi(-2).unwrap();
896 assert_eq!(quarter, Decimal::from_str("0.25").unwrap());
897 }
898}