1use crate::reduced::{impl_reduced_binary_pow, impl_reduced_ops};
2use crate::{imax, udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_trinomial_solinas {
7 (
8 $TypeName:ident,
9 $T:ty,
10 $K:ty,
11 $D:ty,
12 $half_bits:expr,
13 $max_P1:expr,
14 $kind:ident
15 ) => {
16 impl<const P1: u8, const P2: u8, const K: $K> $TypeName<P1, P2, K> {
17 const BITMASK: $T = match (1 as $T).checked_shl(P1 as u32) {
18 Some(v) => v.wrapping_sub(1),
19 None => <$T>::MAX,
20 };
21 pub const MODULUS: $T = {
22 let p1 = match (1 as $T).checked_shl(P1 as u32) {
23 Some(v) => v,
24 None => 0,
25 };
26 let p2 = match (1 as $T).checked_shl(P2 as u32) {
27 Some(v) => v,
28 None => panic!("P2 exceeds type width"),
29 };
30 if K >= 0 {
31 p1.wrapping_sub(p2).wrapping_add(K as $T)
32 } else {
33 p1.wrapping_sub(p2).wrapping_sub((-K) as $T)
34 }
35 };
36
37 const FOLDS: u32 = {
41 let gap = (P1 - P2) as u32;
42 let folds_ceil = ((P1 as u32) + gap - 1) / gap;
43 if K > 0 {
44 folds_ceil + 1
45 } else if K < 0 {
46 folds_ceil + 2
47 } else {
48 1 }
50 };
51
52 impl_fixed_trinomial_solinas!(@reduce_single, $kind, $T, $D);
53 impl_fixed_trinomial_solinas!(@reduce_double, $kind, $T, $D);
54 }
55
56 impl<const P1: u8, const P2: u8, const K: $K> Reducer<$T> for $TypeName<P1, P2, K> {
57 #[inline]
58 fn new(m: &$T) -> Self {
59 assert!(
60 *m == Self::MODULUS,
61 "the given modulus doesn't match with the generic params"
62 );
63 debug_assert!(P1 <= $max_P1);
64 debug_assert!(P2 > 0 && P1 > P2);
65 debug_assert!(K % 2 != 0); debug_assert!((K.unsigned_abs() as u128) < (1u128 << (P2 as u32)));
68 debug_assert!(
69 (Self::MODULUS == 3 || Self::MODULUS % 3 != 0)
70 && (Self::MODULUS == 5 || Self::MODULUS % 5 != 0)
71 && (Self::MODULUS == 7 || Self::MODULUS % 7 != 0)
72 && (Self::MODULUS == 11 || Self::MODULUS % 11 != 0)
73 && (Self::MODULUS == 13 || Self::MODULUS % 13 != 0)
74 ); Self {}
76 }
77 #[inline]
78 fn transform(&self, target: $T) -> $T {
79 Self::reduce_single(target)
80 }
81 #[inline]
82 fn residue(&self, target: $T) -> $T {
83 target
84 }
85
86 impl_reduced_ops!($T);
87
88 #[inline]
89 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
90 if (P1 as u32) < $half_bits {
91 Self::reduce_single(lhs * rhs)
92 } else {
93 Self::reduce_double(impl_fixed_trinomial_solinas!(@widen_mul, $kind, $T, $D, lhs, rhs))
94 }
95 }
96 #[inline]
97 fn inv(&self, target: $T) -> Option<$T> {
98 if (P1 as u32) < usize::BITS {
101 (target as usize)
102 .invm(&(Self::MODULUS as usize))
103 .map(|v| v as $T)
104 } else {
105 target.invm(&Self::MODULUS)
106 }
107 }
108 #[inline]
109 fn sqr(&self, target: $T) -> $T {
110 if (P1 as u32) < $half_bits {
111 Self::reduce_single(target * target)
112 } else {
113 Self::reduce_double(impl_fixed_trinomial_solinas!(@widen_sqr, $kind, $T, $D, target))
114 }
115 }
116
117 impl_reduced_binary_pow!($T);
118 }
119 };
120
121 (@reduce_single, primitive, $T:ty, $D:ty) => {
123 pub const fn reduce_single(v: $T) -> $T {
128 let mut v: $D = v as $D;
129 while v >> P1 > 0 {
130 let lo = (v as $T) & Self::BITMASK;
131 let hi = v >> P1;
132 let mut sum: $D = (hi << (P2 as u32)) + (lo as $D);
133 if K > 0 {
134 sum -= hi * (K as $D);
135 } else if K < 0 {
136 sum += hi * ((-K) as $D);
137 }
138 v = sum;
139 }
140 let v = v as $T;
141 if v >= Self::MODULUS {
142 v - Self::MODULUS
143 } else {
144 v
145 }
146 }
147 };
148
149 (@reduce_single, udouble, $T:ty, $D:ty) => {
152 pub fn reduce_single(v: $T) -> $T {
157 let mut v: $D = udouble { hi: 0, lo: v };
158 while v.hi > 0 || v.lo >> P1 > 0 {
159 let lo = v.lo & Self::BITMASK;
160 let hi = v >> P1;
161 let mut sum = (hi << (P2 as u32)) + lo;
162 if K > 0 {
163 sum -= hi * (K as umax);
164 } else if K < 0 {
165 sum += hi * ((-K) as umax);
166 }
167 v = sum;
168 }
169 let v = v.lo;
170 if v >= Self::MODULUS {
171 v - Self::MODULUS
172 } else {
173 v
174 }
175 }
176 };
177
178 (@reduce_double, primitive, $T:ty, $D:ty) => {
185 pub fn reduce_double(v: $D) -> $T {
190 let mut lo = (v as $T) & Self::BITMASK;
191 let mut hi = v >> P1;
192 macro_rules! solinas_fold {
193 () => {
194 let mut sum: $D = (hi << (P2 as u32)) + (lo as $D);
195 if K > 0 { sum -= hi * (K as $D); }
196 else if K < 0 { sum += hi * ((-K) as $D); }
197 lo = (sum as $T) & Self::BITMASK;
198 hi = sum >> P1;
199 };
200 }
201 if Self::FOLDS <= 3 {
202 #[allow(unused_assignments)] { solinas_fold!(); }
203 #[allow(unused_assignments)] { solinas_fold!(); }
204 #[allow(unused_assignments)] { solinas_fold!(); }
205 } else if Self::FOLDS == 4 {
206 #[allow(unused_assignments)] { solinas_fold!(); }
207 #[allow(unused_assignments)] { solinas_fold!(); }
208 #[allow(unused_assignments)] { solinas_fold!(); }
209 #[allow(unused_assignments)] { solinas_fold!(); }
210 } else {
211 while hi > 0 { solinas_fold!(); }
212 }
213 if lo >= Self::MODULUS {
214 lo - Self::MODULUS
215 } else {
216 lo
217 }
218 }
219 };
220
221 (@reduce_double, udouble, $T:ty, $D:ty) => {
229 pub fn reduce_double(v: $D) -> $T {
234 let mut lo = v.lo & Self::BITMASK;
235 let mut hi = v >> P1;
236 macro_rules! udouble_fold {
237 () => {
238 let mut sum = (hi << (P2 as u32)) + lo;
239 if K > 0 { sum -= hi * (K as umax); }
240 else if K < 0 { sum += hi * ((-K) as umax); }
241 lo = sum.lo & Self::BITMASK;
242 hi = sum >> P1;
243 };
244 }
245 if Self::FOLDS <= 3 {
246 #[allow(unused_assignments)] { udouble_fold!(); }
247 #[allow(unused_assignments)] { udouble_fold!(); }
248 #[allow(unused_assignments)] { udouble_fold!(); }
249 } else if Self::FOLDS == 4 {
250 #[allow(unused_assignments)] { udouble_fold!(); }
251 #[allow(unused_assignments)] { udouble_fold!(); }
252 #[allow(unused_assignments)] { udouble_fold!(); }
253 #[allow(unused_assignments)] { udouble_fold!(); }
254 } else {
255 while hi.hi > 0 || hi.lo > 0 { udouble_fold!(); }
256 }
257 if lo >= Self::MODULUS {
258 lo - Self::MODULUS
259 } else {
260 lo
261 }
262 }
263 };
264
265 (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
267 (*$lhs as $D) * (*$rhs as $D)
268 };
269
270 (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
272 <$D>::widening_mul(*$lhs, *$rhs)
273 };
274
275 (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
277 ($target as $D) * ($target as $D)
278 };
279
280 (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
282 <$D>::widening_square($target)
283 };
284}
285
286#[must_use]
306#[derive(Debug, Clone, Copy)]
307pub struct FixedTrinomialSolinas32<const P1: u8, const P2: u8, const K: i32>();
308
309impl_fixed_trinomial_solinas!(FixedTrinomialSolinas32, u32, i32, u64, 16, 32, primitive);
310
311#[must_use]
332#[derive(Debug, Clone, Copy)]
333pub struct FixedTrinomialSolinas64<const P1: u8, const P2: u8, const K: i64>();
334
335impl_fixed_trinomial_solinas!(FixedTrinomialSolinas64, u64, i64, u128, 32, 64, primitive);
336
337#[must_use]
359#[derive(Debug, Clone, Copy)]
360pub struct FixedTrinomialSolinas<const P1: u8, const P2: u8, const K: imax>();
361
362impl_fixed_trinomial_solinas!(FixedTrinomialSolinas, umax, imax, udouble, 64, 127, udouble);
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::{ModularCoreOps, ModularPow};
368 use rand::random;
369
370 type S1 = FixedTrinomialSolinas<31, 13, 1>;
372 type S2 = FixedTrinomialSolinas<61, 30, 1>;
373 type S3 = FixedTrinomialSolinas<127, 64, 1>;
374 type S4 = FixedTrinomialSolinas<32, 16, 1>;
375 type S5 = FixedTrinomialSolinas<56, 28, -1>;
376 type S6 = FixedTrinomialSolinas<122, 61, -3>;
377
378 type S64_1 = FixedTrinomialSolinas64<31, 13, 1>;
380 type S64_2 = FixedTrinomialSolinas64<61, 30, 1>;
381 type S64_3 = FixedTrinomialSolinas64<32, 16, 1>;
382 type S64_4 = FixedTrinomialSolinas64<64, 32, 1>; type S32_1 = FixedTrinomialSolinas32<4, 2, 1>;
386 type S32_2 = FixedTrinomialSolinas32<5, 3, -1>;
387 type S32_3 = FixedTrinomialSolinas32<6, 2, 1>;
388 type S32_4 = FixedTrinomialSolinas32<32, 20, 1>;
389
390 const NRANDOM: u32 = 10;
391
392 #[test]
393 fn creation_test_u128() {
394 const P: umax = <S1>::MODULUS;
395 let m = S1::new(&P);
396 assert_eq!(m.residue(m.transform(0)), 0);
397 assert_eq!(m.residue(m.transform(1)), 1);
398 assert_eq!(m.residue(m.transform(P)), 0);
399 assert_eq!(m.residue(m.transform(P - 1)), P - 1);
400 assert_eq!(m.residue(m.transform(P + 1)), 1);
401
402 for _ in 0..NRANDOM {
403 let a = random::<umax>();
404
405 const P1: umax = <S1>::MODULUS;
406 let m1 = S1::new(&P1);
407 assert_eq!(m1.residue(m1.transform(a)), a % P1);
408 const P2: umax = <S2>::MODULUS;
409 let m2 = S2::new(&P2);
410 assert_eq!(m2.residue(m2.transform(a)), a % P2);
411 const P3: umax = <S3>::MODULUS;
412 let m3 = S3::new(&P3);
413 assert_eq!(m3.residue(m3.transform(a)), a % P3);
414 const P4: umax = <S4>::MODULUS;
415 let m4 = S4::new(&P4);
416 assert_eq!(m4.residue(m4.transform(a)), a % P4);
417 const P5: umax = <S5>::MODULUS;
418 let m5 = S5::new(&P5);
419 assert_eq!(m5.residue(m5.transform(a)), a % P5);
420 const P6: umax = <S6>::MODULUS;
421 let m6 = S6::new(&P6);
422 assert_eq!(m6.residue(m6.transform(a)), a % P6);
423 }
424 }
425
426 #[test]
427 fn creation_test_u64() {
428 for _ in 0..NRANDOM {
429 let a = random::<u64>();
430
431 const P1: u64 = <S64_1>::MODULUS;
432 let m1 = S64_1::new(&P1);
433 assert_eq!(m1.residue(m1.transform(a)), a % P1);
434 const P2: u64 = <S64_2>::MODULUS;
435 let m2 = S64_2::new(&P2);
436 assert_eq!(m2.residue(m2.transform(a)), a % P2);
437 const P3: u64 = <S64_3>::MODULUS;
438 let m3 = S64_3::new(&P3);
439 assert_eq!(m3.residue(m3.transform(a)), a % P3);
440 const P4: u64 = <S64_4>::MODULUS;
441 let m4 = S64_4::new(&P4);
442 assert_eq!(m4.residue(m4.transform(a)), a % P4);
443 }
444 }
445
446 #[test]
447 fn creation_test_u32() {
448 for _ in 0..NRANDOM {
449 let a = random::<u32>();
450
451 const P1: u32 = <S32_1>::MODULUS;
452 let m1 = S32_1::new(&P1);
453 assert_eq!(m1.residue(m1.transform(a)), a % P1);
454 const P2: u32 = <S32_2>::MODULUS;
455 let m2 = S32_2::new(&P2);
456 assert_eq!(m2.residue(m2.transform(a)), a % P2);
457 const P3: u32 = <S32_3>::MODULUS;
458 let m3 = S32_3::new(&P3);
459 assert_eq!(m3.residue(m3.transform(a)), a % P3);
460 const P4: u32 = <S32_4>::MODULUS;
461 let m4 = S32_4::new(&P4);
462 assert_eq!(m4.residue(m4.transform(a)), a % P4);
463 }
464 }
465
466 #[test]
467 fn test_against_modops_u128() {
468 macro_rules! tests_for {
469 ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
470 const P: umax = <$M>::MODULUS;
471 let r = <$M>::new(&P);
472 let am = r.transform($a);
473 let bm = r.transform($b);
474 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
475 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
476 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
477 assert_eq!(r.neg(am), $a.negm(&P));
478 assert_eq!(r.inv(am), $a.invm(&P));
479 assert_eq!(r.dbl(am), $a.dblm(&P));
480 assert_eq!(r.sqr(am), $a.sqm(&P));
481 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
482 })*);
483 }
484
485 for _ in 0..NRANDOM {
486 let (a, b) = (random::<u128>(), random::<u128>());
487 let e = random::<u8>() as umax;
488 tests_for!(a, b, e; S1 S2 S3 S4 S5 S6);
489 }
490 }
491
492 #[test]
493 fn test_against_modops_u64() {
494 macro_rules! tests_for {
495 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
496 const P: u64 = <$M>::MODULUS;
497 let r = <$M>::new(&P);
498 let am = r.transform($a);
499 let bm = r.transform($b);
500 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
501 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
502 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
503 assert_eq!(r.neg(am), $a.negm(&P));
504 assert_eq!(r.inv(am), $a.invm(&P));
505 assert_eq!(r.dbl(am), $a.dblm(&P));
506 assert_eq!(r.sqr(am), $a.sqm(&P));
507 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
508 })*);
509 }
510
511 for _ in 0..NRANDOM {
512 let a = random::<u64>();
513 let b = random::<u64>();
514 let e = random::<u8>() as u64;
515 tests_for!(a, b, e; S64_1 S64_2 S64_3 S64_4);
516 }
517 }
518
519 #[test]
520 fn test_against_modops_u32() {
521 macro_rules! tests_for {
522 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
523 const P: u32 = <$M>::MODULUS;
524 let r = <$M>::new(&P);
525 let am = r.transform($a);
526 let bm = r.transform($b);
527 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
528 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
529 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
530 assert_eq!(r.neg(am), $a.negm(&P));
531 assert_eq!(r.inv(am), $a.invm(&P));
532 assert_eq!(r.dbl(am), $a.dblm(&P));
533 assert_eq!(r.sqr(am), $a.sqm(&P));
534 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
535 })*);
536 }
537
538 for _ in 0..NRANDOM {
539 let a = random::<u32>();
540 let b = random::<u32>();
541 let e = random::<u8>() as u32;
542 tests_for!(a, b, e; S32_1 S32_2 S32_3 S32_4);
543 }
544 }
545
546 #[test]
547 fn test_add_near_overflow_u64() {
548 type S = FixedTrinomialSolinas64<64, 32, 1>;
550 const P: u64 = <S>::MODULUS;
551 assert_eq!(P, 0xFFFFFFFF00000001);
552 let r = S::new(&P);
553 let a = r.transform(P - 1);
556 let b = r.transform(P - 2);
557 assert_eq!(r.residue(r.add(&a, &b)), P - 3);
558 let c = r.transform(P - 1);
560 assert_eq!(r.residue(r.dbl(c)), P - 2);
561 }
562}