fast_posit/posit/ops/
mul.rs

1use super::*;
2
3impl<
4  const N: u32,
5  const ES: u32,
6  Int: crate::Int,
7> Posit<N, ES, Int> {
8  #[inline]
9  pub(crate) unsafe fn mul_kernel(x: Decoded<N, ES, Int>, y: Decoded<N, ES, Int>) -> (Decoded<N, ES, Int>, Int) {
10    // Multiplying two numbers in the form `frac × 2^exp` is much easier than adding them. We have
11    //
12    //   (x.frac / FRAC_DENOM * 2^x.exp) * (y.frac / FRAC_DENOM * 2^y.exp)
13    //   = (x.frac * y.frac) / FRAC_DENOM² * 2^(x.exp + y.exp)
14    //   = (x.frac * y.frac / FRAC_DENOM) / FRAC_DENOM * 2^(x.exp + y.exp)
15    //
16    // In other words: the resulting `exp` is just the sum of the `exp`s, and the `frac` is the
17    // product of the `frac`s divided by `FRAC_DENOM`. Since we know `FRAC_DENOM` = `2^FRAC_WIDTH`
18    // = `2^(Int::BITS - 2)`, we can re-arrange the expression one more time:
19    //
20    //   = (x.frac * y.frac / 2^FRAC_WIDTH) / FRAC_DENOM * 2^(x.exp + y.exp)
21    //   = ((x.frac * y.frac) >> Int::BITS) / FRAC_DENOM * 2^(x.exp + y.exp + 2)
22    //
23    // Meaning the result has
24    //
25    //   frac = (x.frac * y.frac) >> Int::BITS
26    //    exp = x.exp + y.exp + 2
27    //
28    // Only a couple other points to keep in mind:
29    //
30    //   - The multiplication must use a type with double the precision of `Int`, so that there is
31    //     no chance of overflow.
32    //   - When we shift the frac right by `Int::BITS`, we must also accumulate the lower
33    //     `Int::BITS` to `sticky`.
34    //   - The `frac` must start with `0b01` or `0b10`, i.e. it must represent a `frac` in the
35    //     range [1., 2.[ or [-2., 1.[, but the result of multiplying the `frac`s may not. When
36    //     that happens, we may need to shift 1 or 2 places left. For example: 1. × 1. = 1., but
37    //     1.5 × 1.5 = 2.25; the former is good, the latter needs an extra shift by 1 to become
38    //     1.125. Of course, if we shift the `frac` left by n places we must subtract n from `exp`.
39    //
40    // Keeping these points in mind, the final result is
41    //
42    //   frac = (x.frac * y.frac) << underflow >> Int::BITS
43    //    exp = x.exp + y.exp + 2 - underflow
44
45    use crate::underlying::Double;
46    let mul = x.frac.doubling_mul(y.frac);
47    // SAFETY: `x.frac` and `y.frac` are not 0, so their product cannot be 0; nor can it ever be MIN
48    let underflow = unsafe { mul.leading_run_minus_one() };  // Can only be 0,1,2, optimise?
49    let (frac, sticky) = (mul << underflow).components_hi_lo();
50    let exp = x.exp + y.exp + Int::ONE + Int::ONE - Int::of_u32(underflow);
51
52    (Decoded{frac, exp}, sticky)
53  }
54
55  pub(crate) fn mul(self, other: Self) -> Self {
56    if self == Self::NAR || other == Self::NAR {
57      Self::NAR
58    } else if self == Self::ZERO || other == Self::ZERO {
59      Self::ZERO
60    } else {
61      // SAFETY: neither `self` nor `other` are 0 or NaR
62      let a = unsafe { self.decode_regular() };
63      let b = unsafe { other.decode_regular() };
64      // SAFETY: `self` and `other` aren't symmetrical
65      let (result, sticky) = unsafe { Self::mul_kernel(a, b) };
66      // SAFETY: `result` does not have an underflowing `frac`
67      unsafe { result.encode_regular_round(sticky) }
68    }
69  }
70}
71
72use core::ops::{Mul, MulAssign};
73
74super::mk_ops!{Mul, MulAssign, mul, mul_assign}
75
76#[cfg(test)]
77mod tests {
78  use super::*;
79  use malachite::rational::Rational;
80
81  #[allow(dead_code)]
82  fn ops() {
83    let mut a = crate::p32::ONE;
84    let mut b = crate::p32::MINUS_ONE;
85    let _ = a * b;
86    let _ = &a * b;
87    let _ = a * &b;
88    let _ = &a * &b;
89    a *= b;
90    b *= &a;
91  }
92
93  /// Aux function: check that `a * b` is rounded correctly.
94  fn is_correct_rounded<const N: u32, const ES: u32, Int: crate::Int>(
95    a: Posit<N, ES, Int>,
96    b: Posit<N, ES, Int>,
97  ) -> bool
98  where
99    Rational: TryFrom<Posit<N, ES, Int>>,
100    <Rational as TryFrom<Posit<N, ES, Int>>>::Error: core::fmt::Debug
101  {
102    let mul_posit = a * b;
103    if let (Ok(a), Ok(b)) = (Rational::try_from(a), Rational::try_from(b)) {
104      let mul_exact = a * b;
105      super::rational::is_correct_rounded(mul_exact, mul_posit)
106    } else {
107      mul_posit == Posit::<N, ES, Int>::NAR
108    }
109  }
110
111  // TODO Factor all these into a macro
112
113  #[test]
114  fn posit_10_0_exhaustive() {
115    for a in Posit::<10, 0, i16>::cases_exhaustive_all() {
116      for b in Posit::<10, 0, i16>::cases_exhaustive_all() {
117        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
118      }
119    }
120  }
121
122  #[test]
123  fn posit_10_1_exhaustive() {
124    for a in Posit::<10, 1, i16>::cases_exhaustive_all() {
125      for b in Posit::<10, 1, i16>::cases_exhaustive_all() {
126        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
127      }
128    }
129  }
130
131  #[test]
132  fn posit_10_2_exhaustive() {
133    for a in Posit::<10, 2, i16>::cases_exhaustive_all() {
134      for b in Posit::<10, 2, i16>::cases_exhaustive_all() {
135        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
136      }
137    }
138  }
139
140  #[test]
141  fn posit_10_3_exhaustive() {
142    for a in Posit::<10, 3, i16>::cases_exhaustive_all() {
143      for b in Posit::<10, 3, i16>::cases_exhaustive_all() {
144        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
145      }
146    }
147  }
148
149  #[test]
150  fn posit_8_0_exhaustive() {
151    for a in Posit::<8, 0, i8>::cases_exhaustive_all() {
152      for b in Posit::<8, 0, i8>::cases_exhaustive_all() {
153        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
154      }
155    }
156  }
157
158  #[test]
159  fn p8_exhaustive() {
160    for a in crate::p8::cases_exhaustive_all() {
161      for b in crate::p8::cases_exhaustive_all() {
162        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
163      }
164    }
165  }
166
167  use proptest::prelude::*;
168  const PROPTEST_CASES: u32 = if cfg!(debug_assertions) {0x1_0000} else {0x80_0000};
169  proptest!{
170    #![proptest_config(ProptestConfig::with_cases(PROPTEST_CASES))]
171
172    #[test]
173    fn p16_proptest(
174      a in crate::p16::cases_proptest(),
175      b in crate::p16::cases_proptest(),
176    ) {
177      assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
178    }
179
180    #[test]
181    fn p32_proptest(
182      a in crate::p32::cases_proptest(),
183      b in crate::p32::cases_proptest(),
184    ) {
185      assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
186    }
187
188    #[test]
189    fn p64_proptest(
190      a in crate::p64::cases_proptest(),
191      b in crate::p64::cases_proptest(),
192    ) {
193      assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
194    }
195  }
196
197  #[test]
198  fn posit_3_0_exhaustive() {
199    for a in Posit::<3, 0, i8>::cases_exhaustive_all() {
200      for b in Posit::<3, 0, i8>::cases_exhaustive_all() {
201        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
202      }
203    }
204  }
205  #[test]
206  fn posit_4_0_exhaustive() {
207    for a in Posit::<4, 0, i8>::cases_exhaustive_all() {
208      for b in Posit::<4, 0, i8>::cases_exhaustive_all() {
209        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
210      }
211    }
212  }
213  #[test]
214  fn posit_4_1_exhaustive() {
215    for a in Posit::<4, 1, i8>::cases_exhaustive_all() {
216      for b in Posit::<4, 1, i8>::cases_exhaustive_all() {
217        assert!(is_correct_rounded(a, b), "{:?} * {:?}", a, b)
218      }
219    }
220  }
221}