use bytemuck::{cast_slice, cast_slice_mut};
use crate::{
layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, VecZnx, VecZnxBig,
VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut,
},
reference::ntt120::{
NttDFTExecute, NttFromZnx64,
arithmetic::{b_from_znx64_ref, c_from_b_ref},
mat_vec::{accum_mul_q120_bc, accum_to_q120b},
ntt::{NttTable, ntt_ref},
primes::{PrimeSet, Primes30},
types::Q120bScalar,
vec_znx_dft::NttModuleHandle,
},
};
pub fn ntt120_cnv_prepare_left_tmp_bytes(_n: usize) -> usize {
0
}
pub fn ntt120_cnv_prepare_left<R, A, BE>(module: &impl NttModuleHandle, res: &mut R, a: &A, _tmp: &mut [u8])
where
BE: Backend<ScalarPrep = Q120bScalar> + NttFromZnx64 + NttDFTExecute<NttTable<Primes30>>,
R: CnvPVecLToMut<BE>,
A: VecZnxToRef,
{
let mut res: CnvPVecL<&mut [u8], BE> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
let table = module.get_ntt_table();
let cols = res.cols();
let res_size = res.size();
let min_size = res_size.min(a.size());
for col in 0..cols {
for j in 0..min_size {
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(col, j));
BE::ntt_from_znx64(res_u64, a.at(col, j));
BE::ntt_dft_execute(table, res_u64);
}
for j in min_size..res_size {
cast_slice_mut::<_, u64>(res.at_mut(col, j)).fill(0);
}
}
}
pub fn ntt120_cnv_prepare_right_tmp_bytes(n: usize) -> usize {
4 * n * size_of::<u64>()
}
pub fn ntt120_cnv_prepare_right<R, A, BE>(module: &impl NttModuleHandle, res: &mut R, a: &A, tmp: &mut [u64])
where
BE: Backend<ScalarPrep = Q120bScalar>,
R: CnvPVecRToMut<BE>,
A: VecZnxToRef,
{
let mut res: CnvPVecR<&mut [u8], BE> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
let n = res.n();
let table = module.get_ntt_table();
let cols = res.cols();
let res_size = res.size();
let min_size = res_size.min(a.size());
for col in 0..cols {
for j in 0..min_size {
b_from_znx64_ref::<Primes30>(n, tmp, a.at(col, j));
ntt_ref(table, tmp);
let res_u32: &mut [u32] = cast_slice_mut(res.at_mut(col, j));
c_from_b_ref::<Primes30>(n, res_u32, tmp);
}
for j in min_size..res_size {
cast_slice_mut::<_, u32>(res.at_mut(col, j)).fill(0);
}
}
}
pub fn ntt120_cnv_apply_dft_tmp_bytes(_res_size: usize, _a_size: usize, _b_size: usize) -> usize {
0
}
#[allow(clippy::too_many_arguments)]
pub fn ntt120_cnv_apply_dft<R, A, B, BE>(
module: &impl NttModuleHandle,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
_tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar>,
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: CnvPVecL<&[u8], BE> = a.to_ref();
let b: CnvPVecR<&[u8], BE> = b.to_ref();
let meta = module.get_bbc_meta();
let n = res.n();
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
let bound = a_size + b_size - 1;
let min_size = res_size.min(bound);
let offset = res_offset.min(bound);
for k in 0..min_size {
let k_abs = k + offset;
let j_min = k_abs.saturating_sub(a_size - 1);
let j_max = (k_abs + 1).min(b_size);
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(res_col, k));
for n_i in 0..n {
let mut s = [0u64; 8];
for j in j_min..j_max {
let ai: &[u32; 8] = cast_slice::<_, u32>(a.at(a_col, k_abs - j))[8 * n_i..8 * n_i + 8]
.try_into()
.unwrap();
let bi: &[u32; 8] = cast_slice::<_, u32>(b.at(b_col, j))[8 * n_i..8 * n_i + 8].try_into().unwrap();
accum_mul_q120_bc(&mut s, ai, bi);
}
let mut r4 = [0u64; 4];
accum_to_q120b::<Primes30>(&mut r4, &s, meta);
res_u64[4 * n_i..4 * n_i + 4].copy_from_slice(&r4);
}
}
for j in min_size..res_size {
cast_slice_mut::<_, u64>(res.at_mut(res_col, j)).fill(0);
}
}
pub fn ntt120_cnv_by_const_apply_tmp_bytes(_res_size: usize, _a_size: usize, _b_size: usize) -> usize {
0
}
#[allow(clippy::too_many_arguments)]
pub fn ntt120_cnv_by_const_apply<R, A, BE>(
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
_tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128>,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
let res_size = res.size();
let a_size = a.size();
let b_size = b.len();
let bound = a_size + b_size - 1;
let min_size = res_size.min(bound);
let offset = res_offset.min(bound);
for k in 0..min_size {
let k_abs = k + offset;
let j_min = k_abs.saturating_sub(a_size - 1);
let j_max = (k_abs + 1).min(b_size);
let res_limb: &mut [i128] = res.at_mut(res_col, k);
for (n_i, r) in res_limb.iter_mut().enumerate() {
let mut acc: i128 = 0;
for (j, &b_j) in b.iter().enumerate().take(j_max).skip(j_min) {
acc += a.at(a_col, k_abs - j)[n_i] as i128 * b_j as i128;
}
*r = acc;
}
}
for j in min_size..res_size {
res.at_mut(res_col, j).fill(0i128);
}
}
pub fn ntt120_cnv_pairwise_apply_dft_tmp_bytes(_res_size: usize, _a_size: usize, _b_size: usize) -> usize {
0
}
#[allow(clippy::too_many_arguments)]
pub fn ntt120_cnv_pairwise_apply_dft<R, A, B, BE>(
module: &impl NttModuleHandle,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
col_i: usize,
col_j: usize,
_tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar>,
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>,
{
if col_i == col_j {
ntt120_cnv_apply_dft(module, res, res_offset, res_col, a, col_i, b, col_j, &mut []);
return;
}
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: CnvPVecL<&[u8], BE> = a.to_ref();
let b: CnvPVecR<&[u8], BE> = b.to_ref();
let meta = module.get_bbc_meta();
let n = res.n();
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
let bound = a_size + b_size - 1;
let min_size = res_size.min(bound);
let offset = res_offset.min(bound);
for k in 0..min_size {
let k_abs = k + offset;
let j_min = k_abs.saturating_sub(a_size - 1);
let j_max = (k_abs + 1).min(b_size);
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(res_col, k));
for n_i in 0..n {
let mut s = [0u64; 8];
for j in j_min..j_max {
let ai: &[u32; 8] = cast_slice::<_, u32>(a.at(col_i, k_abs - j))[8 * n_i..8 * n_i + 8]
.try_into()
.unwrap();
let aj: &[u32; 8] = cast_slice::<_, u32>(a.at(col_j, k_abs - j))[8 * n_i..8 * n_i + 8]
.try_into()
.unwrap();
let bi: &[u32; 8] = cast_slice::<_, u32>(b.at(col_i, j))[8 * n_i..8 * n_i + 8].try_into().unwrap();
let bj: &[u32; 8] = cast_slice::<_, u32>(b.at(col_j, j))[8 * n_i..8 * n_i + 8].try_into().unwrap();
let mut a_sum = [0u32; 8];
let mut b_sum = [0u32; 8];
for k in 0..4 {
let q = Primes30::Q[k] as u64;
let ai_k = (ai[2 * k] as u64) | ((ai[2 * k + 1] as u64) << 32);
let aj_k = (aj[2 * k] as u64) | ((aj[2 * k + 1] as u64) << 32);
a_sum[2 * k] = ((ai_k % q) + (aj_k % q)) as u32;
b_sum[2 * k] = bi[2 * k] + bj[2 * k];
b_sum[2 * k + 1] = bi[2 * k + 1] + bj[2 * k + 1];
}
accum_mul_q120_bc(&mut s, &a_sum, &b_sum);
}
let mut r = [0u64; 4];
accum_to_q120b::<Primes30>(&mut r, &s, meta);
res_u64[4 * n_i..4 * n_i + 4].copy_from_slice(&r);
}
}
for j in min_size..res_size {
cast_slice_mut::<_, u64>(res.at_mut(res_col, j)).fill(0);
}
}