use super::crt::{crt_compose_2, crt_decompose_2, crt_modulus, mod_inverse};
use super::gaussian::GaussianSampler;
use super::mod_q::{ModQ, DEFAULT_Q};
use super::ntt::NttContext;
use rand::Rng;
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct Poly {
coeffs: Vec<u64>,
moduli: Vec<u64>,
q: u64,
dim: usize,
crt_q0_inv_mod_q1: u64,
is_ntt: bool,
}
impl Poly {
fn init_moduli(moduli: &[u64]) -> (Vec<u64>, u64, u64) {
assert!(!moduli.is_empty(), "moduli must be non-empty");
if moduli.len() > 2 {
panic!("CRT with more than 2 moduli not supported");
}
let moduli_vec = moduli.to_vec();
let q = crt_modulus(&moduli_vec);
let inv = if moduli_vec.len() == 2 {
mod_inverse(moduli_vec[0], moduli_vec[1])
} else {
0
};
(moduli_vec, q, inv)
}
pub fn zero(dim: usize, q: u64) -> Self {
Self::zero_moduli(dim, &[q])
}
pub fn zero_moduli(dim: usize, moduli: &[u64]) -> Self {
let (moduli_vec, q, inv) = Self::init_moduli(moduli);
let crt_count = moduli_vec.len();
Self {
coeffs: vec![0; dim * crt_count],
moduli: moduli_vec,
q,
dim,
crt_q0_inv_mod_q1: inv,
is_ntt: false,
}
}
pub fn zero_default(dim: usize) -> Self {
Self::zero(dim, DEFAULT_Q)
}
pub fn from_coeffs(coeffs: Vec<u64>, q: u64) -> Self {
Self::from_coeffs_moduli(coeffs, &[q])
}
pub fn from_coeffs_moduli(coeffs: Vec<u64>, moduli: &[u64]) -> Self {
let dim = coeffs.len();
let (moduli_vec, q, inv) = Self::init_moduli(moduli);
let crt_count = moduli_vec.len();
let mut crt_coeffs = vec![0u64; dim * crt_count];
if crt_count == 1 {
for (i, &c) in coeffs.iter().enumerate() {
crt_coeffs[i] = c % moduli_vec[0];
}
} else if crt_count == 2 {
let q0 = moduli_vec[0];
let q1 = moduli_vec[1];
for (i, &c) in coeffs.iter().enumerate() {
let (c0, c1) = crt_decompose_2(c, q0, q1);
crt_coeffs[i] = c0;
crt_coeffs[i + dim] = c1;
}
} else {
panic!("CRT with more than 2 moduli not supported");
}
let mut p = Self {
coeffs: crt_coeffs,
moduli: moduli_vec,
q,
dim,
crt_q0_inv_mod_q1: inv,
is_ntt: false,
};
p.reduce();
p
}
pub fn from_crt_coeffs(coeffs: Vec<u64>, moduli: &[u64]) -> Self {
let (moduli_vec, q, inv) = Self::init_moduli(moduli);
let crt_count = moduli_vec.len();
assert!(
coeffs.len().is_multiple_of(crt_count),
"CRT coeffs length must be a multiple of crt_count"
);
let dim = coeffs.len() / crt_count;
let mut p = Self {
coeffs,
moduli: moduli_vec,
q,
dim,
crt_q0_inv_mod_q1: inv,
is_ntt: false,
};
p.reduce();
p
}
pub fn from_coeffs_default(coeffs: Vec<u64>) -> Self {
Self::from_coeffs(coeffs, DEFAULT_Q)
}
pub fn constant(value: u64, dim: usize, q: u64) -> Self {
Self::constant_moduli(value, dim, &[q])
}
pub fn constant_moduli(value: u64, dim: usize, moduli: &[u64]) -> Self {
let mut coeffs = vec![0u64; dim];
coeffs[0] = value;
Self::from_coeffs_moduli(coeffs, moduli)
}
pub fn sample_gaussian(dim: usize, q: u64, sampler: &mut GaussianSampler) -> Self {
Self::sample_gaussian_moduli(dim, &[q], sampler)
}
pub fn sample_gaussian_moduli(
dim: usize,
moduli: &[u64],
sampler: &mut GaussianSampler,
) -> Self {
let (moduli_vec, q, inv) = Self::init_moduli(moduli);
let crt_count = moduli_vec.len();
let mut coeffs = vec![0u64; dim * crt_count];
let samples = sampler.sample_vec(dim);
for (m, &modulus) in moduli_vec.iter().enumerate() {
let offset = m * dim;
for i in 0..dim {
coeffs[offset + i] = crate::math::ModQ::from_signed(samples[i], modulus);
}
}
Self {
coeffs,
moduli: moduli_vec,
q,
dim,
crt_q0_inv_mod_q1: inv,
is_ntt: false,
}
}
pub fn random(dim: usize, q: u64) -> Self {
Self::random_moduli(dim, &[q])
}
pub fn random_moduli(dim: usize, moduli: &[u64]) -> Self {
let mut rng = rand::thread_rng();
Self::random_with_rng_moduli(dim, moduli, &mut rng)
}
pub fn random_with_rng<R: Rng>(dim: usize, q: u64, rng: &mut R) -> Self {
Self::random_with_rng_moduli(dim, &[q], rng)
}
pub fn random_with_rng_moduli<R: Rng>(dim: usize, moduli: &[u64], rng: &mut R) -> Self {
let (moduli_vec, q, inv) = Self::init_moduli(moduli);
let crt_count = moduli_vec.len();
let mut coeffs = vec![0u64; dim * crt_count];
for (m, &modulus) in moduli_vec.iter().enumerate() {
for i in 0..dim {
coeffs[m * dim + i] = rng.gen_range(0..modulus);
}
}
Self {
coeffs,
moduli: moduli_vec,
q,
dim,
crt_q0_inv_mod_q1: inv,
is_ntt: false,
}
}
pub fn from_seed(seed: &[u8; 32], dim: usize, q: u64) -> Self {
Self::from_seed_moduli(seed, dim, &[q])
}
pub fn from_seed_moduli(seed: &[u8; 32], dim: usize, moduli: &[u64]) -> Self {
let mut rng = ChaCha20Rng::from_seed(*seed);
Self::random_with_rng_moduli(dim, moduli, &mut rng)
}
pub fn from_seed_indexed(seed: &[u8; 32], index: usize, dim: usize, q: u64) -> Self {
Self::from_seed_indexed_moduli(seed, index, dim, &[q])
}
pub fn from_seed_indexed_moduli(
seed: &[u8; 32],
index: usize,
dim: usize,
moduli: &[u64],
) -> Self {
let mut derived_seed = *seed;
let idx_bytes = (index as u64).to_le_bytes();
for i in 0..8 {
derived_seed[i] ^= idx_bytes[i];
}
Self::from_seed_moduli(&derived_seed, dim, moduli)
}
pub fn dimension(&self) -> usize {
self.dim
}
pub fn len(&self) -> usize {
self.dim
}
pub fn is_empty(&self) -> bool {
self.dim == 0
}
pub fn modulus(&self) -> u64 {
self.q
}
pub fn moduli(&self) -> &[u64] {
&self.moduli
}
pub fn crt_count(&self) -> usize {
self.moduli.len()
}
pub fn is_ntt(&self) -> bool {
self.is_ntt
}
#[inline]
pub fn force_ntt_domain(&mut self) {
self.is_ntt = true;
}
#[inline]
pub fn force_coeff_domain(&mut self) {
self.is_ntt = false;
}
pub fn coeff(&self, i: usize) -> u64 {
assert!(!self.is_ntt, "Cannot access coefficients in NTT domain");
match self.moduli.len() {
1 => self.coeffs[i],
2 => {
let q0 = self.moduli[0];
let q1 = self.moduli[1];
let a0 = self.coeffs[i];
let a1 = self.coeffs[i + self.dim];
crt_compose_2(a0, a1, q0, q1, self.crt_q0_inv_mod_q1) % self.q
}
_ => panic!("CRT with more than 2 moduli not supported"),
}
}
pub fn set_coeff(&mut self, i: usize, value: u64) {
assert!(!self.is_ntt, "Cannot set coefficients in NTT domain");
match self.moduli.len() {
1 => {
self.coeffs[i] = value % self.moduli[0];
}
2 => {
let (c0, c1) = crt_decompose_2(value, self.moduli[0], self.moduli[1]);
self.coeffs[i] = c0;
self.coeffs[i + self.dim] = c1;
}
_ => panic!("CRT with more than 2 moduli not supported"),
}
}
pub fn coeffs(&self) -> &[u64] {
&self.coeffs
}
pub fn coeffs_mut(&mut self) -> &mut [u64] {
&mut self.coeffs
}
pub fn coeffs_modulus(&self, modulus_idx: usize) -> &[u64] {
let start = modulus_idx * self.dim;
let end = start + self.dim;
&self.coeffs[start..end]
}
pub fn coeffs_modulus_mut(&mut self, modulus_idx: usize) -> &mut [u64] {
let start = modulus_idx * self.dim;
let end = start + self.dim;
&mut self.coeffs[start..end]
}
fn reduce(&mut self) {
for (m, &modulus) in self.moduli.iter().enumerate() {
let start = m * self.dim;
let end = start + self.dim;
for c in &mut self.coeffs[start..end] {
*c %= modulus;
}
}
}
pub fn to_ntt(&mut self, ctx: &NttContext) {
if !self.is_ntt {
debug_assert_eq!(
self.moduli,
ctx.moduli(),
"NTT context moduli must match polynomial moduli"
);
ctx.forward(&mut self.coeffs);
self.is_ntt = true;
}
}
pub fn from_ntt(&mut self, ctx: &NttContext) {
if self.is_ntt {
debug_assert_eq!(
self.moduli,
ctx.moduli(),
"NTT context moduli must match polynomial moduli"
);
ctx.inverse(&mut self.coeffs);
self.is_ntt = false;
}
}
pub fn to_ntt_new(&self, ctx: &NttContext) -> Self {
let mut result = self.clone();
result.to_ntt(ctx);
result
}
pub fn from_ntt_new(&self, ctx: &NttContext) -> Self {
let mut result = self.clone();
result.from_ntt(ctx);
result
}
pub fn scalar_mul(&self, scalar: u64) -> Self {
let mut coeffs = self.coeffs.clone();
for (m, &modulus) in self.moduli.iter().enumerate() {
let scalar_mod = scalar % modulus;
let start = m * self.dim;
let end = start + self.dim;
for c in &mut coeffs[start..end] {
*c = ((*c as u128 * scalar_mod as u128) % modulus as u128) as u64;
}
}
Self {
coeffs,
moduli: self.moduli.clone(),
q: self.q,
dim: self.dim,
crt_q0_inv_mod_q1: self.crt_q0_inv_mod_q1,
is_ntt: self.is_ntt,
}
}
pub fn scalar_mul_assign(&mut self, scalar: u64) {
for (m, &modulus) in self.moduli.iter().enumerate() {
let scalar_mod = scalar % modulus;
let start = m * self.dim;
let end = start + self.dim;
for c in &mut self.coeffs[start..end] {
*c = ((*c as u128 * scalar_mod as u128) % modulus as u128) as u64;
}
}
}
pub fn scalar_mul_modq(&self, scalar: ModQ) -> Self {
self.scalar_mul(scalar.value())
}
pub fn mul_ntt(&self, other: &Self, ctx: &NttContext) -> Self {
assert_eq!(self.moduli, other.moduli, "Moduli must match");
assert_eq!(
self.coeffs.len(),
other.coeffs.len(),
"Dimensions must match"
);
let mut a = self.clone();
let mut b = other.clone();
a.to_ntt(ctx);
b.to_ntt(ctx);
let mut result = vec![0u64; self.coeffs.len()];
ctx.pointwise_mul(&a.coeffs, &b.coeffs, &mut result);
let mut poly = Self {
coeffs: result,
moduli: self.moduli.clone(),
q: self.q,
dim: self.dim,
crt_q0_inv_mod_q1: self.crt_q0_inv_mod_q1,
is_ntt: true,
};
poly.from_ntt(ctx);
poly
}
pub fn mul_ntt_domain(&self, other: &Self, ctx: &NttContext) -> Self {
assert!(
self.is_ntt && other.is_ntt,
"Both polynomials must be in NTT domain"
);
assert_eq!(self.moduli, other.moduli, "Moduli must match");
let mut result = vec![0u64; self.coeffs.len()];
ctx.pointwise_mul(&self.coeffs, &other.coeffs, &mut result);
Self {
coeffs: result,
moduli: self.moduli.clone(),
q: self.q,
dim: self.dim,
crt_q0_inv_mod_q1: self.crt_q0_inv_mod_q1,
is_ntt: true,
}
}
pub fn add_ntt_domain(&self, other: &Self) -> Self {
assert!(
self.is_ntt && other.is_ntt,
"Both polynomials must be in NTT domain"
);
assert_eq!(self.moduli, other.moduli, "Moduli must match");
assert_eq!(
self.coeffs.len(),
other.coeffs.len(),
"Dimensions must match"
);
let mut coeffs = self.coeffs.clone();
for (m, &modulus) in self.moduli.iter().enumerate() {
let start = m * self.dim;
let end = start + self.dim;
for (c, &o) in coeffs[start..end].iter_mut().zip(&other.coeffs[start..end]) {
let sum = *c + o;
*c = if sum >= modulus { sum - modulus } else { sum };
}
}
Self {
coeffs,
moduli: self.moduli.clone(),
q: self.q,
dim: self.dim,
crt_q0_inv_mod_q1: self.crt_q0_inv_mod_q1,
is_ntt: true,
}
}
pub fn add_assign_ntt_domain(&mut self, other: &Self) {
assert!(
self.is_ntt && other.is_ntt,
"Both polynomials must be in NTT domain"
);
assert_eq!(self.moduli, other.moduli, "Moduli must match");
for (m, &modulus) in self.moduli.iter().enumerate() {
let start = m * self.dim;
let end = start + self.dim;
for i in start..end {
let sum = self.coeffs[i] + other.coeffs[i];
self.coeffs[i] = if sum >= modulus { sum - modulus } else { sum };
}
}
}
pub fn mul_acc_ntt_domain(&mut self, a: &Self, b: &Self, ctx: &NttContext) {
assert!(
self.is_ntt && a.is_ntt && b.is_ntt,
"All polynomials must be in NTT domain"
);
assert_eq!(self.moduli, a.moduli, "Moduli must match");
assert_eq!(self.moduli, b.moduli, "Moduli must match");
for (m, &modulus) in self.moduli.iter().enumerate() {
let start = m * self.dim;
let end = start + self.dim;
for i in start..end {
let prod = ctx.pointwise_mul_single_at(a.coeffs[i], b.coeffs[i], m);
let sum = self.coeffs[i] + prod;
self.coeffs[i] = if sum >= modulus { sum - modulus } else { sum };
}
}
}
pub fn is_zero(&self) -> bool {
self.coeffs.iter().all(|&c| c == 0)
}
pub fn linf_norm(&self) -> u64 {
assert!(!self.is_ntt, "Cannot compute norm in NTT domain");
let mut max_val = 0u64;
for i in 0..self.dim {
let c = self.coeff(i);
let centered = if c <= self.q / 2 { c } else { self.q - c };
if centered > max_val {
max_val = centered;
}
}
max_val
}
pub fn l2_norm_squared(&self) -> u128 {
assert!(!self.is_ntt, "Cannot compute norm in NTT domain");
let mut sum = 0u128;
for i in 0..self.dim {
let c = self.coeff(i);
let centered = if c <= self.q / 2 {
c as i64
} else {
c as i64 - self.q as i64
};
sum += (centered as i128 * centered as i128) as u128;
}
sum
}
pub fn mul(&self, other: &Self) -> Self {
let ctx = NttContext::with_moduli(self.dim, &self.moduli);
self.mul_ntt(other, &ctx)
}
pub fn add(&self, other: &Self) -> Self {
self + other
}
pub fn sub(&self, other: &Self) -> Self {
self - other
}
pub fn negate(&self) -> Self {
-self
}
}
impl PartialEq for Poly {
fn eq(&self, other: &Self) -> bool {
self.q == other.q
&& self.moduli == other.moduli
&& self.dim == other.dim
&& self.is_ntt == other.is_ntt
&& self.coeffs == other.coeffs
}
}
impl Eq for Poly {}
impl Add for Poly {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
&self + &rhs
}
}
impl Add for &Poly {
type Output = Poly;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.moduli, rhs.moduli, "Moduli must match");
assert_eq!(self.is_ntt, rhs.is_ntt, "NTT domains must match");
let mut coeffs = self.coeffs.clone();
for (m, &modulus) in self.moduli.iter().enumerate() {
let start = m * self.dim;
let end = start + self.dim;
for (c, &r) in coeffs[start..end].iter_mut().zip(&rhs.coeffs[start..end]) {
let sum = *c + r;
*c = if sum >= modulus { sum - modulus } else { sum };
}
}
Poly {
coeffs,
moduli: self.moduli.clone(),
q: self.q,
dim: self.dim,
crt_q0_inv_mod_q1: self.crt_q0_inv_mod_q1,
is_ntt: self.is_ntt,
}
}
}
impl AddAssign for Poly {
fn add_assign(&mut self, rhs: Self) {
*self = &*self + &rhs;
}
}
impl AddAssign<&Poly> for Poly {
fn add_assign(&mut self, rhs: &Self) {
*self = &*self + rhs;
}
}
impl Sub for Poly {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
&self - &rhs
}
}
impl Sub for &Poly {
type Output = Poly;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.moduli, rhs.moduli, "Moduli must match");
assert_eq!(self.is_ntt, rhs.is_ntt, "NTT domains must match");
let mut coeffs = self.coeffs.clone();
for (m, &modulus) in self.moduli.iter().enumerate() {
let start = m * self.dim;
let end = start + self.dim;
for (c, &b) in coeffs[start..end].iter_mut().zip(&rhs.coeffs[start..end]) {
let a = *c;
*c = if a >= b { a - b } else { modulus - b + a };
}
}
Poly {
coeffs,
moduli: self.moduli.clone(),
q: self.q,
dim: self.dim,
crt_q0_inv_mod_q1: self.crt_q0_inv_mod_q1,
is_ntt: self.is_ntt,
}
}
}
impl SubAssign for Poly {
fn sub_assign(&mut self, rhs: Self) {
*self = &*self - &rhs;
}
}
impl SubAssign<&Poly> for Poly {
fn sub_assign(&mut self, rhs: &Self) {
*self = &*self - rhs;
}
}
impl Neg for Poly {
type Output = Self;
fn neg(self) -> Self::Output {
-&self
}
}
impl Neg for &Poly {
type Output = Poly;
fn neg(self) -> Self::Output {
let mut coeffs = self.coeffs.clone();
for (m, &modulus) in self.moduli.iter().enumerate() {
let start = m * self.dim;
let end = start + self.dim;
for c in &mut coeffs[start..end] {
*c = if *c == 0 { 0 } else { modulus - *c };
}
}
Poly {
coeffs,
moduli: self.moduli.clone(),
q: self.q,
dim: self.dim,
crt_q0_inv_mod_q1: self.crt_q0_inv_mod_q1,
is_ntt: self.is_ntt,
}
}
}
impl Mul for Poly {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
assert_eq!(self.is_ntt, rhs.is_ntt, "NTT domains must match");
assert!(
self.is_ntt,
"Use mul_ntt for coefficient domain multiplication"
);
let ctx = NttContext::with_moduli(self.dim, &self.moduli);
self.mul_ntt_domain(&rhs, &ctx)
}
}
impl MulAssign for Poly {
fn mul_assign(&mut self, rhs: Self) {
*self = self.clone() * rhs;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ctx(n: usize) -> NttContext {
NttContext::with_default_q(n)
}
#[test]
fn test_zero_polynomial() {
let p = Poly::zero_default(256);
assert!(p.is_zero());
assert_eq!(p.dimension(), 256);
}
#[test]
fn test_constant_polynomial() {
let p = Poly::constant(42, 256, DEFAULT_Q);
assert_eq!(p.coeff(0), 42);
assert!(p.coeffs()[1..].iter().all(|&c| c == 0));
}
#[test]
fn test_addition() {
let a = Poly::from_coeffs_default(vec![1, 2, 3, 4]);
let b = Poly::from_coeffs_default(vec![5, 6, 7, 8]);
let c = &a + &b;
assert_eq!(c.coeff(0), 6);
assert_eq!(c.coeff(1), 8);
assert_eq!(c.coeff(2), 10);
assert_eq!(c.coeff(3), 12);
}
#[test]
fn test_subtraction() {
let q = DEFAULT_Q;
let a = Poly::from_coeffs(vec![10, 20, 30, 40], q);
let b = Poly::from_coeffs(vec![5, 6, 7, 8], q);
let c = &a - &b;
assert_eq!(c.coeff(0), 5);
assert_eq!(c.coeff(1), 14);
assert_eq!(c.coeff(2), 23);
assert_eq!(c.coeff(3), 32);
}
#[test]
fn test_subtraction_underflow() {
let q = DEFAULT_Q;
let a = Poly::from_coeffs(vec![5, 6, 7, 8], q);
let b = Poly::from_coeffs(vec![10, 20, 30, 40], q);
let c = &a - &b;
assert_eq!(c.coeff(0), q - 5);
assert_eq!(c.coeff(1), q - 14);
}
#[test]
fn test_negation() {
let q = DEFAULT_Q;
let a = Poly::from_coeffs(vec![1, 2, 3, 0], q);
let neg_a = -&a;
assert_eq!(neg_a.coeff(0), q - 1);
assert_eq!(neg_a.coeff(1), q - 2);
assert_eq!(neg_a.coeff(2), q - 3);
assert_eq!(neg_a.coeff(3), 0);
let sum = &a + &neg_a;
assert!(sum.is_zero());
}
#[test]
fn test_scalar_multiplication() {
let a = Poly::from_coeffs_default(vec![1, 2, 3, 4]);
let b = a.scalar_mul(10);
assert_eq!(b.coeff(0), 10);
assert_eq!(b.coeff(1), 20);
assert_eq!(b.coeff(2), 30);
assert_eq!(b.coeff(3), 40);
}
#[test]
fn test_ntt_roundtrip() {
let ctx = make_ctx(256);
let mut p = Poly::from_coeffs_default((0..256).collect());
let original = p.clone();
p.to_ntt(&ctx);
assert!(p.is_ntt());
p.from_ntt(&ctx);
assert!(!p.is_ntt());
assert_eq!(p, original);
}
#[test]
fn test_poly_mul_ntt_identity() {
let n = 256;
let ctx = make_ctx(n);
let a = Poly::from_coeffs_default((0..n as u64).collect());
let one = Poly::constant(1, n, DEFAULT_Q);
let result = a.mul_ntt(&one, &ctx);
assert_eq!(result, a);
}
#[test]
fn test_poly_mul_ntt_zero() {
let n = 256;
let ctx = make_ctx(n);
let a = Poly::from_coeffs_default((0..n as u64).collect());
let zero = Poly::zero_default(n);
let result = a.mul_ntt(&zero, &ctx);
assert!(result.is_zero());
}
#[test]
fn test_poly_mul_ntt_simple() {
let n = 256;
let ctx = make_ctx(n);
let q = DEFAULT_Q;
let mut coeffs = vec![0u64; n];
coeffs[0] = 1;
coeffs[1] = 1;
let a = Poly::from_coeffs(coeffs, q);
let result = a.mul_ntt(&a, &ctx);
assert_eq!(result.coeff(0), 1);
assert_eq!(result.coeff(1), 2);
assert_eq!(result.coeff(2), 1);
assert!(result.coeffs()[3..].iter().all(|&c| c == 0));
}
#[test]
fn test_poly_mul_ntt_negacyclic() {
let n = 256;
let ctx = make_ctx(n);
let q = DEFAULT_Q;
let mut a_coeffs = vec![0u64; n];
a_coeffs[1] = 1; let a = Poly::from_coeffs(a_coeffs, q);
let mut b_coeffs = vec![0u64; n];
b_coeffs[n - 1] = 1; let b = Poly::from_coeffs(b_coeffs, q);
let result = a.mul_ntt(&b, &ctx);
assert_eq!(result.coeff(0), q - 1); assert!(result.coeffs()[1..].iter().all(|&c| c == 0));
}
#[test]
fn test_poly_mul_associativity() {
let n = 256;
let ctx = make_ctx(n);
let q = DEFAULT_Q;
let a = Poly::from_coeffs((0..n as u64).map(|i| i % 100).collect(), q);
let b = Poly::from_coeffs((0..n as u64).map(|i| (i * 7) % 100).collect(), q);
let c = Poly::from_coeffs((0..n as u64).map(|i| (i * 13) % 100).collect(), q);
let ab = a.mul_ntt(&b, &ctx);
let ab_c = ab.mul_ntt(&c, &ctx);
let bc = b.mul_ntt(&c, &ctx);
let a_bc = a.mul_ntt(&bc, &ctx);
assert_eq!(ab_c, a_bc);
}
#[test]
fn test_poly_mul_commutativity() {
let n = 256;
let ctx = make_ctx(n);
let q = DEFAULT_Q;
let a = Poly::from_coeffs((0..n as u64).map(|i| i % 100).collect(), q);
let b = Poly::from_coeffs((0..n as u64).map(|i| (i * 7) % 100).collect(), q);
let ab = a.mul_ntt(&b, &ctx);
let ba = b.mul_ntt(&a, &ctx);
assert_eq!(ab, ba);
}
#[test]
fn test_poly_mul_distributivity() {
let n = 256;
let ctx = make_ctx(n);
let q = DEFAULT_Q;
let a = Poly::from_coeffs((0..n as u64).map(|i| i % 50).collect(), q);
let b = Poly::from_coeffs((0..n as u64).map(|i| (i * 3) % 50).collect(), q);
let c = Poly::from_coeffs((0..n as u64).map(|i| (i * 5) % 50).collect(), q);
let b_plus_c = &b + &c;
let left = a.mul_ntt(&b_plus_c, &ctx);
let ab = a.mul_ntt(&b, &ctx);
let ac = a.mul_ntt(&c, &ctx);
let right = &ab + ∾
assert_eq!(left, right);
}
#[test]
fn test_linf_norm() {
let q = DEFAULT_Q;
let mut coeffs = vec![0u64; 16];
coeffs[0] = 100;
coeffs[1] = q - 50; let p = Poly::from_coeffs(coeffs, q);
assert_eq!(p.linf_norm(), 100);
}
#[test]
fn test_l2_norm() {
let q = DEFAULT_Q;
let mut coeffs = vec![0u64; 4];
coeffs[0] = 3;
coeffs[1] = 4;
let p = Poly::from_coeffs(coeffs, q);
assert_eq!(p.l2_norm_squared(), 25);
}
#[test]
fn test_ntt_domain_multiplication() {
let n = 256;
let ctx = make_ctx(n);
let q = DEFAULT_Q;
let a = Poly::from_coeffs((0..n as u64).map(|i| i % 100).collect(), q);
let b = Poly::from_coeffs((0..n as u64).map(|i| (i * 7) % 100).collect(), q);
let result1 = a.mul_ntt(&b, &ctx);
let a_ntt = a.to_ntt_new(&ctx);
let b_ntt = b.to_ntt_new(&ctx);
let mut result2 = a_ntt.mul_ntt_domain(&b_ntt, &ctx);
result2.from_ntt(&ctx);
assert_eq!(result1, result2);
}
}