use super::{NTTBasisPolynomialRingZq, from::ConvolutionType};
use crate::{
integer::Z,
integer_mod_q::{Modulus, PolyOverZq},
traits::GetCoefficient,
utils::index::bit_reverse_permutation,
};
use flint_sys::fmpz_mod::{fmpz_mod_add, fmpz_mod_ctx, fmpz_mod_mul, fmpz_mod_sub};
impl NTTBasisPolynomialRingZq {
pub fn ntt(&self, poly: &PolyOverZq) -> Vec<Z> {
assert!(poly.get_degree() < self.n);
assert_eq!(poly.get_mod(), self.modulus);
let mut poly_coeffs: Vec<Z> = (0..self.n)
.map(|i| unsafe { poly.get_coeff_unchecked(i) })
.collect();
for _ in poly_coeffs.len()..(self.n as usize) {
poly_coeffs.push(Z::default());
}
if self.convolution_type == ConvolutionType::Negacyclic {
for (i, x) in poly_coeffs.iter_mut().enumerate() {
unsafe {
fmpz_mod_mul(
&mut x.value,
&x.value,
&self.powers_of_psi[i].value,
self.modulus.get_fmpz_mod_ctx_struct(),
);
}
}
}
iterative_ntt(poly_coeffs, &self.powers_of_omega, &self.modulus)
}
}
unsafe fn ntt_stride_steps(
chunk: &mut [Z],
stride: usize,
power_pointer: i64,
modulus_pointer: &fmpz_mod_ctx,
powers_of_omega_pointers: &[Z],
) {
for i in 0..stride {
let current_power = &powers_of_omega_pointers[2_usize.pow(power_pointer as u32) * (i)];
let mut temp = Z::default();
unsafe {
fmpz_mod_mul(
&mut temp.value,
¤t_power.value,
&chunk[i + stride].value,
modulus_pointer,
);
fmpz_mod_sub(
&mut chunk[i + stride].value,
&chunk[i].value,
&temp.value,
modulus_pointer,
);
fmpz_mod_add(
&mut chunk[i].value,
&chunk[i].value,
&temp.value,
modulus_pointer,
)
}
}
}
fn iterative_ntt(coefficients: Vec<Z>, powers_of_omega: &[Z], modulus: &Modulus) -> Vec<Z> {
let n = coefficients.len();
let nr_iterations = n.ilog2() as i64;
let mut res = coefficients;
bit_reverse_permutation(&mut res);
let modulus_pointer = modulus.get_fmpz_mod_ctx_struct();
let mut power_pointer: i64 = nr_iterations - 1;
let mut stride = 1;
while stride < n {
res.chunks_mut(2 * stride).for_each(|chunk| unsafe {
ntt_stride_steps(
chunk,
stride,
power_pointer,
modulus_pointer,
powers_of_omega,
)
});
stride *= 2;
power_pointer -= 1;
}
res
}
#[cfg(test)]
mod test_ntt {
use crate::{
integer::Z,
integer_mod_q::{ConvolutionType, Modulus, NTTBasisPolynomialRingZq, PolyOverZq},
};
use std::str::FromStr;
#[test]
fn example_34_multiplication_with_ntt() {
let g_poly = PolyOverZq::from_str("4 1 2 3 4 mod 7681").unwrap();
let modulus = Modulus::from(7681);
let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
let ghat = ntt_basis.ntt(&g_poly);
let cmp_ghat = vec![Z::from(10), Z::from(913), Z::from(7679), Z::from(6764)];
assert_eq!(cmp_ghat, ghat);
}
#[test]
#[should_panic]
fn degree_too_high() {
let g_poly = PolyOverZq::from_str("5 1 2 3 4 5 mod 7681").unwrap();
let modulus = Modulus::from(7681);
let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
let _ = ntt_basis.ntt(&g_poly);
}
#[test]
#[should_panic]
fn different_modulus() {
let g_poly = PolyOverZq::from_str("4 1 2 3 4 mod 7681").unwrap();
let modulus = Modulus::from(7682);
let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
let _ = ntt_basis.ntt(&g_poly);
}
#[test]
fn small_degree() {
let g_poly = PolyOverZq::from_str("2 1 2 mod 7681").unwrap();
let modulus = Modulus::from(7681);
let ntt_basis =
NTTBasisPolynomialRingZq::init(4, 1925, &modulus, ConvolutionType::Negacyclic);
let ghat = ntt_basis.ntt(&g_poly);
let cmp_ghat = vec![Z::from(3851), Z::from(5256), Z::from(3832), Z::from(2427)];
assert_eq!(cmp_ghat, ghat);
}
}