poulpy_core/decryption/
glwe.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, ScratchTakeBasic, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf,
4        VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
5    },
6    layouts::{Backend, DataRef, DataViewMut, Module, Scratch},
7};
8
9use crate::layouts::{
10    GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToMut, GLWEToRef, LWEInfos,
11    prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
12};
13
14impl GLWE<Vec<u8>> {
15    pub fn decrypt_tmp_bytes<A, M, BE: Backend>(module: &M, a_infos: &A) -> usize
16    where
17        A: GLWEInfos,
18        M: GLWEDecrypt<BE>,
19    {
20        module.glwe_decrypt_tmp_bytes(a_infos)
21    }
22}
23
24impl<DataSelf: DataRef> GLWE<DataSelf> {
25    pub fn decrypt<P, S, M, BE: Backend>(&self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
26    where
27        P: GLWEPlaintextToMut,
28        S: GLWESecretPreparedToRef<BE>,
29        M: GLWEDecrypt<BE>,
30        Scratch<BE>: ScratchTakeBasic,
31    {
32        module.glwe_decrypt(self, pt, sk, scratch);
33    }
34}
35
36pub trait GLWEDecrypt<BE: Backend>
37where
38    Self: Sized
39        + ModuleN
40        + VecZnxDftBytesOf
41        + VecZnxNormalizeTmpBytes
42        + VecZnxBigBytesOf
43        + VecZnxDftApply<BE>
44        + SvpApplyDftToDftInplace<BE>
45        + VecZnxIdftApplyConsume<BE>
46        + VecZnxBigAddInplace<BE>
47        + VecZnxBigAddSmallInplace<BE>
48        + VecZnxBigNormalize<BE>,
49{
50    fn glwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
51    where
52        A: GLWEInfos,
53    {
54        let size: usize = infos.size();
55        (self.vec_znx_normalize_tmp_bytes() | self.bytes_of_vec_znx_dft(1, size)) + self.bytes_of_vec_znx_dft(1, size)
56    }
57
58    fn glwe_decrypt<R, P, S>(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
59    where
60        R: GLWEToRef,
61        P: GLWEPlaintextToMut,
62        S: GLWESecretPreparedToRef<BE>,
63        Scratch<BE>: ScratchTakeBasic,
64    {
65        let res: &GLWE<&[u8]> = &res.to_ref();
66        let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref();
67        let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
68
69        #[cfg(debug_assertions)]
70        {
71            assert_eq!(res.rank(), sk.rank());
72            assert_eq!(res.n(), sk.n());
73            assert_eq!(pt.n(), sk.n());
74        }
75
76        let cols: usize = (res.rank() + 1).into();
77
78        let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size()); // TODO optimize size when pt << ct
79        c0_big.data_mut().fill(0);
80
81        {
82            (1..cols).for_each(|i| {
83                // ci_dft = DFT(a[i]) * DFT(s[i])
84                let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); // TODO optimize size when pt << ct
85                self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &res.data, i);
86                self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
87                let ci_big = self.vec_znx_idft_apply_consume(ci_dft);
88
89                // c0_big += a[i] * s[i]
90                self.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
91            });
92        }
93
94        // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
95        self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0);
96
97        // pt = norm(BIG(m + e))
98        self.vec_znx_big_normalize(
99            res.base2k().into(),
100            &mut pt.data,
101            0,
102            res.base2k().into(),
103            &c0_big,
104            0,
105            scratch_1,
106        );
107
108        pt.base2k = res.base2k();
109        pt.k = pt.k().min(res.k());
110    }
111}
112
113impl<BE: Backend> GLWEDecrypt<BE> for Module<BE> where
114    Self: ModuleN
115        + VecZnxDftBytesOf
116        + VecZnxNormalizeTmpBytes
117        + VecZnxBigBytesOf
118        + VecZnxDftApply<BE>
119        + SvpApplyDftToDftInplace<BE>
120        + VecZnxIdftApplyConsume<BE>
121        + VecZnxBigAddInplace<BE>
122        + VecZnxBigAddSmallInplace<BE>
123        + VecZnxBigNormalize<BE>
124{
125}