use bytemuck::{cast_slice, cast_slice_mut};
use crate::{
layouts::{
Backend, HostDataMut, HostDataRef, ScalarZnxBackendRef, SvpPPolBackendMut, SvpPPolBackendRef, VecZnxDftBackendMut,
VecZnxDftBackendRef, ZnxView, ZnxViewMut,
},
reference::ntt120::{
NttCFromB, NttDFTExecute, NttFromZnx64, NttMulBbc, NttZero, ntt::NttTable, primes::Primes30, types::Q120bScalar,
vec_znx_dft::NttModuleHandle,
},
};
pub fn ntt120_svp_prepare<'r, 'a, BE>(
module: &impl NttModuleHandle,
res: &mut SvpPPolBackendMut<'r, BE>,
res_col: usize,
a: &ScalarZnxBackendRef<'a, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttDFTExecute<NttTable<Primes30>> + NttFromZnx64 + NttCFromB,
BE::BufMut<'r>: HostDataMut,
BE::BufRef<'a>: HostDataRef,
{
let n = res.n();
let mut tmp = vec![0u64; 4 * n];
BE::ntt_from_znx64(&mut tmp, a.at(a_col, 0));
BE::ntt_dft_execute(module.get_ntt_table(), &mut tmp);
let res_u32: &mut [u32] = cast_slice_mut(res.at_mut(res_col, 0));
BE::ntt_c_from_b(n, res_u32, &tmp);
}
pub fn ntt120_svp_apply_dft_to_dft<'r, 'a, 'b, BE>(
module: &impl NttModuleHandle,
res: &mut VecZnxDftBackendMut<'r, BE>,
res_col: usize,
a: &SvpPPolBackendRef<'a, BE>,
a_col: usize,
b: &VecZnxDftBackendRef<'b, BE>,
b_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttMulBbc + NttZero,
BE::BufMut<'r>: HostDataMut,
for<'x> BE::BufRef<'x>: HostDataRef,
{
let meta = module.get_bbc_meta();
let n = res.n();
let res_size = res.size();
let b_size = b.size();
let min_size = res_size.min(b_size);
let a_u32: &[u32] = cast_slice(a.at(a_col, 0));
for j in 0..min_size {
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(res_col, j));
let b_u32: &[u32] = cast_slice(b.at(b_col, j));
for n_i in 0..n {
BE::ntt_mul_bbc(
meta,
1,
&mut res_u64[4 * n_i..4 * n_i + 4],
&b_u32[8 * n_i..8 * n_i + 8],
&a_u32[8 * n_i..8 * n_i + 8],
);
}
}
for j in min_size..res_size {
BE::ntt_zero(cast_slice_mut(res.at_mut(res_col, j)));
}
}
pub fn ntt120_svp_apply_dft_to_dft_assign<'r, 'a, BE>(
module: &impl NttModuleHandle,
res: &mut VecZnxDftBackendMut<'r, BE>,
res_col: usize,
a: &SvpPPolBackendRef<'a, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttMulBbc,
BE::BufMut<'r>: HostDataMut,
BE::BufRef<'a>: HostDataRef,
{
let meta = module.get_bbc_meta();
let n = res.n();
let res_size = res.size();
let a_u32: &[u32] = cast_slice(a.at(a_col, 0));
for j in 0..res_size {
let res_slice: &mut [Q120bScalar] = res.at_mut(res_col, j);
let mut product = [0u64; 4];
for n_i in 0..n {
let x_elem: Q120bScalar = res_slice[n_i];
let x_u32: &[u32] = cast_slice(std::slice::from_ref(&x_elem));
BE::ntt_mul_bbc(meta, 1, &mut product, x_u32, &a_u32[8 * n_i..8 * n_i + 8]);
res_slice[n_i] = Q120bScalar(product);
}
}
}