1use std::cmp::Ordering;
2use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
3
4use crate::integer::{ScaledInteger, SignedScaledInteger};
5
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[cfg_attr(feature = "borsh", derive(borsh::BorshSerialize, borsh::BorshDeserialize))]
8#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
9#[repr(transparent)]
10pub struct Decimal<I, const D: u8>(pub I);
11
12impl<I, const D: u8> Decimal<I, D>
13where
14 I: ScaledInteger<D>,
15{
16 pub const ZERO: Decimal<I, D> = Decimal(I::ZERO);
17 pub const ONE: Decimal<I, D> = Decimal(I::SCALING_FACTOR);
18 pub const TWO: Decimal<I, D> = Decimal(I::TWO_SCALING_FACTOR);
19 pub const DECIMALS: u8 = D;
20 pub const SCALING_FACTOR: I = I::SCALING_FACTOR;
21
22 #[must_use]
24 pub fn min() -> Self {
25 Decimal(I::min_value())
26 }
27
28 #[must_use]
30 pub fn max() -> Self {
31 Decimal(I::max_value())
32 }
33
34 pub fn try_from_scaled(integer: I, scale: u8) -> Option<Self> {
46 match scale.cmp(&D) {
47 Ordering::Greater => {
48 #[allow(clippy::arithmetic_side_effects)]
50 let divisor = I::TEN.pow(u32::from(scale - D));
51
52 #[allow(clippy::arithmetic_side_effects)]
54 let remainder = integer % divisor;
55 if remainder != I::ZERO {
56 return None;
58 }
59
60 integer.checked_div(&divisor).map(Decimal)
61 }
62 Ordering::Less => {
63 #[allow(clippy::arithmetic_side_effects)]
65 let multiplier = I::TEN.pow(u32::from(D - scale));
66
67 integer.checked_mul(&multiplier).map(Decimal)
68 }
69 Ordering::Equal => Some(Decimal(integer)),
70 }
71 }
72
73 pub fn is_zero(&self) -> bool {
74 self.0 == I::ZERO
75 }
76}
77
78impl<I, const D: u8> Add for Decimal<I, D>
79where
80 I: ScaledInteger<D>,
81{
82 type Output = Self;
83
84 #[inline]
85 fn add(self, rhs: Self) -> Self::Output {
86 Decimal(self.0.checked_add(&rhs.0).unwrap())
87 }
88}
89
90impl<I, const D: u8> Sub for Decimal<I, D>
91where
92 I: ScaledInteger<D>,
93{
94 type Output = Self;
95
96 #[inline]
97 fn sub(self, rhs: Self) -> Self::Output {
98 Decimal(self.0.checked_sub(&rhs.0).unwrap())
99 }
100}
101
102impl<I, const D: u8> Mul for Decimal<I, D>
103where
104 I: ScaledInteger<D>,
105{
106 type Output = Self;
107
108 #[inline]
109 fn mul(self, rhs: Self) -> Self::Output {
110 Decimal(I::full_mul_div(self.0, rhs.0, I::SCALING_FACTOR))
111 }
112}
113
114impl<I, const D: u8> Div for Decimal<I, D>
115where
116 I: ScaledInteger<D>,
117{
118 type Output = Self;
119
120 #[inline]
121 fn div(self, rhs: Self) -> Self::Output {
122 Decimal(I::full_mul_div(self.0, I::SCALING_FACTOR, rhs.0))
123 }
124}
125
126impl<I, const D: u8> Neg for Decimal<I, D>
127where
128 I: SignedScaledInteger<D>,
129{
130 type Output = Self;
131
132 fn neg(self) -> Self::Output {
133 Decimal(self.0.checked_neg().unwrap())
134 }
135}
136
137impl<I, const D: u8> AddAssign for Decimal<I, D>
138where
139 I: ScaledInteger<D>,
140{
141 #[inline]
142 fn add_assign(&mut self, rhs: Self) {
143 *self = Decimal(self.0.checked_add(&rhs.0).unwrap());
144 }
145}
146
147impl<I, const D: u8> SubAssign for Decimal<I, D>
148where
149 I: ScaledInteger<D>,
150{
151 #[inline]
152 fn sub_assign(&mut self, rhs: Self) {
153 *self = Decimal(self.0.checked_sub(&rhs.0).unwrap());
154 }
155}
156
157impl<I, const D: u8> MulAssign for Decimal<I, D>
158where
159 I: ScaledInteger<D>,
160{
161 #[inline]
162 fn mul_assign(&mut self, rhs: Self) {
163 *self = Decimal(I::full_mul_div(self.0, rhs.0, I::SCALING_FACTOR));
164 }
165}
166
167impl<I, const D: u8> DivAssign for Decimal<I, D>
168where
169 I: ScaledInteger<D>,
170{
171 #[inline]
172 fn div_assign(&mut self, rhs: Self) {
173 *self = Decimal(I::full_mul_div(self.0, I::SCALING_FACTOR, rhs.0));
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use std::fmt::Debug;
180 use std::ops::Shr;
181
182 use malachite::num::basic::traits::Zero;
183 use malachite::{Integer, Rational};
184 use paste::paste;
185 use proptest::prelude::*;
186
187 use super::*;
188
189 macro_rules! test_basic_ops {
190 ($underlying:ty, $decimals:literal) => {
191 paste! {
192 #[test]
193 fn [<$underlying _ $decimals _add>]() {
194 assert_eq!(
195 Decimal::<$underlying, $decimals>::ONE + Decimal::ONE,
196 Decimal::TWO,
197 );
198 }
199
200 #[test]
201 fn [<$underlying _ $decimals _sub>]() {
202 assert_eq!(
203 Decimal::<$underlying, $decimals>::ONE - Decimal::ONE,
204 Decimal::ZERO,
205 )
206 }
207
208 #[test]
209 fn [<$underlying _ $decimals _mul>]() {
210 assert_eq!(
211 Decimal::<$underlying, $decimals>::ONE * Decimal::ONE,
212 Decimal::ONE,
213 );
214 }
215
216 #[test]
217 fn [<$underlying _ $decimals _div>]() {
218 assert_eq!(
219 Decimal::<$underlying, $decimals>::ONE / Decimal::ONE,
220 Decimal::ONE,
221 );
222 }
223
224 #[test]
225 fn [<$underlying _ $decimals _mul_min_by_one>]() {
226 assert_eq!(
227 Decimal::min() * Decimal::<$underlying, $decimals>::ONE,
228 Decimal::min()
229 );
230 }
231
232 #[test]
233 fn [<$underlying _ $decimals _div_min_by_one>]() {
234 assert_eq!(
235 Decimal::min() / Decimal::<$underlying, $decimals>::ONE,
236 Decimal::min()
237 );
238 }
239
240 #[test]
241 fn [<$underlying _ $decimals _mul_max_by_one>]() {
242 assert_eq!(
243 Decimal::max() * Decimal::<$underlying, $decimals>::ONE,
244 Decimal::max(),
245 );
246 }
247
248 #[test]
249 fn [<$underlying _ $decimals _div_max_by_one>]() {
250 assert_eq!(
251 Decimal::max() / Decimal::<$underlying, $decimals>::ONE,
252 Decimal::max(),
253 );
254 }
255
256 #[test]
257 fn [<$underlying _ $decimals _add_assign>]() {
258 let mut out = Decimal::<$underlying, $decimals>::ONE;
259 out += Decimal::ONE;
260
261 assert_eq!(out, Decimal::ONE + Decimal::ONE);
262 }
263
264 #[test]
265 fn [<$underlying _ $decimals _sub_assign>]() {
266 let mut out = Decimal::<$underlying, $decimals>::ONE;
267 out -= Decimal::<$underlying, $decimals>::ONE;
268
269 assert_eq!(out, Decimal::ZERO);
270 }
271
272 #[test]
273 fn [<$underlying _ $decimals _mul_assign>]() {
274 let mut out = Decimal::<$underlying, $decimals>::ONE;
275 out *= Decimal::TWO;
276
277 assert_eq!(out, Decimal::ONE + Decimal::ONE);
278 }
279
280 #[test]
281 fn [<$underlying _ $decimals _div_assign>]() {
282 let mut out = Decimal::<$underlying, $decimals>::ONE;
283 out /= Decimal::TWO;
284
285 assert_eq!(out, Decimal::ONE / Decimal::TWO);
286 }
287 }
288 };
289 }
290
291 macro_rules! fuzz_against_primitive {
292 ($primitive:tt, $decimals:literal) => {
293 paste! {
294 proptest! {
295 #[test]
297 fn [<fuzz_primitive_ $primitive _ $decimals _add>](
298 x in $primitive::MIN..$primitive::MAX,
299 y in $primitive::MIN..$primitive::MAX,
300 ) {
301 let decimal = std::panic::catch_unwind(
302 || Decimal::<_, $decimals>(x) + Decimal(y)
303 );
304 let primitive = std::panic::catch_unwind(|| x.checked_add(y).unwrap());
305
306 match (decimal, primitive) {
307 (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
308 (Err(_), Err(_)) => {}
309 (decimal, primitive) => panic!(
310 "Mismatch; decimal={decimal:?}; primitive={primitive:?}"
311 )
312 }
313 }
314
315 #[test]
317 fn [<fuzz_primitive_ $primitive _ $decimals _sub>](
318 x in $primitive::MIN..$primitive::MAX,
319 y in $primitive::MIN..$primitive::MAX,
320 ) {
321 let decimal = std::panic::catch_unwind(
322 || Decimal::<_, $decimals>(x) - Decimal(y)
323 );
324 let primitive = std::panic::catch_unwind(|| x.checked_sub(y).unwrap());
325
326 match (decimal, primitive) {
327 (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
328 (Err(_), Err(_)) => {}
329 (decimal, primitive) => panic!(
330 "Mismatch; decimal={decimal:?}; primitive={primitive:?}",
331 )
332 }
333 }
334
335 #[test]
337 fn [<fuzz_primitive_ $primitive _ $decimals _mul>](
338 x in ($primitive::MIN.shr($primitive::BITS / 2))
339 ..($primitive::MAX.shr($primitive::BITS / 2)),
340 y in ($primitive::MIN.shr($primitive::BITS / 2))
341 ..($primitive::MAX.shr($primitive::BITS / 2)),
342 ) {
343 let decimal = std::panic::catch_unwind(
344 || Decimal::<_, $decimals>(x) * Decimal(y)
345 );
346 let primitive = std::panic::catch_unwind(
347 || x
348 .checked_mul(y)
349 .unwrap()
350 .checked_div($primitive::pow(10, $decimals))
351 .unwrap()
352 );
353
354 match (decimal, primitive) {
355 (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
356 (Err(_), Err(_)) => {}
357 (decimal, primitive) => panic!(
358 "Mismatch; decimal={decimal:?}; primitive={primitive:?}"
359 )
360 }
361 }
362
363 #[test]
365 fn [<fuzz_primitive_ $primitive _ $decimals _div>](
366 x in ($primitive::MIN / $primitive::pow(10, $decimals))
367 ..($primitive::MAX / $primitive::pow(10, $decimals)),
368 y in ($primitive::MIN / $primitive::pow(10, $decimals))
369 ..($primitive::MAX / $primitive::pow(10, $decimals)),
370 ) {
371 let decimal = std::panic::catch_unwind(
372 || Decimal::<_, $decimals>(x) / Decimal(y)
373 );
374 let primitive = std::panic::catch_unwind(
375 || x
376 .checked_mul($primitive::pow(10, $decimals))
377 .unwrap()
378 .checked_div(y)
379 .unwrap()
380 );
381
382 match (decimal, primitive) {
383 (Ok(decimal), Ok(primitive)) => assert_eq!(decimal.0, primitive),
384 (Err(_), Err(_)) => {}
385 (decimal, primitive) => panic!(
386 "Mismatch; decimal={decimal:?}; primitive={primitive:?}"
387 )
388 }
389 }
390 }
391 }
392 };
393 }
394
395 macro_rules! differential_fuzz {
396 ($underlying:ty, $decimals:literal) => {
397 paste! {
398 #[test]
399 fn [<differential_fuzz_ $underlying _ $decimals _add>]() {
400 differential_fuzz_add::<$underlying, $decimals>();
401 }
402
403 #[test]
404 fn [<differential_fuzz_ $underlying _ $decimals _sub>]() {
405 differential_fuzz_sub::<$underlying, $decimals>();
406 }
407
408 #[test]
409 fn [<differential_fuzz_ $underlying _ $decimals _mul>]() {
410 differential_fuzz_mul::<$underlying, $decimals>();
411 }
412
413 #[test]
414 fn [<differential_fuzz_ $underlying _ $decimals _div>]() {
415 differential_fuzz_div::<$underlying, $decimals>();
416 }
417
418 #[test]
419 fn [<differential_fuzz_ $underlying _ $decimals _add_assign>]() {
420 differential_fuzz_add_assign::<$underlying, $decimals>();
421 }
422
423 #[test]
424 fn [<differential_fuzz_ $underlying _ $decimals _sub_assign>]() {
425 differential_fuzz_sub_assign::<$underlying, $decimals>();
426 }
427
428 #[test]
429 fn [<differential_fuzz_ $underlying _ $decimals _mul_assign>]() {
430 differential_fuzz_mul_assign::<$underlying, $decimals>();
431 }
432
433 #[test]
434 fn [<differential_fuzz_ $underlying _ $decimals _div_assign>]() {
435 differential_fuzz_div_assign::<$underlying, $decimals>();
436 }
437
438 #[test]
439 fn [<differential_fuzz_ $underlying _ $decimals _from_scaled>]() {
440 differential_fuzz_from_scaled::<$underlying, $decimals>();
441 }
442 }
443 };
444 }
445
446 fn differential_fuzz_add<I, const D: u8>()
447 where
448 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
449 Rational: From<Decimal<I, D>>,
450 {
451 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
452 let out = match std::panic::catch_unwind(|| a + b) {
453 Ok(out) => out,
454 Err(_) => return Ok(()),
455 };
456 let reference_out = Rational::from(a) + Rational::from(b);
457
458 assert_eq!(Rational::from(out), reference_out);
459 });
460 }
461
462 fn differential_fuzz_sub<I, const D: u8>()
463 where
464 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
465 Rational: From<Decimal<I, D>>,
466 {
467 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
468 let out = match std::panic::catch_unwind(|| a - b) {
469 Ok(out) => out,
470 Err(_) => return Ok(()),
471 };
472 let reference_out = Rational::from(a) - Rational::from(b);
473
474 assert_eq!(Rational::from(out), reference_out);
475 });
476 }
477
478 fn differential_fuzz_mul<I, const D: u8>()
479 where
480 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
481 Rational: From<Decimal<I, D>>,
482 {
483 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
484 let out = match std::panic::catch_unwind(|| a * b) {
485 Ok(out) => out,
486 Err(_) => return Ok(()),
487 };
488 let reference_out = Rational::from(a) * Rational::from(b);
489
490 let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
492 let divisor = Integer::from(reference_out.denominator_ref());
493 if scaling % divisor != Integer::ZERO {
494 return Ok(());
496 }
497
498 assert_eq!(Rational::from(out), reference_out, "{} {a:?} {b:?} {out:?} {reference_out:?}", I::SCALING_FACTOR);
499 });
500 }
501
502 fn differential_fuzz_div<I, const D: u8>()
503 where
504 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
505 Rational: From<Decimal<I, D>>,
506 {
507 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
508 if b == Decimal::ZERO {
509 return Ok(());
510 }
511
512 let out = match std::panic::catch_unwind(|| a / b) {
513 Ok(out) => out,
514 Err(_) => return Ok(()),
515 };
516 let reference_out = Rational::from(a) / Rational::from(b);
517
518 let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
520 let divisor = Integer::from(reference_out.denominator_ref());
521 if scaling % divisor != Integer::ZERO {
522 return Ok(());
524 }
525
526 assert_eq!(Rational::from(out), reference_out);
527 });
528 }
529
530 fn differential_fuzz_add_assign<I, const D: u8>()
531 where
532 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
533 Rational: From<Decimal<I, D>>,
534 {
535 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
536 let out = match std::panic::catch_unwind(|| {
537 let mut out = a;
538 out += b;
539
540 out
541 }) {
542 Ok(out) => out,
543 Err(_) => return Ok(()),
544 };
545 let reference_out = Rational::from(a) + Rational::from(b);
546
547 assert_eq!(Rational::from(out), reference_out);
548 });
549 }
550
551 fn differential_fuzz_sub_assign<I, const D: u8>()
552 where
553 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe,
554 Rational: From<Decimal<I, D>>,
555 {
556 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
557 let out = match std::panic::catch_unwind(|| {
558 let mut out = a;
559 out -= b;
560
561 out
562 }) {
563 Ok(out) => out,
564 Err(_) => return Ok(()),
565 };
566 let reference_out = Rational::from(a) - Rational::from(b);
567
568 assert_eq!(Rational::from(out), reference_out);
569 });
570 }
571
572 fn differential_fuzz_mul_assign<I, const D: u8>()
573 where
574 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
575 Rational: From<Decimal<I, D>>,
576 {
577 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
578 let out = match std::panic::catch_unwind(|| {
579 let mut out = a;
580 out *= b;
581
582 out
583 }) {
584 Ok(out) => out,
585 Err(_) => return Ok(()),
586 };
587 let reference_out = Rational::from(a) * Rational::from(b);
588
589 let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
591 let divisor = Integer::from(reference_out.denominator_ref());
592 if scaling % divisor != Integer::ZERO {
593 return Ok(());
595 }
596
597 assert_eq!(Rational::from(out), reference_out);
598 });
599 }
600
601 fn differential_fuzz_div_assign<I, const D: u8>()
602 where
603 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer>,
604 Rational: From<Decimal<I, D>>,
605 {
606 proptest!(|(a: Decimal<I, D>, b: Decimal<I, D>)| {
607 let out = match std::panic::catch_unwind(|| {
608 let mut out = a;
609 out /= b;
610
611 out
612 }) {
613 Ok(out) => out,
614 Err(_) => return Ok(()),
615 };
616 let reference_out = Rational::from(a) / Rational::from(b);
617
618 let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
620 let divisor = Integer::from(reference_out.denominator_ref());
621 if scaling % divisor != Integer::ZERO {
622 return Ok(());
624 }
625
626 assert_eq!(Rational::from(out), reference_out);
627 });
628 }
629
630 fn differential_fuzz_from_scaled<I, const D: u8>()
631 where
632 I: ScaledInteger<D> + Arbitrary + std::panic::RefUnwindSafe + Into<Integer> + TryInto<u64>,
633 Rational: From<I> + From<Decimal<I, D>>,
634 <I as TryInto<u64>>::Error: Debug,
635 {
636 proptest!(|(integer: I, decimals_percent in 0..100u64)| {
637 let max_decimals: u64 = crate::algorithms::log10(I::max_value()).try_into().unwrap();
638 let decimals = u8::try_from(decimals_percent * max_decimals / 100).unwrap();
639 let scaling = I::TEN.pow(decimals as u32);
640
641 let out = Decimal::try_from_scaled(integer, decimals);
642 let reference_out = Rational::from_integers(integer.into(), scaling.into());
643
644 match out {
645 Some(out) => assert_eq!(Rational::from(out), reference_out),
646 None => {
647 let scaling: Integer = Decimal::<I, D>::SCALING_FACTOR.into();
648 let remainder = &scaling % Integer::from(reference_out.denominator_ref());
649 let information = &reference_out * Rational::from(scaling);
650
651 assert!(
652 remainder != 0
653 || information > Rational::from(I::max_value())
654 || information < Rational::from(I::min_value()) ,
655 "Failed to parse valid input; integer={integer}; input_scale={decimals}; \
656 output_scale={D}",
657 );
658 }
659 }
660 });
661 }
662
663 crate::macros::apply_to_common_variants!(test_basic_ops);
664 crate::macros::apply_to_common_variants!(fuzz_against_primitive);
665 crate::macros::apply_to_common_variants!(differential_fuzz);
666}