poulpy_core/decryption/
lwe.rs1use poulpy_hal::{
2 api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace},
3 layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut},
4};
5
6use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut};
7
8impl<DataSelf: DataRef + DataMut> LWE<DataSelf> {
9 pub fn decrypt<P, S, M, B: Backend>(&mut self, module: &M, pt: &mut P, sk: &S)
10 where
11 P: LWEPlaintextToMut,
12 S: LWESecretToRef,
13 M: LWEDecrypt<B>,
14 {
15 module.lwe_decrypt(self, pt, sk);
16 }
17}
18
19pub trait LWEDecrypt<BE: Backend> {
20 fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S)
21 where
22 R: LWEToMut,
23 P: LWEPlaintextToMut,
24 S: LWESecretToRef;
25}
26
27impl<BE: Backend> LWEDecrypt<BE> for Module<BE>
28where
29 Self: Sized + ZnNormalizeInplace<BE>,
30 ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
31{
32 fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S)
33 where
34 R: LWEToMut,
35 P: LWEPlaintextToMut,
36 S: LWESecretToRef,
37 {
38 let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
39 let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut();
40 let sk: LWESecret<&[u8]> = sk.to_ref();
41
42 #[cfg(debug_assertions)]
43 {
44 assert_eq!(res.n(), sk.n());
45 }
46
47 (0..pt.size().min(res.size())).for_each(|i| {
48 pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0]
49 + res.data.at(0, i)[1..]
50 .iter()
51 .zip(sk.data.at(0, 0))
52 .map(|(x, y)| x * y)
53 .sum::<i64>();
54 });
55 self.zn_normalize_inplace(
56 1,
57 res.base2k().into(),
58 &mut pt.data,
59 0,
60 ScratchOwned::alloc(size_of::<i64>()).borrow(),
61 );
62 pt.base2k = res.base2k();
63 pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0));
64 }
65}