poulpy_core/encryption/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol,
4        TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace,
5        VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero},
9    source::Source,
10};
11
12use crate::{
13    dist::Distribution,
14    encryption::{SIGMA, SIGMA_BOUND},
15    layouts::{
16        GLWECiphertext, GLWEPlaintext, Infos,
17        prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared},
18    },
19};
20
21impl GLWECiphertext<Vec<u8>> {
22    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
23    where
24        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
25    {
26        let size: usize = k.div_ceil(basek);
27        module.vec_znx_normalize_tmp_bytes()
28            + 2 * VecZnx::alloc_bytes(module.n(), 1, size)
29            + module.vec_znx_dft_alloc_bytes(1, size)
30    }
31    pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
32    where
33        Module<B>: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
34    {
35        let size: usize = k.div_ceil(basek);
36        ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size))
37            | ScalarZnx::alloc_bytes(module.n(), 1))
38            + module.svp_ppol_alloc_bytes(1)
39            + module.vec_znx_normalize_tmp_bytes()
40    }
41}
42
43impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
44    #[allow(clippy::too_many_arguments)]
45    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
46        &mut self,
47        module: &Module<B>,
48        pt: &GLWEPlaintext<DataPt>,
49        sk: &GLWESecretPrepared<DataSk, B>,
50        source_xa: &mut Source,
51        source_xe: &mut Source,
52        scratch: &mut Scratch<B>,
53    ) where
54        Module<B>: VecZnxDftAllocBytes
55            + VecZnxBigNormalize<B>
56            + VecZnxDftApply<B>
57            + SvpApplyDftToDftInplace<B>
58            + VecZnxIdftApplyConsume<B>
59            + VecZnxNormalizeTmpBytes
60            + VecZnxFillUniform
61            + VecZnxSubABInplace
62            + VecZnxAddInplace
63            + VecZnxNormalizeInplace<B>
64            + VecZnxAddNormal
65            + VecZnxNormalize<B>
66            + VecZnxSub,
67        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
68    {
69        #[cfg(debug_assertions)]
70        {
71            assert_eq!(self.rank(), sk.rank());
72            assert_eq!(sk.n(), self.n());
73            assert_eq!(pt.n(), self.n());
74            assert!(
75                scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
76                "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
77                scratch.available(),
78                GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
79            )
80        }
81
82        self.encrypt_sk_internal(module, Some((pt, 0)), sk, source_xa, source_xe, scratch);
83    }
84
85    pub fn encrypt_zero_sk<DataSk: DataRef, B: Backend>(
86        &mut self,
87        module: &Module<B>,
88        sk: &GLWESecretPrepared<DataSk, B>,
89        source_xa: &mut Source,
90        source_xe: &mut Source,
91        scratch: &mut Scratch<B>,
92    ) where
93        Module<B>: VecZnxDftAllocBytes
94            + VecZnxBigNormalize<B>
95            + VecZnxDftApply<B>
96            + SvpApplyDftToDftInplace<B>
97            + VecZnxIdftApplyConsume<B>
98            + VecZnxNormalizeTmpBytes
99            + VecZnxFillUniform
100            + VecZnxSubABInplace
101            + VecZnxAddInplace
102            + VecZnxNormalizeInplace<B>
103            + VecZnxAddNormal
104            + VecZnxNormalize<B>
105            + VecZnxSub,
106        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
107    {
108        #[cfg(debug_assertions)]
109        {
110            assert_eq!(self.rank(), sk.rank());
111            assert_eq!(sk.n(), self.n());
112            assert!(
113                scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
114                "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
115                scratch.available(),
116                GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
117            )
118        }
119        self.encrypt_sk_internal(
120            module,
121            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
122            sk,
123            source_xa,
124            source_xe,
125            scratch,
126        );
127    }
128
129    #[allow(clippy::too_many_arguments)]
130    pub(crate) fn encrypt_sk_internal<DataPt: DataRef, DataSk: DataRef, B: Backend>(
131        &mut self,
132        module: &Module<B>,
133        pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
134        sk: &GLWESecretPrepared<DataSk, B>,
135        source_xa: &mut Source,
136        source_xe: &mut Source,
137        scratch: &mut Scratch<B>,
138    ) where
139        Module<B>: VecZnxDftAllocBytes
140            + VecZnxBigNormalize<B>
141            + VecZnxDftApply<B>
142            + SvpApplyDftToDftInplace<B>
143            + VecZnxIdftApplyConsume<B>
144            + VecZnxNormalizeTmpBytes
145            + VecZnxFillUniform
146            + VecZnxSubABInplace
147            + VecZnxAddInplace
148            + VecZnxNormalizeInplace<B>
149            + VecZnxAddNormal
150            + VecZnxNormalize<B>
151            + VecZnxSub,
152        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
153    {
154        let cols: usize = self.rank() + 1;
155        glwe_encrypt_sk_internal(
156            module,
157            self.basek(),
158            self.k(),
159            &mut self.data,
160            cols,
161            false,
162            pt,
163            sk,
164            source_xa,
165            source_xe,
166            SIGMA,
167            scratch,
168        );
169    }
170
171    #[allow(clippy::too_many_arguments)]
172    pub fn encrypt_pk<DataPt: DataRef, DataPk: DataRef, B: Backend>(
173        &mut self,
174        module: &Module<B>,
175        pt: &GLWEPlaintext<DataPt>,
176        pk: &GLWEPublicKeyPrepared<DataPk, B>,
177        source_xu: &mut Source,
178        source_xe: &mut Source,
179        scratch: &mut Scratch<B>,
180    ) where
181        Module<B>: SvpPrepare<B>
182            + SvpApplyDftToDft<B>
183            + VecZnxIdftApplyConsume<B>
184            + VecZnxBigAddNormal<B>
185            + VecZnxBigAddSmallInplace<B>
186            + VecZnxBigNormalize<B>,
187        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
188    {
189        self.encrypt_pk_internal::<DataPt, DataPk, B>(module, Some((pt, 0)), pk, source_xu, source_xe, scratch);
190    }
191
192    pub fn encrypt_zero_pk<DataPk: DataRef, B: Backend>(
193        &mut self,
194        module: &Module<B>,
195        pk: &GLWEPublicKeyPrepared<DataPk, B>,
196        source_xu: &mut Source,
197        source_xe: &mut Source,
198        scratch: &mut Scratch<B>,
199    ) where
200        Module<B>: SvpPrepare<B>
201            + SvpApplyDftToDft<B>
202            + VecZnxIdftApplyConsume<B>
203            + VecZnxBigAddNormal<B>
204            + VecZnxBigAddSmallInplace<B>
205            + VecZnxBigNormalize<B>,
206        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
207    {
208        self.encrypt_pk_internal::<Vec<u8>, DataPk, B>(
209            module,
210            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
211            pk,
212            source_xu,
213            source_xe,
214            scratch,
215        );
216    }
217
218    #[allow(clippy::too_many_arguments)]
219    pub(crate) fn encrypt_pk_internal<DataPt: DataRef, DataPk: DataRef, B: Backend>(
220        &mut self,
221        module: &Module<B>,
222        pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
223        pk: &GLWEPublicKeyPrepared<DataPk, B>,
224        source_xu: &mut Source,
225        source_xe: &mut Source,
226        scratch: &mut Scratch<B>,
227    ) where
228        Module<B>: SvpPrepare<B>
229            + SvpApplyDftToDft<B>
230            + VecZnxIdftApplyConsume<B>
231            + VecZnxBigAddNormal<B>
232            + VecZnxBigAddSmallInplace<B>
233            + VecZnxBigNormalize<B>,
234        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
235    {
236        #[cfg(debug_assertions)]
237        {
238            assert_eq!(self.basek(), pk.basek());
239            assert_eq!(self.n(), pk.n());
240            assert_eq!(self.rank(), pk.rank());
241            if let Some((pt, _)) = pt {
242                assert_eq!(pt.basek(), pk.basek());
243                assert_eq!(pt.n(), pk.n());
244            }
245        }
246
247        let basek: usize = pk.basek();
248        let size_pk: usize = pk.size();
249        let cols: usize = self.rank() + 1;
250
251        // Generates u according to the underlying secret distribution.
252        let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n(), 1);
253
254        {
255            let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1);
256            match pk.dist {
257                Distribution::NONE => panic!(
258                    "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
259                     Self::generate"
260                ),
261                Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
262                Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
263                Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu),
264                Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu),
265                Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu),
266                Distribution::ZERO => {}
267            }
268
269            module.svp_prepare(&mut u_dft, 0, &u, 0);
270        }
271
272        // ct[i] = pk[i] * u + ei (+ m if col = i)
273        (0..cols).for_each(|i| {
274            let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk);
275            // ci_dft = DFT(u) * DFT(pk[i])
276            module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
277
278            // ci_big = u * p[i]
279            let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft);
280
281            // ci_big = u * pk[i] + e
282            module.vec_znx_big_add_normal(basek, &mut ci_big, 0, pk.k(), source_xe, SIGMA, SIGMA_BOUND);
283
284            // ci_big = u * pk[i] + e + m (if col = i)
285            if let Some((pt, col)) = pt
286                && col == i
287            {
288                module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0);
289            }
290
291            // ct[i] = norm(ci_big)
292            module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2);
293        });
294    }
295}
296
297#[allow(clippy::too_many_arguments)]
298pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk: DataRef, B: Backend>(
299    module: &Module<B>,
300    basek: usize,
301    k: usize,
302    ct: &mut VecZnx<DataCt>,
303    cols: usize,
304    compressed: bool,
305    pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
306    sk: &GLWESecretPrepared<DataSk, B>,
307    source_xa: &mut Source,
308    source_xe: &mut Source,
309    sigma: f64,
310    scratch: &mut Scratch<B>,
311) where
312    Module<B>: VecZnxDftAllocBytes
313        + VecZnxBigNormalize<B>
314        + VecZnxDftApply<B>
315        + SvpApplyDftToDftInplace<B>
316        + VecZnxIdftApplyConsume<B>
317        + VecZnxNormalizeTmpBytes
318        + VecZnxFillUniform
319        + VecZnxSubABInplace
320        + VecZnxAddInplace
321        + VecZnxNormalizeInplace<B>
322        + VecZnxAddNormal
323        + VecZnxNormalize<B>
324        + VecZnxSub,
325    Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
326{
327    #[cfg(debug_assertions)]
328    {
329        if compressed {
330            assert_eq!(
331                ct.cols(),
332                1,
333                "invalid ciphertext: compressed tag=true but #cols={} != 1",
334                ct.cols()
335            )
336        }
337    }
338
339    let size: usize = ct.size();
340
341    let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size);
342    c0.zero();
343
344    {
345        let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size);
346
347        // ct[i] = uniform
348        // ct[0] -= c[i] * s[i],
349        (1..cols).for_each(|i| {
350            let col_ct: usize = if compressed { 0 } else { i };
351
352            // ct[i] = uniform (+ pt)
353            module.vec_znx_fill_uniform(basek, ct, col_ct, source_xa);
354
355            let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
356
357            // ci = ct[i] - pt
358            // i.e. we act as we sample ct[i] already as uniform + pt
359            // and if there is a pt, then we subtract it before applying DFT
360            if let Some((pt, col)) = pt {
361                if i == col {
362                    module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
363                    module.vec_znx_normalize_inplace(basek, &mut ci, 0, scratch_3);
364                    module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
365                } else {
366                    module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
367                }
368            } else {
369                module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
370            }
371
372            module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
373            let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft);
374
375            // use c[0] as buffer, which is overwritten later by the normalization step
376            module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3);
377
378            // c0_tmp = -c[i] * s[i] (use c[0] as buffer)
379            module.vec_znx_sub_ab_inplace(&mut c0, 0, &ci, 0);
380        });
381    }
382
383    // c[0] += e
384    module.vec_znx_add_normal(basek, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
385
386    // c[0] += m if col = 0
387    if let Some((pt, col)) = pt
388        && col == 0
389    {
390        module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0);
391    }
392
393    // c[0] = norm(c[0])
394    module.vec_znx_normalize(basek, ct, 0, &c0, 0, scratch_1);
395}