use crate::{algorithms, impl_bin_op, nlimbs, Uint};
use core::{
iter::Product,
num::Wrapping,
ops::{Mul, MulAssign},
};
impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[allow(clippy::inline_always)]
#[inline(always)]
#[must_use]
pub fn checked_mul(self, rhs: Self) -> Option<Self> {
match self.overflowing_mul(rhs) {
(value, false) => Some(value),
_ => None,
}
}
#[must_use]
pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) {
let mut result = Self::ZERO;
let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
if BITS > 0 {
overflow |= result.limbs[LIMBS - 1] > Self::MASK;
result.limbs[LIMBS - 1] &= Self::MASK;
}
(result, overflow)
}
#[allow(clippy::inline_always)]
#[inline(always)]
#[must_use]
pub fn saturating_mul(self, rhs: Self) -> Self {
match self.overflowing_mul(rhs) {
(value, false) => value,
_ => Self::MAX,
}
}
#[allow(clippy::inline_always)]
#[inline(always)]
#[must_use]
pub fn wrapping_mul(self, rhs: Self) -> Self {
let mut result = Self::ZERO;
algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs());
if BITS > 0 {
result.limbs[LIMBS - 1] &= Self::MASK;
}
result
}
#[must_use]
pub fn inv_ring(self) -> Option<Self> {
if BITS == 0 || self.limbs[0] & 1 == 0 {
return None;
}
let mut result = Self::ZERO;
result.limbs[0] = {
const W2: Wrapping<u64> = Wrapping(2);
const W3: Wrapping<u64> = Wrapping(3);
let n = Wrapping(self.limbs[0]);
let mut inv = (n * W3) ^ W2; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; inv *= W2 - n * inv; debug_assert_eq!(n.0.wrapping_mul(inv.0), 1);
inv.0
};
let mut correct_limbs = 1;
while correct_limbs < LIMBS {
result *= Self::from(2) - self * result;
correct_limbs *= 2;
}
result.limbs[LIMBS - 1] &= Self::MASK;
Some(result)
}
#[must_use]
#[allow(clippy::similar_names)] pub fn widening_mul<
const BITS_RHS: usize,
const LIMBS_RHS: usize,
const BITS_RES: usize,
const LIMBS_RES: usize,
>(
self,
rhs: Uint<BITS_RHS, LIMBS_RHS>,
) -> Uint<BITS_RES, LIMBS_RES> {
assert_eq!(BITS_RES, BITS + BITS_RHS);
assert_eq!(LIMBS_RES, nlimbs(BITS_RES));
let mut result = Uint::<BITS_RES, LIMBS_RES>::ZERO;
algorithms::addmul(&mut result.limbs, &self.limbs, &rhs.limbs);
if LIMBS_RES > 0 {
debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::<BITS_RES, LIMBS_RES>::MASK);
}
result
}
}
impl<const BITS: usize, const LIMBS: usize> Product<Self> for Uint<BITS, LIMBS> {
fn product<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
if BITS == 0 {
return Self::ZERO;
}
iter.fold(Self::from(1), Self::wrapping_mul)
}
}
impl<'a, const BITS: usize, const LIMBS: usize> Product<&'a Self> for Uint<BITS, LIMBS> {
fn product<I>(iter: I) -> Self
where
I: Iterator<Item = &'a Self>,
{
if BITS == 0 {
return Self::ZERO;
}
iter.copied().fold(Self::from(1), Self::wrapping_mul)
}
}
impl_bin_op!(Mul, mul, MulAssign, mul_assign, wrapping_mul);
#[cfg(test)]
mod tests {
use super::*;
use crate::{const_for, nlimbs};
use proptest::proptest;
#[test]
fn test_commutative() {
const_for!(BITS in SIZES {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
proptest!(|(a: U, b: U)| {
assert_eq!(a * b, b * a);
});
});
}
#[test]
fn test_associative() {
const_for!(BITS in SIZES {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
proptest!(|(a: U, b: U, c: U)| {
assert_eq!(a * (b * c), (a * b) * c);
});
});
}
#[test]
fn test_distributive() {
const_for!(BITS in SIZES {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
proptest!(|(a: U, b: U, c: U)| {
assert_eq!(a * (b + c), (a * b) + (a *c));
});
});
}
#[test]
fn test_identity() {
const_for!(BITS in NON_ZERO {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
proptest!(|(value: U)| {
assert_eq!(value * U::from(0), U::ZERO);
assert_eq!(value * U::from(1), value);
});
});
}
#[test]
fn test_inverse() {
const_for!(BITS in NON_ZERO {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
proptest!(|(mut a: U)| {
a |= U::from(1); assert_eq!(a * a.inv_ring().unwrap(), U::from(1));
assert_eq!(a.inv_ring().unwrap().inv_ring().unwrap(), a);
});
});
}
#[test]
fn test_widening_mul() {
const_for!(BITS_LHS in BENCH {
const LIMBS_LHS: usize = nlimbs(BITS_LHS);
type Lhs = Uint<BITS_LHS, LIMBS_LHS>;
const_for!(BITS_RHS in BENCH {
const LIMBS_RHS: usize = nlimbs(BITS_RHS);
type Rhs = Uint<BITS_RHS, LIMBS_RHS>;
const BITS_RES: usize = BITS_LHS + BITS_RHS;
const LIMBS_RES: usize = nlimbs(BITS_RES);
type Res = Uint<BITS_RES, LIMBS_RES>;
proptest!(|(lhs: Lhs, rhs: Rhs)| {
let expected = Res::from(lhs) * Res::from(rhs);
assert_eq!(lhs.widening_mul(rhs), expected);
});
});
});
}
}
#[cfg(feature = "bench")]
#[doc(hidden)]
pub mod bench {
use super::*;
use crate::{const_for, nlimbs};
use ::proptest::{
arbitrary::Arbitrary,
strategy::{Strategy, ValueTree},
test_runner::TestRunner,
};
use criterion::{black_box, BatchSize, Criterion};
pub fn group(criterion: &mut Criterion) {
const_for!(BITS in BENCH {
const LIMBS: usize = nlimbs(BITS);
bench_mul::<BITS, LIMBS>(criterion);
});
const_for!(BITS_LHS in [64, 256,1024] {
const LIMBS_LHS: usize = nlimbs(BITS_LHS);
const_for!(BITS_RHS in [64, 256,1024] {
const LIMBS_RHS: usize = nlimbs(BITS_RHS);
const BITS_RES: usize = BITS_LHS + BITS_RHS;
const LIMBS_RES: usize = nlimbs(BITS_RES);
bench_widening_mul::<BITS_LHS, LIMBS_LHS, BITS_RHS, LIMBS_RHS, BITS_RES, LIMBS_RES>(criterion);
});
});
}
fn bench_mul<const BITS: usize, const LIMBS: usize>(criterion: &mut Criterion) {
let input = (Uint::<BITS, LIMBS>::arbitrary(), Uint::arbitrary());
let mut runner = TestRunner::deterministic();
criterion.bench_function(&format!("mul/{BITS}"), move |bencher| {
bencher.iter_batched(
|| input.new_tree(&mut runner).unwrap().current(),
|(a, b)| black_box(black_box(a) * black_box(b)),
BatchSize::SmallInput,
);
});
}
fn bench_widening_mul<
const BITS_LHS: usize,
const LIMBS_LHS: usize,
const BITS_RHS: usize,
const LIMBS_RHS: usize,
const BITS_RES: usize,
const LIMBS_RES: usize,
>(
criterion: &mut Criterion,
) {
let input = (
Uint::<BITS_LHS, LIMBS_LHS>::arbitrary(),
Uint::<BITS_RHS, LIMBS_RHS>::arbitrary(),
);
let mut runner = TestRunner::deterministic();
criterion.bench_function(
&format!("widening_mul/{BITS_LHS}/{BITS_RHS}"),
move |bencher| {
bencher.iter_batched(
|| input.new_tree(&mut runner).unwrap().current(),
|(a, b)| {
black_box(
black_box(a).widening_mul::<BITS_RHS, LIMBS_RHS, BITS_RES, LIMBS_RES>(
black_box(b),
),
)
},
BatchSize::SmallInput,
);
},
);
}
}