1use crate::reduced::impl_reduced_binary_pow;
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_mersenne {
5 (
6 $TypeName:ident,
7 $T:ty,
8 $D:ty,
9 $half_bits:expr,
10 $max_P:expr,
11 $kind:ident
12 ) => {
13 impl<const P: u8, const K: $T> $TypeName<P, K> {
14 const BITMASK: $T = match (1 as $T).checked_shl(P as u32) {
15 Some(v) => v.wrapping_sub(1),
16 None => <$T>::MAX,
17 };
18 pub const MODULUS: $T = {
19 let p1 = match (1 as $T).checked_shl(P as u32) {
20 Some(v) => v,
21 None => 0,
22 };
23 p1.wrapping_sub(K)
24 };
25
26 const FOLDS: u32 = if K == 1 {
31 2
32 } else {
33 let s = <$T>::BITS - K.leading_zeros(); let gap = P as u32 - s;
35 let folds_ceil = (P as u32 + gap - 1) / gap;
36 folds_ceil + 1
37 };
38
39 pub const fn reduce_single(v: $T) -> $T {
44 let mut lo = v & Self::BITMASK;
45 let mut hi = match v.checked_shr(P as u32) {
46 Some(s) => s,
47 None => 0,
48 };
49 while hi > 0 {
50 let sum = if K == 1 { hi + lo } else { hi * K + lo };
51 lo = sum & Self::BITMASK;
52 hi = match sum.checked_shr(P as u32) {
53 Some(s) => s,
54 None => 0,
55 };
56 }
57 if lo >= Self::MODULUS {
58 lo - Self::MODULUS
59 } else {
60 lo
61 }
62 }
63
64 impl_fixed_mersenne!(@reduce_double, $kind, $T, $D);
65 }
66
67 impl<const P: u8, const K: $T> Reducer<$T> for $TypeName<P, K> {
68 #[inline]
69 fn new(m: &$T) -> Self {
70 assert!(
71 *m == Self::MODULUS,
72 "the given modulus doesn't match with the generic params"
73 );
74 debug_assert!(P <= $max_P);
75 debug_assert!(K > 0 && K < (2 as $T).pow(P as u32 - 1) && K % 2 == 1);
76 debug_assert!(
77 Self::MODULUS % 3 != 0
78 && Self::MODULUS % 5 != 0
79 && Self::MODULUS % 7 != 0
80 && Self::MODULUS % 11 != 0
81 && Self::MODULUS % 13 != 0
82 ); Self {}
84 }
85 #[inline]
86 fn transform(&self, target: $T) -> $T {
87 Self::reduce_single(target)
88 }
89 #[inline]
90 fn check(&self, target: &$T) -> bool {
91 *target < Self::MODULUS
92 }
93 #[inline]
94 fn residue(&self, target: $T) -> $T {
95 target
96 }
97 #[inline]
98 fn modulus(&self) -> $T {
99 Self::MODULUS
100 }
101 #[inline]
102 fn is_zero(&self, target: &$T) -> bool {
103 target == &0
104 }
105
106 #[inline]
107 fn add(&self, lhs: &$T, rhs: &$T) -> $T {
108 let mut sum = lhs + rhs;
109 if sum >= Self::MODULUS {
110 sum -= Self::MODULUS
111 }
112 sum
113 }
114 #[inline]
115 fn sub(&self, lhs: &$T, rhs: &$T) -> $T {
116 if lhs >= rhs {
117 lhs - rhs
118 } else {
119 Self::MODULUS - (rhs - lhs)
120 }
121 }
122 #[inline]
123 fn dbl(&self, target: $T) -> $T {
124 self.add(&target, &target)
125 }
126 #[inline]
127 fn neg(&self, target: $T) -> $T {
128 if target == 0 {
129 0
130 } else {
131 Self::MODULUS - target
132 }
133 }
134 #[inline]
135 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
136 if (P as u32) < $half_bits {
137 Self::reduce_single(lhs * rhs)
138 } else {
139 Self::reduce_double(impl_fixed_mersenne!(@widen_mul, $kind, $T, $D, lhs, rhs))
140 }
141 }
142 #[inline]
143 fn inv(&self, target: $T) -> Option<$T> {
144 if (P as u32) < usize::BITS {
145 (target as usize)
146 .invm(&(Self::MODULUS as usize))
147 .map(|v| v as $T)
148 } else {
149 target.invm(&Self::MODULUS)
150 }
151 }
152 #[inline]
153 fn sqr(&self, target: $T) -> $T {
154 if (P as u32) < $half_bits {
155 Self::reduce_single(target * target)
156 } else {
157 Self::reduce_double(impl_fixed_mersenne!(@widen_sqr, $kind, $T, $D, target))
158 }
159 }
160
161 impl_reduced_binary_pow!($T);
162 }
163 };
164
165 (@reduce_double, primitive, $T:ty, $D:ty) => {
171 pub fn reduce_double(v: $D) -> $T {
176 let mut lo = (v as $T) & Self::BITMASK;
177 let mut hi = v >> P;
178 macro_rules! mersenne_fold {
179 () => {
180 let sum = if K == 1 {
181 hi + lo as $D
182 } else {
183 hi * (K as $D) + lo as $D
184 };
185 lo = (sum as $T) & Self::BITMASK;
186 hi = sum >> P;
187 };
188 }
189 if Self::FOLDS <= 2 {
190 #[allow(unused_assignments)] { mersenne_fold!(); }
191 #[allow(unused_assignments)] { mersenne_fold!(); }
192 } else if Self::FOLDS == 3 {
193 #[allow(unused_assignments)] { mersenne_fold!(); }
194 #[allow(unused_assignments)] { mersenne_fold!(); }
195 #[allow(unused_assignments)] { mersenne_fold!(); }
196 } else {
197 while hi > 0 { mersenne_fold!(); }
198 }
199 if lo >= Self::MODULUS {
200 lo - Self::MODULUS
201 } else {
202 lo
203 }
204 }
205 };
206
207 (@reduce_double, udouble, $T:ty, $D:ty) => {
213 pub fn reduce_double(v: $D) -> $T {
218 let mut lo = v.lo & Self::BITMASK;
219 let mut hi = v >> P;
220 while hi.hi > 0 {
221 let sum = if K == 1 { hi + lo } else { hi * K + lo };
222 lo = sum.lo & Self::BITMASK;
223 hi = sum >> P;
224 }
225 let mut hi = hi.lo;
226 macro_rules! mersenne_u128_fold {
227 () => {
228 let sum = if K == 1 { hi + lo } else { hi * K + lo };
229 lo = sum & Self::BITMASK;
230 hi = match sum.checked_shr(P as u32) {
231 Some(s) => s,
232 None => 0,
233 };
234 };
235 }
236 if Self::FOLDS <= 2 {
237 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
238 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
239 } else if Self::FOLDS == 3 {
240 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
241 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
242 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
243 } else {
244 while hi > 0 { mersenne_u128_fold!(); }
245 }
246 if lo >= Self::MODULUS {
247 lo - Self::MODULUS
248 } else {
249 lo
250 }
251 }
252 };
253
254 (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
256 (*$lhs as $D) * (*$rhs as $D)
257 };
258
259 (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
261 <$D>::widening_mul(*$lhs, *$rhs)
262 };
263
264 (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
266 ($target as $D) * ($target as $D)
267 };
268
269 (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
271 <$D>::widening_square($target)
272 };
273}
274
275#[derive(Debug, Clone, Copy)]
294pub struct FixedMersenne32<const P: u8, const K: u32>();
295
296impl_fixed_mersenne!(FixedMersenne32, u32, u64, 16, 32, primitive);
297
298#[derive(Debug, Clone, Copy)]
318pub struct FixedMersenne64<const P: u8, const K: u64>();
319
320impl_fixed_mersenne!(FixedMersenne64, u64, u128, 32, 64, primitive);
321
322#[derive(Debug, Clone, Copy)]
343pub struct FixedMersenne<const P: u8, const K: umax>();
344
345impl_fixed_mersenne!(FixedMersenne, umax, udouble, 64, 128, udouble);
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::{ModularCoreOps, ModularPow};
351 use rand::random;
352
353 type M1 = FixedMersenne<31, 1>;
355 type M2 = FixedMersenne<61, 1>;
356 type M3 = FixedMersenne<127, 1>;
357 type M4 = FixedMersenne<32, 5>;
358 type M5 = FixedMersenne<56, 5>;
359 type M6 = FixedMersenne<122, 3>;
360 type M7 = FixedMersenne<128, 159>;
361
362 type M64_1 = FixedMersenne64<31, 1>;
364 type M64_2 = FixedMersenne64<61, 1>;
365 type M64_3 = FixedMersenne64<32, 5>;
366 type M64_4 = FixedMersenne64<64, 59>;
367
368 type M32_1 = FixedMersenne32<13, 1>;
370 type M32_2 = FixedMersenne32<31, 1>;
371 type M32_3 = FixedMersenne32<16, 5>;
372
373 const NRANDOM: u32 = 10;
374
375 #[test]
376 fn creation_test_u128() {
377 const P: umax = (1 << 31) - 1;
378 let m = M1::new(&P);
379 assert_eq!(m.residue(m.transform(0)), 0);
380 assert_eq!(m.residue(m.transform(1)), 1);
381 assert_eq!(m.residue(m.transform(P)), 0);
382 assert_eq!(m.residue(m.transform(P - 1)), P - 1);
383 assert_eq!(m.residue(m.transform(P + 1)), 1);
384
385 for _ in 0..NRANDOM {
386 let a = random::<umax>();
387
388 const P1: umax = (1 << 31) - 1;
389 let m1 = M1::new(&P1);
390 assert_eq!(m1.residue(m1.transform(a)), a % P1);
391 const P2: umax = (1 << 61) - 1;
392 let m2 = M2::new(&P2);
393 assert_eq!(m2.residue(m2.transform(a)), a % P2);
394 const P3: umax = (1 << 127) - 1;
395 let m3 = M3::new(&P3);
396 assert_eq!(m3.residue(m3.transform(a)), a % P3);
397 const P4: umax = (1 << 32) - 5;
398 let m4 = M4::new(&P4);
399 assert_eq!(m4.residue(m4.transform(a)), a % P4);
400 const P5: umax = (1 << 56) - 5;
401 let m5 = M5::new(&P5);
402 assert_eq!(m5.residue(m5.transform(a)), a % P5);
403 const P6: umax = (1 << 122) - 3;
404 let m6 = M6::new(&P6);
405 assert_eq!(m6.residue(m6.transform(a)), a % P6);
406 const P7: umax = M7::MODULUS;
407 let m7 = M7::new(&P7);
408 assert_eq!(m7.residue(m7.transform(a)), a % P7);
409 }
410 }
411
412 #[test]
413 fn creation_test_u64() {
414 for _ in 0..NRANDOM {
415 let a = random::<u64>();
416
417 const P1: u64 = (1 << 31) - 1;
418 let m1 = M64_1::new(&P1);
419 assert_eq!(m1.residue(m1.transform(a)), a % P1);
420 const P2: u64 = (1 << 61) - 1;
421 let m2 = M64_2::new(&P2);
422 assert_eq!(m2.residue(m2.transform(a)), a % P2);
423 const P3: u64 = (1 << 32) - 5;
424 let m3 = M64_3::new(&P3);
425 assert_eq!(m3.residue(m3.transform(a)), a % P3);
426 const P4: u64 = M64_4::MODULUS;
427 let m4 = M64_4::new(&P4);
428 assert_eq!(m4.residue(m4.transform(a)), a % P4);
429 }
430 }
431
432 #[test]
433 fn creation_test_u32() {
434 for _ in 0..NRANDOM {
435 let a = random::<u32>();
436
437 const P1: u32 = (1 << 13) - 1;
438 let m1 = M32_1::new(&P1);
439 assert_eq!(m1.residue(m1.transform(a)), a % P1);
440 const P2: u32 = (1 << 31) - 1;
441 let m2 = M32_2::new(&P2);
442 assert_eq!(m2.residue(m2.transform(a)), a % P2);
443 const P3: u32 = (1 << 16) - 5;
444 let m3 = M32_3::new(&P3);
445 assert_eq!(m3.residue(m3.transform(a)), a % P3);
446 }
447 }
448
449 #[test]
450 fn test_against_modops_u128() {
451 macro_rules! tests_for {
452 ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
453 const P: umax = <$M>::MODULUS;
454 let r = <$M>::new(&P);
455 let am = r.transform($a);
456 let bm = r.transform($b);
457 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
458 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
459 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
460 assert_eq!(r.neg(am), $a.negm(&P));
461 assert_eq!(r.inv(am), $a.invm(&P));
462 assert_eq!(r.dbl(am), $a.dblm(&P));
463 assert_eq!(r.sqr(am), $a.sqm(&P));
464 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
465 })*);
466 }
467
468 for _ in 0..NRANDOM {
469 let (a, b) = (random::<u128>(), random::<u128>());
470 let e = random::<u8>() as umax;
471 tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
472 }
473 }
474
475 #[test]
476 fn test_against_modops_u64() {
477 macro_rules! tests_for {
478 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
479 const P: u64 = <$M>::MODULUS;
480 let r = <$M>::new(&P);
481 let am = r.transform($a);
482 let bm = r.transform($b);
483 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
484 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
485 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
486 assert_eq!(r.neg(am), $a.negm(&P));
487 assert_eq!(r.inv(am), $a.invm(&P));
488 assert_eq!(r.dbl(am), $a.dblm(&P));
489 assert_eq!(r.sqr(am), $a.sqm(&P));
490 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
491 })*);
492 }
493
494 for _ in 0..NRANDOM {
495 let a = random::<u64>();
496 let b = random::<u64>();
497 let e = random::<u8>() as u64;
498 tests_for!(a, b, e; M64_1 M64_2 M64_3);
499 }
500 }
501
502 #[test]
503 fn test_against_modops_u32() {
504 macro_rules! tests_for {
505 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
506 const P: u32 = <$M>::MODULUS;
507 let r = <$M>::new(&P);
508 let am = r.transform($a);
509 let bm = r.transform($b);
510 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
511 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
512 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
513 assert_eq!(r.neg(am), $a.negm(&P));
514 assert_eq!(r.inv(am), $a.invm(&P));
515 assert_eq!(r.dbl(am), $a.dblm(&P));
516 assert_eq!(r.sqr(am), $a.sqm(&P));
517 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
518 })*);
519 }
520
521 for _ in 0..NRANDOM {
522 let a = random::<u32>();
523 let b = random::<u32>();
524 let e = random::<u8>() as u32;
525 tests_for!(a, b, e; M32_1 M32_2 M32_3);
526 }
527 }
528}