Skip to main content

num_bigint/biguint/
multiplication.rs

1use super::addition::{__add2, add2};
2use super::subtraction::sub2;
3use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
4
5use crate::big_digit::{self, BigDigit, BigDigits, DoubleBigDigit};
6use crate::Sign::{self, Minus, NoSign, Plus};
7use crate::{BigInt, UsizePromotion};
8
9use core::cmp::Ordering;
10use core::iter::Product;
11use core::ops::{Mul, MulAssign};
12use num_traits::{CheckedMul, FromPrimitive, Zero};
13
14#[inline]
15pub(super) fn mac_with_carry(
16    a: BigDigit,
17    b: BigDigit,
18    c: BigDigit,
19    acc: &mut DoubleBigDigit,
20) -> BigDigit {
21    *acc += DoubleBigDigit::from(a);
22    *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
23    let lo = *acc as BigDigit;
24    *acc >>= big_digit::BITS;
25    lo
26}
27
28#[inline]
29fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
30    *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
31    let lo = *acc as BigDigit;
32    *acc >>= big_digit::BITS;
33    lo
34}
35
36/// Three argument multiply accumulate:
37/// acc += b * c
38fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
39    if c == 0 {
40        return;
41    }
42
43    let mut carry = 0;
44    let (a_lo, a_hi) = acc.split_at_mut(b.len());
45
46    for (a, &b) in a_lo.iter_mut().zip(b) {
47        *a = mac_with_carry(*a, b, c, &mut carry);
48    }
49
50    let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
51    debug_assert_eq!(carry_hi, 0, "mac_with_carry never keeps high bits");
52
53    let final_carry = __add2(a_hi, &[carry_lo]);
54    assert_eq!(final_carry, 0, "carry overflow during multiplication!");
55}
56
57fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
58    let mut u = BigUint {
59        data: BigDigits::from_slice(slice),
60    };
61    u.normalize();
62    BigInt::from(u)
63}
64
65/// Three argument multiply accumulate:
66/// acc += b * c
67#[allow(clippy::many_single_char_names)]
68fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
69    // Least-significant zeros have no effect on the output.
70    if let Some(&0) = b.first() {
71        if let Some(nz) = b.iter().position(|&d| d != 0) {
72            b = &b[nz..];
73            acc = &mut acc[nz..];
74        } else {
75            return;
76        }
77    }
78    if let Some(&0) = c.first() {
79        if let Some(nz) = c.iter().position(|&d| d != 0) {
80            c = &c[nz..];
81            acc = &mut acc[nz..];
82        } else {
83            return;
84        }
85    }
86
87    let acc = acc;
88    let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
89
90    // We use four algorithms for different input sizes.
91    //
92    // - For small inputs, long multiplication is fastest.
93    // - If y is at least least twice as long as x, split using Half-Karatsuba.
94    // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
95    //   to avoid unnecessary allocations for intermediate values.
96    // - For the largest inputs we use Toom-3, which better optimizes the
97    //   number of operations, but uses more temporary allocations.
98    //
99    // The thresholds are somewhat arbitrary, chosen by evaluating the results
100    // of `cargo bench --bench bigint multiply`.
101
102    if x.len() <= 32 {
103        // Long multiplication:
104        for (i, xi) in x.iter().enumerate() {
105            mac_digit(&mut acc[i..], y, *xi);
106        }
107    } else if x.len() * 2 <= y.len() {
108        // Karatsuba Multiplication for factors with significant length disparity.
109        //
110        // The Half-Karatsuba Multiplication Algorithm is a specialized case of
111        // the normal Karatsuba multiplication algorithm, designed for the scenario
112        // where y has at least twice as many base digits as x.
113        //
114        // In this case y (the longer input) is split into high2 and low2,
115        // at m2 (half the length of y) and x (the shorter input),
116        // is used directly without splitting.
117        //
118        // The algorithm then proceeds as follows:
119        //
120        // 1. Compute the product z0 = x * low2.
121        // 2. Compute the product temp = x * high2.
122        // 3. Adjust the weight of temp by adding m2 (* NBASE ^ m2)
123        // 4. Add temp and z0 to obtain the final result.
124        //
125        // Proof:
126        //
127        // The algorithm can be derived from the original Karatsuba algorithm by
128        // simplifying the formula when the shorter factor x is not split into
129        // high and low parts, as shown below.
130        //
131        // Original Karatsuba formula:
132        //
133        //     result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
134        //
135        // Substitutions:
136        //
137        //     low1 = x
138        //     high1 = 0
139        //
140        // Applying substitutions:
141        //
142        //     z0 = (low1 * low2)
143        //        = (x * low2)
144        //
145        //     z1 = ((low1 + high1) * (low2 + high2))
146        //        = ((x + 0) * (low2 + high2))
147        //        = (x * low2) + (x * high2)
148        //
149        //     z2 = (high1 * high2)
150        //        = (0 * high2)
151        //        = 0
152        //
153        // Simplified using the above substitutions:
154        //
155        //     result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
156        //            = (0 * NBASE ^ (m2 × 2)) + ((z1 - 0 - z0) * NBASE ^ m2) + z0
157        //            = ((z1 - z0) * NBASE ^ m2) + z0
158        //            = ((z1 - z0) * NBASE ^ m2) + z0
159        //            = (x * high2) * NBASE ^ m2 + z0
160        let m2 = y.len() / 2;
161        let (low2, high2) = y.split_at(m2);
162
163        // (x * high2) * NBASE ^ m2 + z0
164        mac3(acc, x, low2);
165        mac3(&mut acc[m2..], x, high2);
166    } else if x.len() <= 256 {
167        // Karatsuba multiplication:
168        //
169        // The idea is that we break x and y up into two smaller numbers that each have about half
170        // as many digits, like so (note that multiplying by b is just a shift):
171        //
172        // x = x0 + x1 * b
173        // y = y0 + y1 * b
174        //
175        // With some algebra, we can compute x * y with three smaller products, where the inputs to
176        // each of the smaller products have only about half as many digits as x and y:
177        //
178        // x * y = (x0 + x1 * b) * (y0 + y1 * b)
179        //
180        // x * y = x0 * y0
181        //       + x0 * y1 * b
182        //       + x1 * y0 * b
183        //       + x1 * y1 * b^2
184        //
185        // Let p0 = x0 * y0 and p2 = x1 * y1:
186        //
187        // x * y = p0
188        //       + (x0 * y1 + x1 * y0) * b
189        //       + p2 * b^2
190        //
191        // The real trick is that middle term:
192        //
193        //         x0 * y1 + x1 * y0
194        //
195        //       = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
196        //
197        //       = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
198        //
199        // Now we complete the square:
200        //
201        //       = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
202        //
203        //       = -((x1 - x0) * (y1 - y0)) + p0 + p2
204        //
205        // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
206        //
207        // x * y = p0
208        //       + (p0 + p2 - p1) * b
209        //       + p2 * b^2
210        //
211        // Where the three intermediate products are:
212        //
213        // p0 = x0 * y0
214        // p1 = (x1 - x0) * (y1 - y0)
215        // p2 = x1 * y1
216        //
217        // In doing the computation, we take great care to avoid unnecessary temporary variables
218        // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
219        // bit so we can use the same temporary variable for all the intermediate products:
220        //
221        // x * y = p2 * b^2 + p2 * b
222        //       + p0 * b + p0
223        //       - p1 * b
224        //
225        // The other trick we use is instead of doing explicit shifts, we slice acc at the
226        // appropriate offset when doing the add.
227
228        // When x is smaller than y, it's significantly faster to pick b such that x is split in
229        // half, not y:
230        let b = x.len() / 2;
231        let (x0, x1) = x.split_at(b);
232        let (y0, y1) = y.split_at(b);
233
234        // We reuse the same BigUint for all the intermediate multiplies and have to size p
235        // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
236        let len = x1.len() + y1.len() + 1;
237        let mut p = BigUint {
238            data: BigDigits::from_vec(vec![0; len]),
239        };
240
241        // p2 = x1 * y1
242        mac3(&mut p.data, x1, y1);
243
244        // Not required, but the adds go faster if we drop any unneeded 0s from the end:
245        p.normalize();
246
247        add2(&mut acc[b..], &p.data);
248        add2(&mut acc[b * 2..], &p.data);
249
250        // Zero out p before the next multiply:
251        p.data.clear();
252        p.data.resize(len, 0);
253
254        // p0 = x0 * y0
255        mac3(&mut p.data, x0, y0);
256        p.normalize();
257
258        add2(acc, &p.data);
259        add2(&mut acc[b..], &p.data);
260
261        // p1 = (x1 - x0) * (y1 - y0)
262        // We do this one last, since it may be negative and acc can't ever be negative:
263        let (j0_sign, j0) = sub_sign(x1, x0);
264        let (j1_sign, j1) = sub_sign(y1, y0);
265
266        match j0_sign * j1_sign {
267            Plus => {
268                p.data.clear();
269                p.data.resize(len, 0);
270
271                mac3(&mut p.data, &j0.data, &j1.data);
272                p.normalize();
273
274                sub2(&mut acc[b..], &p.data);
275            }
276            Minus => {
277                mac3(&mut acc[b..], &j0.data, &j1.data);
278            }
279            NoSign => (),
280        }
281    } else {
282        // Toom-3 multiplication:
283        //
284        // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
285        // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
286        //
287        // The general idea is to treat the large integers digits as
288        // polynomials of a certain degree and determine the coefficients/digits
289        // of the product of the two via interpolation of the polynomial product.
290        let i = y.len() / 3 + 1;
291
292        let x0_len = Ord::min(x.len(), i);
293        let x1_len = Ord::min(x.len() - x0_len, i);
294
295        let y0_len = i;
296        let y1_len = Ord::min(y.len() - y0_len, i);
297
298        // Break x and y into three parts, representating an order two polynomial.
299        // t is chosen to be the size of a digit so we can use faster shifts
300        // in place of multiplications.
301        //
302        // x(t) = x2*t^2 + x1*t + x0
303        let x0 = bigint_from_slice(&x[..x0_len]);
304        let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
305        let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
306
307        // y(t) = y2*t^2 + y1*t + y0
308        let y0 = bigint_from_slice(&y[..y0_len]);
309        let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
310        let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
311
312        // Let w(t) = x(t) * y(t)
313        //
314        // This gives us the following order-4 polynomial.
315        //
316        // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
317        //
318        // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
319        // of simply multiplying the x and y in total, we can evaluate w
320        // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
321        // points.
322        //
323        // It is arbitrary as to what points we evaluate w at but we use the
324        // following.
325        //
326        // w(t) at t = 0, 1, -1, -2 and inf
327        //
328        // The values for w(t) in terms of x(t)*y(t) at these points are:
329        //
330        // let a = w(0)   = x0 * y0
331        // let b = w(1)   = (x2 + x1 + x0) * (y2 + y1 + y0)
332        // let c = w(-1)  = (x2 - x1 + x0) * (y2 - y1 + y0)
333        // let d = w(-2)  = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
334        // let e = w(inf) = x2 * y2 as t -> inf
335
336        // x0 + x2, avoiding temporaries
337        let p = &x0 + &x2;
338
339        // y0 + y2, avoiding temporaries
340        let q = &y0 + &y2;
341
342        // x2 - x1 + x0, avoiding temporaries
343        let p2 = &p - &x1;
344
345        // y2 - y1 + y0, avoiding temporaries
346        let q2 = &q - &y1;
347
348        // w(0)
349        let r0 = &x0 * &y0;
350
351        // w(inf)
352        let r4 = &x2 * &y2;
353
354        // w(1)
355        let r1 = (p + x1) * (q + y1);
356
357        // w(-1)
358        let r2 = &p2 * &q2;
359
360        // w(-2)
361        let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
362
363        // Evaluating these points gives us the following system of linear equations.
364        //
365        //  0  0  0  0  1 | a
366        //  1  1  1  1  1 | b
367        //  1 -1  1 -1  1 | c
368        // 16 -8  4 -2  1 | d
369        //  1  0  0  0  0 | e
370        //
371        // The solved equation (after gaussian elimination or similar)
372        // in terms of its coefficients:
373        //
374        // w0 = w(0)
375        // w1 = w(0)/2 + w(1)/3 - w(-1) + w(-2)/6 - 2*w(inf)
376        // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
377        // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(-2)/6 + 2*w(inf)
378        // w4 = w(inf)
379        //
380        // This particular sequence is given by Bodrato and is an interpolation
381        // of the above equations.
382        let mut comp3: BigInt = (r3 - &r1) / 3u32;
383        let mut comp1: BigInt = (r1 - &r2) >> 1;
384        let mut comp2: BigInt = r2 - &r0;
385        comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
386        comp2 += &comp1 - &r4;
387        comp1 -= &comp3;
388
389        // Recomposition. The coefficients of the polynomial are now known.
390        //
391        // Evaluate at w(t) where t is our given base to get the result.
392        //
393        //     let bits = u64::from(big_digit::BITS) * i as u64;
394        //     let result = r0
395        //         + (comp1 << bits)
396        //         + (comp2 << (2 * bits))
397        //         + (comp3 << (3 * bits))
398        //         + (r4 << (4 * bits));
399        //     let result_pos = result.to_biguint().unwrap();
400        //     add2(&mut acc[..], &result_pos.data);
401        //
402        // But with less intermediate copying:
403        for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
404            match result.sign() {
405                Plus => add2(&mut acc[i * j..], result.digits()),
406                Minus => sub2(&mut acc[i * j..], result.digits()),
407                NoSign => {}
408            }
409        }
410    }
411}
412
413fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
414    let len = x.len() + y.len() + 1;
415    let mut prod = BigUint {
416        data: BigDigits::from_vec(vec![0; len]),
417    };
418
419    mac3(&mut prod.data, x, y);
420    prod.normalize();
421    prod
422}
423
424fn scalar_mul(a: &mut BigUint, b: BigDigit) {
425    match b {
426        0 => a.set_zero(),
427        1 => {}
428        _ => {
429            if b.is_power_of_two() {
430                *a <<= b.trailing_zeros();
431            } else {
432                let mut carry = 0;
433                for a in a.data.iter_mut() {
434                    *a = mul_with_carry(*a, b, &mut carry);
435                }
436                if carry != 0 {
437                    a.data.push(carry as BigDigit);
438                }
439            }
440        }
441    }
442}
443
444fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
445    // Normalize:
446    if let Some(&0) = a.last() {
447        a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
448    }
449    if let Some(&0) = b.last() {
450        b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
451    }
452
453    match cmp_slice(a, b) {
454        Ordering::Greater => {
455            let mut a = a.to_vec();
456            sub2(&mut a, b);
457            (Plus, biguint_from_vec(a))
458        }
459        Ordering::Less => {
460            let mut b = b.to_vec();
461            sub2(&mut b, a);
462            (Minus, biguint_from_vec(b))
463        }
464        Ordering::Equal => (NoSign, BigUint::ZERO),
465    }
466}
467
468macro_rules! impl_mul {
469    ($(impl Mul<$Other:ty> for $Self:ty;)*) => {$(
470        impl Mul<$Other> for $Self {
471            type Output = BigUint;
472
473            #[inline]
474            fn mul(self, other: $Other) -> BigUint {
475                match (&*self.data, &*other.data) {
476                    // multiply by zero
477                    (&[], _) | (_, &[]) => BigUint::ZERO,
478                    // multiply by a scalar
479                    (_, &[digit]) => self * digit,
480                    (&[digit], _) => other * digit,
481                    // full multiplication
482                    (x, y) => mul3(x, y),
483                }
484            }
485        }
486    )*}
487}
488impl_mul! {
489    impl Mul<BigUint> for BigUint;
490    impl Mul<BigUint> for &BigUint;
491    impl Mul<&BigUint> for BigUint;
492    impl Mul<&BigUint> for &BigUint;
493}
494
495macro_rules! impl_mul_assign {
496    ($(impl MulAssign<$Other:ty> for BigUint;)*) => {$(
497        impl MulAssign<$Other> for BigUint {
498            #[inline]
499            fn mul_assign(&mut self, other: $Other) {
500                match (&*self.data, &*other.data) {
501                    // multiply by zero
502                    (&[], _) => {},
503                    (_, &[]) => self.set_zero(),
504                    // multiply by a scalar
505                    (_, &[digit]) => *self *= digit,
506                    (&[digit], _) => *self = other * digit,
507                    // full multiplication
508                    (x, y) => *self = mul3(x, y),
509                }
510            }
511        }
512    )*}
513}
514impl_mul_assign! {
515    impl MulAssign<BigUint> for BigUint;
516    impl MulAssign<&BigUint> for BigUint;
517}
518
519promote_unsigned_scalars!(impl Mul for BigUint, mul);
520promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
521forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
522forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
523forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
524
525impl Mul<u32> for BigUint {
526    type Output = BigUint;
527
528    #[inline]
529    fn mul(mut self, other: u32) -> BigUint {
530        self *= other;
531        self
532    }
533}
534impl MulAssign<u32> for BigUint {
535    #[inline]
536    fn mul_assign(&mut self, other: u32) {
537        scalar_mul(self, other as BigDigit);
538    }
539}
540
541impl Mul<u64> for BigUint {
542    type Output = BigUint;
543
544    #[inline]
545    fn mul(mut self, other: u64) -> BigUint {
546        self *= other;
547        self
548    }
549}
550impl MulAssign<u64> for BigUint {
551    cfg_digit!(
552        #[inline]
553        fn mul_assign(&mut self, other: u64) {
554            if let Some(other) = BigDigit::from_u64(other) {
555                scalar_mul(self, other);
556            } else {
557                let (hi, lo) = big_digit::from_doublebigdigit(other);
558                *self = mul3(&self.data, &[lo, hi]);
559            }
560        }
561
562        #[inline]
563        fn mul_assign(&mut self, other: u64) {
564            scalar_mul(self, other);
565        }
566    );
567}
568
569impl Mul<u128> for BigUint {
570    type Output = BigUint;
571
572    #[inline]
573    fn mul(mut self, other: u128) -> BigUint {
574        self *= other;
575        self
576    }
577}
578
579impl MulAssign<u128> for BigUint {
580    cfg_digit!(
581        #[inline]
582        fn mul_assign(&mut self, other: u128) {
583            if let Some(other) = BigDigit::from_u128(other) {
584                scalar_mul(self, other);
585            } else {
586                *self = match super::u32_from_u128(other) {
587                    (0, 0, c, d) => mul3(&self.data, &[d, c]),
588                    (0, b, c, d) => mul3(&self.data, &[d, c, b]),
589                    (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
590                };
591            }
592        }
593
594        #[inline]
595        fn mul_assign(&mut self, other: u128) {
596            if let Some(other) = BigDigit::from_u128(other) {
597                scalar_mul(self, other);
598            } else {
599                let (hi, lo) = big_digit::from_doublebigdigit(other);
600                *self = mul3(&self.data, &[lo, hi]);
601            }
602        }
603    );
604}
605
606impl CheckedMul for BigUint {
607    #[inline]
608    fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
609        Some(self.mul(v))
610    }
611}
612
613impl_product_iter_type!(BigUint);
614
615#[test]
616fn test_sub_sign() {
617    use crate::BigInt;
618    use num_traits::Num;
619
620    fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
621        let (sign, val) = sub_sign(a, b);
622        BigInt::from_biguint(sign, val)
623    }
624
625    let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
626    let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
627    let a_i = BigInt::from(a.clone());
628    let b_i = BigInt::from(b.clone());
629
630    assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
631    assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
632}