Skip to main content

num_modular/
proth.rs

1use crate::impl_fixed_monty_ops;
2use crate::reduced::{impl_reduced_binary_pow, impl_reduced_ops};
3use crate::{powm_u32, powm_u64, udouble, umax, ModularUnaryOps, Reducer};
4
5// Proth primes: m = K * 2^N + 1 (K odd, K < 2^N)
6//
7// Montgomery REDC with R = 2^BITS (so R > m always).  The Montgomery
8// constant N0 = -m⁻¹ mod R is computed at compile time via Newton
9// iteration.  Because R > m the REDC result is always < 2·m, so a
10// single conditional subtraction normalises.
11//
12// The product m·p inside REDC is expanded using the Proth form:
13//   m·(K·2^N + 1) = (m·K)<<N + m
14// Since K ≤ 255 (u8), the (m·K) multiply is narrow (at most 8 bits),
15// replacing a full-width multiply-add with a shift and an add.
16
17// --- macro for FixedProth32 / FixedProth64 ----------------------------------
18
19/// Debug-only primality heuristic: checks that `m` is not divisible by
20/// small primes (3, 5, 7, 11, 13), allowing `m` to be the small prime itself.
21macro_rules! debug_assert_prime_candidate {
22    ($m:expr) => {
23        debug_assert!(
24            ($m == 3 || $m % 3 != 0)
25                && ($m == 5 || $m % 5 != 0)
26                && ($m == 7 || $m % 7 != 0)
27                && ($m == 11 || $m % 11 != 0)
28                && ($m == 13 || $m % 13 != 0)
29        )
30    };
31}
32
33macro_rules! impl_fixed_proth_inherent {
34    ($TypeName:ident, $T:ty, $D:ty, $neginv_fn:path, $powm:ident) => {
35        impl<const N: u8, const K: u8> $TypeName<N, K> {
36            /// Compile-time guard: N must be strictly less than the type bit-width.
37            const _N_BOUND_CHECK: () = assert!((N as u32) < <$T>::BITS);
38
39            pub const MODULUS: $T = {
40                let p2n = match (1 as $T).checked_shl(N as u32) {
41                    Some(v) => v,
42                    None => unreachable!(),
43                };
44                let m = (K as $T).wrapping_mul(p2n).wrapping_add(1);
45                // MODULUS ≤ φ·R guarantees `reduce` never overflows the double-word
46                // sum (φ = (√5−1)/2 ≈ 0.618).
47                assert!(
48                    m as u128
49                        <= match <$T>::BITS {
50                            32 => 2_654_435_769u128,
51                            64 => 11_400_714_819_323_199_485u128,
52                            _ => unreachable!(),
53                        },
54                    "MODULUS exceeds overflow-free bound; lower N or use FixedMontgomery"
55                );
56                m
57            };
58
59            /// Montgomery constant:  -MODULUS⁻¹ mod 2^BITS
60            const N0: $T = $neginv_fn(Self::MODULUS);
61
62            /// R² mod MODULUS  (R = 2^BITS, so R² = 2^{2·BITS})
63            const R2: $T = $powm(2, (2 * <$T>::BITS) as $T, Self::MODULUS);
64
65            #[inline]
66            pub fn reduce(&self, t: $D) -> $T {
67                // Standard Montgomery REDC with Proth-optimised m·p product.
68                // MODULUS ≤ φ·R (guaranteed at compile time) ensures the sum
69                // t + m·MODULUS never exceeds the double-word width.
70                let m = (t as $T).wrapping_mul(Self::N0);
71                // m·p = m·(K·2^N + 1) = (m·K)<<N + m
72                let mp = ((m as $D) * (K as $D)) << N;
73                let mp = mp.wrapping_add(m as $D);
74                let r = (t.wrapping_add(mp) >> <$T>::BITS) as $T;
75                if r >= Self::MODULUS {
76                    r - Self::MODULUS
77                } else {
78                    r
79                }
80            }
81        }
82    };
83}
84
85/// A modular reducer for Proth primes `K * 2^N + 1` with 32-bit operands.
86///
87/// Supports `N` up to 31, `K` odd with `K < 2^N`.  Montgomery form with `R = 2³²`.
88///
89/// # Example
90///
91/// ```rust
92/// use num_modular::{FixedProth32, Reducer};
93///
94/// const N: u8 = 4;
95/// const K: u8 = 1;
96/// let modulus = (K as u32) * (1u32 << N) + 1; // 1*2^4 + 1 = 17
97/// let reducer = FixedProth32::<N, K>::new(&modulus);
98/// let a = reducer.transform(3);
99/// let b = reducer.transform(5);
100/// assert_eq!(reducer.residue(reducer.add(&a, &b)), 8);
101/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), 15);
102/// ```
103#[must_use]
104#[derive(Debug, Clone, Copy)]
105pub struct FixedProth32<const N: u8, const K: u8>;
106
107impl_fixed_proth_inherent!(
108    FixedProth32,
109    u32,
110    u64,
111    crate::monty::neg_mod_inv::u32::neginv,
112    powm_u32
113);
114
115impl<const N: u8, const K: u8> Reducer<u32> for FixedProth32<N, K> {
116    #[inline]
117    fn new(m: &u32) -> Self {
118        assert!(
119            *m == Self::MODULUS,
120            "the given modulus doesn't match with the generic params"
121        );
122        assert!(N < 32, "N must be less than type bit width");
123        assert!(N > 0, "N must be positive");
124        assert!(K > 0, "K must be positive");
125        assert!(K % 2 == 1, "K must be odd");
126        assert!(
127            (K as u64) * (1_u64 << (N as u32)) < u32::MAX as u64,
128            "K·2^N + 1 exceeds type maximum"
129        );
130        debug_assert!((K as u32) < (1u32 << (N as u32)), "K must be less than 2^N");
131        debug_assert_prime_candidate!(Self::MODULUS);
132        Self {}
133    }
134    impl_fixed_monty_ops!(u32, u64, Self::R2, primitive);
135}
136
137/// A modular reducer for Proth primes `K * 2^N + 1` with 64-bit operands.
138///
139/// Supports `N` up to 63, `K` odd with `K < 2^N`.  Montgomery form with `R = 2⁶⁴`.
140///
141/// # Example
142///
143/// ```rust
144/// use num_modular::{FixedProth64, Reducer};
145///
146/// const N: u8 = 5;
147/// const K: u8 = 3;
148/// let modulus = (K as u64) * (1u64 << N) + 1; // 3*2^5 + 1 = 97
149/// let reducer = FixedProth64::<N, K>::new(&modulus);
150/// let a = reducer.transform(10);
151/// let b = reducer.transform(20);
152/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (10u64 * 20) % 97);
153/// ```
154#[must_use]
155#[derive(Debug, Clone, Copy)]
156pub struct FixedProth64<const N: u8, const K: u8>;
157
158impl_fixed_proth_inherent!(
159    FixedProth64,
160    u64,
161    u128,
162    crate::monty::neg_mod_inv::u64::neginv,
163    powm_u64
164);
165
166impl<const N: u8, const K: u8> Reducer<u64> for FixedProth64<N, K> {
167    #[inline]
168    fn new(m: &u64) -> Self {
169        assert!(
170            *m == Self::MODULUS,
171            "the given modulus doesn't match with the generic params"
172        );
173        assert!(N < 64, "N must be less than type bit width");
174        assert!(N > 0, "N must be positive");
175        assert!(K > 0, "K must be positive");
176        assert!(K % 2 == 1, "K must be odd");
177        assert!(
178            (K as u128) * (1_u128 << (N as u32)) < u64::MAX as u128,
179            "K·2^N + 1 exceeds type maximum"
180        );
181        debug_assert!((K as u64) < (1u64 << (N as u32)), "K must be less than 2^N");
182        debug_assert_prime_candidate!(Self::MODULUS);
183        Self {}
184    }
185    impl_fixed_monty_ops!(u64, u128, Self::R2, primitive);
186}
187
188// ── FixedProth (umax / udouble) ──────────────────────────────────────────────
189
190/// A modular reducer for Proth primes `K * 2^N + 1`.
191///
192/// Supports `N` up to 127, `K` odd with `K < 2^N`.  Montgomery form with `R = 2¹²⁸`.
193///
194/// # Example
195///
196/// ```rust
197/// use num_modular::{FixedProth, Reducer};
198///
199/// const N: u8 = 16;
200/// const K: u8 = 1;
201/// let modulus = (K as u128) * (1u128 << N) + 1; // 2^16 + 1 = 65537
202/// let reducer = FixedProth::<N, K>::new(&modulus);
203/// let a = reducer.transform(1000);
204/// let b = reducer.transform(2000);
205/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (1000u128 * 2000) % modulus);
206/// ```
207#[must_use]
208#[derive(Debug, Clone, Copy)]
209pub struct FixedProth<const N: u8, const K: u8>;
210
211impl<const N: u8, const K: u8> FixedProth<N, K> {
212    /// Compile-time guard: N must be strictly less than 128.
213    const _N_BOUND_CHECK_U128: () = assert!(N < 128);
214
215    pub const MODULUS: umax = {
216        let p2n = match 1u128.checked_shl(N as u32) {
217            Some(v) => v,
218            None => unreachable!(),
219        };
220        let m = (K as u128).wrapping_mul(p2n).wrapping_add(1);
221        // MODULUS ≤ φ·R guarantees `reduce` never overflows the udouble sum
222        // (φ = (√5−1)/2 ≈ 0.618).
223        assert!(
224            m <= 210_306_068_529_402_891_650_266_558_847_000_772_608,
225            "MODULUS exceeds overflow-free bound; lower N or use FixedMontgomery"
226        );
227        m
228    };
229
230    /// Montgomery constant:  -MODULUS⁻¹ mod 2¹²⁸
231    const N0: umax = crate::monty::neg_mod_inv::u128::neginv(Self::MODULUS);
232
233    /// R² mod MODULUS  (R = 2¹²⁸, so R² = 2²⁵⁶)
234    const R2: umax = {
235        let r = udouble { hi: 1, lo: 0 }.div_rem_2by1(Self::MODULUS).1; // 2¹²⁸ mod MODULUS
236        udouble::widening_square(r).div_rem_2by1(Self::MODULUS).1 // 2²⁵⁶ mod MODULUS
237    };
238
239    /// Montgomery REDC with R = 2¹²⁸ and Proth-optimised m·p product.
240    #[must_use]
241    #[inline]
242    pub fn reduce(&self, t: udouble) -> umax {
243        let m = t.lo.wrapping_mul(Self::N0);
244        // m·p = m·(K·2^N + 1) = (m·K)<<N + m
245        // K ≤ 255, so the widening_mul is narrow (at most 8 bits).
246        let mk = udouble::widening_mul(m, K as u128);
247        let mp = mk.shl_u32(N as u32) + udouble { hi: 0, lo: m };
248        let r = (t + mp).hi;
249        if r >= Self::MODULUS {
250            r - Self::MODULUS
251        } else {
252            r
253        }
254    }
255}
256
257impl<const N: u8, const K: u8> Reducer<umax> for FixedProth<N, K> {
258    #[inline]
259    fn new(m: &umax) -> Self {
260        assert!(
261            *m == Self::MODULUS,
262            "the given modulus doesn't match with the generic params"
263        );
264        assert!(N < 128, "N must be less than type bit width");
265        assert!(N > 0, "N must be positive");
266        assert!(K > 0, "K must be positive");
267        assert!(K % 2 == 1, "K must be odd");
268        assert!(
269            (K as u128) * (1u128 << (N as u32)) < u128::MAX,
270            "K·2^N + 1 exceeds type maximum"
271        );
272        debug_assert!(
273            (K as u128) < (1u128 << (N as u32)),
274            "K must be less than 2^N"
275        );
276        debug_assert_prime_candidate!(Self::MODULUS);
277        Self {}
278    }
279    impl_fixed_monty_ops!(umax, udouble, Self::R2, udouble);
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::{ModularCoreOps, ModularPow};
286    use rand::random;
287
288    // u128 types
289    type P128_1 = FixedProth<2, 1>; // m = 5
290    type P128_2 = FixedProth<4, 1>; // m = 17
291    type P128_3 = FixedProth<5, 3>; // m = 97
292    type P128_4 = FixedProth<8, 3>; // m = 769
293    type P128_5 = FixedProth<16, 1>; // m = 65537
294
295    // u64 types
296    type P64_1 = FixedProth64<4, 1>; // m = 17
297    type P64_2 = FixedProth64<5, 3>; // m = 97
298    type P64_3 = FixedProth64<8, 1>; // m = 257
299    type P64_4 = FixedProth64<16, 1>; // m = 65537
300
301    // u32 types
302    type P32_1 = FixedProth32<2, 1>; // m = 5
303    type P32_2 = FixedProth32<2, 3>; // m = 13
304    type P32_3 = FixedProth32<4, 1>; // m = 17
305    type P32_4 = FixedProth32<3, 5>; // m = 41
306
307    const NRANDOM: u32 = 10;
308
309    #[test]
310    fn creation_test_u128() {
311        for _ in 0..NRANDOM {
312            let a = random::<u128>();
313
314            const M1: u128 = <P128_1>::MODULUS;
315            let r1 = P128_1::new(&M1);
316            assert_eq!(r1.residue(r1.transform(a % M1)), a % M1);
317
318            const M2: u128 = <P128_2>::MODULUS;
319            let r2 = P128_2::new(&M2);
320            assert_eq!(r2.residue(r2.transform(a % M2)), a % M2);
321
322            const M3: u128 = <P128_3>::MODULUS;
323            let r3 = P128_3::new(&M3);
324            assert_eq!(r3.residue(r3.transform(a % M3)), a % M3);
325
326            const M4: u128 = <P128_4>::MODULUS;
327            let r4 = P128_4::new(&M4);
328            assert_eq!(r4.residue(r4.transform(a % M4)), a % M4);
329
330            const M5: u128 = <P128_5>::MODULUS;
331            let r5 = P128_5::new(&M5);
332            assert_eq!(r5.residue(r5.transform(a % M5)), a % M5);
333        }
334    }
335
336    #[test]
337    fn creation_test_u64() {
338        for _ in 0..NRANDOM {
339            let a = random::<u64>();
340
341            const M1: u64 = <P64_1>::MODULUS;
342            let r1 = P64_1::new(&M1);
343            assert_eq!(r1.residue(r1.transform(a % M1)), a % M1);
344
345            const M2: u64 = <P64_2>::MODULUS;
346            let r2 = P64_2::new(&M2);
347            assert_eq!(r2.residue(r2.transform(a % M2)), a % M2);
348
349            const M3: u64 = <P64_3>::MODULUS;
350            let r3 = P64_3::new(&M3);
351            assert_eq!(r3.residue(r3.transform(a % M3)), a % M3);
352
353            const M4: u64 = <P64_4>::MODULUS;
354            let r4 = P64_4::new(&M4);
355            assert_eq!(r4.residue(r4.transform(a % M4)), a % M4);
356        }
357    }
358
359    #[test]
360    fn creation_test_u32() {
361        for _ in 0..NRANDOM {
362            let a = random::<u32>();
363
364            const M1: u32 = <P32_1>::MODULUS;
365            let r1 = P32_1::new(&M1);
366            assert_eq!(r1.residue(r1.transform(a % M1)), a % M1);
367
368            const M2: u32 = <P32_2>::MODULUS;
369            let r2 = P32_2::new(&M2);
370            assert_eq!(r2.residue(r2.transform(a % M2)), a % M2);
371
372            const M3: u32 = <P32_3>::MODULUS;
373            let r3 = P32_3::new(&M3);
374            assert_eq!(r3.residue(r3.transform(a % M3)), a % M3);
375
376            const M4: u32 = <P32_4>::MODULUS;
377            let r4 = P32_4::new(&M4);
378            assert_eq!(r4.residue(r4.transform(a % M4)), a % M4);
379        }
380    }
381
382    #[test]
383    fn test_against_modops_u128() {
384        macro_rules! tests_for {
385            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
386                const P: u128 = <$M>::MODULUS;
387                let r = <$M>::new(&P);
388                let am = r.transform($a);
389                let bm = r.transform($b);
390                assert_eq!(r.residue(r.add(&am, &bm)), $a.addm($b, &P));
391                assert_eq!(r.residue(r.sub(&am, &bm)), $a.subm($b, &P));
392                assert_eq!(r.residue(r.mul(&am, &bm)), $a.mulm($b, &P));
393                assert_eq!(r.residue(r.neg(am)), $a.negm(&P));
394                assert_eq!(r.residue(r.dbl(am)), $a.dblm(&P));
395                assert_eq!(r.residue(r.sqr(am)), $a.sqm(&P));
396                assert_eq!(r.residue(r.pow(am, &$e)), $a.powm($e, &P));
397                if let (Some(inv), Some(ref_inv)) = (r.inv(am), $a.invm(&P)) {
398                    assert_eq!(r.residue(inv), ref_inv);
399                }
400            })*);
401        }
402
403        for _ in 0..NRANDOM {
404            let a = random::<u128>();
405            let b = random::<u128>();
406            let e = random::<u8>() as u128;
407            tests_for!(a, b, e; P128_1 P128_2 P128_3 P128_4 P128_5);
408        }
409    }
410
411    #[test]
412    fn test_against_modops_u64() {
413        macro_rules! tests_for {
414            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
415                const P: u64 = <$M>::MODULUS;
416                let r = <$M>::new(&P);
417                let am = r.transform($a);
418                let bm = r.transform($b);
419                assert_eq!(r.residue(r.add(&am, &bm)), $a.addm($b, &P));
420                assert_eq!(r.residue(r.sub(&am, &bm)), $a.subm($b, &P));
421                assert_eq!(r.residue(r.mul(&am, &bm)), $a.mulm($b, &P));
422                assert_eq!(r.residue(r.neg(am)), $a.negm(&P));
423                assert_eq!(r.residue(r.dbl(am)), $a.dblm(&P));
424                assert_eq!(r.residue(r.sqr(am)), $a.sqm(&P));
425                assert_eq!(r.residue(r.pow(am, &$e)), $a.powm($e, &P));
426                if let (Some(inv), Some(ref_inv)) = (r.inv(am), $a.invm(&P)) {
427                    assert_eq!(r.residue(inv), ref_inv);
428                }
429            })*);
430        }
431
432        for _ in 0..NRANDOM {
433            let a = random::<u64>();
434            let b = random::<u64>();
435            let e = random::<u8>() as u64;
436            tests_for!(a, b, e; P64_1 P64_2 P64_3 P64_4);
437        }
438    }
439
440    #[test]
441    fn test_against_modops_u32() {
442        macro_rules! tests_for {
443            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
444                const P: u32 = <$M>::MODULUS;
445                let r = <$M>::new(&P);
446                let am = r.transform($a);
447                let bm = r.transform($b);
448                assert_eq!(r.residue(r.add(&am, &bm)), $a.addm($b, &P));
449                assert_eq!(r.residue(r.sub(&am, &bm)), $a.subm($b, &P));
450                assert_eq!(r.residue(r.mul(&am, &bm)), $a.mulm($b, &P));
451                assert_eq!(r.residue(r.neg(am)), $a.negm(&P));
452                assert_eq!(r.residue(r.dbl(am)), $a.dblm(&P));
453                assert_eq!(r.residue(r.sqr(am)), $a.sqm(&P));
454                assert_eq!(r.residue(r.pow(am, &$e)), $a.powm($e, &P));
455                if let (Some(inv), Some(ref_inv)) = (r.inv(am), $a.invm(&P)) {
456                    assert_eq!(r.residue(inv), ref_inv);
457                }
458            })*);
459        }
460
461        for _ in 0..NRANDOM {
462            let a = random::<u32>();
463            let b = random::<u32>();
464            let e = random::<u8>() as u32;
465            tests_for!(a, b, e; P32_1 P32_2 P32_3 P32_4);
466        }
467    }
468
469    #[test]
470    fn test_add_near_overflow_u64() {
471        type S = FixedProth64<32, 3>;
472        const M: u64 = <S>::MODULUS;
473        let r = S::new(&M);
474
475        let a = M - 1;
476        let b = M - 2;
477        let am = r.transform(a);
478        let bm = r.transform(b);
479        let sum = r.add(&am, &bm);
480        assert_eq!(r.residue(sum), a.addm(b, &M));
481
482        let a2 = M - 1;
483        let a2m = r.transform(a2);
484        let dbl = r.dbl(a2m);
485        assert_eq!(r.residue(dbl), a2.dblm(&M));
486    }
487
488    /// Reduce correctness with MODULUS near the overflow-free bound.
489    #[test]
490    fn test_reduce_near_bound() {
491        // 255·2^23 + 1 = 2,139,095,041 (close to φ·2^32 threshold 2,654,435,769)
492        type S = FixedProth32<23, 255>;
493        const M: u32 = <S>::MODULUS;
494        let r = S::new(&M);
495
496        for _ in 0..10 {
497            let a = random::<u32>() % M;
498            let b = random::<u32>() % M;
499            let am = r.transform(a);
500            let bm = r.transform(b);
501            let result = r.residue(r.mul(&am, &bm));
502            assert_eq!(result, a.mulm(b, &M));
503        }
504    }
505
506    /// inv with MODULUS > usize::MAX should not truncate.
507    #[test]
508    fn test_inv_no_truncation_u128() {
509        // N=60 < 64 but MODULUS = 31·2^60+1 > u64::MAX, so the old
510        // `N < usize::BITS` gate would incorrectly take the usize path.
511        type S = FixedProth<60, 31>;
512        const M: u128 = <S>::MODULUS;
513        assert!(
514            M > u64::MAX as u128,
515            "MODULUS must exceed usize for this test"
516        );
517        let r = S::new(&M);
518
519        let a: u128 = 1234567890123456789 % M;
520        let a_mont = r.transform(a);
521        let inv = r.inv(a_mont).expect("inv should succeed");
522        let result = r.residue(inv);
523        assert_eq!(result.mulm(a, &M), 1u128, "inv truncation bug");
524    }
525
526    /// K·2^N exceeding type max should panic, not silently wrap.
527    #[test]
528    #[should_panic(expected = "exceeds type maximum")]
529    fn test_modulus_overflow_panics_u32() {
530        type S = FixedProth32<31, 3>; // 3·2^31+1 > 2^32
531        const M: u32 = <S>::MODULUS; // wraps to 2^31+1
532        let _ = S::new(&M); // should panic
533    }
534
535    /// FixedProth with N>64 should compute reduce correctly
536    /// (no shift truncation in the Proth-optimised m·p product).
537    #[test]
538    fn test_reduce_n_gt_64() {
539        type S = FixedProth<65, 3>; // MODULUS = 3·2^65 + 1
540        const M: u128 = <S>::MODULUS;
541        let r = S::new(&M);
542
543        for _ in 0..10 {
544            let a = random::<u128>() % M;
545            let b = random::<u128>() % M;
546            let am = r.transform(a);
547            let bm = r.transform(b);
548            let result = r.residue(r.mul(&am, &bm));
549            assert_eq!(result, a.mulm(b, &M), "shift truncation bug for N>64");
550        }
551    }
552}