1#![warn(missing_docs, unused_imports)]
2
3pub mod primes;
6
7use std::ops::Deref;
8
9use crate::errors::{Error, Result};
10use fhe_util::{is_prime, transcode_from_bytes, transcode_to_bytes};
11use itertools::{izip, Itertools};
12use num_bigint::BigUint;
13use num_traits::cast::ToPrimitive;
14use pulp::Arch;
15use rand::{distr::Uniform, CryptoRng, Rng, RngCore};
16
17const fn const_time_cond_select(on_true: u64, on_false: u64, cond: bool) -> u64 {
19 let mask = -(cond as i64) as u64;
20 let diff = on_true ^ on_false;
21 (diff & mask) ^ on_false
22}
23
24#[derive(Debug, Clone)]
26pub struct Modulus {
27 pub(crate) p: u64,
28 barrett_hi: u64,
29 barrett_lo: u64,
30 leading_zeros: u32,
31 pub(crate) supports_opt: bool,
32 distribution: Uniform<u64>,
33 arch: Arch,
34}
35
36impl Eq for Modulus {}
38
39impl PartialEq for Modulus {
40 fn eq(&self, other: &Self) -> bool {
41 self.p == other.p
42 }
43}
44
45impl Deref for Modulus {
47 type Target = u64;
48
49 fn deref(&self) -> &Self::Target {
50 &self.p
51 }
52}
53
54impl Modulus {
55 pub fn new(p: u64) -> Result<Self> {
57 if p < 2 || (p >> 62) != 0 {
58 Err(Error::InvalidModulus(p))
59 } else {
60 let barrett = ((BigUint::from(1u64) << 128usize) / p).to_u128().unwrap(); Ok(Self {
62 p,
63 barrett_hi: (barrett >> 64) as u64,
64 barrett_lo: barrett as u64,
65 leading_zeros: p.leading_zeros(),
66 supports_opt: primes::supports_opt(p),
67 distribution: Uniform::new(0, p).unwrap(),
68 arch: Arch::new(),
69 })
70 }
71 }
72
73 pub const fn add(&self, a: u64, b: u64) -> u64 {
76 debug_assert!(a < self.p && b < self.p);
77 Self::reduce1(a + b, self.p)
78 }
79
80 pub const unsafe fn add_vt(&self, a: u64, b: u64) -> u64 {
87 debug_assert!(a < self.p && b < self.p);
88 Self::reduce1_vt(a + b, self.p)
89 }
90
91 pub const fn sub(&self, a: u64, b: u64) -> u64 {
94 debug_assert!(a < self.p && b < self.p);
95 Self::reduce1(a + self.p - b, self.p)
96 }
97
98 const unsafe fn sub_vt(&self, a: u64, b: u64) -> u64 {
105 debug_assert!(a < self.p && b < self.p);
106 Self::reduce1_vt(a + self.p - b, self.p)
107 }
108
109 pub const fn mul(&self, a: u64, b: u64) -> u64 {
112 debug_assert!(a < self.p && b < self.p);
113 self.reduce_u128((a as u128) * (b as u128))
114 }
115
116 const unsafe fn mul_vt(&self, a: u64, b: u64) -> u64 {
123 debug_assert!(a < self.p && b < self.p);
124 Self::reduce1_vt(self.lazy_reduce_u128((a as u128) * (b as u128)), self.p)
125 }
126
127 pub const fn mul_opt(&self, a: u64, b: u64) -> u64 {
131 debug_assert!(self.supports_opt);
132 debug_assert!(a < self.p && b < self.p);
133
134 self.reduce_opt_u128((a as u128) * (b as u128))
135 }
136
137 const unsafe fn mul_opt_vt(&self, a: u64, b: u64) -> u64 {
144 debug_assert!(self.supports_opt);
145 debug_assert!(a < self.p && b < self.p);
146
147 self.reduce_opt_u128_vt((a as u128) * (b as u128))
148 }
149
150 pub const fn neg(&self, a: u64) -> u64 {
154 debug_assert!(a < self.p);
155 Self::reduce1(self.p - a, self.p)
156 }
157
158 const unsafe fn neg_vt(&self, a: u64) -> u64 {
165 debug_assert!(a < self.p);
166 Self::reduce1_vt(self.p - a, self.p)
167 }
168
169 pub const fn shoup(&self, a: u64) -> u64 {
173 debug_assert!(a < self.p);
174
175 (((a as u128) << 64) / (self.p as u128)) as u64
176 }
177
178 pub const fn mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
182 Self::reduce1(self.lazy_mul_shoup(a, b, b_shoup), self.p)
183 }
184
185 const unsafe fn mul_shoup_vt(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
192 Self::reduce1_vt(self.lazy_mul_shoup(a, b, b_shoup), self.p)
193 }
194
195 pub const fn lazy_mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
200 debug_assert!(b < self.p);
201 debug_assert!(b_shoup == self.shoup(b));
202
203 let q = ((a as u128) * (b_shoup as u128)) >> 64;
204 let r = ((a as u128) * (b as u128) - q * (self.p as u128)) as u64;
205
206 debug_assert!(r < 2 * self.p);
207
208 r
209 }
210
211 pub fn add_vec(&self, a: &mut [u64], b: &[u64]) {
216 debug_assert_eq!(a.len(), b.len());
217 self.arch.dispatch(|| {
218 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add(*ai, *bi))
219 })
220 }
221
222 pub unsafe fn add_vec_vt(&self, a: &mut [u64], b: &[u64]) {
230 let n = a.len();
231 debug_assert_eq!(n, b.len());
232
233 let p = self.p;
234 macro_rules! add_at {
235 ($idx:expr) => {
236 *a.get_unchecked_mut($idx) =
237 Self::reduce1_vt(*a.get_unchecked_mut($idx) + *b.get_unchecked($idx), p);
238 };
239 }
240
241 if n % 16 == 0 {
242 self.arch.dispatch(|| {
243 for i in 0..n / 16 {
244 add_at!(16 * i);
245 add_at!(16 * i + 1);
246 add_at!(16 * i + 2);
247 add_at!(16 * i + 3);
248 add_at!(16 * i + 4);
249 add_at!(16 * i + 5);
250 add_at!(16 * i + 6);
251 add_at!(16 * i + 7);
252 add_at!(16 * i + 8);
253 add_at!(16 * i + 9);
254 add_at!(16 * i + 10);
255 add_at!(16 * i + 11);
256 add_at!(16 * i + 12);
257 add_at!(16 * i + 13);
258 add_at!(16 * i + 14);
259 add_at!(16 * i + 15);
260 }
261 })
262 } else {
263 self.arch.dispatch(|| {
264 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add_vt(*ai, *bi))
265 })
266 }
267 }
268
269 pub fn sub_vec(&self, a: &mut [u64], b: &[u64]) {
274 debug_assert_eq!(a.len(), b.len());
275 self.arch.dispatch(|| {
276 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub(*ai, *bi))
277 })
278 }
279
280 pub unsafe fn sub_vec_vt(&self, a: &mut [u64], b: &[u64]) {
288 let n = a.len();
289 debug_assert_eq!(n, b.len());
290
291 let p = self.p;
292 macro_rules! sub_at {
293 ($idx:expr) => {
294 *a.get_unchecked_mut($idx) =
295 Self::reduce1_vt(p + *a.get_unchecked_mut($idx) - *b.get_unchecked($idx), p);
296 };
297 }
298
299 if n % 16 == 0 {
300 self.arch.dispatch(|| {
301 for i in 0..n / 16 {
302 sub_at!(16 * i);
303 sub_at!(16 * i + 1);
304 sub_at!(16 * i + 2);
305 sub_at!(16 * i + 3);
306 sub_at!(16 * i + 4);
307 sub_at!(16 * i + 5);
308 sub_at!(16 * i + 6);
309 sub_at!(16 * i + 7);
310 sub_at!(16 * i + 8);
311 sub_at!(16 * i + 9);
312 sub_at!(16 * i + 10);
313 sub_at!(16 * i + 11);
314 sub_at!(16 * i + 12);
315 sub_at!(16 * i + 13);
316 sub_at!(16 * i + 14);
317 sub_at!(16 * i + 15);
318 }
319 })
320 } else {
321 self.arch.dispatch(|| {
322 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub_vt(*ai, *bi))
323 })
324 }
325 }
326
327 pub fn mul_vec(&self, a: &mut [u64], b: &[u64]) {
332 debug_assert_eq!(a.len(), b.len());
333
334 if self.supports_opt {
335 self.arch.dispatch(|| {
336 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt(*ai, *bi))
337 })
338 } else {
339 self.arch.dispatch(|| {
340 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul(*ai, *bi))
341 })
342 }
343 }
344
345 pub fn scalar_mul_vec(&self, a: &mut [u64], b: u64) {
349 let b_shoup = self.shoup(b);
350 self.arch.dispatch(|| {
351 a.iter_mut()
352 .for_each(|ai| *ai = self.mul_shoup(*ai, b, b_shoup))
353 })
354 }
355
356 pub unsafe fn scalar_mul_vec_vt(&self, a: &mut [u64], b: u64) {
363 let b_shoup = self.shoup(b);
364 self.arch.dispatch(|| {
365 a.iter_mut()
366 .for_each(|ai| *ai = self.mul_shoup_vt(*ai, b, b_shoup))
367 })
368 }
369
370 pub unsafe fn mul_vec_vt(&self, a: &mut [u64], b: &[u64]) {
378 debug_assert_eq!(a.len(), b.len());
379
380 if self.supports_opt {
381 self.arch.dispatch(|| {
382 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt_vt(*ai, *bi))
383 })
384 } else {
385 self.arch.dispatch(|| {
386 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_vt(*ai, *bi))
387 })
388 }
389 }
390
391 pub fn shoup_vec(&self, a: &[u64]) -> Vec<u64> {
395 self.arch
396 .dispatch(|| a.iter().map(|ai| self.shoup(*ai)).collect_vec())
397 }
398
399 pub fn mul_shoup_vec(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) {
404 debug_assert_eq!(a.len(), b.len());
405 debug_assert_eq!(a.len(), b_shoup.len());
406 debug_assert_eq!(&b_shoup, &self.shoup_vec(b));
407
408 self.arch.dispatch(|| {
409 izip!(a.iter_mut(), b.iter(), b_shoup.iter())
410 .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup(*ai, *bi, *bi_shoup))
411 })
412 }
413
414 pub unsafe fn mul_shoup_vec_vt(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) {
422 debug_assert_eq!(a.len(), b.len());
423 debug_assert_eq!(a.len(), b_shoup.len());
424 debug_assert_eq!(&b_shoup, &self.shoup_vec(b));
425
426 self.arch.dispatch(|| {
427 izip!(a.iter_mut(), b.iter(), b_shoup.iter())
428 .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup_vt(*ai, *bi, *bi_shoup))
429 })
430 }
431
432 pub fn reduce_vec(&self, a: &mut [u64]) {
434 self.arch
435 .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce(*ai)))
436 }
437
438 const unsafe fn center_vt(&self, a: u64) -> i64 {
445 debug_assert!(a < self.p);
446
447 if a >= self.p >> 1 {
448 (a as i64) - (self.p as i64)
449 } else {
450 a as i64
451 }
452 }
453
454 pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec<i64> {
460 self.arch
461 .dispatch(|| a.iter().map(|ai| self.center_vt(*ai)).collect_vec())
462 }
463
464 pub unsafe fn reduce_vec_vt(&self, a: &mut [u64]) {
470 self.arch
471 .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce_vt(*ai)))
472 }
473
474 const fn reduce_i64(&self, a: i64) -> u64 {
476 self.reduce_u128((((self.p as i128) << 64) + (a as i128)) as u128)
477 }
478
479 const unsafe fn reduce_i64_vt(&self, a: i64) -> u64 {
485 self.reduce_u128_vt((((self.p as i128) << 64) + (a as i128)) as u128)
486 }
487
488 pub fn reduce_vec_i64(&self, a: &[i64]) -> Vec<u64> {
490 self.arch
491 .dispatch(|| a.iter().map(|ai| self.reduce_i64(*ai)).collect_vec())
492 }
493
494 pub unsafe fn reduce_vec_i64_vt(&self, a: &[i64]) -> Vec<u64> {
500 self.arch
501 .dispatch(|| a.iter().map(|ai| self.reduce_i64_vt(*ai)).collect())
502 }
503
504 pub fn reduce_vec_new(&self, a: &[u64]) -> Vec<u64> {
506 self.arch
507 .dispatch(|| a.iter().map(|ai| self.reduce(*ai)).collect())
508 }
509
510 pub unsafe fn reduce_vec_new_vt(&self, a: &[u64]) -> Vec<u64> {
516 self.arch
517 .dispatch(|| a.iter().map(|bi| self.reduce_vt(*bi)).collect())
518 }
519
520 pub fn neg_vec(&self, a: &mut [u64]) {
524 self.arch
525 .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.neg(*ai)))
526 }
527
528 pub unsafe fn neg_vec_vt(&self, a: &mut [u64]) {
535 self.arch
536 .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.neg_vt(*ai)))
537 }
538
539 pub fn pow(&self, a: u64, n: u64) -> u64 {
543 debug_assert!(a < self.p && n < self.p);
544
545 if n == 0 {
546 1
547 } else if n == 1 {
548 a
549 } else {
550 let mut r = a;
551 let mut i = (62 - n.leading_zeros()) as isize;
552 while i >= 0 {
553 r = self.mul(r, r);
554 if (n >> i) & 1 == 1 {
555 r = self.mul(r, a);
556 }
557 i -= 1;
558 }
559 r
560 }
561 }
562
563 pub fn inv(&self, a: u64) -> std::option::Option<u64> {
568 if !is_prime(self.p) || a == 0 {
569 None
570 } else {
571 let r = self.pow(a, self.p - 2);
572 debug_assert_eq!(self.mul(a, r), 1);
573 Some(r)
574 }
575 }
576
577 pub const fn reduce_u128(&self, a: u128) -> u64 {
579 Self::reduce1(self.lazy_reduce_u128(a), self.p)
580 }
581
582 pub const unsafe fn reduce_u128_vt(&self, a: u128) -> u64 {
588 Self::reduce1_vt(self.lazy_reduce_u128(a), self.p)
589 }
590
591 pub const fn reduce(&self, a: u64) -> u64 {
593 Self::reduce1(self.lazy_reduce(a), self.p)
594 }
595
596 pub const unsafe fn reduce_vt(&self, a: u64) -> u64 {
602 Self::reduce1_vt(self.lazy_reduce(a), self.p)
603 }
604
605 pub const fn reduce_opt_u128(&self, a: u128) -> u64 {
607 debug_assert!(self.supports_opt);
608 Self::reduce1(self.lazy_reduce_opt_u128(a), self.p)
609 }
610
611 pub(crate) const unsafe fn reduce_opt_u128_vt(&self, a: u128) -> u64 {
617 debug_assert!(self.supports_opt);
618 Self::reduce1_vt(self.lazy_reduce_opt_u128(a), self.p)
619 }
620
621 pub const fn reduce_opt(&self, a: u64) -> u64 {
623 Self::reduce1(self.lazy_reduce_opt(a), self.p)
624 }
625
626 pub const unsafe fn reduce_opt_vt(&self, a: u64) -> u64 {
632 Self::reduce1_vt(self.lazy_reduce_opt(a), self.p)
633 }
634
635 pub(crate) const fn reduce1(x: u64, p: u64) -> u64 {
638 debug_assert!(p >> 63 == 0);
639 debug_assert!(x < 2 * p);
640
641 let r = const_time_cond_select(x, x.wrapping_sub(p), x < p);
642
643 debug_assert!(r == x % p);
644
645 r
646 }
647
648 #[cfg(any(target_os = "macos", target_feature = "avx2"))]
655 pub(crate) const unsafe fn reduce1_vt(x: u64, p: u64) -> u64 {
656 debug_assert!(p >> 63 == 0);
657 debug_assert!(x < 2 * p);
658
659 if x >= p {
660 x - p
661 } else {
662 x
663 }
664 }
665
666 #[cfg(all(not(target_os = "macos"), not(target_feature = "avx2")))]
667 #[inline]
668 pub(crate) const unsafe fn reduce1_vt(x: u64, p: u64) -> u64 {
669 Self::reduce1(x, p)
670 }
671
672 pub const fn lazy_reduce_u128(&self, a: u128) -> u64 {
675 let a_lo = a as u64;
676 let a_hi = (a >> 64) as u64;
677 let p_lo_lo = ((a_lo as u128) * (self.barrett_lo as u128)) >> 64;
678 let p_hi_lo = (a_hi as u128) * (self.barrett_lo as u128);
679 let p_lo_hi = (a_lo as u128) * (self.barrett_hi as u128);
680
681 let q = ((p_lo_hi + p_hi_lo + p_lo_lo) >> 64) + (a_hi as u128) * (self.barrett_hi as u128);
682 let r = (a - q * (self.p as u128)) as u64;
683
684 debug_assert!((r as u128) < 2 * (self.p as u128));
685 debug_assert!(r % self.p == (a % (self.p as u128)) as u64);
686
687 r
688 }
689
690 pub const fn lazy_reduce(&self, a: u64) -> u64 {
693 let p_lo_lo = ((a as u128) * (self.barrett_lo as u128)) >> 64;
694 let p_lo_hi = (a as u128) * (self.barrett_hi as u128);
695
696 let q = (p_lo_hi + p_lo_lo) >> 64;
697 let r = (a as u128 - q * (self.p as u128)) as u64;
698
699 debug_assert!((r as u128) < 2 * (self.p as u128));
700 debug_assert!(r % self.p == a % self.p);
701
702 r
703 }
704
705 pub const fn lazy_reduce_opt_u128(&self, a: u128) -> u64 {
710 debug_assert!(a < (self.p as u128) * (self.p as u128));
711
712 let q = (((self.barrett_lo as u128) * (a >> 64)) + (a << self.leading_zeros)) >> 64;
713 let r = (a - q * (self.p as u128)) as u64;
714
715 debug_assert!((r as u128) < 2 * (self.p as u128));
716 debug_assert!(r % self.p == (a % (self.p as u128)) as u64);
717
718 r
719 }
720
721 const fn lazy_reduce_opt(&self, a: u64) -> u64 {
724 let q = a >> (64 - self.leading_zeros);
725 let r = ((a as u128) - (q as u128) * (self.p as u128)) as u64;
726
727 debug_assert!((r as u128) < 2 * (self.p as u128));
728 debug_assert!(r % self.p == a % self.p);
729
730 r
731 }
732
733 pub fn lazy_reduce_vec(&self, a: &mut [u64]) {
736 if self.supports_opt {
737 a.iter_mut().for_each(|ai| *ai = self.lazy_reduce_opt(*ai))
738 } else {
739 a.iter_mut().for_each(|ai| *ai = self.lazy_reduce(*ai))
740 }
741 }
742
743 pub fn random_vec<R: RngCore + CryptoRng>(&self, size: usize, rng: &mut R) -> Vec<u64> {
745 rng.sample_iter(self.distribution).take(size).collect_vec()
746 }
747
748 pub const fn serialization_length(&self, size: usize) -> usize {
752 assert!(size % 8 == 0);
753 let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
754 p_nbits * size / 8
755 }
756
757 pub fn serialize_vec(&self, a: &[u64]) -> Vec<u8> {
761 let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
762 transcode_to_bytes(a, p_nbits)
763 }
764
765 pub fn deserialize_vec(&self, b: &[u8]) -> Vec<u64> {
767 let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
768 transcode_from_bytes(b, p_nbits)
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::{primes, Modulus};
775 use itertools::{izip, Itertools};
776 use proptest::collection::vec as prop_vec;
777 use proptest::prelude::{any, BoxedStrategy, Just, Strategy};
778 use rand::{rng, RngCore};
779
780 fn valid_moduli() -> impl Strategy<Value = Modulus> {
783 any::<u64>().prop_filter_map("filter invalid moduli", |p| Modulus::new(p).ok())
784 }
785
786 fn vecs() -> BoxedStrategy<(Vec<u64>, Vec<u64>)> {
787 prop_vec(any::<u64>(), 1..100)
788 .prop_flat_map(|vec| {
789 let len = vec.len();
790 (Just(vec), prop_vec(any::<u64>(), len))
791 })
792 .boxed()
793 }
794
795 proptest! {
796 #[test]
797 fn constructor(p: u64) {
798 prop_assert!(Modulus::new(p | (1u64 << 62)).is_err());
800 prop_assert!(Modulus::new(p | (1u64 << 63)).is_err());
801
802 prop_assert!(Modulus::new(0u64).is_err());
804 prop_assert!(Modulus::new(1u64).is_err());
805
806 prop_assume!(p >> 2 >= 2);
808 let q = Modulus::new(p >> 2);
809 prop_assert!(q.is_ok());
810 prop_assert_eq!(*q.unwrap(), p >> 2);
811 }
812
813 #[test]
814 fn neg(p in valid_moduli(), mut a: u64) {
815 a = p.reduce(a);
816 prop_assert_eq!(p.neg(a), (*p - a) % *p);
817 unsafe { prop_assert_eq!(p.neg_vt(a), (*p - a) % *p) }
818
819 #[cfg(debug_assertions)]
820 {
821 prop_assert!(std::panic::catch_unwind(|| p.neg(*p)).is_err());
822 prop_assert!(std::panic::catch_unwind(|| p.neg(*p + 1)).is_err());
823 }
824 }
825
826 #[test]
827 fn add(p in valid_moduli(), mut a: u64, mut b: u64) {
828 a = p.reduce(a);
829 b = p.reduce(b);
830 prop_assert_eq!(p.add(a, b), (a + b) % *p);
831 unsafe { prop_assert_eq!(p.add_vt(a, b), (a + b) % *p) }
832
833 #[cfg(debug_assertions)]
834 {
835 prop_assert!(std::panic::catch_unwind(|| p.add(*p, a)).is_err());
836 prop_assert!(std::panic::catch_unwind(|| p.add(a, *p)).is_err());
837 prop_assert!(std::panic::catch_unwind(|| p.add(*p + 1, a)).is_err());
838 prop_assert!(std::panic::catch_unwind(|| p.add(a, *p + 1)).is_err());
839 }
840 }
841
842 #[test]
843 fn sub(p in valid_moduli(), mut a: u64, mut b: u64) {
844 a = p.reduce(a);
845 b = p.reduce(b);
846 prop_assert_eq!(p.sub(a, b), (a + *p - b) % *p);
847 unsafe { prop_assert_eq!(p.sub_vt(a, b), (a + *p - b) % *p) }
848
849 #[cfg(debug_assertions)]
850 {
851 prop_assert!(std::panic::catch_unwind(|| p.sub(*p, a)).is_err());
852 prop_assert!(std::panic::catch_unwind(|| p.sub(a, *p)).is_err());
853 prop_assert!(std::panic::catch_unwind(|| p.sub(*p + 1, a)).is_err());
854 prop_assert!(std::panic::catch_unwind(|| p.sub(a, *p + 1)).is_err());
855 }
856 }
857
858 #[test]
859 fn mul(p in valid_moduli(), mut a: u64, mut b: u64) {
860 a = p.reduce(a);
861 b = p.reduce(b);
862 prop_assert_eq!(p.mul(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128));
863 unsafe { prop_assert_eq!(p.mul_vt(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128)) }
864
865 #[cfg(debug_assertions)]
866 {
867 prop_assert!(std::panic::catch_unwind(|| p.mul(*p, a)).is_err());
868 prop_assert!(std::panic::catch_unwind(|| p.mul(a, *p)).is_err());
869 prop_assert!(std::panic::catch_unwind(|| p.mul(*p + 1, a)).is_err());
870 prop_assert!(std::panic::catch_unwind(|| p.mul(a, *p + 1)).is_err());
871 }
872 }
873
874 #[test]
875 fn mul_shoup(p in valid_moduli(), mut a: u64, mut b: u64) {
876 a = p.reduce(a);
877 b = p.reduce(b);
878
879 let b_shoup = p.shoup(b);
881
882 #[cfg(debug_assertions)]
883 {
884 prop_assert!(std::panic::catch_unwind(|| p.shoup(*p)).is_err());
885 prop_assert!(std::panic::catch_unwind(|| p.shoup(*p + 1)).is_err());
886 }
887
888 prop_assert_eq!(p.mul_shoup(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (*p as u128));
890 unsafe { prop_assert_eq!(p.mul_shoup_vt(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (*p as u128)) }
891
892 #[cfg(debug_assertions)]
894 {
895 prop_assert!(std::panic::catch_unwind(|| p.mul_shoup(a, *p, b_shoup)).is_err());
896 prop_assume!(a != b);
897 prop_assert!(std::panic::catch_unwind(|| p.mul_shoup(a, a, b_shoup)).is_err());
898 }
899 }
900
901 #[test]
902 fn reduce(p in valid_moduli(), a: u64) {
903 prop_assert_eq!(p.reduce(a), a % *p);
904 unsafe { prop_assert_eq!(p.reduce_vt(a), a % *p) }
905 if p.supports_opt {
906 prop_assert_eq!(p.reduce_opt(a), a % *p);
907 unsafe { prop_assert_eq!(p.reduce_opt_vt(a), a % *p) }
908 }
909 }
910
911 #[test]
912 fn lazy_reduce(p in valid_moduli(), a: u64) {
913 prop_assert!(p.lazy_reduce(a) < 2 * *p);
914 prop_assert_eq!(p.lazy_reduce(a) % *p, p.reduce(a));
915 }
916
917 #[test]
918 fn reduce_i64(p in valid_moduli(), a: i64) {
919 let b = if a < 0 { p.neg(p.reduce(-a as u64)) } else { p.reduce(a as u64) };
920 prop_assert_eq!(p.reduce_i64(a), b);
921 unsafe { prop_assert_eq!(p.reduce_i64_vt(a), b) }
922 }
923
924 #[test]
925 fn reduce_u128(p in valid_moduli(), mut a: u128) {
926 prop_assert_eq!(p.reduce_u128(a) as u128, a % (*p as u128));
927 unsafe { prop_assert_eq!(p.reduce_u128_vt(a) as u128, a % (*p as u128)) }
928 if p.supports_opt {
929 let p_square = (*p as u128) * (*p as u128);
930 a %= p_square;
931 prop_assert_eq!(p.reduce_opt_u128(a) as u128, a % (*p as u128));
932 unsafe { prop_assert_eq!(p.reduce_opt_u128_vt(a) as u128, a % (*p as u128)) }
933 }
934 }
935
936 #[test]
937 fn add_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
938 p.reduce_vec(&mut a);
939 p.reduce_vec(&mut b);
940 let c = a.clone();
941 p.add_vec(&mut a, &b);
942 prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec());
943 a.clone_from(&c);
944 unsafe { p.add_vec_vt(&mut a, &b) }
945 prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec());
946 }
947
948 #[test]
949 fn sub_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
950 p.reduce_vec(&mut a);
951 p.reduce_vec(&mut b);
952 let c = a.clone();
953 p.sub_vec(&mut a, &b);
954 prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec());
955 a.clone_from(&c);
956 unsafe { p.sub_vec_vt(&mut a, &b) }
957 prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec());
958 }
959
960 #[test]
961 fn mul_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
962 p.reduce_vec(&mut a);
963 p.reduce_vec(&mut b);
964 let c = a.clone();
965 p.mul_vec(&mut a, &b);
966 prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
967 a.clone_from(&c);
968 unsafe { p.mul_vec_vt(&mut a, &b); }
969 prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
970 }
971
972 #[test]
973 fn scalar_mul_vec(p in valid_moduli(), mut a: Vec<u64>, mut b: u64) {
974 p.reduce_vec(&mut a);
975 b = p.reduce(b);
976 let c = a.clone();
977
978 p.scalar_mul_vec(&mut a, b);
979 prop_assert_eq!(a.clone(), c.iter().map(|ci| p.mul(*ci, b)).collect_vec());
980
981 a.clone_from(&c);
982 unsafe { p.scalar_mul_vec_vt(&mut a, b) }
983 prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec());
984 }
985
986 #[test]
987 fn mul_shoup_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
988 p.reduce_vec(&mut a);
989 p.reduce_vec(&mut b);
990 let b_shoup = p.shoup_vec(&b);
991 let c = a.clone();
992 p.mul_shoup_vec(&mut a, &b, &b_shoup);
993 prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
994 a.clone_from(&c);
995 unsafe { p.mul_shoup_vec_vt(&mut a, &b, &b_shoup) }
996 prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
997 }
998
999 #[test]
1000 fn reduce_vec(p in valid_moduli(), a: Vec<u64>) {
1001 let mut b = a.clone();
1002 p.reduce_vec(&mut b);
1003 prop_assert_eq!(b.clone(), a.iter().map(|ai| p.reduce(*ai)).collect_vec());
1004
1005 b.clone_from(&a);
1006 unsafe { p.reduce_vec_vt(&mut b) }
1007 prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec());
1008 }
1009
1010 #[test]
1011 fn lazy_reduce_vec(p in valid_moduli(), a: Vec<u64>) {
1012 let mut b = a.clone();
1013 p.lazy_reduce_vec(&mut b);
1014 prop_assert!(b.iter().all(|bi| *bi < 2 * *p));
1015 prop_assert!(izip!(a, b).all(|(ai, bi)| bi % *p == ai % *p));
1016 }
1017
1018 #[test]
1019 fn reduce_vec_new(p in valid_moduli(), a: Vec<u64>) {
1020 let b = p.reduce_vec_new(&a);
1021 prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec());
1022 prop_assert_eq!(p.reduce_vec_new(&a), unsafe { p.reduce_vec_new_vt(&a) });
1023 }
1024
1025 #[test]
1026 fn reduce_vec_i64(p in valid_moduli(), a: Vec<i64>) {
1027 let b = p.reduce_vec_i64(&a);
1028 prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec());
1029 let b = unsafe { p.reduce_vec_i64_vt(&a) };
1030 prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec());
1031 }
1032
1033 #[test]
1034 fn neg_vec(p in valid_moduli(), mut a: Vec<u64>) {
1035 p.reduce_vec(&mut a);
1036 let mut b = a.clone();
1037 p.neg_vec(&mut b);
1038 prop_assert_eq!(b.clone(), a.iter().map(|ai| p.neg(*ai)).collect_vec());
1039 b.clone_from(&a);
1040 unsafe { p.neg_vec_vt(&mut b); }
1041 prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec());
1042 }
1043
1044 #[test]
1045 fn random_vec(p in valid_moduli(), size in 1..1000usize) {
1046 let mut rng = rng();
1047
1048 let v = p.random_vec(size, &mut rng);
1049 prop_assert_eq!(v.len(), size);
1050
1051 let w = p.random_vec(size, &mut rng);
1052 prop_assert_eq!(w.len(), size);
1053
1054 if (*p).leading_zeros() <= 30 {
1055 prop_assert_ne!(v, w); }
1057 }
1058
1059 #[test]
1060 fn serialize(p in valid_moduli(), mut a in prop_vec(any::<u64>(), 8)) {
1061 p.reduce_vec(&mut a);
1062 let b = p.serialize_vec(&a);
1063 let c = p.deserialize_vec(&b);
1064 prop_assert_eq!(a, c);
1065 }
1066 }
1067
1068 #[test]
1070 fn mul_opt() {
1071 let ntests = 100;
1072 let mut rng = rand::rng();
1073
1074 #[allow(clippy::single_element_loop)]
1075 for p in [4611686018326724609] {
1076 let q = Modulus::new(p).unwrap();
1077 assert!(primes::supports_opt(p));
1078
1079 assert_eq!(q.mul_opt(0, 1), 0);
1080 assert_eq!(q.mul_opt(1, 1), 1);
1081 assert_eq!(q.mul_opt(2 % p, 3 % p), 6 % p);
1082 assert_eq!(q.mul_opt(p - 1, 1), p - 1);
1083 assert_eq!(q.mul_opt(p - 1, 2 % p), p - 2);
1084
1085 #[cfg(debug_assertions)]
1086 {
1087 assert!(std::panic::catch_unwind(|| q.mul_opt(p, 1)).is_err());
1088 assert!(std::panic::catch_unwind(|| q.mul_opt(p << 1, 1)).is_err());
1089 assert!(std::panic::catch_unwind(|| q.mul_opt(0, p)).is_err());
1090 assert!(std::panic::catch_unwind(|| q.mul_opt(0, p << 1)).is_err());
1091 }
1092
1093 for _ in 0..ntests {
1094 let a = rng.next_u64() % p;
1095 let b = rng.next_u64() % p;
1096 assert_eq!(
1097 q.mul_opt(a, b),
1098 (((a as u128) * (b as u128)) % (p as u128)) as u64
1099 );
1100 }
1101 }
1102 }
1103
1104 #[test]
1106 fn pow() {
1107 let ntests = 10;
1108 let mut rng = rand::rng();
1109
1110 for p in [2u64, 3, 17, 1987, 4611686018326724609] {
1111 let q = Modulus::new(p).unwrap();
1112
1113 assert_eq!(q.pow(p - 1, 0), 1);
1114 assert_eq!(q.pow(p - 1, 1), p - 1);
1115 assert_eq!(q.pow(p - 1, 2 % p), 1);
1116 assert_eq!(q.pow(1, p - 2), 1);
1117 assert_eq!(q.pow(1, p - 1), 1);
1118
1119 #[cfg(debug_assertions)]
1120 {
1121 assert!(std::panic::catch_unwind(|| q.pow(p, 1)).is_err());
1122 assert!(std::panic::catch_unwind(|| q.pow(p << 1, 1)).is_err());
1123 assert!(std::panic::catch_unwind(|| q.pow(0, p)).is_err());
1124 assert!(std::panic::catch_unwind(|| q.pow(0, p << 1)).is_err());
1125 }
1126
1127 for _ in 0..ntests {
1128 let a = rng.next_u64() % p;
1129 let b = (rng.next_u64() % p) % 1000;
1130 let mut c = b;
1131 let mut r = 1;
1132 while c > 0 {
1133 r = q.mul(r, a);
1134 c -= 1;
1135 }
1136 assert_eq!(q.pow(a, b), r);
1137 }
1138 }
1139 }
1140
1141 #[test]
1143 fn inv() {
1144 let ntests = 100;
1145 let mut rng = rand::rng();
1146
1147 for p in [2u64, 3, 17, 1987, 4611686018326724609] {
1148 let q = Modulus::new(p).unwrap();
1149
1150 assert!(q.inv(0).is_none());
1151 assert_eq!(q.inv(1).unwrap(), 1);
1152 assert_eq!(q.inv(p - 1).unwrap(), p - 1);
1153
1154 #[cfg(debug_assertions)]
1155 {
1156 assert!(std::panic::catch_unwind(|| q.inv(p)).is_err());
1157 assert!(std::panic::catch_unwind(|| q.inv(p << 1)).is_err());
1158 }
1159
1160 for _ in 0..ntests {
1161 let a = rng.next_u64() % p;
1162 let b = q.inv(a);
1163
1164 if a == 0 {
1165 assert!(b.is_none())
1166 } else {
1167 assert!(b.is_some());
1168 assert_eq!(q.mul(a, b.unwrap()), 1)
1169 }
1170 }
1171 }
1172 }
1173}