poulpy_core/decryption/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        SvpApplyDftToDftInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
4        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
5    },
6    layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch},
7};
8
9use crate::layouts::{GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared};
10
11impl GLWECiphertext<Vec<u8>> {
12    pub fn decrypt_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
13    where
14        A: GLWEInfos,
15        Module<B>: VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
16    {
17        let size: usize = infos.size();
18        (module.vec_znx_normalize_tmp_bytes() | module.vec_znx_dft_alloc_bytes(1, size)) + module.vec_znx_dft_alloc_bytes(1, size)
19    }
20}
21
22impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
23    pub fn decrypt<DataPt: DataMut, DataSk: DataRef, B: Backend>(
24        &self,
25        module: &Module<B>,
26        pt: &mut GLWEPlaintext<DataPt>,
27        sk: &GLWESecretPrepared<DataSk, B>,
28        scratch: &mut Scratch<B>,
29    ) where
30        Module<B>: VecZnxDftApply<B>
31            + SvpApplyDftToDftInplace<B>
32            + VecZnxIdftApplyConsume<B>
33            + VecZnxBigAddInplace<B>
34            + VecZnxBigAddSmallInplace<B>
35            + VecZnxBigNormalize<B>,
36        Scratch<B>: TakeVecZnxDft<B> + TakeVecZnxBig<B>,
37    {
38        #[cfg(debug_assertions)]
39        {
40            assert_eq!(self.rank(), sk.rank());
41            assert_eq!(self.n(), sk.n());
42            assert_eq!(pt.n(), sk.n());
43        }
44
45        let cols: usize = (self.rank() + 1).into();
46
47        let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct
48        c0_big.data_mut().fill(0);
49
50        {
51            (1..cols).for_each(|i| {
52                // ci_dft = DFT(a[i]) * DFT(s[i])
53                let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct
54                module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i);
55                module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
56                let ci_big = module.vec_znx_idft_apply_consume(ci_dft);
57
58                // c0_big += a[i] * s[i]
59                module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
60            });
61        }
62
63        // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
64        module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0);
65
66        // pt = norm(BIG(m + e))
67        module.vec_znx_big_normalize(
68            self.base2k().into(),
69            &mut pt.data,
70            0,
71            self.base2k().into(),
72            &c0_big,
73            0,
74            scratch_1,
75        );
76
77        pt.base2k = self.base2k();
78        pt.k = pt.k().min(self.k());
79    }
80}