modmath/
exp.rs

1use super::mul::{basic_mod_mul, constrained_mod_mul, strict_mod_mul};
2
3
4/// # Modular Exponentiation (Basic)
5/// Simple version that operates on values and copies them. Requires
6/// `WrappingAdd` and `WrappingSub` traits to be implemented.
7pub fn basic_mod_exp<T>(mut base: T, exponent: T, modulus: T) -> T
8where
9    T: PartialOrd
10        + num_traits::One
11        + num_traits::Zero
12        + core::ops::BitAnd<Output = T>
13        + core::ops::Rem<Output = T>
14        + core::ops::Shr<usize, Output = T>
15        + num_traits::ops::wrapping::WrappingAdd
16        + num_traits::ops::wrapping::WrappingSub
17        + core::ops::ShrAssign<usize>
18        + core::ops::RemAssign<T>
19        + Copy,
20{
21    let two = T::one() + T::one();
22    let mut result = T::one();
23    let mut exp = exponent;
24
25    base %= modulus; // reduce base initially
26
27    while exp > T::zero() {
28        // If the exponent is odd, multiply the result by base
29        if exp % two == T::one() {
30            result = basic_mod_mul(result, base, modulus);
31        }
32        // Right shift the exponent (divide by 2)
33        exp >>= 1;
34
35        // Only square base if exp > 0 (avoid unnecessary squaring in final step)
36        if exp > T::zero() {
37            // Square the base using modular multiplication
38            base = basic_mod_mul(base, base, modulus);
39        }
40    }
41    result
42}
43
44/// # Modular Exponentiation (Constrained)
45/// Version that works with references, requires `WrappingAdd` and
46/// `WrappingSub` traits to be implemented.
47pub fn constrained_mod_exp<T>(mut base: T, exponent: &T, modulus: &T) -> T
48where
49    T: PartialOrd
50        + num_traits::One
51        + num_traits::Zero
52        + num_traits::ops::wrapping::WrappingAdd
53        + num_traits::ops::wrapping::WrappingSub
54        + core::ops::ShrAssign<usize>
55        + core::ops::Shr<usize, Output = T>,
56    for<'a> T: core::ops::RemAssign<&'a T>
57        + core::ops::DivAssign<&'a T>
58        + core::ops::Rem<&'a T, Output = T>,
59    for<'a> &'a T: core::ops::Rem<&'a T, Output = T> + core::ops::BitAnd<Output = T>,
60{
61    base.rem_assign(modulus);
62    let mut result = T::one();
63    let mut exp = T::zero().wrapping_add(exponent);
64    let two = T::one().wrapping_add(&T::one());
65    while exp > T::zero() {
66        if &exp % &two == T::one() {
67            result = constrained_mod_mul(result, &base, modulus);
68        }
69        exp >>= 1;
70        if exp > T::zero() {
71            let tmp_base = T::zero().wrapping_add(&base);
72            base = constrained_mod_mul(base, &tmp_base, modulus);
73        }
74    }
75    result
76}
77
78/// # Modular Exponentiation (Strict)
79/// Most constrained version that works with references. Requires
80/// `OverflowingAdd` and `OverflowingSub` traits to be implemented, and
81/// all multiplication contraints as well.
82pub fn strict_mod_exp<T>(mut base: T, exponent: &T, modulus: &T) -> T
83where
84    T: PartialOrd
85        + num_traits::One
86        + num_traits::Zero
87        + num_traits::ops::overflowing::OverflowingAdd
88        + num_traits::ops::overflowing::OverflowingSub
89        + core::ops::Shr<usize, Output = T>,
90    for<'a> T: core::ops::RemAssign<&'a T>
91        + core::ops::DivAssign<&'a T>
92        + core::ops::ShrAssign<usize>
93        + core::ops::Rem<&'a T, Output = T>,
94    for<'a> &'a T: core::ops::Rem<&'a T, Output = T> + core::ops::BitAnd<Output = T>,
95{
96    let two = T::one().overflowing_add(&T::one()).0;
97    let mut result = T::one();
98    base.rem_assign(modulus);
99    let mut exp = T::zero().overflowing_add(exponent).0;
100
101    while exp > T::zero() {
102        if &exp % &two == T::one() {
103            result = strict_mod_mul(result, &base, modulus);
104        }
105        exp >>= 1;
106
107        if exp > T::zero() {
108            let tmp_base = T::zero().overflowing_add(&base).0;
109            base = strict_mod_mul(base, &tmp_base, modulus);
110        }
111    }
112    result
113}
114
115#[cfg(test)]
116macro_rules! select_mod_exp {
117    ($mod_exp:path, $t:ty, by_ref) => {
118        fn mod_exp(a: $t, b: &$t, m: &$t) -> $t {
119            $mod_exp(a, b, m)
120        }
121    };
122    ($mod_exp:path, $t:ty, by_val) => {
123        fn mod_exp(a: $t, b: &$t, m: &$t) -> $t {
124            $mod_exp(a, *b, *m)
125        }
126    };
127}
128
129#[cfg(test)]
130macro_rules! generate_mod_exp_tests_block_64 {
131    ($mod_add:path, $t:ty, $by_ref:tt) => {
132        select_mod_exp!($mod_add, $t, $by_ref);
133
134        #[test]
135        fn test_basic_small_values() {
136            assert_eq!(mod_exp(2_u64, &3_u64, &5_u64), 3_u64); // 2^3 % 5 = 8 % 5 = 3
137            assert_eq!(mod_exp(5_u64, &0_u64, &7_u64), 1_u64); // 5^0 % 7 = 1
138        }
139
140        #[test]
141        fn test_basic_base_or_exponent_1() {
142            assert_eq!(mod_exp(1_u64, &10_u64, &7_u64), 1_u64); // 1^10 % 7 = 1
143            assert_eq!(mod_exp(7_u64, &1_u64, &13_u64), 7_u64); // 7^1 % 13 = 7
144        }
145
146        #[test]
147        fn test_identity_modulus_of_1() {
148            assert_eq!(mod_exp(10_u64, &10_u64, &1_u64), 0_u64); // Any number % 1 = 0
149        }
150
151        #[test]
152        fn test_identity_exponent_of_0() {
153            assert_eq!(mod_exp(5_u64, &0_u64, &9_u64), 1_u64); // 5^0 % 9 = 1
154        }
155
156        #[test]
157        fn test_identity_zero_to_the_zero() {
158            // Handle 0^0 case based on how it's defined in your mod_exp implementation.
159            assert_eq!(mod_exp(0_u64, &0_u64, &7_u64), 1_u64); // This assumes 0^0 = 1
160        }
161
162        #[test]
163        fn test_edge_max_u64_values() {
164            assert_eq!(mod_exp(u64::MAX, &2_u64, &u64::MAX), 0_u64); // (2^63 - 1)^2 % (2^63 - 1) = 0
165            assert_eq!(
166                mod_exp(u64::MAX, &2_u64, &1_000_000_007_u64),
167                114_944_269_u64
168            );
169            // Big exponent mod test
170        }
171
172        #[test]
173        fn test_edge_base_of_zero() {
174            assert_eq!(mod_exp(0_u64, &10_u64, &7_u64), 0_u64); // 0^10 % 7 = 0
175        }
176
177        #[test]
178        fn test_prime_modulus() {
179            assert_eq!(mod_exp(7_u64, &13_u64, &19_u64), 7_u64); // 7^13 % 19 = 7
180            assert_eq!(mod_exp(3_u64, &13_u64, &17_u64), 12_u64); // 3^13 % 17 = 12
181        }
182
183        #[test]
184        fn test_large_exponent() {
185            // This test assumes efficient modular exponentiation like exponentiation by squaring.
186            assert_eq!(mod_exp(7_u64, &(1 << 20), &13_u64), 9_u64); // 7^2^20 % 13 = 9
187        }
188
189        #[test]
190        fn test_overflow_handling() {
191            assert_eq!(mod_exp(2_u64.pow(32), &2_u64.pow(32), &97_u64), 35_u64); // Big exponent/modulus
192            assert_eq!(
193                mod_exp(2_u64.pow(63), &2_u64.pow(63), &1_000_000_007_u64),
194                719_537_220_u64
195            );
196        }
197
198        #[test]
199        fn test_coprime_values() {
200            assert_eq!(
201                mod_exp(123_456_789_u64, &987_654_321_u64, &1_000_000_007_u64),
202                652_541_198_u64
203            );
204        }
205    };
206}
207
208#[cfg(test)]
209macro_rules! generate_mod_exp_tests_block_8 {
210    ($mod_add:path, $t:ty, $by_ref:tt) => {
211        select_mod_exp!($mod_add, $t, $by_ref);
212
213        #[test]
214        fn test_edge_max_u8_values() {
215            // Equivalent of mod_exp(u64::MAX, 2_u64, u64::MAX) with u8
216            assert_eq!(mod_exp(u8::MAX, &2_u8, &u8::MAX), 0_u8); // (255^2) % 255 = 0
217            assert_eq!(mod_exp(u8::MAX, &2_u8, &97_u8), 35_u8); // (255^2) % 97 = 35
218        }
219
220        #[test]
221        fn test_big_exponent_mod_u8() {
222            assert_eq!(mod_exp(u8::MAX, &2_u8, &97_u8), 35_u8); // (255^2) % 97 = 35
223        }
224
225        #[test]
226        fn test_overflow_handling_u8() {
227            // Equivalent of mod_exp(2^32, 2^32, 97) with u8
228            assert_eq!(mod_exp(2_u8.pow(4), &2_u8.pow(4), &97_u8), 61_u8); // (16^16) % 97 = 61
229        }
230
231        #[test]
232        fn test_prime_modulus_u8() {
233            // Equivalent of mod_exp(7_u64, 13_u64, 19_u64) with u8
234            assert_eq!(mod_exp(7_u8, &13_u8, &19_u8), 7_u8); // 7^13 % 19 = 7
235        }
236    };
237}
238#[cfg(test)]
239macro_rules! generate_mod_exp_tests_block_16 {
240    ($mod_add:path, $t:ty, $by_ref:tt) => {
241        select_mod_exp!($mod_add, $t, $by_ref);
242
243        #[test]
244        fn test_edge_max_u16_values() {
245            // Equivalent of mod_exp(u64::MAX, 2_u64, u64::MAX) with u16
246            assert_eq!(mod_exp(u16::MAX, &2_u16, &u16::MAX), 0_u16); // (65535^2) % 65535 = 0
247        }
248    };
249}
250
251#[cfg(test)]
252macro_rules! generate_mod_exp_tests_block_32 {
253    ($mod_add:path, $t:ty, $by_ref:tt) => {
254        select_mod_exp!($mod_add, $t, $by_ref);
255
256        #[test]
257        fn test_edge_max_u32_values() {
258            // Equivalent of mod_exp(u64::MAX, 2_u64, u64::MAX) with u32
259            assert_eq!(mod_exp(u32::MAX, &2_u32, &u32::MAX), 0_u32); // (4294967295^2) % 4294967295 = 0
260        }
261    };
262}
263
264#[cfg(test)]
265mod strict_mod_exp_tests {
266    use super::strict_mod_exp;
267    mod u64_tests {
268        generate_mod_exp_tests_block_64!(super::strict_mod_exp, u64, by_ref);
269    }
270    mod u8_tests {
271        generate_mod_exp_tests_block_8!(super::strict_mod_exp, u8, by_ref);
272    }
273    mod u16_tests {
274        generate_mod_exp_tests_block_16!(super::strict_mod_exp, u16, by_ref);
275    }
276    mod u32_tests {
277        generate_mod_exp_tests_block_32!(super::strict_mod_exp, u32, by_ref);
278    }
279}
280
281#[cfg(test)]
282mod constrained_mod_exp_tests {
283    use super::constrained_mod_exp;
284
285    mod u64_tests {
286        generate_mod_exp_tests_block_64!(super::constrained_mod_exp, u64, by_ref);
287    }
288    mod u8_tests {
289        generate_mod_exp_tests_block_8!(super::constrained_mod_exp, u8, by_ref);
290    }
291    mod u16_tests {
292        generate_mod_exp_tests_block_16!(super::constrained_mod_exp, u16, by_ref);
293    }
294    mod u32_tests {
295        generate_mod_exp_tests_block_32!(super::constrained_mod_exp, u32, by_ref);
296    }
297}
298
299#[cfg(test)]
300mod basic_mod_exp_tests {
301    use super::basic_mod_exp;
302
303    mod u64_tests {
304        generate_mod_exp_tests_block_64!(super::basic_mod_exp, u64, by_val);
305    }
306    mod u8_tests {
307        generate_mod_exp_tests_block_8!(super::basic_mod_exp, u8, by_val);
308    }
309    mod u16_tests {
310        generate_mod_exp_tests_block_16!(super::basic_mod_exp, u16, by_val);
311    }
312    mod u32_tests {
313        generate_mod_exp_tests_block_32!(super::basic_mod_exp, u32, by_val);
314    }
315}
316
317#[cfg(test)]
318macro_rules! exp_test_module {
319    (
320        $stem:ident,           // Base name (e.g., "bnum")
321        $type_path:path,       // Full path to the type
322        $(type $type_def:ty = $type_expr:ty;)? // Optional type definition
323        strict: $strict:expr,
324        constrained: $constrained:expr,
325        basic: $basic:expr,
326    ) => {
327        paste::paste! {
328            mod [<$stem _tests>] {     // This creates e.g., mod bnum_tests
329                #[allow(unused_imports)]
330                use $type_path;
331                $( type $type_def = $type_expr; )?
332
333                #[test]
334                #[allow(unused_variables)]
335                fn test_mod_exp_basic() {
336                    let a = U256::from(5u8);
337                    let b = U256::from(3u8);
338                    let m = U256::from(13u8);
339
340                    // pow(5,3,13)
341                    let a_val = 5u8;
342                    let a = U256::from(a_val);
343                    let b = U256::from(3u8);
344                    let m = U256::from(13u8);
345                    let result = U256::from(8u8);
346
347                    crate::maybe_test!($strict, assert_eq!(super::strict_mod_exp(a, &b, &m), result));
348                    let a = U256::from(a_val);
349                    crate::maybe_test!($constrained, assert_eq!(super::constrained_mod_exp(a, &b, &m), result));
350                    let a = U256::from(a_val);
351                    crate::maybe_test!($basic, assert_eq!(super::basic_mod_exp(a, b, m), result));
352
353                    // pow(123,45,1000)
354                    let a_val = 123u8;
355                    let a = U256::from(a_val);
356                    let b = U256::from(45u8);
357                    let m = U256::from(1000u16);
358                    let result = U256::from(43u16);
359
360                    crate::maybe_test!($strict, assert_eq!(super::strict_mod_exp(a, &b, &m), result));
361                    let a = U256::from(a_val);
362                    crate::maybe_test!($constrained, assert_eq!(super::constrained_mod_exp(a, &b, &m), result));
363                    let a = U256::from(a_val);
364                    crate::maybe_test!($basic, assert_eq!(super::basic_mod_exp(a, b, m), result));
365                }
366            }
367        }
368    };
369}
370
371#[cfg(test)]
372mod bnum_exp_tests {
373    use super::basic_mod_exp;
374    use super::constrained_mod_exp;
375    use super::strict_mod_exp;
376
377    exp_test_module!(
378        bnum,
379        bnum::types::U256,
380        strict: off, // OverflowingAdd + OverflowingSub is not implemented
381        constrained: on,
382        basic: on,
383    );
384
385    exp_test_module!(
386        bnum_patched,
387        bnum_patched::types::U256,
388        strict: on,
389        constrained: on,
390        basic: on,
391    );
392
393    exp_test_module!(
394        crypto_bigint,
395        crypto_bigint::U256,
396        strict: off, // OverflowingAdd missing
397        constrained: off, // RemAssign
398        basic: off, // RemAssign is not implemented
399    );
400
401    exp_test_module!(
402        crypto_bigint_patched,
403        crypto_bigint_patched::U256,
404        strict: on,
405        constrained: on,
406        basic: on,
407    );
408
409    exp_test_module!(
410        num_bigint,
411        num_bigint::BigUint,
412        type U256 = num_bigint::BigUint;
413        strict: off, // OverflowingAdd missing
414        constrained: off, // WrappingAdd missing
415        basic: off, // Copy cannot be implemented, heap allocation
416    );
417
418    exp_test_module!(
419        num_bigint_patched,
420        num_bigint_patched::BigUint,
421        type U256 = num_bigint_patched::BigUint;
422        strict: on,
423        constrained: on,
424        basic: off, // Copy cannot be implemented, heap allocation
425    );
426
427    exp_test_module!(
428        ibig,
429        ibig::UBig,
430        type U256 = ibig::UBig;
431        strict: off, // OverflowingAdd missing
432        constrained: off, // WrappingAdd missing
433        basic: off, // Copy cannot be implemented, heap allocation
434    );
435
436    exp_test_module!(
437        ibig_patched,
438        ibig_patched::UBig,
439        type U256 = ibig_patched::UBig;
440        strict: on,
441        constrained: on,
442        basic: off, // Copy cannot be implemented, heap allocation
443    );
444
445    exp_test_module!(
446        fixed_bigint,
447        fixed_bigint::FixedUInt,
448        type U256 = fixed_bigint::FixedUInt<u8, 4>;
449        strict: on,
450        constrained: on,
451        basic: on,
452    );
453}