1use crate::reduced::impl_reduced_binary_pow;
2use crate::{imax, udouble, umax, ModularUnaryOps, Reducer};
3
4macro_rules! impl_fixed_trinomial_solinas {
7 (
8 $TypeName:ident,
9 $T:ty,
10 $K:ty,
11 $D:ty,
12 $half_bits:expr,
13 $max_P1:expr,
14 $kind:ident
15 ) => {
16 impl<const P1: u8, const P2: u8, const K: $K> $TypeName<P1, P2, K> {
17 const BITMASK: $T = match (1 as $T).checked_shl(P1 as u32) {
18 Some(v) => v.wrapping_sub(1),
19 None => <$T>::MAX,
20 };
21 pub const MODULUS: $T = {
22 let p1 = match (1 as $T).checked_shl(P1 as u32) {
23 Some(v) => v,
24 None => 0,
25 };
26 let p2 = match (1 as $T).checked_shl(P2 as u32) {
27 Some(v) => v,
28 None => panic!("P2 exceeds type width"),
29 };
30 if K >= 0 {
31 p1.wrapping_sub(p2).wrapping_add(K as $T)
32 } else {
33 p1.wrapping_sub(p2).wrapping_sub((-K) as $T)
34 }
35 };
36
37 const FOLDS: u32 = {
41 let gap = (P1 - P2) as u32;
42 let folds_ceil = ((P1 as u32) + gap - 1) / gap;
43 if K > 0 {
44 folds_ceil + 1
45 } else if K < 0 {
46 folds_ceil + 2
47 } else {
48 1 }
50 };
51
52 impl_fixed_trinomial_solinas!(@reduce_single, $kind, $T, $D);
53 impl_fixed_trinomial_solinas!(@reduce_double, $kind, $T, $D);
54 }
55
56 impl<const P1: u8, const P2: u8, const K: $K> Reducer<$T> for $TypeName<P1, P2, K> {
57 #[inline]
58 fn new(m: &$T) -> Self {
59 assert!(
60 *m == Self::MODULUS,
61 "the given modulus doesn't match with the generic params"
62 );
63 debug_assert!(P1 <= $max_P1);
64 debug_assert!(P2 > 0 && P1 > P2);
65 debug_assert!(K % 2 != 0); debug_assert!((K.unsigned_abs() as u128) < (1u128 << (P2 as u32)));
68 debug_assert!(
69 (Self::MODULUS == 3 || Self::MODULUS % 3 != 0)
70 && (Self::MODULUS == 5 || Self::MODULUS % 5 != 0)
71 && (Self::MODULUS == 7 || Self::MODULUS % 7 != 0)
72 && (Self::MODULUS == 11 || Self::MODULUS % 11 != 0)
73 && (Self::MODULUS == 13 || Self::MODULUS % 13 != 0)
74 ); Self {}
76 }
77 #[inline]
78 fn transform(&self, target: $T) -> $T {
79 Self::reduce_single(target)
80 }
81 #[inline]
82 fn check(&self, target: &$T) -> bool {
83 *target < Self::MODULUS
84 }
85 #[inline]
86 fn residue(&self, target: $T) -> $T {
87 target
88 }
89 #[inline]
90 fn modulus(&self) -> $T {
91 Self::MODULUS
92 }
93 #[inline]
94 fn is_zero(&self, target: &$T) -> bool {
95 target == &0
96 }
97
98 #[inline]
99 fn add(&self, lhs: &$T, rhs: &$T) -> $T {
100 let (sum, overflow) = lhs.overflowing_add(*rhs);
101 if overflow || sum >= Self::MODULUS {
102 let (sum2, _) = sum.overflowing_sub(Self::MODULUS);
103 sum2
104 } else {
105 sum
106 }
107 }
108 #[inline]
109 fn sub(&self, lhs: &$T, rhs: &$T) -> $T {
110 if lhs >= rhs {
111 lhs - rhs
112 } else {
113 Self::MODULUS - (rhs - lhs)
114 }
115 }
116 #[inline]
117 fn dbl(&self, target: $T) -> $T {
118 let (sum, overflow) = target.overflowing_add(target);
119 if overflow || sum >= Self::MODULUS {
120 let (sum2, _) = sum.overflowing_sub(Self::MODULUS);
121 sum2
122 } else {
123 sum
124 }
125 }
126 #[inline]
127 fn neg(&self, target: $T) -> $T {
128 if target == 0 {
129 0
130 } else {
131 Self::MODULUS - target
132 }
133 }
134 #[inline]
135 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
136 if (P1 as u32) < $half_bits {
137 Self::reduce_single(lhs * rhs)
138 } else {
139 Self::reduce_double(impl_fixed_trinomial_solinas!(@widen_mul, $kind, $T, $D, lhs, rhs))
140 }
141 }
142 #[inline]
143 fn inv(&self, target: $T) -> Option<$T> {
144 if (P1 as u32) < usize::BITS {
147 (target as usize)
148 .invm(&(Self::MODULUS as usize))
149 .map(|v| v as $T)
150 } else {
151 target.invm(&Self::MODULUS)
152 }
153 }
154 #[inline]
155 fn sqr(&self, target: $T) -> $T {
156 if (P1 as u32) < $half_bits {
157 Self::reduce_single(target * target)
158 } else {
159 Self::reduce_double(impl_fixed_trinomial_solinas!(@widen_sqr, $kind, $T, $D, target))
160 }
161 }
162
163 impl_reduced_binary_pow!($T);
164 }
165 };
166
167 (@reduce_single, primitive, $T:ty, $D:ty) => {
169 pub const fn reduce_single(v: $T) -> $T {
174 let mut v: $D = v as $D;
175 while v >> P1 > 0 {
176 let lo = (v as $T) & Self::BITMASK;
177 let hi = v >> P1;
178 let mut sum: $D = (hi << (P2 as u32)) + (lo as $D);
179 if K > 0 {
180 sum -= hi * (K as $D);
181 } else if K < 0 {
182 sum += hi * ((-K) as $D);
183 }
184 v = sum;
185 }
186 let v = v as $T;
187 if v >= Self::MODULUS {
188 v - Self::MODULUS
189 } else {
190 v
191 }
192 }
193 };
194
195 (@reduce_single, udouble, $T:ty, $D:ty) => {
198 pub fn reduce_single(v: $T) -> $T {
203 let mut v: $D = udouble { hi: 0, lo: v };
204 while v.hi > 0 || v.lo >> P1 > 0 {
205 let lo = v.lo & Self::BITMASK;
206 let hi = v >> P1;
207 let mut sum = (hi << (P2 as u32)) + lo;
208 if K > 0 {
209 sum -= hi * (K as umax);
210 } else if K < 0 {
211 sum += hi * ((-K) as umax);
212 }
213 v = sum;
214 }
215 let v = v.lo;
216 if v >= Self::MODULUS {
217 v - Self::MODULUS
218 } else {
219 v
220 }
221 }
222 };
223
224 (@reduce_double, primitive, $T:ty, $D:ty) => {
231 pub fn reduce_double(v: $D) -> $T {
236 let mut lo = (v as $T) & Self::BITMASK;
237 let mut hi = v >> P1;
238 macro_rules! solinas_fold {
239 () => {
240 let mut sum: $D = (hi << (P2 as u32)) + (lo as $D);
241 if K > 0 { sum -= hi * (K as $D); }
242 else if K < 0 { sum += hi * ((-K) as $D); }
243 lo = (sum as $T) & Self::BITMASK;
244 hi = sum >> P1;
245 };
246 }
247 if Self::FOLDS <= 3 {
248 #[allow(unused_assignments)] { solinas_fold!(); }
249 #[allow(unused_assignments)] { solinas_fold!(); }
250 #[allow(unused_assignments)] { solinas_fold!(); }
251 } else if Self::FOLDS == 4 {
252 #[allow(unused_assignments)] { solinas_fold!(); }
253 #[allow(unused_assignments)] { solinas_fold!(); }
254 #[allow(unused_assignments)] { solinas_fold!(); }
255 #[allow(unused_assignments)] { solinas_fold!(); }
256 } else {
257 while hi > 0 { solinas_fold!(); }
258 }
259 if lo >= Self::MODULUS {
260 lo - Self::MODULUS
261 } else {
262 lo
263 }
264 }
265 };
266
267 (@reduce_double, udouble, $T:ty, $D:ty) => {
275 pub fn reduce_double(v: $D) -> $T {
280 let mut lo = v.lo & Self::BITMASK;
281 let mut hi = v >> P1;
282 macro_rules! udouble_fold {
283 () => {
284 let mut sum = (hi << (P2 as u32)) + lo;
285 if K > 0 { sum -= hi * (K as umax); }
286 else if K < 0 { sum += hi * ((-K) as umax); }
287 lo = sum.lo & Self::BITMASK;
288 hi = sum >> P1;
289 };
290 }
291 if Self::FOLDS <= 3 {
292 #[allow(unused_assignments)] { udouble_fold!(); }
293 #[allow(unused_assignments)] { udouble_fold!(); }
294 #[allow(unused_assignments)] { udouble_fold!(); }
295 } else if Self::FOLDS == 4 {
296 #[allow(unused_assignments)] { udouble_fold!(); }
297 #[allow(unused_assignments)] { udouble_fold!(); }
298 #[allow(unused_assignments)] { udouble_fold!(); }
299 #[allow(unused_assignments)] { udouble_fold!(); }
300 } else {
301 while hi.hi > 0 || hi.lo > 0 { udouble_fold!(); }
302 }
303 if lo >= Self::MODULUS {
304 lo - Self::MODULUS
305 } else {
306 lo
307 }
308 }
309 };
310
311 (@widen_mul, primitive, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
313 (*$lhs as $D) * (*$rhs as $D)
314 };
315
316 (@widen_mul, udouble, $T:ty, $D:ty, $lhs:expr, $rhs:expr) => {
318 <$D>::widening_mul(*$lhs, *$rhs)
319 };
320
321 (@widen_sqr, primitive, $T:ty, $D:ty, $target:expr) => {
323 ($target as $D) * ($target as $D)
324 };
325
326 (@widen_sqr, udouble, $T:ty, $D:ty, $target:expr) => {
328 <$D>::widening_square($target)
329 };
330}
331
332#[derive(Debug, Clone, Copy)]
352pub struct FixedTrinomialSolinas32<const P1: u8, const P2: u8, const K: i32>();
353
354impl_fixed_trinomial_solinas!(FixedTrinomialSolinas32, u32, i32, u64, 16, 31, primitive);
355
356#[derive(Debug, Clone, Copy)]
377pub struct FixedTrinomialSolinas64<const P1: u8, const P2: u8, const K: i64>();
378
379impl_fixed_trinomial_solinas!(FixedTrinomialSolinas64, u64, i64, u128, 32, 64, primitive);
380
381#[derive(Debug, Clone, Copy)]
403pub struct FixedTrinomialSolinas<const P1: u8, const P2: u8, const K: imax>();
404
405impl_fixed_trinomial_solinas!(FixedTrinomialSolinas, umax, imax, udouble, 64, 127, udouble);
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::{ModularCoreOps, ModularPow};
411 use rand::random;
412
413 type S1 = FixedTrinomialSolinas<31, 13, 1>;
415 type S2 = FixedTrinomialSolinas<61, 30, 1>;
416 type S3 = FixedTrinomialSolinas<127, 64, 1>;
417 type S4 = FixedTrinomialSolinas<32, 16, 1>;
418 type S5 = FixedTrinomialSolinas<56, 28, -1>;
419 type S6 = FixedTrinomialSolinas<122, 61, -3>;
420
421 type S64_1 = FixedTrinomialSolinas64<31, 13, 1>;
423 type S64_2 = FixedTrinomialSolinas64<61, 30, 1>;
424 type S64_3 = FixedTrinomialSolinas64<32, 16, 1>;
425 type S64_4 = FixedTrinomialSolinas64<64, 32, 1>; type S32_1 = FixedTrinomialSolinas32<4, 2, 1>;
429 type S32_2 = FixedTrinomialSolinas32<5, 3, -1>;
430 type S32_3 = FixedTrinomialSolinas32<6, 2, 1>;
431
432 const NRANDOM: u32 = 10;
433
434 #[test]
435 fn creation_test_u128() {
436 const P: umax = <S1>::MODULUS;
437 let m = S1::new(&P);
438 assert_eq!(m.residue(m.transform(0)), 0);
439 assert_eq!(m.residue(m.transform(1)), 1);
440 assert_eq!(m.residue(m.transform(P)), 0);
441 assert_eq!(m.residue(m.transform(P - 1)), P - 1);
442 assert_eq!(m.residue(m.transform(P + 1)), 1);
443
444 for _ in 0..NRANDOM {
445 let a = random::<umax>();
446
447 const P1: umax = <S1>::MODULUS;
448 let m1 = S1::new(&P1);
449 assert_eq!(m1.residue(m1.transform(a)), a % P1);
450 const P2: umax = <S2>::MODULUS;
451 let m2 = S2::new(&P2);
452 assert_eq!(m2.residue(m2.transform(a)), a % P2);
453 const P3: umax = <S3>::MODULUS;
454 let m3 = S3::new(&P3);
455 assert_eq!(m3.residue(m3.transform(a)), a % P3);
456 const P4: umax = <S4>::MODULUS;
457 let m4 = S4::new(&P4);
458 assert_eq!(m4.residue(m4.transform(a)), a % P4);
459 const P5: umax = <S5>::MODULUS;
460 let m5 = S5::new(&P5);
461 assert_eq!(m5.residue(m5.transform(a)), a % P5);
462 const P6: umax = <S6>::MODULUS;
463 let m6 = S6::new(&P6);
464 assert_eq!(m6.residue(m6.transform(a)), a % P6);
465 }
466 }
467
468 #[test]
469 fn creation_test_u64() {
470 for _ in 0..NRANDOM {
471 let a = random::<u64>();
472
473 const P1: u64 = <S64_1>::MODULUS;
474 let m1 = S64_1::new(&P1);
475 assert_eq!(m1.residue(m1.transform(a)), a % P1);
476 const P2: u64 = <S64_2>::MODULUS;
477 let m2 = S64_2::new(&P2);
478 assert_eq!(m2.residue(m2.transform(a)), a % P2);
479 const P3: u64 = <S64_3>::MODULUS;
480 let m3 = S64_3::new(&P3);
481 assert_eq!(m3.residue(m3.transform(a)), a % P3);
482 const P4: u64 = <S64_4>::MODULUS;
483 let m4 = S64_4::new(&P4);
484 assert_eq!(m4.residue(m4.transform(a)), a % P4);
485 }
486 }
487
488 #[test]
489 fn creation_test_u32() {
490 for _ in 0..NRANDOM {
491 let a = random::<u32>();
492
493 const P1: u32 = <S32_1>::MODULUS;
494 let m1 = S32_1::new(&P1);
495 assert_eq!(m1.residue(m1.transform(a)), a % P1);
496 const P2: u32 = <S32_2>::MODULUS;
497 let m2 = S32_2::new(&P2);
498 assert_eq!(m2.residue(m2.transform(a)), a % P2);
499 const P3: u32 = <S32_3>::MODULUS;
500 let m3 = S32_3::new(&P3);
501 assert_eq!(m3.residue(m3.transform(a)), a % P3);
502 }
503 }
504
505 #[test]
506 fn test_against_modops_u128() {
507 macro_rules! tests_for {
508 ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
509 const P: umax = <$M>::MODULUS;
510 let r = <$M>::new(&P);
511 let am = r.transform($a);
512 let bm = r.transform($b);
513 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
514 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
515 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
516 assert_eq!(r.neg(am), $a.negm(&P));
517 assert_eq!(r.inv(am), $a.invm(&P));
518 assert_eq!(r.dbl(am), $a.dblm(&P));
519 assert_eq!(r.sqr(am), $a.sqm(&P));
520 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
521 })*);
522 }
523
524 for _ in 0..NRANDOM {
525 let (a, b) = (random::<u128>(), random::<u128>());
526 let e = random::<u8>() as umax;
527 tests_for!(a, b, e; S1 S2 S3 S4 S5 S6);
528 }
529 }
530
531 #[test]
532 fn test_against_modops_u64() {
533 macro_rules! tests_for {
534 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
535 const P: u64 = <$M>::MODULUS;
536 let r = <$M>::new(&P);
537 let am = r.transform($a);
538 let bm = r.transform($b);
539 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
540 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
541 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
542 assert_eq!(r.neg(am), $a.negm(&P));
543 assert_eq!(r.inv(am), $a.invm(&P));
544 assert_eq!(r.dbl(am), $a.dblm(&P));
545 assert_eq!(r.sqr(am), $a.sqm(&P));
546 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
547 })*);
548 }
549
550 for _ in 0..NRANDOM {
551 let a = random::<u64>();
552 let b = random::<u64>();
553 let e = random::<u8>() as u64;
554 tests_for!(a, b, e; S64_1 S64_2 S64_3 S64_4);
555 }
556 }
557
558 #[test]
559 fn test_against_modops_u32() {
560 macro_rules! tests_for {
561 ($a:ident, $b:ident, $e:ident; $($M:ty)*) => ($({
562 const P: u32 = <$M>::MODULUS;
563 let r = <$M>::new(&P);
564 let am = r.transform($a);
565 let bm = r.transform($b);
566 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
567 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
568 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
569 assert_eq!(r.neg(am), $a.negm(&P));
570 assert_eq!(r.inv(am), $a.invm(&P));
571 assert_eq!(r.dbl(am), $a.dblm(&P));
572 assert_eq!(r.sqr(am), $a.sqm(&P));
573 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
574 })*);
575 }
576
577 for _ in 0..NRANDOM {
578 let a = random::<u32>();
579 let b = random::<u32>();
580 let e = random::<u8>() as u32;
581 tests_for!(a, b, e; S32_1 S32_2 S32_3);
582 }
583 }
584
585 #[test]
586 fn test_add_near_overflow_u64() {
587 type S = FixedTrinomialSolinas64<64, 32, 1>;
589 const P: u64 = <S>::MODULUS;
590 assert_eq!(P, 0xFFFFFFFF00000001);
591 let r = S::new(&P);
592 let a = r.transform(P - 1);
595 let b = r.transform(P - 2);
596 assert_eq!(r.residue(r.add(&a, &b)), P - 3);
597 let c = r.transform(P - 1);
599 assert_eq!(r.residue(r.dbl(c)), P - 2);
600 }
601}