use std::ptr::NonNull;
use poulpy_hal::{
layouts::{Backend, Module},
oep::ModuleNewImpl,
reference::ntt120::{
mat_vec::{BbbMeta, BbcMeta},
ntt::{NttTable, NttTableInv},
primes::Primes30,
types::Q120bScalar,
vec_znx_dft::NttHandleProvider,
},
};
use super::NTT120Avx;
#[repr(C)]
pub struct NTT120AvxHandle {
table_ntt: NttTable<Primes30>,
table_intt: NttTableInv<Primes30>,
meta_bbc: BbcMeta<Primes30>,
meta_bbb: BbbMeta<Primes30>,
}
impl Backend for NTT120Avx {
type ScalarPrep = Q120bScalar;
type ScalarBig = i128;
type Handle = NTT120AvxHandle;
unsafe fn destroy(handle: NonNull<Self::Handle>) {
unsafe {
drop(Box::from_raw(handle.as_ptr()));
}
}
}
unsafe impl ModuleNewImpl<Self> for NTT120Avx {
fn new_impl(n: u64) -> Module<Self> {
if !std::arch::is_x86_feature_detected!("avx2") {
panic!("arch must support avx2")
}
let handle = NTT120AvxHandle {
table_ntt: NttTable::new(n as usize),
table_intt: NttTableInv::new(n as usize),
meta_bbc: BbcMeta::new(),
meta_bbb: BbbMeta::new(),
};
let ptr: NonNull<NTT120AvxHandle> = NonNull::from(Box::leak(Box::new(handle)));
unsafe { Module::from_nonnull(ptr, n) }
}
}
unsafe impl NttHandleProvider for NTT120AvxHandle {
fn get_ntt_table(&self) -> &NttTable<Primes30> {
&self.table_ntt
}
fn get_intt_table(&self) -> &NttTableInv<Primes30> {
&self.table_intt
}
fn get_bbc_meta(&self) -> &BbcMeta<Primes30> {
&self.meta_bbc
}
fn get_bbb_meta(&self) -> &BbbMeta<Primes30> {
&self.meta_bbb
}
}