use fixed_bigint::MulAccOps;
use fixed_bigint::const_numtraits::ConstBorrowingSub;
use num_traits::ops::overflowing::OverflowingAdd;
use num_traits::{One, WrappingMul, Zero};
pub fn cios_montgomery_mul<T: MulAccOps + PartialOrd + ConstBorrowingSub>(
a: &T,
b: &T,
modulus: &T,
n_prime_0: T::Word,
) -> Option<T>
where
T::Word: num_traits::Zero
+ num_traits::One
+ num_traits::WrappingMul
+ num_traits::ops::overflowing::OverflowingAdd
+ core::ops::Add<Output = T::Word>,
{
debug_assert!(a < modulus, "CIOS input a must be in [0, modulus)");
debug_assert!(b < modulus, "CIOS input b must be in [0, modulus)");
let n = T::word_count();
let zero = <T::Word as Zero>::zero();
let one = <T::Word as One>::one();
let mut acc = T::default();
let mut acc_hi = zero;
let mut acc_hi2 = zero;
for i in 0..n {
let ai = a.get_word(i)?;
let carry = T::mul_acc_row(ai, b, &mut acc, zero);
let (sum, overflow) = acc_hi.overflowing_add(&carry);
acc_hi = sum;
if overflow {
acc_hi2 = acc_hi2 + one;
}
let m = acc.get_word(0)?.wrapping_mul(&n_prime_0);
let new_overflow = T::mul_acc_shift_row(m, modulus, &mut acc, acc_hi);
debug_assert!(
new_overflow == zero || new_overflow == one,
"mul_acc_shift_row must return 0 or 1"
);
acc_hi = acc_hi2 + new_overflow;
acc_hi2 = zero;
}
if acc_hi > zero || acc >= *modulus {
let (result, _) = <T as ConstBorrowingSub>::borrowing_sub(acc, *modulus, false);
acc = result;
}
Some(acc)
}
pub trait CiosMontMul: MulAccOps + PartialOrd + ConstBorrowingSub {
fn cios_mont_mul(a: &Self, b: &Self, modulus: &Self, n_prime: &Self) -> Option<Self>;
}
impl<T: MulAccOps + PartialOrd + ConstBorrowingSub> CiosMontMul for T
where
T::Word: num_traits::Zero
+ num_traits::One
+ num_traits::WrappingMul
+ num_traits::ops::overflowing::OverflowingAdd
+ core::ops::Add<Output = T::Word>,
{
#[inline]
fn cios_mont_mul(a: &Self, b: &Self, modulus: &Self, n_prime: &Self) -> Option<Self> {
cios_montgomery_mul(a, b, modulus, n_prime.get_word(0)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::montgomery::basic_mont::{
compute_n_prime_newton, compute_r_mod_n, compute_r2_mod_n, type_bit_width,
wide_montgomery_mul, wide_redc,
};
fn _assert_generic_bound<T: CiosMontMul>() {}
#[test]
fn test_cios_vs_wide_redc_u8() {
use fixed_bigint::FixedUInt;
type U16 = FixedUInt<u8, 2>;
let modulus = U16::from(13u16);
let w = type_bit_width::<U16>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
for a_val in 0u16..13 {
for b_val in 0u16..13 {
let a = U16::from(a_val);
let b = U16::from(b_val);
let (lo, hi) = crate::WideMul::wide_mul(&a, &r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let (lo, hi) = crate::WideMul::wide_mul(&b, &r2_mod_n);
let b_m = wide_redc(lo, hi, modulus, n_prime);
let expected = wide_montgomery_mul(a_m, b_m, modulus, n_prime);
let got = cios_montgomery_mul(&a_m, &b_m, &modulus, n_prime.get_word(0).unwrap())
.unwrap();
assert_eq!(
got, expected,
"CIOS mismatch for {a_val}*{b_val} mod 13: got {got:?}, expected {expected:?}"
);
}
}
}
#[test]
fn test_cios_u32x4() {
use fixed_bigint::FixedUInt;
type U128 = FixedUInt<u32, 4>;
let modulus = !U128::from(0u64) - U128::from(58u64); let w = type_bit_width::<U128>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
let test_vals = [0u64, 1, 2, 42, 0xDEAD_BEEF, 0xCAFE_BABE];
for &a_val in &test_vals {
for &b_val in &test_vals {
let a = U128::from(a_val);
let b = U128::from(b_val);
let (lo, hi) = crate::WideMul::wide_mul(&a, &r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let (lo, hi) = crate::WideMul::wide_mul(&b, &r2_mod_n);
let b_m = wide_redc(lo, hi, modulus, n_prime);
let expected = wide_montgomery_mul(a_m, b_m, modulus, n_prime);
let got = cios_montgomery_mul(&a_m, &b_m, &modulus, n_prime.get_word(0).unwrap())
.unwrap();
assert_eq!(got, expected, "CIOS mismatch for {a_val:#x}*{b_val:#x}");
}
}
}
#[test]
fn test_cios_roundtrip() {
use fixed_bigint::FixedUInt;
type U128 = FixedUInt<u32, 4>;
let modulus = !U128::from(0u64) - U128::from(58u64);
let w = type_bit_width::<U128>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
let a = U128::from(0xDEAD_BEEF_u64);
let b = U128::from(0xCAFE_BABE_u64);
let (lo, hi) = crate::WideMul::wide_mul(&a, &r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let (lo, hi) = crate::WideMul::wide_mul(&b, &r2_mod_n);
let b_m = wide_redc(lo, hi, modulus, n_prime);
let result_m =
cios_montgomery_mul(&a_m, &b_m, &modulus, n_prime.get_word(0).unwrap()).unwrap();
let result = wide_redc(result_m, U128::from(0u64), modulus, n_prime);
let expected = crate::mul::basic_mod_mul(a, b, modulus);
assert_eq!(result, expected);
}
#[test]
fn test_cios_mont_mul_trait() {
use fixed_bigint::FixedUInt;
type U128 = FixedUInt<u32, 4>;
let modulus = !U128::from(0u64) - U128::from(58u64);
let w = type_bit_width::<U128>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
let a = U128::from(42u64);
let b = U128::from(69u64);
let (lo, hi) = crate::WideMul::wide_mul(&a, &r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let (lo, hi) = crate::WideMul::wide_mul(&b, &r2_mod_n);
let b_m = wide_redc(lo, hi, modulus, n_prime);
let result = CiosMontMul::cios_mont_mul(&a_m, &b_m, &modulus, &n_prime).unwrap();
let expected = wide_montgomery_mul(a_m, b_m, modulus, n_prime);
assert_eq!(result, expected);
}
}