use bytemuck::{cast_slice, cast_slice_mut};
use crate::{
layouts::{
Backend, Module, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxView,
ZnxViewMut,
},
reference::ntt120::{
NttAdd, NttAddInplace, NttCopy, NttDFTExecute, NttFromZnx64, NttNegate, NttNegateInplace, NttSub, NttSubInplace,
NttSubNegateInplace, NttToZnx128, NttZero,
mat_vec::{BbbMeta, BbcMeta},
ntt::{NttTable, NttTableInv},
primes::Primes30,
types::Q120bScalar,
},
};
pub trait NttModuleHandle {
fn get_ntt_table(&self) -> &NttTable<Primes30>;
fn get_intt_table(&self) -> &NttTableInv<Primes30>;
fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
}
pub unsafe trait NttHandleProvider {
fn get_ntt_table(&self) -> &NttTable<Primes30>;
fn get_intt_table(&self) -> &NttTableInv<Primes30>;
fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
}
impl<B> NttModuleHandle for Module<B>
where
B: Backend,
B::Handle: NttHandleProvider,
{
fn get_ntt_table(&self) -> &NttTable<Primes30> {
unsafe { (&*self.ptr()).get_ntt_table() }
}
fn get_intt_table(&self) -> &NttTableInv<Primes30> {
unsafe { (&*self.ptr()).get_intt_table() }
}
fn get_bbc_meta(&self) -> &BbcMeta<Primes30> {
unsafe { (&*self.ptr()).get_bbc_meta() }
}
fn get_bbb_meta(&self) -> &BbbMeta<Primes30> {
unsafe { (&*self.ptr()).get_bbb_meta() }
}
}
#[inline(always)]
fn limb_u64<D: crate::layouts::DataRef, BE: Backend<ScalarPrep = Q120bScalar>>(
v: &VecZnxDft<D, BE>,
col: usize,
limb: usize,
) -> &[u64] {
cast_slice(v.at(col, limb))
}
#[inline(always)]
fn limb_u64_mut<D: crate::layouts::DataMut, BE: Backend<ScalarPrep = Q120bScalar>>(
v: &mut VecZnxDft<D, BE>,
col: usize,
limb: usize,
) -> &mut [u64] {
cast_slice_mut(v.at_mut(col, limb))
}
pub fn ntt120_vec_znx_dft_apply<R, A, BE>(
module: &impl NttModuleHandle,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttDFTExecute<NttTable<Primes30>> + NttFromZnx64 + NttZero,
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a = a.to_ref();
let a_size = a.size();
let res_size = res.size();
let table = module.get_ntt_table();
let steps = a_size.div_ceil(step);
let min_steps = res_size.min(steps);
for j in 0..min_steps {
let limb = offset + j * step;
if limb < a_size {
let res_slice: &mut [u64] = limb_u64_mut(&mut res, res_col, j);
BE::ntt_from_znx64(res_slice, a.at(a_col, limb));
BE::ntt_dft_execute(table, res_slice);
} else {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
}
for j in min_steps..res_size {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
}
pub fn ntt120_vec_znx_idft_apply_tmp_bytes(n: usize) -> usize {
4 * n * size_of::<u64>()
}
pub fn ntt120_vec_znx_idft_apply<R, A, BE>(
module: &impl NttModuleHandle,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
tmp: &mut [u64],
) where
BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128 + NttCopy,
R: VecZnxBigToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let n = res.n();
let res_size = res.size();
let min_size = res_size.min(a.size());
let table = module.get_intt_table();
for j in 0..min_size {
let a_slice: &[u64] = limb_u64(&a, a_col, j);
let tmp_n: &mut [u64] = &mut tmp[..4 * n];
BE::ntt_copy(tmp_n, a_slice);
BE::ntt_dft_execute(table, tmp_n);
BE::ntt_to_znx128(res.at_mut(res_col, j), n, tmp_n);
}
for j in min_size..res_size {
res.at_mut(res_col, j).fill(0i128);
}
}
pub fn ntt120_vec_znx_idft_apply_tmpa<R, A, BE>(
module: &impl NttModuleHandle,
res: &mut R,
res_col: usize,
a: &mut A,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128,
R: VecZnxBigToMut<BE>,
A: VecZnxDftToMut<BE>,
{
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut a: VecZnxDft<&mut [u8], BE> = a.to_mut();
let n = res.n();
let res_size = res.size();
let min_size = res_size.min(a.size());
let table = module.get_intt_table();
for j in 0..min_size {
BE::ntt_dft_execute(table, limb_u64_mut(&mut a, a_col, j));
let a_slice: &[u64] = limb_u64(&a, a_col, j);
BE::ntt_to_znx128(res.at_mut(res_col, j), n, a_slice);
}
for j in min_size..res_size {
res.at_mut(res_col, j).fill(0i128);
}
}
pub fn ntt120_vec_znx_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttAdd + NttCopy + NttZero,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
if a_size <= b_size {
let sum_size = a_size.min(res_size);
let cpy_size = b_size.min(res_size);
for j in 0..sum_size {
BE::ntt_add(
limb_u64_mut(&mut res, res_col, j),
limb_u64(&a, a_col, j),
limb_u64(&b, b_col, j),
);
}
for j in sum_size..cpy_size {
BE::ntt_copy(limb_u64_mut(&mut res, res_col, j), limb_u64(&b, b_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
} else {
let sum_size = b_size.min(res_size);
let cpy_size = a_size.min(res_size);
for j in 0..sum_size {
BE::ntt_add(
limb_u64_mut(&mut res, res_col, j),
limb_u64(&a, a_col, j),
limb_u64(&b, b_col, j),
);
}
for j in sum_size..cpy_size {
BE::ntt_copy(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
}
}
pub fn ntt120_vec_znx_dft_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttAddInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let sum_size = res.size().min(a.size());
for j in 0..sum_size {
BE::ntt_add_inplace(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, j));
}
}
pub fn ntt120_vec_znx_dft_add_scaled_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, a_scale: i64)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttAddInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let res_size = res.size();
let a_size = a.size();
if a_scale > 0 {
let shift = (a_scale as usize).min(a_size);
let sum_size = a_size.min(res_size).saturating_sub(shift);
for j in 0..sum_size {
BE::ntt_add_inplace(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, j + shift));
}
} else if a_scale < 0 {
let shift = (a_scale.unsigned_abs() as usize).min(res_size);
let sum_size = a_size.min(res_size.saturating_sub(shift));
for j in 0..sum_size {
BE::ntt_add_inplace(limb_u64_mut(&mut res, res_col, j + shift), limb_u64(&a, a_col, j));
}
} else {
let sum_size = a_size.min(res_size);
for j in 0..sum_size {
BE::ntt_add_inplace(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, j));
}
}
}
pub fn ntt120_vec_znx_dft_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttSub + NttNegate + NttCopy + NttZero,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
if a_size <= b_size {
let sum_size = a_size.min(res_size);
let cpy_size = b_size.min(res_size);
for j in 0..sum_size {
BE::ntt_sub(
limb_u64_mut(&mut res, res_col, j),
limb_u64(&a, a_col, j),
limb_u64(&b, b_col, j),
);
}
for j in sum_size..cpy_size {
BE::ntt_negate(limb_u64_mut(&mut res, res_col, j), limb_u64(&b, b_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
} else {
let sum_size = b_size.min(res_size);
let cpy_size = a_size.min(res_size);
for j in 0..sum_size {
BE::ntt_sub(
limb_u64_mut(&mut res, res_col, j),
limb_u64(&a, a_col, j),
limb_u64(&b, b_col, j),
);
}
for j in sum_size..cpy_size {
BE::ntt_copy(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
}
}
pub fn ntt120_vec_znx_dft_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttSubInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let sum_size = res.size().min(a.size());
for j in 0..sum_size {
BE::ntt_sub_inplace(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, j));
}
}
pub fn ntt120_vec_znx_dft_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttSubNegateInplace + NttNegateInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let res_size = res.size();
let sum_size = res_size.min(a.size());
for j in 0..sum_size {
BE::ntt_sub_negate_inplace(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, j));
}
for j in sum_size..res_size {
BE::ntt_negate_inplace(limb_u64_mut(&mut res, res_col, j));
}
}
pub fn ntt120_vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttCopy + NttZero,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let steps: usize = a.size().div_ceil(step);
let min_steps: usize = res.size().min(steps);
for j in 0..min_steps {
let limb = offset + j * step;
if limb < a.size() {
BE::ntt_copy(limb_u64_mut(&mut res, res_col, j), limb_u64(&a, a_col, limb));
} else {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
}
for j in min_steps..res.size() {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
}
pub fn ntt120_vec_znx_dft_zero<R, BE>(res: &mut R, res_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttZero,
R: VecZnxDftToMut<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
for j in 0..res.size() {
BE::ntt_zero(limb_u64_mut(&mut res, res_col, j));
}
}