use crate::inv::basic_mod_inv;
use crate::parity::Parity;
use crate::wide_mul::WideMul;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NPrimeMethod {
TrialSearch,
ExtendedEuclidean,
HenselsLifting,
#[default]
Newton,
}
fn compute_n_prime_trial_search<T>(modulus: T, r: T) -> Option<T>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Rem<Output = T>,
{
let target = r - T::one();
let mut n_prime = T::one();
loop {
if (modulus * n_prime) % r == target {
return Some(n_prime);
}
n_prime = n_prime + T::one();
if n_prime >= r {
return None; }
}
}
fn compute_n_prime_extended_euclidean<T>(modulus: T, r: T) -> Option<T>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Rem<Output = T>
+ core::ops::Div<Output = T>,
{
if let Some(modulus_inv) = basic_mod_inv(modulus, r) {
if modulus_inv == T::zero() {
Some(r - T::one()) } else {
Some(r - modulus_inv)
}
} else {
None }
}
fn compute_n_prime_hensels_lifting<T>(modulus: T, r: T, r_bits: usize) -> Option<T>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Rem<Output = T>
+ core::ops::Shl<usize, Output = T>
+ core::ops::BitAnd<Output = T>,
{
let mut n_prime = T::one();
for k in 2..=r_bits {
let target_mod = T::one() << k; let check_val = (modulus * n_prime + T::one()) % target_mod;
if check_val != T::zero() {
let prev_power = T::one() << (k - 1);
if check_val == prev_power {
n_prime = n_prime + prev_power;
}
}
}
let final_check = (modulus * n_prime) % r;
let target = r - T::one();
if final_check != target {
None } else {
Some(n_prime)
}
}
pub fn basic_compute_montgomery_params_with_method<T>(
modulus: T,
method: NPrimeMethod,
) -> Option<(T, T, T, usize)>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ core::ops::Shl<usize, Output = T>
+ core::ops::Div<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Rem<Output = T>
+ core::ops::Add<Output = T>
+ core::ops::BitAnd<Output = T>,
{
let mut r = T::one();
let mut r_bits = 0usize;
while r <= modulus {
r = r << 1; r_bits += 1;
}
let r_inv = basic_mod_inv(r, modulus)?;
let n_prime = match method {
NPrimeMethod::TrialSearch => compute_n_prime_trial_search(modulus, r)?,
NPrimeMethod::ExtendedEuclidean => compute_n_prime_extended_euclidean(modulus, r)?,
NPrimeMethod::HenselsLifting => compute_n_prime_hensels_lifting(modulus, r, r_bits)?,
NPrimeMethod::Newton => {
compute_n_prime_extended_euclidean(modulus, r)?
}
};
Some((r, r_inv, n_prime, r_bits))
}
pub fn basic_compute_montgomery_params<T>(modulus: T) -> Option<(T, T, T, usize)>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ core::ops::Shl<usize, Output = T>
+ core::ops::Div<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Mul<Output = T>
+ core::ops::Rem<Output = T>
+ core::ops::Add<Output = T>
+ core::ops::BitAnd<Output = T>,
{
basic_compute_montgomery_params_with_method(modulus, NPrimeMethod::default())
}
pub fn basic_to_montgomery<T>(a: T, modulus: T, r: T) -> T
where
T: core::cmp::PartialOrd
+ Copy
+ num_traits::Zero
+ num_traits::One
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shr<usize, Output = T>
+ core::ops::Rem<Output = T>
+ crate::parity::Parity,
{
crate::mul::basic_mod_mul(a, r, modulus)
}
pub fn basic_from_montgomery<T>(a_mont: T, modulus: T, n_prime: T, r_bits: usize) -> T
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialOrd
+ core::ops::Mul<Output = T>
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::Shl<usize, Output = T>
+ core::ops::BitAnd<Output = T>,
{
if r_bits == 0 {
return if a_mont >= modulus {
a_mont - modulus
} else {
a_mont
};
}
let mask = (T::one() << r_bits) - T::one();
let m = ((a_mont & mask) * n_prime) & mask;
let t = (a_mont + m * modulus) >> r_bits;
if t >= modulus { t - modulus } else { t }
}
pub fn wide_from_montgomery<T>(a_mont: T, modulus: T, n_prime: T) -> T
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialOrd
+ WideMul
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingMul
+ num_traits::WrappingSub,
{
wide_redc(a_mont, T::zero(), modulus, n_prime)
}
pub fn basic_montgomery_mul<T>(a_mont: T, b_mont: T, modulus: T, n_prime: T, r_bits: usize) -> T
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialOrd
+ core::ops::Mul<Output = T>
+ core::ops::Add<Output = T>
+ core::ops::Sub<Output = T>
+ core::ops::Rem<Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::Shl<usize, Output = T>
+ core::ops::BitAnd<Output = T>
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ crate::parity::Parity,
{
let product = crate::mul::basic_mod_mul(a_mont, b_mont, modulus);
basic_from_montgomery(product, modulus, n_prime, r_bits)
}
pub const fn type_bit_width<T>() -> usize {
core::mem::size_of::<T>() * 8
}
fn mod_double<T>(val: T, modulus: T) -> T
where
T: Copy + PartialOrd + num_traits::ops::overflowing::OverflowingAdd + num_traits::WrappingSub,
{
let (doubled, overflow) = val.overflowing_add(&val);
if overflow || doubled >= modulus {
doubled.wrapping_sub(&modulus)
} else {
doubled
}
}
pub fn compute_n_prime_newton<T>(modulus: T, w: usize) -> T
where
T: Copy
+ num_traits::One
+ num_traits::Zero
+ num_traits::WrappingMul
+ num_traits::WrappingSub
+ num_traits::WrappingAdd,
{
let two = T::one().wrapping_add(&T::one());
let mut x = T::one(); let mut precision = 1usize;
while precision < w {
x = x.wrapping_mul(&two.wrapping_sub(&modulus.wrapping_mul(&x)));
precision *= 2;
}
T::zero().wrapping_sub(&x)
}
fn mod_exp2<T>(val: T, modulus: T, w: usize) -> T
where
T: Copy
+ PartialEq
+ PartialOrd
+ num_traits::Zero
+ num_traits::One
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingSub,
{
if modulus == T::one() {
return T::zero();
}
let mut result = val;
for _ in 0..w {
result = mod_double(result, modulus);
}
result
}
pub fn compute_r_mod_n<T>(modulus: T, w: usize) -> T
where
T: Copy
+ PartialEq
+ PartialOrd
+ num_traits::Zero
+ num_traits::One
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingSub,
{
mod_exp2(T::one(), modulus, w)
}
pub fn compute_r2_mod_n<T>(r_mod_n: T, modulus: T, w: usize) -> T
where
T: Copy
+ PartialEq
+ PartialOrd
+ num_traits::Zero
+ num_traits::One
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingSub,
{
mod_exp2(r_mod_n, modulus, w)
}
fn accumulate_high_half_carry<T>(result: T, carry1: bool, carry2: bool) -> (T, bool)
where
T: Copy + num_traits::One + num_traits::ops::overflowing::OverflowingAdd,
{
if carry1 {
let (r2, carry3) = result.overflowing_add(&T::one());
(r2, carry2 || carry3)
} else {
(result, carry2)
}
}
pub fn wide_redc<T>(t_lo: T, t_hi: T, modulus: T, n_prime: T) -> T
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialOrd
+ WideMul
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingMul
+ num_traits::WrappingSub,
{
let m = t_lo.wrapping_mul(&n_prime);
let (m_lo, m_hi) = m.wide_mul(&modulus);
let (_discard_lo, carry1) = t_lo.overflowing_add(&m_lo);
let (result, carry2) = t_hi.overflowing_add(&m_hi);
let (result, extra_bit) = accumulate_high_half_carry(result, carry1, carry2);
if extra_bit || result >= modulus {
result.wrapping_sub(&modulus)
} else {
result
}
}
pub fn wide_montgomery_mul<T>(a_mont: T, b_mont: T, modulus: T, n_prime: T) -> T
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialOrd
+ WideMul
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingMul
+ num_traits::WrappingSub,
{
let (lo, hi) = a_mont.wide_mul(&b_mont);
wide_redc(lo, hi, modulus, n_prime)
}
pub fn basic_montgomery_mod_mul_with_method<T>(
a: T,
b: T,
modulus: T,
_method: NPrimeMethod,
) -> Option<T>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ WideMul
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingMul
+ num_traits::WrappingAdd
+ num_traits::WrappingSub
+ Parity
+ core::ops::Rem<Output = T>,
{
basic_montgomery_mod_mul(a, b, modulus)
}
fn reduce_mod<T>(val: T, modulus: T) -> T
where
T: Copy + core::ops::Rem<Output = T>,
{
val % modulus
}
pub fn basic_montgomery_mod_mul<T>(a: T, b: T, modulus: T) -> Option<T>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ WideMul
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingMul
+ num_traits::WrappingAdd
+ num_traits::WrappingSub
+ Parity
+ core::ops::Rem<Output = T>,
{
if modulus == T::zero() || modulus.is_even() {
return None;
}
let w = type_bit_width::<T>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
let a_red = reduce_mod(a, modulus);
let b_red = reduce_mod(b, modulus);
let (lo, hi) = a_red.wide_mul(&r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let (lo, hi) = b_red.wide_mul(&r2_mod_n);
let b_m = wide_redc(lo, hi, modulus, n_prime);
let r_m = wide_montgomery_mul(a_m, b_m, modulus, n_prime);
Some(wide_redc(r_m, T::zero(), modulus, n_prime))
}
pub fn basic_montgomery_mod_exp_with_method<T>(
base: T,
exponent: T,
modulus: T,
_method: NPrimeMethod,
) -> Option<T>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ WideMul
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingMul
+ num_traits::WrappingAdd
+ num_traits::WrappingSub
+ Parity
+ core::ops::Rem<Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::ShrAssign<usize>,
{
basic_montgomery_mod_exp(base, exponent, modulus)
}
pub fn basic_montgomery_mod_exp<T>(base: T, exponent: T, modulus: T) -> Option<T>
where
T: Copy
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ WideMul
+ num_traits::ops::overflowing::OverflowingAdd
+ num_traits::WrappingMul
+ num_traits::WrappingAdd
+ num_traits::WrappingSub
+ Parity
+ core::ops::Rem<Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::ShrAssign<usize>,
{
if modulus == T::zero() || modulus.is_even() {
return None;
}
let w = type_bit_width::<T>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
let one_mont = r_mod_n;
let base_red = reduce_mod(base, modulus);
let (lo, hi) = base_red.wide_mul(&r2_mod_n);
let mut base_mont = wide_redc(lo, hi, modulus, n_prime);
let mut result = one_mont;
let mut exp = exponent;
while exp > T::zero() {
if exp.is_odd() {
result = wide_montgomery_mul(result, base_mont, modulus, n_prime);
}
exp >>= 1;
if exp > T::zero() {
base_mont = wide_montgomery_mul(base_mont, base_mont, modulus, n_prime);
}
}
Some(wide_redc(result, T::zero(), modulus, n_prime))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_compute_n_prime_trial_search_failure() {
let modulus = 4u32;
let r = 8u32;
let result = compute_n_prime_trial_search(modulus, r);
assert!(
result.is_none(),
"Should return None for invalid modulus-R pair"
);
}
#[test]
fn test_basic_compute_montgomery_params_failure() {
let even_modulus = 4u32;
let result = basic_compute_montgomery_params(even_modulus);
assert!(result.is_none(), "Should return None for even modulus");
}
#[test]
fn test_basic_compute_montgomery_params_failure_with_method() {
let invalid_modulus = 4u32;
let trial_result =
basic_compute_montgomery_params_with_method(invalid_modulus, NPrimeMethod::TrialSearch);
assert!(
trial_result.is_none(),
"Trial search should fail with even modulus"
);
let euclidean_result = basic_compute_montgomery_params_with_method(
invalid_modulus,
NPrimeMethod::ExtendedEuclidean,
);
assert!(
euclidean_result.is_none(),
"Extended Euclidean should fail with even modulus"
);
let hensels_result = basic_compute_montgomery_params_with_method(
invalid_modulus,
NPrimeMethod::HenselsLifting,
);
assert!(
hensels_result.is_none(),
"Hensel's lifting should fail with even modulus"
);
}
#[test]
fn test_basic_montgomery_mod_mul_parameter_failure() {
let invalid_modulus = 4u32;
let a = 2u32;
let b = 3u32;
let result = basic_montgomery_mod_mul(a, b, invalid_modulus);
assert!(
result.is_none(),
"Montgomery mod_mul should return None for invalid modulus"
);
}
#[test]
fn test_basic_montgomery_mod_exp_parameter_failure() {
let invalid_modulus = 4u32;
let base = 2u32;
let exponent = 3u32;
let result = basic_montgomery_mod_exp(base, exponent, invalid_modulus);
assert!(
result.is_none(),
"Montgomery mod_exp should return None for invalid modulus"
);
}
#[test]
fn test_basic_montgomery_reduction_final_subtraction() {
let modulus = 15u32;
let (r, _r_inv, n_prime, r_bits) = basic_compute_montgomery_params(modulus).unwrap();
let high_value = 14u32; let mont_high = basic_to_montgomery(high_value, modulus, r);
let result = basic_from_montgomery(mont_high, modulus, n_prime, r_bits);
assert_eq!(result, high_value);
let mont_max = basic_to_montgomery(modulus - 1, modulus, r);
let result_max = basic_from_montgomery(mont_max, modulus, n_prime, r_bits);
assert_eq!(result_max, modulus - 1);
}
#[test]
fn test_basic_multiplication_edge_cases() {
let modulus = 21u32; let (r, _r_inv, n_prime, r_bits) = basic_compute_montgomery_params(modulus).unwrap();
let a = 20u32; let b = 19u32;
let a_mont = basic_to_montgomery(a, modulus, r);
let b_mont = basic_to_montgomery(b, modulus, r);
let result_mont = basic_montgomery_mul(a_mont, b_mont, modulus, n_prime, r_bits);
let result = basic_from_montgomery(result_mont, modulus, n_prime, r_bits);
let expected = (a * b) % modulus;
assert_eq!(result, expected);
}
#[test]
fn test_basic_n_prime_computation_edge_cases() {
let test_moduli = [9u32, 15u32, 21u32, 25u32, 27u32];
for &modulus in &test_moduli {
let trial_result =
basic_compute_montgomery_params_with_method(modulus, NPrimeMethod::TrialSearch);
let euclidean_result = basic_compute_montgomery_params_with_method(
modulus,
NPrimeMethod::ExtendedEuclidean,
);
let hensels_result =
basic_compute_montgomery_params_with_method(modulus, NPrimeMethod::HenselsLifting);
assert!(
trial_result.is_some(),
"Trial search failed for modulus {}",
modulus
);
assert!(
euclidean_result.is_some(),
"Extended Euclidean failed for modulus {}",
modulus
);
assert!(
hensels_result.is_some(),
"Hensel's lifting failed for modulus {}",
modulus
);
assert_eq!(
trial_result, euclidean_result,
"Methods disagree for modulus {}",
modulus
);
assert_eq!(
trial_result, hensels_result,
"Methods disagree for modulus {}",
modulus
);
}
}
#[test]
fn test_basic_exponentiation_loop_branches() {
let modulus = 17u32;
let base = 16u32; let exponent = 15u32;
let result = basic_montgomery_mod_exp(base, exponent, modulus).unwrap();
let expected = crate::exp::basic_mod_exp(base, exponent, modulus);
assert_eq!(result, expected);
let exp_pow2 = 16u32; let result_pow2 = basic_montgomery_mod_exp(3u32, exp_pow2, modulus).unwrap();
let expected_pow2 = crate::exp::basic_mod_exp(3u32, exp_pow2, modulus);
assert_eq!(result_pow2, expected_pow2);
let exp_odd = 255u32; let result_odd = basic_montgomery_mod_exp(2u32, exp_odd, modulus).unwrap();
let expected_odd = crate::exp::basic_mod_exp(2u32, exp_odd, modulus);
assert_eq!(result_odd, expected_odd);
}
#[test]
fn test_basic_extended_euclidean_none_case() {
let even_modulus = 6u32; let r = 8u32;
assert!(
crate::inv::basic_mod_inv(even_modulus, r).is_none(),
"basic_mod_inv should return None for non-coprime inputs"
);
let result = compute_n_prime_extended_euclidean(even_modulus, r);
assert!(
result.is_none(),
"Should return None when basic_mod_inv fails"
);
let test_cases = [
(4u32, 8u32), (6u32, 12u32), (10u32, 16u32), (12u32, 8u32), ];
for (modulus, r) in test_cases.iter() {
assert!(
crate::inv::basic_mod_inv(*modulus, *r).is_none(),
"basic_mod_inv({}, {}) should return None",
modulus,
r
);
let result = compute_n_prime_extended_euclidean(*modulus, *r);
assert!(
result.is_none(),
"compute_n_prime_extended_euclidean({}, {}) should return None",
modulus,
r
);
}
}
#[test]
fn test_type_bit_width() {
assert_eq!(type_bit_width::<u8>(), 8);
assert_eq!(type_bit_width::<u16>(), 16);
assert_eq!(type_bit_width::<u32>(), 32);
assert_eq!(type_bit_width::<u64>(), 64);
}
#[test]
fn test_mod_double() {
assert_eq!(mod_double(5u8, 13), 10);
assert_eq!(mod_double(8u8, 13), 3);
assert_eq!(mod_double(200u8, 201), 199);
assert_eq!(mod_double(0u8, 13), 0);
}
#[test]
fn test_compute_n_prime_newton() {
let test_moduli: &[u8] = &[3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 255];
for &n in test_moduli {
let np = compute_n_prime_newton(n, 8);
let product = n.wrapping_mul(np);
assert_eq!(
product, 0xFF,
"n={n}: n*n_prime={product:#04x}, expected 0xFF"
);
}
let np32 = compute_n_prime_newton(13u32, 32);
assert_eq!(13u32.wrapping_mul(np32), u32::MAX);
}
#[test]
fn test_compute_r_mod_n() {
assert_eq!(compute_r_mod_n(13u8, 8), 9);
assert_eq!(compute_r_mod_n(255u8, 8), 1);
assert_eq!(compute_r_mod_n(3u8, 8), 1);
assert_eq!(compute_r_mod_n(13u32, 32), 9);
}
#[test]
fn test_wide_redc() {
let n: u8 = 13;
let n_prime = compute_n_prime_newton(n, 8);
let result = wide_redc(35u8, 0u8, n, n_prime);
assert_eq!(result, 1);
}
#[test]
fn test_wide_montgomery_roundtrip() {
let modulus = 13u8;
let w = type_bit_width::<u8>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
for a in 0u8..13 {
let (lo, hi) = a.wide_mul(&r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let back = wide_redc(a_m, 0u8, modulus, n_prime);
assert_eq!(back, a, "roundtrip failed for a={a}");
}
}
#[test]
fn test_wide_montgomery_roundtrip_u32() {
let modulus = 0xFFFF_FFF1u32; let w = type_bit_width::<u32>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
let test_vals = [0u32, 1, 2, 100, modulus - 1, modulus - 2, 0x7FFF_FFFF];
for &a in &test_vals {
let (lo, hi) = a.wide_mul(&r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let back = wide_redc(a_m, 0u32, modulus, n_prime);
assert_eq!(back, a, "roundtrip failed for a={a:#x}");
}
}
#[test]
fn test_wide_mod_mul_u8_exhaustive() {
let modulus = 13u8;
for a in 0u8..13 {
for b in 0u8..13 {
let expected = ((a as u16 * b as u16) % 13) as u8;
let got = basic_montgomery_mod_mul(a, b, modulus).unwrap();
assert_eq!(
got, expected,
"{a}*{b} mod 13: got {got}, expected {expected}"
);
}
}
}
#[test]
fn test_wide_mod_mul_u32() {
let modulus = 0xFFFF_FFF1u32;
let vals = [0u32, 1, 2, 7, 1000, 0x7FFF_FFFF, modulus - 1, modulus - 2];
for &a in &vals {
for &b in &vals {
let expected = crate::mul::basic_mod_mul(a, b, modulus);
let got = basic_montgomery_mod_mul(a, b, modulus).unwrap();
assert_eq!(
got, expected,
"{a:#x}*{b:#x} mod {modulus:#x}: got {got:#x}, expected {expected:#x}"
);
}
}
}
#[test]
fn test_wide_mod_exp_small() {
let modulus = 13u8;
for base in 0u8..13 {
for exp in 0u8..20 {
let expected = crate::exp::basic_mod_exp(base, exp, modulus);
let got = basic_montgomery_mod_exp(base, exp, modulus).unwrap();
assert_eq!(
got, expected,
"{base}^{exp} mod 13: got {got}, expected {expected}"
);
}
}
}
#[test]
fn test_wide_mod_exp_u32() {
let modulus = 0xFFFF_FFF1u32;
let expected = crate::exp::basic_mod_exp(2u32, 100, modulus);
let got = basic_montgomery_mod_exp(2, 100, modulus).unwrap();
assert_eq!(got, expected);
let expected = crate::exp::basic_mod_exp(3u32, 1000, modulus);
let got = basic_montgomery_mod_exp(3, 1000, modulus).unwrap();
assert_eq!(got, expected);
let expected = crate::exp::basic_mod_exp(modulus - 1, modulus - 2, modulus);
let got = basic_montgomery_mod_exp(modulus - 1, modulus - 2, modulus).unwrap();
assert_eq!(got, expected);
}
#[test]
fn test_wide_mod_exp_large_u64() {
let modulus = 0xFFFF_FFFF_FFFF_FFC5u64; let base = 0xDEAD_BEEF_CAFE_BABEu64;
let exp = 0x1234_5678u64;
let expected = crate::exp::basic_mod_exp(base, exp, modulus);
let got = basic_montgomery_mod_exp(base, exp, modulus).unwrap();
assert_eq!(got, expected);
}
#[test]
fn test_wide_params_even_modulus() {
assert!(basic_montgomery_mod_mul(2u32, 3u32, 4u32).is_none());
assert!(basic_montgomery_mod_mul(2u32, 3u32, 0u32).is_none());
}
#[test]
fn test_wide_mod_mul_even_modulus() {
assert!(basic_montgomery_mod_mul(2u32, 3u32, 4u32).is_none());
}
#[test]
fn test_wide_mod_exp_even_modulus() {
assert!(basic_montgomery_mod_exp(2u32, 3u32, 4u32).is_none());
}
#[test]
fn test_wide_fixed_bigint() {
use fixed_bigint::FixedUInt;
type U128 = FixedUInt<u32, 4>;
let modulus = !U128::from(0u64) - U128::from(58u64);
let w = type_bit_width::<U128>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
let a = U128::from(0xDEAD_BEEF_u64);
let (lo, hi) = a.wide_mul(&r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let back = wide_redc(a_m, U128::from(0u64), modulus, n_prime);
assert_eq!(back, a);
let b = U128::from(0xCAFE_BABE_u64);
let expected = crate::mul::basic_mod_mul(a, b, modulus);
let got = basic_montgomery_mod_mul(a, b, modulus).unwrap();
assert_eq!(got, expected);
let base = U128::from(42u64);
let exp = U128::from(1000u64);
let expected = crate::exp::basic_mod_exp(base, exp, modulus);
let got = basic_montgomery_mod_exp(base, exp, modulus).unwrap();
assert_eq!(got, expected);
}
#[test]
fn test_input_reduction_mod_mul() {
let modulus = 13u32;
let a = 27u32; let b = 39u32; let got = basic_montgomery_mod_mul(a, b, modulus).unwrap();
assert_eq!(got, 0, "27 * 39 mod 13 should be 0");
let a = 100u32; let b = 200u32; let expected = (9 * 5) % 13; let got = basic_montgomery_mod_mul(a, b, modulus).unwrap();
assert_eq!(got, expected);
let got = basic_montgomery_mod_mul(13u32, 5u32, modulus).unwrap();
assert_eq!(got, 0, "modulus * 5 mod modulus should be 0");
}
#[test]
fn test_input_reduction_mod_exp() {
let modulus = 13u32;
let base = 27u32; let exp = 100u32;
let expected = 1u32; let got = basic_montgomery_mod_exp(base, exp, modulus).unwrap();
assert_eq!(got, expected);
let base = 100u32; let exp = 3u32;
let expected = crate::exp::basic_mod_exp(9u32, 3, modulus); let got = basic_montgomery_mod_exp(base, exp, modulus).unwrap();
assert_eq!(got, expected);
}
#[test]
fn test_wide_from_montgomery() {
let modulus = 13u8;
let w = type_bit_width::<u8>();
let n_prime = compute_n_prime_newton(modulus, w);
let r_mod_n = compute_r_mod_n(modulus, w);
let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
for a in 0u8..13 {
let (lo, hi) = a.wide_mul(&r2_mod_n);
let a_m = wide_redc(lo, hi, modulus, n_prime);
let back = wide_from_montgomery(a_m, modulus, n_prime);
assert_eq!(back, a, "wide_from_montgomery roundtrip failed for {a}");
}
}
}