poulpy_core/encryption/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
4        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform,
5        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, ZnxZero,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx},
8    source::Source,
9};
10
11use crate::{
12    TakeGLWEPt,
13    layouts::{GGSWCiphertext, GLWECiphertext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
18    where
19        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
20    {
21        let size = k.div_ceil(basek);
22        GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k)
23            + VecZnx::alloc_bytes(n, rank + 1, size)
24            + VecZnx::alloc_bytes(n, 1, size)
25            + module.vec_znx_dft_alloc_bytes(n, rank + 1, size)
26    }
27}
28
29impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
30    #[allow(clippy::too_many_arguments)]
31    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
32        &mut self,
33        module: &Module<B>,
34        pt: &ScalarZnx<DataPt>,
35        sk: &GLWESecretPrepared<DataSk, B>,
36        source_xa: &mut Source,
37        source_xe: &mut Source,
38        sigma: f64,
39        scratch: &mut Scratch<B>,
40    ) where
41        Module<B>: VecZnxAddScalarInplace
42            + VecZnxDftAllocBytes
43            + VecZnxBigNormalize<B>
44            + VecZnxDftFromVecZnx<B>
45            + SvpApplyInplace<B>
46            + VecZnxDftToVecZnxBigConsume<B>
47            + VecZnxNormalizeTmpBytes
48            + VecZnxFillUniform
49            + VecZnxSubABInplace
50            + VecZnxAddInplace
51            + VecZnxNormalizeInplace<B>
52            + VecZnxAddNormal
53            + VecZnxNormalize<B>
54            + VecZnxSub,
55        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
56    {
57        #[cfg(debug_assertions)]
58        {
59            use poulpy_hal::api::ZnxInfos;
60
61            assert_eq!(self.rank(), sk.rank());
62            assert_eq!(self.n(), sk.n());
63            assert_eq!(pt.n(), sk.n());
64        }
65
66        let basek: usize = self.basek();
67        let k: usize = self.k();
68        let rank: usize = self.rank();
69        let digits: usize = self.digits();
70
71        let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(self.n(), basek, k);
72
73        (0..self.rows()).for_each(|row_i| {
74            tmp_pt.data.zero();
75
76            // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
77            module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
78            module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch1);
79
80            (0..rank + 1).for_each(|col_j| {
81                // rlwe encrypt of vec_znx_pt into vec_znx_ct
82
83                self.at_mut(row_i, col_j).encrypt_sk_internal(
84                    module,
85                    Some((&tmp_pt, col_j)),
86                    sk,
87                    source_xa,
88                    source_xe,
89                    sigma,
90                    scratch1,
91                );
92            });
93        });
94    }
95}