Skip to main content

num_modular/
mersenne.rs

1use crate::reduced::impl_reduced_binary_pow;
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_mersenne {
5    (
6        $TypeName:ident,
7        $T:ty,
8        $D:ty,
9        $half_bits:expr,
10        $max_P:expr,
11        $kind:ident
12    ) => {
13        impl<const P: u8, const K: $T> $TypeName<P, K> {
14            const BITMASK: $T = match (1 as $T).checked_shl(P as u32) {
15                Some(v) => v.wrapping_sub(1),
16                None => <$T>::MAX,
17            };
18            pub const MODULUS: $T = {
19                let p1 = match (1 as $T).checked_shl(P as u32) {
20                    Some(v) => v,
21                    None => 0,
22                };
23                p1.wrapping_sub(K)
24            };
25
26            /// Worst-case fold count for `reduce_double`.
27            /// Each fold replaces V = hi·2^P + lo with hi·K + lo (since 2^P ≡ K).
28            /// For K = 1: always 2 folds (the carry chain terminates in at most one
29            /// extra step). For K > 1: ⌈P/(P−⌈log₂K⌉)⌉ + 1 folds.
30            const FOLDS: u32 = if K == 1 {
31                2
32            } else {
33                let s = <$T>::BITS - K.leading_zeros(); // bit-width of K
34                let gap = P as u32 - s;
35                let folds_ceil = (P as u32 + gap - 1) / gap;
36                folds_ceil + 1
37            };
38
39            const fn reduce_single(v: $T) -> $T {
40                let mut lo = v & Self::BITMASK;
41                let mut hi = match v.checked_shr(P as u32) {
42                    Some(s) => s,
43                    None => 0,
44                };
45                while hi > 0 {
46                    let sum = if K == 1 { hi + lo } else { hi * K + lo };
47                    lo = sum & Self::BITMASK;
48                    hi = match sum.checked_shr(P as u32) {
49                        Some(s) => s,
50                        None => 0,
51                    };
52                }
53                if lo >= Self::MODULUS {
54                    lo - Self::MODULUS
55                } else {
56                    lo
57                }
58            }
59
60            impl_fixed_mersenne!(@reduce_double, $kind, $T, $D);
61        }
62
63        impl<const P: u8, const K: $T> Reducer<$T> for $TypeName<P, K> {
64            #[inline]
65            fn new(m: &$T) -> Self {
66                assert!(
67                    *m == Self::MODULUS,
68                    "the given modulus doesn't match with the generic params"
69                );
70                debug_assert!(P <= $max_P);
71                debug_assert!(K > 0 && K < (2 as $T).pow(P as u32 - 1) && K % 2 == 1);
72                debug_assert!(
73                    Self::MODULUS % 3 != 0
74                        && Self::MODULUS % 5 != 0
75                        && Self::MODULUS % 7 != 0
76                        && Self::MODULUS % 11 != 0
77                        && Self::MODULUS % 13 != 0
78                ); // error on easy composites
79                Self {}
80            }
81            #[inline]
82            fn transform(&self, target: $T) -> $T {
83                Self::reduce_single(target)
84            }
85            #[inline]
86            fn check(&self, target: &$T) -> bool {
87                *target < Self::MODULUS
88            }
89            #[inline]
90            fn residue(&self, target: $T) -> $T {
91                target
92            }
93            #[inline]
94            fn modulus(&self) -> $T {
95                Self::MODULUS
96            }
97            #[inline]
98            fn is_zero(&self, target: &$T) -> bool {
99                target == &0
100            }
101
102            #[inline]
103            fn add(&self, lhs: &$T, rhs: &$T) -> $T {
104                let mut sum = lhs + rhs;
105                if sum >= Self::MODULUS {
106                    sum -= Self::MODULUS
107                }
108                sum
109            }
110            #[inline]
111            fn sub(&self, lhs: &$T, rhs: &$T) -> $T {
112                if lhs >= rhs {
113                    lhs - rhs
114                } else {
115                    Self::MODULUS - (rhs - lhs)
116                }
117            }
118            #[inline]
119            fn dbl(&self, target: $T) -> $T {
120                self.add(&target, &target)
121            }
122            #[inline]
123            fn neg(&self, target: $T) -> $T {
124                if target == 0 {
125                    0
126                } else {
127                    Self::MODULUS - target
128                }
129            }
130            #[inline]
131            fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
132                if (P as u32) < $half_bits {
133                    Self::reduce_single(lhs * rhs)
134                } else {
135                    Self::reduce_double(impl_fixed_mersenne!(@widen_mul, $kind, $T, $D, lhs, rhs))
136                }
137            }
138            #[inline]
139            fn inv(&self, target: $T) -> Option<$T> {
140                if (P as u32) < usize::BITS {
141                    (target as usize)
142                        .invm(&(Self::MODULUS as usize))
143                        .map(|v| v as $T)
144                } else {
145                    target.invm(&Self::MODULUS)
146                }
147            }
148            #[inline]
149            fn sqr(&self, target: $T) -> $T {
150                if (P as u32) < $half_bits {
151                    Self::reduce_single(target * target)
152                } else {
153                    Self::reduce_double(impl_fixed_mersenne!(@widen_sqr, $kind, $T, $D, target))
154                }
155            }
156
157            impl_reduced_binary_pow!($T);
158        }
159    };
160
161    // Internal: reduce_double for primitive double-width types (u32→u64, u64→u128)
162    //
163    // For real pseudo-Mersennes, FOLDS is always ≤ 3 (K=1 → 2; small K → 3).
164    // Unrolling replaces the data-dependent while loop with straight-line folds.
165    // Extra folds past the true count are no-ops (hi reaches 0).
166    (@reduce_double, primitive, $T:ty, $D:ty) => {
167        fn reduce_double(v: $D) -> $T {
168            let mut lo = (v as $T) & Self::BITMASK;
169            let mut hi = v >> P;
170            macro_rules! mersenne_fold {
171                () => {
172                    let sum = if K == 1 {
173                        hi + lo as $D
174                    } else {
175                        hi * (K as $D) + lo as $D
176                    };
177                    lo = (sum as $T) & Self::BITMASK;
178                    hi = sum >> P;
179                };
180            }
181            if Self::FOLDS <= 2 {
182                #[allow(unused_assignments)] { mersenne_fold!(); }
183                #[allow(unused_assignments)] { mersenne_fold!(); }
184            } else if Self::FOLDS == 3 {
185                #[allow(unused_assignments)] { mersenne_fold!(); }
186                #[allow(unused_assignments)] { mersenne_fold!(); }
187                #[allow(unused_assignments)] { mersenne_fold!(); }
188            } else {
189                while hi > 0 { mersenne_fold!(); }
190            }
191            if lo >= Self::MODULUS {
192                lo - Self::MODULUS
193            } else {
194                lo
195            }
196        }
197    };
198
199    // Internal: reduce_double for udouble (u128→udouble)
200    //
201    // Phase 1 (udouble while hi.hi > 0) is unreachable for valid P ≤ 128 since
202    // hi = v >> P < 2^P ≤ 2^128 always fits in one word. Phase 2 uses u128
203    // arithmetic and is unrolled when FOLDS ≤ 3 (all practical pseudo-Mersennes).
204    (@reduce_double, udouble, $T:ty, $D:ty) => {
205        fn reduce_double(v: $D) -> $T {
206            let mut lo = v.lo & Self::BITMASK;
207            let mut hi = v >> P;
208            while hi.hi > 0 {
209                let sum = if K == 1 { hi + lo } else { hi * K + lo };
210                lo = sum.lo & Self::BITMASK;
211                hi = sum >> P;
212            }
213            let mut hi = hi.lo;
214            macro_rules! mersenne_u128_fold {
215                () => {
216                    let sum = if K == 1 { hi + lo } else { hi * K + lo };
217                    lo = sum & Self::BITMASK;
218                    hi = match sum.checked_shr(P as u32) {
219                        Some(s) => s,
220                        None => 0,
221                    };
222                };
223            }
224            if Self::FOLDS <= 2 {
225                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
226                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
227            } else if Self::FOLDS == 3 {
228                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
229                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
230                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
231            } else {
232                while hi > 0 { mersenne_u128_fold!(); }
233            }
234            if lo >= Self::MODULUS {
235                lo - Self::MODULUS
236            } else {
237                lo
238            }
239        }
240    };
241
242    // Internal: widening multiplication for primitive types
243    (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
244        (*$lhs as $D) * (*$rhs as $D)
245    };
246
247    // Internal: widening multiplication for udouble
248    (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
249        <$D>::widening_mul(*$lhs, *$rhs)
250    };
251
252    // Internal: widening square for primitive types
253    (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
254        ($target as $D) * ($target as $D)
255    };
256
257    // Internal: widening square for udouble
258    (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
259        <$D>::widening_square($target)
260    };
261}
262
263/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus with 32-bit operands.
264///
265/// Supports `P` up to 32 and `K < 2^(P-1)`. All inputs and outputs are `u32`.
266/// The modulus `2^P - K` must be prime for modular inverse and Fermat-based operations to be valid.
267///
268/// # Example
269///
270/// ```rust
271/// use num_modular::{FixedMersenne32, Reducer};
272///
273/// const P: u8 = 13;
274/// const K: u32 = 1;
275/// let modulus = (1u32 << P) - K; // 2^13 - 1 = 8191 (Mersenne prime)
276/// let reducer = FixedMersenne32::<P, K>::new(&modulus);
277/// let a = reducer.transform(100);
278/// let b = reducer.transform(200);
279/// assert_eq!(reducer.residue(reducer.add(&a, &b)), 300 % modulus);
280/// ```
281#[derive(Debug, Clone, Copy)]
282pub struct FixedMersenne32<const P: u8, const K: u32>();
283
284impl_fixed_mersenne!(FixedMersenne32, u32, u64, 16, 32, primitive);
285
286/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus with 64-bit operands.
287///
288/// Supports `P` up to 64 and `K < 2^(P-1)`. All inputs and outputs are `u64`.
289/// Uses `u128` as the double-width intermediate for multiplication and reduction.
290/// The modulus `2^P - K` must be prime for modular inverse and Fermat-based operations to be valid.
291///
292/// # Example
293///
294/// ```rust
295/// use num_modular::{FixedMersenne64, Reducer};
296///
297/// const P: u8 = 61;
298/// const K: u64 = 1;
299/// let modulus = (1u64 << P) - K; // 2^61 - 1 (Mersenne prime)
300/// let reducer = FixedMersenne64::<P, K>::new(&modulus);
301/// let a = reducer.transform(1000);
302/// let b = reducer.transform(2000);
303/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (1000u64 * 2000) % modulus);
304/// ```
305#[derive(Debug, Clone, Copy)]
306pub struct FixedMersenne64<const P: u8, const K: u64>();
307
308impl_fixed_mersenne!(FixedMersenne64, u64, u128, 32, 64, primitive);
309
310/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus.
311///
312/// Supports `P` up to 128 and `K < 2^(P-1)`. All inputs and outputs are [umax] (currently `u128`).
313///
314/// The `P` is limited to 128 so that overflow checks aren't necessary. This covers all Mersenne
315/// primes within the range of [umax] (i.e. `u128`).
316///
317/// # Example
318///
319/// ```rust
320/// use num_modular::{FixedMersenne, Reducer, umax};
321///
322/// const P: u8 = 31;
323/// const K: umax = 1;
324/// let modulus = (1 << P) - K; // 2^31 - 1 (Mersenne prime)
325/// let reducer = FixedMersenne::<P, K>::new(&modulus);
326/// let a = reducer.transform(1000);
327/// let b = reducer.transform(2000);
328/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (1000 * 2000) % modulus);
329/// ```
330#[derive(Debug, Clone, Copy)]
331pub struct FixedMersenne<const P: u8, const K: umax>();
332
333impl_fixed_mersenne!(FixedMersenne, umax, udouble, 64, 128, udouble);
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::{ModularCoreOps, ModularPow};
339    use rand::random;
340
341    // u128 tests (existing)
342    type M1 = FixedMersenne<31, 1>;
343    type M2 = FixedMersenne<61, 1>;
344    type M3 = FixedMersenne<127, 1>;
345    type M4 = FixedMersenne<32, 5>;
346    type M5 = FixedMersenne<56, 5>;
347    type M6 = FixedMersenne<122, 3>;
348    type M7 = FixedMersenne<128, 159>;
349
350    // u64 tests
351    type M64_1 = FixedMersenne64<31, 1>;
352    type M64_2 = FixedMersenne64<61, 1>;
353    type M64_3 = FixedMersenne64<32, 5>;
354    type M64_4 = FixedMersenne64<64, 59>;
355
356    // u32 tests
357    type M32_1 = FixedMersenne32<13, 1>;
358    type M32_2 = FixedMersenne32<31, 1>;
359    type M32_3 = FixedMersenne32<16, 5>;
360
361    const NRANDOM: u32 = 10;
362
363    #[test]
364    fn creation_test_u128() {
365        const P: umax = (1 << 31) - 1;
366        let m = M1::new(&P);
367        assert_eq!(m.residue(m.transform(0)), 0);
368        assert_eq!(m.residue(m.transform(1)), 1);
369        assert_eq!(m.residue(m.transform(P)), 0);
370        assert_eq!(m.residue(m.transform(P - 1)), P - 1);
371        assert_eq!(m.residue(m.transform(P + 1)), 1);
372
373        for _ in 0..NRANDOM {
374            let a = random::<umax>();
375
376            const P1: umax = (1 << 31) - 1;
377            let m1 = M1::new(&P1);
378            assert_eq!(m1.residue(m1.transform(a)), a % P1);
379            const P2: umax = (1 << 61) - 1;
380            let m2 = M2::new(&P2);
381            assert_eq!(m2.residue(m2.transform(a)), a % P2);
382            const P3: umax = (1 << 127) - 1;
383            let m3 = M3::new(&P3);
384            assert_eq!(m3.residue(m3.transform(a)), a % P3);
385            const P4: umax = (1 << 32) - 5;
386            let m4 = M4::new(&P4);
387            assert_eq!(m4.residue(m4.transform(a)), a % P4);
388            const P5: umax = (1 << 56) - 5;
389            let m5 = M5::new(&P5);
390            assert_eq!(m5.residue(m5.transform(a)), a % P5);
391            const P6: umax = (1 << 122) - 3;
392            let m6 = M6::new(&P6);
393            assert_eq!(m6.residue(m6.transform(a)), a % P6);
394            const P7: umax = M7::MODULUS;
395            let m7 = M7::new(&P7);
396            assert_eq!(m7.residue(m7.transform(a)), a % P7);
397        }
398    }
399
400    #[test]
401    fn creation_test_u64() {
402        for _ in 0..NRANDOM {
403            let a = random::<u64>();
404
405            const P1: u64 = (1 << 31) - 1;
406            let m1 = M64_1::new(&P1);
407            assert_eq!(m1.residue(m1.transform(a)), a % P1);
408            const P2: u64 = (1 << 61) - 1;
409            let m2 = M64_2::new(&P2);
410            assert_eq!(m2.residue(m2.transform(a)), a % P2);
411            const P3: u64 = (1 << 32) - 5;
412            let m3 = M64_3::new(&P3);
413            assert_eq!(m3.residue(m3.transform(a)), a % P3);
414            const P4: u64 = M64_4::MODULUS;
415            let m4 = M64_4::new(&P4);
416            assert_eq!(m4.residue(m4.transform(a)), a % P4);
417        }
418    }
419
420    #[test]
421    fn creation_test_u32() {
422        for _ in 0..NRANDOM {
423            let a = random::<u32>();
424
425            const P1: u32 = (1 << 13) - 1;
426            let m1 = M32_1::new(&P1);
427            assert_eq!(m1.residue(m1.transform(a)), a % P1);
428            const P2: u32 = (1 << 31) - 1;
429            let m2 = M32_2::new(&P2);
430            assert_eq!(m2.residue(m2.transform(a)), a % P2);
431            const P3: u32 = (1 << 16) - 5;
432            let m3 = M32_3::new(&P3);
433            assert_eq!(m3.residue(m3.transform(a)), a % P3);
434        }
435    }
436
437    #[test]
438    fn test_against_modops_u128() {
439        macro_rules! tests_for {
440            ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
441                const P: umax = <$M>::MODULUS;
442                let r = <$M>::new(&P);
443                let am = r.transform($a);
444                let bm = r.transform($b);
445                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
446                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
447                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
448                assert_eq!(r.neg(am), $a.negm(&P));
449                assert_eq!(r.inv(am), $a.invm(&P));
450                assert_eq!(r.dbl(am), $a.dblm(&P));
451                assert_eq!(r.sqr(am), $a.sqm(&P));
452                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
453            })*);
454        }
455
456        for _ in 0..NRANDOM {
457            let (a, b) = (random::<u128>(), random::<u128>());
458            let e = random::<u8>() as umax;
459            tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
460        }
461    }
462
463    #[test]
464    fn test_against_modops_u64() {
465        macro_rules! tests_for {
466            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
467                const P: u64 = <$M>::MODULUS;
468                let r = <$M>::new(&P);
469                let am = r.transform($a);
470                let bm = r.transform($b);
471                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
472                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
473                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
474                assert_eq!(r.neg(am), $a.negm(&P));
475                assert_eq!(r.inv(am), $a.invm(&P));
476                assert_eq!(r.dbl(am), $a.dblm(&P));
477                assert_eq!(r.sqr(am), $a.sqm(&P));
478                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
479            })*);
480        }
481
482        for _ in 0..NRANDOM {
483            let a = random::<u64>();
484            let b = random::<u64>();
485            let e = random::<u8>() as u64;
486            tests_for!(a, b, e; M64_1 M64_2 M64_3);
487        }
488    }
489
490    #[test]
491    fn test_against_modops_u32() {
492        macro_rules! tests_for {
493            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
494                const P: u32 = <$M>::MODULUS;
495                let r = <$M>::new(&P);
496                let am = r.transform($a);
497                let bm = r.transform($b);
498                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
499                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
500                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
501                assert_eq!(r.neg(am), $a.negm(&P));
502                assert_eq!(r.inv(am), $a.invm(&P));
503                assert_eq!(r.dbl(am), $a.dblm(&P));
504                assert_eq!(r.sqr(am), $a.sqm(&P));
505                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
506            })*);
507        }
508
509        for _ in 0..NRANDOM {
510            let a = random::<u32>();
511            let b = random::<u32>();
512            let e = random::<u8>() as u32;
513            tests_for!(a, b, e; M32_1 M32_2 M32_3);
514        }
515    }
516}