use std::{
fmt::{Debug, Display},
marker::PhantomData,
ptr::NonNull,
};
use bytemuck::Pod;
use rand_distr::num_traits::Zero;
use crate::layouts::{Data, Location, MatZnx, ScalarZnx, VecZnx};
use crate::{
GALOISGENERATOR,
api::{ModuleLogN, ModuleN},
};
#[allow(clippy::missing_safety_doc)]
pub trait Backend: Sized + Sync + Send {
type ScalarBig: Copy + Zero + Display + Debug + Pod;
type ScalarPrep: Copy + Zero + Display + Debug + Pod;
type OwnedBuf: Data + Send + Sync;
type BufRef<'a>: Data + Sync
where
Self: 'a;
type BufMut<'a>: Data + Send
where
Self: 'a;
type Handle: 'static;
type Location: Location;
fn alloc_bytes(len: usize) -> Self::OwnedBuf;
fn alloc_zeroed_bytes(len: usize) -> Self::OwnedBuf {
let mut buf = Self::alloc_bytes(len);
let zeros = vec![0u8; len];
Self::copy_from_host(&mut buf, &zeros);
buf
}
fn from_host_bytes(bytes: &[u8]) -> Self::OwnedBuf;
fn from_bytes(bytes: Vec<u8>) -> Self::OwnedBuf;
fn to_host_bytes(buf: &Self::OwnedBuf) -> Vec<u8>;
fn copy_to_host(buf: &Self::OwnedBuf, dst: &mut [u8]);
fn copy_from_host(buf: &mut Self::OwnedBuf, src: &[u8]);
fn len_bytes(buf: &Self::OwnedBuf) -> usize;
fn view(buf: &Self::OwnedBuf) -> Self::BufRef<'_>;
fn view_ref<'a, 'b>(buf: &'a Self::BufRef<'b>) -> Self::BufRef<'a>
where
Self: 'b;
fn view_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>) -> Self::BufRef<'a>
where
Self: 'b;
fn view_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>) -> Self::BufMut<'a>
where
Self: 'b;
fn view_mut(buf: &mut Self::OwnedBuf) -> Self::BufMut<'_>;
fn region(buf: &Self::OwnedBuf, offset: usize, len: usize) -> Self::BufRef<'_>;
fn region_mut(buf: &mut Self::OwnedBuf, offset: usize, len: usize) -> Self::BufMut<'_>;
fn region_ref<'a, 'b>(buf: &'a Self::BufRef<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
where
Self: 'b;
fn region_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
where
Self: 'b;
fn region_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufMut<'a>
where
Self: 'b;
fn size_of_scalar_big() -> usize {
size_of::<Self::ScalarBig>()
}
fn size_of_scalar_prep() -> usize {
size_of::<Self::ScalarPrep>()
}
const SCRATCH_ALIGN: usize = 64;
fn bytes_of_vec_znx_dft(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * Self::size_of_scalar_prep()
}
fn bytes_of_vec_znx_big(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * Self::size_of_scalar_big()
}
fn bytes_of_svp_ppol(n: usize, cols: usize) -> usize {
n * cols * Self::size_of_scalar_prep()
}
fn bytes_of_vmp_pmat(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
n * rows * cols_in * cols_out * size * Self::size_of_scalar_prep()
}
fn bytes_of_cnv_pvec_left(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * Self::size_of_scalar_prep()
}
fn bytes_of_cnv_pvec_right(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * Self::size_of_scalar_prep()
}
unsafe fn destroy(handle: NonNull<Self::Handle>);
}
#[repr(C)]
pub struct Module<B: Backend> {
ptr: NonNull<B::Handle>,
n: u64,
_marker: PhantomData<B>,
}
unsafe impl<B: Backend> Sync for Module<B> {}
unsafe impl<B: Backend> Send for Module<B> {}
impl<B: Backend> Module<B> {
#[inline]
pub fn new(n: u64) -> Self
where
Self: crate::api::ModuleNew<B>,
{
crate::api::ModuleNew::new(n)
}
#[allow(clippy::missing_safety_doc)]
#[inline]
pub unsafe fn from_nonnull(ptr: NonNull<B::Handle>, n: u64) -> Self {
assert!(n.is_power_of_two(), "n must be a power of two, got {n}");
Self {
ptr,
n,
_marker: PhantomData,
}
}
#[inline]
#[allow(clippy::missing_safety_doc)]
pub unsafe fn from_raw_parts(ptr: *mut B::Handle, n: u64) -> Self {
assert!(n.is_power_of_two(), "n must be a power of two, got {n}");
Self {
ptr: NonNull::new(ptr).expect("null module ptr"),
n,
_marker: PhantomData,
}
}
#[allow(clippy::missing_safety_doc)]
#[inline]
pub unsafe fn ptr(&self) -> *mut <B as Backend>::Handle {
self.ptr.as_ptr()
}
#[inline]
pub fn n(&self) -> usize {
self.n as usize
}
#[inline]
pub fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnx<B::OwnedBuf> {
let n = self.n();
let len = ScalarZnx::<Vec<u8>>::bytes_of(n, cols);
let bytes = B::alloc_zeroed_bytes(len);
ScalarZnx::from_data(bytes, n, cols)
}
#[inline]
pub fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnx<B::OwnedBuf> {
self.vec_znx_alloc_with_max_size(cols, size, size)
}
#[inline]
pub fn vec_znx_alloc_with_max_size(&self, cols: usize, size: usize, max_size: usize) -> VecZnx<B::OwnedBuf> {
let n = self.n();
let len = VecZnx::<Vec<u8>>::bytes_of(n, cols, max_size);
let bytes = B::alloc_zeroed_bytes(len);
VecZnx::from_data_with_max_size(bytes, n, cols, size, max_size)
}
#[inline]
pub fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnx<B::OwnedBuf> {
let n = self.n();
let len = MatZnx::<Vec<u8>>::bytes_of(n, rows, cols_in, cols_out, size);
let bytes = B::alloc_zeroed_bytes(len);
MatZnx::from_data(bytes, n, rows, cols_in, cols_out, size)
}
#[inline]
pub fn as_mut_ptr(&self) -> *mut B::Handle {
self.ptr.as_ptr()
}
#[inline]
pub fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
#[inline]
pub fn reinterpret<Other>(&self) -> &Module<Other>
where
Other: Backend<Handle = B::Handle>,
{
unsafe { &*(self as *const Self as *const Module<Other>) }
}
#[inline]
pub fn reinterpret_mut<Other>(&mut self) -> &mut Module<Other>
where
Other: Backend<Handle = B::Handle>,
{
unsafe { &mut *(self as *mut Self as *mut Module<Other>) }
}
}
pub trait CyclotomicOrder
where
Self: ModuleN,
{
fn cyclotomic_order(&self) -> i64 {
(self.n() << 1) as _
}
}
impl<BE: Backend> ModuleLogN for Module<BE> where Self: ModuleN {}
impl<BE: Backend> CyclotomicOrder for Module<BE> where Self: ModuleN {}
#[inline(always)]
pub fn galois_element(generator: i64, cyclotomic_order: i64) -> i64 {
debug_assert!(
cyclotomic_order > 0 && (cyclotomic_order as u64).is_power_of_two(),
"cyclotomic_order must be a power of two, got {cyclotomic_order}"
);
if generator == 0 {
return 1;
}
let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (cyclotomic_order - 1) as u64;
g_exp as i64 * generator.signum()
}
pub trait GaloisElement
where
Self: CyclotomicOrder,
{
fn galois_element(&self, generator: i64) -> i64 {
galois_element(generator, self.cyclotomic_order())
}
fn galois_element_inv(&self, gal_el: i64) -> i64 {
if gal_el == 0 {
panic!("cannot invert 0")
}
let g_exp: u64 =
mod_exp_u64(gal_el.unsigned_abs(), (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64;
g_exp as i64 * gal_el.signum()
}
}
impl<BE: Backend> GaloisElement for Module<BE> where Self: CyclotomicOrder {}
impl<B: Backend> Drop for Module<B> {
fn drop(&mut self) {
unsafe { B::destroy(self.ptr) }
}
}
pub fn mod_exp_u64(x: u64, e: usize) -> u64 {
let mut y: u64 = 1;
let mut x_pow: u64 = x;
let mut exp = e;
while exp > 0 {
if exp & 1 == 1 {
y = y.wrapping_mul(x_pow);
}
x_pow = x_pow.wrapping_mul(x_pow);
exp >>= 1;
}
y
}