1use crate::impl_fixed_monty_ops;
2use crate::reduced::{impl_reduced_binary_pow, impl_reduced_ops};
3use crate::{powm_u32, powm_u64, udouble, umax, ModularUnaryOps, Reducer};
4
5macro_rules! debug_assert_prime_candidate {
22 ($m:expr) => {
23 debug_assert!(
24 ($m == 3 || $m % 3 != 0)
25 && ($m == 5 || $m % 5 != 0)
26 && ($m == 7 || $m % 7 != 0)
27 && ($m == 11 || $m % 11 != 0)
28 && ($m == 13 || $m % 13 != 0)
29 )
30 };
31}
32
33macro_rules! impl_fixed_proth_inherent {
34 ($TypeName:ident, $T:ty, $D:ty, $neginv_fn:path, $powm:ident) => {
35 impl<const N: u8, const K: u8> $TypeName<N, K> {
36 const _N_BOUND_CHECK: () = assert!((N as u32) < <$T>::BITS);
38
39 pub const MODULUS: $T = {
40 let p2n = match (1 as $T).checked_shl(N as u32) {
41 Some(v) => v,
42 None => unreachable!(),
43 };
44 let m = (K as $T).wrapping_mul(p2n).wrapping_add(1);
45 assert!(
48 m as u128
49 <= match <$T>::BITS {
50 32 => 2_654_435_769u128,
51 64 => 11_400_714_819_323_199_485u128,
52 _ => unreachable!(),
53 },
54 "MODULUS exceeds overflow-free bound; lower N or use FixedMontgomery"
55 );
56 m
57 };
58
59 const N0: $T = $neginv_fn(Self::MODULUS);
61
62 const R2: $T = $powm(2, (2 * <$T>::BITS) as $T, Self::MODULUS);
64
65 #[inline]
66 pub fn reduce(&self, t: $D) -> $T {
67 let m = (t as $T).wrapping_mul(Self::N0);
71 let mp = ((m as $D) * (K as $D)) << N;
73 let mp = mp.wrapping_add(m as $D);
74 let r = (t.wrapping_add(mp) >> <$T>::BITS) as $T;
75 if r >= Self::MODULUS {
76 r - Self::MODULUS
77 } else {
78 r
79 }
80 }
81 }
82 };
83}
84
85#[must_use]
104#[derive(Debug, Clone, Copy)]
105pub struct FixedProth32<const N: u8, const K: u8>;
106
107impl_fixed_proth_inherent!(
108 FixedProth32,
109 u32,
110 u64,
111 crate::monty::neg_mod_inv::u32::neginv,
112 powm_u32
113);
114
115impl<const N: u8, const K: u8> Reducer<u32> for FixedProth32<N, K> {
116 #[inline]
117 fn new(m: &u32) -> Self {
118 assert!(
119 *m == Self::MODULUS,
120 "the given modulus doesn't match with the generic params"
121 );
122 assert!(N < 32, "N must be less than type bit width");
123 assert!(N > 0, "N must be positive");
124 assert!(K > 0, "K must be positive");
125 assert!(K % 2 == 1, "K must be odd");
126 assert!(
127 (K as u64) * (1_u64 << (N as u32)) < u32::MAX as u64,
128 "K·2^N + 1 exceeds type maximum"
129 );
130 debug_assert!((K as u32) < (1u32 << (N as u32)), "K must be less than 2^N");
131 debug_assert_prime_candidate!(Self::MODULUS);
132 Self {}
133 }
134 impl_fixed_monty_ops!(u32, u64, Self::R2, primitive);
135}
136
137#[must_use]
155#[derive(Debug, Clone, Copy)]
156pub struct FixedProth64<const N: u8, const K: u8>;
157
158impl_fixed_proth_inherent!(
159 FixedProth64,
160 u64,
161 u128,
162 crate::monty::neg_mod_inv::u64::neginv,
163 powm_u64
164);
165
166impl<const N: u8, const K: u8> Reducer<u64> for FixedProth64<N, K> {
167 #[inline]
168 fn new(m: &u64) -> Self {
169 assert!(
170 *m == Self::MODULUS,
171 "the given modulus doesn't match with the generic params"
172 );
173 assert!(N < 64, "N must be less than type bit width");
174 assert!(N > 0, "N must be positive");
175 assert!(K > 0, "K must be positive");
176 assert!(K % 2 == 1, "K must be odd");
177 assert!(
178 (K as u128) * (1_u128 << (N as u32)) < u64::MAX as u128,
179 "K·2^N + 1 exceeds type maximum"
180 );
181 debug_assert!((K as u64) < (1u64 << (N as u32)), "K must be less than 2^N");
182 debug_assert_prime_candidate!(Self::MODULUS);
183 Self {}
184 }
185 impl_fixed_monty_ops!(u64, u128, Self::R2, primitive);
186}
187
188#[must_use]
208#[derive(Debug, Clone, Copy)]
209pub struct FixedProth<const N: u8, const K: u8>;
210
211impl<const N: u8, const K: u8> FixedProth<N, K> {
212 const _N_BOUND_CHECK_U128: () = assert!(N < 128);
214
215 pub const MODULUS: umax = {
216 let p2n = match 1u128.checked_shl(N as u32) {
217 Some(v) => v,
218 None => unreachable!(),
219 };
220 let m = (K as u128).wrapping_mul(p2n).wrapping_add(1);
221 assert!(
224 m <= 210_306_068_529_402_891_650_266_558_847_000_772_608,
225 "MODULUS exceeds overflow-free bound; lower N or use FixedMontgomery"
226 );
227 m
228 };
229
230 const N0: umax = crate::monty::neg_mod_inv::u128::neginv(Self::MODULUS);
232
233 const R2: umax = {
235 let r = udouble { hi: 1, lo: 0 }.div_rem_2by1(Self::MODULUS).1; udouble::widening_square(r).div_rem_2by1(Self::MODULUS).1 };
238
239 #[must_use]
241 #[inline]
242 pub fn reduce(&self, t: udouble) -> umax {
243 let m = t.lo.wrapping_mul(Self::N0);
244 let mk = udouble::widening_mul(m, K as u128);
247 let mp = mk.shl_u32(N as u32) + udouble { hi: 0, lo: m };
248 let r = (t + mp).hi;
249 if r >= Self::MODULUS {
250 r - Self::MODULUS
251 } else {
252 r
253 }
254 }
255}
256
257impl<const N: u8, const K: u8> Reducer<umax> for FixedProth<N, K> {
258 #[inline]
259 fn new(m: &umax) -> Self {
260 assert!(
261 *m == Self::MODULUS,
262 "the given modulus doesn't match with the generic params"
263 );
264 assert!(N < 128, "N must be less than type bit width");
265 assert!(N > 0, "N must be positive");
266 assert!(K > 0, "K must be positive");
267 assert!(K % 2 == 1, "K must be odd");
268 assert!(
269 (K as u128) * (1u128 << (N as u32)) < u128::MAX,
270 "K·2^N + 1 exceeds type maximum"
271 );
272 debug_assert!(
273 (K as u128) < (1u128 << (N as u32)),
274 "K must be less than 2^N"
275 );
276 debug_assert_prime_candidate!(Self::MODULUS);
277 Self {}
278 }
279 impl_fixed_monty_ops!(umax, udouble, Self::R2, udouble);
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::{ModularCoreOps, ModularPow};
286 use rand::random;
287
288 type P128_1 = FixedProth<2, 1>; type P128_2 = FixedProth<4, 1>; type P128_3 = FixedProth<5, 3>; type P128_4 = FixedProth<8, 3>; type P128_5 = FixedProth<16, 1>; type P64_1 = FixedProth64<4, 1>; type P64_2 = FixedProth64<5, 3>; type P64_3 = FixedProth64<8, 1>; type P64_4 = FixedProth64<16, 1>; type P32_1 = FixedProth32<2, 1>; type P32_2 = FixedProth32<2, 3>; type P32_3 = FixedProth32<4, 1>; type P32_4 = FixedProth32<3, 5>; const NRANDOM: u32 = 10;
308
309 #[test]
310 fn creation_test_u128() {
311 for _ in 0..NRANDOM {
312 let a = random::<u128>();
313
314 const M1: u128 = <P128_1>::MODULUS;
315 let r1 = P128_1::new(&M1);
316 assert_eq!(r1.residue(r1.transform(a % M1)), a % M1);
317
318 const M2: u128 = <P128_2>::MODULUS;
319 let r2 = P128_2::new(&M2);
320 assert_eq!(r2.residue(r2.transform(a % M2)), a % M2);
321
322 const M3: u128 = <P128_3>::MODULUS;
323 let r3 = P128_3::new(&M3);
324 assert_eq!(r3.residue(r3.transform(a % M3)), a % M3);
325
326 const M4: u128 = <P128_4>::MODULUS;
327 let r4 = P128_4::new(&M4);
328 assert_eq!(r4.residue(r4.transform(a % M4)), a % M4);
329
330 const M5: u128 = <P128_5>::MODULUS;
331 let r5 = P128_5::new(&M5);
332 assert_eq!(r5.residue(r5.transform(a % M5)), a % M5);
333 }
334 }
335
336 #[test]
337 fn creation_test_u64() {
338 for _ in 0..NRANDOM {
339 let a = random::<u64>();
340
341 const M1: u64 = <P64_1>::MODULUS;
342 let r1 = P64_1::new(&M1);
343 assert_eq!(r1.residue(r1.transform(a % M1)), a % M1);
344
345 const M2: u64 = <P64_2>::MODULUS;
346 let r2 = P64_2::new(&M2);
347 assert_eq!(r2.residue(r2.transform(a % M2)), a % M2);
348
349 const M3: u64 = <P64_3>::MODULUS;
350 let r3 = P64_3::new(&M3);
351 assert_eq!(r3.residue(r3.transform(a % M3)), a % M3);
352
353 const M4: u64 = <P64_4>::MODULUS;
354 let r4 = P64_4::new(&M4);
355 assert_eq!(r4.residue(r4.transform(a % M4)), a % M4);
356 }
357 }
358
359 #[test]
360 fn creation_test_u32() {
361 for _ in 0..NRANDOM {
362 let a = random::<u32>();
363
364 const M1: u32 = <P32_1>::MODULUS;
365 let r1 = P32_1::new(&M1);
366 assert_eq!(r1.residue(r1.transform(a % M1)), a % M1);
367
368 const M2: u32 = <P32_2>::MODULUS;
369 let r2 = P32_2::new(&M2);
370 assert_eq!(r2.residue(r2.transform(a % M2)), a % M2);
371
372 const M3: u32 = <P32_3>::MODULUS;
373 let r3 = P32_3::new(&M3);
374 assert_eq!(r3.residue(r3.transform(a % M3)), a % M3);
375
376 const M4: u32 = <P32_4>::MODULUS;
377 let r4 = P32_4::new(&M4);
378 assert_eq!(r4.residue(r4.transform(a % M4)), a % M4);
379 }
380 }
381
382 #[test]
383 fn test_against_modops_u128() {
384 macro_rules! tests_for {
385 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
386 const P: u128 = <$M>::MODULUS;
387 let r = <$M>::new(&P);
388 let am = r.transform($a);
389 let bm = r.transform($b);
390 assert_eq!(r.residue(r.add(&am, &bm)), $a.addm($b, &P));
391 assert_eq!(r.residue(r.sub(&am, &bm)), $a.subm($b, &P));
392 assert_eq!(r.residue(r.mul(&am, &bm)), $a.mulm($b, &P));
393 assert_eq!(r.residue(r.neg(am)), $a.negm(&P));
394 assert_eq!(r.residue(r.dbl(am)), $a.dblm(&P));
395 assert_eq!(r.residue(r.sqr(am)), $a.sqm(&P));
396 assert_eq!(r.residue(r.pow(am, &$e)), $a.powm($e, &P));
397 if let (Some(inv), Some(ref_inv)) = (r.inv(am), $a.invm(&P)) {
398 assert_eq!(r.residue(inv), ref_inv);
399 }
400 })*);
401 }
402
403 for _ in 0..NRANDOM {
404 let a = random::<u128>();
405 let b = random::<u128>();
406 let e = random::<u8>() as u128;
407 tests_for!(a, b, e; P128_1 P128_2 P128_3 P128_4 P128_5);
408 }
409 }
410
411 #[test]
412 fn test_against_modops_u64() {
413 macro_rules! tests_for {
414 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
415 const P: u64 = <$M>::MODULUS;
416 let r = <$M>::new(&P);
417 let am = r.transform($a);
418 let bm = r.transform($b);
419 assert_eq!(r.residue(r.add(&am, &bm)), $a.addm($b, &P));
420 assert_eq!(r.residue(r.sub(&am, &bm)), $a.subm($b, &P));
421 assert_eq!(r.residue(r.mul(&am, &bm)), $a.mulm($b, &P));
422 assert_eq!(r.residue(r.neg(am)), $a.negm(&P));
423 assert_eq!(r.residue(r.dbl(am)), $a.dblm(&P));
424 assert_eq!(r.residue(r.sqr(am)), $a.sqm(&P));
425 assert_eq!(r.residue(r.pow(am, &$e)), $a.powm($e, &P));
426 if let (Some(inv), Some(ref_inv)) = (r.inv(am), $a.invm(&P)) {
427 assert_eq!(r.residue(inv), ref_inv);
428 }
429 })*);
430 }
431
432 for _ in 0..NRANDOM {
433 let a = random::<u64>();
434 let b = random::<u64>();
435 let e = random::<u8>() as u64;
436 tests_for!(a, b, e; P64_1 P64_2 P64_3 P64_4);
437 }
438 }
439
440 #[test]
441 fn test_against_modops_u32() {
442 macro_rules! tests_for {
443 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
444 const P: u32 = <$M>::MODULUS;
445 let r = <$M>::new(&P);
446 let am = r.transform($a);
447 let bm = r.transform($b);
448 assert_eq!(r.residue(r.add(&am, &bm)), $a.addm($b, &P));
449 assert_eq!(r.residue(r.sub(&am, &bm)), $a.subm($b, &P));
450 assert_eq!(r.residue(r.mul(&am, &bm)), $a.mulm($b, &P));
451 assert_eq!(r.residue(r.neg(am)), $a.negm(&P));
452 assert_eq!(r.residue(r.dbl(am)), $a.dblm(&P));
453 assert_eq!(r.residue(r.sqr(am)), $a.sqm(&P));
454 assert_eq!(r.residue(r.pow(am, &$e)), $a.powm($e, &P));
455 if let (Some(inv), Some(ref_inv)) = (r.inv(am), $a.invm(&P)) {
456 assert_eq!(r.residue(inv), ref_inv);
457 }
458 })*);
459 }
460
461 for _ in 0..NRANDOM {
462 let a = random::<u32>();
463 let b = random::<u32>();
464 let e = random::<u8>() as u32;
465 tests_for!(a, b, e; P32_1 P32_2 P32_3 P32_4);
466 }
467 }
468
469 #[test]
470 fn test_add_near_overflow_u64() {
471 type S = FixedProth64<32, 3>;
472 const M: u64 = <S>::MODULUS;
473 let r = S::new(&M);
474
475 let a = M - 1;
476 let b = M - 2;
477 let am = r.transform(a);
478 let bm = r.transform(b);
479 let sum = r.add(&am, &bm);
480 assert_eq!(r.residue(sum), a.addm(b, &M));
481
482 let a2 = M - 1;
483 let a2m = r.transform(a2);
484 let dbl = r.dbl(a2m);
485 assert_eq!(r.residue(dbl), a2.dblm(&M));
486 }
487
488 #[test]
490 fn test_reduce_near_bound() {
491 type S = FixedProth32<23, 255>;
493 const M: u32 = <S>::MODULUS;
494 let r = S::new(&M);
495
496 for _ in 0..10 {
497 let a = random::<u32>() % M;
498 let b = random::<u32>() % M;
499 let am = r.transform(a);
500 let bm = r.transform(b);
501 let result = r.residue(r.mul(&am, &bm));
502 assert_eq!(result, a.mulm(b, &M));
503 }
504 }
505
506 #[test]
508 fn test_inv_no_truncation_u128() {
509 type S = FixedProth<60, 31>;
512 const M: u128 = <S>::MODULUS;
513 assert!(
514 M > u64::MAX as u128,
515 "MODULUS must exceed usize for this test"
516 );
517 let r = S::new(&M);
518
519 let a: u128 = 1234567890123456789 % M;
520 let a_mont = r.transform(a);
521 let inv = r.inv(a_mont).expect("inv should succeed");
522 let result = r.residue(inv);
523 assert_eq!(result.mulm(a, &M), 1u128, "inv truncation bug");
524 }
525
526 #[test]
528 #[should_panic(expected = "exceeds type maximum")]
529 fn test_modulus_overflow_panics_u32() {
530 type S = FixedProth32<31, 3>; const M: u32 = <S>::MODULUS; let _ = S::new(&M); }
534
535 #[test]
538 fn test_reduce_n_gt_64() {
539 type S = FixedProth<65, 3>; const M: u128 = <S>::MODULUS;
541 let r = S::new(&M);
542
543 for _ in 0..10 {
544 let a = random::<u128>() % M;
545 let b = random::<u128>() % M;
546 let am = r.transform(a);
547 let bm = r.transform(b);
548 let result = r.residue(r.mul(&am, &bm));
549 assert_eq!(result, a.mulm(b, &M), "shift truncation bug for N>64");
550 }
551 }
552}