1use crate::reduced::{impl_reduced_binary_pow, impl_reduced_ops};
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_mersenne {
5 (
6 $TypeName:ident,
7 $T:ty,
8 $D:ty,
9 $half_bits:expr,
10 $max_P:expr,
11 $kind:ident
12 ) => {
13 impl<const P: u8, const K: $T> $TypeName<P, K> {
14 const BITMASK: $T = match (1 as $T).checked_shl(P as u32) {
15 Some(v) => v.wrapping_sub(1),
16 None => <$T>::MAX,
17 };
18 pub const MODULUS: $T = {
19 let p1 = match (1 as $T).checked_shl(P as u32) {
20 Some(v) => v,
21 None => 0,
22 };
23 p1.wrapping_sub(K)
24 };
25
26 const FOLDS: u32 = if K == 1 {
31 2
32 } else {
33 let s = <$T>::BITS - K.leading_zeros(); let gap = P as u32 - s;
35 let folds_ceil = (P as u32 + gap - 1) / gap;
36 folds_ceil + 1
37 };
38
39 pub const fn reduce_single(v: $T) -> $T {
44 let mut lo = v & Self::BITMASK;
45 let mut hi = match v.checked_shr(P as u32) {
46 Some(s) => s,
47 None => 0,
48 };
49 while hi > 0 {
50 let sum = if K == 1 { hi + lo } else { hi * K + lo };
51 lo = sum & Self::BITMASK;
52 hi = match sum.checked_shr(P as u32) {
53 Some(s) => s,
54 None => 0,
55 };
56 }
57 if lo >= Self::MODULUS {
58 lo - Self::MODULUS
59 } else {
60 lo
61 }
62 }
63
64 impl_fixed_mersenne!(@reduce_double, $kind, $T, $D);
65 }
66
67 impl<const P: u8, const K: $T> Reducer<$T> for $TypeName<P, K> {
68 #[inline]
69 fn new(m: &$T) -> Self {
70 assert!(
71 *m == Self::MODULUS,
72 "the given modulus doesn't match with the generic params"
73 );
74 debug_assert!(P <= $max_P);
75 debug_assert!(K > 0 && K < (2 as $T).pow(P as u32 - 1) && K % 2 == 1);
76 debug_assert!(
77 (Self::MODULUS == 3 || Self::MODULUS % 3 != 0)
78 && (Self::MODULUS == 5 || Self::MODULUS % 5 != 0)
79 && (Self::MODULUS == 7 || Self::MODULUS % 7 != 0)
80 && (Self::MODULUS == 11 || Self::MODULUS % 11 != 0)
81 && (Self::MODULUS == 13 || Self::MODULUS % 13 != 0)
82 ); Self {}
84 }
85 #[inline]
86 fn transform(&self, target: $T) -> $T {
87 Self::reduce_single(target)
88 }
89 #[inline]
90 fn residue(&self, target: $T) -> $T {
91 target
92 }
93
94 impl_reduced_ops!($T);
95
96 #[inline]
97 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
98 if (P as u32) < $half_bits {
99 Self::reduce_single(lhs * rhs)
100 } else {
101 Self::reduce_double(impl_fixed_mersenne!(@widen_mul, $kind, $T, $D, lhs, rhs))
102 }
103 }
104 #[inline]
105 fn inv(&self, target: $T) -> Option<$T> {
106 if (P as u32) < usize::BITS {
107 (target as usize)
108 .invm(&(Self::MODULUS as usize))
109 .map(|v| v as $T)
110 } else {
111 target.invm(&Self::MODULUS)
112 }
113 }
114 #[inline]
115 fn sqr(&self, target: $T) -> $T {
116 if (P as u32) < $half_bits {
117 Self::reduce_single(target * target)
118 } else {
119 Self::reduce_double(impl_fixed_mersenne!(@widen_sqr, $kind, $T, $D, target))
120 }
121 }
122
123 impl_reduced_binary_pow!($T);
124 }
125 };
126
127 (@reduce_double, primitive, $T:ty, $D:ty) => {
133 pub fn reduce_double(v: $D) -> $T {
138 let mut lo = (v as $T) & Self::BITMASK;
139 let mut hi = v >> P;
140 macro_rules! mersenne_fold {
141 () => {
142 let sum = if K == 1 {
143 hi + lo as $D
144 } else {
145 hi * (K as $D) + lo as $D
146 };
147 lo = (sum as $T) & Self::BITMASK;
148 hi = sum >> P;
149 };
150 }
151 if Self::FOLDS <= 2 {
152 #[allow(unused_assignments)] { mersenne_fold!(); }
153 #[allow(unused_assignments)] { mersenne_fold!(); }
154 } else if Self::FOLDS == 3 {
155 #[allow(unused_assignments)] { mersenne_fold!(); }
156 #[allow(unused_assignments)] { mersenne_fold!(); }
157 #[allow(unused_assignments)] { mersenne_fold!(); }
158 } else {
159 while hi > 0 { mersenne_fold!(); }
160 }
161 if lo >= Self::MODULUS {
162 lo - Self::MODULUS
163 } else {
164 lo
165 }
166 }
167 };
168
169 (@reduce_double, udouble, $T:ty, $D:ty) => {
175 pub fn reduce_double(v: $D) -> $T {
180 let mut lo = v.lo & Self::BITMASK;
181 let mut hi = v >> P;
182 while hi.hi > 0 {
183 let sum = if K == 1 { hi + lo } else { hi * K + lo };
184 lo = sum.lo & Self::BITMASK;
185 hi = sum >> P;
186 }
187 let mut hi = hi.lo;
188 macro_rules! mersenne_u128_fold {
189 () => {
190 let sum = if K == 1 { hi + lo } else { hi * K + lo };
191 lo = sum & Self::BITMASK;
192 hi = match sum.checked_shr(P as u32) {
193 Some(s) => s,
194 None => 0,
195 };
196 };
197 }
198 if Self::FOLDS <= 2 {
199 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
200 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
201 } else if Self::FOLDS == 3 {
202 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
203 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
204 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
205 } else {
206 while hi > 0 { mersenne_u128_fold!(); }
207 }
208 if lo >= Self::MODULUS {
209 lo - Self::MODULUS
210 } else {
211 lo
212 }
213 }
214 };
215
216 (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
218 (*$lhs as $D) * (*$rhs as $D)
219 };
220
221 (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
223 <$D>::widening_mul(*$lhs, *$rhs)
224 };
225
226 (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
228 ($target as $D) * ($target as $D)
229 };
230
231 (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
233 <$D>::widening_square($target)
234 };
235}
236
237#[must_use]
256#[derive(Debug, Clone, Copy)]
257pub struct FixedMersenne32<const P: u8, const K: u32>();
258
259impl_fixed_mersenne!(FixedMersenne32, u32, u64, 16, 32, primitive);
260
261#[must_use]
281#[derive(Debug, Clone, Copy)]
282pub struct FixedMersenne64<const P: u8, const K: u64>();
283
284impl_fixed_mersenne!(FixedMersenne64, u64, u128, 32, 64, primitive);
285
286#[must_use]
307#[derive(Debug, Clone, Copy)]
308pub struct FixedMersenne<const P: u8, const K: umax>();
309
310impl_fixed_mersenne!(FixedMersenne, umax, udouble, 64, 128, udouble);
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::{ModularCoreOps, ModularPow};
316 use rand::random;
317
318 type M1 = FixedMersenne<31, 1>;
320 type M2 = FixedMersenne<61, 1>;
321 type M3 = FixedMersenne<127, 1>;
322 type M4 = FixedMersenne<32, 5>;
323 type M5 = FixedMersenne<56, 5>;
324 type M6 = FixedMersenne<122, 3>;
325 type M7 = FixedMersenne<128, 159>;
326
327 type M64_1 = FixedMersenne64<31, 1>;
329 type M64_2 = FixedMersenne64<61, 1>;
330 type M64_3 = FixedMersenne64<32, 5>;
331 type M64_4 = FixedMersenne64<64, 59>;
332
333 type M32_1 = FixedMersenne32<13, 1>;
335 type M32_2 = FixedMersenne32<31, 1>;
336 type M32_3 = FixedMersenne32<16, 5>;
337
338 const NRANDOM: u32 = 10;
339
340 #[test]
341 fn creation_test_u128() {
342 const P: umax = (1 << 31) - 1;
343 let m = M1::new(&P);
344 assert_eq!(m.residue(m.transform(0)), 0);
345 assert_eq!(m.residue(m.transform(1)), 1);
346 assert_eq!(m.residue(m.transform(P)), 0);
347 assert_eq!(m.residue(m.transform(P - 1)), P - 1);
348 assert_eq!(m.residue(m.transform(P + 1)), 1);
349
350 for _ in 0..NRANDOM {
351 let a = random::<umax>();
352
353 const P1: umax = (1 << 31) - 1;
354 let m1 = M1::new(&P1);
355 assert_eq!(m1.residue(m1.transform(a)), a % P1);
356 const P2: umax = (1 << 61) - 1;
357 let m2 = M2::new(&P2);
358 assert_eq!(m2.residue(m2.transform(a)), a % P2);
359 const P3: umax = (1 << 127) - 1;
360 let m3 = M3::new(&P3);
361 assert_eq!(m3.residue(m3.transform(a)), a % P3);
362 const P4: umax = (1 << 32) - 5;
363 let m4 = M4::new(&P4);
364 assert_eq!(m4.residue(m4.transform(a)), a % P4);
365 const P5: umax = (1 << 56) - 5;
366 let m5 = M5::new(&P5);
367 assert_eq!(m5.residue(m5.transform(a)), a % P5);
368 const P6: umax = (1 << 122) - 3;
369 let m6 = M6::new(&P6);
370 assert_eq!(m6.residue(m6.transform(a)), a % P6);
371 const P7: umax = M7::MODULUS;
372 let m7 = M7::new(&P7);
373 assert_eq!(m7.residue(m7.transform(a)), a % P7);
374 }
375 }
376
377 #[test]
378 fn creation_test_u64() {
379 for _ in 0..NRANDOM {
380 let a = random::<u64>();
381
382 const P1: u64 = (1 << 31) - 1;
383 let m1 = M64_1::new(&P1);
384 assert_eq!(m1.residue(m1.transform(a)), a % P1);
385 const P2: u64 = (1 << 61) - 1;
386 let m2 = M64_2::new(&P2);
387 assert_eq!(m2.residue(m2.transform(a)), a % P2);
388 const P3: u64 = (1 << 32) - 5;
389 let m3 = M64_3::new(&P3);
390 assert_eq!(m3.residue(m3.transform(a)), a % P3);
391 const P4: u64 = M64_4::MODULUS;
392 let m4 = M64_4::new(&P4);
393 assert_eq!(m4.residue(m4.transform(a)), a % P4);
394 }
395 }
396
397 #[test]
398 fn creation_test_u32() {
399 for _ in 0..NRANDOM {
400 let a = random::<u32>();
401
402 const P1: u32 = (1 << 13) - 1;
403 let m1 = M32_1::new(&P1);
404 assert_eq!(m1.residue(m1.transform(a)), a % P1);
405 const P2: u32 = (1 << 31) - 1;
406 let m2 = M32_2::new(&P2);
407 assert_eq!(m2.residue(m2.transform(a)), a % P2);
408 const P3: u32 = (1 << 16) - 5;
409 let m3 = M32_3::new(&P3);
410 assert_eq!(m3.residue(m3.transform(a)), a % P3);
411 }
412 }
413
414 #[test]
415 fn test_against_modops_u128() {
416 macro_rules! tests_for {
417 ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
418 const P: umax = <$M>::MODULUS;
419 let r = <$M>::new(&P);
420 let am = r.transform($a);
421 let bm = r.transform($b);
422 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
423 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
424 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
425 assert_eq!(r.neg(am), $a.negm(&P));
426 assert_eq!(r.inv(am), $a.invm(&P));
427 assert_eq!(r.dbl(am), $a.dblm(&P));
428 assert_eq!(r.sqr(am), $a.sqm(&P));
429 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
430 })*);
431 }
432
433 for _ in 0..NRANDOM {
434 let (a, b) = (random::<u128>(), random::<u128>());
435 let e = random::<u8>() as umax;
436 tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
437 }
438 }
439
440 #[test]
441 fn test_against_modops_u64() {
442 macro_rules! tests_for {
443 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
444 const P: u64 = <$M>::MODULUS;
445 let r = <$M>::new(&P);
446 let am = r.transform($a);
447 let bm = r.transform($b);
448 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
449 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
450 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
451 assert_eq!(r.neg(am), $a.negm(&P));
452 assert_eq!(r.inv(am), $a.invm(&P));
453 assert_eq!(r.dbl(am), $a.dblm(&P));
454 assert_eq!(r.sqr(am), $a.sqm(&P));
455 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
456 })*);
457 }
458
459 for _ in 0..NRANDOM {
460 let a = random::<u64>();
461 let b = random::<u64>();
462 let e = random::<u8>() as u64;
463 tests_for!(a, b, e; M64_1 M64_2 M64_3);
464 }
465 }
466
467 #[test]
468 fn test_against_modops_u32() {
469 macro_rules! tests_for {
470 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
471 const P: u32 = <$M>::MODULUS;
472 let r = <$M>::new(&P);
473 let am = r.transform($a);
474 let bm = r.transform($b);
475 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
476 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
477 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
478 assert_eq!(r.neg(am), $a.negm(&P));
479 assert_eq!(r.inv(am), $a.invm(&P));
480 assert_eq!(r.dbl(am), $a.dblm(&P));
481 assert_eq!(r.sqr(am), $a.sqm(&P));
482 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
483 })*);
484 }
485
486 for _ in 0..NRANDOM {
487 let a = random::<u32>();
488 let b = random::<u32>();
489 let e = random::<u8>() as u32;
490 tests_for!(a, b, e; M32_1 M32_2 M32_3);
491 }
492 }
493
494 #[test]
495 fn small_prime_moduli() {
496 type M3 = FixedMersenne<2, 1>; type M7 = FixedMersenne<3, 1>; const M3_MOD: umax = M3::MODULUS;
504 let r3 = M3::new(&M3_MOD);
505 assert_eq!(r3.residue(r3.transform(5 % M3_MOD)), 5 % M3_MOD);
506
507 const M7_MOD: umax = M7::MODULUS;
508 let r7 = M7::new(&M7_MOD);
509 assert_eq!(r7.residue(r7.transform(10 % M7_MOD)), 10 % M7_MOD);
510
511 type M32_3 = FixedMersenne32<2, 1>; type M32_7 = FixedMersenne32<3, 1>; const P3: u32 = M32_3::MODULUS;
516 let r3 = M32_3::new(&P3);
517 assert_eq!(r3.residue(r3.transform(5 % P3)), 5 % P3);
518
519 const P7: u32 = M32_7::MODULUS;
520 let r7 = M32_7::new(&P7);
521 assert_eq!(r7.residue(r7.transform(10 % P7)), 10 % P7);
522 }
523}