use crate::{
add,
buffer::Buffer,
div,
ibig::IBig,
primitive::{double_word, extend_word, split_double_word, SignedWord, Word},
shift,
sign::Sign::{self, *},
ubig::{Repr::*, UBig},
};
use core::{
mem,
ops::{Mul, MulAssign},
};
const THRESHOLD_SIMPLE: usize = 24;
const THRESHOLD_KARATSUBA: usize = 192;
impl Mul<UBig> for UBig {
type Output = UBig;
fn mul(self, rhs: UBig) -> UBig {
match (self.into_repr(), rhs.into_repr()) {
(Small(word0), Small(word1)) => UBig::mul_word(word0, word1),
(Small(word0), Large(buffer1)) => UBig::mul_large_word(buffer1, word0),
(Large(buffer0), Small(word1)) => UBig::mul_large_word(buffer0, word1),
(Large(buffer0), Large(buffer1)) => UBig::mul_large(&buffer0, &buffer1),
}
}
}
impl Mul<&UBig> for UBig {
type Output = UBig;
fn mul(self, rhs: &UBig) -> UBig {
match self.into_repr() {
Small(word0) => match rhs.repr() {
Small(word1) => UBig::mul_word(word0, *word1),
Large(buffer1) => UBig::mul_large_word(buffer1.clone(), word0),
},
Large(buffer0) => match rhs.repr() {
Small(word1) => UBig::mul_large_word(buffer0, *word1),
Large(buffer1) => UBig::mul_large(&buffer0, buffer1),
},
}
}
}
impl Mul<UBig> for &UBig {
type Output = UBig;
fn mul(self, rhs: UBig) -> UBig {
rhs.mul(self)
}
}
impl Mul<&UBig> for &UBig {
type Output = UBig;
fn mul(self, rhs: &UBig) -> UBig {
match (self.repr(), rhs.repr()) {
(Small(word0), Small(word1)) => UBig::mul_word(*word0, *word1),
(Small(word0), Large(buffer1)) => UBig::mul_large_word(buffer1.clone(), *word0),
(Large(buffer0), Small(word1)) => UBig::mul_large_word(buffer0.clone(), *word1),
(Large(buffer0), Large(buffer1)) => UBig::mul_large(buffer0, buffer1),
}
}
}
impl MulAssign<UBig> for UBig {
fn mul_assign(&mut self, rhs: UBig) {
*self = mem::take(self) * rhs;
}
}
impl MulAssign<&UBig> for UBig {
fn mul_assign(&mut self, rhs: &UBig) {
*self = mem::take(self) * rhs;
}
}
impl Mul<IBig> for IBig {
type Output = IBig;
fn mul(self, rhs: IBig) -> IBig {
let (sign0, mag0) = self.into_sign_magnitude();
let (sign1, mag1) = rhs.into_sign_magnitude();
IBig::from_sign_magnitude(sign0 * sign1, mag0 * mag1)
}
}
impl Mul<&IBig> for IBig {
type Output = IBig;
fn mul(self, rhs: &IBig) -> IBig {
let (sign0, mag0) = self.into_sign_magnitude();
let (sign1, mag1) = (rhs.sign(), rhs.magnitude());
IBig::from_sign_magnitude(sign0 * sign1, mag0 * mag1)
}
}
impl Mul<IBig> for &IBig {
type Output = IBig;
fn mul(self, rhs: IBig) -> IBig {
rhs.mul(self)
}
}
impl Mul<&IBig> for &IBig {
type Output = IBig;
fn mul(self, rhs: &IBig) -> IBig {
let (sign0, mag0) = (self.sign(), self.magnitude());
let (sign1, mag1) = (rhs.sign(), rhs.magnitude());
IBig::from_sign_magnitude(sign0 * sign1, mag0 * mag1)
}
}
impl MulAssign<IBig> for IBig {
fn mul_assign(&mut self, rhs: IBig) {
*self = mem::take(self) * rhs;
}
}
impl MulAssign<&IBig> for IBig {
fn mul_assign(&mut self, rhs: &IBig) {
*self = mem::take(self) * rhs;
}
}
impl Mul<Sign> for Sign {
type Output = Sign;
fn mul(self, rhs: Sign) -> Sign {
match (self, rhs) {
(Positive, Positive) => Positive,
(Positive, Negative) => Negative,
(Negative, Positive) => Negative,
(Negative, Negative) => Positive,
}
}
}
impl MulAssign<Sign> for Sign {
fn mul_assign(&mut self, rhs: Sign) {
*self = *self * rhs;
}
}
impl UBig {
fn mul_word(a: Word, b: Word) -> UBig {
match a.checked_mul(b) {
Some(c) => UBig::from_word(c),
None => UBig::from(extend_word(a) * extend_word(b)),
}
}
fn mul_large_word(mut buffer: Buffer, a: Word) -> UBig {
match a {
0 => UBig::from_word(0),
1 => buffer.into(),
_ => {
let carry = mul_word_in_place(&mut buffer, a);
if carry != 0 {
buffer.push_may_reallocate(carry);
}
buffer.into()
}
}
}
fn mul_large(lhs: &[Word], rhs: &[Word]) -> UBig {
debug_assert!(lhs.len() >= 2 && rhs.len() >= 2);
let mut buffer = Buffer::allocate(lhs.len() + rhs.len());
buffer.push_zeros(lhs.len() + rhs.len());
let temp_len = mul_temp_buffer_len(lhs.len(), rhs.len());
let mut temp = Buffer::allocate_no_extra(temp_len);
temp.push_zeros(temp_len);
let overflow = add_signed_mul(&mut buffer, Positive, lhs, rhs, &mut temp);
assert!(overflow == 0);
buffer.into()
}
}
fn mul_word_in_place(words: &mut [Word], rhs: Word) -> Word {
mul_word_in_place_with_carry(words, rhs, 0)
}
pub(crate) fn mul_word_in_place_with_carry(words: &mut [Word], rhs: Word, mut carry: Word) -> Word {
for a in words {
let (v_lo, v_hi) =
split_double_word(extend_word(*a) * extend_word(rhs) + extend_word(carry));
*a = v_lo;
carry = v_hi;
}
carry
}
fn add_mul_word_same_len_in_place(words: &mut [Word], mult: Word, rhs: &[Word]) -> Word {
assert!(words.len() == rhs.len());
let mut carry: Word = 0;
for (a, b) in words.iter_mut().zip(rhs.iter()) {
let (v_lo, v_hi) = split_double_word(
extend_word(*a) + extend_word(carry) + extend_word(mult) * extend_word(*b),
);
*a = v_lo;
carry = v_hi;
}
carry
}
fn add_mul_word_in_place(words: &mut [Word], mult: Word, rhs: &[Word]) -> Word {
assert!(words.len() >= rhs.len());
let n = rhs.len();
let mut carry = add_mul_word_same_len_in_place(&mut words[..n], mult, rhs);
if words.len() > n {
carry = Word::from(add::add_word_in_place(&mut words[n..], carry));
}
carry
}
pub(crate) fn sub_mul_word_same_len_in_place(words: &mut [Word], mult: Word, rhs: &[Word]) -> Word {
assert!(words.len() == rhs.len());
let mut carry_plus_max = Word::MAX;
for (a, b) in words.iter_mut().zip(rhs.iter()) {
let v = extend_word(*a)
+ extend_word(carry_plus_max)
+ (double_word(0, Word::MAX) - extend_word(Word::MAX))
- extend_word(mult) * extend_word(*b);
let (v_lo, v_hi) = split_double_word(v);
*a = v_lo;
carry_plus_max = v_hi;
}
Word::MAX - carry_plus_max
}
fn sub_mul_word_in_place(words: &mut [Word], mult: Word, rhs: &[Word]) -> Word {
assert!(words.len() >= rhs.len());
let n = rhs.len();
let mut borrow = sub_mul_word_same_len_in_place(&mut words[..n], mult, rhs);
if words.len() > n {
borrow = Word::from(add::sub_word_in_place(&mut words[n..], borrow));
}
borrow
}
fn mul_temp_buffer_len(a_len: usize, b_len: usize) -> usize {
let n = a_len.min(b_len);
if n < THRESHOLD_SIMPLE {
0
} else if n < THRESHOLD_KARATSUBA {
2 * n + 2 * (n.next_power_of_two().trailing_zeros() as usize)
} else {
4 * n + 13 * (n.next_power_of_two().trailing_zeros() as usize)
}
}
fn add_signed_mul<'a>(
c: &mut [Word],
sign: Sign,
mut a: &'a [Word],
mut b: &'a [Word],
temp: &mut [Word],
) -> SignedWord {
debug_assert!(c.len() == a.len() + b.len());
if a.len() < b.len() {
mem::swap(&mut a, &mut b);
}
if b.len() < THRESHOLD_SIMPLE {
add_signed_mul_simple(c, sign, a, b)
} else if b.len() < THRESHOLD_KARATSUBA {
add_signed_mul_karatsuba(c, sign, a, b, temp)
} else {
add_signed_mul_toom_3(c, sign, a, b, temp)
}
}
fn add_signed_mul_simple(c: &mut [Word], sign: Sign, a: &[Word], b: &[Word]) -> SignedWord {
match sign {
Positive => SignedWord::from(add_mul_simple(c, a, b)),
Negative => -SignedWord::from(sub_mul_simple(c, a, b)),
}
}
fn add_mul_simple(c: &mut [Word], a: &[Word], b: &[Word]) -> bool {
debug_assert!(a.len() >= b.len() && c.len() == a.len() + b.len());
debug_assert!(b.len() < THRESHOLD_SIMPLE);
let mut carry: Word = 0;
for (i, m) in b.iter().enumerate() {
carry += add_mul_word_in_place(&mut c[i..], *m, a);
}
debug_assert!(carry <= 1);
carry != 0
}
fn sub_mul_simple(c: &mut [Word], a: &[Word], b: &[Word]) -> bool {
debug_assert!(a.len() >= b.len() && c.len() == a.len() + b.len());
debug_assert!(b.len() < THRESHOLD_SIMPLE);
let mut borrow: Word = 0;
for (i, m) in b.iter().enumerate() {
borrow += sub_mul_word_in_place(&mut c[i..], *m, a);
}
debug_assert!(borrow <= 1);
borrow != 0
}
fn add_signed_mul_karatsuba<'a>(
mut c: &mut [Word],
sign: Sign,
mut a: &'a [Word],
mut b: &'a [Word],
temp: &mut [Word],
) -> SignedWord {
debug_assert!(a.len() >= b.len() && c.len() == a.len() + b.len());
debug_assert!(b.len() < THRESHOLD_KARATSUBA);
let mut carry: SignedWord = 0;
while b.len() >= THRESHOLD_SIMPLE {
let n = b.len();
let mut carry_n: SignedWord = 0;
while a.len() >= n {
carry_n = add::add_signed_word_in_place(&mut c[n..2 * n], carry_n);
let (a_lo, a_hi) = a.split_at(n);
carry_n += add_signed_mul_karatsuba_same_len(&mut c[..2 * n], sign, a_lo, b, temp);
a = a_hi;
c = &mut c[n..];
}
carry += add::add_signed_word_in_place(&mut c[n..], carry_n);
mem::swap(&mut a, &mut b);
}
carry += add_signed_mul_simple(c, sign, a, b);
debug_assert!(carry.abs() <= 1);
carry
}
fn mul_karatsuba(c: &mut [Word], a: &[Word], b: &[Word], temp: &mut [Word]) {
c.fill(0);
let overflow = add_signed_mul_karatsuba(c, Positive, a, b, temp);
assert!(overflow == 0);
}
fn add_signed_mul_karatsuba_same_len(
c: &mut [Word],
sign: Sign,
a: &[Word],
b: &[Word],
temp: &mut [Word],
) -> SignedWord {
let n = a.len();
debug_assert!(b.len() == n && c.len() == 2 * n);
debug_assert!(n < THRESHOLD_KARATSUBA);
if n < THRESHOLD_SIMPLE {
return add_signed_mul_simple(c, sign, a, b);
}
let mid = (n + 1) / 2;
let (a_lo, a_hi) = a.split_at(mid);
let (b_lo, b_hi) = b.split_at(mid);
let (my_temp, temp) = temp.split_at_mut(2 * mid);
let mut carry: SignedWord = 0;
let mut carry_c0: SignedWord = 0;
let mut carry_c1: SignedWord = 0;
{
let c_lo = &mut my_temp[..];
mul_karatsuba(c_lo, a_lo, b_lo, temp);
carry_c0 += add::add_signed_same_len_in_place(&mut c[..2 * mid], sign, c_lo);
carry_c1 += add::add_signed_same_len_in_place(&mut c[mid..3 * mid], sign, c_lo);
}
{
let c_hi = &mut my_temp[..2 * (n - mid)];
mul_karatsuba(c_hi, a_hi, b_hi, temp);
carry += add::add_signed_same_len_in_place(&mut c[2 * mid..], sign, c_hi);
carry_c1 += add::add_signed_in_place(&mut c[mid..3 * mid], sign, c_hi);
}
{
let (a_diff, b_diff) = my_temp.split_at_mut(mid);
a_diff.copy_from_slice(a_lo);
let mut diff_sign = add::sub_in_place_with_sign(a_diff, a_hi);
b_diff.copy_from_slice(b_lo);
diff_sign *= add::sub_in_place_with_sign(b_diff, b_hi);
carry_c1 += add_signed_mul_karatsuba(
&mut c[mid..3 * mid],
-sign * diff_sign,
a_diff,
b_diff,
temp,
);
}
carry_c1 += add::add_signed_word_in_place(&mut c[2 * mid..3 * mid], carry_c0);
carry += add::add_signed_word_in_place(&mut c[3 * mid..], carry_c1);
assert!(carry.abs() <= 1);
carry
}
fn add_signed_mul_toom_3<'a>(
mut c: &mut [Word],
sign: Sign,
mut a: &'a [Word],
mut b: &'a [Word],
temp: &mut [Word],
) -> SignedWord {
assert!(a.len() >= b.len() && c.len() == a.len() + b.len());
let mut carry: SignedWord = 0;
while b.len() >= THRESHOLD_KARATSUBA {
let n = b.len();
let mut carry_n: SignedWord = 0;
while a.len() >= n {
carry_n = add::add_signed_word_in_place(&mut c[n..2 * n], carry_n);
let (a_lo, a_hi) = a.split_at(n);
carry_n += add_signed_mul_toom_3_same_len(&mut c[..2 * n], sign, a_lo, b, temp);
a = a_hi;
c = &mut c[n..];
}
carry += add::add_signed_word_in_place(&mut c[n..], carry_n);
mem::swap(&mut a, &mut b);
}
carry += add_signed_mul_karatsuba(c, sign, a, b, temp);
debug_assert!(carry.abs() <= 1);
carry
}
fn mul_toom_3_same_len(c: &mut [Word], a: &[Word], b: &[Word], temp: &mut [Word]) {
c.fill(0);
let overflow = add_signed_mul_toom_3_same_len(c, Positive, a, b, temp);
assert!(overflow == 0);
}
fn add_signed_mul_toom_3_same_len(
c: &mut [Word],
sign: Sign,
a: &[Word],
b: &[Word],
temp: &mut [Word],
) -> SignedWord {
let n = a.len();
debug_assert!(b.len() == n && c.len() == 2 * n);
if n < THRESHOLD_KARATSUBA {
return add_signed_mul_karatsuba(c, sign, a, b, temp);
}
let n3 = (n + 2) / 3;
let n3_short = n - 2 * n3;
let (a0, a12) = a.split_at(n3);
let (a1, a2) = a12.split_at(n3);
let (b0, b12) = b.split_at(n3);
let (b1, b2) = b12.split_at(n3);
let (a_eval, temp) = temp.split_at_mut(n3 + 1);
let (b_eval, temp) = temp.split_at_mut(n3 + 1);
let (c_eval, temp) = temp.split_at_mut(2 * n3 + 2);
let (t1, temp) = temp.split_at_mut(2 * n3 + 2);
let (t2, temp) = temp.split_at_mut(2 * n3 + 2);
let mut carry: SignedWord = 0;
let mut carry_c0: SignedWord = 0;
let mut carry_c1: SignedWord = 0;
let mut carry_c2: SignedWord = 0;
let mut carry_c3: SignedWord = 0;
{
let t1_short = &mut t1[..2 * n3];
mul_toom_3_same_len(t1_short, a0, b0, temp);
carry_c0 += add::add_signed_same_len_in_place(&mut c[..2 * n3], sign, t1_short);
carry_c2 += add::add_signed_in_place(&mut c[2 * n3..4 * n3 + 2], -sign, t1_short);
t1[2 * n3] = mul_word_in_place(t1_short, 3);
t1[2 * n3 + 1] = 0;
}
a_eval[..n3].copy_from_slice(a0);
a_eval[n3] = add_mul_word_same_len_in_place(&mut a_eval[..n3], 2, a1);
a_eval[n3] += add_mul_word_in_place(&mut a_eval[..n3], 4, a2);
b_eval[..n3].copy_from_slice(b0);
b_eval[n3] = add_mul_word_same_len_in_place(&mut b_eval[..n3], 2, b1);
b_eval[n3] += add_mul_word_in_place(&mut b_eval[..n3], 4, b2);
let overflow = add_signed_mul_toom_3_same_len(t1, Positive, a_eval, b_eval, temp);
assert!(overflow == 0);
{
let c_eval_short = &mut c_eval[..2 * n3_short];
mul_toom_3_same_len(c_eval_short, a2, b2, temp);
carry_c2 += add::add_signed_in_place(&mut c[2 * n3..4 * n3 + 2], -sign, c_eval_short);
carry += add::add_signed_same_len_in_place(&mut c[4 * n3..], sign, c_eval_short);
c_eval[2 * n3_short] = mul_word_in_place(c_eval_short, 12);
}
let overflow = add::sub_in_place(t1, &c_eval[..2 * n3_short + 1]);
assert!(!overflow);
let mut value_neg1_sign;
{
let (a02, b02) = c_eval.split_at_mut(n3 + 1);
a02[..n3].copy_from_slice(a0);
a02[n3] = Word::from(add::add_in_place(&mut a02[..n3], a2));
a_eval.copy_from_slice(a02);
a_eval[n3] += Word::from(add::add_same_len_in_place(&mut a_eval[..n3], a1));
b02[..n3].copy_from_slice(b0);
b02[n3] = Word::from(add::add_in_place(&mut b02[..n3], b2));
b_eval.copy_from_slice(b02);
b_eval[n3] += Word::from(add::add_same_len_in_place(&mut b_eval[..n3], b1));
mul_toom_3_same_len(t2, a_eval, b_eval, temp);
carry_c1 += add::add_signed_in_place(&mut c[n3..3 * n3 + 2], sign, t2);
a_eval.copy_from_slice(a02);
value_neg1_sign = add::sub_in_place_with_sign(a_eval, a1);
b_eval.copy_from_slice(b02);
value_neg1_sign *= add::sub_in_place_with_sign(b_eval, b1);
}
mul_toom_3_same_len(c_eval, a_eval, b_eval, temp);
let overflow = add::add_signed_same_len_in_place(t2, value_neg1_sign, c_eval);
assert!(overflow == 0);
match value_neg1_sign {
Positive => {
let overflow = add_mul_word_same_len_in_place(t1, 2, c_eval);
assert!(overflow == 0);
}
Negative => {
let overflow = sub_mul_word_same_len_in_place(t1, 2, c_eval);
assert!(overflow == 0);
}
}
let t1_rem = div::div_rem_by_word_in_place(t1, 6);
assert_eq!(t1_rem, 0);
assert_eq!(t2[0] & 1, 0);
shift::shr_in_place(t2, 1);
carry_c1 += add::add_signed_same_len_in_place(&mut c[n3..3 * n3 + 2], -sign, t1);
carry_c3 += add::add_signed_same_len_in_place(&mut c[3 * n3..5 * n3 + 2], sign, t1);
carry_c2 += add::add_signed_same_len_in_place(&mut c[2 * n3..4 * n3 + 2], sign, t2);
carry_c3 += add::add_signed_same_len_in_place(&mut c[3 * n3..5 * n3 + 2], -sign, t2);
carry_c1 += add::add_signed_word_in_place(&mut c[2 * n3..3 * n3 + 2], carry_c0);
carry_c2 += add::add_signed_word_in_place(&mut c[3 * n3 + 2..4 * n3 + 2], carry_c1);
carry_c3 += add::add_signed_word_in_place(&mut c[4 * n3 + 2..5 * n3 + 2], carry_c2);
carry += add::add_signed_word_in_place(&mut c[5 * n3 + 2..], carry_c3);
assert!(carry.abs() <= 1);
carry
}