use bytemuck::{cast_slice, cast_slice_mut};
use crate::{
layouts::{
Backend, HostDataMut, HostDataRef, Module, VecZnxBackendRef, VecZnxBigBackendMut, VecZnxDft, VecZnxDftBackendMut,
VecZnxDftBackendRef, ZnxView, ZnxViewMut,
},
reference::ntt120::{
NttAdd, NttAddAssign, NttCopy, NttDFTExecute, NttFromZnx64, NttNegate, NttNegateAssign, NttSub, NttSubAssign,
NttSubNegateAssign, NttToZnx128, NttZero,
mat_vec::{BbbMeta, BbcMeta},
ntt::{NttTable, NttTableInv, intt_ref},
primes::{PrimeSet, Primes30},
types::Q120bScalar,
},
};
pub trait NttModuleHandle {
fn get_ntt_table(&self) -> &NttTable<Primes30>;
fn get_intt_table(&self) -> &NttTableInv<Primes30>;
fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
}
pub unsafe trait NttHandleProvider {
fn get_ntt_table(&self) -> &NttTable<Primes30>;
fn get_intt_table(&self) -> &NttTableInv<Primes30>;
fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
}
pub unsafe trait NttHandleFactory: Sized {
fn create_ntt_handle(n: usize) -> Self;
fn assert_ntt_runtime_support() {}
}
impl<B> NttModuleHandle for Module<B>
where
B: Backend,
B::Handle: NttHandleProvider,
{
fn get_ntt_table(&self) -> &NttTable<Primes30> {
unsafe { (&*self.ptr()).get_ntt_table() }
}
fn get_intt_table(&self) -> &NttTableInv<Primes30> {
unsafe { (&*self.ptr()).get_intt_table() }
}
fn get_bbc_meta(&self) -> &BbcMeta<Primes30> {
unsafe { (&*self.ptr()).get_bbc_meta() }
}
fn get_bbb_meta(&self) -> &BbbMeta<Primes30> {
unsafe { (&*self.ptr()).get_bbb_meta() }
}
}
#[inline(always)]
fn limb_u64<D: crate::layouts::HostDataRef, BE: Backend<ScalarPrep = Q120bScalar>>(
v: &VecZnxDft<D, BE>,
col: usize,
limb: usize,
) -> &[u64] {
cast_slice(v.at(col, limb))
}
#[inline(always)]
fn limb_u64_mut<D: crate::layouts::HostDataMut, BE: Backend<ScalarPrep = Q120bScalar>>(
v: &mut VecZnxDft<D, BE>,
col: usize,
limb: usize,
) -> &mut [u64] {
cast_slice_mut(v.at_mut(col, limb))
}
pub fn ntt120_vec_znx_dft_apply<BE>(
module: &impl NttModuleHandle,
step: usize,
offset: usize,
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxBackendRef<'_, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttDFTExecute<NttTable<Primes30>> + NttFromZnx64 + NttZero + 'static,
for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
{
let a_size = a.size();
let res_size = res.size();
let table = module.get_ntt_table();
let steps = a_size.div_ceil(step);
let min_steps = res_size.min(steps);
for j in 0..min_steps {
let limb = offset + j * step;
if limb < a_size {
let res_slice: &mut [u64] = limb_u64_mut(res, res_col, j);
BE::ntt_from_znx64(res_slice, a.at(a_col, limb));
BE::ntt_dft_execute(table, res_slice);
} else {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
}
for j in min_steps..res_size {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
}
pub fn ntt120_vec_znx_idft_apply_tmp_bytes(n: usize) -> usize {
4 * n * size_of::<u64>()
}
pub fn ntt120_vec_znx_idft_apply<BE>(
module: &impl NttModuleHandle,
res: &mut VecZnxBigBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
tmp: &mut [u64],
) where
BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128 + NttCopy,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
let n = res.n();
let res_size = res.size();
let min_size = res_size.min(a.size());
let table = module.get_intt_table();
for j in 0..min_size {
let a_slice: &[u64] = limb_u64(a, a_col, j);
let tmp_n: &mut [u64] = &mut tmp[..4 * n];
BE::ntt_copy(tmp_n, a_slice);
BE::ntt_dft_execute(table, tmp_n);
BE::ntt_to_znx128(res.at_mut(res_col, j), n, tmp_n);
}
for j in min_size..res_size {
res.at_mut(res_col, j).fill(0i128);
}
}
pub fn ntt120_vec_znx_idft_apply_tmpa<BE>(
module: &impl NttModuleHandle,
res: &mut VecZnxBigBackendMut<'_, BE>,
res_col: usize,
a: &mut VecZnxDftBackendMut<'_, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
{
let n = res.n();
let res_size = res.size();
let min_size = res_size.min(a.size());
let table = module.get_intt_table();
for j in 0..min_size {
BE::ntt_dft_execute(table, limb_u64_mut(a, a_col, j));
let a_slice: &[u64] = limb_u64(a, a_col, j);
BE::ntt_to_znx128(res.at_mut(res_col, j), n, a_slice);
}
for j in min_size..res_size {
res.at_mut(res_col, j).fill(0i128);
}
}
#[allow(dead_code)]
pub fn ntt120_vec_znx_idft_apply_consume<'a, BE>(
module: &impl NttModuleHandle,
mut a: VecZnxDftBackendMut<'a, BE>,
) -> VecZnxBigBackendMut<'a, BE>
where
BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128>,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
{
let table = module.get_intt_table();
let (n, n_blocks, u64_ptr) = {
let n = a.n();
let n_blocks = a.cols() * a.size();
let ptr: *mut u64 = {
let s = a.raw_mut();
cast_slice_mut::<_, u64>(s).as_mut_ptr()
};
(n, n_blocks, ptr)
};
unsafe { compact_all_blocks_scalar(n, n_blocks, u64_ptr, table) };
a.into_big()
}
#[allow(dead_code)]
#[inline(always)]
fn barrett_u61(x: u64, q: u64, mu: u64) -> u64 {
let q_approx = ((x as u128 * mu as u128) >> 61) as u64;
let r = x - q_approx * q;
let r = if r >= q { r - q } else { r };
if r >= q { r - q } else { r }
}
#[allow(dead_code)]
#[inline(always)]
fn reduce_q120b_crt(x: u64, q: u64, mu: u64, pow32_crt: u64, pow16_crt: u64, crt: u64) -> u64 {
let x_hi = x >> 32;
let x_hi_r = if x_hi >= q { x_hi - q } else { x_hi };
let x_lo = x & 0xFFFF_FFFF;
let x_lo_hi = x_lo >> 16;
let x_lo_lo = x_lo & 0xFFFF;
let tmp = x_hi_r
.wrapping_mul(pow32_crt)
.wrapping_add(x_lo_hi.wrapping_mul(pow16_crt))
.wrapping_add(x_lo_lo.wrapping_mul(crt));
barrett_u61(tmp, q, mu)
}
#[allow(dead_code)]
unsafe fn compact_all_blocks_scalar(n: usize, n_blocks: usize, u64_ptr: *mut u64, table: &NttTableInv<Primes30>) {
let q_u64: [u64; 4] = Primes30::Q.map(|qi| qi as u64);
let mu: [u64; 4] = q_u64.map(|qi| (1u64 << 61) / qi);
let crt: [u64; 4] = Primes30::CRT_CST.map(|c| c as u64);
let pow32_crt: [u64; 4] = std::array::from_fn(|k| {
let pow32 = ((1u128 << 32) % q_u64[k] as u128) as u64;
barrett_u61(pow32 * crt[k], q_u64[k], mu[k])
});
let pow16_crt: [u64; 4] = std::array::from_fn(|k| barrett_u61((1u64 << 16) * crt[k], q_u64[k], mu[k]));
let q: [u128; 4] = q_u64.map(|qi| qi as u128);
let total_q: u128 = q[0] * q[1] * q[2] * q[3];
let qm: [u128; 4] = [q[1] * q[2] * q[3], q[0] * q[2] * q[3], q[0] * q[1] * q[3], q[0] * q[1] * q[2]];
let half_q: u128 = total_q.div_ceil(2);
let total_q_mult: [u128; 4] = [0, total_q, total_q * 2, total_q * 3];
for k in 0..n_blocks {
let src_start = 4 * n * k;
let dst_start = 2 * n * k;
{
let blk: &mut [u64] = unsafe { std::slice::from_raw_parts_mut(u64_ptr.add(src_start), 4 * n) };
intt_ref::<Primes30>(table, blk);
}
for c in 0..n {
let (x0, x1, x2, x3) = unsafe {
(
*u64_ptr.add(src_start + 4 * c),
*u64_ptr.add(src_start + 4 * c + 1),
*u64_ptr.add(src_start + 4 * c + 2),
*u64_ptr.add(src_start + 4 * c + 3),
)
};
let t0 = reduce_q120b_crt(x0, q_u64[0], mu[0], pow32_crt[0], pow16_crt[0], crt[0]);
let t1 = reduce_q120b_crt(x1, q_u64[1], mu[1], pow32_crt[1], pow16_crt[1], crt[1]);
let t2 = reduce_q120b_crt(x2, q_u64[2], mu[2], pow32_crt[2], pow16_crt[2], crt[2]);
let t3 = reduce_q120b_crt(x3, q_u64[3], mu[3], pow32_crt[3], pow16_crt[3], crt[3]);
let mut v: u128 = t0 as u128 * qm[0] + t1 as u128 * qm[1] + t2 as u128 * qm[2] + t3 as u128 * qm[3];
let q_approx = (v >> 120) as usize;
v -= total_q_mult[q_approx];
if v >= total_q {
v -= total_q;
}
let val: i128 = if v >= half_q { v as i128 - total_q as i128 } else { v as i128 };
unsafe { (u64_ptr.add(dst_start + 2 * c) as *mut i128).write_unaligned(val) };
}
}
}
pub fn ntt120_vec_znx_dft_add_into<BE>(
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
b: &VecZnxDftBackendRef<'_, BE>,
b_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttAdd + NttCopy + NttZero,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
if a_size <= b_size {
let sum_size = a_size.min(res_size);
let cpy_size = b_size.min(res_size);
for j in 0..sum_size {
BE::ntt_add(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
}
for j in sum_size..cpy_size {
BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(b, b_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
} else {
let sum_size = b_size.min(res_size);
let cpy_size = a_size.min(res_size);
for j in 0..sum_size {
BE::ntt_add(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
}
for j in sum_size..cpy_size {
BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
}
}
pub fn ntt120_vec_znx_dft_add_assign<BE>(
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttAddAssign,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
let sum_size = res.size().min(a.size());
for j in 0..sum_size {
BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
}
}
pub fn ntt120_vec_znx_dft_add_scaled_assign<BE>(
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
a_scale: i64,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttAddAssign,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
let res_size = res.size();
let a_size = a.size();
if a_scale > 0 {
let shift = (a_scale as usize).min(a_size);
let sum_size = a_size.min(res_size).saturating_sub(shift);
for j in 0..sum_size {
BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j + shift));
}
} else if a_scale < 0 {
let shift = (a_scale.unsigned_abs() as usize).min(res_size);
let sum_size = a_size.min(res_size.saturating_sub(shift));
for j in 0..sum_size {
BE::ntt_add_assign(limb_u64_mut(res, res_col, j + shift), limb_u64(a, a_col, j));
}
} else {
let sum_size = a_size.min(res_size);
for j in 0..sum_size {
BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
}
}
}
pub fn ntt120_vec_znx_dft_sub<BE>(
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
b: &VecZnxDftBackendRef<'_, BE>,
b_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttSub + NttNegate + NttCopy + NttZero,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
let res_size = res.size();
let a_size = a.size();
let b_size = b.size();
if a_size <= b_size {
let sum_size = a_size.min(res_size);
let cpy_size = b_size.min(res_size);
for j in 0..sum_size {
BE::ntt_sub(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
}
for j in sum_size..cpy_size {
BE::ntt_negate(limb_u64_mut(res, res_col, j), limb_u64(b, b_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
} else {
let sum_size = b_size.min(res_size);
let cpy_size = a_size.min(res_size);
for j in 0..sum_size {
BE::ntt_sub(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
}
for j in sum_size..cpy_size {
BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
}
for j in cpy_size..res_size {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
}
}
pub fn ntt120_vec_znx_dft_sub_assign<BE>(
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttSubAssign,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
let sum_size = res.size().min(a.size());
for j in 0..sum_size {
BE::ntt_sub_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
}
}
pub fn ntt120_vec_znx_dft_sub_negate_assign<BE>(
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttSubNegateAssign + NttNegateAssign,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
let res_size = res.size();
let sum_size = res_size.min(a.size());
for j in 0..sum_size {
BE::ntt_sub_negate_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
}
for j in sum_size..res_size {
BE::ntt_negate_assign(limb_u64_mut(res, res_col, j));
}
}
pub fn ntt120_vec_znx_dft_copy<BE>(
step: usize,
offset: usize,
res: &mut VecZnxDftBackendMut<'_, BE>,
res_col: usize,
a: &VecZnxDftBackendRef<'_, BE>,
a_col: usize,
) where
BE: Backend<ScalarPrep = Q120bScalar> + NttCopy + NttZero,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let steps: usize = a.size().div_ceil(step);
let min_steps: usize = res.size().min(steps);
for j in 0..min_steps {
let limb = offset + j * step;
if limb < a.size() {
BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, limb));
} else {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
}
for j in min_steps..res.size() {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
}
pub fn ntt120_vec_znx_dft_zero<BE>(res: &mut VecZnxDftBackendMut<'_, BE>, res_col: usize)
where
BE: Backend<ScalarPrep = Q120bScalar> + NttZero,
for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
{
for j in 0..res.size() {
BE::ntt_zero(limb_u64_mut(res, res_col, j));
}
}