poulpy_core/encryption/compressed/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4        VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5        VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
8    source::Source,
9};
10
11use crate::{
12    TakeGLWEPt,
13    encryption::{SIGMA, glwe_encrypt_sk_internal},
14    layouts::{GGLWECiphertext, GGLWELayoutInfos, LWEInfos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared},
15};
16
17impl GGLWECiphertextCompressed<Vec<u8>> {
18    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
19    where
20        A: GGLWELayoutInfos,
21        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
22    {
23        GGLWECiphertext::encrypt_sk_scratch_space(module, infos)
24    }
25}
26
27impl<D: DataMut> GGLWECiphertextCompressed<D> {
28    #[allow(clippy::too_many_arguments)]
29    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
30        &mut self,
31        module: &Module<B>,
32        pt: &ScalarZnx<DataPt>,
33        sk: &GLWESecretPrepared<DataSk, B>,
34        seed: [u8; 32],
35        source_xe: &mut Source,
36        scratch: &mut Scratch<B>,
37    ) where
38        Module<B>: VecZnxAddScalarInplace
39            + VecZnxDftAllocBytes
40            + VecZnxBigNormalize<B>
41            + VecZnxDftApply<B>
42            + SvpApplyDftToDftInplace<B>
43            + VecZnxIdftApplyConsume<B>
44            + VecZnxNormalizeTmpBytes
45            + VecZnxFillUniform
46            + VecZnxSubInplace
47            + VecZnxAddInplace
48            + VecZnxNormalizeInplace<B>
49            + VecZnxAddNormal
50            + VecZnxNormalize<B>
51            + VecZnxSub,
52        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
53    {
54        #[cfg(debug_assertions)]
55        {
56            use poulpy_hal::layouts::ZnxInfos;
57
58            assert_eq!(
59                self.rank_in(),
60                pt.cols() as u32,
61                "self.rank_in(): {} != pt.cols(): {}",
62                self.rank_in(),
63                pt.cols()
64            );
65            assert_eq!(
66                self.rank_out(),
67                sk.rank(),
68                "self.rank_out(): {} != sk.rank(): {}",
69                self.rank_out(),
70                sk.rank()
71            );
72            assert_eq!(self.n(), sk.n());
73            assert_eq!(pt.n() as u32, sk.n());
74            assert!(
75                scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self),
76                "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space: {}",
77                scratch.available(),
78                GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self)
79            );
80            assert!(
81                self.rows().0 * self.digits().0 * self.base2k().0 <= self.k().0,
82                "self.rows() : {} * self.digits() : {} * self.base2k() : {} = {} >= self.k() = {}",
83                self.rows(),
84                self.digits(),
85                self.base2k(),
86                self.rows().0 * self.digits().0 * self.base2k().0,
87                self.k()
88            );
89        }
90
91        let rows: usize = self.rows().into();
92        let digits: usize = self.digits().into();
93        let base2k: usize = self.base2k().into();
94        let rank_in: usize = self.rank_in().into();
95        let cols: usize = (self.rank_out() + 1).into();
96
97        let mut source_xa = Source::new(seed);
98
99        let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self);
100        (0..rank_in).for_each(|col_i| {
101            (0..rows).for_each(|row_i| {
102                // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
103                tmp_pt.data.zero(); // zeroes for next iteration
104                module.vec_znx_add_scalar_inplace(
105                    &mut tmp_pt.data,
106                    0,
107                    (digits - 1) + row_i * digits,
108                    pt,
109                    col_i,
110                );
111                module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
112
113                let (seed, mut source_xa_tmp) = source_xa.branch();
114                self.seed[col_i * rows + row_i] = seed;
115
116                glwe_encrypt_sk_internal(
117                    module,
118                    self.base2k().into(),
119                    self.k().into(),
120                    &mut self.at_mut(row_i, col_i).data,
121                    cols,
122                    true,
123                    Some((&tmp_pt, 0)),
124                    sk,
125                    &mut source_xa_tmp,
126                    source_xe,
127                    SIGMA,
128                    scrach_1,
129                );
130            });
131        });
132    }
133}