use poulpy_hal::{
api::{ModuleN, ScratchArenaTakeBasic, SvpPPolBytesOf, VmpPMatBytesOf},
layouts::{Backend, ScratchArena},
};
use crate::{
dist::Distribution,
layouts::{
Degree, GGLWE, GGLWEInfos, GGLWEPreparedViewMut, GGLWEViewMut, GGSW, GGSWInfos, GGSWPreparedViewMut, GGSWViewMut, GLWE,
GLWEInfos, GLWEPlaintext, GLWEPlaintextViewMut, GLWESecret, GLWESecretPreparedViewMut, GLWESecretTensor,
GLWESecretTensorViewMut, GLWESecretViewMut, GLWETensor, GLWETensorViewMut, GLWEViewMut, LWE, LWEInfos, LWEPlaintext,
LWEPlaintextViewMut, LWEViewMut, Rank,
prepared::{GGLWEPrepared, GGSWPrepared, GLWESecretPrepared},
},
};
pub trait ScratchArenaTakeCore<'a, B: Backend>: ScratchArenaTakeBasic<'a, B> + Sized {
fn take_lwe_scratch<A>(self, infos: &A) -> (LWEViewMut<'a, B>, Self)
where
B: 'a,
A: LWEInfos,
{
let (body, scratch_1) = self.take_vec_znx_scratch(1, 1, infos.size());
let (mask, scratch) = scratch_1.take_vec_znx_scratch(infos.n().into(), 1, infos.size());
(
LWEViewMut::from_inner(LWE {
base2k: infos.base2k(),
body: body.into_inner(),
mask: mask.into_inner(),
}),
scratch,
)
}
fn take_lwe_plaintext_scratch<A>(self, infos: &A) -> (LWEPlaintextViewMut<'a, B>, Self)
where
B: 'a,
A: LWEInfos,
{
let (data, scratch) = self.take_vec_znx_scratch(1, 1, infos.size());
(
LWEPlaintextViewMut::from_inner(LWEPlaintext {
base2k: infos.base2k(),
data: data.into_inner(),
}),
scratch,
)
}
fn take_glwe_scratch<A>(self, infos: &A) -> (GLWEViewMut<'a, B>, Self)
where
B: 'a,
A: GLWEInfos,
{
let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), (infos.rank() + 1).into(), infos.size());
(
GLWEViewMut::from_inner(GLWE {
base2k: infos.base2k(),
data: data.into_inner(),
}),
scratch,
)
}
fn take_glwe_slice_scratch<A>(self, size: usize, infos: &A) -> (Vec<GLWEViewMut<'a, B>>, Self)
where
B: 'a,
A: GLWEInfos,
{
let mut scratch: Self = self;
let mut cts: Vec<GLWEViewMut<'a, B>> = Vec::with_capacity(size);
for _ in 0..size {
let (ct, new_scratch) = scratch.take_glwe_scratch(infos);
scratch = new_scratch;
cts.push(ct);
}
(cts, scratch)
}
fn take_glwe_tensor_scratch<A>(self, infos: &A) -> (GLWETensorViewMut<'a, B>, Self)
where
B: 'a,
A: GLWEInfos,
{
let cols: usize = infos.rank().as_usize() + 1;
let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), pairs, infos.size());
(
GLWETensorViewMut::from_inner(GLWETensor {
base2k: infos.base2k(),
rank: infos.rank(),
data: data.into_inner(),
}),
scratch,
)
}
fn take_glwe_plaintext_scratch<A>(self, infos: &A) -> (GLWEPlaintextViewMut<'a, B>, Self)
where
B: 'a,
A: GLWEInfos,
{
let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), 1, infos.size());
(
GLWEPlaintextViewMut::from_inner(GLWEPlaintext {
base2k: infos.base2k(),
data: data.into_inner(),
}),
scratch,
)
}
fn take_glwe_secret_prepared_scratch<M>(self, module: &M, rank: Rank) -> (GLWESecretPreparedViewMut<'a, B>, Self)
where
B: 'a,
M: ModuleN + SvpPPolBytesOf,
{
let (data, scratch) = self.take_svp_ppol_scratch(module, rank.into());
(
GLWESecretPreparedViewMut::from_inner(GLWESecretPrepared {
data: data.into_inner(),
dist: Distribution::NONE,
}),
scratch,
)
}
fn take_glwe_secret_scratch(self, n: Degree, rank: Rank) -> (GLWESecretViewMut<'a, B>, Self)
where
B: 'a,
{
let (data, scratch) = self.take_scalar_znx_scratch(n.into(), rank.into());
(
GLWESecretViewMut::from_inner(GLWESecret {
data: data.into_inner(),
dist: Distribution::NONE,
}),
scratch,
)
}
fn take_glwe_secret_tensor_scratch(self, n: Degree, rank: Rank) -> (GLWESecretTensorViewMut<'a, B>, Self)
where
B: 'a,
{
let (data, scratch) = self.take_scalar_znx_scratch(n.into(), GLWESecretTensor::pairs(rank.into()));
(
GLWESecretTensorViewMut::from_inner(GLWESecretTensor {
data: data.into_inner(),
rank,
dist: Distribution::NONE,
}),
scratch,
)
}
fn take_gglwe_scratch<A>(self, infos: &A) -> (GGLWEViewMut<'a, B>, Self)
where
B: 'a,
A: GGLWEInfos,
{
let (data, scratch) = self.take_mat_znx_scratch(
infos.n().into(),
infos.dnum().0.div_ceil(infos.dsize().0) as usize,
infos.rank_in().into(),
(infos.rank_out() + 1).into(),
infos.size(),
);
(
GGLWEViewMut::from_inner(GGLWE {
base2k: infos.base2k(),
dsize: infos.dsize(),
data: data.into_inner(),
}),
scratch,
)
}
fn take_gglwe_prepared_scratch<A, M>(self, module: &M, infos: &A) -> (GGLWEPreparedViewMut<'a, B>, Self)
where
B: 'a,
A: GGLWEInfos,
M: ModuleN + VmpPMatBytesOf,
{
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vmp_pmat_scratch(
module,
infos.dnum().into(),
infos.rank_in().into(),
(infos.rank_out() + 1).into(),
infos.size(),
);
(
GGLWEPreparedViewMut::from_inner(GGLWEPrepared {
base2k: infos.base2k(),
dsize: infos.dsize(),
data: data.into_inner(),
}),
scratch,
)
}
fn take_ggsw_scratch<A>(self, infos: &A) -> (GGSWViewMut<'a, B>, Self)
where
B: 'a,
A: GGSWInfos,
{
let (data, scratch) = self.take_mat_znx_scratch(
infos.n().into(),
infos.dnum().into(),
(infos.rank() + 1).into(),
(infos.rank() + 1).into(),
infos.size(),
);
(
GGSWViewMut::from_inner(GGSW {
base2k: infos.base2k(),
dsize: infos.dsize(),
data: data.into_inner(),
}),
scratch,
)
}
fn take_ggsw_prepared_scratch<A, M>(self, module: &M, infos: &A) -> (GGSWPreparedViewMut<'a, B>, Self)
where
B: 'a,
A: GGSWInfos,
M: ModuleN + VmpPMatBytesOf,
{
assert_eq!(module.n() as u32, infos.n());
let (data, scratch) = self.take_vmp_pmat_scratch(
module,
infos.dnum().into(),
(infos.rank() + 1).into(),
(infos.rank() + 1).into(),
infos.size(),
);
(
GGSWPreparedViewMut::from_inner(GGSWPrepared {
base2k: infos.base2k(),
dsize: infos.dsize(),
data: data.into_inner(),
}),
scratch,
)
}
}
impl<'a, B: Backend> ScratchArenaTakeCore<'a, B> for ScratchArena<'a, B> {}