1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx,
4 TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
5 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, ZnxInfos, ZnxZero,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig},
9 source::Source,
10};
11
12use crate::{
13 SIX_SIGMA,
14 dist::Distribution,
15 layouts::{
16 GLWECiphertext, GLWEPlaintext, Infos,
17 prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared},
18 },
19};
20
21impl GLWECiphertext<Vec<u8>> {
22 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
23 where
24 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
25 {
26 let size: usize = k.div_ceil(basek);
27 module.vec_znx_normalize_tmp_bytes(n) + 2 * VecZnx::alloc_bytes(n, 1, size) + module.vec_znx_dft_alloc_bytes(n, 1, size)
28 }
29 pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
30 where
31 Module<B>: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
32 {
33 let size: usize = k.div_ceil(basek);
34 ((module.vec_znx_dft_alloc_bytes(n, 1, size) + module.vec_znx_big_alloc_bytes(n, 1, size)) | ScalarZnx::alloc_bytes(n, 1))
35 + module.svp_ppol_alloc_bytes(n, 1)
36 + module.vec_znx_normalize_tmp_bytes(n)
37 }
38}
39
40impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
41 #[allow(clippy::too_many_arguments)]
42 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
43 &mut self,
44 module: &Module<B>,
45 pt: &GLWEPlaintext<DataPt>,
46 sk: &GLWESecretPrepared<DataSk, B>,
47 source_xa: &mut Source,
48 source_xe: &mut Source,
49 sigma: f64,
50 scratch: &mut Scratch<B>,
51 ) where
52 Module<B>: VecZnxDftAllocBytes
53 + VecZnxBigNormalize<B>
54 + VecZnxDftFromVecZnx<B>
55 + SvpApplyInplace<B>
56 + VecZnxDftToVecZnxBigConsume<B>
57 + VecZnxNormalizeTmpBytes
58 + VecZnxFillUniform
59 + VecZnxSubABInplace
60 + VecZnxAddInplace
61 + VecZnxNormalizeInplace<B>
62 + VecZnxAddNormal
63 + VecZnxNormalize<B>
64 + VecZnxSub,
65 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
66 {
67 #[cfg(debug_assertions)]
68 {
69 assert_eq!(self.rank(), sk.rank());
70 assert_eq!(sk.n(), self.n());
71 assert_eq!(pt.n(), self.n());
72 assert!(
73 scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()),
74 "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
75 scratch.available(),
76 GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k())
77 )
78 }
79
80 self.encrypt_sk_internal(
81 module,
82 Some((pt, 0)),
83 sk,
84 source_xa,
85 source_xe,
86 sigma,
87 scratch,
88 );
89 }
90
91 pub fn encrypt_zero_sk<DataSk: DataRef, B: Backend>(
92 &mut self,
93 module: &Module<B>,
94 sk: &GLWESecretPrepared<DataSk, B>,
95 source_xa: &mut Source,
96 source_xe: &mut Source,
97 sigma: f64,
98 scratch: &mut Scratch<B>,
99 ) where
100 Module<B>: VecZnxDftAllocBytes
101 + VecZnxBigNormalize<B>
102 + VecZnxDftFromVecZnx<B>
103 + SvpApplyInplace<B>
104 + VecZnxDftToVecZnxBigConsume<B>
105 + VecZnxNormalizeTmpBytes
106 + VecZnxFillUniform
107 + VecZnxSubABInplace
108 + VecZnxAddInplace
109 + VecZnxNormalizeInplace<B>
110 + VecZnxAddNormal
111 + VecZnxNormalize<B>
112 + VecZnxSub,
113 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
114 {
115 #[cfg(debug_assertions)]
116 {
117 assert_eq!(self.rank(), sk.rank());
118 assert_eq!(sk.n(), self.n());
119 assert!(
120 scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()),
121 "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
122 scratch.available(),
123 GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k())
124 )
125 }
126 self.encrypt_sk_internal(
127 module,
128 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
129 sk,
130 source_xa,
131 source_xe,
132 sigma,
133 scratch,
134 );
135 }
136
137 #[allow(clippy::too_many_arguments)]
138 pub(crate) fn encrypt_sk_internal<DataPt: DataRef, DataSk: DataRef, B: Backend>(
139 &mut self,
140 module: &Module<B>,
141 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
142 sk: &GLWESecretPrepared<DataSk, B>,
143 source_xa: &mut Source,
144 source_xe: &mut Source,
145 sigma: f64,
146 scratch: &mut Scratch<B>,
147 ) where
148 Module<B>: VecZnxDftAllocBytes
149 + VecZnxBigNormalize<B>
150 + VecZnxDftFromVecZnx<B>
151 + SvpApplyInplace<B>
152 + VecZnxDftToVecZnxBigConsume<B>
153 + VecZnxNormalizeTmpBytes
154 + VecZnxFillUniform
155 + VecZnxSubABInplace
156 + VecZnxAddInplace
157 + VecZnxNormalizeInplace<B>
158 + VecZnxAddNormal
159 + VecZnxNormalize<B>
160 + VecZnxSub,
161 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
162 {
163 let cols: usize = self.rank() + 1;
164 glwe_encrypt_sk_internal(
165 module,
166 self.basek(),
167 self.k(),
168 &mut self.data,
169 cols,
170 false,
171 pt,
172 sk,
173 source_xa,
174 source_xe,
175 sigma,
176 scratch,
177 );
178 }
179
180 #[allow(clippy::too_many_arguments)]
181 pub fn encrypt_pk<DataPt: DataRef, DataPk: DataRef, B: Backend>(
182 &mut self,
183 module: &Module<B>,
184 pt: &GLWEPlaintext<DataPt>,
185 pk: &GLWEPublicKeyPrepared<DataPk, B>,
186 source_xu: &mut Source,
187 source_xe: &mut Source,
188 sigma: f64,
189 scratch: &mut Scratch<B>,
190 ) where
191 Module<B>: SvpPrepare<B>
192 + SvpApply<B>
193 + VecZnxDftToVecZnxBigConsume<B>
194 + VecZnxBigAddNormal<B>
195 + VecZnxBigAddSmallInplace<B>
196 + VecZnxBigNormalize<B>,
197 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
198 {
199 self.encrypt_pk_internal::<DataPt, DataPk, B>(
200 module,
201 Some((pt, 0)),
202 pk,
203 source_xu,
204 source_xe,
205 sigma,
206 scratch,
207 );
208 }
209
210 pub fn encrypt_zero_pk<DataPk: DataRef, B: Backend>(
211 &mut self,
212 module: &Module<B>,
213 pk: &GLWEPublicKeyPrepared<DataPk, B>,
214 source_xu: &mut Source,
215 source_xe: &mut Source,
216 sigma: f64,
217 scratch: &mut Scratch<B>,
218 ) where
219 Module<B>: SvpPrepare<B>
220 + SvpApply<B>
221 + VecZnxDftToVecZnxBigConsume<B>
222 + VecZnxBigAddNormal<B>
223 + VecZnxBigAddSmallInplace<B>
224 + VecZnxBigNormalize<B>,
225 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
226 {
227 self.encrypt_pk_internal::<Vec<u8>, DataPk, B>(
228 module,
229 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
230 pk,
231 source_xu,
232 source_xe,
233 sigma,
234 scratch,
235 );
236 }
237
238 #[allow(clippy::too_many_arguments)]
239 pub(crate) fn encrypt_pk_internal<DataPt: DataRef, DataPk: DataRef, B: Backend>(
240 &mut self,
241 module: &Module<B>,
242 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
243 pk: &GLWEPublicKeyPrepared<DataPk, B>,
244 source_xu: &mut Source,
245 source_xe: &mut Source,
246 sigma: f64,
247 scratch: &mut Scratch<B>,
248 ) where
249 Module<B>: SvpPrepare<B>
250 + SvpApply<B>
251 + VecZnxDftToVecZnxBigConsume<B>
252 + VecZnxBigAddNormal<B>
253 + VecZnxBigAddSmallInplace<B>
254 + VecZnxBigNormalize<B>,
255 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
256 {
257 #[cfg(debug_assertions)]
258 {
259 assert_eq!(self.basek(), pk.basek());
260 assert_eq!(self.n(), pk.n());
261 assert_eq!(self.rank(), pk.rank());
262 if let Some((pt, _)) = pt {
263 assert_eq!(pt.basek(), pk.basek());
264 assert_eq!(pt.n(), pk.n());
265 }
266 }
267
268 let basek: usize = pk.basek();
269 let size_pk: usize = pk.size();
270 let cols: usize = self.rank() + 1;
271
272 let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n(), 1);
274
275 {
276 let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1);
277 match pk.dist {
278 Distribution::NONE => panic!(
279 "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
280 Self::generate"
281 ),
282 Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
283 Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
284 Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu),
285 Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu),
286 Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu),
287 Distribution::ZERO => {}
288 }
289
290 module.svp_prepare(&mut u_dft, 0, &u, 0);
291 }
292
293 (0..cols).for_each(|i| {
295 let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk);
296 module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
298
299 let mut ci_big = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft);
301
302 module.vec_znx_big_add_normal(
304 basek,
305 &mut ci_big,
306 0,
307 pk.k(),
308 source_xe,
309 sigma,
310 sigma * SIX_SIGMA,
311 );
312
313 if let Some((pt, col)) = pt
315 && col == i
316 {
317 module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0);
318 }
319
320 module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2);
322 });
323 }
324}
325
326#[allow(clippy::too_many_arguments)]
327pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk: DataRef, B: Backend>(
328 module: &Module<B>,
329 basek: usize,
330 k: usize,
331 ct: &mut VecZnx<DataCt>,
332 cols: usize,
333 compressed: bool,
334 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
335 sk: &GLWESecretPrepared<DataSk, B>,
336 source_xa: &mut Source,
337 source_xe: &mut Source,
338 sigma: f64,
339 scratch: &mut Scratch<B>,
340) where
341 Module<B>: VecZnxDftAllocBytes
342 + VecZnxBigNormalize<B>
343 + VecZnxDftFromVecZnx<B>
344 + SvpApplyInplace<B>
345 + VecZnxDftToVecZnxBigConsume<B>
346 + VecZnxNormalizeTmpBytes
347 + VecZnxFillUniform
348 + VecZnxSubABInplace
349 + VecZnxAddInplace
350 + VecZnxNormalizeInplace<B>
351 + VecZnxAddNormal
352 + VecZnxNormalize<B>
353 + VecZnxSub,
354 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
355{
356 #[cfg(debug_assertions)]
357 {
358 if compressed {
359 use poulpy_hal::api::ZnxInfos;
360
361 assert_eq!(
362 ct.cols(),
363 1,
364 "invalid ciphertext: compressed tag=true but #cols={} != 1",
365 ct.cols()
366 )
367 }
368 }
369
370 let size: usize = ct.size();
371
372 let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size);
373 c0.zero();
374
375 {
376 let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size);
377
378 (1..cols).for_each(|i| {
381 let col_ct: usize = if compressed { 0 } else { i };
382
383 module.vec_znx_fill_uniform(basek, ct, col_ct, k, source_xa);
385
386 let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
387
388 if let Some((pt, col)) = pt {
392 if i == col {
393 module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
394 module.vec_znx_normalize_inplace(basek, &mut ci, 0, scratch_3);
395 module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, &ci, 0);
396 } else {
397 module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, ct, col_ct);
398 }
399 } else {
400 module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, ct, col_ct);
401 }
402
403 module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1);
404 let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft);
405
406 module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3);
408
409 module.vec_znx_sub_ab_inplace(&mut c0, 0, &ci, 0);
411 });
412 }
413
414 module.vec_znx_add_normal(basek, &mut c0, 0, k, source_xe, sigma, sigma * SIX_SIGMA);
416
417 if let Some((pt, col)) = pt
419 && col == 0
420 {
421 module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0);
422 }
423
424 module.vec_znx_normalize(basek, ct, 0, &c0, 0, scratch_1);
426}