Skip to main content

num_modular/
mersenne.rs

1use crate::reduced::{impl_reduced_binary_pow, impl_reduced_ops};
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_mersenne {
5    (
6        $TypeName:ident,
7        $T:ty,
8        $D:ty,
9        $half_bits:expr,
10        $max_P:expr,
11        $kind:ident
12    ) => {
13        impl<const P: u8, const K: $T> $TypeName<P, K> {
14            const BITMASK: $T = match (1 as $T).checked_shl(P as u32) {
15                Some(v) => v.wrapping_sub(1),
16                None => <$T>::MAX,
17            };
18            pub const MODULUS: $T = {
19                let p1 = match (1 as $T).checked_shl(P as u32) {
20                    Some(v) => v,
21                    None => 0,
22                };
23                p1.wrapping_sub(K)
24            };
25
26            /// Worst-case fold count for `reduce_double`.
27            /// Each fold replaces V = hi·2^P + lo with hi·K + lo (since 2^P ≡ K).
28            /// For K = 1: always 2 folds (the carry chain terminates in at most one
29            /// extra step). For K > 1: ⌈P/(P−⌈log₂K⌉)⌉ + 1 folds.
30            const FOLDS: u32 = if K == 1 {
31                2
32            } else {
33                let s = <$T>::BITS - K.leading_zeros(); // bit-width of K
34                let gap = P as u32 - s;
35                let folds_ceil = (P as u32 + gap - 1) / gap;
36                folds_ceil + 1
37            };
38
39            /// Reduces a single-width value `v` modulo `2^P - K`.
40            ///
41            /// For the result of a widening multiplication or square, use
42            /// [`reduce_double`](Self::reduce_double) instead.
43            pub const fn reduce_single(v: $T) -> $T {
44                let mut lo = v & Self::BITMASK;
45                let mut hi = match v.checked_shr(P as u32) {
46                    Some(s) => s,
47                    None => 0,
48                };
49                while hi > 0 {
50                    let sum = if K == 1 { hi + lo } else { hi * K + lo };
51                    lo = sum & Self::BITMASK;
52                    hi = match sum.checked_shr(P as u32) {
53                        Some(s) => s,
54                        None => 0,
55                    };
56                }
57                if lo >= Self::MODULUS {
58                    lo - Self::MODULUS
59                } else {
60                    lo
61                }
62            }
63
64            impl_fixed_mersenne!(@reduce_double, $kind, $T, $D);
65        }
66
67        impl<const P: u8, const K: $T> Reducer<$T> for $TypeName<P, K> {
68            #[inline]
69            fn new(m: &$T) -> Self {
70                assert!(
71                    *m == Self::MODULUS,
72                    "the given modulus doesn't match with the generic params"
73                );
74                debug_assert!(P <= $max_P);
75                debug_assert!(K > 0 && K < (2 as $T).pow(P as u32 - 1) && K % 2 == 1);
76                debug_assert!(
77                    (Self::MODULUS == 3 || Self::MODULUS % 3 != 0)
78                        && (Self::MODULUS == 5 || Self::MODULUS % 5 != 0)
79                        && (Self::MODULUS == 7 || Self::MODULUS % 7 != 0)
80                        && (Self::MODULUS == 11 || Self::MODULUS % 11 != 0)
81                        && (Self::MODULUS == 13 || Self::MODULUS % 13 != 0)
82                ); // error on easy composites
83                Self {}
84            }
85            #[inline]
86            fn transform(&self, target: $T) -> $T {
87                Self::reduce_single(target)
88            }
89            #[inline]
90            fn residue(&self, target: $T) -> $T {
91                target
92            }
93
94            impl_reduced_ops!($T);
95
96            #[inline]
97            fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
98                if (P as u32) < $half_bits {
99                    Self::reduce_single(lhs * rhs)
100                } else {
101                    Self::reduce_double(impl_fixed_mersenne!(@widen_mul, $kind, $T, $D, lhs, rhs))
102                }
103            }
104            #[inline]
105            fn inv(&self, target: $T) -> Option<$T> {
106                if (P as u32) < usize::BITS {
107                    (target as usize)
108                        .invm(&(Self::MODULUS as usize))
109                        .map(|v| v as $T)
110                } else {
111                    target.invm(&Self::MODULUS)
112                }
113            }
114            #[inline]
115            fn sqr(&self, target: $T) -> $T {
116                if (P as u32) < $half_bits {
117                    Self::reduce_single(target * target)
118                } else {
119                    Self::reduce_double(impl_fixed_mersenne!(@widen_sqr, $kind, $T, $D, target))
120                }
121            }
122
123            impl_reduced_binary_pow!($T);
124        }
125    };
126
127    // Internal: reduce_double for primitive double-width types (u32→u64, u64→u128)
128    //
129    // For real pseudo-Mersennes, FOLDS is always ≤ 3 (K=1 → 2; small K → 3).
130    // Unrolling replaces the data-dependent while loop with straight-line folds.
131    // Extra folds past the true count are no-ops (hi reaches 0).
132    (@reduce_double, primitive, $T:ty, $D:ty) => {
133        /// Reduces a double-width value `v` modulo `2^P - K`.
134        ///
135        /// This handles widening-multiplication or widening-square results.
136        /// For single-width values, use [`reduce_single`](Self::reduce_single).
137        pub fn reduce_double(v: $D) -> $T {
138            let mut lo = (v as $T) & Self::BITMASK;
139            let mut hi = v >> P;
140            macro_rules! mersenne_fold {
141                () => {
142                    let sum = if K == 1 {
143                        hi + lo as $D
144                    } else {
145                        hi * (K as $D) + lo as $D
146                    };
147                    lo = (sum as $T) & Self::BITMASK;
148                    hi = sum >> P;
149                };
150            }
151            if Self::FOLDS <= 2 {
152                #[allow(unused_assignments)] { mersenne_fold!(); }
153                #[allow(unused_assignments)] { mersenne_fold!(); }
154            } else if Self::FOLDS == 3 {
155                #[allow(unused_assignments)] { mersenne_fold!(); }
156                #[allow(unused_assignments)] { mersenne_fold!(); }
157                #[allow(unused_assignments)] { mersenne_fold!(); }
158            } else {
159                while hi > 0 { mersenne_fold!(); }
160            }
161            if lo >= Self::MODULUS {
162                lo - Self::MODULUS
163            } else {
164                lo
165            }
166        }
167    };
168
169    // Internal: reduce_double for udouble (u128→udouble)
170    //
171    // Phase 1 (udouble while hi.hi > 0) is unreachable for valid P ≤ 128 since
172    // hi = v >> P < 2^P ≤ 2^128 always fits in one word. Phase 2 uses u128
173    // arithmetic and is unrolled when FOLDS ≤ 3 (all practical pseudo-Mersennes).
174    (@reduce_double, udouble, $T:ty, $D:ty) => {
175        /// Reduces a double-width value `v` modulo `2^P - K`.
176        ///
177        /// This handles widening-multiplication or widening-square results.
178        /// For single-width values, use [`reduce_single`](Self::reduce_single).
179        pub fn reduce_double(v: $D) -> $T {
180            let mut lo = v.lo & Self::BITMASK;
181            let mut hi = v >> P;
182            while hi.hi > 0 {
183                let sum = if K == 1 { hi + lo } else { hi * K + lo };
184                lo = sum.lo & Self::BITMASK;
185                hi = sum >> P;
186            }
187            let mut hi = hi.lo;
188            macro_rules! mersenne_u128_fold {
189                () => {
190                    let sum = if K == 1 { hi + lo } else { hi * K + lo };
191                    lo = sum & Self::BITMASK;
192                    hi = match sum.checked_shr(P as u32) {
193                        Some(s) => s,
194                        None => 0,
195                    };
196                };
197            }
198            if Self::FOLDS <= 2 {
199                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
200                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
201            } else if Self::FOLDS == 3 {
202                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
203                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
204                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
205            } else {
206                while hi > 0 { mersenne_u128_fold!(); }
207            }
208            if lo >= Self::MODULUS {
209                lo - Self::MODULUS
210            } else {
211                lo
212            }
213        }
214    };
215
216    // Internal: widening multiplication for primitive types
217    (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
218        (*$lhs as $D) * (*$rhs as $D)
219    };
220
221    // Internal: widening multiplication for udouble
222    (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
223        <$D>::widening_mul(*$lhs, *$rhs)
224    };
225
226    // Internal: widening square for primitive types
227    (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
228        ($target as $D) * ($target as $D)
229    };
230
231    // Internal: widening square for udouble
232    (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
233        <$D>::widening_square($target)
234    };
235}
236
237/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus with 32-bit operands.
238///
239/// Supports `P` up to 32 and `K < 2^(P-1)`. All inputs and outputs are `u32`.
240/// The modulus `2^P - K` must be prime for modular inverse and Fermat-based operations to be valid.
241///
242/// # Example
243///
244/// ```rust
245/// use num_modular::{FixedMersenne32, Reducer};
246///
247/// const P: u8 = 13;
248/// const K: u32 = 1;
249/// let modulus = (1u32 << P) - K; // 2^13 - 1 = 8191 (Mersenne prime)
250/// let reducer = FixedMersenne32::<P, K>::new(&modulus);
251/// let a = reducer.transform(100);
252/// let b = reducer.transform(200);
253/// assert_eq!(reducer.residue(reducer.add(&a, &b)), 300 % modulus);
254/// ```
255#[must_use]
256#[derive(Debug, Clone, Copy)]
257pub struct FixedMersenne32<const P: u8, const K: u32>();
258
259impl_fixed_mersenne!(FixedMersenne32, u32, u64, 16, 32, primitive);
260
261/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus with 64-bit operands.
262///
263/// Supports `P` up to 64 and `K < 2^(P-1)`. All inputs and outputs are `u64`.
264/// Uses `u128` as the double-width intermediate for multiplication and reduction.
265/// The modulus `2^P - K` must be prime for modular inverse and Fermat-based operations to be valid.
266///
267/// # Example
268///
269/// ```rust
270/// use num_modular::{FixedMersenne64, Reducer};
271///
272/// const P: u8 = 61;
273/// const K: u64 = 1;
274/// let modulus = (1u64 << P) - K; // 2^61 - 1 (Mersenne prime)
275/// let reducer = FixedMersenne64::<P, K>::new(&modulus);
276/// let a = reducer.transform(1000);
277/// let b = reducer.transform(2000);
278/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (1000u64 * 2000) % modulus);
279/// ```
280#[must_use]
281#[derive(Debug, Clone, Copy)]
282pub struct FixedMersenne64<const P: u8, const K: u64>();
283
284impl_fixed_mersenne!(FixedMersenne64, u64, u128, 32, 64, primitive);
285
286/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus.
287///
288/// Supports `P` up to 128 and `K < 2^(P-1)`. All inputs and outputs are [umax] (currently `u128`).
289///
290/// The `P` is limited to 128 so that overflow checks aren't necessary. This covers all Mersenne
291/// primes within the range of [umax] (i.e. `u128`).
292///
293/// # Example
294///
295/// ```rust
296/// use num_modular::{FixedMersenne, Reducer, umax};
297///
298/// const P: u8 = 31;
299/// const K: umax = 1;
300/// let modulus = (1 << P) - K; // 2^31 - 1 (Mersenne prime)
301/// let reducer = FixedMersenne::<P, K>::new(&modulus);
302/// let a = reducer.transform(1000);
303/// let b = reducer.transform(2000);
304/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (1000 * 2000) % modulus);
305/// ```
306#[must_use]
307#[derive(Debug, Clone, Copy)]
308pub struct FixedMersenne<const P: u8, const K: umax>();
309
310impl_fixed_mersenne!(FixedMersenne, umax, udouble, 64, 128, udouble);
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::{ModularCoreOps, ModularPow};
316    use rand::random;
317
318    // u128 tests (existing)
319    type M1 = FixedMersenne<31, 1>;
320    type M2 = FixedMersenne<61, 1>;
321    type M3 = FixedMersenne<127, 1>;
322    type M4 = FixedMersenne<32, 5>;
323    type M5 = FixedMersenne<56, 5>;
324    type M6 = FixedMersenne<122, 3>;
325    type M7 = FixedMersenne<128, 159>;
326
327    // u64 tests
328    type M64_1 = FixedMersenne64<31, 1>;
329    type M64_2 = FixedMersenne64<61, 1>;
330    type M64_3 = FixedMersenne64<32, 5>;
331    type M64_4 = FixedMersenne64<64, 59>;
332
333    // u32 tests
334    type M32_1 = FixedMersenne32<13, 1>;
335    type M32_2 = FixedMersenne32<31, 1>;
336    type M32_3 = FixedMersenne32<16, 5>;
337
338    const NRANDOM: u32 = 10;
339
340    #[test]
341    fn creation_test_u128() {
342        const P: umax = (1 << 31) - 1;
343        let m = M1::new(&P);
344        assert_eq!(m.residue(m.transform(0)), 0);
345        assert_eq!(m.residue(m.transform(1)), 1);
346        assert_eq!(m.residue(m.transform(P)), 0);
347        assert_eq!(m.residue(m.transform(P - 1)), P - 1);
348        assert_eq!(m.residue(m.transform(P + 1)), 1);
349
350        for _ in 0..NRANDOM {
351            let a = random::<umax>();
352
353            const P1: umax = (1 << 31) - 1;
354            let m1 = M1::new(&P1);
355            assert_eq!(m1.residue(m1.transform(a)), a % P1);
356            const P2: umax = (1 << 61) - 1;
357            let m2 = M2::new(&P2);
358            assert_eq!(m2.residue(m2.transform(a)), a % P2);
359            const P3: umax = (1 << 127) - 1;
360            let m3 = M3::new(&P3);
361            assert_eq!(m3.residue(m3.transform(a)), a % P3);
362            const P4: umax = (1 << 32) - 5;
363            let m4 = M4::new(&P4);
364            assert_eq!(m4.residue(m4.transform(a)), a % P4);
365            const P5: umax = (1 << 56) - 5;
366            let m5 = M5::new(&P5);
367            assert_eq!(m5.residue(m5.transform(a)), a % P5);
368            const P6: umax = (1 << 122) - 3;
369            let m6 = M6::new(&P6);
370            assert_eq!(m6.residue(m6.transform(a)), a % P6);
371            const P7: umax = M7::MODULUS;
372            let m7 = M7::new(&P7);
373            assert_eq!(m7.residue(m7.transform(a)), a % P7);
374        }
375    }
376
377    #[test]
378    fn creation_test_u64() {
379        for _ in 0..NRANDOM {
380            let a = random::<u64>();
381
382            const P1: u64 = (1 << 31) - 1;
383            let m1 = M64_1::new(&P1);
384            assert_eq!(m1.residue(m1.transform(a)), a % P1);
385            const P2: u64 = (1 << 61) - 1;
386            let m2 = M64_2::new(&P2);
387            assert_eq!(m2.residue(m2.transform(a)), a % P2);
388            const P3: u64 = (1 << 32) - 5;
389            let m3 = M64_3::new(&P3);
390            assert_eq!(m3.residue(m3.transform(a)), a % P3);
391            const P4: u64 = M64_4::MODULUS;
392            let m4 = M64_4::new(&P4);
393            assert_eq!(m4.residue(m4.transform(a)), a % P4);
394        }
395    }
396
397    #[test]
398    fn creation_test_u32() {
399        for _ in 0..NRANDOM {
400            let a = random::<u32>();
401
402            const P1: u32 = (1 << 13) - 1;
403            let m1 = M32_1::new(&P1);
404            assert_eq!(m1.residue(m1.transform(a)), a % P1);
405            const P2: u32 = (1 << 31) - 1;
406            let m2 = M32_2::new(&P2);
407            assert_eq!(m2.residue(m2.transform(a)), a % P2);
408            const P3: u32 = (1 << 16) - 5;
409            let m3 = M32_3::new(&P3);
410            assert_eq!(m3.residue(m3.transform(a)), a % P3);
411        }
412    }
413
414    #[test]
415    fn test_against_modops_u128() {
416        macro_rules! tests_for {
417            ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
418                const P: umax = <$M>::MODULUS;
419                let r = <$M>::new(&P);
420                let am = r.transform($a);
421                let bm = r.transform($b);
422                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
423                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
424                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
425                assert_eq!(r.neg(am), $a.negm(&P));
426                assert_eq!(r.inv(am), $a.invm(&P));
427                assert_eq!(r.dbl(am), $a.dblm(&P));
428                assert_eq!(r.sqr(am), $a.sqm(&P));
429                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
430            })*);
431        }
432
433        for _ in 0..NRANDOM {
434            let (a, b) = (random::<u128>(), random::<u128>());
435            let e = random::<u8>() as umax;
436            tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
437        }
438    }
439
440    #[test]
441    fn test_against_modops_u64() {
442        macro_rules! tests_for {
443            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
444                const P: u64 = <$M>::MODULUS;
445                let r = <$M>::new(&P);
446                let am = r.transform($a);
447                let bm = r.transform($b);
448                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
449                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
450                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
451                assert_eq!(r.neg(am), $a.negm(&P));
452                assert_eq!(r.inv(am), $a.invm(&P));
453                assert_eq!(r.dbl(am), $a.dblm(&P));
454                assert_eq!(r.sqr(am), $a.sqm(&P));
455                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
456            })*);
457        }
458
459        for _ in 0..NRANDOM {
460            let a = random::<u64>();
461            let b = random::<u64>();
462            let e = random::<u8>() as u64;
463            tests_for!(a, b, e; M64_1 M64_2 M64_3);
464        }
465    }
466
467    #[test]
468    fn test_against_modops_u32() {
469        macro_rules! tests_for {
470            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
471                const P: u32 = <$M>::MODULUS;
472                let r = <$M>::new(&P);
473                let am = r.transform($a);
474                let bm = r.transform($b);
475                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
476                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
477                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
478                assert_eq!(r.neg(am), $a.negm(&P));
479                assert_eq!(r.inv(am), $a.invm(&P));
480                assert_eq!(r.dbl(am), $a.dblm(&P));
481                assert_eq!(r.sqr(am), $a.sqm(&P));
482                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
483            })*);
484        }
485
486        for _ in 0..NRANDOM {
487            let a = random::<u32>();
488            let b = random::<u32>();
489            let e = random::<u8>() as u32;
490            tests_for!(a, b, e; M32_1 M32_2 M32_3);
491        }
492    }
493
494    #[test]
495    fn small_prime_moduli() {
496        // Small primes that are themselves 3, 5, 7, 11, or 13 should not
497        // trigger the composite check debug_assert.
498        // 2^2 - 1 = 3, 2^3 - 1 = 7, 2^4 - 1 = 15 (not prime, skip)
499        // 2^5 - 1 = 31 (Mersenne prime)
500        type M3 = FixedMersenne<2, 1>; // modulus = 3
501        type M7 = FixedMersenne<3, 1>; // modulus = 7
502
503        const M3_MOD: umax = M3::MODULUS;
504        let r3 = M3::new(&M3_MOD);
505        assert_eq!(r3.residue(r3.transform(5 % M3_MOD)), 5 % M3_MOD);
506
507        const M7_MOD: umax = M7::MODULUS;
508        let r7 = M7::new(&M7_MOD);
509        assert_eq!(r7.residue(r7.transform(10 % M7_MOD)), 10 % M7_MOD);
510
511        // 32-bit variants
512        type M32_3 = FixedMersenne32<2, 1>; // modulus = 3
513        type M32_7 = FixedMersenne32<3, 1>; // modulus = 7
514
515        const P3: u32 = M32_3::MODULUS;
516        let r3 = M32_3::new(&P3);
517        assert_eq!(r3.residue(r3.transform(5 % P3)), 5 % P3);
518
519        const P7: u32 = M32_7::MODULUS;
520        let r7 = M32_7::new(&P7);
521        assert_eq!(r7.residue(r7.transform(10 % P7)), 10 % P7);
522    }
523}