use super::basic_mont::NPrimeMethod;
use crate::inv::constrained_mod_inv;
fn compute_n_prime_trial_search_constrained<T>(modulus: &T, r: &T) -> Option<T>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::RemAssign<&'a T> + core::ops::Mul<&'a T, Output = T>,
for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
{
let target = r.clone().wrapping_sub(&T::one());
let mut n_prime = T::one();
loop {
if (modulus.clone() * &n_prime) % r == target {
return Some(n_prime);
}
n_prime = n_prime.wrapping_add(&T::one());
if &n_prime >= r {
return None; }
}
}
fn compute_n_prime_extended_euclidean_constrained<T>(modulus: &T, r: &T) -> Option<T>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub,
for<'a> T: core::ops::Add<&'a T, Output = T> + core::ops::Sub<&'a T, Output = T>,
for<'a> &'a T: core::ops::Sub<T, Output = T> + core::ops::Div<&'a T, Output = T>,
{
if let Some(modulus_inv) = constrained_mod_inv(modulus.clone(), r) {
if modulus_inv == T::zero() {
Some(r.clone().wrapping_sub(&T::one())) } else {
Some(r.clone().wrapping_sub(&modulus_inv))
}
} else {
None }
}
fn compute_n_prime_hensels_lifting_constrained<T>(modulus: &T, r: &T, r_bits: usize) -> Option<T>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shl<usize, Output = T>
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::RemAssign<&'a T> + core::ops::Mul<&'a T, Output = T>,
for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
{
let mut n_prime = T::one();
for k in 2..=r_bits {
let target_mod = T::one() << k; let temp_prod = modulus.clone() * &n_prime;
let temp_sum = temp_prod.wrapping_add(&T::one());
let check_val = &temp_sum % &target_mod;
if check_val != T::zero() {
let prev_power = T::one() << (k - 1); if check_val == prev_power {
n_prime = n_prime.wrapping_add(&prev_power);
}
}
}
let final_check = (modulus.clone() * &n_prime) % r;
let target = r.clone().wrapping_sub(&T::one());
if final_check != target {
None } else {
Some(n_prime)
}
}
pub fn constrained_compute_montgomery_params_with_method<T>(
modulus: &T,
method: NPrimeMethod,
) -> Option<(T, T, T, usize)>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shl<usize, Output = T>
+ core::ops::Sub<Output = T>
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::Add<&'a T, Output = T>
+ core::ops::Sub<&'a T, Output = T>
+ core::ops::Mul<&'a T, Output = T>
+ core::ops::RemAssign<&'a T>,
for<'a> &'a T: core::ops::Sub<T, Output = T>
+ core::ops::Div<&'a T, Output = T>
+ core::ops::Rem<&'a T, Output = T>,
{
let mut r = T::one();
let mut r_bits = 0usize;
while &r <= modulus {
r = r << 1; r_bits += 1;
}
let r_inv = constrained_mod_inv(r.clone(), modulus)?;
let n_prime = match method {
NPrimeMethod::TrialSearch => compute_n_prime_trial_search_constrained(modulus, &r)?,
NPrimeMethod::ExtendedEuclidean | NPrimeMethod::Newton => {
compute_n_prime_extended_euclidean_constrained(modulus, &r)?
}
NPrimeMethod::HenselsLifting => {
compute_n_prime_hensels_lifting_constrained(modulus, &r, r_bits)?
}
};
Some((r, r_inv, n_prime, r_bits))
}
pub fn constrained_compute_montgomery_params<T>(modulus: &T) -> Option<(T, T, T, usize)>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shl<usize, Output = T>
+ core::ops::Sub<Output = T>
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::Add<&'a T, Output = T>
+ core::ops::Sub<&'a T, Output = T>
+ core::ops::Mul<&'a T, Output = T>
+ core::ops::RemAssign<&'a T>,
for<'a> &'a T: core::ops::Sub<T, Output = T>
+ core::ops::Div<&'a T, Output = T>
+ core::ops::Rem<&'a T, Output = T>,
{
constrained_compute_montgomery_params_with_method(modulus, NPrimeMethod::default())
}
pub fn constrained_to_montgomery<T>(a: T, modulus: &T, r: &T) -> T
where
T: num_traits::Zero
+ num_traits::One
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shr<usize, Output = T>
+ crate::parity::Parity,
for<'a> T: core::ops::RemAssign<&'a T>,
for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
{
crate::mul::constrained_mod_mul(a, r, modulus)
}
pub fn constrained_from_montgomery<T>(a_mont: T, modulus: &T, n_prime: &T, r_bits: usize) -> T
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ PartialOrd
+ core::ops::Shl<usize, Output = T>
+ core::ops::Shr<usize, Output = T>
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub,
for<'a> T: core::ops::Mul<&'a T, Output = T>,
for<'a> &'a T: core::ops::BitAnd<&'a T, Output = T>,
{
if r_bits == 0 {
return if &a_mont >= modulus {
a_mont.wrapping_sub(modulus)
} else {
a_mont
};
}
let mask = (T::one() << r_bits).wrapping_sub(&T::one());
let a_low = &a_mont & &mask;
let product = a_low * n_prime;
let m = &product & &mask;
let m_times_n = m * modulus;
let temp_sum = a_mont.wrapping_add(&m_times_n);
let t = temp_sum >> r_bits;
if &t >= modulus {
t.wrapping_sub(modulus)
} else {
t
}
}
pub fn constrained_montgomery_mul<T>(
a_mont: &T,
b_mont: &T,
modulus: &T,
n_prime: &T,
r_bits: usize,
) -> T
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ PartialOrd
+ core::ops::Shl<usize, Output = T>
+ core::ops::Shr<usize, Output = T>
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ crate::parity::Parity
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::RemAssign<&'a T> + core::ops::Mul<&'a T, Output = T>,
for<'a> &'a T: core::ops::Rem<&'a T, Output = T> + core::ops::BitAnd<Output = T>,
{
let product = crate::mul::constrained_mod_mul(a_mont.clone(), b_mont, modulus);
constrained_from_montgomery(product, modulus, n_prime, r_bits)
}
pub fn constrained_montgomery_mod_mul_with_method<T>(
a: T,
b: &T,
modulus: &T,
method: NPrimeMethod,
) -> Option<T>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ crate::parity::Parity
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shl<usize, Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::Sub<Output = T>
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::Add<&'a T, Output = T>
+ core::ops::Sub<&'a T, Output = T>
+ core::ops::Mul<&'a T, Output = T>
+ core::ops::RemAssign<&'a T>,
for<'a> &'a T: core::ops::Sub<T, Output = T>
+ core::ops::Div<&'a T, Output = T>
+ core::ops::Rem<&'a T, Output = T>
+ core::ops::BitAnd<Output = T>,
{
let (r, _r_inv, n_prime, r_bits) =
constrained_compute_montgomery_params_with_method(modulus, method)?;
let a_mont = constrained_to_montgomery(a, modulus, &r);
let b_mont = constrained_to_montgomery(b.clone(), modulus, &r);
let result_mont = constrained_montgomery_mul(&a_mont, &b_mont, modulus, &n_prime, r_bits);
Some(constrained_from_montgomery(
result_mont,
modulus,
&n_prime,
r_bits,
))
}
pub fn constrained_montgomery_mod_mul<T>(a: T, b: &T, modulus: &T) -> Option<T>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ crate::parity::Parity
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shl<usize, Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::Sub<Output = T>
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::Add<&'a T, Output = T>
+ core::ops::Sub<&'a T, Output = T>
+ core::ops::Mul<&'a T, Output = T>
+ core::ops::RemAssign<&'a T>,
for<'a> &'a T: core::ops::Sub<T, Output = T>
+ core::ops::Div<&'a T, Output = T>
+ core::ops::Rem<&'a T, Output = T>
+ core::ops::BitAnd<Output = T>,
{
constrained_montgomery_mod_mul_with_method(a, b, modulus, NPrimeMethod::default())
}
pub fn constrained_montgomery_mod_exp_with_method<T>(
mut base: T,
exponent: &T,
modulus: &T,
method: NPrimeMethod,
) -> Option<T>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ crate::parity::Parity
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shl<usize, Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::ShrAssign<usize>
+ core::ops::Sub<Output = T>
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::RemAssign<&'a T>
+ core::ops::Add<&'a T, Output = T>
+ core::ops::Sub<&'a T, Output = T>
+ core::ops::Mul<&'a T, Output = T>,
for<'a> &'a T: core::ops::Sub<T, Output = T>
+ core::ops::Div<&'a T, Output = T>
+ core::ops::Rem<&'a T, Output = T>
+ core::ops::BitAnd<Output = T>,
{
let (r, _r_inv, n_prime, r_bits) =
constrained_compute_montgomery_params_with_method(modulus, method)?;
base.rem_assign(modulus);
base = constrained_to_montgomery(base, modulus, &r);
let mut result = constrained_to_montgomery(T::one(), modulus, &r);
let mut exp = exponent.clone();
while exp > T::zero() {
if exp.is_odd() {
result = constrained_montgomery_mul(&result, &base, modulus, &n_prime, r_bits);
}
exp >>= 1;
if exp > T::zero() {
base = constrained_montgomery_mul(&base, &base, modulus, &n_prime, r_bits);
}
}
Some(constrained_from_montgomery(
result, modulus, &n_prime, r_bits,
))
}
pub fn constrained_montgomery_mod_exp<T>(base: T, exponent: &T, modulus: &T) -> Option<T>
where
T: Clone
+ num_traits::Zero
+ num_traits::One
+ crate::parity::Parity
+ PartialEq
+ PartialOrd
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::wrapping::WrappingSub
+ core::ops::Shl<usize, Output = T>
+ core::ops::Shr<usize, Output = T>
+ core::ops::ShrAssign<usize>
+ core::ops::Sub<Output = T>
+ for<'a> core::ops::Rem<&'a T, Output = T>,
for<'a> T: core::ops::RemAssign<&'a T>
+ core::ops::Add<&'a T, Output = T>
+ core::ops::Sub<&'a T, Output = T>
+ core::ops::Mul<&'a T, Output = T>,
for<'a> &'a T: core::ops::Sub<T, Output = T>
+ core::ops::Div<&'a T, Output = T>
+ core::ops::Rem<&'a T, Output = T>
+ core::ops::BitAnd<Output = T>,
{
constrained_montgomery_mod_exp_with_method(base, exponent, modulus, NPrimeMethod::default())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constrained_compute_n_prime_trial_search_failure() {
let modulus = 4u32;
let r = 8u32;
let result = compute_n_prime_trial_search_constrained(&modulus, &r);
assert!(
result.is_none(),
"Should return None for invalid modulus-R pair"
);
}
#[test]
fn test_constrained_compute_montgomery_params_failure() {
let even_modulus = 4u32;
let result = constrained_compute_montgomery_params(&even_modulus);
assert!(result.is_none(), "Should return None for even modulus");
}
#[test]
fn test_constrained_compute_montgomery_params_failure_with_method() {
let invalid_modulus = 4u32;
let trial_result = constrained_compute_montgomery_params_with_method(
&invalid_modulus,
NPrimeMethod::TrialSearch,
);
assert!(
trial_result.is_none(),
"Trial search should fail with even modulus"
);
let euclidean_result = constrained_compute_montgomery_params_with_method(
&invalid_modulus,
NPrimeMethod::ExtendedEuclidean,
);
assert!(
euclidean_result.is_none(),
"Extended Euclidean should fail with even modulus"
);
let hensels_result = constrained_compute_montgomery_params_with_method(
&invalid_modulus,
NPrimeMethod::HenselsLifting,
);
assert!(
hensels_result.is_none(),
"Hensel's lifting should fail with even modulus"
);
}
#[test]
fn test_constrained_montgomery_mod_mul_parameter_failure() {
let invalid_modulus = 4u32;
let a = 2u32;
let b = 3u32;
let result = constrained_montgomery_mod_mul(a, &b, &invalid_modulus);
assert!(
result.is_none(),
"Montgomery mod_mul should return None for invalid modulus"
);
}
#[test]
fn test_constrained_montgomery_mod_exp_parameter_failure() {
let invalid_modulus = 4u32;
let base = 2u32;
let exponent = 3u32;
let result = constrained_montgomery_mod_exp(base, &exponent, &invalid_modulus);
assert!(
result.is_none(),
"Montgomery mod_exp should return None for invalid modulus"
);
}
#[test]
fn test_constrained_montgomery_reduction_final_subtraction() {
let modulus = 15u32;
let (r, _r_inv, n_prime, r_bits) = constrained_compute_montgomery_params(&modulus).unwrap();
let high_value = 14u32;
let mont_high = constrained_to_montgomery(high_value, &modulus, &r);
let result = constrained_from_montgomery(mont_high, &modulus, &n_prime, r_bits);
assert_eq!(result, high_value);
let mont_13 = constrained_to_montgomery(13u32, &modulus, &r);
let result_13 = constrained_from_montgomery(mont_13, &modulus, &n_prime, r_bits);
assert_eq!(result_13, 13u32);
}
#[test]
fn test_constrained_hensel_lifting_branches() {
let test_moduli = [9u32, 15u32, 21u32, 35u32, 45u32];
for &modulus in &test_moduli {
let hensels_result = constrained_compute_montgomery_params_with_method(
&modulus,
crate::montgomery::NPrimeMethod::HenselsLifting,
);
assert!(
hensels_result.is_some(),
"Hensel's lifting should work for modulus {}",
modulus
);
if let Some((r, _r_inv, n_prime, _r_bits)) = &hensels_result {
let check = (modulus * n_prime.clone()) % r.clone();
let expected = r.clone() - 1; assert_eq!(
check, expected,
"N' verification failed for modulus {} with Hensel's lifting",
modulus
);
}
}
}
#[test]
fn test_constrained_multiplication_stress() {
let modulus = 33u32; let (r, _r_inv, n_prime, r_bits) = constrained_compute_montgomery_params(&modulus).unwrap();
let test_pairs = [(31u32, 32u32), (29u32, 30u32), (25u32, 27u32)];
for (a, b) in test_pairs.iter() {
let a_mont = constrained_to_montgomery(*a, &modulus, &r);
let b_mont = constrained_to_montgomery(*b, &modulus, &r);
let result_mont =
constrained_montgomery_mul(&a_mont, &b_mont, &modulus, &n_prime, r_bits);
let result = constrained_from_montgomery(result_mont, &modulus, &n_prime, r_bits);
let expected = (a * b) % modulus;
assert_eq!(result, expected, "Failed for {} * {} mod {}", a, b, modulus);
}
}
#[test]
fn test_constrained_exponentiation_conditional_branches() {
let modulus = 19u32;
let test_cases = [
(18u32, 2u32), (17u32, 7u32), (15u32, 31u32), (2u32, 127u32), ];
for (base, exponent) in test_cases.iter() {
let result = constrained_montgomery_mod_exp(*base, exponent, &modulus).unwrap();
let expected = crate::exp::constrained_mod_exp(*base, exponent, &modulus);
assert_eq!(
result, expected,
"Failed for {}^{} mod {}",
base, exponent, modulus
);
}
}
#[test]
fn test_constrained_extended_euclidean_edge_cases() {
let edge_moduli = [7u32, 9u32, 25u32, 49u32, 121u32];
for &modulus in &edge_moduli {
let euclidean_result = constrained_compute_montgomery_params_with_method(
&modulus,
crate::montgomery::NPrimeMethod::ExtendedEuclidean,
);
assert!(
euclidean_result.is_some(),
"Extended Euclidean should work for modulus {}",
modulus
);
let trial_result = constrained_compute_montgomery_params_with_method(
&modulus,
crate::montgomery::NPrimeMethod::TrialSearch,
);
assert_eq!(
euclidean_result, trial_result,
"Extended Euclidean vs Trial Search mismatch for modulus {}",
modulus
);
}
}
}