arpfloat/
arithmetic.rs

1//! This module contains the implementation of the basic arithmetic operations:
2//! Addition, Subtraction, Multiplication, Division.
3extern crate alloc;
4use crate::bigint::BigInt;
5
6use super::bigint::LossFraction;
7use super::float::{Category, Float, RoundingMode};
8use core::cmp::Ordering;
9use core::ops::{
10    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign,
11};
12
13impl Float {
14    /// An inner function that performs the addition and subtraction of normal
15    /// numbers (no NaN, Inf, Zeros).
16    /// See Pg 247.  Chapter 8. Algorithms for the Five Basic Operations.
17    /// This implementation follows the APFloat implementation, that does not
18    /// swap the operands.
19    fn add_or_sub_normals(
20        a: &Self,
21        b: &Self,
22        subtract: bool,
23    ) -> (Self, LossFraction) {
24        debug_assert_eq!(a.get_semantics(), b.get_semantics());
25        let sem = a.get_semantics();
26        let loss;
27        let mut a = a.clone();
28        let mut b = b.clone();
29
30        // Align the input numbers on the same exponent.
31        let bits = a.get_exp() - b.get_exp();
32
33        // Can transform (a-b) to (a + -b), either way, there are cases where
34        // subtraction needs to happen.
35        let subtract = subtract ^ (a.get_sign() ^ b.get_sign());
36        if subtract {
37            // Align the input numbers. We shift LHS one bit to the left to
38            // allow carry/borrow in case of underflow as result of subtraction.
39            match bits.cmp(&0) {
40                Ordering::Equal => {
41                    loss = LossFraction::ExactlyZero;
42                }
43                Ordering::Greater => {
44                    loss = b.shift_significand_right((bits - 1) as u64);
45                    a.shift_significand_left(1);
46                }
47                Ordering::Less => {
48                    loss = a.shift_significand_right((-bits - 1) as u64);
49                    b.shift_significand_left(1);
50                }
51            }
52
53            let a_mantissa = a.get_mantissa();
54            let b_mantissa = b.get_mantissa();
55            let ab_mantissa;
56            let mut sign = a.get_sign();
57
58            // Figure out the carry from the shifting operations that dropped
59            // bits.
60            let c = !loss.is_exactly_zero() as u64;
61            let c = BigInt::from_u64(c);
62
63            // Figure out which mantissa is larger, to make sure that we don't
64            // overflow the subtraction.
65            if a_mantissa < b_mantissa {
66                // A < B
67                ab_mantissa = b_mantissa - a_mantissa - c;
68                sign = !sign;
69            } else {
70                // A >= B
71                ab_mantissa = a_mantissa - b_mantissa - c;
72            }
73            (
74                Self::from_parts(sem, sign, a.get_exp(), ab_mantissa),
75                loss.invert(),
76            )
77        } else {
78            // Handle the easy case of Add:
79            let mut b = b.clone();
80            let mut a = a.clone();
81            if bits > 0 {
82                loss = b.shift_significand_right(bits as u64);
83            } else {
84                loss = a.shift_significand_right(-bits as u64);
85            }
86            debug_assert!(a.get_exp() == b.get_exp());
87            let ab_mantissa = a.get_mantissa() + b.get_mantissa();
88            (
89                Self::from_parts(sem, a.get_sign(), a.get_exp(), ab_mantissa),
90                loss,
91            )
92        }
93    }
94
95    /// Computes a+b using the rounding mode `rm`.
96    pub fn add_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
97        Self::add_sub(a, b, false, rm)
98    }
99    /// Computes a-b using the rounding mode `rm`.
100    pub fn sub_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
101        Self::add_sub(a, b, true, rm)
102    }
103
104    fn add_sub(a: &Self, b: &Self, subtract: bool, rm: RoundingMode) -> Self {
105        let sem = a.get_semantics();
106        // Table 8.2: Specification of addition for positive floating-point
107        // data. Pg 247.
108        match (a.get_category(), b.get_category()) {
109            (Category::NaN, Category::Infinity)
110            | (Category::NaN, Category::NaN)
111            | (Category::NaN, Category::Normal)
112            | (Category::NaN, Category::Zero)
113            | (Category::Normal, Category::Zero)
114            | (Category::Infinity, Category::Normal)
115            | (Category::Infinity, Category::Zero) => a.clone(),
116
117            (Category::Zero, Category::NaN)
118            | (Category::Normal, Category::NaN)
119            | (Category::Infinity, Category::NaN) => {
120                Self::nan(sem, b.get_sign())
121            }
122
123            (Category::Normal, Category::Infinity)
124            | (Category::Zero, Category::Infinity) => {
125                Self::inf(sem, b.get_sign() ^ subtract)
126            }
127
128            (Category::Zero, Category::Normal) => Self::from_parts(
129                sem,
130                b.get_sign() ^ subtract,
131                b.get_exp(),
132                b.get_mantissa(),
133            ),
134
135            (Category::Zero, Category::Zero) => {
136                Self::zero(sem, a.get_sign() && b.get_sign())
137            }
138
139            (Category::Infinity, Category::Infinity) => {
140                if a.get_sign() ^ b.get_sign() ^ subtract {
141                    return Self::nan(sem, a.get_sign() ^ b.get_sign());
142                }
143                Self::inf(sem, a.get_sign())
144            }
145
146            (Category::Normal, Category::Normal) => {
147                // The IEEE 754 spec (section 6.3) states that cancellation
148                // results in a positive zero, except for the case of the
149                // negative rounding mode.
150                let cancellation = subtract == (a.get_sign() == b.get_sign());
151                let same_absolute_number = a.same_absolute_value(b);
152                if cancellation && same_absolute_number {
153                    let is_negative = RoundingMode::Negative == rm;
154                    return Self::zero(sem, is_negative);
155                }
156
157                let mut res = Self::add_or_sub_normals(a, b, subtract);
158                res.0.normalize(rm, res.1);
159                res.0
160            }
161        }
162    }
163}
164
165#[test]
166fn test_add() {
167    use super::float::FP64;
168    let a = Float::from_u64(FP64, 1);
169    let b = Float::from_u64(FP64, 2);
170    let _ = Float::add(a, b);
171}
172
173#[test]
174fn test_addition() {
175    fn add_helper(a: f64, b: f64) -> f64 {
176        let a = Float::from_f64(a);
177        let b = Float::from_f64(b);
178        let c = Float::add(a, b);
179        c.as_f64()
180    }
181
182    assert_eq!(add_helper(0., -4.), -4.);
183    assert_eq!(add_helper(-4., 0.), -4.);
184    assert_eq!(add_helper(1., 1.), 2.);
185    assert_eq!(add_helper(8., 4.), 12.);
186    assert_eq!(add_helper(8., 4.), 12.);
187    assert_eq!(add_helper(128., 2.), 130.);
188    assert_eq!(add_helper(128., -8.), 120.);
189    assert_eq!(add_helper(64., -60.), 4.);
190    assert_eq!(add_helper(69., -65.), 4.);
191    assert_eq!(add_helper(69., 69.), 138.);
192    assert_eq!(add_helper(69., 1.), 70.);
193    assert_eq!(add_helper(-128., -8.), -136.);
194    assert_eq!(add_helper(64., -65.), -1.);
195    assert_eq!(add_helper(-64., -65.), -129.);
196    assert_eq!(add_helper(-15., -15.), -30.);
197
198    assert_eq!(add_helper(-15., 15.), 0.);
199
200    for i in -4..15 {
201        for j in i..15 {
202            assert_eq!(
203                add_helper(f64::from(j), f64::from(i)),
204                f64::from(i) + f64::from(j)
205            );
206        }
207    }
208
209    // Check that adding a negative and positive results in a positive zero for
210    // the default rounding mode.
211    let a = Float::from_f64(4.0);
212    let b = Float::from_f64(-4.0);
213    let c = Float::add(a.clone(), b);
214    let d = Float::sub(a.clone(), a);
215    assert!(c.is_zero());
216    assert!(!c.is_negative());
217    assert!(d.is_zero());
218    assert!(!d.is_negative());
219}
220
221// Pg 120.  Chapter 4. Basic Properties and Algorithms.
222#[test]
223fn test_addition_large_numbers() {
224    use super::float::FP64;
225    let rm = RoundingMode::NearestTiesToEven;
226
227    let one = Float::from_i64(FP64, 1);
228    let mut a = Float::from_i64(FP64, 1);
229
230    while Float::sub_with_rm(&Float::add_with_rm(&a, &one, rm), &a, rm) == one {
231        a = Float::add_with_rm(&a, &a, rm);
232    }
233
234    let mut b = one.clone();
235    while Float::sub_with_rm(&Float::add_with_rm(&a, &b, rm), &a, rm) != b {
236        b = Float::add_with_rm(&b, &one, rm);
237    }
238
239    assert_eq!(a.as_f64(), 9007199254740992.);
240    assert_eq!(b.as_f64(), 2.);
241}
242
243#[test]
244fn add_denormals() {
245    let v0 = f64::from_bits(0x0000_0000_0010_0010);
246    let v1 = f64::from_bits(0x0000_0000_1001_0010);
247    let v2 = f64::from_bits(0x1000_0000_0001_0010);
248    assert_eq!(add_f64(v2, -v1), v2 - v1);
249
250    let a0 = Float::from_f64(v0);
251    assert_eq!(a0.as_f64(), v0);
252
253    fn add_f64(a: f64, b: f64) -> f64 {
254        let a0 = Float::from_f64(a);
255        let b0 = Float::from_f64(b);
256        assert_eq!(a0.as_f64(), a);
257        Float::add(a0, b0).as_f64()
258    }
259
260    // Add and subtract denormals.
261    assert_eq!(add_f64(v0, v1), v0 + v1);
262    assert_eq!(add_f64(v0, -v0), v0 - v0);
263    assert_eq!(add_f64(v0, v2), v0 + v2);
264    assert_eq!(add_f64(v2, v1), v2 + v1);
265    assert_eq!(add_f64(v2, -v1), v2 - v1);
266
267    // Add and subtract denormals and normal numbers.
268    assert_eq!(add_f64(v0, 10.), v0 + 10.);
269    assert_eq!(add_f64(v0, -10.), v0 - 10.);
270    assert_eq!(add_f64(10000., v0), 10000. + v0);
271}
272
273#[cfg(feature = "std")]
274#[test]
275fn add_special_values() {
276    use crate::utils;
277
278    // Test the addition of various irregular values.
279    let values = utils::get_special_test_values();
280
281    fn add_f64(a: f64, b: f64) -> f64 {
282        let a = Float::from_f64(a);
283        let b = Float::from_f64(b);
284        Float::add(a, b).as_f64()
285    }
286
287    for v0 in values {
288        for v1 in values {
289            let r0 = add_f64(v0, v1);
290            let r1 = v0 + v1;
291            let r0_bits = r0.to_bits();
292            let r1_bits = r1.to_bits();
293            assert_eq!(r0.is_finite(), r1.is_finite());
294            assert_eq!(r0.is_nan(), r1.is_nan());
295            assert_eq!(r0.is_infinite(), r1.is_infinite());
296            assert_eq!(r0.is_normal(), r1.is_normal());
297            // Check that the results are bit identical, or are both NaN.
298            assert!(!r0.is_normal() || r0_bits == r1_bits);
299        }
300    }
301}
302
303#[test]
304fn test_add_random_vals() {
305    use crate::utils;
306
307    let mut lfsr = utils::Lfsr::new();
308
309    let v0: u64 = 0x645e91f69778bad3;
310    let v1: u64 = 0xe4d91b16be9ae0c5;
311
312    fn add_f64(a: f64, b: f64) -> f64 {
313        let a = Float::from_f64(a);
314        let b = Float::from_f64(b);
315        let k = Float::add(a, b);
316        k.as_f64()
317    }
318
319    let f0 = f64::from_bits(v0);
320    let f1 = f64::from_bits(v1);
321
322    let r0 = add_f64(f0, f1);
323    let r1 = f0 + f1;
324
325    assert_eq!(r0.is_finite(), r1.is_finite());
326    assert_eq!(r0.is_nan(), r1.is_nan());
327    assert_eq!(r0.is_infinite(), r1.is_infinite());
328    let r0_bits = r0.to_bits();
329    let r1_bits = r1.to_bits();
330    // Check that the results are bit identical, or are both NaN.
331    assert!(r1.is_nan() || r0_bits == r1_bits);
332
333    for _ in 0..50000 {
334        let v0 = lfsr.get64();
335        let v1 = lfsr.get64();
336
337        let f0 = f64::from_bits(v0);
338        let f1 = f64::from_bits(v1);
339
340        let r0 = add_f64(f0, f1);
341        let r1 = f0 + f1;
342
343        assert_eq!(r0.is_finite(), r1.is_finite());
344        assert_eq!(r0.is_nan(), r1.is_nan());
345        assert_eq!(r0.is_infinite(), r1.is_infinite());
346        let r0_bits = r0.to_bits();
347        let r1_bits = r1.to_bits();
348        // Check that the results are bit identical, or are both NaN.
349        assert!(r1.is_nan() || r0_bits == r1_bits);
350    }
351}
352
353impl Float {
354    /// Compute a*b using the rounding mode `rm`.
355    pub fn mul_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
356        let sem = a.get_semantics();
357        let sign = a.get_sign() ^ b.get_sign();
358
359        // Table 8.4: Specification of multiplication for floating-point data of
360        // positive sign. Page 251.
361        match (a.get_category(), b.get_category()) {
362            (Category::Zero, Category::NaN)
363            | (Category::Normal, Category::NaN)
364            | (Category::Infinity, Category::NaN) => {
365                Self::nan(sem, b.get_sign())
366            }
367            (Category::NaN, Category::Infinity)
368            | (Category::NaN, Category::NaN)
369            | (Category::NaN, Category::Normal)
370            | (Category::NaN, Category::Zero) => Self::nan(sem, a.get_sign()),
371            (Category::Normal, Category::Infinity)
372            | (Category::Infinity, Category::Normal)
373            | (Category::Infinity, Category::Infinity) => Self::inf(sem, sign),
374            (Category::Normal, Category::Zero)
375            | (Category::Zero, Category::Normal)
376            | (Category::Zero, Category::Zero) => Self::zero(sem, sign),
377
378            (Category::Zero, Category::Infinity)
379            | (Category::Infinity, Category::Zero) => Self::nan(sem, sign),
380
381            (Category::Normal, Category::Normal) => {
382                let (mut res, loss) = Self::mul_normals(a, b, sign);
383                res.normalize(rm, loss);
384                res
385            }
386        }
387    }
388
389    /// See Pg 251. 8.4 Floating-Point Multiplication
390    fn mul_normals(a: &Self, b: &Self, sign: bool) -> (Self, LossFraction) {
391        debug_assert_eq!(a.get_semantics(), b.get_semantics());
392        let sem = a.get_semantics();
393        // We multiply digits in the format 1.xx * 2^(e), or mantissa * 2^(e+1).
394        // When we multiply two 2^(e+1) numbers, we get:
395        // log(2^(e_a+1)*2^(e_b+1)) = e_a + e_b + 2.
396        let mut exp = a.get_exp() + b.get_exp();
397
398        let a_significand = a.get_mantissa();
399        let b_significand = b.get_mantissa();
400        let ab_significand = a_significand * b_significand;
401
402        // The exponent is correct, but the bits are not in the right place.
403        // Set the right exponent for where the bits are placed, and fix the
404        // exponent below.
405        exp -= sem.get_mantissa_len() as i64;
406
407        let loss = LossFraction::ExactlyZero;
408        (Self::from_parts(sem, sign, exp, ab_significand), loss)
409    }
410}
411
412#[test]
413fn test_mul_simple() {
414    let a: f64 = -24.0;
415    let b: f64 = 0.1;
416
417    let af = Float::from_f64(a);
418    let bf = Float::from_f64(b);
419    let cf = Float::mul(af, bf);
420
421    let r0 = cf.as_f64();
422    let r1: f64 = a * b;
423    assert_eq!(r0, r1);
424}
425
426#[test]
427fn mul_regular_values() {
428    // Test the addition of regular values.
429    let values = [-5.0, 0., -0., 24., 1., 11., 10000., 256., 0.1, 3., 17.5];
430
431    fn mul_f64(a: f64, b: f64) -> f64 {
432        let a = Float::from_f64(a);
433        let b = Float::from_f64(b);
434        Float::mul(a, b).as_f64()
435    }
436
437    for v0 in values {
438        for v1 in values {
439            let r0 = mul_f64(v0, v1);
440            let r1 = v0 * v1;
441            let r0_bits = r0.to_bits();
442            let r1_bits = r1.to_bits();
443            // Check that the results are bit identical, or are both NaN.
444            assert_eq!(r0_bits, r1_bits);
445        }
446    }
447}
448
449#[cfg(feature = "std")]
450#[test]
451fn test_mul_special_values() {
452    use super::utils;
453
454    // Test the multiplication of various irregular values.
455    let values = utils::get_special_test_values();
456
457    fn mul_f64(a: f64, b: f64) -> f64 {
458        let a = Float::from_f64(a);
459        let b = Float::from_f64(b);
460        Float::mul(a, b).as_f64()
461    }
462
463    for v0 in values {
464        for v1 in values {
465            let r0 = mul_f64(v0, v1);
466            let r1 = v0 * v1;
467            assert_eq!(r0.is_finite(), r1.is_finite());
468            assert_eq!(r0.is_nan(), r1.is_nan());
469            assert_eq!(r0.is_infinite(), r1.is_infinite());
470            let r0_bits = r0.to_bits();
471            let r1_bits = r1.to_bits();
472            // Check that the results are bit identical, or are both NaN.
473            assert!(r1.is_nan() || r0_bits == r1_bits);
474        }
475    }
476}
477
478#[test]
479fn test_mul_random_vals() {
480    use super::utils;
481
482    let mut lfsr = utils::Lfsr::new();
483
484    fn mul_f64(a: f64, b: f64) -> f64 {
485        let a = Float::from_f64(a);
486        let b = Float::from_f64(b);
487        let k = Float::mul(a, b);
488        k.as_f64()
489    }
490
491    for _ in 0..50000 {
492        let v0 = lfsr.get64();
493        let v1 = lfsr.get64();
494
495        let f0 = f64::from_bits(v0);
496        let f1 = f64::from_bits(v1);
497
498        let r0 = mul_f64(f0, f1);
499        let r1 = f0 * f1;
500        assert_eq!(r0.is_finite(), r1.is_finite());
501        assert_eq!(r0.is_nan(), r1.is_nan());
502        assert_eq!(r0.is_infinite(), r1.is_infinite());
503        let r0_bits = r0.to_bits();
504        let r1_bits = r1.to_bits();
505        // Check that the results are bit identical, or are both NaN.
506        assert!(r1.is_nan() || r0_bits == r1_bits);
507    }
508}
509
510impl Float {
511    /// Compute a/b, with the rounding mode `rm`.
512    pub fn div_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
513        let sem = a.get_semantics();
514        let sign = a.get_sign() ^ b.get_sign();
515        // Table 8.5: Special values for x/y - Page 263.
516        match (a.get_category(), b.get_category()) {
517            (Category::NaN, _)
518            | (_, Category::NaN)
519            | (Category::Zero, Category::Zero)
520            | (Category::Infinity, Category::Infinity) => Self::nan(sem, sign),
521
522            (_, Category::Infinity) => Self::zero(sem, sign),
523            (Category::Zero, _) => Self::zero(sem, sign),
524            (_, Category::Zero) => Self::inf(sem, sign),
525            (Category::Infinity, _) => Self::inf(sem, sign),
526            (Category::Normal, Category::Normal) => {
527                let (mut res, loss) = Self::div_normals(a, b);
528                res.normalize(rm, loss);
529                res
530            }
531        }
532    }
533
534    /// Compute a/b, where both `a` and `b` are normals.
535    /// Page 262 8.6. Floating-Point Division.
536    /// This implementation uses a regular integer division for the mantissa.
537    fn div_normals(a: &Self, b: &Self) -> (Self, LossFraction) {
538        debug_assert_eq!(a.get_semantics(), b.get_semantics());
539        let sem = a.get_semantics();
540
541        let mut a = a.clone();
542        let mut b = b.clone();
543        // Start by normalizing the dividend and divisor to the MSB.
544        a.align_mantissa(); // Normalize the dividend.
545        b.align_mantissa(); // Normalize the divisor.
546
547        let mut a_mantissa = a.get_mantissa();
548        let b_mantissa = b.get_mantissa();
549
550        // Calculate the sign and exponent.
551        let mut exp = a.get_exp() - b.get_exp();
552        let sign = a.get_sign() ^ b.get_sign();
553
554        // Make sure that A >= B, to allow the integer division to generate all
555        // of the bits of the result.
556        if a_mantissa < b_mantissa {
557            a_mantissa.shift_left(1);
558            exp -= 1;
559        }
560
561        // The bits are now aligned to the MSB of the mantissa. The
562        // semantics need to be 1.xxxxx, but we perform integer division.
563        // Shift the dividend to make sure that we generate the bits after
564        // the period.
565        a_mantissa.shift_left(sem.get_mantissa_len());
566        let reminder = a_mantissa.inplace_div(&b_mantissa);
567
568        // Find 2 x reminder, to be able to compare to the reminder and figure
569        // out the kind of loss that we have.
570        let mut reminder_2x = reminder;
571        reminder_2x.shift_left(1);
572
573        let reminder = reminder_2x.cmp(&b_mantissa);
574        let is_zero = reminder_2x.is_zero();
575        let loss = match reminder {
576            Ordering::Less => {
577                if is_zero {
578                    LossFraction::ExactlyZero
579                } else {
580                    LossFraction::LessThanHalf
581                }
582            }
583            Ordering::Equal => LossFraction::ExactlyHalf,
584            Ordering::Greater => LossFraction::MoreThanHalf,
585        };
586
587        let x = Self::from_parts(sem, sign, exp, a_mantissa);
588        (x, loss)
589    }
590}
591
592#[test]
593fn test_div_simple() {
594    let a: f64 = 1.0;
595    let b: f64 = 7.0;
596
597    let af = Float::from_f64(a);
598    let bf = Float::from_f64(b);
599    let cf = Float::div_with_rm(&af, &bf, RoundingMode::NearestTiesToEven);
600
601    let r0 = cf.as_f64();
602    let r1: f64 = a / b;
603    assert_eq!(r0, r1);
604}
605
606#[cfg(feature = "std")]
607#[test]
608fn test_div_special_values() {
609    use super::utils;
610
611    // Test the multiplication of various irregular values.
612    let values = utils::get_special_test_values();
613
614    fn div_f64(a: f64, b: f64) -> f64 {
615        let a = Float::from_f64(a);
616        let b = Float::from_f64(b);
617        Float::div_with_rm(&a, &b, RoundingMode::NearestTiesToEven).as_f64()
618    }
619
620    for v0 in values {
621        for v1 in values {
622            let r0 = div_f64(v0, v1);
623            let r1 = v0 / v1;
624            assert_eq!(r0.is_finite(), r1.is_finite());
625            assert_eq!(r0.is_nan(), r1.is_nan());
626            assert_eq!(r0.is_infinite(), r1.is_infinite());
627            let r0_bits = r0.to_bits();
628            let r1_bits = r1.to_bits();
629            // Check that the results are bit identical, or are both NaN.
630            assert!(r1.is_nan() || r0_bits == r1_bits);
631        }
632    }
633}
634
635macro_rules! declare_operator {
636    ($trait_name:ident,
637     $func_name:ident,
638     $func_impl_name:ident) => {
639        // Self + Self
640        impl $trait_name for Float {
641            type Output = Self;
642            fn $func_name(self, rhs: Self) -> Self {
643                let sem = self.get_semantics();
644                Self::$func_impl_name(&self, &rhs, sem.get_rounding_mode())
645            }
646        }
647
648        // Self + u64
649        impl $trait_name<u64> for Float {
650            type Output = Self;
651            fn $func_name(self, rhs: u64) -> Self {
652                let sem = self.get_semantics();
653                Self::$func_impl_name(
654                    &self,
655                    &Self::Output::from_u64(sem, rhs),
656                    sem.get_rounding_mode(),
657                )
658            }
659        }
660        // &Self + &Self
661        impl $trait_name<Self> for &Float {
662            type Output = Float;
663            fn $func_name(self, rhs: Self) -> Self::Output {
664                let sem = self.get_semantics();
665                Self::Output::$func_impl_name(
666                    &self,
667                    rhs,
668                    sem.get_rounding_mode(),
669                )
670            }
671        }
672        // &Self + u64
673        impl $trait_name<u64> for &Float {
674            type Output = Float;
675            fn $func_name(self, rhs: u64) -> Self::Output {
676                let sem = self.get_semantics();
677                Self::Output::$func_impl_name(
678                    &self,
679                    &Self::Output::from_u64(self.get_semantics(), rhs),
680                    sem.get_rounding_mode(),
681                )
682            }
683        }
684
685        // &Self + Self
686        impl $trait_name<Float> for &Float {
687            type Output = Float;
688            fn $func_name(self, rhs: Float) -> Self::Output {
689                let sem = self.get_semantics();
690                Self::Output::$func_impl_name(
691                    &self,
692                    &rhs,
693                    sem.get_rounding_mode(),
694                )
695            }
696        }
697    };
698}
699
700declare_operator!(Add, add, add_with_rm);
701declare_operator!(Sub, sub, sub_with_rm);
702declare_operator!(Mul, mul, mul_with_rm);
703declare_operator!(Div, div, div_with_rm);
704
705macro_rules! declare_assign_operator {
706    ($trait_name:ident,
707     $func_name:ident,
708     $func_impl_name:ident) => {
709        impl $trait_name for Float {
710            fn $func_name(&mut self, rhs: Self) {
711                let sem = self.get_semantics();
712                *self =
713                    Self::$func_impl_name(self, &rhs, sem.get_rounding_mode());
714            }
715        }
716
717        impl $trait_name<&Float> for Float {
718            fn $func_name(&mut self, rhs: &Self) {
719                let sem = self.get_semantics();
720                *self =
721                    Self::$func_impl_name(self, rhs, sem.get_rounding_mode());
722            }
723        }
724    };
725}
726
727declare_assign_operator!(AddAssign, add_assign, add_with_rm);
728declare_assign_operator!(SubAssign, sub_assign, sub_with_rm);
729declare_assign_operator!(MulAssign, mul_assign, mul_with_rm);
730declare_assign_operator!(DivAssign, div_assign, div_with_rm);
731
732#[test]
733fn test_operators() {
734    use crate::FP64;
735    let a = Float::from_f32(8.0).cast(FP64);
736    let b = Float::from_f32(2.0).cast(FP64);
737    let c = &a + &b;
738    let d = &a - &b;
739    let e = &a * &b;
740    let f = &a / &b;
741    assert_eq!(c.as_f64(), 10.0);
742    assert_eq!(d.as_f64(), 6.0);
743    assert_eq!(e.as_f64(), 16.0);
744    assert_eq!(f.as_f64(), 4.0);
745}
746
747#[test]
748fn test_slow_sqrt_2_test() {
749    use crate::FP128;
750    use crate::FP64;
751
752    // Find sqrt using a binary search.
753    let two = Float::from_f64(2.0).cast(FP128);
754    let mut high = Float::from_f64(2.0).cast(FP128);
755    let mut low = Float::from_f64(1.0).cast(FP128);
756
757    for _ in 0..25 {
758        let mid = (&high + &low) / 2;
759        if (&mid * &mid) < two {
760            low = mid;
761        } else {
762            high = mid;
763        }
764    }
765
766    let res = low.cast(FP64);
767    assert!(res.as_f64() < 1.4142137_f64);
768    assert!(res.as_f64() > 1.4142134_f64);
769}
770
771#[cfg(feature = "std")]
772#[test]
773fn test_famous_pentium4_bug() {
774    use crate::std::string::ToString;
775    // https://en.wikipedia.org/wiki/Pentium_FDIV_bug
776    use crate::FP128;
777
778    let a = Float::from_u64(FP128, 4_195_835);
779    let b = Float::from_u64(FP128, 3_145_727);
780    let res = a / b;
781    let result = res.to_string();
782    assert!(result.starts_with("1.333820449136241002"));
783}
784
785impl Float {
786    // Perform a fused multiply-add of normal numbers, without rounding.
787    fn fused_mul_add_normals(
788        a: &Self,
789        b: &Self,
790        c: &Self,
791    ) -> (Self, LossFraction) {
792        debug_assert_eq!(a.get_semantics(), b.get_semantics());
793        let sem = a.get_semantics();
794
795        // Multiply a and b, without rounding.
796        let sign = a.get_sign() ^ b.get_sign();
797        let mut ab = Self::mul_normals(a, b, sign).0;
798
799        // Shift the product, to allow enough precision for the addition.
800        // Notice that this can be implemented more efficiently with 3 extra
801        // bits and sticky bits.
802        // See 8.5. Floating-Point Fused Multiply-Add, Page 255.
803        let mut c = c.clone();
804        let extra_bits = sem.get_precision() + 1;
805        ab.shift_significand_left(extra_bits as u64);
806        c.shift_significand_left(extra_bits as u64);
807
808        // Perform the addition, without rounding.
809        Self::add_or_sub_normals(&ab, &c, false)
810    }
811
812    /// Compute a*b + c, with the rounding mode `rm`.
813    pub fn fused_mul_add_with_rm(
814        a: &Self,
815        b: &Self,
816        c: &Self,
817        rm: RoundingMode,
818    ) -> Self {
819        if a.is_normal() && b.is_normal() && c.is_normal() {
820            let (mut res, loss) = Self::fused_mul_add_normals(a, b, c);
821            res.normalize(rm, loss); // Finally, round the result.
822            res
823        } else {
824            // Perform two operations. First, handle non-normal values.
825
826            // NaN anything = NaN
827            if a.is_nan() || b.is_nan() || c.is_nan() {
828                return Self::nan(a.get_semantics(), a.get_sign());
829            }
830            // (infinity * 0) + c = NaN
831            if (a.is_inf() && b.is_zero()) || (a.is_zero() && b.is_inf()) {
832                return Self::nan(a.get_semantics(), a.get_sign());
833            }
834            // (normal * normal) + infinity = infinity
835            if a.is_normal() && b.is_normal() && c.is_inf() {
836                return c.clone();
837            }
838            // (normal * 0) + c = c
839            if a.is_zero() || b.is_zero() {
840                return c.clone();
841            }
842
843            // Multiply (with rounding), and add (with rounding).
844            let ab = Self::mul_with_rm(a, b, rm);
845            Self::add_with_rm(&ab, c, rm)
846        }
847    }
848
849    /// Compute a*b + c.
850    pub fn fma(a: &Self, b: &Self, c: &Self) -> Self {
851        Self::fused_mul_add_with_rm(a, b, c, c.get_rounding_mode())
852    }
853}
854
855#[test]
856fn test_fma() {
857    let v0 = -10.;
858    let v1 = -1.1;
859    let v2 = 0.000000000000000000000000000000000000001;
860    let af = Float::from_f64(v0);
861    let bf = Float::from_f64(v1);
862    let cf = Float::from_f64(v2);
863
864    let r = Float::fused_mul_add_with_rm(
865        &af,
866        &bf,
867        &cf,
868        RoundingMode::NearestTiesToEven,
869    );
870
871    assert_eq!(f64::mul_add(v0, v1, v2), r.as_f64());
872}
873
874#[cfg(feature = "std")]
875#[test]
876fn test_fma_simple() {
877    use super::utils;
878    // Test the multiplication of various irregular values.
879    let values = utils::get_special_test_values();
880    for a in values {
881        for b in values {
882            for c in values {
883                let af = Float::from_f64(a);
884                let bf = Float::from_f64(b);
885                let cf = Float::from_f64(c);
886
887                let rf = Float::fused_mul_add_with_rm(
888                    &af,
889                    &bf,
890                    &cf,
891                    RoundingMode::NearestTiesToEven,
892                );
893
894                let r0 = rf.as_f64();
895                let r1: f64 = a.mul_add(b, c);
896                assert_eq!(r0.is_finite(), r1.is_finite());
897                assert_eq!(r0.is_nan(), r1.is_nan());
898                assert_eq!(r0.is_infinite(), r1.is_infinite());
899                // Check that the results are bit identical, or are both NaN.
900                assert!(r1.is_nan() || r1.is_infinite() || r0 == r1);
901            }
902        }
903    }
904}
905
906#[test]
907fn test_fma_random_vals() {
908    use super::utils;
909
910    let mut lfsr = utils::Lfsr::new();
911
912    fn mul_f32(a: f32, b: f32, c: f32) -> f32 {
913        let a = Float::from_f32(a);
914        let b = Float::from_f32(b);
915        let c = Float::from_f32(c);
916        let k = Float::fused_mul_add_with_rm(
917            &a,
918            &b,
919            &c,
920            RoundingMode::NearestTiesToEven,
921        );
922        k.as_f32()
923    }
924
925    for _ in 0..50000 {
926        let v0 = lfsr.get64() as u32;
927        let v1 = lfsr.get64() as u32;
928        let v2 = lfsr.get64() as u32;
929
930        let f0 = f32::from_bits(v0);
931        let f1 = f32::from_bits(v1);
932        let f2 = f32::from_bits(v2);
933
934        let r0 = mul_f32(f0, f1, f2);
935        let r1 = f32::mul_add(f0, f1, f2);
936        assert_eq!(r0.is_finite(), r1.is_finite());
937        assert_eq!(r0.is_nan(), r1.is_nan());
938        assert_eq!(r0.is_infinite(), r1.is_infinite());
939        let r0_bits = r0.to_bits();
940        let r1_bits = r1.to_bits();
941        // Check that the results are bit identical, or are both NaN.
942        assert!(r1.is_nan() || r0_bits == r1_bits);
943    }
944}