1extern crate alloc;
4use crate::bigint::BigInt;
5
6use super::bigint::LossFraction;
7use super::float::{Category, Float, RoundingMode};
8use core::cmp::Ordering;
9use core::ops::{
10 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign,
11};
12
13impl Float {
14 fn add_or_sub_normals(
20 a: &Self,
21 b: &Self,
22 subtract: bool,
23 ) -> (Self, LossFraction) {
24 debug_assert_eq!(a.get_semantics(), b.get_semantics());
25 let sem = a.get_semantics();
26 let loss;
27 let mut a = a.clone();
28 let mut b = b.clone();
29
30 let bits = a.get_exp() - b.get_exp();
32
33 let subtract = subtract ^ (a.get_sign() ^ b.get_sign());
36 if subtract {
37 match bits.cmp(&0) {
40 Ordering::Equal => {
41 loss = LossFraction::ExactlyZero;
42 }
43 Ordering::Greater => {
44 loss = b.shift_significand_right((bits - 1) as u64);
45 a.shift_significand_left(1);
46 }
47 Ordering::Less => {
48 loss = a.shift_significand_right((-bits - 1) as u64);
49 b.shift_significand_left(1);
50 }
51 }
52
53 let a_mantissa = a.get_mantissa();
54 let b_mantissa = b.get_mantissa();
55 let ab_mantissa;
56 let mut sign = a.get_sign();
57
58 let c = !loss.is_exactly_zero() as u64;
61 let c = BigInt::from_u64(c);
62
63 if a_mantissa < b_mantissa {
66 ab_mantissa = b_mantissa - a_mantissa - c;
68 sign = !sign;
69 } else {
70 ab_mantissa = a_mantissa - b_mantissa - c;
72 }
73 (
74 Self::from_parts(sem, sign, a.get_exp(), ab_mantissa),
75 loss.invert(),
76 )
77 } else {
78 let mut b = b.clone();
80 let mut a = a.clone();
81 if bits > 0 {
82 loss = b.shift_significand_right(bits as u64);
83 } else {
84 loss = a.shift_significand_right(-bits as u64);
85 }
86 debug_assert!(a.get_exp() == b.get_exp());
87 let ab_mantissa = a.get_mantissa() + b.get_mantissa();
88 (
89 Self::from_parts(sem, a.get_sign(), a.get_exp(), ab_mantissa),
90 loss,
91 )
92 }
93 }
94
95 pub fn add_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
97 Self::add_sub(a, b, false, rm)
98 }
99 pub fn sub_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
101 Self::add_sub(a, b, true, rm)
102 }
103
104 fn add_sub(a: &Self, b: &Self, subtract: bool, rm: RoundingMode) -> Self {
105 let sem = a.get_semantics();
106 match (a.get_category(), b.get_category()) {
109 (Category::NaN, Category::Infinity)
110 | (Category::NaN, Category::NaN)
111 | (Category::NaN, Category::Normal)
112 | (Category::NaN, Category::Zero)
113 | (Category::Normal, Category::Zero)
114 | (Category::Infinity, Category::Normal)
115 | (Category::Infinity, Category::Zero) => a.clone(),
116
117 (Category::Zero, Category::NaN)
118 | (Category::Normal, Category::NaN)
119 | (Category::Infinity, Category::NaN) => {
120 Self::nan(sem, b.get_sign())
121 }
122
123 (Category::Normal, Category::Infinity)
124 | (Category::Zero, Category::Infinity) => {
125 Self::inf(sem, b.get_sign() ^ subtract)
126 }
127
128 (Category::Zero, Category::Normal) => Self::from_parts(
129 sem,
130 b.get_sign() ^ subtract,
131 b.get_exp(),
132 b.get_mantissa(),
133 ),
134
135 (Category::Zero, Category::Zero) => {
136 Self::zero(sem, a.get_sign() && b.get_sign())
137 }
138
139 (Category::Infinity, Category::Infinity) => {
140 if a.get_sign() ^ b.get_sign() ^ subtract {
141 return Self::nan(sem, a.get_sign() ^ b.get_sign());
142 }
143 Self::inf(sem, a.get_sign())
144 }
145
146 (Category::Normal, Category::Normal) => {
147 let cancellation = subtract == (a.get_sign() == b.get_sign());
151 let same_absolute_number = a.same_absolute_value(b);
152 if cancellation && same_absolute_number {
153 let is_negative = RoundingMode::Negative == rm;
154 return Self::zero(sem, is_negative);
155 }
156
157 let mut res = Self::add_or_sub_normals(a, b, subtract);
158 res.0.normalize(rm, res.1);
159 res.0
160 }
161 }
162 }
163}
164
165#[test]
166fn test_add() {
167 use super::float::FP64;
168 let a = Float::from_u64(FP64, 1);
169 let b = Float::from_u64(FP64, 2);
170 let _ = Float::add(a, b);
171}
172
173#[test]
174fn test_addition() {
175 fn add_helper(a: f64, b: f64) -> f64 {
176 let a = Float::from_f64(a);
177 let b = Float::from_f64(b);
178 let c = Float::add(a, b);
179 c.as_f64()
180 }
181
182 assert_eq!(add_helper(0., -4.), -4.);
183 assert_eq!(add_helper(-4., 0.), -4.);
184 assert_eq!(add_helper(1., 1.), 2.);
185 assert_eq!(add_helper(8., 4.), 12.);
186 assert_eq!(add_helper(8., 4.), 12.);
187 assert_eq!(add_helper(128., 2.), 130.);
188 assert_eq!(add_helper(128., -8.), 120.);
189 assert_eq!(add_helper(64., -60.), 4.);
190 assert_eq!(add_helper(69., -65.), 4.);
191 assert_eq!(add_helper(69., 69.), 138.);
192 assert_eq!(add_helper(69., 1.), 70.);
193 assert_eq!(add_helper(-128., -8.), -136.);
194 assert_eq!(add_helper(64., -65.), -1.);
195 assert_eq!(add_helper(-64., -65.), -129.);
196 assert_eq!(add_helper(-15., -15.), -30.);
197
198 assert_eq!(add_helper(-15., 15.), 0.);
199
200 for i in -4..15 {
201 for j in i..15 {
202 assert_eq!(
203 add_helper(f64::from(j), f64::from(i)),
204 f64::from(i) + f64::from(j)
205 );
206 }
207 }
208
209 let a = Float::from_f64(4.0);
212 let b = Float::from_f64(-4.0);
213 let c = Float::add(a.clone(), b);
214 let d = Float::sub(a.clone(), a);
215 assert!(c.is_zero());
216 assert!(!c.is_negative());
217 assert!(d.is_zero());
218 assert!(!d.is_negative());
219}
220
221#[test]
223fn test_addition_large_numbers() {
224 use super::float::FP64;
225 let rm = RoundingMode::NearestTiesToEven;
226
227 let one = Float::from_i64(FP64, 1);
228 let mut a = Float::from_i64(FP64, 1);
229
230 while Float::sub_with_rm(&Float::add_with_rm(&a, &one, rm), &a, rm) == one {
231 a = Float::add_with_rm(&a, &a, rm);
232 }
233
234 let mut b = one.clone();
235 while Float::sub_with_rm(&Float::add_with_rm(&a, &b, rm), &a, rm) != b {
236 b = Float::add_with_rm(&b, &one, rm);
237 }
238
239 assert_eq!(a.as_f64(), 9007199254740992.);
240 assert_eq!(b.as_f64(), 2.);
241}
242
243#[test]
244fn add_denormals() {
245 let v0 = f64::from_bits(0x0000_0000_0010_0010);
246 let v1 = f64::from_bits(0x0000_0000_1001_0010);
247 let v2 = f64::from_bits(0x1000_0000_0001_0010);
248 assert_eq!(add_f64(v2, -v1), v2 - v1);
249
250 let a0 = Float::from_f64(v0);
251 assert_eq!(a0.as_f64(), v0);
252
253 fn add_f64(a: f64, b: f64) -> f64 {
254 let a0 = Float::from_f64(a);
255 let b0 = Float::from_f64(b);
256 assert_eq!(a0.as_f64(), a);
257 Float::add(a0, b0).as_f64()
258 }
259
260 assert_eq!(add_f64(v0, v1), v0 + v1);
262 assert_eq!(add_f64(v0, -v0), v0 - v0);
263 assert_eq!(add_f64(v0, v2), v0 + v2);
264 assert_eq!(add_f64(v2, v1), v2 + v1);
265 assert_eq!(add_f64(v2, -v1), v2 - v1);
266
267 assert_eq!(add_f64(v0, 10.), v0 + 10.);
269 assert_eq!(add_f64(v0, -10.), v0 - 10.);
270 assert_eq!(add_f64(10000., v0), 10000. + v0);
271}
272
273#[cfg(feature = "std")]
274#[test]
275fn add_special_values() {
276 use crate::utils;
277
278 let values = utils::get_special_test_values();
280
281 fn add_f64(a: f64, b: f64) -> f64 {
282 let a = Float::from_f64(a);
283 let b = Float::from_f64(b);
284 Float::add(a, b).as_f64()
285 }
286
287 for v0 in values {
288 for v1 in values {
289 let r0 = add_f64(v0, v1);
290 let r1 = v0 + v1;
291 let r0_bits = r0.to_bits();
292 let r1_bits = r1.to_bits();
293 assert_eq!(r0.is_finite(), r1.is_finite());
294 assert_eq!(r0.is_nan(), r1.is_nan());
295 assert_eq!(r0.is_infinite(), r1.is_infinite());
296 assert_eq!(r0.is_normal(), r1.is_normal());
297 assert!(!r0.is_normal() || r0_bits == r1_bits);
299 }
300 }
301}
302
303#[test]
304fn test_add_random_vals() {
305 use crate::utils;
306
307 let mut lfsr = utils::Lfsr::new();
308
309 let v0: u64 = 0x645e91f69778bad3;
310 let v1: u64 = 0xe4d91b16be9ae0c5;
311
312 fn add_f64(a: f64, b: f64) -> f64 {
313 let a = Float::from_f64(a);
314 let b = Float::from_f64(b);
315 let k = Float::add(a, b);
316 k.as_f64()
317 }
318
319 let f0 = f64::from_bits(v0);
320 let f1 = f64::from_bits(v1);
321
322 let r0 = add_f64(f0, f1);
323 let r1 = f0 + f1;
324
325 assert_eq!(r0.is_finite(), r1.is_finite());
326 assert_eq!(r0.is_nan(), r1.is_nan());
327 assert_eq!(r0.is_infinite(), r1.is_infinite());
328 let r0_bits = r0.to_bits();
329 let r1_bits = r1.to_bits();
330 assert!(r1.is_nan() || r0_bits == r1_bits);
332
333 for _ in 0..50000 {
334 let v0 = lfsr.get64();
335 let v1 = lfsr.get64();
336
337 let f0 = f64::from_bits(v0);
338 let f1 = f64::from_bits(v1);
339
340 let r0 = add_f64(f0, f1);
341 let r1 = f0 + f1;
342
343 assert_eq!(r0.is_finite(), r1.is_finite());
344 assert_eq!(r0.is_nan(), r1.is_nan());
345 assert_eq!(r0.is_infinite(), r1.is_infinite());
346 let r0_bits = r0.to_bits();
347 let r1_bits = r1.to_bits();
348 assert!(r1.is_nan() || r0_bits == r1_bits);
350 }
351}
352
353impl Float {
354 pub fn mul_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
356 let sem = a.get_semantics();
357 let sign = a.get_sign() ^ b.get_sign();
358
359 match (a.get_category(), b.get_category()) {
362 (Category::Zero, Category::NaN)
363 | (Category::Normal, Category::NaN)
364 | (Category::Infinity, Category::NaN) => {
365 Self::nan(sem, b.get_sign())
366 }
367 (Category::NaN, Category::Infinity)
368 | (Category::NaN, Category::NaN)
369 | (Category::NaN, Category::Normal)
370 | (Category::NaN, Category::Zero) => Self::nan(sem, a.get_sign()),
371 (Category::Normal, Category::Infinity)
372 | (Category::Infinity, Category::Normal)
373 | (Category::Infinity, Category::Infinity) => Self::inf(sem, sign),
374 (Category::Normal, Category::Zero)
375 | (Category::Zero, Category::Normal)
376 | (Category::Zero, Category::Zero) => Self::zero(sem, sign),
377
378 (Category::Zero, Category::Infinity)
379 | (Category::Infinity, Category::Zero) => Self::nan(sem, sign),
380
381 (Category::Normal, Category::Normal) => {
382 let (mut res, loss) = Self::mul_normals(a, b, sign);
383 res.normalize(rm, loss);
384 res
385 }
386 }
387 }
388
389 fn mul_normals(a: &Self, b: &Self, sign: bool) -> (Self, LossFraction) {
391 debug_assert_eq!(a.get_semantics(), b.get_semantics());
392 let sem = a.get_semantics();
393 let mut exp = a.get_exp() + b.get_exp();
397
398 let a_significand = a.get_mantissa();
399 let b_significand = b.get_mantissa();
400 let ab_significand = a_significand * b_significand;
401
402 exp -= sem.get_mantissa_len() as i64;
406
407 let loss = LossFraction::ExactlyZero;
408 (Self::from_parts(sem, sign, exp, ab_significand), loss)
409 }
410}
411
412#[test]
413fn test_mul_simple() {
414 let a: f64 = -24.0;
415 let b: f64 = 0.1;
416
417 let af = Float::from_f64(a);
418 let bf = Float::from_f64(b);
419 let cf = Float::mul(af, bf);
420
421 let r0 = cf.as_f64();
422 let r1: f64 = a * b;
423 assert_eq!(r0, r1);
424}
425
426#[test]
427fn mul_regular_values() {
428 let values = [-5.0, 0., -0., 24., 1., 11., 10000., 256., 0.1, 3., 17.5];
430
431 fn mul_f64(a: f64, b: f64) -> f64 {
432 let a = Float::from_f64(a);
433 let b = Float::from_f64(b);
434 Float::mul(a, b).as_f64()
435 }
436
437 for v0 in values {
438 for v1 in values {
439 let r0 = mul_f64(v0, v1);
440 let r1 = v0 * v1;
441 let r0_bits = r0.to_bits();
442 let r1_bits = r1.to_bits();
443 assert_eq!(r0_bits, r1_bits);
445 }
446 }
447}
448
449#[cfg(feature = "std")]
450#[test]
451fn test_mul_special_values() {
452 use super::utils;
453
454 let values = utils::get_special_test_values();
456
457 fn mul_f64(a: f64, b: f64) -> f64 {
458 let a = Float::from_f64(a);
459 let b = Float::from_f64(b);
460 Float::mul(a, b).as_f64()
461 }
462
463 for v0 in values {
464 for v1 in values {
465 let r0 = mul_f64(v0, v1);
466 let r1 = v0 * v1;
467 assert_eq!(r0.is_finite(), r1.is_finite());
468 assert_eq!(r0.is_nan(), r1.is_nan());
469 assert_eq!(r0.is_infinite(), r1.is_infinite());
470 let r0_bits = r0.to_bits();
471 let r1_bits = r1.to_bits();
472 assert!(r1.is_nan() || r0_bits == r1_bits);
474 }
475 }
476}
477
478#[test]
479fn test_mul_random_vals() {
480 use super::utils;
481
482 let mut lfsr = utils::Lfsr::new();
483
484 fn mul_f64(a: f64, b: f64) -> f64 {
485 let a = Float::from_f64(a);
486 let b = Float::from_f64(b);
487 let k = Float::mul(a, b);
488 k.as_f64()
489 }
490
491 for _ in 0..50000 {
492 let v0 = lfsr.get64();
493 let v1 = lfsr.get64();
494
495 let f0 = f64::from_bits(v0);
496 let f1 = f64::from_bits(v1);
497
498 let r0 = mul_f64(f0, f1);
499 let r1 = f0 * f1;
500 assert_eq!(r0.is_finite(), r1.is_finite());
501 assert_eq!(r0.is_nan(), r1.is_nan());
502 assert_eq!(r0.is_infinite(), r1.is_infinite());
503 let r0_bits = r0.to_bits();
504 let r1_bits = r1.to_bits();
505 assert!(r1.is_nan() || r0_bits == r1_bits);
507 }
508}
509
510impl Float {
511 pub fn div_with_rm(a: &Self, b: &Self, rm: RoundingMode) -> Self {
513 let sem = a.get_semantics();
514 let sign = a.get_sign() ^ b.get_sign();
515 match (a.get_category(), b.get_category()) {
517 (Category::NaN, _)
518 | (_, Category::NaN)
519 | (Category::Zero, Category::Zero)
520 | (Category::Infinity, Category::Infinity) => Self::nan(sem, sign),
521
522 (_, Category::Infinity) => Self::zero(sem, sign),
523 (Category::Zero, _) => Self::zero(sem, sign),
524 (_, Category::Zero) => Self::inf(sem, sign),
525 (Category::Infinity, _) => Self::inf(sem, sign),
526 (Category::Normal, Category::Normal) => {
527 let (mut res, loss) = Self::div_normals(a, b);
528 res.normalize(rm, loss);
529 res
530 }
531 }
532 }
533
534 fn div_normals(a: &Self, b: &Self) -> (Self, LossFraction) {
538 debug_assert_eq!(a.get_semantics(), b.get_semantics());
539 let sem = a.get_semantics();
540
541 let mut a = a.clone();
542 let mut b = b.clone();
543 a.align_mantissa(); b.align_mantissa(); let mut a_mantissa = a.get_mantissa();
548 let b_mantissa = b.get_mantissa();
549
550 let mut exp = a.get_exp() - b.get_exp();
552 let sign = a.get_sign() ^ b.get_sign();
553
554 if a_mantissa < b_mantissa {
557 a_mantissa.shift_left(1);
558 exp -= 1;
559 }
560
561 a_mantissa.shift_left(sem.get_mantissa_len());
566 let reminder = a_mantissa.inplace_div(&b_mantissa);
567
568 let mut reminder_2x = reminder;
571 reminder_2x.shift_left(1);
572
573 let reminder = reminder_2x.cmp(&b_mantissa);
574 let is_zero = reminder_2x.is_zero();
575 let loss = match reminder {
576 Ordering::Less => {
577 if is_zero {
578 LossFraction::ExactlyZero
579 } else {
580 LossFraction::LessThanHalf
581 }
582 }
583 Ordering::Equal => LossFraction::ExactlyHalf,
584 Ordering::Greater => LossFraction::MoreThanHalf,
585 };
586
587 let x = Self::from_parts(sem, sign, exp, a_mantissa);
588 (x, loss)
589 }
590}
591
592#[test]
593fn test_div_simple() {
594 let a: f64 = 1.0;
595 let b: f64 = 7.0;
596
597 let af = Float::from_f64(a);
598 let bf = Float::from_f64(b);
599 let cf = Float::div_with_rm(&af, &bf, RoundingMode::NearestTiesToEven);
600
601 let r0 = cf.as_f64();
602 let r1: f64 = a / b;
603 assert_eq!(r0, r1);
604}
605
606#[cfg(feature = "std")]
607#[test]
608fn test_div_special_values() {
609 use super::utils;
610
611 let values = utils::get_special_test_values();
613
614 fn div_f64(a: f64, b: f64) -> f64 {
615 let a = Float::from_f64(a);
616 let b = Float::from_f64(b);
617 Float::div_with_rm(&a, &b, RoundingMode::NearestTiesToEven).as_f64()
618 }
619
620 for v0 in values {
621 for v1 in values {
622 let r0 = div_f64(v0, v1);
623 let r1 = v0 / v1;
624 assert_eq!(r0.is_finite(), r1.is_finite());
625 assert_eq!(r0.is_nan(), r1.is_nan());
626 assert_eq!(r0.is_infinite(), r1.is_infinite());
627 let r0_bits = r0.to_bits();
628 let r1_bits = r1.to_bits();
629 assert!(r1.is_nan() || r0_bits == r1_bits);
631 }
632 }
633}
634
635macro_rules! declare_operator {
636 ($trait_name:ident,
637 $func_name:ident,
638 $func_impl_name:ident) => {
639 impl $trait_name for Float {
641 type Output = Self;
642 fn $func_name(self, rhs: Self) -> Self {
643 let sem = self.get_semantics();
644 Self::$func_impl_name(&self, &rhs, sem.get_rounding_mode())
645 }
646 }
647
648 impl $trait_name<u64> for Float {
650 type Output = Self;
651 fn $func_name(self, rhs: u64) -> Self {
652 let sem = self.get_semantics();
653 Self::$func_impl_name(
654 &self,
655 &Self::Output::from_u64(sem, rhs),
656 sem.get_rounding_mode(),
657 )
658 }
659 }
660 impl $trait_name<Self> for &Float {
662 type Output = Float;
663 fn $func_name(self, rhs: Self) -> Self::Output {
664 let sem = self.get_semantics();
665 Self::Output::$func_impl_name(
666 &self,
667 rhs,
668 sem.get_rounding_mode(),
669 )
670 }
671 }
672 impl $trait_name<u64> for &Float {
674 type Output = Float;
675 fn $func_name(self, rhs: u64) -> Self::Output {
676 let sem = self.get_semantics();
677 Self::Output::$func_impl_name(
678 &self,
679 &Self::Output::from_u64(self.get_semantics(), rhs),
680 sem.get_rounding_mode(),
681 )
682 }
683 }
684
685 impl $trait_name<Float> for &Float {
687 type Output = Float;
688 fn $func_name(self, rhs: Float) -> Self::Output {
689 let sem = self.get_semantics();
690 Self::Output::$func_impl_name(
691 &self,
692 &rhs,
693 sem.get_rounding_mode(),
694 )
695 }
696 }
697 };
698}
699
700declare_operator!(Add, add, add_with_rm);
701declare_operator!(Sub, sub, sub_with_rm);
702declare_operator!(Mul, mul, mul_with_rm);
703declare_operator!(Div, div, div_with_rm);
704
705macro_rules! declare_assign_operator {
706 ($trait_name:ident,
707 $func_name:ident,
708 $func_impl_name:ident) => {
709 impl $trait_name for Float {
710 fn $func_name(&mut self, rhs: Self) {
711 let sem = self.get_semantics();
712 *self =
713 Self::$func_impl_name(self, &rhs, sem.get_rounding_mode());
714 }
715 }
716
717 impl $trait_name<&Float> for Float {
718 fn $func_name(&mut self, rhs: &Self) {
719 let sem = self.get_semantics();
720 *self =
721 Self::$func_impl_name(self, rhs, sem.get_rounding_mode());
722 }
723 }
724 };
725}
726
727declare_assign_operator!(AddAssign, add_assign, add_with_rm);
728declare_assign_operator!(SubAssign, sub_assign, sub_with_rm);
729declare_assign_operator!(MulAssign, mul_assign, mul_with_rm);
730declare_assign_operator!(DivAssign, div_assign, div_with_rm);
731
732#[test]
733fn test_operators() {
734 use crate::FP64;
735 let a = Float::from_f32(8.0).cast(FP64);
736 let b = Float::from_f32(2.0).cast(FP64);
737 let c = &a + &b;
738 let d = &a - &b;
739 let e = &a * &b;
740 let f = &a / &b;
741 assert_eq!(c.as_f64(), 10.0);
742 assert_eq!(d.as_f64(), 6.0);
743 assert_eq!(e.as_f64(), 16.0);
744 assert_eq!(f.as_f64(), 4.0);
745}
746
747#[test]
748fn test_slow_sqrt_2_test() {
749 use crate::FP128;
750 use crate::FP64;
751
752 let two = Float::from_f64(2.0).cast(FP128);
754 let mut high = Float::from_f64(2.0).cast(FP128);
755 let mut low = Float::from_f64(1.0).cast(FP128);
756
757 for _ in 0..25 {
758 let mid = (&high + &low) / 2;
759 if (&mid * &mid) < two {
760 low = mid;
761 } else {
762 high = mid;
763 }
764 }
765
766 let res = low.cast(FP64);
767 assert!(res.as_f64() < 1.4142137_f64);
768 assert!(res.as_f64() > 1.4142134_f64);
769}
770
771#[cfg(feature = "std")]
772#[test]
773fn test_famous_pentium4_bug() {
774 use crate::std::string::ToString;
775 use crate::FP128;
777
778 let a = Float::from_u64(FP128, 4_195_835);
779 let b = Float::from_u64(FP128, 3_145_727);
780 let res = a / b;
781 let result = res.to_string();
782 assert!(result.starts_with("1.333820449136241002"));
783}
784
785impl Float {
786 fn fused_mul_add_normals(
788 a: &Self,
789 b: &Self,
790 c: &Self,
791 ) -> (Self, LossFraction) {
792 debug_assert_eq!(a.get_semantics(), b.get_semantics());
793 let sem = a.get_semantics();
794
795 let sign = a.get_sign() ^ b.get_sign();
797 let mut ab = Self::mul_normals(a, b, sign).0;
798
799 let mut c = c.clone();
804 let extra_bits = sem.get_precision() + 1;
805 ab.shift_significand_left(extra_bits as u64);
806 c.shift_significand_left(extra_bits as u64);
807
808 Self::add_or_sub_normals(&ab, &c, false)
810 }
811
812 pub fn fused_mul_add_with_rm(
814 a: &Self,
815 b: &Self,
816 c: &Self,
817 rm: RoundingMode,
818 ) -> Self {
819 if a.is_normal() && b.is_normal() && c.is_normal() {
820 let (mut res, loss) = Self::fused_mul_add_normals(a, b, c);
821 res.normalize(rm, loss); res
823 } else {
824 if a.is_nan() || b.is_nan() || c.is_nan() {
828 return Self::nan(a.get_semantics(), a.get_sign());
829 }
830 if (a.is_inf() && b.is_zero()) || (a.is_zero() && b.is_inf()) {
832 return Self::nan(a.get_semantics(), a.get_sign());
833 }
834 if a.is_normal() && b.is_normal() && c.is_inf() {
836 return c.clone();
837 }
838 if a.is_zero() || b.is_zero() {
840 return c.clone();
841 }
842
843 let ab = Self::mul_with_rm(a, b, rm);
845 Self::add_with_rm(&ab, c, rm)
846 }
847 }
848
849 pub fn fma(a: &Self, b: &Self, c: &Self) -> Self {
851 Self::fused_mul_add_with_rm(a, b, c, c.get_rounding_mode())
852 }
853}
854
855#[test]
856fn test_fma() {
857 let v0 = -10.;
858 let v1 = -1.1;
859 let v2 = 0.000000000000000000000000000000000000001;
860 let af = Float::from_f64(v0);
861 let bf = Float::from_f64(v1);
862 let cf = Float::from_f64(v2);
863
864 let r = Float::fused_mul_add_with_rm(
865 &af,
866 &bf,
867 &cf,
868 RoundingMode::NearestTiesToEven,
869 );
870
871 assert_eq!(f64::mul_add(v0, v1, v2), r.as_f64());
872}
873
874#[cfg(feature = "std")]
875#[test]
876fn test_fma_simple() {
877 use super::utils;
878 let values = utils::get_special_test_values();
880 for a in values {
881 for b in values {
882 for c in values {
883 let af = Float::from_f64(a);
884 let bf = Float::from_f64(b);
885 let cf = Float::from_f64(c);
886
887 let rf = Float::fused_mul_add_with_rm(
888 &af,
889 &bf,
890 &cf,
891 RoundingMode::NearestTiesToEven,
892 );
893
894 let r0 = rf.as_f64();
895 let r1: f64 = a.mul_add(b, c);
896 assert_eq!(r0.is_finite(), r1.is_finite());
897 assert_eq!(r0.is_nan(), r1.is_nan());
898 assert_eq!(r0.is_infinite(), r1.is_infinite());
899 assert!(r1.is_nan() || r1.is_infinite() || r0 == r1);
901 }
902 }
903 }
904}
905
906#[test]
907fn test_fma_random_vals() {
908 use super::utils;
909
910 let mut lfsr = utils::Lfsr::new();
911
912 fn mul_f32(a: f32, b: f32, c: f32) -> f32 {
913 let a = Float::from_f32(a);
914 let b = Float::from_f32(b);
915 let c = Float::from_f32(c);
916 let k = Float::fused_mul_add_with_rm(
917 &a,
918 &b,
919 &c,
920 RoundingMode::NearestTiesToEven,
921 );
922 k.as_f32()
923 }
924
925 for _ in 0..50000 {
926 let v0 = lfsr.get64() as u32;
927 let v1 = lfsr.get64() as u32;
928 let v2 = lfsr.get64() as u32;
929
930 let f0 = f32::from_bits(v0);
931 let f1 = f32::from_bits(v1);
932 let f2 = f32::from_bits(v2);
933
934 let r0 = mul_f32(f0, f1, f2);
935 let r1 = f32::mul_add(f0, f1, f2);
936 assert_eq!(r0.is_finite(), r1.is_finite());
937 assert_eq!(r0.is_nan(), r1.is_nan());
938 assert_eq!(r0.is_infinite(), r1.is_infinite());
939 let r0_bits = r0.to_bits();
940 let r1_bits = r1.to_bits();
941 assert!(r1.is_nan() || r0_bits == r1_bits);
943 }
944}