const_decimal/
decimal.rs

1use std::cmp::Ordering;
2use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
3
4use crate::integer::{ScaledInteger, SignedScaledInteger};
5
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[cfg_attr(feature = "borsh", derive(borsh::BorshSerialize, borsh::BorshDeserialize))]
8#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9#[repr(transparent)]
10pub struct Decimal<I, const D: u8>(pub I);
11
12impl<I, const D: u8> Decimal<I, D>
13where
14    I: ScaledInteger<D>,
15{
16    pub const ZERO: Decimal<I, D> = Decimal(I::ZERO);
17    pub const ONE: Decimal<I, D> = Decimal(I::SCALING_FACTOR);
18    pub const TWO: Decimal<I, D> = Decimal(I::TWO_SCALING_FACTOR);
19    pub const DECIMALS: u8 = D;
20    pub const SCALING_FACTOR: I = I::SCALING_FACTOR;
21
22    // TODO: See if we can generate a constant.
23    #[must_use]
24    pub fn min() -> Self {
25        Decimal(I::min_value())
26    }
27
28    // TODO: See if we can generate a constant.
29    #[must_use]
30    pub fn max() -> Self {
31        Decimal(I::max_value())
32    }
33
34    /// Losslessly converts a scaled integer to this type.
35    ///
36    /// # Examples
37    ///
38    /// ```rust
39    /// use const_decimal::Decimal;
40    ///
41    /// let five = Decimal::<u64, 3>::try_from_scaled(5, 0).unwrap();
42    /// assert_eq!(five, Decimal::TWO + Decimal::TWO + Decimal::ONE);
43    /// assert_eq!(five.0, 5000);
44    /// ```
45    pub fn try_from_scaled(integer: I, scale: u8) -> Option<Self> {
46        match scale.cmp(&D) {
47            Ordering::Greater => {
48                // SAFETY: We know `scale > D` so this cannot underflow.
49                #[allow(clippy::arithmetic_side_effects)]
50                let divisor = I::TEN.pow(u32::from(scale - D));
51
52                // SAFETY: `divisor` cannot be zero as `x.pow(y)` cannot return 0.
53                #[allow(clippy::arithmetic_side_effects)]
54                let remainder = integer % divisor;
55                if remainder != I::ZERO {
56                    // NB: Cast would lose precision.
57                    return None;
58                }
59
60                integer.checked_div(&divisor).map(Decimal)
61            }
62            Ordering::Less => {
63                // SAFETY: We know `scale < D` so this cannot underflow.
64                #[allow(clippy::arithmetic_side_effects)]
65                let multiplier = I::TEN.pow(u32::from(D - scale));
66
67                integer.checked_mul(&multiplier).map(Decimal)
68            }
69            Ordering::Equal => Some(Decimal(integer)),
70        }
71    }
72
73    pub fn is_zero(&self) -> bool {
74        self.0 == I::ZERO
75    }
76}
77
78impl<I, const D: u8> Add for Decimal<I, D>
79where
80    I: ScaledInteger<D>,
81{
82    type Output = Self;
83
84    #[inline]
85    fn add(self, rhs: Self) -> Self::Output {
86        Decimal(self.0.checked_add(&rhs.0).unwrap())
87    }
88}
89
90impl<I, const D: u8> Sub for Decimal<I, D>
91where
92    I: ScaledInteger<D>,
93{
94    type Output = Self;
95
96    #[inline]
97    fn sub(self, rhs: Self) -> Self::Output {
98        Decimal(self.0.checked_sub(&rhs.0).unwrap())
99    }
100}
101
102impl<I, const D: u8> Mul for Decimal<I, D>
103where
104    I: ScaledInteger<D>,
105{
106    type Output = Self;
107
108    #[inline]
109    fn mul(self, rhs: Self) -> Self::Output {
110        Decimal(I::full_mul_div(self.0, rhs.0, I::SCALING_FACTOR))
111    }
112}
113
114impl<I, const D: u8> Div for Decimal<I, D>
115where
116    I: ScaledInteger<D>,
117{
118    type Output = Self;
119
120    #[inline]
121    fn div(self, rhs: Self) -> Self::Output {
122        Decimal(I::full_mul_div(self.0, I::SCALING_FACTOR, rhs.0))
123    }
124}
125
126impl<I, const D: u8> Neg for Decimal<I, D>
127where
128    I: SignedScaledInteger<D>,
129{
130    type Output = Self;
131
132    fn neg(self) -> Self::Output {
133        Decimal(self.0.checked_neg().unwrap())
134    }
135}
136
137impl<I, const D: u8> AddAssign for Decimal<I, D>
138where
139    I: ScaledInteger<D>,
140{
141    #[inline]
142    fn add_assign(&mut self, rhs: Self) {
143        *self = Decimal(self.0.checked_add(&rhs.0).unwrap());
144    }
145}
146
147impl<I, const D: u8> SubAssign for Decimal<I, D>
148where
149    I: ScaledInteger<D>,
150{
151    #[inline]
152    fn sub_assign(&mut self, rhs: Self) {
153        *self = Decimal(self.0.checked_sub(&rhs.0).unwrap());
154    }
155}
156
157impl<I, const D: u8> MulAssign for Decimal<I, D>
158where
159    I: ScaledInteger<D>,
160{
161    #[inline]
162    fn mul_assign(&mut self, rhs: Self) {
163        *self = Decimal(I::full_mul_div(self.0, rhs.0, I::SCALING_FACTOR));
164    }
165}
166
167impl<I, const D: u8> DivAssign for Decimal<I, D>
168where
169    I: ScaledInteger<D>,
170{
171    #[inline]
172    fn div_assign(&mut self, rhs: Self) {
173        *self = Decimal(I::full_mul_div(self.0, I::SCALING_FACTOR, rhs.0));
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use std::fmt::Debug;
180    use std::ops::Shr;
181
182    use malachite::num::basic::traits::Zero;
183    use malachite::{Integer, Rational};
184    use paste::paste;
185    use proptest::prelude::*;
186
187    use super::*;
188
189    macro_rules! test_basic_ops {
190        ($underlying:ty, $decimals:literal) => {
191            paste! {
192                #[test]
193                fn [<$underlying _ $decimals _add>]() {
194                    assert_eq!(
195                        Decimal::<$underlying, $decimals>::ONE + Decimal::ONE,
196                        Decimal::TWO,
197                    );
198                }
199
200                #[test]
201                fn [<$underlying _ $decimals _sub>]() {
202                    assert_eq!(
203                        Decimal::<$underlying, $decimals>::ONE - Decimal::ONE,
204                        Decimal::ZERO,
205                    )
206                }
207
208                #[test]
209                fn [<$underlying _ $decimals _mul>]() {
210                    assert_eq!(
211                        Decimal::<$underlying, $decimals>::ONE * Decimal::ONE,
212                        Decimal::ONE,
213                    );
214                }
215
216                #[test]
217                fn [<$underlying _ $decimals _div>]() {
218                    assert_eq!(
219                        Decimal::<$underlying, $decimals>::ONE / Decimal::ONE,
220                        Decimal::ONE,
221                    );
222                }
223
224                #[test]
225                fn [<$underlying _ $decimals _mul_min_by_one>]() {
226                    assert_eq!(
227                        Decimal::min() * Decimal::<$underlying, $decimals>::ONE,
228                        Decimal::min()
229                    );
230                }
231
232                #[test]
233                fn [<$underlying _ $decimals _div_min_by_one>]() {
234                    assert_eq!(
235                        Decimal::min() / Decimal::<$underlying, $decimals>::ONE,
236                        Decimal::min()
237                    );
238                }
239
240                #[test]
241                fn [<$underlying _ $decimals _mul_max_by_one>]() {
242                    assert_eq!(
243                        Decimal::max() * Decimal::<$underlying, $decimals>::ONE,
244                        Decimal::max(),
245                    );
246                }
247
248                #[test]
249                fn [<$underlying _ $decimals _div_max_by_one>]() {
250                    assert_eq!(
251                        Decimal::max() / Decimal::<$underlying, $decimals>::ONE,
252                        Decimal::max(),
253                    );
254                }
255
256                #[test]
257                fn [<$underlying _ $decimals _add_assign>]() {
258                    let mut out = Decimal::<$underlying, $decimals>::ONE;
259                    out += Decimal::ONE;
260
261                    assert_eq!(out, Decimal::ONE + Decimal::ONE);
262                }
263
264                #[test]
265                fn [<$underlying _ $decimals _sub_assign>]() {
266                    let mut out = Decimal::<$underlying, $decimals>::ONE;
267                    out -= Decimal::<$underlying, $decimals>::ONE;
268
269                    assert_eq!(out, Decimal::ZERO);
270                }
271
272                #[test]
273                fn [<$underlying _ $decimals _mul_assign>]() {
274                    let mut out = Decimal::<$underlying, $decimals>::ONE;
275                    out *= Decimal::TWO;
276
277                    assert_eq!(out, Decimal::ONE + Decimal::ONE);
278                }
279
280                #[test]
281                fn [<$underlying _ $decimals _div_assign>]() {
282                    let mut out = Decimal::<$underlying, $decimals>::ONE;
283                    out /= Decimal::TWO;
284
285                    assert_eq!(out, Decimal::ONE / Decimal::TWO);
286                }
287            }
288        };
289    }
290
291    macro_rules! fuzz_against_primitive {
292        ($primitive:tt, $decimals:literal) => {
293            paste! {
294                proptest! {
295                    /// Addition functions the same as regular unsigned integer addition.
296                    #[test]
297                    fn [<fuzz_primitive_ $primitive _ $decimals _add>](
298                        x in $primitive::MIN..$primitive::MAX,
299                        y in $primitive::MIN..$primitive::MAX,
300                    ) {
301                        let decimal = std::panic::catch_unwind(
302                            || Decimal::<_, $decimals>(x) + Decimal(y)
303                        );
304                        let primitive = std::panic::catch_unwind(|| x.checked_add(y).unwrap());
305
306                        match (decimal, primitive) {
307                            (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
308                            (Err(_), Err(_)) => {}
309                            (decimal, primitive) => panic!(
310                                "Mismatch; decimal={decimal:?}; primitive={primitive:?}"
311                            )
312                        }
313                    }
314
315                    /// Subtraction functions the same as regular unsigned integer addition.
316                    #[test]
317                    fn [<fuzz_primitive_ $primitive _ $decimals _sub>](
318                        x in $primitive::MIN..$primitive::MAX,
319                        y in $primitive::MIN..$primitive::MAX,
320                    ) {
321                        let decimal = std::panic::catch_unwind(
322                            || Decimal::<_, $decimals>(x) - Decimal(y)
323                        );
324                        let primitive = std::panic::catch_unwind(|| x.checked_sub(y).unwrap());
325
326                        match (decimal, primitive) {
327                            (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
328                            (Err(_), Err(_)) => {}
329                            (decimal, primitive) => panic!(
330                                "Mismatch; decimal={decimal:?}; primitive={primitive:?}",
331                            )
332                        }
333                    }
334
335                    /// Multiplication requires the result to be divided by the scaling factor.
336                    #[test]
337                    fn [<fuzz_primitive_ $primitive _ $decimals _mul>](
338                        x in ($primitive::MIN.shr($primitive::BITS / 2))
339                            ..($primitive::MAX.shr($primitive::BITS / 2)),
340                        y in ($primitive::MIN.shr($primitive::BITS / 2))
341                            ..($primitive::MAX.shr($primitive::BITS / 2)),
342                    ) {
343                        let decimal = std::panic::catch_unwind(
344                            || Decimal::<_, $decimals>(x) * Decimal(y)
345                        );
346                        let primitive = std::panic::catch_unwind(
347                            || x
348                                .checked_mul(y)
349                                .unwrap()
350                                .checked_div($primitive::pow(10, $decimals))
351                                .unwrap()
352                        );
353
354                        match (decimal, primitive) {
355                            (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
356                            (Err(_), Err(_)) => {}
357                            (decimal, primitive) => panic!(
358                                "Mismatch; decimal={decimal:?}; primitive={primitive:?}"
359                            )
360                        }
361                    }
362
363                    /// Division requires the numerator to first be scaled by the scaling factor.
364                    #[test]
365                    fn [<fuzz_primitive_ $primitive _ $decimals _div>](
366                        x in ($primitive::MIN / $primitive::pow(10, $decimals))
367                            ..($primitive::MAX / $primitive::pow(10, $decimals)),
368                        y in ($primitive::MIN / $primitive::pow(10, $decimals))
369                            ..($primitive::MAX / $primitive::pow(10, $decimals)),
370                    ) {
371                        let decimal = std::panic::catch_unwind(
372                            || Decimal::<_, $decimals>(x) / Decimal(y)
373                        );
374                        let primitive = std::panic::catch_unwind(
375                            || x
376                                .checked_mul($primitive::pow(10, $decimals))
377                                .unwrap()
378                                .checked_div(y)
379                                .unwrap()
380                        );
381
382                        match (decimal, primitive) {
383                            (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
384                            (Err(_), Err(_)) => {}
385                            (decimal, primitive) => panic!(
386                                "Mismatch; decimal={decimal:?}; primitive={primitive:?}"
387                            )
388                        }
389                    }
390                }
391            }
392        };
393    }
394
395    macro_rules! differential_fuzz {
396        ($underlying:ty, $decimals:literal) => {
397            paste! {
398                #[test]
399                fn [<differential_fuzz_ $underlying _ $decimals _add>]() {
400                    differential_fuzz_add::<$underlying, $decimals>();
401                }
402
403                #[test]
404                fn [<differential_fuzz_ $underlying _ $decimals _sub>]() {
405                    differential_fuzz_sub::<$underlying, $decimals>();
406                }
407
408                #[test]
409                fn [<differential_fuzz_ $underlying _ $decimals _mul>]() {
410                    differential_fuzz_mul::<$underlying, $decimals>();
411                }
412
413                #[test]
414                fn [<differential_fuzz_ $underlying _ $decimals _div>]() {
415                    differential_fuzz_div::<$underlying, $decimals>();
416                }
417
418                #[test]
419                fn [<differential_fuzz_ $underlying _ $decimals _add_assign>]() {
420                    differential_fuzz_add_assign::<$underlying, $decimals>();
421                }
422
423                #[test]
424                fn [<differential_fuzz_ $underlying _ $decimals _sub_assign>]() {
425                    differential_fuzz_sub_assign::<$underlying, $decimals>();
426                }
427
428                #[test]
429                fn [<differential_fuzz_ $underlying _ $decimals _mul_assign>]() {
430                    differential_fuzz_mul_assign::<$underlying, $decimals>();
431                }
432
433                #[test]
434                fn [<differential_fuzz_ $underlying _ $decimals _div_assign>]() {
435                    differential_fuzz_div_assign::<$underlying, $decimals>();
436                }
437
438                #[test]
439                fn [<differential_fuzz_ $underlying _ $decimals _from_scaled>]() {
440                    differential_fuzz_from_scaled::<$underlying, $decimals>();
441                }
442            }
443        };
444    }
445
446    fn differential_fuzz_add<I, const D: u8>()
447    where
448        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
449        Rational: From<Decimal<I, D>>,
450    {
451        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
452            let out = match std::panic::catch_unwind(|| a + b) {
453                Ok(out) => out,
454                Err(_) => return Ok(()),
455            };
456            let reference_out = Rational::from(a) + Rational::from(b);
457
458            assert_eq!(Rational::from(out), reference_out);
459        });
460    }
461
462    fn differential_fuzz_sub<I, const D: u8>()
463    where
464        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
465        Rational: From<Decimal<I, D>>,
466    {
467        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
468            let out = match std::panic::catch_unwind(|| a - b) {
469                Ok(out) => out,
470                Err(_) => return Ok(()),
471            };
472            let reference_out = Rational::from(a) - Rational::from(b);
473
474            assert_eq!(Rational::from(out), reference_out);
475        });
476    }
477
478    fn differential_fuzz_mul<I, const D: u8>()
479    where
480        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
481        Rational: From<Decimal<I, D>>,
482    {
483        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
484            let out = match std::panic::catch_unwind(|| a * b) {
485                Ok(out) => out,
486                Err(_) => return Ok(()),
487            };
488            let reference_out = Rational::from(a) * Rational::from(b);
489
490            // If the multiplication contains truncation ignore it.
491            let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
492            let divisor = Integer::from(reference_out.denominator_ref());
493            if scaling % divisor != Integer::ZERO {
494                // TODO: Can we assert they are within N of each other?
495                return Ok(());
496            }
497
498            assert_eq!(Rational::from(out), reference_out, "{} {a:?} {b:?} {out:?} {reference_out:?}", I::SCALING_FACTOR);
499        });
500    }
501
502    fn differential_fuzz_div<I, const D: u8>()
503    where
504        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
505        Rational: From<Decimal<I, D>>,
506    {
507        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
508            if b == Decimal::ZERO {
509                return Ok(());
510            }
511
512            let out = match std::panic::catch_unwind(|| a / b) {
513                Ok(out) => out,
514                Err(_) => return Ok(()),
515            };
516            let reference_out = Rational::from(a) / Rational::from(b);
517
518            // If the division contains truncation ignore it.
519            let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
520            let divisor = Integer::from(reference_out.denominator_ref());
521            if scaling % divisor != Integer::ZERO {
522                // TODO: Can we assert they are within N of each other?
523                return Ok(());
524            }
525
526            assert_eq!(Rational::from(out), reference_out);
527        });
528    }
529
530    fn differential_fuzz_add_assign<I, const D: u8>()
531    where
532        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
533        Rational: From<Decimal<I, D>>,
534    {
535        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
536            let out = match std::panic::catch_unwind(|| {
537                let mut out = a;
538                out += b;
539
540                out
541            }) {
542                Ok(out) => out,
543                Err(_) => return Ok(()),
544            };
545            let reference_out = Rational::from(a) + Rational::from(b);
546
547            assert_eq!(Rational::from(out), reference_out);
548        });
549    }
550
551    fn differential_fuzz_sub_assign<I, const D: u8>()
552    where
553        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
554        Rational: From<Decimal<I, D>>,
555    {
556        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
557            let out = match std::panic::catch_unwind(|| {
558                let mut out = a;
559                out -= b;
560
561                out
562            }) {
563                Ok(out) => out,
564                Err(_) => return Ok(()),
565            };
566            let reference_out = Rational::from(a) - Rational::from(b);
567
568            assert_eq!(Rational::from(out), reference_out);
569        });
570    }
571
572    fn differential_fuzz_mul_assign<I, const D: u8>()
573    where
574        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
575        Rational: From<Decimal<I, D>>,
576    {
577        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
578            let out = match std::panic::catch_unwind(|| {
579                let mut out = a;
580                out *= b;
581
582                out
583            }) {
584                Ok(out) => out,
585                Err(_) => return Ok(()),
586            };
587            let reference_out = Rational::from(a) * Rational::from(b);
588
589            // If the multiplication contains truncation ignore it.
590            let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
591            let divisor = Integer::from(reference_out.denominator_ref());
592            if scaling % divisor != Integer::ZERO {
593                // TODO: Can we assert they are within N of each other?
594                return Ok(());
595            }
596
597            assert_eq!(Rational::from(out), reference_out);
598        });
599    }
600
601    fn differential_fuzz_div_assign<I, const D: u8>()
602    where
603        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
604        Rational: From<Decimal<I, D>>,
605    {
606        proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
607            let out = match std::panic::catch_unwind(|| {
608                let mut out = a;
609                out /= b;
610
611                out
612            }) {
613                Ok(out) => out,
614                Err(_) => return Ok(()),
615            };
616            let reference_out = Rational::from(a) / Rational::from(b);
617
618            // If the division contains truncation ignore it.
619            let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
620            let divisor = Integer::from(reference_out.denominator_ref());
621            if scaling % divisor != Integer::ZERO {
622                // TODO: Can we assert they are within N of each other?
623                return Ok(());
624            }
625
626            assert_eq!(Rational::from(out), reference_out);
627        });
628    }
629
630    fn differential_fuzz_from_scaled<I, const D: u8>()
631    where
632        I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer> + TryInto<u64>,
633        Rational: From<I> + From<Decimal<I, D>>,
634        <I as TryInto<u64>>::Error: Debug,
635    {
636        proptest!(|(integer: I, decimals_percent in 0..100u64)| {
637            let max_decimals: u64 = crate::algorithms::log10(I::max_value()).try_into().unwrap();
638            let decimals = u8::try_from(decimals_percent * max_decimals / 100).unwrap();
639            let scaling = I::TEN.pow(decimals as u32);
640
641            let out = Decimal::try_from_scaled(integer, decimals);
642            let reference_out = Rational::from_integers(integer.into(), scaling.into());
643
644            match out {
645                Some(out) => assert_eq!(Rational::from(out), reference_out),
646                None => {
647                    let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
648                    let remainder = &scaling % Integer::from(reference_out.denominator_ref());
649                    let information = &reference_out * Rational::from(scaling);
650
651                    assert!(
652                        remainder != 0
653                            || information > Rational::from(I::max_value())
654                            || information < Rational::from(I::min_value()) ,
655                        "Failed to parse valid input; integer={integer}; input_scale={decimals}; \
656                        output_scale={D}",
657                    );
658                }
659            }
660        });
661    }
662
663    crate::macros::apply_to_common_variants!(test_basic_ops);
664    crate::macros::apply_to_common_variants!(fuzz_against_primitive);
665    crate::macros::apply_to_common_variants!(differential_fuzz);
666}