use std::mem::size_of;
use bytemuck::cast_slice_mut;
use poulpy_hal::{
api::TakeSlice,
layouts::{
Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos,
ZnxViewMut,
},
oep::{
TakeSliceImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAddScaledInplaceImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl,
VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
reference::ntt120::{
ntt::{NttTableInv, intt_ref},
primes::{PrimeSet, Primes30},
vec_znx_dft::{
NttModuleHandle, ntt120_vec_znx_dft_add, ntt120_vec_znx_dft_add_inplace, ntt120_vec_znx_dft_add_scaled_inplace,
ntt120_vec_znx_dft_apply, ntt120_vec_znx_dft_copy, ntt120_vec_znx_dft_sub, ntt120_vec_znx_dft_sub_inplace,
ntt120_vec_znx_dft_sub_negate_inplace, ntt120_vec_znx_dft_zero, ntt120_vec_znx_idft_apply,
ntt120_vec_znx_idft_apply_tmp_bytes, ntt120_vec_znx_idft_apply_tmpa,
},
},
};
use crate::NTT120Ref;
unsafe fn compact_all_blocks(n: usize, n_blocks: usize, u64_ptr: *mut u64, table: &NttTableInv<Primes30>) {
let q: [i128; 4] = Primes30::Q.map(|qi| qi as i128);
let total_q: i128 = q[0] * q[1] * q[2] * q[3];
let qm: [i128; 4] = [q[1] * q[2] * q[3], q[0] * q[2] * q[3], q[0] * q[1] * q[3], q[0] * q[1] * q[2]];
let crt: [i128; 4] = Primes30::CRT_CST.map(|c| c as i128);
let half: i128 = (total_q + 1) / 2;
for k in 0..n_blocks {
let src_start = 4 * n * k; let dst_start = 2 * n * k;
{
let blk: &mut [u64] = unsafe { std::slice::from_raw_parts_mut(u64_ptr.add(src_start), 4 * n) };
intt_ref::<Primes30>(table, blk);
}
for c in 0..n {
let (x0, x1, x2, x3) = unsafe {
(
*u64_ptr.add(src_start + 4 * c),
*u64_ptr.add(src_start + 4 * c + 1),
*u64_ptr.add(src_start + 4 * c + 2),
*u64_ptr.add(src_start + 4 * c + 3),
)
};
let mut v: i128 = 0;
v += (x0 % Primes30::Q[0] as u64) as i128 * crt[0] % q[0] * qm[0];
v += (x1 % Primes30::Q[1] as u64) as i128 * crt[1] % q[1] * qm[1];
v += (x2 % Primes30::Q[2] as u64) as i128 * crt[2] % q[2] * qm[2];
v += (x3 % Primes30::Q[3] as u64) as i128 * crt[3] % q[3] * qm[3];
v %= total_q;
let val: i128 = if v >= half { v - total_q } else { v };
unsafe { (u64_ptr.add(dst_start + 2 * c) as *mut i128).write_unaligned(val) };
}
}
}
unsafe impl VecZnxIdftApplyTmpBytesImpl<Self> for NTT120Ref {
fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module<Self>) -> usize {
ntt120_vec_znx_idft_apply_tmp_bytes(module.n())
}
}
unsafe impl VecZnxIdftApplyImpl<Self> for NTT120Ref
where
Self: TakeSliceImpl<Self>,
{
fn vec_znx_idft_apply_impl<R, A>(
module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxDftToRef<Self>,
{
let (tmp, _) = scratch.take_slice(ntt120_vec_znx_idft_apply_tmp_bytes(module.n()) / size_of::<u64>());
ntt120_vec_znx_idft_apply::<R, A, Self>(module, res, res_col, a, a_col, tmp);
}
}
unsafe impl VecZnxIdftApplyTmpAImpl<Self> for NTT120Ref {
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<Self>,
A: VecZnxDftToMut<Self>,
{
ntt120_vec_znx_idft_apply_tmpa::<R, A, Self>(module, res, res_col, a, a_col);
}
}
unsafe impl VecZnxIdftApplyConsumeImpl<Self> for NTT120Ref {
fn vec_znx_idft_apply_consume_impl<D: Data>(
module: &Module<NTT120Ref>,
mut a: VecZnxDft<D, NTT120Ref>,
) -> VecZnxBig<D, NTT120Ref>
where
VecZnxDft<D, NTT120Ref>: VecZnxDftToMut<NTT120Ref>,
{
let table = module.get_intt_table();
let (n, n_blocks, u64_ptr) = {
let mut a_mut: VecZnxDft<&mut [u8], NTT120Ref> = a.to_mut();
let n = a_mut.n();
let n_blocks = a_mut.cols() * a_mut.size();
let ptr: *mut u64 = {
let s = a_mut.raw_mut(); cast_slice_mut::<_, u64>(s).as_mut_ptr()
};
(n, n_blocks, ptr)
};
unsafe { compact_all_blocks(n, n_blocks, u64_ptr, table) };
a.into_big()
}
}
unsafe impl VecZnxDftApplyImpl<Self> for NTT120Ref {
fn vec_znx_dft_apply_impl<R, A>(
module: &Module<Self>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxToRef,
{
ntt120_vec_znx_dft_apply::<R, A, Self>(module, step, offset, res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftAddImpl<Self> for NTT120Ref {
fn vec_znx_dft_add_impl<R, A, D>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &D,
b_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
D: VecZnxDftToRef<Self>,
{
ntt120_vec_znx_dft_add::<R, A, D, Self>(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxDftAddScaledInplaceImpl<Self> for NTT120Ref {
fn vec_znx_dft_add_scaled_inplace_impl<R, A>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
a_scale: i64,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
ntt120_vec_znx_dft_add_scaled_inplace::<R, A, Self>(res, res_col, a, a_col, a_scale);
}
}
unsafe impl VecZnxDftAddInplaceImpl<Self> for NTT120Ref {
fn vec_znx_dft_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
ntt120_vec_znx_dft_add_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftSubImpl<Self> for NTT120Ref {
fn vec_znx_dft_sub_impl<R, A, D>(
_module: &Module<Self>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &D,
b_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
D: VecZnxDftToRef<Self>,
{
ntt120_vec_znx_dft_sub::<R, A, D, Self>(res, res_col, a, a_col, b, b_col);
}
}
unsafe impl VecZnxDftSubInplaceImpl<Self> for NTT120Ref {
fn vec_znx_dft_sub_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
ntt120_vec_znx_dft_sub_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftSubNegateInplaceImpl<Self> for NTT120Ref {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
ntt120_vec_znx_dft_sub_negate_inplace::<R, A, Self>(res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftCopyImpl<Self> for NTT120Ref {
fn vec_znx_dft_copy_impl<R, A>(
_module: &Module<Self>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
R: VecZnxDftToMut<Self>,
A: VecZnxDftToRef<Self>,
{
ntt120_vec_znx_dft_copy::<R, A, Self>(step, offset, res, res_col, a, a_col);
}
}
unsafe impl VecZnxDftZeroImpl<Self> for NTT120Ref {
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxDftToMut<Self>,
{
ntt120_vec_znx_dft_zero::<R, Self>(res, res_col);
}
}