use bytemuck::{cast_slice, cast_slice_mut};
use crate::{
layouts::{
Backend, CnvPVecLBackendMut, CnvPVecLBackendRef, CnvPVecRBackendMut, CnvPVecRBackendRef, HostDataRef, VecZnxBackendRef,
VecZnxBigBackendMut, VecZnxDftBackendMut, ZnxView, ZnxViewMut,
},
reference::ntt120::{
NttAddAssign, NttCFromB, NttDFTExecute, NttFromZnx64, NttMulBbc1ColX2, NttMulBbc2ColsX2, NttPackLeft1BlkX2,
NttPackRight1BlkX2, NttPairwisePackLeft1BlkX2, NttPairwisePackRight1BlkX2, ntt::NttTable, primes::Primes30,
types::Q120bScalar, vec_znx_dft::NttModuleHandle,
},
};
pub fn ntt120_cnv_prepare_left_tmp_bytes(_n: usize) -> usize {
0
}
pub fn ntt120_cnv_prepare_left<BE>(
module: &impl NttModuleHandle,
res: &mut CnvPVecLBackendMut<'_, BE>,
a: &VecZnxBackendRef<'_, BE>,
mask: i64,
_tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttFromZnx64 + NttDFTExecute<NttTable<Primes30>> + 'static,
for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
{
let table = module.get_ntt_table();
let cols = res.cols();
let res_size = res.size();
let min_size = res_size.min(a.size());
for col in 0..cols {
for j in 0..min_size.saturating_sub(1) {
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(col, j));
BE::ntt_from_znx64(res_u64, a.at(col, j));
BE::ntt_dft_execute(table, res_u64);
}
if min_size > 0 {
let last = min_size - 1;
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(col, last));
BE::ntt_from_znx64_masked(res_u64, a.at(col, last), mask);
BE::ntt_dft_execute(table, res_u64);
}
for j in min_size..res_size {
cast_slice_mut::<_, u64>(res.at_mut(col, j)).fill(0);
}
}
}
pub fn ntt120_cnv_prepare_right_tmp_bytes(n: usize) -> usize {
4 * n * size_of::<u64>()
}
pub fn ntt120_cnv_prepare_right<BE>(
module: &impl NttModuleHandle,
res: &mut CnvPVecRBackendMut<'_, BE>,
a: &VecZnxBackendRef<'_, BE>,
mask: i64,
tmp: &mut [u64],
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttFromZnx64 + NttDFTExecute<NttTable<Primes30>> + NttCFromB + 'static,
for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
{
let n = res.n();
let table = module.get_ntt_table();
let cols = res.cols();
let res_size = res.size();
let min_size = res_size.min(a.size());
for col in 0..cols {
for j in 0..min_size.saturating_sub(1) {
BE::ntt_from_znx64(tmp, a.at(col, j));
BE::ntt_dft_execute(table, tmp);
let res_u32: &mut [u32] = cast_slice_mut(res.at_mut(col, j));
BE::ntt_c_from_b(n, res_u32, tmp);
}
if min_size > 0 {
let last = min_size - 1;
BE::ntt_from_znx64_masked(tmp, a.at(col, last), mask);
BE::ntt_dft_execute(table, tmp);
let res_u32: &mut [u32] = cast_slice_mut(res.at_mut(col, last));
BE::ntt_c_from_b(n, res_u32, tmp);
}
for j in min_size..res_size {
cast_slice_mut::<_, u32>(res.at_mut(col, j)).fill(0);
}
}
}
pub fn ntt120_cnv_prepare_self_tmp_bytes(_n: usize) -> usize {
0
}
pub fn ntt120_cnv_prepare_self<BE>(
module: &impl NttModuleHandle,
left: &mut CnvPVecLBackendMut<'_, BE>,
right: &mut CnvPVecRBackendMut<'_, BE>,
a: &VecZnxBackendRef<'_, BE>,
mask: i64,
_tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttFromZnx64 + NttDFTExecute<NttTable<Primes30>> + NttCFromB + 'static,
for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
{
let table = module.get_ntt_table();
let n = left.n();
let cols = left.cols();
let res_size = left.size();
let min_size = res_size.min(a.size());
for col in 0..cols {
for j in 0..min_size.saturating_sub(1) {
{
let left_u64: &mut [u64] = cast_slice_mut(left.at_mut(col, j));
BE::ntt_from_znx64(left_u64, a.at(col, j));
BE::ntt_dft_execute(table, left_u64);
}
let left_u64: &[u64] = cast_slice(left.at(col, j));
let right_u32: &mut [u32] = cast_slice_mut(right.at_mut(col, j));
BE::ntt_c_from_b(n, right_u32, left_u64);
}
if min_size > 0 {
let last = min_size - 1;
{
let left_u64: &mut [u64] = cast_slice_mut(left.at_mut(col, last));
BE::ntt_from_znx64_masked(left_u64, a.at(col, last), mask);
BE::ntt_dft_execute(table, left_u64);
}
let left_u64: &[u64] = cast_slice(left.at(col, last));
let right_u32: &mut [u32] = cast_slice_mut(right.at_mut(col, last));
BE::ntt_c_from_b(n, right_u32, left_u64);
}
for j in min_size..res_size {
cast_slice_mut::<_, u64>(left.at_mut(col, j)).fill(0);
cast_slice_mut::<_, u32>(right.at_mut(col, j)).fill(0);
}
}
}
pub fn ntt120_cnv_apply_dft_tmp_bytes(_res_size: usize, a_size: usize, b_size: usize) -> usize {
(16 * (a_size + b_size)) * size_of::<u32>()
}
#[allow(clippy::too_many_arguments)]
pub fn ntt120_cnv_apply_dft<BE>(
module: &impl NttModuleHandle,
cnv_offset: usize,
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &CnvPVecLBackendRef<'_, BE>,
a_col: usize,
b: &CnvPVecRBackendRef<'_, BE>,
b_col: usize,
tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttMulBbc1ColX2 + NttPackLeft1BlkX2 + NttPackRight1BlkX2,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
{
let n = res.n();
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
if res_size == 0 || a_size == 0 || b_size == 0 {
for j in 0..res_size {
cast_slice_mut::<_, u64>(res.at_mut(res_col, j)).fill(0);
}
return;
}
let bound = a_size + b_size - 1;
let offset = cnv_offset.min(bound);
let min_size = res_size.min((bound + 1).saturating_sub(offset));
let meta = module.get_bbc_meta();
let a_cols = a.cols();
let b_cols = b.cols();
let n_blks = n / 2;
let a_row_stride_u64 = 4 * n * a_cols;
let b_row_stride_u32 = 8 * n * b_cols;
let a_col_offset_u64 = 4 * n * a_col;
let b_col_offset_u32 = 8 * n * b_col;
let a_raw_u64: &[u64] = cast_slice(a.raw());
let b_raw_u32: &[u32] = cast_slice(b.raw());
let (prefix, tmp_u32, suffix) = unsafe { tmp.align_to_mut::<u32>() };
debug_assert!(prefix.is_empty());
debug_assert!(suffix.is_empty());
debug_assert!(tmp_u32.len() >= 16 * (a_size + b_size));
let (a_tmp, b_tmp) = tmp_u32.split_at_mut(16 * a_size);
for blk in 0..n_blks {
BE::ntt_pack_left_1blk_x2(a_tmp, &a_raw_u64[a_col_offset_u64..], a_size, a_row_stride_u64, blk);
BE::ntt_pack_right_1blk_x2(b_tmp, &b_raw_u32[b_col_offset_u32..], b_size, b_row_stride_u32, blk);
for k in 0..min_size {
let k_abs = k + offset;
let j_min = k_abs.saturating_sub(a_size - 1);
let j_max = (k_abs + 1).min(b_size);
let ell = j_max - j_min;
let a_start = k_abs + 1 - j_max;
let b_start = b_size - j_max;
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(res_col, k));
BE::ntt_mul_bbc_1col_x2(
meta,
ell,
&mut res_u64[8 * blk..8 * blk + 8],
&a_tmp[16 * a_start..],
&b_tmp[16 * b_start..],
);
}
}
for j in min_size..res_size {
cast_slice_mut::<_, u64>(res.at_mut(res_col, j)).fill(0);
}
}
pub fn ntt120_cnv_by_const_apply_tmp_bytes(_res_size: usize, _a_size: usize, _b_size: usize) -> usize {
0
}
#[allow(clippy::too_many_arguments)]
pub fn ntt120_cnv_by_const_apply<BE>(
cnv_offset: usize,
res: &mut VecZnxBigBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxBackendRef<'_, BE>,
a_col: usize,
b: &VecZnxBackendRef<'_, BE>,
b_col: usize,
b_coeff: usize,
_tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + 'static,
for<'x> BE: Backend<BufRef<'x> = &'x [u8]>,
for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
{
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
if res_size == 0 || a_size == 0 || b_size == 0 {
for j in 0..res_size {
res.at_mut(res_col, j).fill(0i128);
}
return;
}
let bound = a_size + b_size - 1;
let offset = cnv_offset.min(bound);
let min_size = res_size.min((bound + 1).saturating_sub(offset));
for k in 0..min_size {
let k_abs = k + offset;
let j_min = k_abs.saturating_sub(a_size - 1);
let j_max = (k_abs + 1).min(b_size);
let res_limb: &mut [i128] = res.at_mut(res_col, k);
for (n_i, r) in res_limb.iter_mut().enumerate() {
let mut acc: i128 = 0;
for j in j_min..j_max {
let b_j = b.at(b_col, j)[b_coeff];
acc += a.at(a_col, k_abs - j)[n_i] as i128 * b_j as i128;
}
*r = acc;
}
}
for j in min_size..res_size {
res.at_mut(res_col, j).fill(0i128);
}
}
pub fn ntt120_cnv_pairwise_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
if a_size == 0 || b_size == 0 || res_size == 0 {
0
} else {
(16 * (a_size + b_size) * size_of::<u32>()).max(ntt120_cnv_apply_dft_tmp_bytes(res_size, a_size, b_size))
}
}
#[allow(clippy::too_many_arguments)]
pub fn ntt120_cnv_pairwise_apply_dft<BE>(
module: &impl NttModuleHandle,
cnv_offset: usize,
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &CnvPVecLBackendRef<'_, BE>,
b: &CnvPVecRBackendRef<'_, BE>,
col_i: usize,
col_j: usize,
tmp: &mut [u8],
) where
BE: Backend<ScalarPrep = Q120bScalar>
+ NttAddAssign
+ NttMulBbc1ColX2
+ NttMulBbc2ColsX2
+ NttPackLeft1BlkX2
+ NttPackRight1BlkX2
+ NttPairwisePackLeft1BlkX2
+ NttPairwisePackRight1BlkX2,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
{
if col_i == col_j {
ntt120_cnv_apply_dft(module, cnv_offset, res, res_col, a, col_i, b, col_j, tmp);
return;
}
let meta = module.get_bbc_meta();
let n = res.n();
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
if res_size == 0 || a_size == 0 || b_size == 0 {
for j in 0..res_size {
cast_slice_mut::<_, u64>(res.at_mut(res_col, j)).fill(0);
}
return;
}
let a_cols = a.cols();
let b_cols = b.cols();
let bound = a_size + b_size - 1;
let offset = cnv_offset.min(bound);
let min_size = res_size.min((bound + 1).saturating_sub(offset));
let n_blks = n / 2;
let a_row_stride_u64 = 4 * n * a_cols;
let b_row_stride_u32 = 8 * n * b_cols;
let a_col_offset_u64_i = 4 * n * col_i;
let a_col_offset_u64_j = 4 * n * col_j;
let b_col_offset_u32_i = 8 * n * col_i;
let b_col_offset_u32_j = 8 * n * col_j;
let a_raw_u64: &[u64] = cast_slice(a.raw());
let b_raw_u32: &[u32] = cast_slice(b.raw());
let (prefix, tmp_u32, suffix) = unsafe { tmp.align_to_mut::<u32>() };
debug_assert!(prefix.is_empty());
debug_assert!(suffix.is_empty());
debug_assert!(tmp_u32.len() >= 16 * (a_size + b_size));
let (a_tmp, b_tmp) = tmp_u32.split_at_mut(16 * a_size);
for blk in 0..n_blks {
BE::ntt_pairwise_pack_left_1blk_x2(
a_tmp,
&a_raw_u64[a_col_offset_u64_i..],
&a_raw_u64[a_col_offset_u64_j..],
a_size,
a_row_stride_u64,
blk,
);
BE::ntt_pairwise_pack_right_1blk_x2(
b_tmp,
&b_raw_u32[b_col_offset_u32_i..],
&b_raw_u32[b_col_offset_u32_j..],
b_size,
b_row_stride_u32,
blk,
);
for k in 0..min_size {
let k_abs = k + offset;
let j_min = k_abs.saturating_sub(a_size - 1);
let j_max = (k_abs + 1).min(b_size);
let ell = j_max - j_min;
let a_start = k_abs + 1 - j_max;
let b_start = b_size - j_max;
let res_u64: &mut [u64] = cast_slice_mut(res.at_mut(res_col, k));
BE::ntt_mul_bbc_1col_x2(
meta,
ell,
&mut res_u64[8 * blk..8 * blk + 8],
&a_tmp[16 * a_start..],
&b_tmp[16 * b_start..],
);
}
}
for j in min_size..res_size {
cast_slice_mut::<_, u64>(res.at_mut(res_col, j)).fill(0);
}
}