poulpy_core/encryption/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol,
4        TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace,
5        VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero},
9    source::Source,
10};
11
12use crate::{
13    dist::Distribution,
14    encryption::{SIGMA, SIGMA_BOUND},
15    layouts::{
16        GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos,
17        prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared},
18    },
19};
20
21impl GLWECiphertext<Vec<u8>> {
22    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
23    where
24        A: GLWEInfos,
25        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
26    {
27        let size: usize = infos.size();
28        assert_eq!(module.n() as u32, infos.n());
29        module.vec_znx_normalize_tmp_bytes()
30            + 2 * VecZnx::alloc_bytes(module.n(), 1, size)
31            + module.vec_znx_dft_alloc_bytes(1, size)
32    }
33    pub fn encrypt_pk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
34    where
35        A: GLWEInfos,
36        Module<B>: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
37    {
38        let size: usize = infos.size();
39        assert_eq!(module.n() as u32, infos.n());
40        ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size))
41            | ScalarZnx::alloc_bytes(module.n(), 1))
42            + module.svp_ppol_alloc_bytes(1)
43            + module.vec_znx_normalize_tmp_bytes()
44    }
45}
46
47impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
48    #[allow(clippy::too_many_arguments)]
49    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
50        &mut self,
51        module: &Module<B>,
52        pt: &GLWEPlaintext<DataPt>,
53        sk: &GLWESecretPrepared<DataSk, B>,
54        source_xa: &mut Source,
55        source_xe: &mut Source,
56        scratch: &mut Scratch<B>,
57    ) where
58        Module<B>: VecZnxDftAllocBytes
59            + VecZnxBigNormalize<B>
60            + VecZnxDftApply<B>
61            + SvpApplyDftToDftInplace<B>
62            + VecZnxIdftApplyConsume<B>
63            + VecZnxNormalizeTmpBytes
64            + VecZnxFillUniform
65            + VecZnxSubInplace
66            + VecZnxAddInplace
67            + VecZnxNormalizeInplace<B>
68            + VecZnxAddNormal
69            + VecZnxNormalize<B>
70            + VecZnxSub,
71        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
72    {
73        #[cfg(debug_assertions)]
74        {
75            assert_eq!(self.rank(), sk.rank());
76            assert_eq!(sk.n(), self.n());
77            assert_eq!(pt.n(), self.n());
78            assert!(
79                scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self),
80                "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
81                scratch.available(),
82                GLWECiphertext::encrypt_sk_scratch_space(module, self)
83            )
84        }
85
86        self.encrypt_sk_internal(module, Some((pt, 0)), sk, source_xa, source_xe, scratch);
87    }
88
89    pub fn encrypt_zero_sk<DataSk: DataRef, B: Backend>(
90        &mut self,
91        module: &Module<B>,
92        sk: &GLWESecretPrepared<DataSk, B>,
93        source_xa: &mut Source,
94        source_xe: &mut Source,
95        scratch: &mut Scratch<B>,
96    ) where
97        Module<B>: VecZnxDftAllocBytes
98            + VecZnxBigNormalize<B>
99            + VecZnxDftApply<B>
100            + SvpApplyDftToDftInplace<B>
101            + VecZnxIdftApplyConsume<B>
102            + VecZnxNormalizeTmpBytes
103            + VecZnxFillUniform
104            + VecZnxSubInplace
105            + VecZnxAddInplace
106            + VecZnxNormalizeInplace<B>
107            + VecZnxAddNormal
108            + VecZnxNormalize<B>
109            + VecZnxSub,
110        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
111    {
112        #[cfg(debug_assertions)]
113        {
114            assert_eq!(self.rank(), sk.rank());
115            assert_eq!(sk.n(), self.n());
116            assert!(
117                scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self),
118                "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
119                scratch.available(),
120                GLWECiphertext::encrypt_sk_scratch_space(module, self)
121            )
122        }
123        self.encrypt_sk_internal(
124            module,
125            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
126            sk,
127            source_xa,
128            source_xe,
129            scratch,
130        );
131    }
132
133    #[allow(clippy::too_many_arguments)]
134    pub(crate) fn encrypt_sk_internal<DataPt: DataRef, DataSk: DataRef, B: Backend>(
135        &mut self,
136        module: &Module<B>,
137        pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
138        sk: &GLWESecretPrepared<DataSk, B>,
139        source_xa: &mut Source,
140        source_xe: &mut Source,
141        scratch: &mut Scratch<B>,
142    ) where
143        Module<B>: VecZnxDftAllocBytes
144            + VecZnxBigNormalize<B>
145            + VecZnxDftApply<B>
146            + SvpApplyDftToDftInplace<B>
147            + VecZnxIdftApplyConsume<B>
148            + VecZnxNormalizeTmpBytes
149            + VecZnxFillUniform
150            + VecZnxSubInplace
151            + VecZnxAddInplace
152            + VecZnxNormalizeInplace<B>
153            + VecZnxAddNormal
154            + VecZnxNormalize<B>
155            + VecZnxSub,
156        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
157    {
158        let cols: usize = (self.rank() + 1).into();
159        glwe_encrypt_sk_internal(
160            module,
161            self.base2k().into(),
162            self.k().into(),
163            &mut self.data,
164            cols,
165            false,
166            pt,
167            sk,
168            source_xa,
169            source_xe,
170            SIGMA,
171            scratch,
172        );
173    }
174
175    #[allow(clippy::too_many_arguments)]
176    pub fn encrypt_pk<DataPt: DataRef, DataPk: DataRef, B: Backend>(
177        &mut self,
178        module: &Module<B>,
179        pt: &GLWEPlaintext<DataPt>,
180        pk: &GLWEPublicKeyPrepared<DataPk, B>,
181        source_xu: &mut Source,
182        source_xe: &mut Source,
183        scratch: &mut Scratch<B>,
184    ) where
185        Module<B>: SvpPrepare<B>
186            + SvpApplyDftToDft<B>
187            + VecZnxIdftApplyConsume<B>
188            + VecZnxBigAddNormal<B>
189            + VecZnxBigAddSmallInplace<B>
190            + VecZnxBigNormalize<B>,
191        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
192    {
193        self.encrypt_pk_internal::<DataPt, DataPk, B>(module, Some((pt, 0)), pk, source_xu, source_xe, scratch);
194    }
195
196    pub fn encrypt_zero_pk<DataPk: DataRef, B: Backend>(
197        &mut self,
198        module: &Module<B>,
199        pk: &GLWEPublicKeyPrepared<DataPk, B>,
200        source_xu: &mut Source,
201        source_xe: &mut Source,
202        scratch: &mut Scratch<B>,
203    ) where
204        Module<B>: SvpPrepare<B>
205            + SvpApplyDftToDft<B>
206            + VecZnxIdftApplyConsume<B>
207            + VecZnxBigAddNormal<B>
208            + VecZnxBigAddSmallInplace<B>
209            + VecZnxBigNormalize<B>,
210        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
211    {
212        self.encrypt_pk_internal::<Vec<u8>, DataPk, B>(
213            module,
214            None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
215            pk,
216            source_xu,
217            source_xe,
218            scratch,
219        );
220    }
221
222    #[allow(clippy::too_many_arguments)]
223    pub(crate) fn encrypt_pk_internal<DataPt: DataRef, DataPk: DataRef, B: Backend>(
224        &mut self,
225        module: &Module<B>,
226        pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
227        pk: &GLWEPublicKeyPrepared<DataPk, B>,
228        source_xu: &mut Source,
229        source_xe: &mut Source,
230        scratch: &mut Scratch<B>,
231    ) where
232        Module<B>: SvpPrepare<B>
233            + SvpApplyDftToDft<B>
234            + VecZnxIdftApplyConsume<B>
235            + VecZnxBigAddNormal<B>
236            + VecZnxBigAddSmallInplace<B>
237            + VecZnxBigNormalize<B>,
238        Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
239    {
240        #[cfg(debug_assertions)]
241        {
242            assert_eq!(self.base2k(), pk.base2k());
243            assert_eq!(self.n(), pk.n());
244            assert_eq!(self.rank(), pk.rank());
245            if let Some((pt, _)) = pt {
246                assert_eq!(pt.base2k(), pk.base2k());
247                assert_eq!(pt.n(), pk.n());
248            }
249        }
250
251        let base2k: usize = pk.base2k().into();
252        let size_pk: usize = pk.size();
253        let cols: usize = (self.rank() + 1).into();
254
255        // Generates u according to the underlying secret distribution.
256        let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n().into(), 1);
257
258        {
259            let (mut u, _) = scratch_1.take_scalar_znx(self.n().into(), 1);
260            match pk.dist {
261                Distribution::NONE => panic!(
262                    "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
263                     Self::generate"
264                ),
265                Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
266                Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
267                Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu),
268                Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu),
269                Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu),
270                Distribution::ZERO => {}
271            }
272
273            module.svp_prepare(&mut u_dft, 0, &u, 0);
274        }
275
276        // ct[i] = pk[i] * u + ei (+ m if col = i)
277        (0..cols).for_each(|i| {
278            let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), 1, size_pk);
279            // ci_dft = DFT(u) * DFT(pk[i])
280            module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
281
282            // ci_big = u * p[i]
283            let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft);
284
285            // ci_big = u * pk[i] + e
286            module.vec_znx_big_add_normal(
287                base2k,
288                &mut ci_big,
289                0,
290                pk.k().into(),
291                source_xe,
292                SIGMA,
293                SIGMA_BOUND,
294            );
295
296            // ci_big = u * pk[i] + e + m (if col = i)
297            if let Some((pt, col)) = pt
298                && col == i
299            {
300                module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0);
301            }
302
303            // ct[i] = norm(ci_big)
304            module.vec_znx_big_normalize(base2k, &mut self.data, i, base2k, &ci_big, 0, scratch_2);
305        });
306    }
307}
308
309#[allow(clippy::too_many_arguments)]
310pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk: DataRef, B: Backend>(
311    module: &Module<B>,
312    base2k: usize,
313    k: usize,
314    ct: &mut VecZnx<DataCt>,
315    cols: usize,
316    compressed: bool,
317    pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
318    sk: &GLWESecretPrepared<DataSk, B>,
319    source_xa: &mut Source,
320    source_xe: &mut Source,
321    sigma: f64,
322    scratch: &mut Scratch<B>,
323) where
324    Module<B>: VecZnxDftAllocBytes
325        + VecZnxBigNormalize<B>
326        + VecZnxDftApply<B>
327        + SvpApplyDftToDftInplace<B>
328        + VecZnxIdftApplyConsume<B>
329        + VecZnxNormalizeTmpBytes
330        + VecZnxFillUniform
331        + VecZnxSubInplace
332        + VecZnxAddInplace
333        + VecZnxNormalizeInplace<B>
334        + VecZnxAddNormal
335        + VecZnxNormalize<B>
336        + VecZnxSub,
337    Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
338{
339    #[cfg(debug_assertions)]
340    {
341        if compressed {
342            assert_eq!(
343                ct.cols(),
344                1,
345                "invalid ciphertext: compressed tag=true but #cols={} != 1",
346                ct.cols()
347            )
348        }
349    }
350
351    let size: usize = ct.size();
352
353    let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size);
354    c0.zero();
355
356    {
357        let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size);
358
359        // ct[i] = uniform
360        // ct[0] -= c[i] * s[i],
361        (1..cols).for_each(|i| {
362            let col_ct: usize = if compressed { 0 } else { i };
363
364            // ct[i] = uniform (+ pt)
365            module.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa);
366
367            let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
368
369            // ci = ct[i] - pt
370            // i.e. we act as we sample ct[i] already as uniform + pt
371            // and if there is a pt, then we subtract it before applying DFT
372            if let Some((pt, col)) = pt {
373                if i == col {
374                    module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
375                    module.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3);
376                    module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
377                } else {
378                    module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
379                }
380            } else {
381                module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
382            }
383
384            module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
385            let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft);
386
387            // use c[0] as buffer, which is overwritten later by the normalization step
388            module.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3);
389
390            // c0_tmp = -c[i] * s[i] (use c[0] as buffer)
391            module.vec_znx_sub_inplace(&mut c0, 0, &ci, 0);
392        });
393    }
394
395    // c[0] += e
396    module.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
397
398    // c[0] += m if col = 0
399    if let Some((pt, col)) = pt
400        && col == 0
401    {
402        module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0);
403    }
404
405    // c[0] = norm(c[0])
406    module.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1);
407}