use crate::{
integer::Z,
integer_mod_q::NTTPolynomialRingZq,
macros::arithmetics::{
arithmetic_assign_trait_borrowed_to_owned, arithmetic_trait_borrowed_to_owned,
arithmetic_trait_mixed_borrowed_owned,
},
traits::CompareBase,
};
use flint_sys::fmpz_mod::fmpz_mod_add;
use std::ops::{Add, AddAssign};
impl Add for &NTTPolynomialRingZq {
type Output = NTTPolynomialRingZq;
fn add(self, other: Self) -> Self::Output {
if !self.compare_base(other) {
panic!("{}", self.call_compare_base_error(other).unwrap());
}
let binding = &self.modulus.get_q_as_modulus();
let mod_q = binding.get_fmpz_mod_ctx_struct();
let mut out = NTTPolynomialRingZq {
poly: vec![Z::default(); self.poly.len()],
modulus: self.modulus.clone(),
};
for i in 0..self.poly.len() {
unsafe {
fmpz_mod_add(
&mut out.poly[i].value,
&self.poly[i].value,
&other.poly[i].value,
mod_q,
);
}
}
out
}
}
arithmetic_trait_borrowed_to_owned!(
Add,
add,
NTTPolynomialRingZq,
NTTPolynomialRingZq,
NTTPolynomialRingZq
);
arithmetic_trait_mixed_borrowed_owned!(
Add,
add,
NTTPolynomialRingZq,
NTTPolynomialRingZq,
NTTPolynomialRingZq
);
impl AddAssign<&NTTPolynomialRingZq> for NTTPolynomialRingZq {
fn add_assign(&mut self, other: &Self) {
if !self.compare_base(other) {
panic!("{}", self.call_compare_base_error(other).unwrap());
}
let binding = &self.modulus.get_q_as_modulus();
let mod_q = binding.get_fmpz_mod_ctx_struct();
for i in 0..self.poly.len() {
unsafe {
fmpz_mod_add(
&mut self.poly[i].value,
&self.poly[i].value,
&other.poly[i].value,
mod_q,
);
}
}
}
}
arithmetic_assign_trait_borrowed_to_owned!(
AddAssign,
add_assign,
NTTPolynomialRingZq,
NTTPolynomialRingZq
);
#[cfg(test)]
mod test_add {
use crate::{
integer_mod_q::{
ModulusPolynomialRingZq, NTTPolynomialRingZq, PolyOverZq, PolynomialRingZq,
},
traits::SetCoefficient,
};
use std::{ops::Add, str::FromStr};
#[test]
fn test_dilithium_params() {
let n = 256;
let modulus = 2_i64.pow(23) - 2_i64.pow(13) + 1;
let mut mod_poly = PolyOverZq::from(modulus);
mod_poly.set_coeff(0, 1).unwrap();
mod_poly.set_coeff(n, 1).unwrap();
let mut polynomial_modulus = ModulusPolynomialRingZq::from(&mod_poly);
polynomial_modulus.set_ntt_unchecked(1753);
let p1 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let p2 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let ntt1 = NTTPolynomialRingZq::from(&p1);
let ntt2 = NTTPolynomialRingZq::from(&p2);
let res = (&ntt1).add(ntt2);
assert_eq!(&p1 + &p2, PolynomialRingZq::from(res))
}
#[test]
fn test_hawk1024_params() {
let n = 1024;
let modulus = 12289;
let mut mod_poly = PolyOverZq::from(modulus);
mod_poly.set_coeff(0, 1).unwrap();
mod_poly.set_coeff(n, 1).unwrap();
let mut polynomial_modulus = ModulusPolynomialRingZq::from(&mod_poly);
polynomial_modulus.set_ntt_unchecked(1945);
let p1 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let p2 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let ntt1 = NTTPolynomialRingZq::from(&p1);
let ntt2 = NTTPolynomialRingZq::from(&p2);
let res = ntt1.add(ntt2);
assert_eq!(&p1 + &p2, PolynomialRingZq::from(res))
}
#[test]
#[should_panic]
fn different_moduli() {
let mut modulus0 = ModulusPolynomialRingZq::from_str("5 1 0 0 0 1 mod 257").unwrap();
modulus0.set_ntt_unchecked(64);
let mut modulus1 = ModulusPolynomialRingZq::from_str("6 1 0 0 0 0 1 mod 257").unwrap();
modulus1.set_ntt_unchecked(64);
let a = NTTPolynomialRingZq::sample_uniform(&modulus0);
let b = NTTPolynomialRingZq::sample_uniform(&modulus1);
let _ = a + b;
}
}
#[cfg(test)]
mod test_add_assign {
use crate::{
integer_mod_q::{
ModulusPolynomialRingZq, NTTPolynomialRingZq, PolyOverZq, PolynomialRingZq,
},
traits::SetCoefficient,
};
use std::{ops::AddAssign, str::FromStr};
#[test]
fn test_dilithium_params() {
let n = 256;
let modulus = 2_i64.pow(23) - 2_i64.pow(13) + 1;
let mut mod_poly = PolyOverZq::from(modulus);
mod_poly.set_coeff(0, 1).unwrap();
mod_poly.set_coeff(n, 1).unwrap();
let mut polynomial_modulus = ModulusPolynomialRingZq::from(&mod_poly);
polynomial_modulus.set_ntt_unchecked(1753);
let p1 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let p2 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let mut ntt1 = NTTPolynomialRingZq::from(&p1);
let ntt2 = NTTPolynomialRingZq::from(&p2);
ntt1.add_assign(ntt2);
assert_eq!(&p1 + &p2, PolynomialRingZq::from(ntt1))
}
#[test]
fn test_hawk1024_params() {
let n = 1024;
let modulus = 12289;
let mut mod_poly = PolyOverZq::from(modulus);
mod_poly.set_coeff(0, 1).unwrap();
mod_poly.set_coeff(n, 1).unwrap();
let mut polynomial_modulus = ModulusPolynomialRingZq::from(&mod_poly);
polynomial_modulus.set_ntt_unchecked(1945);
let p1 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let p2 = PolynomialRingZq::sample_uniform(&polynomial_modulus);
let mut ntt1 = NTTPolynomialRingZq::from(&p1);
let ntt2 = NTTPolynomialRingZq::from(&p2);
ntt1.add_assign(&ntt2);
assert_eq!(&p1 + &p2, PolynomialRingZq::from(ntt1))
}
#[test]
#[should_panic]
fn different_moduli() {
let mut modulus0 = ModulusPolynomialRingZq::from_str("5 1 0 0 0 1 mod 257").unwrap();
modulus0.set_ntt_unchecked(64);
let mut modulus1 = ModulusPolynomialRingZq::from_str("6 1 0 0 0 0 1 mod 257").unwrap();
modulus1.set_ntt_unchecked(64);
let mut a = NTTPolynomialRingZq::sample_uniform(&modulus0);
let b = NTTPolynomialRingZq::sample_uniform(&modulus1);
a += b;
}
}