Skip to main content

num_modular/
mersenne.rs

1use crate::reduced::impl_reduced_binary_pow;
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_mersenne {
5    (
6        $TypeName:ident,
7        $T:ty,
8        $D:ty,
9        $half_bits:expr,
10        $max_P:expr,
11        $kind:ident
12    ) => {
13        impl<const P: u8, const K: $T> $TypeName<P, K> {
14            const BITMASK: $T = match (1 as $T).checked_shl(P as u32) {
15                Some(v) => v.wrapping_sub(1),
16                None => <$T>::MAX,
17            };
18            pub const MODULUS: $T = {
19                let p1 = match (1 as $T).checked_shl(P as u32) {
20                    Some(v) => v,
21                    None => 0,
22                };
23                p1.wrapping_sub(K)
24            };
25
26            /// Worst-case fold count for `reduce_double`.
27            /// Each fold replaces V = hi·2^P + lo with hi·K + lo (since 2^P ≡ K).
28            /// For K = 1: always 2 folds (the carry chain terminates in at most one
29            /// extra step). For K > 1: ⌈P/(P−⌈log₂K⌉)⌉ + 1 folds.
30            const FOLDS: u32 = if K == 1 {
31                2
32            } else {
33                let s = <$T>::BITS - K.leading_zeros(); // bit-width of K
34                let gap = P as u32 - s;
35                let folds_ceil = (P as u32 + gap - 1) / gap;
36                folds_ceil + 1
37            };
38
39            /// Reduces a single-width value `v` modulo `2^P - K`.
40            ///
41            /// For the result of a widening multiplication or square, use
42            /// [`reduce_double`](Self::reduce_double) instead.
43            pub const fn reduce_single(v: $T) -> $T {
44                let mut lo = v & Self::BITMASK;
45                let mut hi = match v.checked_shr(P as u32) {
46                    Some(s) => s,
47                    None => 0,
48                };
49                while hi > 0 {
50                    let sum = if K == 1 { hi + lo } else { hi * K + lo };
51                    lo = sum & Self::BITMASK;
52                    hi = match sum.checked_shr(P as u32) {
53                        Some(s) => s,
54                        None => 0,
55                    };
56                }
57                if lo >= Self::MODULUS {
58                    lo - Self::MODULUS
59                } else {
60                    lo
61                }
62            }
63
64            impl_fixed_mersenne!(@reduce_double, $kind, $T, $D);
65        }
66
67        impl<const P: u8, const K: $T> Reducer<$T> for $TypeName<P, K> {
68            #[inline]
69            fn new(m: &$T) -> Self {
70                assert!(
71                    *m == Self::MODULUS,
72                    "the given modulus doesn't match with the generic params"
73                );
74                debug_assert!(P <= $max_P);
75                debug_assert!(K > 0 && K < (2 as $T).pow(P as u32 - 1) && K % 2 == 1);
76                debug_assert!(
77                    Self::MODULUS % 3 != 0
78                        && Self::MODULUS % 5 != 0
79                        && Self::MODULUS % 7 != 0
80                        && Self::MODULUS % 11 != 0
81                        && Self::MODULUS % 13 != 0
82                ); // error on easy composites
83                Self {}
84            }
85            #[inline]
86            fn transform(&self, target: $T) -> $T {
87                Self::reduce_single(target)
88            }
89            #[inline]
90            fn check(&self, target: &$T) -> bool {
91                *target < Self::MODULUS
92            }
93            #[inline]
94            fn residue(&self, target: $T) -> $T {
95                target
96            }
97            #[inline]
98            fn modulus(&self) -> $T {
99                Self::MODULUS
100            }
101            #[inline]
102            fn is_zero(&self, target: &$T) -> bool {
103                target == &0
104            }
105
106            #[inline]
107            fn add(&self, lhs: &$T, rhs: &$T) -> $T {
108                let mut sum = lhs + rhs;
109                if sum >= Self::MODULUS {
110                    sum -= Self::MODULUS
111                }
112                sum
113            }
114            #[inline]
115            fn sub(&self, lhs: &$T, rhs: &$T) -> $T {
116                if lhs >= rhs {
117                    lhs - rhs
118                } else {
119                    Self::MODULUS - (rhs - lhs)
120                }
121            }
122            #[inline]
123            fn dbl(&self, target: $T) -> $T {
124                self.add(&target, &target)
125            }
126            #[inline]
127            fn neg(&self, target: $T) -> $T {
128                if target == 0 {
129                    0
130                } else {
131                    Self::MODULUS - target
132                }
133            }
134            #[inline]
135            fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
136                if (P as u32) < $half_bits {
137                    Self::reduce_single(lhs * rhs)
138                } else {
139                    Self::reduce_double(impl_fixed_mersenne!(@widen_mul, $kind, $T, $D, lhs, rhs))
140                }
141            }
142            #[inline]
143            fn inv(&self, target: $T) -> Option<$T> {
144                if (P as u32) < usize::BITS {
145                    (target as usize)
146                        .invm(&(Self::MODULUS as usize))
147                        .map(|v| v as $T)
148                } else {
149                    target.invm(&Self::MODULUS)
150                }
151            }
152            #[inline]
153            fn sqr(&self, target: $T) -> $T {
154                if (P as u32) < $half_bits {
155                    Self::reduce_single(target * target)
156                } else {
157                    Self::reduce_double(impl_fixed_mersenne!(@widen_sqr, $kind, $T, $D, target))
158                }
159            }
160
161            impl_reduced_binary_pow!($T);
162        }
163    };
164
165    // Internal: reduce_double for primitive double-width types (u32→u64, u64→u128)
166    //
167    // For real pseudo-Mersennes, FOLDS is always ≤ 3 (K=1 → 2; small K → 3).
168    // Unrolling replaces the data-dependent while loop with straight-line folds.
169    // Extra folds past the true count are no-ops (hi reaches 0).
170    (@reduce_double, primitive, $T:ty, $D:ty) => {
171        /// Reduces a double-width value `v` modulo `2^P - K`.
172        ///
173        /// This handles widening-multiplication or widening-square results.
174        /// For single-width values, use [`reduce_single`](Self::reduce_single).
175        pub fn reduce_double(v: $D) -> $T {
176            let mut lo = (v as $T) & Self::BITMASK;
177            let mut hi = v >> P;
178            macro_rules! mersenne_fold {
179                () => {
180                    let sum = if K == 1 {
181                        hi + lo as $D
182                    } else {
183                        hi * (K as $D) + lo as $D
184                    };
185                    lo = (sum as $T) & Self::BITMASK;
186                    hi = sum >> P;
187                };
188            }
189            if Self::FOLDS <= 2 {
190                #[allow(unused_assignments)] { mersenne_fold!(); }
191                #[allow(unused_assignments)] { mersenne_fold!(); }
192            } else if Self::FOLDS == 3 {
193                #[allow(unused_assignments)] { mersenne_fold!(); }
194                #[allow(unused_assignments)] { mersenne_fold!(); }
195                #[allow(unused_assignments)] { mersenne_fold!(); }
196            } else {
197                while hi > 0 { mersenne_fold!(); }
198            }
199            if lo >= Self::MODULUS {
200                lo - Self::MODULUS
201            } else {
202                lo
203            }
204        }
205    };
206
207    // Internal: reduce_double for udouble (u128→udouble)
208    //
209    // Phase 1 (udouble while hi.hi > 0) is unreachable for valid P ≤ 128 since
210    // hi = v >> P < 2^P ≤ 2^128 always fits in one word. Phase 2 uses u128
211    // arithmetic and is unrolled when FOLDS ≤ 3 (all practical pseudo-Mersennes).
212    (@reduce_double, udouble, $T:ty, $D:ty) => {
213        /// Reduces a double-width value `v` modulo `2^P - K`.
214        ///
215        /// This handles widening-multiplication or widening-square results.
216        /// For single-width values, use [`reduce_single`](Self::reduce_single).
217        pub fn reduce_double(v: $D) -> $T {
218            let mut lo = v.lo & Self::BITMASK;
219            let mut hi = v >> P;
220            while hi.hi > 0 {
221                let sum = if K == 1 { hi + lo } else { hi * K + lo };
222                lo = sum.lo & Self::BITMASK;
223                hi = sum >> P;
224            }
225            let mut hi = hi.lo;
226            macro_rules! mersenne_u128_fold {
227                () => {
228                    let sum = if K == 1 { hi + lo } else { hi * K + lo };
229                    lo = sum & Self::BITMASK;
230                    hi = match sum.checked_shr(P as u32) {
231                        Some(s) => s,
232                        None => 0,
233                    };
234                };
235            }
236            if Self::FOLDS <= 2 {
237                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
238                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
239            } else if Self::FOLDS == 3 {
240                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
241                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
242                #[allow(unused_assignments)] { mersenne_u128_fold!(); }
243            } else {
244                while hi > 0 { mersenne_u128_fold!(); }
245            }
246            if lo >= Self::MODULUS {
247                lo - Self::MODULUS
248            } else {
249                lo
250            }
251        }
252    };
253
254    // Internal: widening multiplication for primitive types
255    (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
256        (*$lhs as $D) * (*$rhs as $D)
257    };
258
259    // Internal: widening multiplication for udouble
260    (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
261        <$D>::widening_mul(*$lhs, *$rhs)
262    };
263
264    // Internal: widening square for primitive types
265    (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
266        ($target as $D) * ($target as $D)
267    };
268
269    // Internal: widening square for udouble
270    (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
271        <$D>::widening_square($target)
272    };
273}
274
275/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus with 32-bit operands.
276///
277/// Supports `P` up to 32 and `K < 2^(P-1)`. All inputs and outputs are `u32`.
278/// The modulus `2^P - K` must be prime for modular inverse and Fermat-based operations to be valid.
279///
280/// # Example
281///
282/// ```rust
283/// use num_modular::{FixedMersenne32, Reducer};
284///
285/// const P: u8 = 13;
286/// const K: u32 = 1;
287/// let modulus = (1u32 << P) - K; // 2^13 - 1 = 8191 (Mersenne prime)
288/// let reducer = FixedMersenne32::<P, K>::new(&modulus);
289/// let a = reducer.transform(100);
290/// let b = reducer.transform(200);
291/// assert_eq!(reducer.residue(reducer.add(&a, &b)), 300 % modulus);
292/// ```
293#[derive(Debug, Clone, Copy)]
294pub struct FixedMersenne32<const P: u8, const K: u32>();
295
296impl_fixed_mersenne!(FixedMersenne32, u32, u64, 16, 32, primitive);
297
298/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus with 64-bit operands.
299///
300/// Supports `P` up to 64 and `K < 2^(P-1)`. All inputs and outputs are `u64`.
301/// Uses `u128` as the double-width intermediate for multiplication and reduction.
302/// The modulus `2^P - K` must be prime for modular inverse and Fermat-based operations to be valid.
303///
304/// # Example
305///
306/// ```rust
307/// use num_modular::{FixedMersenne64, Reducer};
308///
309/// const P: u8 = 61;
310/// const K: u64 = 1;
311/// let modulus = (1u64 << P) - K; // 2^61 - 1 (Mersenne prime)
312/// let reducer = FixedMersenne64::<P, K>::new(&modulus);
313/// let a = reducer.transform(1000);
314/// let b = reducer.transform(2000);
315/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (1000u64 * 2000) % modulus);
316/// ```
317#[derive(Debug, Clone, Copy)]
318pub struct FixedMersenne64<const P: u8, const K: u64>();
319
320impl_fixed_mersenne!(FixedMersenne64, u64, u128, 32, 64, primitive);
321
322/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus.
323///
324/// Supports `P` up to 128 and `K < 2^(P-1)`. All inputs and outputs are [umax] (currently `u128`).
325///
326/// The `P` is limited to 128 so that overflow checks aren't necessary. This covers all Mersenne
327/// primes within the range of [umax] (i.e. `u128`).
328///
329/// # Example
330///
331/// ```rust
332/// use num_modular::{FixedMersenne, Reducer, umax};
333///
334/// const P: u8 = 31;
335/// const K: umax = 1;
336/// let modulus = (1 << P) - K; // 2^31 - 1 (Mersenne prime)
337/// let reducer = FixedMersenne::<P, K>::new(&modulus);
338/// let a = reducer.transform(1000);
339/// let b = reducer.transform(2000);
340/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (1000 * 2000) % modulus);
341/// ```
342#[derive(Debug, Clone, Copy)]
343pub struct FixedMersenne<const P: u8, const K: umax>();
344
345impl_fixed_mersenne!(FixedMersenne, umax, udouble, 64, 128, udouble);
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::{ModularCoreOps, ModularPow};
351    use rand::random;
352
353    // u128 tests (existing)
354    type M1 = FixedMersenne<31, 1>;
355    type M2 = FixedMersenne<61, 1>;
356    type M3 = FixedMersenne<127, 1>;
357    type M4 = FixedMersenne<32, 5>;
358    type M5 = FixedMersenne<56, 5>;
359    type M6 = FixedMersenne<122, 3>;
360    type M7 = FixedMersenne<128, 159>;
361
362    // u64 tests
363    type M64_1 = FixedMersenne64<31, 1>;
364    type M64_2 = FixedMersenne64<61, 1>;
365    type M64_3 = FixedMersenne64<32, 5>;
366    type M64_4 = FixedMersenne64<64, 59>;
367
368    // u32 tests
369    type M32_1 = FixedMersenne32<13, 1>;
370    type M32_2 = FixedMersenne32<31, 1>;
371    type M32_3 = FixedMersenne32<16, 5>;
372
373    const NRANDOM: u32 = 10;
374
375    #[test]
376    fn creation_test_u128() {
377        const P: umax = (1 << 31) - 1;
378        let m = M1::new(&P);
379        assert_eq!(m.residue(m.transform(0)), 0);
380        assert_eq!(m.residue(m.transform(1)), 1);
381        assert_eq!(m.residue(m.transform(P)), 0);
382        assert_eq!(m.residue(m.transform(P - 1)), P - 1);
383        assert_eq!(m.residue(m.transform(P + 1)), 1);
384
385        for _ in 0..NRANDOM {
386            let a = random::<umax>();
387
388            const P1: umax = (1 << 31) - 1;
389            let m1 = M1::new(&P1);
390            assert_eq!(m1.residue(m1.transform(a)), a % P1);
391            const P2: umax = (1 << 61) - 1;
392            let m2 = M2::new(&P2);
393            assert_eq!(m2.residue(m2.transform(a)), a % P2);
394            const P3: umax = (1 << 127) - 1;
395            let m3 = M3::new(&P3);
396            assert_eq!(m3.residue(m3.transform(a)), a % P3);
397            const P4: umax = (1 << 32) - 5;
398            let m4 = M4::new(&P4);
399            assert_eq!(m4.residue(m4.transform(a)), a % P4);
400            const P5: umax = (1 << 56) - 5;
401            let m5 = M5::new(&P5);
402            assert_eq!(m5.residue(m5.transform(a)), a % P5);
403            const P6: umax = (1 << 122) - 3;
404            let m6 = M6::new(&P6);
405            assert_eq!(m6.residue(m6.transform(a)), a % P6);
406            const P7: umax = M7::MODULUS;
407            let m7 = M7::new(&P7);
408            assert_eq!(m7.residue(m7.transform(a)), a % P7);
409        }
410    }
411
412    #[test]
413    fn creation_test_u64() {
414        for _ in 0..NRANDOM {
415            let a = random::<u64>();
416
417            const P1: u64 = (1 << 31) - 1;
418            let m1 = M64_1::new(&P1);
419            assert_eq!(m1.residue(m1.transform(a)), a % P1);
420            const P2: u64 = (1 << 61) - 1;
421            let m2 = M64_2::new(&P2);
422            assert_eq!(m2.residue(m2.transform(a)), a % P2);
423            const P3: u64 = (1 << 32) - 5;
424            let m3 = M64_3::new(&P3);
425            assert_eq!(m3.residue(m3.transform(a)), a % P3);
426            const P4: u64 = M64_4::MODULUS;
427            let m4 = M64_4::new(&P4);
428            assert_eq!(m4.residue(m4.transform(a)), a % P4);
429        }
430    }
431
432    #[test]
433    fn creation_test_u32() {
434        for _ in 0..NRANDOM {
435            let a = random::<u32>();
436
437            const P1: u32 = (1 << 13) - 1;
438            let m1 = M32_1::new(&P1);
439            assert_eq!(m1.residue(m1.transform(a)), a % P1);
440            const P2: u32 = (1 << 31) - 1;
441            let m2 = M32_2::new(&P2);
442            assert_eq!(m2.residue(m2.transform(a)), a % P2);
443            const P3: u32 = (1 << 16) - 5;
444            let m3 = M32_3::new(&P3);
445            assert_eq!(m3.residue(m3.transform(a)), a % P3);
446        }
447    }
448
449    #[test]
450    fn test_against_modops_u128() {
451        macro_rules! tests_for {
452            ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
453                const P: umax = <$M>::MODULUS;
454                let r = <$M>::new(&P);
455                let am = r.transform($a);
456                let bm = r.transform($b);
457                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
458                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
459                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
460                assert_eq!(r.neg(am), $a.negm(&P));
461                assert_eq!(r.inv(am), $a.invm(&P));
462                assert_eq!(r.dbl(am), $a.dblm(&P));
463                assert_eq!(r.sqr(am), $a.sqm(&P));
464                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
465            })*);
466        }
467
468        for _ in 0..NRANDOM {
469            let (a, b) = (random::<u128>(), random::<u128>());
470            let e = random::<u8>() as umax;
471            tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
472        }
473    }
474
475    #[test]
476    fn test_against_modops_u64() {
477        macro_rules! tests_for {
478            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
479                const P: u64 = <$M>::MODULUS;
480                let r = <$M>::new(&P);
481                let am = r.transform($a);
482                let bm = r.transform($b);
483                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
484                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
485                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
486                assert_eq!(r.neg(am), $a.negm(&P));
487                assert_eq!(r.inv(am), $a.invm(&P));
488                assert_eq!(r.dbl(am), $a.dblm(&P));
489                assert_eq!(r.sqr(am), $a.sqm(&P));
490                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
491            })*);
492        }
493
494        for _ in 0..NRANDOM {
495            let a = random::<u64>();
496            let b = random::<u64>();
497            let e = random::<u8>() as u64;
498            tests_for!(a, b, e; M64_1 M64_2 M64_3);
499        }
500    }
501
502    #[test]
503    fn test_against_modops_u32() {
504        macro_rules! tests_for {
505            ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
506                const P: u32 = <$M>::MODULUS;
507                let r = <$M>::new(&P);
508                let am = r.transform($a);
509                let bm = r.transform($b);
510                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
511                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
512                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
513                assert_eq!(r.neg(am), $a.negm(&P));
514                assert_eq!(r.inv(am), $a.invm(&P));
515                assert_eq!(r.dbl(am), $a.dblm(&P));
516                assert_eq!(r.sqr(am), $a.sqm(&P));
517                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
518            })*);
519        }
520
521        for _ in 0..NRANDOM {
522            let a = random::<u32>();
523            let b = random::<u32>();
524            let e = random::<u8>() as u32;
525            tests_for!(a, b, e; M32_1 M32_2 M32_3);
526        }
527    }
528}