use poulpy_hal::{
api::{
ScratchAvailable, ScratchTakeBasic, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
},
layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
source::Source,
};
use crate::{
ScratchTakeCore,
encryption::{SIGMA, SIGMA_BOUND},
layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut},
};
impl LWE<Vec<u8>> {
pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
where
A: LWEInfos,
M: LWEEncryptSk<BE>,
{
module.lwe_encrypt_sk_tmp_bytes(infos)
}
}
impl<DataSelf: DataMut> LWE<DataSelf> {
pub fn encrypt_sk<P, S, M, BE: Backend>(
&mut self,
module: &M,
pt: &P,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
P: LWEPlaintextToRef,
S: LWESecretToRef,
M: LWEEncryptSk<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch);
}
}
pub trait LWEEncryptSk<BE: Backend> {
fn lwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: LWEInfos;
fn lwe_encrypt_sk<R, P, S>(
&self,
res: &mut R,
pt: &P,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: LWEToMut,
P: LWEPlaintextToRef,
S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend> LWEEncryptSk<BE> for Module<BE>
where
Self: Sized + VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace<BE> + VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchTakeBasic + ScratchAvailable,
{
fn lwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: LWEInfos,
{
let size: usize = infos.size();
let lvl_0: usize = LWEPlaintext::bytes_of(size);
let lvl_1: usize = self.vec_znx_normalize_tmp_bytes();
lvl_0 + lvl_1
}
fn lwe_encrypt_sk<R, P, S>(
&self,
res: &mut R,
pt: &P,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: LWEToMut,
P: LWEPlaintextToRef,
S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
let pt: &LWEPlaintext<&[u8]> = &pt.to_ref();
let sk: &LWESecret<&[u8]> = &sk.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), sk.n())
}
assert!(
scratch.available() >= self.lwe_encrypt_sk_tmp_bytes(res),
"scratch.available(): {} < LWEEncryptSk::lwe_encrypt_sk_tmp_bytes: {}",
scratch.available(),
self.lwe_encrypt_sk_tmp_bytes(res)
);
let base2k: usize = res.base2k().into();
let k: usize = res.k().into();
self.vec_znx_fill_uniform(base2k, &mut res.data, 0, source_xa);
let (mut tmp_znx, scratch_1) = scratch.take_vec_znx(1, 1, res.size());
tmp_znx.zero();
let min_size: usize = res.size().min(pt.size());
(0..min_size).for_each(|i| {
tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0]
- res.data.at(0, i)[1..]
.iter()
.zip(sk.data.at(0, 0))
.map(|(x, y)| x * y)
.sum::<i64>();
});
(min_size..res.size()).for_each(|i| {
tmp_znx.at_mut(0, i)[0] -= res.data.at(0, i)[1..]
.iter()
.zip(sk.data.at(0, 0))
.map(|(x, y)| x * y)
.sum::<i64>();
});
self.vec_znx_add_normal(base2k, &mut tmp_znx, 0, k, source_xe, SIGMA, SIGMA_BOUND);
self.vec_znx_normalize_inplace(base2k, &mut tmp_znx, 0, scratch_1);
(0..res.size()).for_each(|i| {
res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0];
});
}
}