poulpy_core/encryption/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx,
4        TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
5        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, ZnxInfos, ZnxZero,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig},
9    source::Source,
10};
11
12use crate::{
13    SIX_SIGMA,
14    dist::Distribution,
15    layouts::{
16        GLWECiphertext, GLWEPlaintext, Infos,
17        prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared},
18    },
19};
20
21impl GLWECiphertext<Vec<u8>> {
22    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
23    where
24        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
25    {
26        let size: usize = k.div_ceil(basek);
27        module.vec_znx_normalize_tmp_bytes(n) + 2 * VecZnx::alloc_bytes(n, 1, size) + module.vec_znx_dft_alloc_bytes(n, 1, size)
28    }
29    pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
30    where
31        Module<B>: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
32    {
33        let size: usize = k.div_ceil(basek);
34        ((module.vec_znx_dft_alloc_bytes(n, 1, size) + module.vec_znx_big_alloc_bytes(n, 1, size)) | ScalarZnx::alloc_bytes(n, 1))
35            + module.svp_ppol_alloc_bytes(n, 1)
36            + module.vec_znx_normalize_tmp_bytes(n)
37    }
38}
39
40impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
41    #[allow(clippy::too_many_arguments)]
42    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
43        &mut self,
44        module: &Module<B>,
45        pt: &GLWEPlaintext<DataPt>,
46        sk: &GLWESecretPrepared<DataSk, B>,
47        source_xa: &mut Source,
48        source_xe: &mut Source,
49        sigma: f64,
50        scratch: &mut Scratch<B>,
51    ) where
52        Module<B>: VecZnxDftAllocBytes
53            + VecZnxBigNormalize<B>
54            + VecZnxDftFromVecZnx<B>
55            + SvpApplyInplace<B>
56            + VecZnxDftToVecZnxBigConsume<B>
57            + VecZnxNormalizeTmpBytes
58            + VecZnxFillUniform
59            + VecZnxSubABInplace
60            + VecZnxAddInplace
61            + VecZnxNormalizeInplace<B>
62            + VecZnxAddNormal
63            + VecZnxNormalize<B>
64            + VecZnxSub,
65        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
66    {
67        #[cfg(debug_assertions)]
68        {
69            assert_eq!(self.rank(), sk.rank());
70            assert_eq!(sk.n(), self.n());
71            assert_eq!(pt.n(), self.n());
72            assert!(
73                scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()),
74                "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
75                scratch.available(),
76                GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k())
77            )
78        }
79
80        self.encrypt_sk_internal(
81            module,
82            Some((pt, 0)),
83            sk,
84            source_xa,
85            source_xe,
86            sigma,
87            scratch,
88        );
89    }
90
91    pub fn encrypt_zero_sk<DataSk: DataRef, B: Backend>(
92        &mut self,
93        module: &Module<B>,
94        sk: &GLWESecretPrepared<DataSk, B>,
95        source_xa: &mut Source,
96        source_xe: &mut Source,
97        sigma: f64,
98        scratch: &mut Scratch<B>,
99    ) where
100        Module<B>: VecZnxDftAllocBytes
101            + VecZnxBigNormalize<B>
102            + VecZnxDftFromVecZnx<B>
103            + SvpApplyInplace<B>
104            + VecZnxDftToVecZnxBigConsume<B>
105            + VecZnxNormalizeTmpBytes
106            + VecZnxFillUniform
107            + VecZnxSubABInplace
108            + VecZnxAddInplace
109            + VecZnxNormalizeInplace<B>
110            + VecZnxAddNormal
111            + VecZnxNormalize<B>
112            + VecZnxSub,
113        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
114    {
115        #[cfg(debug_assertions)]
116        {
117            assert_eq!(self.rank(), sk.rank());
118            assert_eq!(sk.n(), self.n());
119            assert!(
120                scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()),
121                "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
122                scratch.available(),
123                GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k())
124            )
125        }
126        self.encrypt_sk_internal(
127            module,
128            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
129            sk,
130            source_xa,
131            source_xe,
132            sigma,
133            scratch,
134        );
135    }
136
137    #[allow(clippy::too_many_arguments)]
138    pub(crate) fn encrypt_sk_internal<DataPt: DataRef, DataSk: DataRef, B: Backend>(
139        &mut self,
140        module: &Module<B>,
141        pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
142        sk: &GLWESecretPrepared<DataSk, B>,
143        source_xa: &mut Source,
144        source_xe: &mut Source,
145        sigma: f64,
146        scratch: &mut Scratch<B>,
147    ) where
148        Module<B>: VecZnxDftAllocBytes
149            + VecZnxBigNormalize<B>
150            + VecZnxDftFromVecZnx<B>
151            + SvpApplyInplace<B>
152            + VecZnxDftToVecZnxBigConsume<B>
153            + VecZnxNormalizeTmpBytes
154            + VecZnxFillUniform
155            + VecZnxSubABInplace
156            + VecZnxAddInplace
157            + VecZnxNormalizeInplace<B>
158            + VecZnxAddNormal
159            + VecZnxNormalize<B>
160            + VecZnxSub,
161        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
162    {
163        let cols: usize = self.rank() + 1;
164        glwe_encrypt_sk_internal(
165            module,
166            self.basek(),
167            self.k(),
168            &mut self.data,
169            cols,
170            false,
171            pt,
172            sk,
173            source_xa,
174            source_xe,
175            sigma,
176            scratch,
177        );
178    }
179
180    #[allow(clippy::too_many_arguments)]
181    pub fn encrypt_pk<DataPt: DataRef, DataPk: DataRef, B: Backend>(
182        &mut self,
183        module: &Module<B>,
184        pt: &GLWEPlaintext<DataPt>,
185        pk: &GLWEPublicKeyPrepared<DataPk, B>,
186        source_xu: &mut Source,
187        source_xe: &mut Source,
188        sigma: f64,
189        scratch: &mut Scratch<B>,
190    ) where
191        Module<B>: SvpPrepare<B>
192            + SvpApply<B>
193            + VecZnxDftToVecZnxBigConsume<B>
194            + VecZnxBigAddNormal<B>
195            + VecZnxBigAddSmallInplace<B>
196            + VecZnxBigNormalize<B>,
197        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
198    {
199        self.encrypt_pk_internal::<DataPt, DataPk, B>(
200            module,
201            Some((pt, 0)),
202            pk,
203            source_xu,
204            source_xe,
205            sigma,
206            scratch,
207        );
208    }
209
210    pub fn encrypt_zero_pk<DataPk: DataRef, B: Backend>(
211        &mut self,
212        module: &Module<B>,
213        pk: &GLWEPublicKeyPrepared<DataPk, B>,
214        source_xu: &mut Source,
215        source_xe: &mut Source,
216        sigma: f64,
217        scratch: &mut Scratch<B>,
218    ) where
219        Module<B>: SvpPrepare<B>
220            + SvpApply<B>
221            + VecZnxDftToVecZnxBigConsume<B>
222            + VecZnxBigAddNormal<B>
223            + VecZnxBigAddSmallInplace<B>
224            + VecZnxBigNormalize<B>,
225        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
226    {
227        self.encrypt_pk_internal::<Vec<u8>, DataPk, B>(
228            module,
229            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
230            pk,
231            source_xu,
232            source_xe,
233            sigma,
234            scratch,
235        );
236    }
237
238    #[allow(clippy::too_many_arguments)]
239    pub(crate) fn encrypt_pk_internal<DataPt: DataRef, DataPk: DataRef, B: Backend>(
240        &mut self,
241        module: &Module<B>,
242        pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
243        pk: &GLWEPublicKeyPrepared<DataPk, B>,
244        source_xu: &mut Source,
245        source_xe: &mut Source,
246        sigma: f64,
247        scratch: &mut Scratch<B>,
248    ) where
249        Module<B>: SvpPrepare<B>
250            + SvpApply<B>
251            + VecZnxDftToVecZnxBigConsume<B>
252            + VecZnxBigAddNormal<B>
253            + VecZnxBigAddSmallInplace<B>
254            + VecZnxBigNormalize<B>,
255        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
256    {
257        #[cfg(debug_assertions)]
258        {
259            assert_eq!(self.basek(), pk.basek());
260            assert_eq!(self.n(), pk.n());
261            assert_eq!(self.rank(), pk.rank());
262            if let Some((pt, _)) = pt {
263                assert_eq!(pt.basek(), pk.basek());
264                assert_eq!(pt.n(), pk.n());
265            }
266        }
267
268        let basek: usize = pk.basek();
269        let size_pk: usize = pk.size();
270        let cols: usize = self.rank() + 1;
271
272        // Generates u according to the underlying secret distribution.
273        let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n(), 1);
274
275        {
276            let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1);
277            match pk.dist {
278                Distribution::NONE => panic!(
279                    "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
280                     Self::generate"
281                ),
282                Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
283                Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
284                Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu),
285                Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu),
286                Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu),
287                Distribution::ZERO => {}
288            }
289
290            module.svp_prepare(&mut u_dft, 0, &u, 0);
291        }
292
293        // ct[i] = pk[i] * u + ei (+ m if col = i)
294        (0..cols).for_each(|i| {
295            let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk);
296            // ci_dft = DFT(u) * DFT(pk[i])
297            module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
298
299            // ci_big = u * p[i]
300            let mut ci_big = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft);
301
302            // ci_big = u * pk[i] + e
303            module.vec_znx_big_add_normal(
304                basek,
305                &mut ci_big,
306                0,
307                pk.k(),
308                source_xe,
309                sigma,
310                sigma * SIX_SIGMA,
311            );
312
313            // ci_big = u * pk[i] + e + m (if col = i)
314            if let Some((pt, col)) = pt
315                && col == i
316            {
317                module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0);
318            }
319
320            // ct[i] = norm(ci_big)
321            module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2);
322        });
323    }
324}
325
326#[allow(clippy::too_many_arguments)]
327pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk: DataRef, B: Backend>(
328    module: &Module<B>,
329    basek: usize,
330    k: usize,
331    ct: &mut VecZnx<DataCt>,
332    cols: usize,
333    compressed: bool,
334    pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
335    sk: &GLWESecretPrepared<DataSk, B>,
336    source_xa: &mut Source,
337    source_xe: &mut Source,
338    sigma: f64,
339    scratch: &mut Scratch<B>,
340) where
341    Module<B>: VecZnxDftAllocBytes
342        + VecZnxBigNormalize<B>
343        + VecZnxDftFromVecZnx<B>
344        + SvpApplyInplace<B>
345        + VecZnxDftToVecZnxBigConsume<B>
346        + VecZnxNormalizeTmpBytes
347        + VecZnxFillUniform
348        + VecZnxSubABInplace
349        + VecZnxAddInplace
350        + VecZnxNormalizeInplace<B>
351        + VecZnxAddNormal
352        + VecZnxNormalize<B>
353        + VecZnxSub,
354    Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
355{
356    #[cfg(debug_assertions)]
357    {
358        if compressed {
359            use poulpy_hal::api::ZnxInfos;
360
361            assert_eq!(
362                ct.cols(),
363                1,
364                "invalid ciphertext: compressed tag=true but #cols={} != 1",
365                ct.cols()
366            )
367        }
368    }
369
370    let size: usize = ct.size();
371
372    let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size);
373    c0.zero();
374
375    {
376        let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size);
377
378        // ct[i] = uniform
379        // ct[0] -= c[i] * s[i],
380        (1..cols).for_each(|i| {
381            let col_ct: usize = if compressed { 0 } else { i };
382
383            // ct[i] = uniform (+ pt)
384            module.vec_znx_fill_uniform(basek, ct, col_ct, k, source_xa);
385
386            let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
387
388            // ci = ct[i] - pt
389            // i.e. we act as we sample ct[i] already as uniform + pt
390            // and if there is a pt, then we subtract it before applying DFT
391            if let Some((pt, col)) = pt {
392                if i == col {
393                    module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
394                    module.vec_znx_normalize_inplace(basek, &mut ci, 0, scratch_3);
395                    module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, &ci, 0);
396                } else {
397                    module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, ct, col_ct);
398                }
399            } else {
400                module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, ct, col_ct);
401            }
402
403            module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1);
404            let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft);
405
406            // use c[0] as buffer, which is overwritten later by the normalization step
407            module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3);
408
409            // c0_tmp = -c[i] * s[i] (use c[0] as buffer)
410            module.vec_znx_sub_ab_inplace(&mut c0, 0, &ci, 0);
411        });
412    }
413
414    // c[0] += e
415    module.vec_znx_add_normal(basek, &mut c0, 0, k, source_xe, sigma, sigma * SIX_SIGMA);
416
417    // c[0] += m if col = 0
418    if let Some((pt, col)) = pt
419        && col == 0
420    {
421        module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0);
422    }
423
424    // c[0] = norm(c[0])
425    module.vec_znx_normalize(basek, ct, 0, &c0, 0, scratch_1);
426}