poulpy_core/encryption/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4        VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize,
5        VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero},
8    source::Source,
9};
10
11use crate::{
12    TakeGLWEPt,
13    layouts::{GGSWCiphertext, GLWECiphertext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
18    where
19        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
20    {
21        let size = k.div_ceil(basek);
22        GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
23            + VecZnx::alloc_bytes(module.n(), rank + 1, size)
24            + VecZnx::alloc_bytes(module.n(), 1, size)
25            + module.vec_znx_dft_alloc_bytes(rank + 1, size)
26    }
27}
28
29impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
30    #[allow(clippy::too_many_arguments)]
31    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
32        &mut self,
33        module: &Module<B>,
34        pt: &ScalarZnx<DataPt>,
35        sk: &GLWESecretPrepared<DataSk, B>,
36        source_xa: &mut Source,
37        source_xe: &mut Source,
38        scratch: &mut Scratch<B>,
39    ) where
40        Module<B>: VecZnxAddScalarInplace
41            + VecZnxDftAllocBytes
42            + VecZnxBigNormalize<B>
43            + DFT<B>
44            + SvpApplyInplace<B>
45            + IDFTConsume<B>
46            + VecZnxNormalizeTmpBytes
47            + VecZnxFillUniform
48            + VecZnxSubABInplace
49            + VecZnxAddInplace
50            + VecZnxNormalizeInplace<B>
51            + VecZnxAddNormal
52            + VecZnxNormalize<B>
53            + VecZnxSub,
54        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
55    {
56        #[cfg(debug_assertions)]
57        {
58            use poulpy_hal::layouts::ZnxInfos;
59
60            assert_eq!(self.rank(), sk.rank());
61            assert_eq!(self.n(), sk.n());
62            assert_eq!(pt.n(), sk.n());
63        }
64
65        let basek: usize = self.basek();
66        let k: usize = self.k();
67        let rank: usize = self.rank();
68        let digits: usize = self.digits();
69
70        let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(self.n(), basek, k);
71
72        (0..self.rows()).for_each(|row_i| {
73            tmp_pt.data.zero();
74
75            // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
76            module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
77            module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch1);
78
79            (0..rank + 1).for_each(|col_j| {
80                // rlwe encrypt of vec_znx_pt into vec_znx_ct
81
82                self.at_mut(row_i, col_j).encrypt_sk_internal(
83                    module,
84                    Some((&tmp_pt, col_j)),
85                    sk,
86                    source_xa,
87                    source_xe,
88                    scratch1,
89                );
90            });
91        });
92    }
93}