1use crate::reduced::impl_reduced_binary_pow;
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_mersenne {
5 (
6 $TypeName:ident,
7 $T:ty,
8 $D:ty,
9 $half_bits:expr,
10 $max_P:expr,
11 $kind:ident
12 ) => {
13 impl<const P: u8, const K: $T> $TypeName<P, K> {
14 const BITMASK: $T = match (1 as $T).checked_shl(P as u32) {
15 Some(v) => v.wrapping_sub(1),
16 None => <$T>::MAX,
17 };
18 pub const MODULUS: $T = {
19 let p1 = match (1 as $T).checked_shl(P as u32) {
20 Some(v) => v,
21 None => 0,
22 };
23 p1.wrapping_sub(K)
24 };
25
26 const FOLDS: u32 = if K == 1 {
31 2
32 } else {
33 let s = <$T>::BITS - K.leading_zeros(); let gap = P as u32 - s;
35 let folds_ceil = (P as u32 + gap - 1) / gap;
36 folds_ceil + 1
37 };
38
39 const fn reduce_single(v: $T) -> $T {
40 let mut lo = v & Self::BITMASK;
41 let mut hi = match v.checked_shr(P as u32) {
42 Some(s) => s,
43 None => 0,
44 };
45 while hi > 0 {
46 let sum = if K == 1 { hi + lo } else { hi * K + lo };
47 lo = sum & Self::BITMASK;
48 hi = match sum.checked_shr(P as u32) {
49 Some(s) => s,
50 None => 0,
51 };
52 }
53 if lo >= Self::MODULUS {
54 lo - Self::MODULUS
55 } else {
56 lo
57 }
58 }
59
60 impl_fixed_mersenne!(@reduce_double, $kind, $T, $D);
61 }
62
63 impl<const P: u8, const K: $T> Reducer<$T> for $TypeName<P, K> {
64 #[inline]
65 fn new(m: &$T) -> Self {
66 assert!(
67 *m == Self::MODULUS,
68 "the given modulus doesn't match with the generic params"
69 );
70 debug_assert!(P <= $max_P);
71 debug_assert!(K > 0 && K < (2 as $T).pow(P as u32 - 1) && K % 2 == 1);
72 debug_assert!(
73 Self::MODULUS % 3 != 0
74 && Self::MODULUS % 5 != 0
75 && Self::MODULUS % 7 != 0
76 && Self::MODULUS % 11 != 0
77 && Self::MODULUS % 13 != 0
78 ); Self {}
80 }
81 #[inline]
82 fn transform(&self, target: $T) -> $T {
83 Self::reduce_single(target)
84 }
85 #[inline]
86 fn check(&self, target: &$T) -> bool {
87 *target < Self::MODULUS
88 }
89 #[inline]
90 fn residue(&self, target: $T) -> $T {
91 target
92 }
93 #[inline]
94 fn modulus(&self) -> $T {
95 Self::MODULUS
96 }
97 #[inline]
98 fn is_zero(&self, target: &$T) -> bool {
99 target == &0
100 }
101
102 #[inline]
103 fn add(&self, lhs: &$T, rhs: &$T) -> $T {
104 let mut sum = lhs + rhs;
105 if sum >= Self::MODULUS {
106 sum -= Self::MODULUS
107 }
108 sum
109 }
110 #[inline]
111 fn sub(&self, lhs: &$T, rhs: &$T) -> $T {
112 if lhs >= rhs {
113 lhs - rhs
114 } else {
115 Self::MODULUS - (rhs - lhs)
116 }
117 }
118 #[inline]
119 fn dbl(&self, target: $T) -> $T {
120 self.add(&target, &target)
121 }
122 #[inline]
123 fn neg(&self, target: $T) -> $T {
124 if target == 0 {
125 0
126 } else {
127 Self::MODULUS - target
128 }
129 }
130 #[inline]
131 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
132 if (P as u32) < $half_bits {
133 Self::reduce_single(lhs * rhs)
134 } else {
135 Self::reduce_double(impl_fixed_mersenne!(@widen_mul, $kind, $T, $D, lhs, rhs))
136 }
137 }
138 #[inline]
139 fn inv(&self, target: $T) -> Option<$T> {
140 if (P as u32) < usize::BITS {
141 (target as usize)
142 .invm(&(Self::MODULUS as usize))
143 .map(|v| v as $T)
144 } else {
145 target.invm(&Self::MODULUS)
146 }
147 }
148 #[inline]
149 fn sqr(&self, target: $T) -> $T {
150 if (P as u32) < $half_bits {
151 Self::reduce_single(target * target)
152 } else {
153 Self::reduce_double(impl_fixed_mersenne!(@widen_sqr, $kind, $T, $D, target))
154 }
155 }
156
157 impl_reduced_binary_pow!($T);
158 }
159 };
160
161 (@reduce_double, primitive, $T:ty, $D:ty) => {
167 fn reduce_double(v: $D) -> $T {
168 let mut lo = (v as $T) & Self::BITMASK;
169 let mut hi = v >> P;
170 macro_rules! mersenne_fold {
171 () => {
172 let sum = if K == 1 {
173 hi + lo as $D
174 } else {
175 hi * (K as $D) + lo as $D
176 };
177 lo = (sum as $T) & Self::BITMASK;
178 hi = sum >> P;
179 };
180 }
181 if Self::FOLDS <= 2 {
182 #[allow(unused_assignments)] { mersenne_fold!(); }
183 #[allow(unused_assignments)] { mersenne_fold!(); }
184 } else if Self::FOLDS == 3 {
185 #[allow(unused_assignments)] { mersenne_fold!(); }
186 #[allow(unused_assignments)] { mersenne_fold!(); }
187 #[allow(unused_assignments)] { mersenne_fold!(); }
188 } else {
189 while hi > 0 { mersenne_fold!(); }
190 }
191 if lo >= Self::MODULUS {
192 lo - Self::MODULUS
193 } else {
194 lo
195 }
196 }
197 };
198
199 (@reduce_double, udouble, $T:ty, $D:ty) => {
205 fn reduce_double(v: $D) -> $T {
206 let mut lo = v.lo & Self::BITMASK;
207 let mut hi = v >> P;
208 while hi.hi > 0 {
209 let sum = if K == 1 { hi + lo } else { hi * K + lo };
210 lo = sum.lo & Self::BITMASK;
211 hi = sum >> P;
212 }
213 let mut hi = hi.lo;
214 macro_rules! mersenne_u128_fold {
215 () => {
216 let sum = if K == 1 { hi + lo } else { hi * K + lo };
217 lo = sum & Self::BITMASK;
218 hi = match sum.checked_shr(P as u32) {
219 Some(s) => s,
220 None => 0,
221 };
222 };
223 }
224 if Self::FOLDS <= 2 {
225 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
226 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
227 } else if Self::FOLDS == 3 {
228 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
229 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
230 #[allow(unused_assignments)] { mersenne_u128_fold!(); }
231 } else {
232 while hi > 0 { mersenne_u128_fold!(); }
233 }
234 if lo >= Self::MODULUS {
235 lo - Self::MODULUS
236 } else {
237 lo
238 }
239 }
240 };
241
242 (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
244 (*$lhs as $D) * (*$rhs as $D)
245 };
246
247 (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
249 <$D>::widening_mul(*$lhs, *$rhs)
250 };
251
252 (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
254 ($target as $D) * ($target as $D)
255 };
256
257 (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
259 <$D>::widening_square($target)
260 };
261}
262
263#[derive(Debug, Clone, Copy)]
282pub struct FixedMersenne32<const P: u8, const K: u32>();
283
284impl_fixed_mersenne!(FixedMersenne32, u32, u64, 16, 32, primitive);
285
286#[derive(Debug, Clone, Copy)]
306pub struct FixedMersenne64<const P: u8, const K: u64>();
307
308impl_fixed_mersenne!(FixedMersenne64, u64, u128, 32, 64, primitive);
309
310#[derive(Debug, Clone, Copy)]
331pub struct FixedMersenne<const P: u8, const K: umax>();
332
333impl_fixed_mersenne!(FixedMersenne, umax, udouble, 64, 128, udouble);
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::{ModularCoreOps, ModularPow};
339 use rand::random;
340
341 type M1 = FixedMersenne<31, 1>;
343 type M2 = FixedMersenne<61, 1>;
344 type M3 = FixedMersenne<127, 1>;
345 type M4 = FixedMersenne<32, 5>;
346 type M5 = FixedMersenne<56, 5>;
347 type M6 = FixedMersenne<122, 3>;
348 type M7 = FixedMersenne<128, 159>;
349
350 type M64_1 = FixedMersenne64<31, 1>;
352 type M64_2 = FixedMersenne64<61, 1>;
353 type M64_3 = FixedMersenne64<32, 5>;
354 type M64_4 = FixedMersenne64<64, 59>;
355
356 type M32_1 = FixedMersenne32<13, 1>;
358 type M32_2 = FixedMersenne32<31, 1>;
359 type M32_3 = FixedMersenne32<16, 5>;
360
361 const NRANDOM: u32 = 10;
362
363 #[test]
364 fn creation_test_u128() {
365 const P: umax = (1 << 31) - 1;
366 let m = M1::new(&P);
367 assert_eq!(m.residue(m.transform(0)), 0);
368 assert_eq!(m.residue(m.transform(1)), 1);
369 assert_eq!(m.residue(m.transform(P)), 0);
370 assert_eq!(m.residue(m.transform(P - 1)), P - 1);
371 assert_eq!(m.residue(m.transform(P + 1)), 1);
372
373 for _ in 0..NRANDOM {
374 let a = random::<umax>();
375
376 const P1: umax = (1 << 31) - 1;
377 let m1 = M1::new(&P1);
378 assert_eq!(m1.residue(m1.transform(a)), a % P1);
379 const P2: umax = (1 << 61) - 1;
380 let m2 = M2::new(&P2);
381 assert_eq!(m2.residue(m2.transform(a)), a % P2);
382 const P3: umax = (1 << 127) - 1;
383 let m3 = M3::new(&P3);
384 assert_eq!(m3.residue(m3.transform(a)), a % P3);
385 const P4: umax = (1 << 32) - 5;
386 let m4 = M4::new(&P4);
387 assert_eq!(m4.residue(m4.transform(a)), a % P4);
388 const P5: umax = (1 << 56) - 5;
389 let m5 = M5::new(&P5);
390 assert_eq!(m5.residue(m5.transform(a)), a % P5);
391 const P6: umax = (1 << 122) - 3;
392 let m6 = M6::new(&P6);
393 assert_eq!(m6.residue(m6.transform(a)), a % P6);
394 const P7: umax = M7::MODULUS;
395 let m7 = M7::new(&P7);
396 assert_eq!(m7.residue(m7.transform(a)), a % P7);
397 }
398 }
399
400 #[test]
401 fn creation_test_u64() {
402 for _ in 0..NRANDOM {
403 let a = random::<u64>();
404
405 const P1: u64 = (1 << 31) - 1;
406 let m1 = M64_1::new(&P1);
407 assert_eq!(m1.residue(m1.transform(a)), a % P1);
408 const P2: u64 = (1 << 61) - 1;
409 let m2 = M64_2::new(&P2);
410 assert_eq!(m2.residue(m2.transform(a)), a % P2);
411 const P3: u64 = (1 << 32) - 5;
412 let m3 = M64_3::new(&P3);
413 assert_eq!(m3.residue(m3.transform(a)), a % P3);
414 const P4: u64 = M64_4::MODULUS;
415 let m4 = M64_4::new(&P4);
416 assert_eq!(m4.residue(m4.transform(a)), a % P4);
417 }
418 }
419
420 #[test]
421 fn creation_test_u32() {
422 for _ in 0..NRANDOM {
423 let a = random::<u32>();
424
425 const P1: u32 = (1 << 13) - 1;
426 let m1 = M32_1::new(&P1);
427 assert_eq!(m1.residue(m1.transform(a)), a % P1);
428 const P2: u32 = (1 << 31) - 1;
429 let m2 = M32_2::new(&P2);
430 assert_eq!(m2.residue(m2.transform(a)), a % P2);
431 const P3: u32 = (1 << 16) - 5;
432 let m3 = M32_3::new(&P3);
433 assert_eq!(m3.residue(m3.transform(a)), a % P3);
434 }
435 }
436
437 #[test]
438 fn test_against_modops_u128() {
439 macro_rules! tests_for {
440 ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
441 const P: umax = <$M>::MODULUS;
442 let r = <$M>::new(&P);
443 let am = r.transform($a);
444 let bm = r.transform($b);
445 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
446 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
447 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
448 assert_eq!(r.neg(am), $a.negm(&P));
449 assert_eq!(r.inv(am), $a.invm(&P));
450 assert_eq!(r.dbl(am), $a.dblm(&P));
451 assert_eq!(r.sqr(am), $a.sqm(&P));
452 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
453 })*);
454 }
455
456 for _ in 0..NRANDOM {
457 let (a, b) = (random::<u128>(), random::<u128>());
458 let e = random::<u8>() as umax;
459 tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
460 }
461 }
462
463 #[test]
464 fn test_against_modops_u64() {
465 macro_rules! tests_for {
466 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
467 const P: u64 = <$M>::MODULUS;
468 let r = <$M>::new(&P);
469 let am = r.transform($a);
470 let bm = r.transform($b);
471 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
472 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
473 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
474 assert_eq!(r.neg(am), $a.negm(&P));
475 assert_eq!(r.inv(am), $a.invm(&P));
476 assert_eq!(r.dbl(am), $a.dblm(&P));
477 assert_eq!(r.sqr(am), $a.sqm(&P));
478 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
479 })*);
480 }
481
482 for _ in 0..NRANDOM {
483 let a = random::<u64>();
484 let b = random::<u64>();
485 let e = random::<u8>() as u64;
486 tests_for!(a, b, e; M64_1 M64_2 M64_3);
487 }
488 }
489
490 #[test]
491 fn test_against_modops_u32() {
492 macro_rules! tests_for {
493 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
494 const P: u32 = <$M>::MODULUS;
495 let r = <$M>::new(&P);
496 let am = r.transform($a);
497 let bm = r.transform($b);
498 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
499 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
500 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
501 assert_eq!(r.neg(am), $a.negm(&P));
502 assert_eq!(r.inv(am), $a.invm(&P));
503 assert_eq!(r.dbl(am), $a.dblm(&P));
504 assert_eq!(r.sqr(am), $a.sqm(&P));
505 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
506 })*);
507 }
508
509 for _ in 0..NRANDOM {
510 let a = random::<u32>();
511 let b = random::<u32>();
512 let e = random::<u8>() as u32;
513 tests_for!(a, b, e; M32_1 M32_2 M32_3);
514 }
515 }
516}