use poulpy_hal::{
api::{
ModuleN, SvpApplyDftToDft, SvpPrepare, VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
},
layouts::{
Backend, Data, HostDataMut, HostDataRef, Module, ScalarZnx, ScalarZnxToBackendRef, ScratchArena, ScratchOwned,
SvpPPolReborrowBackendMut, SvpPPolReborrowBackendRef, VecZnxBigToBackendMut, VecZnxBigToBackendRef, VecZnxDft,
VecZnxDftToBackendMut, VecZnxDftToBackendRef, ZnxView, ZnxViewMut, scalar_znx_as_vec_znx_backend_mut_from_mut,
scalar_znx_as_vec_znx_backend_ref_from_ref,
},
};
use crate::{
GetDistribution, GetDistributionMut, ScratchArenaTakeCore,
dist::Distribution,
layouts::{
Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretBackendMut, GLWESecretBackendRef, GLWESecretPreparedFactory,
GLWESecretToBackendMut, GLWESecretToBackendRef, LWEInfos, Rank,
},
};
pub struct GLWESecretTensor<D: Data> {
pub(crate) data: ScalarZnx<D>,
pub(crate) rank: Rank,
pub(crate) dist: Distribution,
}
impl GLWESecretTensor<Vec<u8>> {
pub(crate) fn pairs(rank: usize) -> usize {
(((rank + 1) * rank) >> 1).max(1)
}
}
impl<D: Data> GetDistribution for GLWESecretTensor<D> {
fn dist(&self) -> &Distribution {
&self.dist
}
}
impl<D: Data> GetDistributionMut for GLWESecretTensor<D> {
fn dist_mut(&mut self) -> &mut Distribution {
&mut self.dist
}
}
impl<D: Data> LWEInfos for GLWESecretTensor<D> {
fn base2k(&self) -> Base2K {
Base2K(0)
}
fn n(&self) -> Degree {
Degree(self.data.n() as u32)
}
fn size(&self) -> usize {
1
}
}
impl<D: Data> LWEInfos for &mut GLWESecretTensor<D> {
fn base2k(&self) -> Base2K {
(**self).base2k()
}
fn n(&self) -> Degree {
(**self).n()
}
fn size(&self) -> usize {
(**self).size()
}
}
impl<D: HostDataRef> GLWESecretTensor<D> {
pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank().into();
ScalarZnx::from_data(
bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)),
self.n().into(),
1,
)
}
}
impl<D: HostDataMut> GLWESecretTensor<D> {
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank().into();
let n = self.n().into();
ScalarZnx::from_data(
bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)),
n,
1,
)
}
}
impl<D: Data> GLWEInfos for GLWESecretTensor<D> {
fn rank(&self) -> Rank {
self.rank
}
}
impl<D: Data> GLWEInfos for &mut GLWESecretTensor<D> {
fn rank(&self) -> Rank {
(**self).rank()
}
}
impl<BE: Backend> GLWESecretToBackendRef<BE> for GLWESecretTensor<BE::OwnedBuf> {
fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
GLWESecret {
data: <ScalarZnx<BE::OwnedBuf> as ScalarZnxToBackendRef<BE>>::to_backend_ref(&self.data),
dist: self.dist,
}
}
}
impl<'b, BE: Backend + 'b> GLWESecretToBackendRef<BE> for &mut GLWESecretTensor<BE::BufMut<'b>> {
fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
GLWESecret {
data: ScalarZnx::from_data(BE::view_ref_mut(&self.data.data), self.data.n(), self.data.cols()),
dist: self.dist,
}
}
}
impl<BE: Backend> GLWESecretToBackendMut<BE> for GLWESecretTensor<BE::OwnedBuf> {
fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
GLWESecret {
data: <ScalarZnx<BE::OwnedBuf> as poulpy_hal::layouts::ScalarZnxToBackendMut<BE>>::to_backend_mut(&mut self.data),
dist: self.dist,
}
}
}
impl<'b, BE: Backend + 'b> GLWESecretToBackendMut<BE> for &mut GLWESecretTensor<BE::BufMut<'b>> {
fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
let n = self.data.n();
let cols = self.data.cols();
GLWESecret {
data: ScalarZnx::from_data(BE::view_mut_ref(&mut self.data.data), n, cols),
dist: self.dist,
}
}
}
#[expect(
dead_code,
reason = "host-owned constructors are kept for serialization and host-only staging"
)]
impl GLWESecretTensor<Vec<u8>> {
pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
where
A: GLWEInfos,
{
Self::alloc(infos.n(), infos.rank())
}
pub(crate) fn alloc(n: Degree, rank: Rank) -> Self {
GLWESecretTensor {
data: ScalarZnx::from_data(
poulpy_hal::layouts::HostBytesBackend::alloc_bytes(ScalarZnx::<Vec<u8>>::bytes_of(
n.into(),
Self::pairs(rank.into()),
)),
n.into(),
Self::pairs(rank.into()),
),
rank,
dist: Distribution::NONE,
}
}
pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where
A: GLWEInfos,
{
Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into())
}
pub fn bytes_of(n: Degree, rank: Rank) -> usize {
ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into()))
}
}
pub trait GLWESecretTensorFactory<BE: Backend> {
fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize;
fn glwe_secret_tensor_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut ScratchArena<'_, BE>)
where
R: GLWESecretToBackendMut<BE> + GetDistributionMut + GLWEInfos,
O: GLWESecretToBackendRef<BE> + GetDistribution + GLWEInfos;
}
impl<BE: Backend> GLWESecretTensorFactory<BE> for Module<BE>
where
Self: ModuleN
+ GLWESecretPreparedFactory<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigBytesOf
+ VecZnxBigNormalizeTmpBytes,
{
fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize {
self.glwe_secret_prepared_bytes_of(rank)
}
fn glwe_secret_tensor_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut ScratchArena<'_, BE>)
where
R: GLWESecretToBackendMut<BE> + GetDistributionMut + GLWEInfos,
A: GLWESecretToBackendRef<BE> + GetDistribution + GLWEInfos,
{
let res = &mut res.to_backend_mut();
let a = a.to_backend_ref();
assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32);
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert!(
scratch.available() >= self.glwe_secret_tensor_prepare_tmp_bytes(a.rank()),
"scratch.available(): {} < GLWESecretTensorFactory::glwe_secret_tensor_prepare_tmp_bytes: {}",
scratch.available(),
self.glwe_secret_tensor_prepare_tmp_bytes(a.rank())
);
let rank: usize = a.rank().into();
let scratch = scratch.borrow();
let (mut a_prepared, _scratch_1) = scratch.take_glwe_secret_prepared_scratch(self, rank.into());
{
let mut a_prepared_data = a_prepared.data.reborrow_backend_mut();
for i in 0..rank {
self.svp_prepare(&mut a_prepared_data, i, &a.data, i);
}
}
a_prepared.dist = *a.dist();
let base2k: usize = 17;
let mut a_dft = VecZnxDft::<BE::OwnedBuf, BE>::alloc(self.n(), rank, 1);
let a_backend_vec = scalar_znx_as_vec_znx_backend_ref_from_ref::<BE>(&a.data);
for i in 0..rank {
let mut a_dft_backend = a_dft.to_backend_mut();
self.vec_znx_dft_apply(1, 0, &mut a_dft_backend, i, &a_backend_vec, i);
}
let mut a_ij_dft = VecZnxDft::<BE::OwnedBuf, BE>::alloc(self.n(), 1, 1);
let a_prepared_backend_ref = a_prepared.data.reborrow_backend_ref();
let mut a_ij_big_backend = self.vec_znx_big_alloc(1, 1);
let mut norm_scratch = ScratchOwned {
data: BE::alloc_bytes(self.vec_znx_big_normalize_tmp_bytes()),
_phantom: std::marker::PhantomData,
};
let mut res_backend = scalar_znx_as_vec_znx_backend_mut_from_mut::<BE>(&mut res.data);
for i in 0..rank {
for j in i..rank {
let idx: usize = i * rank + j - (i * (i + 1) / 2);
let a_dft_ref = a_dft.to_backend_ref();
{
let mut a_ij_dft_backend = a_ij_dft.to_backend_mut();
self.svp_apply_dft_to_dft(&mut a_ij_dft_backend, 0, &a_prepared_backend_ref, j, &a_dft_ref, i);
}
{
let mut a_ij_big = a_ij_big_backend.to_backend_mut();
let mut a_ij_dft = a_ij_dft.to_backend_mut();
self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0);
}
{
let a_ij_big = a_ij_big_backend.to_backend_ref();
self.vec_znx_big_normalize(
&mut res_backend,
base2k,
0,
idx,
&a_ij_big,
base2k,
0,
&mut norm_scratch.arena(),
);
}
}
}
res.dist = *a.dist();
}
}