num_bigint/biguint/multiplication.rs
1use super::addition::{__add2, add2};
2use super::subtraction::sub2;
3use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
4
5use crate::big_digit::{self, BigDigit, BigDigits, DoubleBigDigit};
6use crate::Sign::{self, Minus, NoSign, Plus};
7use crate::{BigInt, UsizePromotion};
8
9use core::cmp::Ordering;
10use core::iter::Product;
11use core::ops::{Mul, MulAssign};
12use num_traits::{CheckedMul, FromPrimitive, Zero};
13
14#[inline]
15pub(super) fn mac_with_carry(
16 a: BigDigit,
17 b: BigDigit,
18 c: BigDigit,
19 acc: &mut DoubleBigDigit,
20) -> BigDigit {
21 *acc += DoubleBigDigit::from(a);
22 *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
23 let lo = *acc as BigDigit;
24 *acc >>= big_digit::BITS;
25 lo
26}
27
28#[inline]
29fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
30 *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
31 let lo = *acc as BigDigit;
32 *acc >>= big_digit::BITS;
33 lo
34}
35
36/// Three argument multiply accumulate:
37/// acc += b * c
38fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
39 if c == 0 {
40 return;
41 }
42
43 let mut carry = 0;
44 let (a_lo, a_hi) = acc.split_at_mut(b.len());
45
46 for (a, &b) in a_lo.iter_mut().zip(b) {
47 *a = mac_with_carry(*a, b, c, &mut carry);
48 }
49
50 let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
51 debug_assert_eq!(carry_hi, 0, "mac_with_carry never keeps high bits");
52
53 let final_carry = __add2(a_hi, &[carry_lo]);
54 assert_eq!(final_carry, 0, "carry overflow during multiplication!");
55}
56
57fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
58 let mut u = BigUint {
59 data: BigDigits::from_slice(slice),
60 };
61 u.normalize();
62 BigInt::from(u)
63}
64
65/// Three argument multiply accumulate:
66/// acc += b * c
67#[allow(clippy::many_single_char_names)]
68fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
69 // Least-significant zeros have no effect on the output.
70 if let Some(&0) = b.first() {
71 if let Some(nz) = b.iter().position(|&d| d != 0) {
72 b = &b[nz..];
73 acc = &mut acc[nz..];
74 } else {
75 return;
76 }
77 }
78 if let Some(&0) = c.first() {
79 if let Some(nz) = c.iter().position(|&d| d != 0) {
80 c = &c[nz..];
81 acc = &mut acc[nz..];
82 } else {
83 return;
84 }
85 }
86
87 let acc = acc;
88 let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
89
90 // We use four algorithms for different input sizes.
91 //
92 // - For small inputs, long multiplication is fastest.
93 // - If y is at least least twice as long as x, split using Half-Karatsuba.
94 // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
95 // to avoid unnecessary allocations for intermediate values.
96 // - For the largest inputs we use Toom-3, which better optimizes the
97 // number of operations, but uses more temporary allocations.
98 //
99 // The thresholds are somewhat arbitrary, chosen by evaluating the results
100 // of `cargo bench --bench bigint multiply`.
101
102 if x.len() <= 32 {
103 // Long multiplication:
104 for (i, xi) in x.iter().enumerate() {
105 mac_digit(&mut acc[i..], y, *xi);
106 }
107 } else if x.len() * 2 <= y.len() {
108 // Karatsuba Multiplication for factors with significant length disparity.
109 //
110 // The Half-Karatsuba Multiplication Algorithm is a specialized case of
111 // the normal Karatsuba multiplication algorithm, designed for the scenario
112 // where y has at least twice as many base digits as x.
113 //
114 // In this case y (the longer input) is split into high2 and low2,
115 // at m2 (half the length of y) and x (the shorter input),
116 // is used directly without splitting.
117 //
118 // The algorithm then proceeds as follows:
119 //
120 // 1. Compute the product z0 = x * low2.
121 // 2. Compute the product temp = x * high2.
122 // 3. Adjust the weight of temp by adding m2 (* NBASE ^ m2)
123 // 4. Add temp and z0 to obtain the final result.
124 //
125 // Proof:
126 //
127 // The algorithm can be derived from the original Karatsuba algorithm by
128 // simplifying the formula when the shorter factor x is not split into
129 // high and low parts, as shown below.
130 //
131 // Original Karatsuba formula:
132 //
133 // result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
134 //
135 // Substitutions:
136 //
137 // low1 = x
138 // high1 = 0
139 //
140 // Applying substitutions:
141 //
142 // z0 = (low1 * low2)
143 // = (x * low2)
144 //
145 // z1 = ((low1 + high1) * (low2 + high2))
146 // = ((x + 0) * (low2 + high2))
147 // = (x * low2) + (x * high2)
148 //
149 // z2 = (high1 * high2)
150 // = (0 * high2)
151 // = 0
152 //
153 // Simplified using the above substitutions:
154 //
155 // result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
156 // = (0 * NBASE ^ (m2 × 2)) + ((z1 - 0 - z0) * NBASE ^ m2) + z0
157 // = ((z1 - z0) * NBASE ^ m2) + z0
158 // = ((z1 - z0) * NBASE ^ m2) + z0
159 // = (x * high2) * NBASE ^ m2 + z0
160 let m2 = y.len() / 2;
161 let (low2, high2) = y.split_at(m2);
162
163 // (x * high2) * NBASE ^ m2 + z0
164 mac3(acc, x, low2);
165 mac3(&mut acc[m2..], x, high2);
166 } else if x.len() <= 256 {
167 // Karatsuba multiplication:
168 //
169 // The idea is that we break x and y up into two smaller numbers that each have about half
170 // as many digits, like so (note that multiplying by b is just a shift):
171 //
172 // x = x0 + x1 * b
173 // y = y0 + y1 * b
174 //
175 // With some algebra, we can compute x * y with three smaller products, where the inputs to
176 // each of the smaller products have only about half as many digits as x and y:
177 //
178 // x * y = (x0 + x1 * b) * (y0 + y1 * b)
179 //
180 // x * y = x0 * y0
181 // + x0 * y1 * b
182 // + x1 * y0 * b
183 // + x1 * y1 * b^2
184 //
185 // Let p0 = x0 * y0 and p2 = x1 * y1:
186 //
187 // x * y = p0
188 // + (x0 * y1 + x1 * y0) * b
189 // + p2 * b^2
190 //
191 // The real trick is that middle term:
192 //
193 // x0 * y1 + x1 * y0
194 //
195 // = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
196 //
197 // = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
198 //
199 // Now we complete the square:
200 //
201 // = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
202 //
203 // = -((x1 - x0) * (y1 - y0)) + p0 + p2
204 //
205 // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
206 //
207 // x * y = p0
208 // + (p0 + p2 - p1) * b
209 // + p2 * b^2
210 //
211 // Where the three intermediate products are:
212 //
213 // p0 = x0 * y0
214 // p1 = (x1 - x0) * (y1 - y0)
215 // p2 = x1 * y1
216 //
217 // In doing the computation, we take great care to avoid unnecessary temporary variables
218 // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
219 // bit so we can use the same temporary variable for all the intermediate products:
220 //
221 // x * y = p2 * b^2 + p2 * b
222 // + p0 * b + p0
223 // - p1 * b
224 //
225 // The other trick we use is instead of doing explicit shifts, we slice acc at the
226 // appropriate offset when doing the add.
227
228 // When x is smaller than y, it's significantly faster to pick b such that x is split in
229 // half, not y:
230 let b = x.len() / 2;
231 let (x0, x1) = x.split_at(b);
232 let (y0, y1) = y.split_at(b);
233
234 // We reuse the same BigUint for all the intermediate multiplies and have to size p
235 // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
236 let len = x1.len() + y1.len() + 1;
237 let mut p = BigUint {
238 data: BigDigits::from_vec(vec![0; len]),
239 };
240
241 // p2 = x1 * y1
242 mac3(&mut p.data, x1, y1);
243
244 // Not required, but the adds go faster if we drop any unneeded 0s from the end:
245 p.normalize();
246
247 add2(&mut acc[b..], &p.data);
248 add2(&mut acc[b * 2..], &p.data);
249
250 // Zero out p before the next multiply:
251 p.data.clear();
252 p.data.resize(len, 0);
253
254 // p0 = x0 * y0
255 mac3(&mut p.data, x0, y0);
256 p.normalize();
257
258 add2(acc, &p.data);
259 add2(&mut acc[b..], &p.data);
260
261 // p1 = (x1 - x0) * (y1 - y0)
262 // We do this one last, since it may be negative and acc can't ever be negative:
263 let (j0_sign, j0) = sub_sign(x1, x0);
264 let (j1_sign, j1) = sub_sign(y1, y0);
265
266 match j0_sign * j1_sign {
267 Plus => {
268 p.data.clear();
269 p.data.resize(len, 0);
270
271 mac3(&mut p.data, &j0.data, &j1.data);
272 p.normalize();
273
274 sub2(&mut acc[b..], &p.data);
275 }
276 Minus => {
277 mac3(&mut acc[b..], &j0.data, &j1.data);
278 }
279 NoSign => (),
280 }
281 } else {
282 // Toom-3 multiplication:
283 //
284 // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
285 // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
286 //
287 // The general idea is to treat the large integers digits as
288 // polynomials of a certain degree and determine the coefficients/digits
289 // of the product of the two via interpolation of the polynomial product.
290 let i = y.len() / 3 + 1;
291
292 let x0_len = Ord::min(x.len(), i);
293 let x1_len = Ord::min(x.len() - x0_len, i);
294
295 let y0_len = i;
296 let y1_len = Ord::min(y.len() - y0_len, i);
297
298 // Break x and y into three parts, representating an order two polynomial.
299 // t is chosen to be the size of a digit so we can use faster shifts
300 // in place of multiplications.
301 //
302 // x(t) = x2*t^2 + x1*t + x0
303 let x0 = bigint_from_slice(&x[..x0_len]);
304 let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
305 let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
306
307 // y(t) = y2*t^2 + y1*t + y0
308 let y0 = bigint_from_slice(&y[..y0_len]);
309 let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
310 let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
311
312 // Let w(t) = x(t) * y(t)
313 //
314 // This gives us the following order-4 polynomial.
315 //
316 // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
317 //
318 // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
319 // of simply multiplying the x and y in total, we can evaluate w
320 // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
321 // points.
322 //
323 // It is arbitrary as to what points we evaluate w at but we use the
324 // following.
325 //
326 // w(t) at t = 0, 1, -1, -2 and inf
327 //
328 // The values for w(t) in terms of x(t)*y(t) at these points are:
329 //
330 // let a = w(0) = x0 * y0
331 // let b = w(1) = (x2 + x1 + x0) * (y2 + y1 + y0)
332 // let c = w(-1) = (x2 - x1 + x0) * (y2 - y1 + y0)
333 // let d = w(-2) = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
334 // let e = w(inf) = x2 * y2 as t -> inf
335
336 // x0 + x2, avoiding temporaries
337 let p = &x0 + &x2;
338
339 // y0 + y2, avoiding temporaries
340 let q = &y0 + &y2;
341
342 // x2 - x1 + x0, avoiding temporaries
343 let p2 = &p - &x1;
344
345 // y2 - y1 + y0, avoiding temporaries
346 let q2 = &q - &y1;
347
348 // w(0)
349 let r0 = &x0 * &y0;
350
351 // w(inf)
352 let r4 = &x2 * &y2;
353
354 // w(1)
355 let r1 = (p + x1) * (q + y1);
356
357 // w(-1)
358 let r2 = &p2 * &q2;
359
360 // w(-2)
361 let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
362
363 // Evaluating these points gives us the following system of linear equations.
364 //
365 // 0 0 0 0 1 | a
366 // 1 1 1 1 1 | b
367 // 1 -1 1 -1 1 | c
368 // 16 -8 4 -2 1 | d
369 // 1 0 0 0 0 | e
370 //
371 // The solved equation (after gaussian elimination or similar)
372 // in terms of its coefficients:
373 //
374 // w0 = w(0)
375 // w1 = w(0)/2 + w(1)/3 - w(-1) + w(-2)/6 - 2*w(inf)
376 // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
377 // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(-2)/6 + 2*w(inf)
378 // w4 = w(inf)
379 //
380 // This particular sequence is given by Bodrato and is an interpolation
381 // of the above equations.
382 let mut comp3: BigInt = (r3 - &r1) / 3u32;
383 let mut comp1: BigInt = (r1 - &r2) >> 1;
384 let mut comp2: BigInt = r2 - &r0;
385 comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
386 comp2 += &comp1 - &r4;
387 comp1 -= &comp3;
388
389 // Recomposition. The coefficients of the polynomial are now known.
390 //
391 // Evaluate at w(t) where t is our given base to get the result.
392 //
393 // let bits = u64::from(big_digit::BITS) * i as u64;
394 // let result = r0
395 // + (comp1 << bits)
396 // + (comp2 << (2 * bits))
397 // + (comp3 << (3 * bits))
398 // + (r4 << (4 * bits));
399 // let result_pos = result.to_biguint().unwrap();
400 // add2(&mut acc[..], &result_pos.data);
401 //
402 // But with less intermediate copying:
403 for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
404 match result.sign() {
405 Plus => add2(&mut acc[i * j..], result.digits()),
406 Minus => sub2(&mut acc[i * j..], result.digits()),
407 NoSign => {}
408 }
409 }
410 }
411}
412
413fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
414 let len = x.len() + y.len() + 1;
415 let mut prod = BigUint {
416 data: BigDigits::from_vec(vec![0; len]),
417 };
418
419 mac3(&mut prod.data, x, y);
420 prod.normalize();
421 prod
422}
423
424fn scalar_mul(a: &mut BigUint, b: BigDigit) {
425 match b {
426 0 => a.set_zero(),
427 1 => {}
428 _ => {
429 if b.is_power_of_two() {
430 *a <<= b.trailing_zeros();
431 } else {
432 let mut carry = 0;
433 for a in a.data.iter_mut() {
434 *a = mul_with_carry(*a, b, &mut carry);
435 }
436 if carry != 0 {
437 a.data.push(carry as BigDigit);
438 }
439 }
440 }
441 }
442}
443
444fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
445 // Normalize:
446 if let Some(&0) = a.last() {
447 a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
448 }
449 if let Some(&0) = b.last() {
450 b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
451 }
452
453 match cmp_slice(a, b) {
454 Ordering::Greater => {
455 let mut a = a.to_vec();
456 sub2(&mut a, b);
457 (Plus, biguint_from_vec(a))
458 }
459 Ordering::Less => {
460 let mut b = b.to_vec();
461 sub2(&mut b, a);
462 (Minus, biguint_from_vec(b))
463 }
464 Ordering::Equal => (NoSign, BigUint::ZERO),
465 }
466}
467
468macro_rules! impl_mul {
469 ($(impl Mul<$Other:ty> for $Self:ty;)*) => {$(
470 impl Mul<$Other> for $Self {
471 type Output = BigUint;
472
473 #[inline]
474 fn mul(self, other: $Other) -> BigUint {
475 match (&*self.data, &*other.data) {
476 // multiply by zero
477 (&[], _) | (_, &[]) => BigUint::ZERO,
478 // multiply by a scalar
479 (_, &[digit]) => self * digit,
480 (&[digit], _) => other * digit,
481 // full multiplication
482 (x, y) => mul3(x, y),
483 }
484 }
485 }
486 )*}
487}
488impl_mul! {
489 impl Mul<BigUint> for BigUint;
490 impl Mul<BigUint> for &BigUint;
491 impl Mul<&BigUint> for BigUint;
492 impl Mul<&BigUint> for &BigUint;
493}
494
495macro_rules! impl_mul_assign {
496 ($(impl MulAssign<$Other:ty> for BigUint;)*) => {$(
497 impl MulAssign<$Other> for BigUint {
498 #[inline]
499 fn mul_assign(&mut self, other: $Other) {
500 match (&*self.data, &*other.data) {
501 // multiply by zero
502 (&[], _) => {},
503 (_, &[]) => self.set_zero(),
504 // multiply by a scalar
505 (_, &[digit]) => *self *= digit,
506 (&[digit], _) => *self = other * digit,
507 // full multiplication
508 (x, y) => *self = mul3(x, y),
509 }
510 }
511 }
512 )*}
513}
514impl_mul_assign! {
515 impl MulAssign<BigUint> for BigUint;
516 impl MulAssign<&BigUint> for BigUint;
517}
518
519promote_unsigned_scalars!(impl Mul for BigUint, mul);
520promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
521forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
522forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
523forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
524
525impl Mul<u32> for BigUint {
526 type Output = BigUint;
527
528 #[inline]
529 fn mul(mut self, other: u32) -> BigUint {
530 self *= other;
531 self
532 }
533}
534impl MulAssign<u32> for BigUint {
535 #[inline]
536 fn mul_assign(&mut self, other: u32) {
537 scalar_mul(self, other as BigDigit);
538 }
539}
540
541impl Mul<u64> for BigUint {
542 type Output = BigUint;
543
544 #[inline]
545 fn mul(mut self, other: u64) -> BigUint {
546 self *= other;
547 self
548 }
549}
550impl MulAssign<u64> for BigUint {
551 cfg_digit!(
552 #[inline]
553 fn mul_assign(&mut self, other: u64) {
554 if let Some(other) = BigDigit::from_u64(other) {
555 scalar_mul(self, other);
556 } else {
557 let (hi, lo) = big_digit::from_doublebigdigit(other);
558 *self = mul3(&self.data, &[lo, hi]);
559 }
560 }
561
562 #[inline]
563 fn mul_assign(&mut self, other: u64) {
564 scalar_mul(self, other);
565 }
566 );
567}
568
569impl Mul<u128> for BigUint {
570 type Output = BigUint;
571
572 #[inline]
573 fn mul(mut self, other: u128) -> BigUint {
574 self *= other;
575 self
576 }
577}
578
579impl MulAssign<u128> for BigUint {
580 cfg_digit!(
581 #[inline]
582 fn mul_assign(&mut self, other: u128) {
583 if let Some(other) = BigDigit::from_u128(other) {
584 scalar_mul(self, other);
585 } else {
586 *self = match super::u32_from_u128(other) {
587 (0, 0, c, d) => mul3(&self.data, &[d, c]),
588 (0, b, c, d) => mul3(&self.data, &[d, c, b]),
589 (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
590 };
591 }
592 }
593
594 #[inline]
595 fn mul_assign(&mut self, other: u128) {
596 if let Some(other) = BigDigit::from_u128(other) {
597 scalar_mul(self, other);
598 } else {
599 let (hi, lo) = big_digit::from_doublebigdigit(other);
600 *self = mul3(&self.data, &[lo, hi]);
601 }
602 }
603 );
604}
605
606impl CheckedMul for BigUint {
607 #[inline]
608 fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
609 Some(self.mul(v))
610 }
611}
612
613impl_product_iter_type!(BigUint);
614
615#[test]
616fn test_sub_sign() {
617 use crate::BigInt;
618 use num_traits::Num;
619
620 fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
621 let (sign, val) = sub_sign(a, b);
622 BigInt::from_biguint(sign, val)
623 }
624
625 let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
626 let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
627 let a_i = BigInt::from(a.clone());
628 let b_i = BigInt::from(b.clone());
629
630 assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
631 assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
632}