#![warn(missing_docs, unused_imports)]
pub mod ntt;
pub mod primes;
use crate::errors::{Error, Result};
use fhe_util::{is_prime, transcode_from_bytes, transcode_to_bytes};
use itertools::{izip, Itertools};
use num_bigint::BigUint;
use num_traits::cast::ToPrimitive;
use rand::{distributions::Uniform, CryptoRng, Rng, RngCore};
#[derive(Debug, Clone, PartialEq)]
pub struct Modulus {
p: u64,
nbits: usize,
barrett_hi: u64,
barrett_lo: u64,
leading_zeros: u32,
pub(crate) supports_opt: bool,
distribution: Uniform<u64>,
}
impl Eq for Modulus {}
impl Modulus {
pub fn new(p: u64) -> Result<Self> {
if p < 2 || (p >> 62) != 0 {
Err(Error::InvalidModulus(p))
} else {
let barrett = ((BigUint::from(1u64) << 128usize) / p).to_u128().unwrap(); Ok(Self {
p,
nbits: 64 - p.leading_zeros() as usize,
barrett_hi: (barrett >> 64) as u64,
barrett_lo: barrett as u64,
leading_zeros: p.leading_zeros(),
supports_opt: primes::supports_opt(p),
distribution: Uniform::from(0..p),
})
}
}
pub const fn modulus(&self) -> u64 {
self.p
}
pub const fn add(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.p && b < self.p);
Self::reduce1(a + b, self.p)
}
pub const unsafe fn add_vt(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.p && b < self.p);
Self::reduce1_vt(a + b, self.p)
}
pub const fn sub(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.p && b < self.p);
Self::reduce1(a + self.p - b, self.p)
}
const unsafe fn sub_vt(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.p && b < self.p);
Self::reduce1_vt(a + self.p - b, self.p)
}
pub const fn mul(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.p && b < self.p);
self.reduce_u128((a as u128) * (b as u128))
}
const unsafe fn mul_vt(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.p && b < self.p);
Self::reduce1_vt(self.lazy_reduce_u128((a as u128) * (b as u128)), self.p)
}
pub const fn mul_opt(&self, a: u64, b: u64) -> u64 {
debug_assert!(self.supports_opt);
debug_assert!(a < self.p && b < self.p);
self.reduce_opt_u128((a as u128) * (b as u128))
}
const unsafe fn mul_opt_vt(&self, a: u64, b: u64) -> u64 {
debug_assert!(self.supports_opt);
debug_assert!(a < self.p && b < self.p);
self.reduce_opt_u128_vt((a as u128) * (b as u128))
}
pub const fn neg(&self, a: u64) -> u64 {
debug_assert!(a < self.p);
Self::reduce1(self.p - a, self.p)
}
const unsafe fn neg_vt(&self, a: u64) -> u64 {
debug_assert!(a < self.p);
Self::reduce1_vt(self.p - a, self.p)
}
pub const fn shoup(&self, a: u64) -> u64 {
debug_assert!(a < self.p);
(((a as u128) << 64) / (self.p as u128)) as u64
}
pub const fn mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
Self::reduce1(self.lazy_mul_shoup(a, b, b_shoup), self.p)
}
const unsafe fn mul_shoup_vt(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
Self::reduce1_vt(self.lazy_mul_shoup(a, b, b_shoup), self.p)
}
pub const fn lazy_mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
debug_assert!(b < self.p);
debug_assert!(b_shoup == self.shoup(b));
let q = ((a as u128) * (b_shoup as u128)) >> 64;
let r = ((a as u128) * (b as u128) - q * (self.p as u128)) as u64;
debug_assert!(r < 2 * self.p);
r
}
pub fn add_vec(&self, a: &mut [u64], b: &[u64]) {
debug_assert_eq!(a.len(), b.len());
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add(*ai, *bi));
}
pub unsafe fn add_vec_vt(&self, a: &mut [u64], b: &[u64]) {
let n = a.len();
debug_assert_eq!(n, b.len());
let p = self.p;
macro_rules! add_at {
($idx:expr) => {
*a.get_unchecked_mut($idx) =
Self::reduce1_vt(*a.get_unchecked_mut($idx) + *b.get_unchecked($idx), p);
};
}
if n % 16 == 0 {
for i in 0..n / 16 {
add_at!(16 * i);
add_at!(16 * i + 1);
add_at!(16 * i + 2);
add_at!(16 * i + 3);
add_at!(16 * i + 4);
add_at!(16 * i + 5);
add_at!(16 * i + 6);
add_at!(16 * i + 7);
add_at!(16 * i + 8);
add_at!(16 * i + 9);
add_at!(16 * i + 10);
add_at!(16 * i + 11);
add_at!(16 * i + 12);
add_at!(16 * i + 13);
add_at!(16 * i + 14);
add_at!(16 * i + 15);
}
} else {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add_vt(*ai, *bi));
}
}
pub fn sub_vec(&self, a: &mut [u64], b: &[u64]) {
debug_assert_eq!(a.len(), b.len());
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub(*ai, *bi));
}
pub unsafe fn sub_vec_vt(&self, a: &mut [u64], b: &[u64]) {
let n = a.len();
debug_assert_eq!(n, b.len());
let p = self.p;
macro_rules! sub_at {
($idx:expr) => {
*a.get_unchecked_mut($idx) =
Self::reduce1_vt(p + *a.get_unchecked_mut($idx) - *b.get_unchecked($idx), p);
};
}
if n % 16 == 0 {
for i in 0..n / 16 {
sub_at!(16 * i);
sub_at!(16 * i + 1);
sub_at!(16 * i + 2);
sub_at!(16 * i + 3);
sub_at!(16 * i + 4);
sub_at!(16 * i + 5);
sub_at!(16 * i + 6);
sub_at!(16 * i + 7);
sub_at!(16 * i + 8);
sub_at!(16 * i + 9);
sub_at!(16 * i + 10);
sub_at!(16 * i + 11);
sub_at!(16 * i + 12);
sub_at!(16 * i + 13);
sub_at!(16 * i + 14);
sub_at!(16 * i + 15);
}
} else {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub_vt(*ai, *bi));
}
}
pub fn mul_vec(&self, a: &mut [u64], b: &[u64]) {
debug_assert_eq!(a.len(), b.len());
if self.supports_opt {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt(*ai, *bi));
} else {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul(*ai, *bi));
}
}
pub fn scalar_mul_vec(&self, a: &mut [u64], b: u64) {
let b_shoup = self.shoup(b);
a.iter_mut()
.for_each(|ai| *ai = self.mul_shoup(*ai, b, b_shoup));
}
pub unsafe fn scalar_mul_vec_vt(&self, a: &mut [u64], b: u64) {
let b_shoup = self.shoup(b);
a.iter_mut()
.for_each(|ai| *ai = self.mul_shoup_vt(*ai, b, b_shoup));
}
pub unsafe fn mul_vec_vt(&self, a: &mut [u64], b: &[u64]) {
debug_assert_eq!(a.len(), b.len());
if self.supports_opt {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt_vt(*ai, *bi));
} else {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_vt(*ai, *bi));
}
}
pub fn shoup_vec(&self, a: &[u64]) -> Vec<u64> {
a.iter().map(|ai| self.shoup(*ai)).collect_vec()
}
pub fn mul_shoup_vec(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), b_shoup.len());
debug_assert_eq!(&b_shoup, &self.shoup_vec(b));
izip!(a.iter_mut(), b.iter(), b_shoup.iter())
.for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup(*ai, *bi, *bi_shoup));
}
pub unsafe fn mul_shoup_vec_vt(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), b_shoup.len());
debug_assert_eq!(&b_shoup, &self.shoup_vec(b));
izip!(a.iter_mut(), b.iter(), b_shoup.iter())
.for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup_vt(*ai, *bi, *bi_shoup));
}
pub fn reduce_vec(&self, a: &mut [u64]) {
a.iter_mut().for_each(|ai| *ai = self.reduce(*ai));
}
const unsafe fn center_vt(&self, a: u64) -> i64 {
debug_assert!(a < self.p);
if a >= self.p >> 1 {
(a as i64) - (self.p as i64)
} else {
a as i64
}
}
pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec<i64> {
a.iter().map(|ai| self.center_vt(*ai)).collect_vec()
}
pub unsafe fn reduce_vec_vt(&self, a: &mut [u64]) {
a.iter_mut().for_each(|ai| *ai = self.reduce_vt(*ai));
}
const fn reduce_i64(&self, a: i64) -> u64 {
self.reduce_u128((((self.p as i128) << 64) + (a as i128)) as u128)
}
const unsafe fn reduce_i64_vt(&self, a: i64) -> u64 {
self.reduce_u128_vt((((self.p as i128) << 64) + (a as i128)) as u128)
}
pub fn reduce_vec_i64(&self, a: &[i64]) -> Vec<u64> {
a.iter().map(|ai| self.reduce_i64(*ai)).collect_vec()
}
pub unsafe fn reduce_vec_i64_vt(&self, a: &[i64]) -> Vec<u64> {
a.iter().map(|ai| self.reduce_i64_vt(*ai)).collect_vec()
}
pub fn reduce_vec_new(&self, a: &[u64]) -> Vec<u64> {
a.iter().map(|ai| self.reduce(*ai)).collect_vec()
}
pub unsafe fn reduce_vec_new_vt(&self, a: &[u64]) -> Vec<u64> {
a.iter().map(|bi| self.reduce_vt(*bi)).collect_vec()
}
pub fn neg_vec(&self, a: &mut [u64]) {
izip!(a.iter_mut()).for_each(|ai| *ai = self.neg(*ai));
}
pub unsafe fn neg_vec_vt(&self, a: &mut [u64]) {
izip!(a.iter_mut()).for_each(|ai| *ai = self.neg_vt(*ai));
}
pub fn pow(&self, a: u64, n: u64) -> u64 {
debug_assert!(a < self.p && n < self.p);
if n == 0 {
1
} else if n == 1 {
a
} else {
let mut r = a;
let mut i = (62 - n.leading_zeros()) as isize;
while i >= 0 {
r = self.mul(r, r);
if (n >> i) & 1 == 1 {
r = self.mul(r, a);
}
i -= 1;
}
r
}
}
pub fn inv(&self, a: u64) -> std::option::Option<u64> {
if !is_prime(self.p) || a == 0 {
None
} else {
let r = self.pow(a, self.p - 2);
debug_assert_eq!(self.mul(a, r), 1);
Some(r)
}
}
pub const fn reduce_u128(&self, a: u128) -> u64 {
Self::reduce1(self.lazy_reduce_u128(a), self.p)
}
pub const unsafe fn reduce_u128_vt(&self, a: u128) -> u64 {
Self::reduce1_vt(self.lazy_reduce_u128(a), self.p)
}
pub const fn reduce(&self, a: u64) -> u64 {
Self::reduce1(self.lazy_reduce(a), self.p)
}
pub const unsafe fn reduce_vt(&self, a: u64) -> u64 {
Self::reduce1_vt(self.lazy_reduce(a), self.p)
}
pub const fn reduce_opt_u128(&self, a: u128) -> u64 {
debug_assert!(self.supports_opt);
Self::reduce1(self.lazy_reduce_opt_u128(a), self.p)
}
pub(crate) const unsafe fn reduce_opt_u128_vt(&self, a: u128) -> u64 {
debug_assert!(self.supports_opt);
Self::reduce1_vt(self.lazy_reduce_opt_u128(a), self.p)
}
pub const fn reduce_opt(&self, a: u64) -> u64 {
Self::reduce1(self.lazy_reduce_opt(a), self.p)
}
pub const unsafe fn reduce_opt_vt(&self, a: u64) -> u64 {
Self::reduce1_vt(self.lazy_reduce_opt(a), self.p)
}
const fn reduce1(x: u64, p: u64) -> u64 {
debug_assert!(p >> 63 == 0);
debug_assert!(x < 2 * p);
let (y, _) = x.overflowing_sub(p);
let xp = x ^ p;
let yp = y ^ p;
let xy = xp ^ yp;
let xxy = x ^ xy;
let xxy = xxy >> 63;
let (c, _) = xxy.overflowing_sub(1);
let r = (c & y) | ((!c) & x);
debug_assert!(r == x % p);
r
}
const unsafe fn reduce1_vt(x: u64, p: u64) -> u64 {
debug_assert!(p >> 63 == 0);
debug_assert!(x < 2 * p);
if x >= p {
x - p
} else {
x
}
}
pub const fn lazy_reduce_u128(&self, a: u128) -> u64 {
let a_lo = a as u64;
let a_hi = (a >> 64) as u64;
let p_lo_lo = ((a_lo as u128) * (self.barrett_lo as u128)) >> 64;
let p_hi_lo = (a_hi as u128) * (self.barrett_lo as u128);
let p_lo_hi = (a_lo as u128) * (self.barrett_hi as u128);
let q = ((p_lo_hi + p_hi_lo + p_lo_lo) >> 64) + (a_hi as u128) * (self.barrett_hi as u128);
let r = (a - q * (self.p as u128)) as u64;
debug_assert!((r as u128) < 2 * (self.p as u128));
debug_assert!(r % self.p == (a % (self.p as u128)) as u64);
r
}
pub const fn lazy_reduce(&self, a: u64) -> u64 {
let p_lo_lo = ((a as u128) * (self.barrett_lo as u128)) >> 64;
let p_lo_hi = (a as u128) * (self.barrett_hi as u128);
let q = (p_lo_hi + p_lo_lo) >> 64;
let r = (a as u128 - q * (self.p as u128)) as u64;
debug_assert!((r as u128) < 2 * (self.p as u128));
debug_assert!(r % self.p == a % self.p);
r
}
pub const fn lazy_reduce_opt_u128(&self, a: u128) -> u64 {
debug_assert!(a < (self.p as u128) * (self.p as u128));
let q = (((self.barrett_lo as u128) * (a >> 64)) + (a << self.leading_zeros)) >> 64;
let r = (a - q * (self.p as u128)) as u64;
debug_assert!((r as u128) < 2 * (self.p as u128));
debug_assert!(r % self.p == (a % (self.p as u128)) as u64);
r
}
const fn lazy_reduce_opt(&self, a: u64) -> u64 {
let q = a >> (64 - self.leading_zeros);
let r = ((a as u128) - (q as u128) * (self.p as u128)) as u64;
debug_assert!((r as u128) < 2 * (self.p as u128));
debug_assert!(r % self.p == a % self.p);
r
}
pub fn lazy_reduce_vec(&self, a: &mut [u64]) {
if self.supports_opt {
a.iter_mut().for_each(|ai| *ai = self.lazy_reduce_opt(*ai))
} else {
a.iter_mut().for_each(|ai| *ai = self.lazy_reduce(*ai))
}
}
pub fn random_vec<R: RngCore + CryptoRng>(&self, size: usize, rng: &mut R) -> Vec<u64> {
rng.sample_iter(self.distribution).take(size).collect_vec()
}
pub const fn serialization_length(&self, size: usize) -> usize {
assert!(size % 8 == 0);
let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
p_nbits * size / 8
}
pub fn serialize_vec(&self, a: &[u64]) -> Vec<u8> {
let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
transcode_to_bytes(a, p_nbits)
}
pub fn deserialize_vec(&self, b: &[u8]) -> Vec<u64> {
let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
transcode_from_bytes(b, p_nbits)
}
}
#[cfg(test)]
mod tests {
use super::{primes, Modulus};
use fhe_util::catch_unwind;
use itertools::{izip, Itertools};
use proptest::collection::vec as prop_vec;
use proptest::prelude::{any, BoxedStrategy, Just, Strategy};
use rand::{thread_rng, RngCore};
fn valid_moduli() -> impl Strategy<Value = Modulus> {
any::<u64>().prop_filter_map("filter invalid moduli", |p| Modulus::new(p).ok())
}
fn vecs() -> BoxedStrategy<(Vec<u64>, Vec<u64>)> {
prop_vec(any::<u64>(), 1..100)
.prop_flat_map(|vec| {
let len = vec.len();
(Just(vec), prop_vec(any::<u64>(), len))
})
.boxed()
}
proptest! {
#[test]
fn constructor(p: u64) {
prop_assert!(Modulus::new(p | (1u64 << 62)).is_err());
prop_assert!(Modulus::new(p | (1u64 << 63)).is_err());
prop_assert!(Modulus::new(0u64).is_err());
prop_assert!(Modulus::new(1u64).is_err());
prop_assume!(p >> 2 >= 2);
let q = Modulus::new(p >> 2);
prop_assert!(q.is_ok());
prop_assert_eq!(q.unwrap().modulus(), p >> 2);
}
#[test]
fn neg(p in valid_moduli(), mut a: u64, mut q: u64) {
a = p.reduce(a);
prop_assert_eq!(p.neg(a), (p.modulus() - a) % p.modulus());
unsafe { prop_assert_eq!(p.neg_vt(a), (p.modulus() - a) % p.modulus()) }
q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); prop_assert!(catch_unwind(|| p.neg(q)).is_err());
}
#[test]
fn add(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) {
a = p.reduce(a);
b = p.reduce(b);
prop_assert_eq!(p.add(a, b), (a + b) % p.modulus());
unsafe { prop_assert_eq!(p.add_vt(a, b), (a + b) % p.modulus()) }
q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); prop_assert!(catch_unwind(|| p.add(q, a)).is_err());
prop_assert!(catch_unwind(|| p.add(a, q)).is_err());
}
#[test]
fn sub(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) {
a = p.reduce(a);
b = p.reduce(b);
prop_assert_eq!(p.sub(a, b), (a + p.modulus() - b) % p.modulus());
unsafe { prop_assert_eq!(p.sub_vt(a, b), (a + p.modulus() - b) % p.modulus()) }
q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); prop_assert!(catch_unwind(|| p.sub(q, a)).is_err());
prop_assert!(catch_unwind(|| p.sub(a, q)).is_err());
}
#[test]
fn mul(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) {
a = p.reduce(a);
b = p.reduce(b);
prop_assert_eq!(p.mul(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128));
unsafe { prop_assert_eq!(p.mul_vt(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) }
q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus(); prop_assert!(catch_unwind(|| p.mul(q, a)).is_err());
prop_assert!(catch_unwind(|| p.mul(a, q)).is_err());
}
#[test]
fn mul_shoup(p in valid_moduli(), mut a: u64, mut b: u64, mut q: u64) {
a = p.reduce(a);
b = p.reduce(b);
q = (q % (u64::MAX - p.modulus())) + 1 + p.modulus();
let b_shoup = p.shoup(b);
prop_assert!(catch_unwind(|| p.shoup(q)).is_err());
prop_assert_eq!(p.mul_shoup(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128));
unsafe { prop_assert_eq!(p.mul_shoup_vt(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) }
prop_assert!(catch_unwind(|| p.mul_shoup(a, q, b_shoup)).is_err());
prop_assume!(a != b);
prop_assert!(catch_unwind(|| p.mul_shoup(a, a, b_shoup)).is_err());
}
#[test]
fn reduce(p in valid_moduli(), a: u64) {
prop_assert_eq!(p.reduce(a), a % p.modulus());
unsafe { prop_assert_eq!(p.reduce_vt(a), a % p.modulus()) }
if p.supports_opt {
prop_assert_eq!(p.reduce_opt(a), a % p.modulus());
unsafe { prop_assert_eq!(p.reduce_opt_vt(a), a % p.modulus()) }
}
}
#[test]
fn lazy_reduce(p in valid_moduli(), a: u64) {
prop_assert!(p.lazy_reduce(a) < 2 * p.modulus());
prop_assert_eq!(p.lazy_reduce(a) % p.modulus(), p.reduce(a));
}
#[test]
fn reduce_i64(p in valid_moduli(), a: i64) {
let b = if a < 0 { p.neg(p.reduce(-a as u64)) } else { p.reduce(a as u64) };
prop_assert_eq!(p.reduce_i64(a), b);
unsafe { prop_assert_eq!(p.reduce_i64_vt(a), b) }
}
#[test]
fn reduce_u128(p in valid_moduli(), mut a: u128) {
prop_assert_eq!(p.reduce_u128(a) as u128, a % (p.modulus() as u128));
unsafe { prop_assert_eq!(p.reduce_u128_vt(a) as u128, a % (p.modulus() as u128)) }
if p.supports_opt {
let p_square = (p.modulus() as u128) * (p.modulus() as u128);
a %= p_square;
prop_assert_eq!(p.reduce_opt_u128(a) as u128, a % (p.modulus() as u128));
unsafe { prop_assert_eq!(p.reduce_opt_u128_vt(a) as u128, a % (p.modulus() as u128)) }
}
}
#[test]
fn add_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
p.reduce_vec(&mut a);
p.reduce_vec(&mut b);
let c = a.clone();
p.add_vec(&mut a, &b);
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec());
a = c.clone();
unsafe { p.add_vec_vt(&mut a, &b) }
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec());
}
#[test]
fn sub_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
p.reduce_vec(&mut a);
p.reduce_vec(&mut b);
let c = a.clone();
p.sub_vec(&mut a, &b);
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec());
a = c.clone();
unsafe { p.sub_vec_vt(&mut a, &b) }
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec());
}
#[test]
fn mul_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
p.reduce_vec(&mut a);
p.reduce_vec(&mut b);
let c = a.clone();
p.mul_vec(&mut a, &b);
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
a = c.clone();
unsafe { p.mul_vec_vt(&mut a, &b); }
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
}
#[test]
fn scalar_mul_vec(p in valid_moduli(), mut a: Vec<u64>, mut b: u64) {
p.reduce_vec(&mut a);
b = p.reduce(b);
let c = a.clone();
p.scalar_mul_vec(&mut a, b);
prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec());
a = c.clone();
unsafe { p.scalar_mul_vec_vt(&mut a, b) }
prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec());
}
#[test]
fn mul_shoup_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
p.reduce_vec(&mut a);
p.reduce_vec(&mut b);
let b_shoup = p.shoup_vec(&b);
let c = a.clone();
p.mul_shoup_vec(&mut a, &b, &b_shoup);
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
a = c.clone();
unsafe { p.mul_shoup_vec_vt(&mut a, &b, &b_shoup) }
prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
}
#[test]
fn reduce_vec(p in valid_moduli(), a: Vec<u64>) {
let mut b = a.clone();
p.reduce_vec(&mut b);
prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec());
b = a.clone();
unsafe { p.reduce_vec_vt(&mut b) }
prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec());
}
#[test]
fn lazy_reduce_vec(p in valid_moduli(), a: Vec<u64>) {
let mut b = a.clone();
p.lazy_reduce_vec(&mut b);
prop_assert!(b.iter().all(|bi| *bi < 2 * p.modulus()));
prop_assert!(izip!(a, b).all(|(ai, bi)| bi % p.modulus() == ai % p.modulus()));
}
#[test]
fn reduce_vec_new(p in valid_moduli(), a: Vec<u64>) {
let b = p.reduce_vec_new(&a);
prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec());
prop_assert_eq!(p.reduce_vec_new(&a), unsafe { p.reduce_vec_new_vt(&a) });
}
#[test]
fn reduce_vec_i64(p in valid_moduli(), a: Vec<i64>) {
let b = p.reduce_vec_i64(&a);
prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec());
let b = unsafe { p.reduce_vec_i64_vt(&a) };
prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec());
}
#[test]
fn neg_vec(p in valid_moduli(), mut a: Vec<u64>) {
p.reduce_vec(&mut a);
let mut b = a.clone();
p.neg_vec(&mut b);
prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec());
b = a.clone();
unsafe { p.neg_vec_vt(&mut b); }
prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec());
}
#[test]
fn random_vec(p in valid_moduli(), size in 1..1000usize) {
let mut rng = thread_rng();
let v = p.random_vec(size, &mut rng);
prop_assert_eq!(v.len(), size);
let w = p.random_vec(size, &mut rng);
prop_assert_eq!(w.len(), size);
if p.modulus().leading_zeros() <= 30 {
prop_assert_ne!(v, w); }
}
#[test]
fn serialize(p in valid_moduli(), mut a in prop_vec(any::<u64>(), 8)) {
p.reduce_vec(&mut a);
let b = p.serialize_vec(&a);
let c = p.deserialize_vec(&b);
prop_assert_eq!(a, c);
}
}
#[test]
fn mul_opt() {
let ntests = 100;
let mut rng = rand::thread_rng();
#[allow(clippy::single_element_loop)]
for p in [4611686018326724609] {
let q = Modulus::new(p).unwrap();
assert!(primes::supports_opt(p));
assert_eq!(q.mul_opt(0, 1), 0);
assert_eq!(q.mul_opt(1, 1), 1);
assert_eq!(q.mul_opt(2 % p, 3 % p), 6 % p);
assert_eq!(q.mul_opt(p - 1, 1), p - 1);
assert_eq!(q.mul_opt(p - 1, 2 % p), p - 2);
assert!(catch_unwind(|| q.mul_opt(p, 1)).is_err());
assert!(catch_unwind(|| q.mul_opt(p << 1, 1)).is_err());
assert!(catch_unwind(|| q.mul_opt(0, p)).is_err());
assert!(catch_unwind(|| q.mul_opt(0, p << 1)).is_err());
for _ in 0..ntests {
let a = rng.next_u64() % p;
let b = rng.next_u64() % p;
assert_eq!(
q.mul_opt(a, b),
(((a as u128) * (b as u128)) % (p as u128)) as u64
);
}
}
}
#[test]
fn pow() {
let ntests = 10;
let mut rng = rand::thread_rng();
for p in [2u64, 3, 17, 1987, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
assert_eq!(q.pow(p - 1, 0), 1);
assert_eq!(q.pow(p - 1, 1), p - 1);
assert_eq!(q.pow(p - 1, 2 % p), 1);
assert_eq!(q.pow(1, p - 2), 1);
assert_eq!(q.pow(1, p - 1), 1);
assert!(catch_unwind(|| q.pow(p, 1)).is_err());
assert!(catch_unwind(|| q.pow(p << 1, 1)).is_err());
assert!(catch_unwind(|| q.pow(0, p)).is_err());
assert!(catch_unwind(|| q.pow(0, p << 1)).is_err());
for _ in 0..ntests {
let a = rng.next_u64() % p;
let b = (rng.next_u64() % p) % 1000;
let mut c = b;
let mut r = 1;
while c > 0 {
r = q.mul(r, a);
c -= 1;
}
assert_eq!(q.pow(a, b), r);
}
}
}
#[test]
fn inv() {
let ntests = 100;
let mut rng = rand::thread_rng();
for p in [2u64, 3, 17, 1987, 4611686018326724609] {
let q = Modulus::new(p).unwrap();
assert!(q.inv(0).is_none());
assert_eq!(q.inv(1).unwrap(), 1);
assert_eq!(q.inv(p - 1).unwrap(), p - 1);
assert!(catch_unwind(|| q.inv(p)).is_err());
assert!(catch_unwind(|| q.inv(p << 1)).is_err());
for _ in 0..ntests {
let a = rng.next_u64() % p;
let b = q.inv(a);
if a == 0 {
assert!(b.is_none())
} else {
assert!(b.is_some());
assert_eq!(q.mul(a, b.unwrap()), 1)
}
}
}
}
}