poulpy_core/encryption/
glwe.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, ScratchAvailable, ScratchTakeBasic, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare,
4        VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
5        VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
6        VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
7    },
8    layouts::{Backend, DataMut, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, VecZnxToMut, ZnxInfos, ZnxZero},
9    source::Source,
10};
11
12use crate::{
13    GetDistribution, ScratchTakeCore,
14    dist::Distribution,
15    encryption::{SIGMA, SIGMA_BOUND},
16    layouts::{
17        GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToRef, GLWEPrepared, GLWEPreparedToRef, GLWEToMut, LWEInfos,
18        prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
19    },
20};
21
22impl GLWE<Vec<u8>> {
23    pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
24    where
25        A: GLWEInfos,
26        M: GLWEEncryptSk<BE>,
27    {
28        module.glwe_encrypt_sk_tmp_bytes(infos)
29    }
30
31    pub fn encrypt_pk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
32    where
33        A: GLWEInfos,
34        M: GLWEEncryptPk<BE>,
35    {
36        module.glwe_encrypt_pk_tmp_bytes(infos)
37    }
38}
39
40impl<D: DataMut> GLWE<D> {
41    pub fn encrypt_sk<P, S, M, BE: Backend>(
42        &mut self,
43        module: &M,
44        pt: &P,
45        sk: &S,
46        source_xa: &mut Source,
47        source_xe: &mut Source,
48        scratch: &mut Scratch<BE>,
49    ) where
50        P: GLWEPlaintextToRef,
51        S: GLWESecretPreparedToRef<BE>,
52        M: GLWEEncryptSk<BE>,
53        Scratch<BE>: ScratchTakeCore<BE>,
54    {
55        module.glwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch);
56    }
57
58    pub fn encrypt_zero_sk<S, M, BE: Backend>(
59        &mut self,
60        module: &M,
61        sk: &S,
62        source_xa: &mut Source,
63        source_xe: &mut Source,
64        scratch: &mut Scratch<BE>,
65    ) where
66        S: GLWESecretPreparedToRef<BE>,
67        M: GLWEEncryptSk<BE>,
68        Scratch<BE>: ScratchTakeCore<BE>,
69    {
70        module.glwe_encrypt_zero_sk(self, sk, source_xa, source_xe, scratch);
71    }
72
73    pub fn encrypt_pk<P, K, M, BE: Backend>(
74        &mut self,
75        module: &M,
76        pt: &P,
77        pk: &K,
78        source_xu: &mut Source,
79        source_xe: &mut Source,
80        scratch: &mut Scratch<BE>,
81    ) where
82        P: GLWEPlaintextToRef + GLWEInfos,
83        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
84        M: GLWEEncryptPk<BE>,
85    {
86        module.glwe_encrypt_pk(self, pt, pk, source_xu, source_xe, scratch);
87    }
88
89    pub fn encrypt_zero_pk<K, M, BE: Backend>(
90        &mut self,
91        module: &M,
92        pk: &K,
93        source_xu: &mut Source,
94        source_xe: &mut Source,
95        scratch: &mut Scratch<BE>,
96    ) where
97        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
98        M: GLWEEncryptPk<BE>,
99    {
100        module.glwe_encrypt_zero_pk(self, pk, source_xu, source_xe, scratch);
101    }
102}
103
104pub trait GLWEEncryptSk<BE: Backend> {
105    fn glwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
106    where
107        A: GLWEInfos;
108
109    fn glwe_encrypt_sk<R, P, S>(
110        &self,
111        res: &mut R,
112        pt: &P,
113        sk: &S,
114        source_xa: &mut Source,
115        source_xe: &mut Source,
116        scratch: &mut Scratch<BE>,
117    ) where
118        R: GLWEToMut,
119        P: GLWEPlaintextToRef,
120        S: GLWESecretPreparedToRef<BE>;
121
122    fn glwe_encrypt_zero_sk<R, S>(
123        &self,
124        res: &mut R,
125        sk: &S,
126        source_xa: &mut Source,
127        source_xe: &mut Source,
128        scratch: &mut Scratch<BE>,
129    ) where
130        R: GLWEToMut,
131        S: GLWESecretPreparedToRef<BE>;
132}
133
134impl<BE: Backend> GLWEEncryptSk<BE> for Module<BE>
135where
136    Self: Sized + ModuleN + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + GLWEEncryptSkInternal<BE>,
137    Scratch<BE>: ScratchAvailable,
138{
139    fn glwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
140    where
141        A: GLWEInfos,
142    {
143        let size: usize = infos.size();
144        assert_eq!(self.n() as u32, infos.n());
145        self.vec_znx_normalize_tmp_bytes() + 2 * VecZnx::bytes_of(self.n(), 1, size) + self.bytes_of_vec_znx_dft(1, size)
146    }
147
148    fn glwe_encrypt_sk<R, P, S>(
149        &self,
150        res: &mut R,
151        pt: &P,
152        sk: &S,
153        source_xa: &mut Source,
154        source_xe: &mut Source,
155        scratch: &mut Scratch<BE>,
156    ) where
157        R: GLWEToMut,
158        P: GLWEPlaintextToRef,
159        S: GLWESecretPreparedToRef<BE>,
160    {
161        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
162        let pt: &GLWEPlaintext<&[u8]> = &pt.to_ref();
163        let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
164
165        assert_eq!(res.rank(), sk.rank());
166        assert_eq!(res.n(), self.n() as u32);
167        assert_eq!(sk.n(), self.n() as u32);
168        assert_eq!(pt.n(), self.n() as u32);
169        assert!(
170            scratch.available() >= self.glwe_encrypt_sk_tmp_bytes(res),
171            "scratch.available(): {} < GLWE::encrypt_sk_tmp_bytes: {}",
172            scratch.available(),
173            self.glwe_encrypt_sk_tmp_bytes(res)
174        );
175
176        let cols: usize = (res.rank() + 1).into();
177        self.glwe_encrypt_sk_internal(
178            res.base2k().into(),
179            res.k().into(),
180            res.data_mut(),
181            cols,
182            false,
183            Some((pt, 0)),
184            sk,
185            source_xa,
186            source_xe,
187            SIGMA,
188            scratch,
189        );
190    }
191
192    fn glwe_encrypt_zero_sk<R, S>(
193        &self,
194        res: &mut R,
195        sk: &S,
196        source_xa: &mut Source,
197        source_xe: &mut Source,
198        scratch: &mut Scratch<BE>,
199    ) where
200        R: GLWEToMut,
201        S: GLWESecretPreparedToRef<BE>,
202    {
203        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
204        let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
205
206        assert_eq!(res.rank(), sk.rank());
207        assert_eq!(res.n(), self.n() as u32);
208        assert_eq!(sk.n(), self.n() as u32);
209        assert!(
210            scratch.available() >= self.glwe_encrypt_sk_tmp_bytes(res),
211            "scratch.available(): {} < GLWE::encrypt_sk_tmp_bytes: {}",
212            scratch.available(),
213            self.glwe_encrypt_sk_tmp_bytes(res)
214        );
215
216        let cols: usize = (res.rank() + 1).into();
217        self.glwe_encrypt_sk_internal(
218            res.base2k().into(),
219            res.k().into(),
220            res.data_mut(),
221            cols,
222            false,
223            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
224            sk,
225            source_xa,
226            source_xe,
227            SIGMA,
228            scratch,
229        );
230    }
231}
232
233pub trait GLWEEncryptPk<BE: Backend> {
234    fn glwe_encrypt_pk_tmp_bytes<A>(&self, infos: &A) -> usize
235    where
236        A: GLWEInfos;
237
238    fn glwe_encrypt_pk<R, P, K>(
239        &self,
240        res: &mut R,
241        pt: &P,
242        pk: &K,
243        source_xu: &mut Source,
244        source_xe: &mut Source,
245        scratch: &mut Scratch<BE>,
246    ) where
247        R: GLWEToMut,
248        P: GLWEPlaintextToRef + GLWEInfos,
249        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos;
250
251    fn glwe_encrypt_zero_pk<R, K>(
252        &self,
253        res: &mut R,
254        pk: &K,
255        source_xu: &mut Source,
256        source_xe: &mut Source,
257        scratch: &mut Scratch<BE>,
258    ) where
259        R: GLWEToMut,
260        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos;
261}
262
263impl<BE: Backend> GLWEEncryptPk<BE> for Module<BE>
264where
265    Self: GLWEEncryptPkInternal<BE> + VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes,
266{
267    fn glwe_encrypt_pk_tmp_bytes<A>(&self, infos: &A) -> usize
268    where
269        A: GLWEInfos,
270    {
271        let size: usize = infos.size();
272        assert_eq!(self.n() as u32, infos.n());
273        ((self.bytes_of_vec_znx_dft(1, size) + self.bytes_of_vec_znx_big(1, size)).max(ScalarZnx::bytes_of(self.n(), 1)))
274            + self.bytes_of_svp_ppol(1)
275            + self.vec_znx_normalize_tmp_bytes()
276    }
277
278    fn glwe_encrypt_pk<R, P, K>(
279        &self,
280        res: &mut R,
281        pt: &P,
282        pk: &K,
283        source_xu: &mut Source,
284        source_xe: &mut Source,
285        scratch: &mut Scratch<BE>,
286    ) where
287        R: GLWEToMut,
288        P: GLWEPlaintextToRef + GLWEInfos,
289        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
290    {
291        self.glwe_encrypt_pk_internal(res, Some((pt, 0)), pk, source_xu, source_xe, scratch);
292    }
293
294    fn glwe_encrypt_zero_pk<R, K>(
295        &self,
296        res: &mut R,
297        pk: &K,
298        source_xu: &mut Source,
299        source_xe: &mut Source,
300        scratch: &mut Scratch<BE>,
301    ) where
302        R: GLWEToMut,
303        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
304    {
305        self.glwe_encrypt_pk_internal(
306            res,
307            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
308            pk,
309            source_xu,
310            source_xe,
311            scratch,
312        );
313    }
314}
315
316pub(crate) trait GLWEEncryptPkInternal<BE: Backend> {
317    fn glwe_encrypt_pk_internal<R, P, K>(
318        &self,
319        res: &mut R,
320        pt: Option<(&P, usize)>,
321        pk: &K,
322        source_xu: &mut Source,
323        source_xe: &mut Source,
324        scratch: &mut Scratch<BE>,
325    ) where
326        R: GLWEToMut,
327        P: GLWEPlaintextToRef + GLWEInfos,
328        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos;
329}
330
331impl<BE: Backend> GLWEEncryptPkInternal<BE> for Module<BE>
332where
333    Self: SvpPrepare<BE>
334        + SvpApplyDftToDft<BE>
335        + VecZnxIdftApplyConsume<BE>
336        + VecZnxBigAddNormal<BE>
337        + VecZnxBigAddSmallInplace<BE>
338        + VecZnxBigNormalize<BE>
339        + SvpPPolBytesOf
340        + ModuleN
341        + VecZnxDftBytesOf,
342    Scratch<BE>: ScratchTakeBasic,
343{
344    fn glwe_encrypt_pk_internal<R, P, K>(
345        &self,
346        res: &mut R,
347        pt: Option<(&P, usize)>,
348        pk: &K,
349        source_xu: &mut Source,
350        source_xe: &mut Source,
351        scratch: &mut Scratch<BE>,
352    ) where
353        R: GLWEToMut,
354        P: GLWEPlaintextToRef + GLWEInfos,
355        K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
356    {
357        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
358
359        assert_eq!(res.base2k(), pk.base2k());
360        assert_eq!(res.n(), pk.n());
361        assert_eq!(res.rank(), pk.rank());
362        if let Some((pt, _)) = pt {
363            assert_eq!(pt.base2k(), pk.base2k());
364            assert_eq!(pt.n(), pk.n());
365        }
366
367        let base2k: usize = pk.base2k().into();
368        let size_pk: usize = pk.size();
369        let cols: usize = (res.rank() + 1).into();
370
371        // Generates u according to the underlying secret distribution.
372        let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self, 1);
373
374        {
375            let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1);
376            match pk.dist() {
377                Distribution::NONE => panic!(
378                    "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
379                     Self::generate"
380                ),
381                Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, *hw, source_xu),
382                Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, *prob, source_xu),
383                Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, *hw, source_xu),
384                Distribution::BinaryProb(prob) => u.fill_binary_prob(0, *prob, source_xu),
385                Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, *block_size, source_xu),
386                Distribution::ZERO => {}
387            }
388
389            self.svp_prepare(&mut u_dft, 0, &u, 0);
390        }
391
392        {
393            let pk: &GLWEPrepared<&[u8], BE> = &pk.to_ref();
394
395            // ct[i] = pk[i] * u + ei (+ m if col = i)
396            for i in 0..cols {
397                let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, 1, size_pk);
398                // ci_dft = DFT(u) * DFT(pk[i])
399                self.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
400
401                // ci_big = u * p[i]
402                let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft);
403
404                // ci_big = u * pk[i] + e
405                self.vec_znx_big_add_normal(
406                    base2k,
407                    &mut ci_big,
408                    0,
409                    pk.k().into(),
410                    source_xe,
411                    SIGMA,
412                    SIGMA_BOUND,
413                );
414
415                // ci_big = u * pk[i] + e + m (if col = i)
416                if let Some((pt, col)) = pt
417                    && col == i
418                {
419                    self.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.to_ref().data, 0);
420                }
421
422                // ct[i] = norm(ci_big)
423                self.vec_znx_big_normalize(base2k, &mut res.data, i, base2k, &ci_big, 0, scratch_2);
424            }
425        }
426    }
427}
428
429pub(crate) trait GLWEEncryptSkInternal<BE: Backend> {
430    #[allow(clippy::too_many_arguments)]
431    fn glwe_encrypt_sk_internal<R, P, S>(
432        &self,
433        base2k: usize,
434        k: usize,
435        res: &mut R,
436        cols: usize,
437        compressed: bool,
438        pt: Option<(&P, usize)>,
439        sk: &S,
440        source_xa: &mut Source,
441        source_xe: &mut Source,
442        sigma: f64,
443        scratch: &mut Scratch<BE>,
444    ) where
445        R: VecZnxToMut,
446        P: GLWEPlaintextToRef,
447        S: GLWESecretPreparedToRef<BE>;
448}
449
450impl<BE: Backend> GLWEEncryptSkInternal<BE> for Module<BE>
451where
452    Self: ModuleN
453        + VecZnxDftBytesOf
454        + VecZnxBigNormalize<BE>
455        + VecZnxDftApply<BE>
456        + SvpApplyDftToDftInplace<BE>
457        + VecZnxIdftApplyConsume<BE>
458        + VecZnxNormalizeTmpBytes
459        + VecZnxFillUniform
460        + VecZnxSubInplace
461        + VecZnxAddInplace
462        + VecZnxNormalizeInplace<BE>
463        + VecZnxAddNormal
464        + VecZnxNormalize<BE>
465        + VecZnxSub,
466    Scratch<BE>: ScratchTakeBasic,
467{
468    fn glwe_encrypt_sk_internal<R, P, S>(
469        &self,
470        base2k: usize,
471        k: usize,
472        res: &mut R,
473        cols: usize,
474        compressed: bool,
475        pt: Option<(&P, usize)>,
476        sk: &S,
477        source_xa: &mut Source,
478        source_xe: &mut Source,
479        sigma: f64,
480        scratch: &mut Scratch<BE>,
481    ) where
482        R: VecZnxToMut,
483        P: GLWEPlaintextToRef,
484        S: GLWESecretPreparedToRef<BE>,
485    {
486        let ct: &mut VecZnx<&mut [u8]> = &mut res.to_mut();
487        let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref();
488
489        if compressed {
490            assert_eq!(
491                ct.cols(),
492                1,
493                "invalid glwe: compressed tag=true but #cols={} != 1",
494                ct.cols()
495            )
496        }
497
498        assert!(
499            sk.dist != Distribution::NONE,
500            "glwe secret distribution is NONE (have you prepared the key?)"
501        );
502
503        let size: usize = ct.size();
504
505        let (mut c0, scratch_1) = scratch.take_vec_znx(self.n(), 1, size);
506        c0.zero();
507
508        {
509            let (mut ci, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, size);
510
511            // ct[i] = uniform
512            // ct[0] -= c[i] * s[i],
513            (1..cols).for_each(|i| {
514                let col_ct: usize = if compressed { 0 } else { i };
515
516                // ct[i] = uniform (+ pt)
517                self.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa);
518
519                // println!("vec_znx_fill_uniform: {}", ct);
520
521                let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, size);
522
523                // ci = ct[i] - pt
524                // i.e. we act as we sample ct[i] already as uniform + pt
525                // and if there is a pt, then we subtract it before applying DFT
526                if let Some((pt, col)) = pt {
527                    if i == col {
528                        self.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.to_ref().data, 0);
529                        self.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3);
530                        self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
531                    } else {
532                        self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
533                    }
534                } else {
535                    self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
536                }
537
538                self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
539                let ci_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(ci_dft);
540
541                // use c[0] as buffer, which is overwritten later by the normalization step
542                self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3);
543
544                // c0_tmp = -c[i] * s[i] (use c[0] as buffer)
545                self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0);
546            });
547        }
548
549        // c[0] += e
550        self.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
551
552        // c[0] += m if col = 0
553        if let Some((pt, col)) = pt
554            && col == 0
555        {
556            self.vec_znx_add_inplace(&mut c0, 0, &pt.to_ref().data, 0);
557        }
558
559        // c[0] = norm(c[0])
560        self.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1);
561    }
562}