poulpy_core/decryption/
lwe.rs

1use poulpy_hal::{
2    api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace},
3    layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut},
4};
5
6use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut};
7
8impl<DataSelf: DataRef + DataMut> LWE<DataSelf> {
9    pub fn decrypt<P, S, M, B: Backend>(&mut self, module: &M, pt: &mut P, sk: &S)
10    where
11        P: LWEPlaintextToMut,
12        S: LWESecretToRef,
13        M: LWEDecrypt<B>,
14    {
15        module.lwe_decrypt(self, pt, sk);
16    }
17}
18
19pub trait LWEDecrypt<BE: Backend> {
20    fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S)
21    where
22        R: LWEToMut,
23        P: LWEPlaintextToMut,
24        S: LWESecretToRef;
25}
26
27impl<BE: Backend> LWEDecrypt<BE> for Module<BE>
28where
29    Self: Sized + ZnNormalizeInplace<BE>,
30    ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
31{
32    fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S)
33    where
34        R: LWEToMut,
35        P: LWEPlaintextToMut,
36        S: LWESecretToRef,
37    {
38        let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
39        let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut();
40        let sk: LWESecret<&[u8]> = sk.to_ref();
41
42        #[cfg(debug_assertions)]
43        {
44            assert_eq!(res.n(), sk.n());
45        }
46
47        (0..pt.size().min(res.size())).for_each(|i| {
48            pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0]
49                + res.data.at(0, i)[1..]
50                    .iter()
51                    .zip(sk.data.at(0, 0))
52                    .map(|(x, y)| x * y)
53                    .sum::<i64>();
54        });
55        self.zn_normalize_inplace(
56            1,
57            res.base2k().into(),
58            &mut pt.data,
59            0,
60            ScratchOwned::alloc(size_of::<i64>()).borrow(),
61        );
62        pt.base2k = res.base2k();
63        pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0));
64    }
65}