poulpy_core/encryption/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4        VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5        VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero},
8    source::Source,
9};
10
11use crate::{
12    TakeGLWEPt,
13    layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GLWESecretPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
18    where
19        A: GGSWInfos,
20        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
21    {
22        let size = infos.size();
23        GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout())
24            + VecZnx::alloc_bytes(module.n(), (infos.rank() + 1).into(), size)
25            + VecZnx::alloc_bytes(module.n(), 1, size)
26            + module.vec_znx_dft_alloc_bytes((infos.rank() + 1).into(), size)
27    }
28}
29
30impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
31    #[allow(clippy::too_many_arguments)]
32    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
33        &mut self,
34        module: &Module<B>,
35        pt: &ScalarZnx<DataPt>,
36        sk: &GLWESecretPrepared<DataSk, B>,
37        source_xa: &mut Source,
38        source_xe: &mut Source,
39        scratch: &mut Scratch<B>,
40    ) where
41        Module<B>: VecZnxAddScalarInplace
42            + VecZnxDftAllocBytes
43            + VecZnxBigNormalize<B>
44            + VecZnxDftApply<B>
45            + SvpApplyDftToDftInplace<B>
46            + VecZnxIdftApplyConsume<B>
47            + VecZnxNormalizeTmpBytes
48            + VecZnxFillUniform
49            + VecZnxSubInplace
50            + VecZnxAddInplace
51            + VecZnxNormalizeInplace<B>
52            + VecZnxAddNormal
53            + VecZnxNormalize<B>
54            + VecZnxSub,
55        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
56    {
57        #[cfg(debug_assertions)]
58        {
59            use poulpy_hal::layouts::ZnxInfos;
60
61            assert_eq!(self.rank(), sk.rank());
62            assert_eq!(self.n(), sk.n());
63            assert_eq!(pt.n() as u32, sk.n());
64        }
65
66        let base2k: usize = self.base2k().into();
67        let rank: usize = self.rank().into();
68        let digits: usize = self.digits().into();
69
70        let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout());
71
72        (0..self.rows().into()).for_each(|row_i| {
73            tmp_pt.data.zero();
74
75            // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
76            module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
77            module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1);
78
79            (0..rank + 1).for_each(|col_j| {
80                // rlwe encrypt of vec_znx_pt into vec_znx_ct
81
82                self.at_mut(row_i, col_j).encrypt_sk_internal(
83                    module,
84                    Some((&tmp_pt, col_j)),
85                    sk,
86                    source_xa,
87                    source_xe,
88                    scratch_1,
89                );
90            });
91        });
92    }
93}