use super::UInt;
use crate::{Concat, Limb, Wrapping};
use core::ops::{Mul, MulAssign};
use subtle::CtOption;
impl<const LIMBS: usize> UInt<LIMBS> {
pub const fn mul_wide(&self, rhs: &Self) -> (Self, Self) {
let mut i = 0;
let mut lo = Self::ZERO;
let mut hi = Self::ZERO;
while i < LIMBS {
let mut j = 0;
let mut carry = Limb::ZERO;
while j < LIMBS {
let k = i + j;
if k >= LIMBS {
let (n, c) = hi.limbs[k - LIMBS].mac(self.limbs[i], rhs.limbs[j], carry);
hi.limbs[k - LIMBS] = n;
carry = c;
} else {
let (n, c) = lo.limbs[k].mac(self.limbs[i], rhs.limbs[j], carry);
lo.limbs[k] = n;
carry = c;
}
j += 1;
}
hi.limbs[i + j - LIMBS] = carry;
i += 1;
}
(hi, lo)
}
pub const fn wrapping_mul(&self, rhs: &Self) -> Self {
self.mul_wide(rhs).1
}
pub fn checked_mul(&self, rhs: &Self) -> CtOption<Self> {
let (hi, lo) = self.mul_wide(rhs);
CtOption::new(lo, hi.is_zero())
}
pub fn square(&self) -> <Self as Concat>::Output
where
Self: Concat,
{
let (hi, lo) = self.mul_wide(self);
hi.concat(&lo)
}
}
impl<const LIMBS: usize> Mul for Wrapping<UInt<LIMBS>> {
type Output = Self;
fn mul(self, rhs: Self) -> Wrapping<UInt<LIMBS>> {
Wrapping(self.0.wrapping_mul(&rhs.0))
}
}
impl<const LIMBS: usize> Mul<&Wrapping<UInt<LIMBS>>> for Wrapping<UInt<LIMBS>> {
type Output = Wrapping<UInt<LIMBS>>;
fn mul(self, rhs: &Wrapping<UInt<LIMBS>>) -> Wrapping<UInt<LIMBS>> {
Wrapping(self.0.wrapping_mul(&rhs.0))
}
}
impl<const LIMBS: usize> Mul<Wrapping<UInt<LIMBS>>> for &Wrapping<UInt<LIMBS>> {
type Output = Wrapping<UInt<LIMBS>>;
fn mul(self, rhs: Wrapping<UInt<LIMBS>>) -> Wrapping<UInt<LIMBS>> {
Wrapping(self.0.wrapping_mul(&rhs.0))
}
}
impl<const LIMBS: usize> Mul<&Wrapping<UInt<LIMBS>>> for &Wrapping<UInt<LIMBS>> {
type Output = Wrapping<UInt<LIMBS>>;
fn mul(self, rhs: &Wrapping<UInt<LIMBS>>) -> Wrapping<UInt<LIMBS>> {
Wrapping(self.0.wrapping_mul(&rhs.0))
}
}
impl<const LIMBS: usize> MulAssign for Wrapping<UInt<LIMBS>> {
fn mul_assign(&mut self, other: Self) {
*self = *self * other;
}
}
impl<const LIMBS: usize> MulAssign<&Wrapping<UInt<LIMBS>>> for Wrapping<UInt<LIMBS>> {
fn mul_assign(&mut self, other: &Self) {
*self = *self * other;
}
}
#[cfg(test)]
mod tests {
use crate::Split;
use crate::U64;
#[test]
fn mul_wide_zero_and_one() {
assert_eq!(U64::ZERO.mul_wide(&U64::ZERO), (U64::ZERO, U64::ZERO));
assert_eq!(U64::ZERO.mul_wide(&U64::ONE), (U64::ZERO, U64::ZERO));
assert_eq!(U64::ONE.mul_wide(&U64::ZERO), (U64::ZERO, U64::ZERO));
assert_eq!(U64::ONE.mul_wide(&U64::ONE), (U64::ZERO, U64::ONE));
}
#[test]
fn mul_wide_lo_only() {
let primes: &[u32] = &[3, 5, 17, 256, 65537];
for &a_int in primes {
for &b_int in primes {
let (hi, lo) = U64::from_u32(a_int).mul_wide(&U64::from_u32(b_int));
let expected = U64::from_u64(a_int as u64 * b_int as u64);
assert_eq!(lo, expected);
assert!(bool::from(hi.is_zero()));
}
}
}
#[test]
fn checked_mul_ok() {
let n = U64::from_u32(0xffff_ffff);
assert_eq!(
n.checked_mul(&n).unwrap(),
U64::from_u64(0xffff_fffe_0000_0001)
);
}
#[test]
fn checked_mul_overflow() {
let n = U64::from_u64(0xffff_ffff_ffff_ffff);
assert!(bool::from(n.checked_mul(&n).is_none()));
}
#[test]
fn square() {
let n = U64::from_u64(0xffff_ffff_ffff_ffff);
let (hi, lo) = n.square().split();
assert_eq!(lo, U64::from_u64(1));
assert_eq!(hi, U64::from_u64(0xffff_ffff_ffff_fffe));
}
}