1use poulpy_hal::{
2 api::{
3 ModuleN, ScratchAvailable, ScratchTakeBasic, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolBytesOf, SvpPrepare,
4 VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
5 VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
6 VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
7 },
8 layouts::{Backend, DataMut, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, VecZnxToMut, ZnxInfos, ZnxZero},
9 source::Source,
10};
11
12use crate::{
13 GetDistribution, ScratchTakeCore,
14 dist::Distribution,
15 encryption::{SIGMA, SIGMA_BOUND},
16 layouts::{
17 GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToRef, GLWEPrepared, GLWEPreparedToRef, GLWEToMut, LWEInfos,
18 prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
19 },
20};
21
22impl GLWE<Vec<u8>> {
23 pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
24 where
25 A: GLWEInfos,
26 M: GLWEEncryptSk<BE>,
27 {
28 module.glwe_encrypt_sk_tmp_bytes(infos)
29 }
30
31 pub fn encrypt_pk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
32 where
33 A: GLWEInfos,
34 M: GLWEEncryptPk<BE>,
35 {
36 module.glwe_encrypt_pk_tmp_bytes(infos)
37 }
38}
39
40impl<D: DataMut> GLWE<D> {
41 pub fn encrypt_sk<P, S, M, BE: Backend>(
42 &mut self,
43 module: &M,
44 pt: &P,
45 sk: &S,
46 source_xa: &mut Source,
47 source_xe: &mut Source,
48 scratch: &mut Scratch<BE>,
49 ) where
50 P: GLWEPlaintextToRef,
51 S: GLWESecretPreparedToRef<BE>,
52 M: GLWEEncryptSk<BE>,
53 Scratch<BE>: ScratchTakeCore<BE>,
54 {
55 module.glwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch);
56 }
57
58 pub fn encrypt_zero_sk<S, M, BE: Backend>(
59 &mut self,
60 module: &M,
61 sk: &S,
62 source_xa: &mut Source,
63 source_xe: &mut Source,
64 scratch: &mut Scratch<BE>,
65 ) where
66 S: GLWESecretPreparedToRef<BE>,
67 M: GLWEEncryptSk<BE>,
68 Scratch<BE>: ScratchTakeCore<BE>,
69 {
70 module.glwe_encrypt_zero_sk(self, sk, source_xa, source_xe, scratch);
71 }
72
73 pub fn encrypt_pk<P, K, M, BE: Backend>(
74 &mut self,
75 module: &M,
76 pt: &P,
77 pk: &K,
78 source_xu: &mut Source,
79 source_xe: &mut Source,
80 scratch: &mut Scratch<BE>,
81 ) where
82 P: GLWEPlaintextToRef + GLWEInfos,
83 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
84 M: GLWEEncryptPk<BE>,
85 {
86 module.glwe_encrypt_pk(self, pt, pk, source_xu, source_xe, scratch);
87 }
88
89 pub fn encrypt_zero_pk<K, M, BE: Backend>(
90 &mut self,
91 module: &M,
92 pk: &K,
93 source_xu: &mut Source,
94 source_xe: &mut Source,
95 scratch: &mut Scratch<BE>,
96 ) where
97 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
98 M: GLWEEncryptPk<BE>,
99 {
100 module.glwe_encrypt_zero_pk(self, pk, source_xu, source_xe, scratch);
101 }
102}
103
104pub trait GLWEEncryptSk<BE: Backend> {
105 fn glwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
106 where
107 A: GLWEInfos;
108
109 fn glwe_encrypt_sk<R, P, S>(
110 &self,
111 res: &mut R,
112 pt: &P,
113 sk: &S,
114 source_xa: &mut Source,
115 source_xe: &mut Source,
116 scratch: &mut Scratch<BE>,
117 ) where
118 R: GLWEToMut,
119 P: GLWEPlaintextToRef,
120 S: GLWESecretPreparedToRef<BE>;
121
122 fn glwe_encrypt_zero_sk<R, S>(
123 &self,
124 res: &mut R,
125 sk: &S,
126 source_xa: &mut Source,
127 source_xe: &mut Source,
128 scratch: &mut Scratch<BE>,
129 ) where
130 R: GLWEToMut,
131 S: GLWESecretPreparedToRef<BE>;
132}
133
134impl<BE: Backend> GLWEEncryptSk<BE> for Module<BE>
135where
136 Self: Sized + ModuleN + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + GLWEEncryptSkInternal<BE>,
137 Scratch<BE>: ScratchAvailable,
138{
139 fn glwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
140 where
141 A: GLWEInfos,
142 {
143 let size: usize = infos.size();
144 assert_eq!(self.n() as u32, infos.n());
145 self.vec_znx_normalize_tmp_bytes() + 2 * VecZnx::bytes_of(self.n(), 1, size) + self.bytes_of_vec_znx_dft(1, size)
146 }
147
148 fn glwe_encrypt_sk<R, P, S>(
149 &self,
150 res: &mut R,
151 pt: &P,
152 sk: &S,
153 source_xa: &mut Source,
154 source_xe: &mut Source,
155 scratch: &mut Scratch<BE>,
156 ) where
157 R: GLWEToMut,
158 P: GLWEPlaintextToRef,
159 S: GLWESecretPreparedToRef<BE>,
160 {
161 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
162 let pt: &GLWEPlaintext<&[u8]> = &pt.to_ref();
163 let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
164
165 assert_eq!(res.rank(), sk.rank());
166 assert_eq!(res.n(), self.n() as u32);
167 assert_eq!(sk.n(), self.n() as u32);
168 assert_eq!(pt.n(), self.n() as u32);
169 assert!(
170 scratch.available() >= self.glwe_encrypt_sk_tmp_bytes(res),
171 "scratch.available(): {} < GLWE::encrypt_sk_tmp_bytes: {}",
172 scratch.available(),
173 self.glwe_encrypt_sk_tmp_bytes(res)
174 );
175
176 let cols: usize = (res.rank() + 1).into();
177 self.glwe_encrypt_sk_internal(
178 res.base2k().into(),
179 res.k().into(),
180 res.data_mut(),
181 cols,
182 false,
183 Some((pt, 0)),
184 sk,
185 source_xa,
186 source_xe,
187 SIGMA,
188 scratch,
189 );
190 }
191
192 fn glwe_encrypt_zero_sk<R, S>(
193 &self,
194 res: &mut R,
195 sk: &S,
196 source_xa: &mut Source,
197 source_xe: &mut Source,
198 scratch: &mut Scratch<BE>,
199 ) where
200 R: GLWEToMut,
201 S: GLWESecretPreparedToRef<BE>,
202 {
203 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
204 let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
205
206 assert_eq!(res.rank(), sk.rank());
207 assert_eq!(res.n(), self.n() as u32);
208 assert_eq!(sk.n(), self.n() as u32);
209 assert!(
210 scratch.available() >= self.glwe_encrypt_sk_tmp_bytes(res),
211 "scratch.available(): {} < GLWE::encrypt_sk_tmp_bytes: {}",
212 scratch.available(),
213 self.glwe_encrypt_sk_tmp_bytes(res)
214 );
215
216 let cols: usize = (res.rank() + 1).into();
217 self.glwe_encrypt_sk_internal(
218 res.base2k().into(),
219 res.k().into(),
220 res.data_mut(),
221 cols,
222 false,
223 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
224 sk,
225 source_xa,
226 source_xe,
227 SIGMA,
228 scratch,
229 );
230 }
231}
232
233pub trait GLWEEncryptPk<BE: Backend> {
234 fn glwe_encrypt_pk_tmp_bytes<A>(&self, infos: &A) -> usize
235 where
236 A: GLWEInfos;
237
238 fn glwe_encrypt_pk<R, P, K>(
239 &self,
240 res: &mut R,
241 pt: &P,
242 pk: &K,
243 source_xu: &mut Source,
244 source_xe: &mut Source,
245 scratch: &mut Scratch<BE>,
246 ) where
247 R: GLWEToMut,
248 P: GLWEPlaintextToRef + GLWEInfos,
249 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos;
250
251 fn glwe_encrypt_zero_pk<R, K>(
252 &self,
253 res: &mut R,
254 pk: &K,
255 source_xu: &mut Source,
256 source_xe: &mut Source,
257 scratch: &mut Scratch<BE>,
258 ) where
259 R: GLWEToMut,
260 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos;
261}
262
263impl<BE: Backend> GLWEEncryptPk<BE> for Module<BE>
264where
265 Self: GLWEEncryptPkInternal<BE> + VecZnxDftBytesOf + SvpPPolBytesOf + VecZnxBigBytesOf + VecZnxNormalizeTmpBytes,
266{
267 fn glwe_encrypt_pk_tmp_bytes<A>(&self, infos: &A) -> usize
268 where
269 A: GLWEInfos,
270 {
271 let size: usize = infos.size();
272 assert_eq!(self.n() as u32, infos.n());
273 ((self.bytes_of_vec_znx_dft(1, size) + self.bytes_of_vec_znx_big(1, size)).max(ScalarZnx::bytes_of(self.n(), 1)))
274 + self.bytes_of_svp_ppol(1)
275 + self.vec_znx_normalize_tmp_bytes()
276 }
277
278 fn glwe_encrypt_pk<R, P, K>(
279 &self,
280 res: &mut R,
281 pt: &P,
282 pk: &K,
283 source_xu: &mut Source,
284 source_xe: &mut Source,
285 scratch: &mut Scratch<BE>,
286 ) where
287 R: GLWEToMut,
288 P: GLWEPlaintextToRef + GLWEInfos,
289 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
290 {
291 self.glwe_encrypt_pk_internal(res, Some((pt, 0)), pk, source_xu, source_xe, scratch);
292 }
293
294 fn glwe_encrypt_zero_pk<R, K>(
295 &self,
296 res: &mut R,
297 pk: &K,
298 source_xu: &mut Source,
299 source_xe: &mut Source,
300 scratch: &mut Scratch<BE>,
301 ) where
302 R: GLWEToMut,
303 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
304 {
305 self.glwe_encrypt_pk_internal(
306 res,
307 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
308 pk,
309 source_xu,
310 source_xe,
311 scratch,
312 );
313 }
314}
315
316pub(crate) trait GLWEEncryptPkInternal<BE: Backend> {
317 fn glwe_encrypt_pk_internal<R, P, K>(
318 &self,
319 res: &mut R,
320 pt: Option<(&P, usize)>,
321 pk: &K,
322 source_xu: &mut Source,
323 source_xe: &mut Source,
324 scratch: &mut Scratch<BE>,
325 ) where
326 R: GLWEToMut,
327 P: GLWEPlaintextToRef + GLWEInfos,
328 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos;
329}
330
331impl<BE: Backend> GLWEEncryptPkInternal<BE> for Module<BE>
332where
333 Self: SvpPrepare<BE>
334 + SvpApplyDftToDft<BE>
335 + VecZnxIdftApplyConsume<BE>
336 + VecZnxBigAddNormal<BE>
337 + VecZnxBigAddSmallInplace<BE>
338 + VecZnxBigNormalize<BE>
339 + SvpPPolBytesOf
340 + ModuleN
341 + VecZnxDftBytesOf,
342 Scratch<BE>: ScratchTakeBasic,
343{
344 fn glwe_encrypt_pk_internal<R, P, K>(
345 &self,
346 res: &mut R,
347 pt: Option<(&P, usize)>,
348 pk: &K,
349 source_xu: &mut Source,
350 source_xe: &mut Source,
351 scratch: &mut Scratch<BE>,
352 ) where
353 R: GLWEToMut,
354 P: GLWEPlaintextToRef + GLWEInfos,
355 K: GLWEPreparedToRef<BE> + GetDistribution + GLWEInfos,
356 {
357 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
358
359 assert_eq!(res.base2k(), pk.base2k());
360 assert_eq!(res.n(), pk.n());
361 assert_eq!(res.rank(), pk.rank());
362 if let Some((pt, _)) = pt {
363 assert_eq!(pt.base2k(), pk.base2k());
364 assert_eq!(pt.n(), pk.n());
365 }
366
367 let base2k: usize = pk.base2k().into();
368 let size_pk: usize = pk.size();
369 let cols: usize = (res.rank() + 1).into();
370
371 let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self, 1);
373
374 {
375 let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1);
376 match pk.dist() {
377 Distribution::NONE => panic!(
378 "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
379 Self::generate"
380 ),
381 Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, *hw, source_xu),
382 Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, *prob, source_xu),
383 Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, *hw, source_xu),
384 Distribution::BinaryProb(prob) => u.fill_binary_prob(0, *prob, source_xu),
385 Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, *block_size, source_xu),
386 Distribution::ZERO => {}
387 }
388
389 self.svp_prepare(&mut u_dft, 0, &u, 0);
390 }
391
392 {
393 let pk: &GLWEPrepared<&[u8], BE> = &pk.to_ref();
394
395 for i in 0..cols {
397 let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, 1, size_pk);
398 self.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
400
401 let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft);
403
404 self.vec_znx_big_add_normal(
406 base2k,
407 &mut ci_big,
408 0,
409 pk.k().into(),
410 source_xe,
411 SIGMA,
412 SIGMA_BOUND,
413 );
414
415 if let Some((pt, col)) = pt
417 && col == i
418 {
419 self.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.to_ref().data, 0);
420 }
421
422 self.vec_znx_big_normalize(base2k, &mut res.data, i, base2k, &ci_big, 0, scratch_2);
424 }
425 }
426 }
427}
428
429pub(crate) trait GLWEEncryptSkInternal<BE: Backend> {
430 #[allow(clippy::too_many_arguments)]
431 fn glwe_encrypt_sk_internal<R, P, S>(
432 &self,
433 base2k: usize,
434 k: usize,
435 res: &mut R,
436 cols: usize,
437 compressed: bool,
438 pt: Option<(&P, usize)>,
439 sk: &S,
440 source_xa: &mut Source,
441 source_xe: &mut Source,
442 sigma: f64,
443 scratch: &mut Scratch<BE>,
444 ) where
445 R: VecZnxToMut,
446 P: GLWEPlaintextToRef,
447 S: GLWESecretPreparedToRef<BE>;
448}
449
450impl<BE: Backend> GLWEEncryptSkInternal<BE> for Module<BE>
451where
452 Self: ModuleN
453 + VecZnxDftBytesOf
454 + VecZnxBigNormalize<BE>
455 + VecZnxDftApply<BE>
456 + SvpApplyDftToDftInplace<BE>
457 + VecZnxIdftApplyConsume<BE>
458 + VecZnxNormalizeTmpBytes
459 + VecZnxFillUniform
460 + VecZnxSubInplace
461 + VecZnxAddInplace
462 + VecZnxNormalizeInplace<BE>
463 + VecZnxAddNormal
464 + VecZnxNormalize<BE>
465 + VecZnxSub,
466 Scratch<BE>: ScratchTakeBasic,
467{
468 fn glwe_encrypt_sk_internal<R, P, S>(
469 &self,
470 base2k: usize,
471 k: usize,
472 res: &mut R,
473 cols: usize,
474 compressed: bool,
475 pt: Option<(&P, usize)>,
476 sk: &S,
477 source_xa: &mut Source,
478 source_xe: &mut Source,
479 sigma: f64,
480 scratch: &mut Scratch<BE>,
481 ) where
482 R: VecZnxToMut,
483 P: GLWEPlaintextToRef,
484 S: GLWESecretPreparedToRef<BE>,
485 {
486 let ct: &mut VecZnx<&mut [u8]> = &mut res.to_mut();
487 let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref();
488
489 if compressed {
490 assert_eq!(
491 ct.cols(),
492 1,
493 "invalid glwe: compressed tag=true but #cols={} != 1",
494 ct.cols()
495 )
496 }
497
498 assert!(
499 sk.dist != Distribution::NONE,
500 "glwe secret distribution is NONE (have you prepared the key?)"
501 );
502
503 let size: usize = ct.size();
504
505 let (mut c0, scratch_1) = scratch.take_vec_znx(self.n(), 1, size);
506 c0.zero();
507
508 {
509 let (mut ci, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, size);
510
511 (1..cols).for_each(|i| {
514 let col_ct: usize = if compressed { 0 } else { i };
515
516 self.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa);
518
519 let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, size);
522
523 if let Some((pt, col)) = pt {
527 if i == col {
528 self.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.to_ref().data, 0);
529 self.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3);
530 self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
531 } else {
532 self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
533 }
534 } else {
535 self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
536 }
537
538 self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
539 let ci_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(ci_dft);
540
541 self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3);
543
544 self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0);
546 });
547 }
548
549 self.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
551
552 if let Some((pt, col)) = pt
554 && col == 0
555 {
556 self.vec_znx_add_inplace(&mut c0, 0, &pt.to_ref().data, 0);
557 }
558
559 self.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1);
561 }
562}