modmath/montgomery/
constrained_mont.rs

1// Constrained Montgomery arithmetic functions
2// These work with references to avoid unnecessary copies, following the pattern from exp.rs
3
4use super::basic_mont::NPrimeMethod;
5use crate::inv::constrained_mod_inv;
6
7/// Compute N' using trial search method - O(R) complexity (constrained version)
8/// Finds N' such that modulus * N' ≡ -1 (mod R)
9/// Returns None if N' cannot be found (should not happen for valid Montgomery parameters)
10fn compute_n_prime_trial_search_constrained<T>(modulus: &T, r: &T) -> Option<T>
11where
12    T: Clone
13        + num_traits::Zero
14        + num_traits::One
15        + PartialEq
16        + PartialOrd
17        + num_traits::ops::wrapping::WrappingAdd
18        + num_traits::ops::wrapping::WrappingSub
19        + for<'a> core::ops::Rem<&'a T, Output = T>,
20    for<'a> T: core::ops::RemAssign<&'a T> + core::ops::Mul<&'a T, Output = T>,
21    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
22{
23    // We need to find N' where modulus * N' ≡ R - 1 (mod R)
24    let target = r.clone().wrapping_sub(&T::one()); // This is -1 mod R
25
26    // Simple trial search for N'
27    let mut n_prime = T::one();
28    loop {
29        if (modulus.clone() * &n_prime) % r == target {
30            return Some(n_prime);
31        }
32        n_prime = n_prime.wrapping_add(&T::one());
33
34        // Safety check to avoid infinite loop
35        if &n_prime >= r {
36            return None; // Could not find N' - should not happen for valid inputs
37        }
38    }
39}
40
41/// Compute N' using Extended Euclidean Algorithm - O(log R) complexity (constrained version)
42/// Finds N' such that modulus * N' ≡ -1 (mod R)
43/// Returns None if modular inverse cannot be found
44fn compute_n_prime_extended_euclidean_constrained<T>(modulus: &T, r: &T) -> Option<T>
45where
46    T: Clone
47        + num_traits::Zero
48        + num_traits::One
49        + PartialEq
50        + PartialOrd
51        + num_traits::ops::wrapping::WrappingAdd
52        + num_traits::ops::wrapping::WrappingSub,
53    for<'a> T: core::ops::Add<&'a T, Output = T> + core::ops::Sub<&'a T, Output = T>,
54    for<'a> &'a T: core::ops::Sub<T, Output = T> + core::ops::Div<&'a T, Output = T>,
55{
56    // We need to solve: modulus * N' ≡ -1 (mod R)
57    // This is equivalent to: N' ≡ -modulus^(-1) (mod R)
58
59    // Use constrained_mod_inv to find modulus^(-1) mod R
60    if let Some(modulus_inv) = constrained_mod_inv(modulus.clone(), r) {
61        // N' = -modulus^(-1) mod R = R - modulus^(-1) mod R
62        if modulus_inv == T::zero() {
63            Some(r.clone().wrapping_sub(&T::one())) // Handle edge case where inverse is 0
64        } else {
65            Some(r.clone().wrapping_sub(&modulus_inv))
66        }
67    } else {
68        None // Could not find modular inverse - gcd(modulus, R) should be 1 for valid Montgomery parameters
69    }
70}
71
72/// Compute N' using Hensel's lifting - O(log R) complexity, optimized for R = 2^k (constrained version)
73/// Finds N' such that modulus * N' ≡ -1 (mod R)
74/// Returns None if Hensel's lifting fails to produce correct N'
75fn compute_n_prime_hensels_lifting_constrained<T>(modulus: &T, r: &T, r_bits: usize) -> Option<T>
76where
77    T: Clone
78        + num_traits::Zero
79        + num_traits::One
80        + PartialEq
81        + PartialOrd
82        + num_traits::ops::wrapping::WrappingAdd
83        + num_traits::ops::wrapping::WrappingSub
84        + core::ops::Shl<usize, Output = T>
85        + for<'a> core::ops::Rem<&'a T, Output = T>,
86    for<'a> T: core::ops::RemAssign<&'a T> + core::ops::Mul<&'a T, Output = T>,
87    for<'a> &'a T: core::ops::Rem<&'a T, Output = T> + core::ops::BitAnd<&'a T, Output = T>,
88{
89    // Hensel's lifting requires modulus to be odd (prerequisite for Montgomery arithmetic)
90    debug_assert!(
91        modulus & &T::one() == T::one(),
92        "Hensel's lifting requires an odd modulus for Montgomery arithmetic"
93    );
94
95    // Hensel's lifting for N' computation when R = 2^k
96    let mut n_prime = T::one();
97
98    // Lift from 2^1 to 2^r_bits using Newton's method
99    for k in 2..=r_bits {
100        let target_mod = T::one() << k; // 2^k
101        let temp_prod = modulus.clone() * &n_prime;
102        let temp_sum = temp_prod.wrapping_add(&T::one());
103        let check_val = &temp_sum % &target_mod;
104
105        if check_val != T::zero() {
106            let prev_power = T::one() << (k - 1); // 2^(k-1)
107            if check_val == prev_power {
108                n_prime = n_prime.wrapping_add(&prev_power);
109            }
110        }
111    }
112
113    // Final check
114    let final_check = (modulus.clone() * &n_prime) % r;
115    let target = r.clone().wrapping_sub(&T::one()); // -1 mod R
116
117    if final_check != target {
118        None // Hensel lifting failed to produce correct N'
119    } else {
120        Some(n_prime)
121    }
122}
123
124/// Montgomery parameter computation (Constrained)
125/// Computes R, R^(-1) mod N, N', and R bit length for Montgomery arithmetic
126/// Returns None if N' computation fails or R^(-1) mod N cannot be found
127pub fn constrained_compute_montgomery_params_with_method<T>(
128    modulus: &T,
129    method: NPrimeMethod,
130) -> Option<(T, T, T, usize)>
131where
132    T: Clone
133        + num_traits::Zero
134        + num_traits::One
135        + PartialEq
136        + PartialOrd
137        + num_traits::ops::wrapping::WrappingAdd
138        + num_traits::ops::wrapping::WrappingSub
139        + core::ops::Shl<usize, Output = T>
140        + core::ops::Sub<Output = T>
141        + for<'a> core::ops::Rem<&'a T, Output = T>,
142    for<'a> T: core::ops::Add<&'a T, Output = T>
143        + core::ops::Sub<&'a T, Output = T>
144        + core::ops::Mul<&'a T, Output = T>
145        + core::ops::RemAssign<&'a T>,
146    for<'a> &'a T: core::ops::Sub<T, Output = T>
147        + core::ops::Div<&'a T, Output = T>
148        + core::ops::Rem<&'a T, Output = T>
149        + core::ops::BitAnd<&'a T, Output = T>,
150{
151    // Step 1: Find R = 2^k where R > modulus
152    let mut r = T::one();
153    let mut r_bits = 0usize;
154
155    while &r <= modulus {
156        r = r << 1; // r *= 2
157        r_bits += 1;
158    }
159
160    // Step 2: Compute R^(-1) mod modulus
161    let r_inv = constrained_mod_inv(r.clone(), modulus)?;
162
163    // Step 3: Compute N' such that N * N' ≡ -1 (mod R) using selected method
164    let n_prime = match method {
165        NPrimeMethod::TrialSearch => compute_n_prime_trial_search_constrained(modulus, &r)?,
166        NPrimeMethod::ExtendedEuclidean => {
167            compute_n_prime_extended_euclidean_constrained(modulus, &r)?
168        }
169        NPrimeMethod::HenselsLifting => {
170            compute_n_prime_hensels_lifting_constrained(modulus, &r, r_bits)?
171        }
172    };
173
174    Some((r, r_inv, n_prime, r_bits))
175}
176
177/// Montgomery parameter computation (Constrained) with default method
178/// Computes R, R^(-1) mod N, N', and R bit length for Montgomery arithmetic
179/// Returns None if parameter computation fails
180pub fn constrained_compute_montgomery_params<T>(modulus: &T) -> Option<(T, T, T, usize)>
181where
182    T: Clone
183        + num_traits::Zero
184        + num_traits::One
185        + PartialEq
186        + PartialOrd
187        + num_traits::ops::wrapping::WrappingAdd
188        + num_traits::ops::wrapping::WrappingSub
189        + core::ops::Shl<usize, Output = T>
190        + core::ops::Sub<Output = T>
191        + for<'a> core::ops::Rem<&'a T, Output = T>,
192    for<'a> T: core::ops::Add<&'a T, Output = T>
193        + core::ops::Sub<&'a T, Output = T>
194        + core::ops::Mul<&'a T, Output = T>
195        + core::ops::RemAssign<&'a T>,
196    for<'a> &'a T: core::ops::Sub<T, Output = T>
197        + core::ops::Div<&'a T, Output = T>
198        + core::ops::Rem<&'a T, Output = T>
199        + core::ops::BitAnd<&'a T, Output = T>,
200{
201    constrained_compute_montgomery_params_with_method(modulus, NPrimeMethod::default())
202}
203
204/// Convert to Montgomery form (Constrained): a -> (a * R) mod N
205pub fn constrained_to_montgomery<T>(a: T, modulus: &T, r: &T) -> T
206where
207    T: num_traits::Zero
208        + num_traits::One
209        + PartialOrd
210        + num_traits::ops::wrapping::WrappingAdd
211        + num_traits::ops::wrapping::WrappingSub
212        + core::ops::Shr<usize, Output = T>,
213    for<'a> T: core::ops::RemAssign<&'a T>,
214    for<'a> &'a T: core::ops::Rem<&'a T, Output = T> + core::ops::BitAnd<Output = T>,
215{
216    crate::mul::constrained_mod_mul(a, r, modulus)
217}
218
219/// Convert from Montgomery form (Constrained): (a * R) -> a mod N
220/// Uses Montgomery reduction algorithm
221pub fn constrained_from_montgomery<T>(a_mont: T, modulus: &T, n_prime: &T, r_bits: usize) -> T
222where
223    T: Clone
224        + num_traits::Zero
225        + num_traits::One
226        + PartialOrd
227        + core::ops::Shl<usize, Output = T>
228        + core::ops::Shr<usize, Output = T>
229        + num_traits::ops::wrapping::WrappingAdd
230        + num_traits::ops::wrapping::WrappingSub
231        + for<'a> core::ops::Rem<&'a T, Output = T>,
232    for<'a> T: core::ops::Mul<&'a T, Output = T>,
233    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
234{
235    // Montgomery reduction algorithm:
236    let r = T::one() << r_bits; // R = 2^r_bits
237
238    // Step 1: m = (a_mont * N') mod R
239    let m = (a_mont.clone() * n_prime) % &r;
240
241    // Step 2: t = (a_mont + m * N) / R
242    let temp_prod = m.clone() * modulus;
243    let temp_sum = a_mont.wrapping_add(&temp_prod);
244    let t = temp_sum >> r_bits; // Divide by R = 2^r_bits
245
246    // Step 3: Final reduction
247    if &t >= modulus {
248        t.wrapping_sub(modulus)
249    } else {
250        t
251    }
252}
253
254/// Montgomery multiplication (Constrained): (a * R) * (b * R) -> (a * b * R) mod N
255pub fn constrained_montgomery_mul<T>(
256    a_mont: &T,
257    b_mont: &T,
258    modulus: &T,
259    n_prime: &T,
260    r_bits: usize,
261) -> T
262where
263    T: Clone
264        + num_traits::Zero
265        + num_traits::One
266        + PartialOrd
267        + core::ops::Shl<usize, Output = T>
268        + core::ops::Shr<usize, Output = T>
269        + num_traits::ops::wrapping::WrappingAdd
270        + num_traits::ops::wrapping::WrappingSub
271        + for<'a> core::ops::Rem<&'a T, Output = T>,
272    for<'a> T: core::ops::RemAssign<&'a T> + core::ops::Mul<&'a T, Output = T>,
273    for<'a> &'a T: core::ops::Rem<&'a T, Output = T> + core::ops::BitAnd<Output = T>,
274{
275    // Step 1: Regular modular multiplication in Montgomery domain
276    let product = crate::mul::constrained_mod_mul(a_mont.clone(), b_mont, modulus);
277
278    // Step 2: Apply Montgomery reduction to get result in Montgomery form
279    constrained_from_montgomery(product, modulus, n_prime, r_bits)
280}
281
282/// Complete Montgomery modular multiplication (Constrained): A * B mod N
283/// Returns None if Montgomery parameter computation fails
284pub fn constrained_montgomery_mod_mul<T>(a: T, b: &T, modulus: &T) -> Option<T>
285where
286    T: Clone
287        + num_traits::Zero
288        + num_traits::One
289        + PartialEq
290        + PartialOrd
291        + num_traits::ops::wrapping::WrappingAdd
292        + num_traits::ops::wrapping::WrappingSub
293        + core::ops::Shl<usize, Output = T>
294        + core::ops::Shr<usize, Output = T>
295        + core::ops::Sub<Output = T>
296        + for<'a> core::ops::Rem<&'a T, Output = T>,
297    for<'a> T: core::ops::Add<&'a T, Output = T>
298        + core::ops::Sub<&'a T, Output = T>
299        + core::ops::Mul<&'a T, Output = T>
300        + core::ops::RemAssign<&'a T>,
301    for<'a> &'a T: core::ops::Sub<T, Output = T>
302        + core::ops::Div<&'a T, Output = T>
303        + core::ops::Rem<&'a T, Output = T>
304        + core::ops::BitAnd<Output = T>,
305{
306    let (r, _r_inv, n_prime, r_bits) = constrained_compute_montgomery_params(modulus)?;
307    let a_mont = constrained_to_montgomery(a, modulus, &r);
308    let b_mont = constrained_to_montgomery(b.clone(), modulus, &r);
309    let result_mont = constrained_montgomery_mul(&a_mont, &b_mont, modulus, &n_prime, r_bits);
310    Some(constrained_from_montgomery(
311        result_mont,
312        modulus,
313        &n_prime,
314        r_bits,
315    ))
316}
317
318/// Montgomery-based modular exponentiation (Constrained): base^exponent mod modulus
319/// Uses Montgomery arithmetic for efficient repeated multiplication
320/// Returns None if Montgomery parameter computation fails
321pub fn constrained_montgomery_mod_exp<T>(mut base: T, exponent: &T, modulus: &T) -> Option<T>
322where
323    T: Clone
324        + num_traits::Zero
325        + num_traits::One
326        + PartialEq
327        + PartialOrd
328        + num_traits::ops::wrapping::WrappingAdd
329        + num_traits::ops::wrapping::WrappingSub
330        + core::ops::Shl<usize, Output = T>
331        + core::ops::Shr<usize, Output = T>
332        + core::ops::ShrAssign<usize>
333        + core::ops::Sub<Output = T>
334        + for<'a> core::ops::Rem<&'a T, Output = T>,
335    for<'a> T: core::ops::RemAssign<&'a T>
336        + core::ops::Add<&'a T, Output = T>
337        + core::ops::Sub<&'a T, Output = T>
338        + core::ops::Mul<&'a T, Output = T>,
339    for<'a> &'a T: core::ops::Sub<T, Output = T>
340        + core::ops::Div<&'a T, Output = T>
341        + core::ops::Rem<&'a T, Output = T>
342        + core::ops::BitAnd<Output = T>,
343{
344    // Compute Montgomery parameters
345    let (r, _r_inv, n_prime, r_bits) = constrained_compute_montgomery_params(modulus)?;
346
347    // Reduce base and convert to Montgomery form
348    base.rem_assign(modulus);
349    base = constrained_to_montgomery(base, modulus, &r);
350
351    // Montgomery form of 1 (the initial result)
352    let mut result = constrained_to_montgomery(T::one(), modulus, &r);
353
354    // Copy exponent for manipulation
355    let mut exp = exponent.clone();
356    let two = T::one().clone().wrapping_add(&T::one());
357
358    // Binary exponentiation using Montgomery multiplication
359    while exp > T::zero() {
360        // If exponent is odd, multiply result by current base power
361        if &exp % &two == T::one() {
362            result = constrained_montgomery_mul(&result, &base, modulus, &n_prime, r_bits);
363        }
364
365        // Square the base for next iteration
366        exp >>= 1;
367        base = constrained_montgomery_mul(&base, &base, modulus, &n_prime, r_bits);
368    }
369
370    // Convert result back from Montgomery form
371    Some(constrained_from_montgomery(
372        result, modulus, &n_prime, r_bits,
373    ))
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_constrained_compute_n_prime_trial_search_failure() {
382        // Test case where no N' can be found - this happens when gcd(modulus, R) != 1
383        // Example: N = 4, R = 8 (both even, so gcd(4, 8) = 4 != 1)
384        let modulus = 4u32;
385        let r = 8u32;
386        let result = compute_n_prime_trial_search_constrained(&modulus, &r);
387        assert!(
388            result.is_none(),
389            "Should return None for invalid modulus-R pair"
390        );
391    }
392
393    #[test]
394    fn test_constrained_compute_montgomery_params_failure() {
395        // Test Montgomery parameter computation failure with even modulus
396        let even_modulus = 4u32;
397        let result = constrained_compute_montgomery_params(&even_modulus);
398        assert!(result.is_none(), "Should return None for even modulus");
399    }
400
401    #[test]
402    fn test_constrained_compute_montgomery_params_failure_with_method() {
403        // Test all N' computation methods with invalid inputs
404        let invalid_modulus = 4u32; // Even modulus
405
406        // Trial search should fail
407        let trial_result = constrained_compute_montgomery_params_with_method(
408            &invalid_modulus,
409            NPrimeMethod::TrialSearch,
410        );
411        assert!(
412            trial_result.is_none(),
413            "Trial search should fail with even modulus"
414        );
415
416        // Extended Euclidean should fail
417        let euclidean_result = constrained_compute_montgomery_params_with_method(
418            &invalid_modulus,
419            NPrimeMethod::ExtendedEuclidean,
420        );
421        assert!(
422            euclidean_result.is_none(),
423            "Extended Euclidean should fail with even modulus"
424        );
425
426        // Hensel's lifting should fail
427        let hensels_result = constrained_compute_montgomery_params_with_method(
428            &invalid_modulus,
429            NPrimeMethod::HenselsLifting,
430        );
431        assert!(
432            hensels_result.is_none(),
433            "Hensel's lifting should fail with even modulus"
434        );
435    }
436
437    #[test]
438    fn test_constrained_montgomery_mod_mul_parameter_failure() {
439        // Test that montgomery_mod_mul returns None when parameter computation fails
440        let invalid_modulus = 4u32;
441        let a = 2u32;
442        let b = 3u32;
443
444        let result = constrained_montgomery_mod_mul(a, &b, &invalid_modulus);
445        assert!(
446            result.is_none(),
447            "Montgomery mod_mul should return None for invalid modulus"
448        );
449    }
450
451    #[test]
452    fn test_constrained_montgomery_mod_exp_parameter_failure() {
453        // Test that montgomery_mod_exp returns None when parameter computation fails
454        let invalid_modulus = 4u32;
455        let base = 2u32;
456        let exponent = 3u32;
457
458        let result = constrained_montgomery_mod_exp(base, &exponent, &invalid_modulus);
459        assert!(
460            result.is_none(),
461            "Montgomery mod_exp should return None for invalid modulus"
462        );
463    }
464
465    #[test]
466    fn test_constrained_montgomery_reduction_final_subtraction() {
467        // Test to trigger t >= modulus branch in constrained_from_montgomery
468        let modulus = 15u32;
469        let (r, _r_inv, n_prime, r_bits) = constrained_compute_montgomery_params(&modulus).unwrap();
470
471        // Test with maximum value to potentially trigger final subtraction
472        let high_value = 14u32;
473        let mont_high = constrained_to_montgomery(high_value, &modulus, &r);
474        let result = constrained_from_montgomery(mont_high, &modulus, &n_prime, r_bits);
475        assert_eq!(result, high_value);
476
477        // Test with another high value
478        let mont_13 = constrained_to_montgomery(13u32, &modulus, &r);
479        let result_13 = constrained_from_montgomery(mont_13, &modulus, &n_prime, r_bits);
480        assert_eq!(result_13, 13u32);
481    }
482
483    #[test]
484    fn test_constrained_hensel_lifting_branches() {
485        // Test Hensel's lifting with moduli that may trigger different conditional paths
486        let test_moduli = [9u32, 15u32, 21u32, 35u32, 45u32]; // Various composite odd moduli
487
488        for &modulus in &test_moduli {
489            let hensels_result = constrained_compute_montgomery_params_with_method(
490                &modulus,
491                crate::montgomery::NPrimeMethod::HenselsLifting,
492            );
493
494            assert!(
495                hensels_result.is_some(),
496                "Hensel's lifting should work for modulus {}",
497                modulus
498            );
499
500            // Verify the result is mathematically correct
501            if let Some((r, _r_inv, n_prime, _r_bits)) = &hensels_result {
502                let check = (modulus * n_prime.clone()) % r.clone();
503                let expected = r.clone() - 1; // Should equal R - 1 (which is -1 mod R)
504                assert_eq!(
505                    check, expected,
506                    "N' verification failed for modulus {} with Hensel's lifting",
507                    modulus
508                );
509            }
510        }
511    }
512
513    #[test]
514    fn test_constrained_multiplication_stress() {
515        // Test multiplication with values designed to stress different code paths
516        let modulus = 33u32; // 33 = 3 * 11, composite modulus
517        let (r, _r_inv, n_prime, r_bits) = constrained_compute_montgomery_params(&modulus).unwrap();
518
519        // Test with values that may cause intermediate results needing reduction
520        let test_pairs = [(31u32, 32u32), (29u32, 30u32), (25u32, 27u32)];
521
522        for (a, b) in test_pairs.iter() {
523            let a_mont = constrained_to_montgomery(*a, &modulus, &r);
524            let b_mont = constrained_to_montgomery(*b, &modulus, &r);
525
526            // This may hit different branches in Montgomery multiplication
527            let result_mont =
528                constrained_montgomery_mul(&a_mont, &b_mont, &modulus, &n_prime, r_bits);
529            let result = constrained_from_montgomery(result_mont, &modulus, &n_prime, r_bits);
530
531            let expected = (a * b) % modulus;
532            assert_eq!(result, expected, "Failed for {} * {} mod {}", a, b, modulus);
533        }
534    }
535
536    #[test]
537    fn test_constrained_exponentiation_conditional_branches() {
538        // Test exponentiation with specific values to hit different loop branches
539        let modulus = 19u32; // Prime modulus
540
541        // Test with base near modulus and various exponent patterns
542        let test_cases = [
543            (18u32, 2u32),  // Base near modulus, small exponent
544            (17u32, 7u32),  // Various combinations
545            (15u32, 31u32), // Larger exponent with specific bit pattern
546            (2u32, 127u32), // Small base, large exponent
547        ];
548
549        for (base, exponent) in test_cases.iter() {
550            let result = constrained_montgomery_mod_exp(*base, exponent, &modulus).unwrap();
551            let expected = crate::exp::constrained_mod_exp(*base, exponent, &modulus);
552            assert_eq!(
553                result, expected,
554                "Failed for {}^{} mod {}",
555                base, exponent, modulus
556            );
557        }
558    }
559
560    #[test]
561    fn test_constrained_extended_euclidean_edge_cases() {
562        // Test Extended Euclidean N' computation with various moduli
563        let edge_moduli = [7u32, 9u32, 25u32, 49u32, 121u32]; // Powers of primes and products
564
565        for &modulus in &edge_moduli {
566            let euclidean_result = constrained_compute_montgomery_params_with_method(
567                &modulus,
568                crate::montgomery::NPrimeMethod::ExtendedEuclidean,
569            );
570
571            assert!(
572                euclidean_result.is_some(),
573                "Extended Euclidean should work for modulus {}",
574                modulus
575            );
576
577            // Cross-validate with trial search
578            let trial_result = constrained_compute_montgomery_params_with_method(
579                &modulus,
580                crate::montgomery::NPrimeMethod::TrialSearch,
581            );
582
583            assert_eq!(
584                euclidean_result, trial_result,
585                "Extended Euclidean vs Trial Search mismatch for modulus {}",
586                modulus
587            );
588        }
589    }
590}