use poulpy_hal::{
api::{
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VecZnxNormalizeTmpBytes,
},
layouts::{Backend, DataMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VecZnxToRef, ZnxZero},
};
use crate::{
GGLWEProduct, GLWECopy, ScratchTakeCore,
layouts::{
GGLWE, GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWE,
GLWEInfos, LWEInfos,
},
};
impl GGLWE<Vec<u8>> {
pub fn from_gglw_tmp_bytes<R, A, M, BE: Backend>(module: &M, res_infos: &R, tsk_infos: &A) -> usize
where
M: GGSWFromGGLWE<BE>,
R: GGSWInfos,
A: GGLWEInfos,
{
module.ggsw_from_gglwe_tmp_bytes(res_infos, tsk_infos)
}
}
impl<D: DataMut> GGSW<D> {
pub fn from_gglwe<G, M, T, BE: Backend>(&mut self, module: &M, gglwe: &G, tsk: &T, scratch: &mut Scratch<BE>)
where
M: GGSWFromGGLWE<BE>,
G: GGLWEToRef,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.ggsw_from_gglwe(self, gglwe, tsk, scratch);
}
}
impl<BE: Backend> GGSWFromGGLWE<BE> for Module<BE>
where
Self: GGSWExpandRows<BE> + GLWECopy,
{
fn ggsw_from_gglwe_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
where
R: GGSWInfos,
A: GGLWEInfos,
{
let lvl_0: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos);
lvl_0
}
fn ggsw_from_gglwe<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
A: GGLWEToRef,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
assert_eq!(res.rank(), a.rank_out());
assert_eq!(res.dnum(), a.dnum());
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(tsk.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(
scratch.available() >= self.ggsw_from_gglwe_tmp_bytes(res, tsk),
"scratch.available(): {} < GGSWFromGGLWE::ggsw_from_gglwe_tmp_bytes: {}",
scratch.available(),
self.ggsw_from_gglwe_tmp_bytes(res, tsk)
);
for row in 0..res.dnum().into() {
self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0));
}
self.ggsw_expand_row(res, tsk, scratch);
}
}
pub trait GGSWFromGGLWE<BE: Backend> {
fn ggsw_from_gglwe_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
where
R: GGSWInfos,
A: GGLWEInfos;
fn ggsw_from_gglwe<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
A: GGLWEToRef,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>;
}
pub trait GGSWExpandRows<BE: Backend> {
fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
where
R: GGSWInfos,
A: GGLWEInfos;
fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend> GGSWExpandRows<BE> for Module<BE>
where
Self: GGLWEProduct<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes
+ VecZnxBigBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftApply<BE>
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxCopy,
{
fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
where
R: GGSWInfos,
A: GGLWEInfos,
{
assert_eq!(self.n() as u32, res_infos.n());
assert_eq!(self.n() as u32, tsk_infos.n());
let tsk_base2k: usize = tsk_infos.base2k().into();
let rank: usize = res_infos.rank().into();
let cols: usize = rank + 1;
let res_size: usize = res_infos.size();
let a_size: usize = res_infos.max_k().as_usize().div_ceil(tsk_base2k);
let lvl_0: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size) + VecZnx::bytes_of(self.n(), 1, a_size);
let lvl_1_res_dft: usize = self.bytes_of_vec_znx_dft(cols, a_size);
let lvl_1_gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos);
let lvl_1_norm_big: usize = self.vec_znx_big_normalize_tmp_bytes();
let lvl_1: usize = lvl_1_res_dft + lvl_1_gglwe_prod.max(lvl_1_norm_big);
let lvl_2: usize = if res_infos.base2k() == tsk_infos.base2k() {
0
} else {
self.vec_znx_normalize_tmp_bytes()
};
lvl_0 + lvl_1.max(lvl_2)
}
fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
let res_base2k: usize = res.base2k().into();
let tsk_base2k: usize = tsk.base2k().into();
assert!(
scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk),
"scratch.available(): {} < GGSWExpandRows::ggsw_expand_rows_tmp_bytes: {}",
scratch.available(),
self.ggsw_expand_rows_tmp_bytes(res, tsk)
);
let rank: usize = res.rank().into();
let cols: usize = rank + 1;
let res_conv_size: usize = res.max_k().as_usize().div_ceil(tsk_base2k);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size);
let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size);
for row in 0..res.dnum().as_usize() {
let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
if res_base2k == tsk_base2k {
for col_i in 0..cols - 1 {
self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
}
self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0);
} else {
for i in 0..cols - 1 {
self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, i + 1, scratch_2);
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0);
}
self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, 0, scratch_2);
}
ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2)
}
}
}
fn ggsw_expand_rows_internal<M, R, C, A, T, BE: Backend>(
module: &M,
row: usize,
res: &mut R,
a_0: &C,
a_dft: &A,
tsk: &T,
scratch: &mut Scratch<BE>,
) where
R: GGSWToMut,
C: VecZnxToRef,
A: VecZnxDftToRef<BE>,
M: GGLWEProduct<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigAddSmallInplace<BE> + VecZnxBigNormalize<BE>,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a_0: &VecZnx<&[u8]> = &a_0.to_ref();
let a_dft: &VecZnxDft<&[u8], BE> = &a_dft.to_ref();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
let cols: usize = res.rank().as_usize() + 1;
for col in 1..cols {
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); res_dft.zero();
module.gglwe_product_dft(&mut res_dft, a_dft, tsk.at(col - 1), scratch_1);
let mut res_big: VecZnxBig<&mut [u8], BE> = module.vec_znx_idft_apply_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0);
let res_base2k: usize = res.base2k().as_usize();
for j in 0..cols {
module.vec_znx_big_normalize(
res.at_mut(row, col).data_mut(),
res_base2k,
0,
j,
&res_big,
tsk.base2k().as_usize(),
j,
scratch_1,
);
}
}
}