Skip to main content

cmpa/
montgomery_impl.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2023 SUSE LLC
3// Author: Nicolai Stange <nstange@suse.de>
4
5use super::add_impl::ct_sub_cond_mp_mp;
6use super::cmp_impl::{ct_geq_mp_mp, ct_lt_mp_mp};
7use super::div_impl::{
8    ct_mod_lshifted_mp_mp, ct_mod_pow2_mp, CtModLshiftedMpMpError, CtModPow2MpError, CtMpDivisor,
9    CtMpDivisorError,
10};
11use super::limb::{
12    ct_add_l_l, ct_eq_l_l, ct_inv_mod_l, ct_lsb_mask_l, ct_mul_add_l_l_l_c, LimbChoice, LimbType,
13    LIMB_BITS,
14};
15use super::limbs_buffer::{
16    ct_mp_limbs_align_len, ct_mp_nlimbs, MpMutNativeEndianUIntLimbsSlice, MpMutUInt,
17    MpMutUIntSlice, MpUIntCommon,
18};
19
20fn ct_montgomery_radix_shift_len(n_len: usize) -> usize {
21    ct_mp_limbs_align_len(n_len)
22}
23
24fn ct_montgomery_radix_shift_mp_nlimbs(n_len: usize) -> usize {
25    ct_mp_nlimbs(n_len)
26}
27
28#[derive(Debug)]
29pub enum CtMontgomeryNegN0InvModLMpError {
30    InvalidModulus,
31}
32
33pub fn ct_montgomery_neg_n0_inv_mod_l_mp<NT: MpUIntCommon>(
34    n: &NT,
35) -> Result<LimbType, CtMontgomeryNegN0InvModLMpError> {
36    if n.is_empty() {
37        return Err(CtMontgomeryNegN0InvModLMpError::InvalidModulus);
38    }
39    let n0 = n.load_l(0);
40    if ct_eq_l_l(n0 & 1, 0).unwrap() != 0 {
41        return Err(CtMontgomeryNegN0InvModLMpError::InvalidModulus);
42    }
43    let n0_inv_mod_l = ct_inv_mod_l(n0);
44    Ok((!n0_inv_mod_l).wrapping_add(1))
45}
46
47#[test]
48fn test_ct_montgomery_neg_n0_inv_mod_l_mp() {
49    use super::limb::LIMB_BYTES;
50    use super::limbs_buffer::{MpMutBigEndianUIntByteSlice, MpMutUIntSlicePriv as _};
51
52    for n0 in 0 as LimbType..128 {
53        let n0 = 2 * n0 + 1;
54        for j in 0..2048 {
55            const MERSENNE_PRIME_13: LimbType = 8191 as LimbType;
56            let v = MERSENNE_PRIME_13.wrapping_mul((511 as LimbType).wrapping_mul(j));
57            let v = v << 8;
58            let n0 = n0.wrapping_add(v);
59
60            let mut n0_buf: [u8; LIMB_BYTES] = [0; LIMB_BYTES];
61            let mut n = MpMutBigEndianUIntByteSlice::from_slice(n0_buf.as_mut_slice()).unwrap();
62            n.store_l(0, n0);
63            let neg_n0_inv = ct_montgomery_neg_n0_inv_mod_l_mp(&n).unwrap();
64            assert_eq!(n0.wrapping_mul(neg_n0_inv), !0);
65        }
66    }
67}
68
69pub struct CtMontgomeryRedcKernel {
70    redc_pow2_rshift: u32,
71    redc_pow2_mask: LimbType,
72    m: LimbType,
73    last_redced_val: LimbType,
74    carry: LimbType,
75}
76
77impl CtMontgomeryRedcKernel {
78    const fn redc_rshift_lo(
79        redc_pow2_rshift: u32,
80        redc_pow2_mask: LimbType,
81        val: LimbType,
82    ) -> LimbType {
83        // There are two possible cases where redc_pow2_rshift == 0:
84        // a.) redc_pow2_exp == 0. In this case !redc_pow2_mask == !0.
85        // b.) redc_pow2_exp == LIMB_BITS. In this case !redc_pow2_mask == !!0 == 0.
86        (val & !redc_pow2_mask) >> redc_pow2_rshift
87    }
88
89    const fn redc_lshift_hi(
90        redc_pow2_rshift: u32,
91        redc_pow2_mask: LimbType,
92        val: LimbType,
93    ) -> LimbType {
94        // There are two possible cases where redc_pow2_rshift == 0:
95        // a.) redc_pow2_exp == 0. In this case redc_pow2_mask == 0 as well.
96        // b.) redc_pow2_exp == LIMB_BITS. In this case redc_pow2_mask == !0.
97        (val & redc_pow2_mask) << ((LIMB_BITS - redc_pow2_rshift) % LIMB_BITS)
98    }
99
100    pub fn start(
101        redc_pow2_exp: u32,
102        t0_val: LimbType,
103        n0_val: LimbType,
104        neg_n0_inv_mod_l: LimbType,
105    ) -> Self {
106        debug_assert!(redc_pow2_exp <= LIMB_BITS);
107
108        // Calculate the shift distance for right shifting a redced limb into its final
109        // position. Be careful to keep it < LIMB_BITS for not running into
110        // undefined behaviour with the shift.  See the comments in
111        // redc_rshift_lo()/redc_rshift_hi() for the mask<->rshift interaction.
112        let redc_pow2_rshift = redc_pow2_exp % LIMB_BITS;
113        let redc_pow2_mask = ct_lsb_mask_l(redc_pow2_exp);
114
115        // For any i >= j, if n' == -n^{-1} mod 2^i, then n' mod 2^j == -n^{-1} mod 2^j.
116        let neg_n0_inv_mod_l = neg_n0_inv_mod_l & redc_pow2_mask;
117        debug_assert_eq!(
118            neg_n0_inv_mod_l.wrapping_mul(n0_val) & redc_pow2_mask,
119            redc_pow2_mask
120        );
121
122        let m = t0_val.wrapping_mul(neg_n0_inv_mod_l) & redc_pow2_mask;
123
124        let (carry, redced_t0_val) = ct_mul_add_l_l_l_c(t0_val, m, n0_val, 0);
125        debug_assert_eq!(redced_t0_val & redc_pow2_mask, 0);
126        // If redc_pow2_exp < LIMB_BITS, the upper bits of the reduced zeroth limb
127        // will become the lower bits of the resulting zeroth limb.
128        let last_redced_val = Self::redc_rshift_lo(redc_pow2_rshift, redc_pow2_mask, redced_t0_val);
129
130        Self {
131            redc_pow2_rshift,
132            redc_pow2_mask,
133            m,
134            last_redced_val,
135            carry,
136        }
137    }
138
139    pub fn update(&mut self, t_val: LimbType, n_val: LimbType) -> LimbType {
140        let redced_t_val;
141        (self.carry, redced_t_val) = ct_mul_add_l_l_l_c(t_val, self.m, n_val, self.carry);
142
143        // If redc_pow2_exp < LIMB_BITS, the lower bits of the reduced current limb
144        // correspond to the upper bits of the returned result limb.
145        let result_val = self.last_redced_val
146            | Self::redc_lshift_hi(self.redc_pow2_rshift, self.redc_pow2_mask, redced_t_val);
147        // If redc_pow2_exp < LIMB_BITS, the upper bits of the reduced current limb
148        // will become the lower bits of the subsequently returned result limb.
149        self.last_redced_val =
150            Self::redc_rshift_lo(self.redc_pow2_rshift, self.redc_pow2_mask, redced_t_val);
151
152        result_val
153    }
154
155    pub fn finish(self, t_val: LimbType) -> (LimbType, LimbType) {
156        debug_assert_eq!(t_val & !self.redc_pow2_mask, 0);
157        let (carry, redced_t_val) = ct_add_l_l(t_val, self.carry);
158        (
159            carry,
160            self.last_redced_val
161                | Self::redc_lshift_hi(self.redc_pow2_rshift, self.redc_pow2_mask, redced_t_val),
162        )
163    }
164
165    pub fn finish_in_twos_complement(self, t_val: LimbType) -> (LimbType, LimbType) {
166        let t_val_sign = t_val >> (LIMB_BITS - 1);
167        let (carry, redced_t_val) = ct_add_l_l(t_val, self.carry);
168
169        // In two's complement representation, the addition overflows iff the sign
170        // bit (indicating a virtual borrow) is getting neutralized.
171        debug_assert!(carry == 0 || t_val_sign == 1);
172        debug_assert!(carry == 0 || redced_t_val >> (LIMB_BITS - 1) == 0);
173        debug_assert!(carry == 1 || redced_t_val >> (LIMB_BITS - 1) == t_val_sign);
174        let redced_t_val_sign = carry ^ t_val_sign;
175        debug_assert!(redced_t_val >> (LIMB_BITS - 1) == redced_t_val_sign);
176        let redced_t_val_extended_sign =
177            (0 as LimbType).wrapping_sub(redced_t_val_sign) & !self.redc_pow2_mask;
178        (
179            redced_t_val_sign,
180            self.last_redced_val
181                | Self::redc_lshift_hi(self.redc_pow2_rshift, self.redc_pow2_mask, redced_t_val)
182                | redced_t_val_extended_sign,
183        )
184    }
185}
186
187#[derive(Debug)]
188pub enum CtMontgomeryRedcMpError {
189    InvalidModulus,
190    InsufficientResultSpace,
191    InputValueOutOfRange,
192}
193
194pub fn ct_montgomery_redc_mp<TT: MpMutUInt, NT: MpUIntCommon>(
195    t: &mut TT,
196    n: &NT,
197    neg_n0_inv_mod_l: LimbType,
198) -> Result<(), CtMontgomeryRedcMpError> {
199    if n.is_empty() {
200        return Err(CtMontgomeryRedcMpError::InvalidModulus);
201    }
202    let n0_val = n.load_l(0);
203    if ct_eq_l_l(n0_val & 1, 0).unwrap() != 0 {
204        return Err(CtMontgomeryRedcMpError::InvalidModulus);
205    }
206    if !n.len_is_compatible_with(t.len()) {
207        return Err(CtMontgomeryRedcMpError::InsufficientResultSpace);
208    }
209    if !t.len_is_compatible_with(2 * n.len()) {
210        return Err(CtMontgomeryRedcMpError::InputValueOutOfRange);
211    }
212    let t_nlimbs = t.nlimbs();
213    let n_nlimbs = n.nlimbs();
214    debug_assert!(n0_val.wrapping_mul(neg_n0_inv_mod_l) == !0);
215
216    let mut reduced_t_carry = 0;
217    // t's high limb might be a partial one, do not update directly in the course of
218    // reducing in order to avoid overflowing it. Use a shadow instead.
219    let mut t_high_shadow = t.load_l(t_nlimbs - 1);
220    for _i in 0..ct_montgomery_radix_shift_mp_nlimbs(n.len()) {
221        let mut redc_kernel =
222            CtMontgomeryRedcKernel::start(LIMB_BITS, t.load_l(0), n0_val, neg_n0_inv_mod_l);
223        let mut j = 0;
224        while j + 2 < n_nlimbs {
225            debug_assert!(j < t_nlimbs - 1);
226            t.store_l_full(
227                j,
228                redc_kernel.update(t.load_l_full(j + 1), n.load_l_full(j + 1)),
229            );
230            j += 1;
231        }
232
233        debug_assert_eq!(j + 2, n_nlimbs);
234        // Do not read the potentially partial, stale high limb directly from t, use the
235        // t_high_shadow shadow instead.
236        let t_val = if j + 2 != t_nlimbs {
237            t.load_l_full(j + 1)
238        } else {
239            t_high_shadow
240        };
241        t.store_l_full(j, redc_kernel.update(t_val, n.load_l(j + 1)));
242        j += 1;
243
244        while j + 2 < t_nlimbs {
245            t.store_l_full(j, redc_kernel.update(t.load_l_full(j + 1), 0));
246            j += 1;
247        }
248        if j + 1 == t_nlimbs - 1 {
249            t.store_l_full(j, redc_kernel.update(t_high_shadow, 0));
250            j += 1;
251        }
252        debug_assert_eq!(j, t_nlimbs - 1);
253
254        (reduced_t_carry, t_high_shadow) = redc_kernel.finish(reduced_t_carry);
255    }
256
257    // Now apply the high limb shadow back.
258    let t_high_shadow_mask = t.partial_high_mask();
259    let t_high_shadow_shift = t.partial_high_shift();
260    assert!(t_high_shadow_shift == 0 || reduced_t_carry == 0);
261    reduced_t_carry |= (t_high_shadow & !t_high_shadow_mask) >> t_high_shadow_shift;
262    t_high_shadow &= t_high_shadow_mask;
263    t.store_l(t_nlimbs - 1, t_high_shadow);
264
265    ct_sub_cond_mp_mp(t, n, LimbChoice::from(reduced_t_carry) | ct_geq_mp_mp(t, n));
266    debug_assert!(ct_geq_mp_mp(t, n).unwrap() == 0);
267    Ok(())
268}
269
270#[cfg(test)]
271fn test_ct_montgomery_redc_mp<TT: MpMutUIntSlice, NT: MpMutUIntSlice>() {
272    use super::div_impl::ct_mod_mp_mp;
273    use super::limb::LIMB_BYTES;
274
275    for i in 0..64 {
276        const MERSENNE_PRIME_13: LimbType = 8191 as LimbType;
277        let n_high = MERSENNE_PRIME_13.wrapping_mul((16385 as LimbType).wrapping_mul(i));
278        for j in 0..64 {
279            const MERSENNE_PRIME_17: LimbType = 131071 as LimbType;
280            let n_low = MERSENNE_PRIME_17.wrapping_mul((1023 as LimbType).wrapping_mul(j));
281            // Force n_low odd.
282            let n_low = n_low | 1;
283            let mut n = tst_mk_mp_backing_vec!(NT, 2 * LIMB_BYTES);
284            let mut n = NT::from_slice(n.as_mut_slice()).unwrap();
285            n.store_l(0, n_low);
286            n.store_l(1, n_high);
287            let neg_n0_inv = ct_montgomery_neg_n0_inv_mod_l_mp(&n).unwrap();
288
289            for k in 0..8 {
290                let t_high = MERSENNE_PRIME_17.wrapping_mul((8191 as LimbType).wrapping_mul(k));
291                for l in 0..8 {
292                    let t_low =
293                        MERSENNE_PRIME_13.wrapping_mul((131087 as LimbType).wrapping_mul(l));
294
295                    let mut t = tst_mk_mp_backing_vec!(TT, 2 * LIMB_BYTES);
296                    let mut t = TT::from_slice(t.as_mut_slice()).unwrap();
297                    t.store_l(0, t_low);
298                    t.store_l(1, t_high);
299
300                    // All montgomery operations are defined mod n, compute t mod n
301                    ct_mod_mp_mp(None, &mut t, &CtMpDivisor::new(&n, None).unwrap());
302                    let t_low = t.load_l(0);
303                    let t_high = t.load_l(1);
304
305                    // To Montgomery form: t * R mod n
306                    ct_to_montgomery_form_direct_mp(&mut t, &n).unwrap();
307
308                    // And back to normal: (t * R mod n) / R mod n
309                    ct_montgomery_redc_mp(&mut t, &n, neg_n0_inv).unwrap();
310                    assert_eq!(t.load_l(0), t_low);
311                    assert_eq!(t.load_l(1), t_high);
312                }
313            }
314        }
315    }
316}
317
318#[test]
319fn test_ct_montgomery_redc_be_be() {
320    use super::limbs_buffer::MpMutBigEndianUIntByteSlice;
321    test_ct_montgomery_redc_mp::<MpMutBigEndianUIntByteSlice, MpMutBigEndianUIntByteSlice>()
322}
323
324#[test]
325fn test_ct_montgomery_redc_le_le() {
326    use super::limbs_buffer::MpMutLittleEndianUIntByteSlice;
327    test_ct_montgomery_redc_mp::<MpMutLittleEndianUIntByteSlice, MpMutLittleEndianUIntByteSlice>()
328}
329
330#[test]
331fn test_ct_montgomery_redc_ne_ne() {
332    use super::limbs_buffer::MpMutNativeEndianUIntLimbsSlice;
333    test_ct_montgomery_redc_mp::<MpMutNativeEndianUIntLimbsSlice, MpMutNativeEndianUIntLimbsSlice>()
334}
335
336#[derive(Debug)]
337pub enum CtMontgomeryMulModMpMpError {
338    InvalidModulus,
339    InsufficientResultSpace,
340    InconsistentOperandLengths,
341}
342
343pub fn ct_montgomery_mul_mod_mp_mp<
344    RT: MpMutUIntSlice,
345    T0: MpUIntCommon,
346    T1: MpUIntCommon,
347    NT: MpUIntCommon,
348>(
349    result: &mut RT,
350    op0: &T0,
351    op1: &T1,
352    n: &NT,
353    neg_n0_inv_mod_l: LimbType,
354) -> Result<(), CtMontgomeryMulModMpMpError> {
355    // This is an implementation of the "Finely Integrated Operand Scanning (FIOS)
356    // Method" approach to fused multiplication and Montgomery reduction, as
357    // described in "Analyzing and Comparing Montgomery Multiplication
358    // Algorithm", IEEE Micro, 16(3):26-33, June 1996.
359    if n.is_empty() {
360        return Err(CtMontgomeryMulModMpMpError::InvalidModulus);
361    }
362    let n0_val = n.load_l(0);
363    if ct_eq_l_l(n0_val & 1, 0).unwrap() != 0 {
364        return Err(CtMontgomeryMulModMpMpError::InvalidModulus);
365    }
366    if !n.len_is_compatible_with(result.len()) {
367        return Err(CtMontgomeryMulModMpMpError::InsufficientResultSpace);
368    }
369    debug_assert!(n.nlimbs() <= result.nlimbs());
370    if !op0.len_is_compatible_with(n.len()) || !op1.len_is_compatible_with(n.len()) {
371        return Err(CtMontgomeryMulModMpMpError::InconsistentOperandLengths);
372    }
373    debug_assert!(op0.nlimbs() <= n.nlimbs());
374    debug_assert!(ct_lt_mp_mp(op0, n).unwrap() != 0);
375    debug_assert!(op1.nlimbs() <= n.nlimbs());
376    debug_assert!(ct_lt_mp_mp(op1, n).unwrap() != 0);
377
378    let op0_nlimbs = op0.nlimbs();
379    let op1_nlimbs = op1.nlimbs();
380    let n_nlimbs = n.nlimbs();
381    debug_assert!(n0_val.wrapping_mul(neg_n0_inv_mod_l) == !0);
382
383    result.clear_bytes_above(0);
384    let mut result = result.shrink_to(n.len());
385    debug_assert_eq!(result.nlimbs(), n.nlimbs());
386    let mut result_carry = 0;
387    // result's high limb might be a partial one, do not update directly in the
388    // course of reducing in order to avoid overflowing it. Use a shadow
389    // instead.
390    let mut result_high_shadow = 0;
391    for i in 0..op0_nlimbs {
392        debug_assert!(result_carry <= 1); // Loop invariant.
393        let op0_val = op0.load_l(i);
394
395        // Do not read the potentially partial, stale high limb directly from result,
396        // use the result_high_shadow shadow instead.
397        let result_val = if n_nlimbs != 1 {
398            result.load_l_full(0)
399        } else {
400            result_high_shadow
401        };
402        let op1_val = op1.load_l(0);
403        let (mut op0_op1_add_carry, result_val) =
404            ct_mul_add_l_l_l_c(result_val, op0_val, op1_val, 0);
405
406        let mut redc_kernel =
407            CtMontgomeryRedcKernel::start(LIMB_BITS, result_val, n0_val, neg_n0_inv_mod_l);
408
409        let mut j = 0;
410        while j + 1 < op1_nlimbs {
411            let op1_val = op1.load_l(j + 1);
412
413            // Do not read the potentially partial, stale high limb directly from result,
414            // use the result_high_shadow shadow instead.
415            let mut result_val = if j + 1 != n_nlimbs - 1 {
416                result.load_l_full(j + 1)
417            } else {
418                result_high_shadow
419            };
420
421            (op0_op1_add_carry, result_val) =
422                ct_mul_add_l_l_l_c(result_val, op0_val, op1_val, op0_op1_add_carry);
423
424            let n_val = n.load_l(j + 1);
425            let result_val = redc_kernel.update(result_val, n_val);
426            result.store_l_full(j, result_val);
427            j += 1;
428        }
429        debug_assert_eq!(j + 1, op1_nlimbs);
430
431        // If op1_nlimbs < n_nlimbs, handle the rest by propagating the multiplication
432        // carry and continue redcing.
433        while j + 1 < n_nlimbs {
434            // Do not read the potentially partial, stale high limb directly from result,
435            // use the result_high_shadow shadow instead.
436            let mut result_val = if j + 1 != n_nlimbs - 1 {
437                result.load_l_full(j + 1)
438            } else {
439                result_high_shadow
440            };
441
442            (op0_op1_add_carry, result_val) = ct_add_l_l(result_val, op0_op1_add_carry);
443
444            let n_val = n.load_l(j + 1);
445            let result_val = redc_kernel.update(result_val, n_val);
446            result.store_l_full(j, result_val);
447            j += 1;
448        }
449        debug_assert_eq!(j + 1, n_nlimbs);
450
451        let mut result_val;
452        debug_assert!(result_carry <= 1);
453        (result_carry, result_val) = ct_add_l_l(result_carry, op0_op1_add_carry);
454        debug_assert!(result_carry <= 1);
455        debug_assert!(result_carry == 0 || result_val == 0);
456
457        (result_carry, result_val) = redc_kernel.finish(result_val);
458        debug_assert!(result_carry <= 1);
459        result_high_shadow = result_val;
460    }
461
462    // If op0_nlimbs < the montgomery radix shift distance, handle the rest by
463    // REDCing it.
464    for _i in op0_nlimbs..ct_montgomery_radix_shift_mp_nlimbs(n.len()) {
465        // Do not read the potentially partial, stale high limb directly from result,
466        // use the result_high_shadow shadow instead.
467        let result_val = if n_nlimbs != 1 {
468            result.load_l_full(0)
469        } else {
470            result_high_shadow
471        };
472
473        let mut redc_kernel =
474            CtMontgomeryRedcKernel::start(LIMB_BITS, result_val, n0_val, neg_n0_inv_mod_l);
475
476        let mut j = 0;
477        while j + 1 < n_nlimbs {
478            // Do not read the potentially partial, stale high limb directly from result,
479            // use the result_high_shadow shadow instead.
480            let result_val = if j + 1 != n_nlimbs - 1 {
481                result.load_l_full(j + 1)
482            } else {
483                result_high_shadow
484            };
485
486            let n_val = n.load_l(j + 1);
487            let result_val = redc_kernel.update(result_val, n_val);
488            result.store_l_full(j, result_val);
489            j += 1;
490        }
491        debug_assert_eq!(j + 1, n_nlimbs);
492
493        (result_carry, result_high_shadow) = redc_kernel.finish(result_carry);
494        debug_assert!(result_carry <= 1);
495    }
496
497    // Now apply the high limb shadow back.
498    debug_assert!(result.nlimbs() == n.nlimbs());
499    let result_high_shadow_mask = n.partial_high_mask();
500    let result_high_shadow_shift = n.partial_high_shift();
501    assert!(result_high_shadow_shift == 0 || result_carry == 0);
502    result_carry |= (result_high_shadow & !result_high_shadow_mask) >> result_high_shadow_shift;
503    result_high_shadow &= result_high_shadow_mask;
504    result.store_l(n_nlimbs - 1, result_high_shadow);
505
506    let result_geq_n = LimbChoice::from(result_carry) | ct_geq_mp_mp(&result, n);
507    ct_sub_cond_mp_mp(&mut result, n, result_geq_n);
508    debug_assert!(ct_geq_mp_mp(&result, n).unwrap() == 0);
509    Ok(())
510}
511
512#[cfg(test)]
513fn test_ct_montgomery_mul_mod_mp_mp<
514    RT: MpMutUIntSlice,
515    T0: MpMutUIntSlice,
516    T1: MpMutUIntSlice,
517    NT: MpMutUIntSlice,
518>() {
519    use super::div_impl::ct_mod_mp_mp;
520    use super::limb::LIMB_BYTES;
521    use super::mul_impl::ct_mul_trunc_mp_mp;
522
523    for i in 0..16 {
524        const MERSENNE_PRIME_13: LimbType = 8191 as LimbType;
525        let n_high = MERSENNE_PRIME_13.wrapping_mul((65543 as LimbType).wrapping_mul(i));
526        for j in 0..16 {
527            const MERSENNE_PRIME_17: LimbType = 131071 as LimbType;
528            let n_low = MERSENNE_PRIME_17.wrapping_mul((4095 as LimbType).wrapping_mul(j));
529            // Force n_low odd.
530            let n_low = n_low | 1;
531            let mut n_buf = tst_mk_mp_backing_vec!(NT, 2 * LIMB_BYTES);
532            let mut n = NT::from_slice(n_buf.as_mut_slice()).unwrap();
533            n.store_l(0, n_low);
534            n.store_l(1, n_high);
535            drop(n);
536            let n_lengths = if !RT::SUPPORTS_UNALIGNED_BUFFER_LENGTHS
537                || !T0::SUPPORTS_UNALIGNED_BUFFER_LENGTHS
538                || !T1::SUPPORTS_UNALIGNED_BUFFER_LENGTHS
539                || !NT::SUPPORTS_UNALIGNED_BUFFER_LENGTHS
540            {
541                [LIMB_BYTES, 2 * LIMB_BYTES]
542            } else {
543                [2 * LIMB_BYTES - 1, 2 * LIMB_BYTES]
544            };
545            for n_len in n_lengths {
546                let mut n_buf = n_buf.clone();
547                let mut n = NT::from_slice(n_buf.as_mut_slice()).unwrap();
548                n.clear_bytes_above(n_len);
549                let n = n.shrink_to(n_len);
550                let neg_n0_inv = ct_montgomery_neg_n0_inv_mod_l_mp(&n).unwrap();
551
552                // r_mod_n = 2^(2 * LIMB_BITS) % n.
553                let mut r_mod_n = tst_mk_mp_backing_vec!(RT, 3 * LIMB_BYTES);
554                let mut r_mod_n = RT::from_slice(r_mod_n.as_mut_slice()).unwrap();
555                r_mod_n.store_l_full(ct_montgomery_radix_shift_mp_nlimbs(n_len), 1);
556                ct_mod_mp_mp(None, &mut r_mod_n, &CtMpDivisor::new(&n, None).unwrap());
557                let r_mod_n = r_mod_n.shrink_to(n.len());
558
559                for k in 0..4 {
560                    let a_high =
561                        MERSENNE_PRIME_17.wrapping_mul((16383 as LimbType).wrapping_mul(k));
562                    for l in 0..4 {
563                        let a_low =
564                            MERSENNE_PRIME_13.wrapping_mul((262175 as LimbType).wrapping_mul(l));
565                        let mut a_buf = tst_mk_mp_backing_vec!(T0, 2 * LIMB_BYTES);
566                        let mut a = T0::from_slice(a_buf.as_mut_slice()).unwrap();
567                        a.store_l(0, a_low);
568                        a.store_l(1, a_high);
569                        // All montgomery operations are defined mod n, compute a mod n
570                        ct_mod_mp_mp(None, &mut a, &CtMpDivisor::new(&n, None).unwrap());
571                        drop(a);
572                        for s in 0..4 {
573                            let b_high = MERSENNE_PRIME_13
574                                .wrapping_mul((262175 as LimbType).wrapping_mul(s));
575                            for t in 0..4 {
576                                const MERSENNE_PRIME_19: LimbType = 524287 as LimbType;
577                                let b_low = MERSENNE_PRIME_19
578                                    .wrapping_mul((4095 as LimbType).wrapping_mul(t));
579                                let mut b_buf = tst_mk_mp_backing_vec!(T1, 2 * LIMB_BYTES);
580                                let mut b = T1::from_slice(b_buf.as_mut_slice()).unwrap();
581                                b.store_l(0, b_low);
582                                b.store_l(1, b_high);
583                                // All montgomery operations are defined mod n, compute b mod n
584                                ct_mod_mp_mp(None, &mut b, &CtMpDivisor::new(&n, None).unwrap());
585                                drop(b);
586
587                                for op_len in [0, 1 * LIMB_BYTES, n_len] {
588                                    let mut a_buf = a_buf.clone();
589                                    let mut a = T0::from_slice(a_buf.as_mut_slice()).unwrap();
590                                    a.clear_bytes_above(op_len);
591                                    let a = a.shrink_to(op_len);
592                                    let mut b_buf = b_buf.clone();
593                                    let mut b = T1::from_slice(b_buf.as_mut_slice()).unwrap();
594                                    b.clear_bytes_above(op_len);
595                                    let b = b.shrink_to(op_len);
596
597                                    let mut _result = tst_mk_mp_backing_vec!(RT, 4 * LIMB_BYTES);
598                                    let mut result =
599                                        RT::from_slice(_result.as_mut_slice()).unwrap();
600                                    let mut mg_mul_result = result.shrink_to(n_len);
601                                    ct_montgomery_mul_mod_mp_mp(
602                                        &mut mg_mul_result,
603                                        &a,
604                                        &b,
605                                        &n,
606                                        neg_n0_inv,
607                                    )
608                                    .unwrap();
609                                    drop(mg_mul_result);
610
611                                    // For testing against the expected result computed using the
612                                    // "conventional" methods only, multiply by r_mod_n -- this
613                                    // avoids having to multiply
614                                    // the conventional product by r^-1 mod n, which is
615                                    // not known without implementing Euklid's algorithm.
616                                    ct_mul_trunc_mp_mp(&mut result, n.len(), &r_mod_n);
617                                    ct_mod_mp_mp(
618                                        None,
619                                        &mut result,
620                                        &CtMpDivisor::new(&n, None).unwrap(),
621                                    );
622                                    drop(result);
623
624                                    let mut _expected = tst_mk_mp_backing_vec!(RT, 4 * LIMB_BYTES);
625                                    let mut expected =
626                                        RT::from_slice(_expected.as_mut_slice()).unwrap();
627                                    expected.copy_from(&a);
628                                    ct_mul_trunc_mp_mp(&mut expected, op_len, &b);
629                                    ct_mod_mp_mp(
630                                        None,
631                                        &mut expected,
632                                        &CtMpDivisor::new(&n, None).unwrap(),
633                                    );
634                                    drop(expected);
635
636                                    assert_eq!(_result, _expected);
637                                }
638                            }
639                        }
640                    }
641                }
642            }
643        }
644    }
645}
646
647#[test]
648fn test_ct_montgomery_mul_mod_be_be_be_be() {
649    use super::limbs_buffer::MpMutBigEndianUIntByteSlice;
650    test_ct_montgomery_mul_mod_mp_mp::<
651        MpMutBigEndianUIntByteSlice,
652        MpMutBigEndianUIntByteSlice,
653        MpMutBigEndianUIntByteSlice,
654        MpMutBigEndianUIntByteSlice,
655    >()
656}
657
658#[test]
659fn test_ct_montgomery_mul_mod_le_le_le_le() {
660    use super::limbs_buffer::MpMutLittleEndianUIntByteSlice;
661    test_ct_montgomery_mul_mod_mp_mp::<
662        MpMutLittleEndianUIntByteSlice,
663        MpMutLittleEndianUIntByteSlice,
664        MpMutLittleEndianUIntByteSlice,
665        MpMutLittleEndianUIntByteSlice,
666    >()
667}
668
669#[test]
670fn test_ct_montgomery_mul_mod_ne_ne_ne_ne() {
671    use super::limbs_buffer::MpMutNativeEndianUIntLimbsSlice;
672    test_ct_montgomery_mul_mod_mp_mp::<
673        MpMutNativeEndianUIntLimbsSlice,
674        MpMutNativeEndianUIntLimbsSlice,
675        MpMutNativeEndianUIntLimbsSlice,
676        MpMutNativeEndianUIntLimbsSlice,
677    >()
678}
679
680#[derive(Debug)]
681pub enum CtMontgomeryTransformationError {
682    InvalidModulus,
683    InsufficientResultSpace,
684}
685
686pub fn ct_to_montgomery_form_direct_mp<TT: MpMutUInt, NT: MpUIntCommon>(
687    t: &mut TT,
688    n: &NT,
689) -> Result<(), CtMontgomeryTransformationError> {
690    if n.test_bit(0).unwrap() == 0 {
691        return Err(CtMontgomeryTransformationError::InvalidModulus);
692    }
693    if !n.len_is_compatible_with(t.len()) {
694        return Err(CtMontgomeryTransformationError::InsufficientResultSpace);
695    }
696    debug_assert!(t.nlimbs() >= n.nlimbs());
697    let radix_shift_len = ct_montgomery_radix_shift_len(n.len());
698    let n = CtMpDivisor::new(n, None).map_err(|e| match e {
699        CtMpDivisorError::DivisorIsZero => {
700            // n had been checked for being odd above, so should be unreachable, but play
701            // safe.
702            debug_assert!(false);
703            CtMontgomeryTransformationError::InvalidModulus
704        }
705    })?;
706    ct_mod_lshifted_mp_mp(t, t.len(), radix_shift_len, &n).map_err(|e| match e {
707        CtModLshiftedMpMpError::InsufficientRemainderSpace => {
708            // The result space had been checked at function entry already, but play safe.
709            debug_assert!(false);
710            CtMontgomeryTransformationError::InsufficientResultSpace
711        }
712    })?;
713    Ok(())
714}
715
716pub fn ct_montgomery_radix2_mod_n_mp<RX2T: MpMutUIntSlice, NT: MpUIntCommon>(
717    radix2_mod_n_out: &mut RX2T,
718    n: &NT,
719) -> Result<(), CtMontgomeryTransformationError> {
720    if n.test_bit(0).unwrap() == 0 {
721        return Err(CtMontgomeryTransformationError::InvalidModulus);
722    }
723    if !n.len_is_compatible_with(radix2_mod_n_out.len()) {
724        return Err(CtMontgomeryTransformationError::InsufficientResultSpace);
725    }
726
727    radix2_mod_n_out.clear_bytes_above(n.len());
728    let mut radix2_mod_n_out = radix2_mod_n_out.shrink_to(n.len());
729    debug_assert_eq!(radix2_mod_n_out.nlimbs(), n.nlimbs());
730
731    let radix_shift_len = ct_montgomery_radix_shift_len(n.len());
732    let n = CtMpDivisor::new(n, None).map_err(|e| match e {
733        CtMpDivisorError::DivisorIsZero => {
734            // n had been checked for being odd above, so should be unreachable, but play
735            // safe.
736            debug_assert!(false);
737            CtMontgomeryTransformationError::InvalidModulus
738        }
739    })?;
740    ct_mod_pow2_mp::<_, _>(2 * 8 * radix_shift_len, &mut radix2_mod_n_out, &n).map_err(
741        |e| match e {
742            CtModPow2MpError::InsufficientRemainderSpace => {
743                // The result space had been checked at function entry already, but play safe.
744                debug_assert!(false);
745                CtMontgomeryTransformationError::InsufficientResultSpace
746            }
747        },
748    )?;
749    Ok(())
750}
751
752#[derive(Debug)]
753pub enum CtToMontgomeryFormMpError {
754    InvalidModulus,
755    InsufficientResultSpace,
756    InconsistentOperandLengths,
757    InconsistentRadix2ModNLenth,
758}
759
760pub fn ct_to_montgomery_form_mp<
761    RT: MpMutUIntSlice,
762    TT: MpUIntCommon,
763    NT: MpUIntCommon,
764    RX2T: MpUIntCommon,
765>(
766    result: &mut RT,
767    t: &TT,
768    n: &NT,
769    neg_n0_inv_mod_l: LimbType,
770    radix2_mod_n: &RX2T,
771) -> Result<(), CtToMontgomeryFormMpError> {
772    // The Montgomery multiplication will do all error checking needed. However, to
773    // disambiguate which of the two factors has a length inconsistent with n,
774    // if any, check that here.
775    if !t.len_is_compatible_with(n.len()) {
776        return Err(CtToMontgomeryFormMpError::InconsistentOperandLengths);
777    }
778    if !radix2_mod_n.len_is_compatible_with(n.len()) {
779        return Err(CtToMontgomeryFormMpError::InconsistentRadix2ModNLenth);
780    }
781    debug_assert!(ct_lt_mp_mp(t, n).unwrap() != 0);
782    debug_assert!(ct_lt_mp_mp(radix2_mod_n, n).unwrap() != 0);
783
784    // All input arguments have been validated above, just unwrap().
785    ct_montgomery_mul_mod_mp_mp(result, t, radix2_mod_n, n, neg_n0_inv_mod_l).map_err(
786        |e| match e {
787            CtMontgomeryMulModMpMpError::InsufficientResultSpace => {
788                CtToMontgomeryFormMpError::InsufficientResultSpace
789            }
790            CtMontgomeryMulModMpMpError::InvalidModulus => {
791                CtToMontgomeryFormMpError::InvalidModulus
792            }
793            CtMontgomeryMulModMpMpError::InconsistentOperandLengths => {
794                // The multiplication's factors have been validated above, but play safe.
795                CtToMontgomeryFormMpError::InconsistentOperandLengths
796            }
797        },
798    )?;
799
800    Ok(())
801}
802
803#[cfg(test)]
804fn test_ct_to_montgomery_form_mp<TT: MpMutUIntSlice, NT: MpMutUIntSlice, RX2T: MpMutUIntSlice>() {
805    use super::cmp_impl::ct_eq_mp_mp;
806    use super::div_impl::ct_mod_mp_mp;
807    use super::limb::LIMB_BYTES;
808
809    for i in 0..16 {
810        const MERSENNE_PRIME_13: LimbType = 8191 as LimbType;
811        let n_high = MERSENNE_PRIME_13.wrapping_mul((65543 as LimbType).wrapping_mul(i));
812        for j in 0..16 {
813            const MERSENNE_PRIME_17: LimbType = 131071 as LimbType;
814            let n_low = MERSENNE_PRIME_17.wrapping_mul((4095 as LimbType).wrapping_mul(j));
815            // Force n_low odd.
816            let n_low = n_low | 1;
817            let mut n_buf = tst_mk_mp_backing_vec!(NT, 2 * LIMB_BYTES);
818            let mut n = NT::from_slice(n_buf.as_mut_slice()).unwrap();
819            n.store_l(0, n_low);
820            n.store_l(1, n_high);
821            drop(n);
822            let n_lengths = if !NT::SUPPORTS_UNALIGNED_BUFFER_LENGTHS {
823                [LIMB_BYTES, 2 * LIMB_BYTES]
824            } else {
825                [2 * LIMB_BYTES - 1, 2 * LIMB_BYTES]
826            };
827
828            for n_len in n_lengths {
829                let mut n_buf = n_buf.clone();
830                let mut n = NT::from_slice(n_buf.as_mut_slice()).unwrap();
831                n.clear_bytes_above(n_len);
832                let n = n.shrink_to(n_len);
833                let neg_n0_inv = ct_montgomery_neg_n0_inv_mod_l_mp(&n).unwrap();
834
835                let mut radix2_mod_n = tst_mk_mp_backing_vec!(RX2T, n_len);
836                let mut radix2_mod_n = RX2T::from_slice(radix2_mod_n.as_mut_slice()).unwrap();
837                ct_montgomery_radix2_mod_n_mp(&mut radix2_mod_n, &n).unwrap();
838
839                for k in 0..4 {
840                    let a_high =
841                        MERSENNE_PRIME_17.wrapping_mul((16383 as LimbType).wrapping_mul(k));
842                    for l in 0..4 {
843                        let a_low =
844                            MERSENNE_PRIME_13.wrapping_mul((262175 as LimbType).wrapping_mul(l));
845                        let mut a = tst_mk_mp_backing_vec!(TT, 2 * LIMB_BYTES);
846                        let mut a = TT::from_slice(a.as_mut_slice()).unwrap();
847                        a.store_l(0, a_low);
848                        a.store_l(1, a_high);
849                        // All montgomery operations are defined mod n, compute a mod n
850                        ct_mod_mp_mp(None, &mut a, &CtMpDivisor::new(&n, None).unwrap());
851                        let mut a = a.shrink_to(n_len);
852
853                        let mut result = tst_mk_mp_backing_vec!(TT, n_len);
854                        let mut result = TT::from_slice(result.as_mut_slice()).unwrap();
855                        ct_to_montgomery_form_mp(&mut result, &a, &n, neg_n0_inv, &radix2_mod_n)
856                            .unwrap();
857
858                        ct_to_montgomery_form_direct_mp(&mut a, &n).unwrap();
859                        assert_eq!(ct_eq_mp_mp(&result, &a).unwrap(), 1);
860                    }
861                }
862            }
863        }
864    }
865}
866
867#[test]
868fn test_ct_to_montgomery_form_be_be_be() {
869    use super::limbs_buffer::MpMutBigEndianUIntByteSlice;
870    test_ct_to_montgomery_form_mp::<
871        MpMutBigEndianUIntByteSlice,
872        MpMutBigEndianUIntByteSlice,
873        MpMutBigEndianUIntByteSlice,
874    >()
875}
876
877#[test]
878fn test_ct_to_montgomery_form_le_le_le() {
879    use super::limbs_buffer::MpMutLittleEndianUIntByteSlice;
880    test_ct_to_montgomery_form_mp::<
881        MpMutLittleEndianUIntByteSlice,
882        MpMutLittleEndianUIntByteSlice,
883        MpMutLittleEndianUIntByteSlice,
884    >()
885}
886
887#[test]
888fn test_ct_to_montgomery_form_ne_ne_ne() {
889    use super::limbs_buffer::MpMutNativeEndianUIntLimbsSlice;
890    test_ct_to_montgomery_form_mp::<
891        MpMutNativeEndianUIntLimbsSlice,
892        MpMutNativeEndianUIntLimbsSlice,
893        MpMutNativeEndianUIntLimbsSlice,
894    >()
895}
896
897// result must have been initialized with a one in Montgomery form before the
898// call.
899fn _ct_montogmery_exp_mod_mp_mp<
900    RT: MpMutUIntSlice,
901    T0: MpUIntCommon,
902    NT: MpUIntCommon,
903    ET: MpUIntCommon,
904>(
905    result: &mut RT,
906    op0: &T0,
907    n: &NT,
908    neg_n0_inv_mod_l: LimbType,
909    exponent: &ET,
910    exponent_nbits: usize,
911    scratch: &mut [LimbType],
912) {
913    debug_assert_eq!(result.nlimbs(), n.nlimbs());
914
915    let n_nlimbs = MpMutNativeEndianUIntLimbsSlice::nlimbs_for_len(n.len());
916    debug_assert!(scratch.len() >= n_nlimbs);
917    let mut scratch = MpMutNativeEndianUIntLimbsSlice::from_limbs(scratch);
918    scratch.clear_bytes_above(n.len());
919    let mut scratch = scratch.shrink_to(n.len());
920
921    let exponent_nbits = exponent_nbits.min(8 * exponent.len());
922    for i in 0..exponent_nbits {
923        // Input arguments have been validated/setup by callers, just unwrap() the
924        // result.
925        ct_montgomery_mul_mod_mp_mp(&mut scratch, result, result, n, neg_n0_inv_mod_l).unwrap();
926        ct_montgomery_mul_mod_mp_mp(result, &scratch, op0, n, neg_n0_inv_mod_l).unwrap();
927        // If the current exponent bit is zero, "undo" the latter multiplication.
928        result.copy_from_cond(&scratch, !exponent.test_bit(exponent_nbits - i - 1));
929    }
930}
931
932#[derive(Debug)]
933pub enum CtMontgomeryExpModOddMpMpError {
934    InvalidModulus,
935    InsufficientResultSpace,
936    InsufficientScratchSpace,
937    InconsistentOperandLengths,
938    InconsistendRadixModNLengh,
939}
940
941#[allow(clippy::too_many_arguments)]
942pub fn ct_montogmery_exp_mod_odd_mp_mp<
943    RT: MpMutUIntSlice,
944    T0: MpUIntCommon,
945    NT: MpUIntCommon,
946    RXT: MpUIntCommon,
947    ET: MpUIntCommon,
948>(
949    result: &mut RT,
950    op0: &T0,
951    n: &NT,
952    neg_n0_inv_mod_l: LimbType,
953    radix_mod_n: &RXT,
954    exponent: &ET,
955    exponent_nbits: usize,
956    scratch: &mut [LimbType],
957) -> Result<(), CtMontgomeryExpModOddMpMpError> {
958    if n.test_bit(0).unwrap() == 0 {
959        return Err(CtMontgomeryExpModOddMpMpError::InvalidModulus);
960    }
961    if !n.len_is_compatible_with(result.len()) {
962        return Err(CtMontgomeryExpModOddMpMpError::InsufficientResultSpace);
963    }
964    if scratch.len() < MpMutNativeEndianUIntLimbsSlice::nlimbs_for_len(n.len()) {
965        return Err(CtMontgomeryExpModOddMpMpError::InsufficientScratchSpace);
966    }
967    if !op0.len_is_compatible_with(n.len()) {
968        return Err(CtMontgomeryExpModOddMpMpError::InconsistentOperandLengths);
969    }
970    debug_assert!(ct_lt_mp_mp(op0, n).unwrap() != 0);
971    if !radix_mod_n.len_is_compatible_with(n.len()) {
972        return Err(CtMontgomeryExpModOddMpMpError::InconsistendRadixModNLengh);
973    }
974    debug_assert!(ct_lt_mp_mp(radix_mod_n, n).unwrap() != 0);
975
976    // Initialize the result with a one in Montgomery form.
977    result.clear_bytes_above(n.len());
978    let mut result = result.shrink_to(n.len());
979    debug_assert_eq!(result.nlimbs(), n.nlimbs());
980    result.copy_from(radix_mod_n);
981
982    _ct_montogmery_exp_mod_mp_mp(
983        &mut result,
984        op0,
985        n,
986        neg_n0_inv_mod_l,
987        exponent,
988        exponent_nbits,
989        scratch,
990    );
991    Ok(())
992}
993
994#[derive(Debug)]
995pub enum CtExpModOddMpMpError {
996    InvalidModulus,
997    InsufficientResultSpace,
998    InsufficientScratchSpace,
999    InconsistentOperandLengths,
1000}
1001
1002pub fn ct_exp_mod_odd_mp_mp<
1003    RT: MpMutUIntSlice,
1004    T0: MpMutUInt,
1005    NT: MpUIntCommon,
1006    ET: MpUIntCommon,
1007>(
1008    result: &mut RT,
1009    op0: &mut T0,
1010    n: &NT,
1011    exponent: &ET,
1012    exponent_nbits: usize,
1013    scratch: &mut [LimbType],
1014) -> Result<(), CtExpModOddMpMpError> {
1015    if !n.len_is_compatible_with(result.len()) {
1016        return Err(CtExpModOddMpMpError::InsufficientResultSpace);
1017    }
1018    if !op0.len_is_compatible_with(n.len()) {
1019        return Err(CtExpModOddMpMpError::InconsistentOperandLengths);
1020    }
1021    debug_assert!(ct_lt_mp_mp(op0, n).unwrap() != 0);
1022    if !n.len_is_compatible_with(op0.len()) {
1023        // op0 will get transformed in-place into Montgomery form. So the
1024        // backing byte slice must be large enough.
1025        return Err(CtExpModOddMpMpError::InconsistentOperandLengths);
1026    }
1027    debug_assert_eq!(op0.nlimbs(), n.nlimbs());
1028    if scratch.len() < MpMutNativeEndianUIntLimbsSlice::nlimbs_for_len(n.len()) {
1029        return Err(CtExpModOddMpMpError::InsufficientScratchSpace);
1030    }
1031
1032    // This checks the modulus for validity as a side-effect.
1033    let neg_n0_inv_mod_l = ct_montgomery_neg_n0_inv_mod_l_mp(n).map_err(|e| match e {
1034        CtMontgomeryNegN0InvModLMpError::InvalidModulus => CtExpModOddMpMpError::InvalidModulus,
1035    })?;
1036
1037    // Shrink result[] to the length of n. It will be used to temporarily hold the
1038    // radix^2 mod n and ct_to_montgomery_form_mp() below would complain if its
1039    // length is unexpectedly large.
1040    result.clear_bytes_above(n.len());
1041    let mut result = result.shrink_to(n.len());
1042    debug_assert_eq!(result.nlimbs(), n.nlimbs());
1043
1044    // The radix squared mod n gets into result[], it will be reduced
1045    // later on to a one in Montgomery form.
1046    ct_montgomery_radix2_mod_n_mp(&mut result, n).unwrap();
1047
1048    // Transform op0 into Montgomery form, the function argument will get
1049    // overwritten to save an extra scratch buffer.
1050    let mut mg_op0 = MpMutNativeEndianUIntLimbsSlice::from_limbs(scratch);
1051    mg_op0.clear_bytes_above(n.len());
1052    let mut mg_op0 = mg_op0.shrink_to(n.len());
1053    ct_to_montgomery_form_mp(&mut mg_op0, op0, n, neg_n0_inv_mod_l, &result).unwrap();
1054    op0.copy_from(&mg_op0);
1055
1056    // Reduce the radix squared mod n in result[] to the radix mod n,
1057    // i.e. to a one in Montgomery form.
1058    ct_montgomery_redc_mp(&mut result, n, neg_n0_inv_mod_l).unwrap();
1059
1060    // Do the Montgomery exponentiation.
1061    _ct_montogmery_exp_mod_mp_mp(
1062        &mut result,
1063        op0,
1064        n,
1065        neg_n0_inv_mod_l,
1066        exponent,
1067        exponent_nbits,
1068        scratch,
1069    );
1070
1071    // And transform the result back from Montgomery form.
1072    ct_montgomery_redc_mp(&mut result, n, neg_n0_inv_mod_l).unwrap();
1073    Ok(())
1074}
1075
1076#[cfg(test)]
1077fn test_ct_exp_mod_odd_mp_mp<
1078    RT: MpMutUIntSlice,
1079    T0: MpMutUIntSlice,
1080    NT: MpMutUIntSlice,
1081    ET: MpMutUIntSlice,
1082>() {
1083    extern crate alloc;
1084    use super::limb::LIMB_BYTES;
1085    use super::mul_impl::ct_mul_trunc_mp_l;
1086    use super::shift_impl::ct_lshift_mp;
1087    use alloc::vec;
1088
1089    fn test_one<'a, RT: MpMutUIntSlice, T0: MpMutUIntSlice, NT: MpUIntCommon, ET: MpUIntCommon>(
1090        op0: &T0,
1091        n: &NT,
1092        exponent: &'a ET,
1093    ) {
1094        use super::cmp_impl::ct_eq_mp_mp;
1095        use super::div_impl::ct_mod_mp_mp;
1096        use super::mul_impl::{ct_mul_trunc_mp_mp, ct_square_trunc_mp};
1097
1098        let n_len = n.len();
1099
1100        let mut op0_mod_n = tst_mk_mp_backing_vec!(T0, n_len.max(op0.len()));
1101        let mut op0_mod_n = T0::from_slice(&mut op0_mod_n).unwrap();
1102        op0_mod_n.copy_from(op0);
1103        ct_mod_mp_mp(None, &mut op0_mod_n, &CtMpDivisor::new(n, None).unwrap());
1104
1105        let mut op0_scratch = tst_mk_mp_backing_vec!(T0, n_len);
1106        let mut op0_scratch = T0::from_slice(&mut op0_scratch).unwrap();
1107        op0_scratch.copy_from(&op0_mod_n);
1108        let mut result = tst_mk_mp_backing_vec!(RT, n_len);
1109        let mut result = RT::from_slice(&mut result).unwrap();
1110        let mut scratch =
1111            vec![0 as LimbType; MpMutNativeEndianUIntLimbsSlice::nlimbs_for_len(n_len)];
1112        ct_exp_mod_odd_mp_mp(
1113            &mut result,
1114            &mut op0_scratch,
1115            n,
1116            exponent,
1117            8 * exponent.len(),
1118            &mut scratch,
1119        )
1120        .unwrap();
1121
1122        // Compute the expected value using repeated multiplications/squarings and
1123        // modular reductions.
1124        let mut expected = tst_mk_mp_backing_vec!(RT, 2 * n_len);
1125        let mut expected = RT::from_slice(&mut expected).unwrap();
1126        expected.clear_bytes_above(0);
1127        expected.store_l(0, 1);
1128        for i in 0..8 * exponent.len() {
1129            ct_square_trunc_mp(&mut expected, n_len);
1130            ct_mod_mp_mp(None, &mut expected, &CtMpDivisor::new(n, None).unwrap());
1131            if exponent.test_bit(8 * exponent.len() - i - 1).unwrap() != 0 {
1132                ct_mul_trunc_mp_mp(&mut expected, n_len, &op0_mod_n);
1133                ct_mod_mp_mp(None, &mut expected, &CtMpDivisor::new(n, None).unwrap());
1134            }
1135        }
1136        assert_ne!(ct_eq_mp_mp(&result, &expected).unwrap(), 0);
1137    }
1138
1139    let exponent_len = LIMB_BYTES + 1;
1140    let mut e0_buf = tst_mk_mp_backing_vec!(ET, exponent_len);
1141    let mut e1_buf = tst_mk_mp_backing_vec!(ET, exponent_len);
1142    let mut e1 = ET::from_slice(&mut e1_buf).unwrap();
1143    e1.store_l(0, 1);
1144    drop(e1);
1145    let mut e2_buf = tst_mk_mp_backing_vec!(ET, exponent_len);
1146    let mut e2 = ET::from_slice(&mut e2_buf).unwrap();
1147    e2.store_l(0, 2);
1148    drop(e2);
1149    let mut ef_buf = tst_mk_mp_backing_vec!(ET, exponent_len);
1150    ef_buf.fill(0xefu8.into());
1151
1152    for n_len in [1, LIMB_BYTES + 1, 2 * LIMB_BYTES - 1, 3 * LIMB_BYTES] {
1153        let mut n = tst_mk_mp_backing_vec!(NT, n_len);
1154        let mut n = NT::from_slice(&mut n).unwrap();
1155        n.store_l(0, 1);
1156        while n.load_l((n_len - 1) / LIMB_BYTES) >> (8 * ((n_len - 1) % LIMB_BYTES)) == 0 {
1157            ct_mul_trunc_mp_l(&mut n, n_len, 251);
1158        }
1159
1160        for op0_len in 1..n_len {
1161            let mut op0 = tst_mk_mp_backing_vec!(T0, n_len);
1162            let mut op0 = T0::from_slice(&mut op0).unwrap();
1163            op0.store_l(0, 1);
1164            for _ in 0..op0_len {
1165                ct_mul_trunc_mp_l(&mut op0, n_len, 241);
1166            }
1167            ct_lshift_mp(&mut op0, 8 * (n_len - op0_len));
1168            for e_buf in [&mut e0_buf, &mut e1_buf, &mut e2_buf, &mut ef_buf] {
1169                let e = ET::from_slice(e_buf).unwrap();
1170                test_one::<RT, _, _, _>(&op0, &n, &e);
1171            }
1172        }
1173    }
1174}
1175
1176#[test]
1177fn test_ct_exp_mod_odd_be_be_be_be() {
1178    use super::limbs_buffer::MpMutBigEndianUIntByteSlice;
1179    test_ct_exp_mod_odd_mp_mp::<
1180        MpMutBigEndianUIntByteSlice,
1181        MpMutBigEndianUIntByteSlice,
1182        MpMutBigEndianUIntByteSlice,
1183        MpMutBigEndianUIntByteSlice,
1184    >()
1185}
1186
1187#[test]
1188fn test_ct_exp_mod_odd_le_le_le_le() {
1189    use super::limbs_buffer::MpMutLittleEndianUIntByteSlice;
1190    test_ct_exp_mod_odd_mp_mp::<
1191        MpMutLittleEndianUIntByteSlice,
1192        MpMutLittleEndianUIntByteSlice,
1193        MpMutLittleEndianUIntByteSlice,
1194        MpMutLittleEndianUIntByteSlice,
1195    >()
1196}
1197
1198#[test]
1199fn test_ct_exp_mod_odd_ne_ne_ne_ne() {
1200    use super::limbs_buffer::MpMutNativeEndianUIntLimbsSlice;
1201    test_ct_exp_mod_odd_mp_mp::<
1202        MpMutNativeEndianUIntLimbsSlice,
1203        MpMutNativeEndianUIntLimbsSlice,
1204        MpMutNativeEndianUIntLimbsSlice,
1205        MpMutNativeEndianUIntLimbsSlice,
1206    >()
1207}