use crate::{
api::{CnvPVecBytesOf, ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
layouts::{
Backend, CnvPVecL, CnvPVecLViewMut, CnvPVecR, CnvPVecRViewMut, MatZnx, MatZnxViewMut, ScalarZnx, ScalarZnxViewMut,
ScratchArena, SvpPPol, SvpPPolViewMut, VecZnx, VecZnxBig, VecZnxBigViewMut, VecZnxDft, VecZnxDftViewMut, VecZnxViewMut,
VmpPMat, VmpPMatViewMut,
},
};
pub trait ScratchOwnedAlloc<B: Backend> {
fn alloc(size: usize) -> Self;
}
pub trait ScratchOwnedBorrow<B: Backend> {
fn borrow(&mut self) -> ScratchArena<'_, B>;
}
pub trait ScratchAvailable {
fn available(&self) -> usize;
}
pub trait HostBufMut<'a>: Sized {
fn into_bytes(self) -> &'a mut [u8];
}
impl<'a> HostBufMut<'a> for &'a mut [u8] {
#[inline]
fn into_bytes(self) -> &'a mut [u8] {
self
}
}
pub trait ScratchArenaTakeBasic<'a, B: Backend>: Sized {
fn take_cnv_pvec_left_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecLViewMut<'a, B>, Self)
where
B: 'a,
M: ModuleN + CnvPVecBytesOf;
fn take_cnv_pvec_right_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecRViewMut<'a, B>, Self)
where
B: 'a,
M: ModuleN + CnvPVecBytesOf;
fn take_scalar_znx_scratch(self, n: usize, cols: usize) -> (ScalarZnxViewMut<'a, B>, Self)
where
B: 'a;
fn take_svp_ppol_scratch<M>(self, module: &M, cols: usize) -> (SvpPPolViewMut<'a, B>, Self)
where
B: 'a,
M: SvpPPolBytesOf + ModuleN;
fn take_vec_znx_scratch(self, n: usize, cols: usize, size: usize) -> (VecZnxViewMut<'a, B>, Self)
where
B: 'a;
fn take_vec_znx_big_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
where
B: 'a,
M: VecZnxBigBytesOf + ModuleN;
fn take_vec_znx_big_scratch_n(self, n: usize, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
where
B: 'a;
fn take_vec_znx_dft_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxDftViewMut<'a, B>, Self)
where
B: 'a,
M: VecZnxDftBytesOf + ModuleN;
fn take_vec_znx_dft_slice_scratch<M>(
self,
module: &M,
len: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDftViewMut<'a, B>>, Self)
where
B: 'a,
M: VecZnxDftBytesOf + ModuleN,
{
let mut scratch: Self = self;
let mut slice: Vec<VecZnxDftViewMut<'a, B>> = Vec::with_capacity(len);
for _ in 0..len {
let (znx, rem) = scratch.take_vec_znx_dft_scratch(module, cols, size);
scratch = rem;
slice.push(znx);
}
(slice, scratch)
}
fn take_vec_znx_slice_scratch(self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnxViewMut<'a, B>>, Self)
where
B: 'a,
{
let mut scratch: Self = self;
let mut slice: Vec<VecZnxViewMut<'a, B>> = Vec::with_capacity(len);
for _ in 0..len {
let (znx, rem) = scratch.take_vec_znx_scratch(n, cols, size);
scratch = rem;
slice.push(znx);
}
(slice, scratch)
}
fn take_vmp_pmat_scratch<M>(
self,
module: &M,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMatViewMut<'a, B>, Self)
where
B: 'a,
M: VmpPMatBytesOf + ModuleN;
fn take_mat_znx_scratch(
self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnxViewMut<'a, B>, Self)
where
B: 'a;
}
impl<'a, B: Backend> ScratchArenaTakeBasic<'a, B> for ScratchArena<'a, B> {
fn take_cnv_pvec_left_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecLViewMut<'a, B>, Self)
where
B: 'a,
M: ModuleN + CnvPVecBytesOf,
{
let (data, arena) = self.take_region(module.bytes_of_cnv_pvec_left(cols, size));
(
CnvPVecLViewMut::from_inner(CnvPVecL::from_data(data, module.n(), cols, size)),
arena,
)
}
fn take_cnv_pvec_right_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecRViewMut<'a, B>, Self)
where
B: 'a,
M: ModuleN + CnvPVecBytesOf,
{
let (data, arena) = self.take_region(module.bytes_of_cnv_pvec_right(cols, size));
(
CnvPVecRViewMut::from_inner(CnvPVecR::from_data(data, module.n(), cols, size)),
arena,
)
}
fn take_scalar_znx_scratch(self, n: usize, cols: usize) -> (ScalarZnxViewMut<'a, B>, Self)
where
B: 'a,
{
let (data, arena) = self.take_region(ScalarZnx::bytes_of(n, cols));
(ScalarZnxViewMut::from_inner(ScalarZnx::from_data(data, n, cols)), arena)
}
fn take_svp_ppol_scratch<M>(self, module: &M, cols: usize) -> (SvpPPolViewMut<'a, B>, Self)
where
B: 'a,
M: SvpPPolBytesOf + ModuleN,
{
let (data, arena) = self.take_region(module.bytes_of_svp_ppol(cols));
(SvpPPolViewMut::from_inner(SvpPPol::from_data(data, module.n(), cols)), arena)
}
fn take_vec_znx_scratch(self, n: usize, cols: usize, size: usize) -> (VecZnxViewMut<'a, B>, Self)
where
B: 'a,
{
let (data, arena) = self.take_region(VecZnx::bytes_of(n, cols, size));
(VecZnxViewMut::from_inner(VecZnx::from_data(data, n, cols, size)), arena)
}
fn take_vec_znx_big_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
where
B: 'a,
M: VecZnxBigBytesOf + ModuleN,
{
self.take_vec_znx_big_scratch_n(module.n(), cols, size)
}
fn take_vec_znx_big_scratch_n(self, n: usize, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
where
B: 'a,
{
let (data, arena) = self.take_region(B::bytes_of_vec_znx_big(n, cols, size));
(VecZnxBigViewMut::from_inner(VecZnxBig::from_data(data, n, cols, size)), arena)
}
fn take_vec_znx_dft_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxDftViewMut<'a, B>, Self)
where
B: 'a,
M: VecZnxDftBytesOf + ModuleN,
{
let (data, arena) = self.take_region(module.bytes_of_vec_znx_dft(cols, size));
(
VecZnxDftViewMut::from_inner(VecZnxDft::from_data(data, module.n(), cols, size)),
arena,
)
}
fn take_vmp_pmat_scratch<M>(
self,
module: &M,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMatViewMut<'a, B>, Self)
where
B: 'a,
M: VmpPMatBytesOf + ModuleN,
{
let (data, arena) = self.take_region(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
(
VmpPMatViewMut::from_inner(VmpPMat::from_data(data, module.n(), rows, cols_in, cols_out, size)),
arena,
)
}
fn take_mat_znx_scratch(
self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnxViewMut<'a, B>, Self)
where
B: 'a,
{
let (data, arena) = self.take_region(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
(
MatZnxViewMut::from_inner(MatZnx::from_data(data, n, rows, cols_in, cols_out, size)),
arena,
)
}
}