1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol,
4 TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace,
5 VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero},
9 source::Source,
10};
11
12use crate::{
13 dist::Distribution,
14 encryption::{SIGMA, SIGMA_BOUND},
15 layouts::{
16 GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos,
17 prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared},
18 },
19};
20
21impl GLWECiphertext<Vec<u8>> {
22 pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
23 where
24 A: GLWEInfos,
25 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
26 {
27 let size: usize = infos.size();
28 assert_eq!(module.n() as u32, infos.n());
29 module.vec_znx_normalize_tmp_bytes()
30 + 2 * VecZnx::alloc_bytes(module.n(), 1, size)
31 + module.vec_znx_dft_alloc_bytes(1, size)
32 }
33 pub fn encrypt_pk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
34 where
35 A: GLWEInfos,
36 Module<B>: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
37 {
38 let size: usize = infos.size();
39 assert_eq!(module.n() as u32, infos.n());
40 ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size))
41 | ScalarZnx::alloc_bytes(module.n(), 1))
42 + module.svp_ppol_alloc_bytes(1)
43 + module.vec_znx_normalize_tmp_bytes()
44 }
45}
46
47impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
48 #[allow(clippy::too_many_arguments)]
49 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
50 &mut self,
51 module: &Module<B>,
52 pt: &GLWEPlaintext<DataPt>,
53 sk: &GLWESecretPrepared<DataSk, B>,
54 source_xa: &mut Source,
55 source_xe: &mut Source,
56 scratch: &mut Scratch<B>,
57 ) where
58 Module<B>: VecZnxDftAllocBytes
59 + VecZnxBigNormalize<B>
60 + VecZnxDftApply<B>
61 + SvpApplyDftToDftInplace<B>
62 + VecZnxIdftApplyConsume<B>
63 + VecZnxNormalizeTmpBytes
64 + VecZnxFillUniform
65 + VecZnxSubInplace
66 + VecZnxAddInplace
67 + VecZnxNormalizeInplace<B>
68 + VecZnxAddNormal
69 + VecZnxNormalize<B>
70 + VecZnxSub,
71 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
72 {
73 #[cfg(debug_assertions)]
74 {
75 assert_eq!(self.rank(), sk.rank());
76 assert_eq!(sk.n(), self.n());
77 assert_eq!(pt.n(), self.n());
78 assert!(
79 scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self),
80 "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
81 scratch.available(),
82 GLWECiphertext::encrypt_sk_scratch_space(module, self)
83 )
84 }
85
86 self.encrypt_sk_internal(module, Some((pt, 0)), sk, source_xa, source_xe, scratch);
87 }
88
89 pub fn encrypt_zero_sk<DataSk: DataRef, B: Backend>(
90 &mut self,
91 module: &Module<B>,
92 sk: &GLWESecretPrepared<DataSk, B>,
93 source_xa: &mut Source,
94 source_xe: &mut Source,
95 scratch: &mut Scratch<B>,
96 ) where
97 Module<B>: VecZnxDftAllocBytes
98 + VecZnxBigNormalize<B>
99 + VecZnxDftApply<B>
100 + SvpApplyDftToDftInplace<B>
101 + VecZnxIdftApplyConsume<B>
102 + VecZnxNormalizeTmpBytes
103 + VecZnxFillUniform
104 + VecZnxSubInplace
105 + VecZnxAddInplace
106 + VecZnxNormalizeInplace<B>
107 + VecZnxAddNormal
108 + VecZnxNormalize<B>
109 + VecZnxSub,
110 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
111 {
112 #[cfg(debug_assertions)]
113 {
114 assert_eq!(self.rank(), sk.rank());
115 assert_eq!(sk.n(), self.n());
116 assert!(
117 scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self),
118 "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
119 scratch.available(),
120 GLWECiphertext::encrypt_sk_scratch_space(module, self)
121 )
122 }
123 self.encrypt_sk_internal(
124 module,
125 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
126 sk,
127 source_xa,
128 source_xe,
129 scratch,
130 );
131 }
132
133 #[allow(clippy::too_many_arguments)]
134 pub(crate) fn encrypt_sk_internal<DataPt: DataRef, DataSk: DataRef, B: Backend>(
135 &mut self,
136 module: &Module<B>,
137 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
138 sk: &GLWESecretPrepared<DataSk, B>,
139 source_xa: &mut Source,
140 source_xe: &mut Source,
141 scratch: &mut Scratch<B>,
142 ) where
143 Module<B>: VecZnxDftAllocBytes
144 + VecZnxBigNormalize<B>
145 + VecZnxDftApply<B>
146 + SvpApplyDftToDftInplace<B>
147 + VecZnxIdftApplyConsume<B>
148 + VecZnxNormalizeTmpBytes
149 + VecZnxFillUniform
150 + VecZnxSubInplace
151 + VecZnxAddInplace
152 + VecZnxNormalizeInplace<B>
153 + VecZnxAddNormal
154 + VecZnxNormalize<B>
155 + VecZnxSub,
156 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
157 {
158 let cols: usize = (self.rank() + 1).into();
159 glwe_encrypt_sk_internal(
160 module,
161 self.base2k().into(),
162 self.k().into(),
163 &mut self.data,
164 cols,
165 false,
166 pt,
167 sk,
168 source_xa,
169 source_xe,
170 SIGMA,
171 scratch,
172 );
173 }
174
175 #[allow(clippy::too_many_arguments)]
176 pub fn encrypt_pk<DataPt: DataRef, DataPk: DataRef, B: Backend>(
177 &mut self,
178 module: &Module<B>,
179 pt: &GLWEPlaintext<DataPt>,
180 pk: &GLWEPublicKeyPrepared<DataPk, B>,
181 source_xu: &mut Source,
182 source_xe: &mut Source,
183 scratch: &mut Scratch<B>,
184 ) where
185 Module<B>: SvpPrepare<B>
186 + SvpApplyDftToDft<B>
187 + VecZnxIdftApplyConsume<B>
188 + VecZnxBigAddNormal<B>
189 + VecZnxBigAddSmallInplace<B>
190 + VecZnxBigNormalize<B>,
191 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
192 {
193 self.encrypt_pk_internal::<DataPt, DataPk, B>(module, Some((pt, 0)), pk, source_xu, source_xe, scratch);
194 }
195
196 pub fn encrypt_zero_pk<DataPk: DataRef, B: Backend>(
197 &mut self,
198 module: &Module<B>,
199 pk: &GLWEPublicKeyPrepared<DataPk, B>,
200 source_xu: &mut Source,
201 source_xe: &mut Source,
202 scratch: &mut Scratch<B>,
203 ) where
204 Module<B>: SvpPrepare<B>
205 + SvpApplyDftToDft<B>
206 + VecZnxIdftApplyConsume<B>
207 + VecZnxBigAddNormal<B>
208 + VecZnxBigAddSmallInplace<B>
209 + VecZnxBigNormalize<B>,
210 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
211 {
212 self.encrypt_pk_internal::<Vec<u8>, DataPk, B>(
213 module,
214 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
215 pk,
216 source_xu,
217 source_xe,
218 scratch,
219 );
220 }
221
222 #[allow(clippy::too_many_arguments)]
223 pub(crate) fn encrypt_pk_internal<DataPt: DataRef, DataPk: DataRef, B: Backend>(
224 &mut self,
225 module: &Module<B>,
226 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
227 pk: &GLWEPublicKeyPrepared<DataPk, B>,
228 source_xu: &mut Source,
229 source_xe: &mut Source,
230 scratch: &mut Scratch<B>,
231 ) where
232 Module<B>: SvpPrepare<B>
233 + SvpApplyDftToDft<B>
234 + VecZnxIdftApplyConsume<B>
235 + VecZnxBigAddNormal<B>
236 + VecZnxBigAddSmallInplace<B>
237 + VecZnxBigNormalize<B>,
238 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
239 {
240 #[cfg(debug_assertions)]
241 {
242 assert_eq!(self.base2k(), pk.base2k());
243 assert_eq!(self.n(), pk.n());
244 assert_eq!(self.rank(), pk.rank());
245 if let Some((pt, _)) = pt {
246 assert_eq!(pt.base2k(), pk.base2k());
247 assert_eq!(pt.n(), pk.n());
248 }
249 }
250
251 let base2k: usize = pk.base2k().into();
252 let size_pk: usize = pk.size();
253 let cols: usize = (self.rank() + 1).into();
254
255 let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n().into(), 1);
257
258 {
259 let (mut u, _) = scratch_1.take_scalar_znx(self.n().into(), 1);
260 match pk.dist {
261 Distribution::NONE => panic!(
262 "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
263 Self::generate"
264 ),
265 Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
266 Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
267 Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu),
268 Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu),
269 Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu),
270 Distribution::ZERO => {}
271 }
272
273 module.svp_prepare(&mut u_dft, 0, &u, 0);
274 }
275
276 (0..cols).for_each(|i| {
278 let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), 1, size_pk);
279 module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
281
282 let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft);
284
285 module.vec_znx_big_add_normal(
287 base2k,
288 &mut ci_big,
289 0,
290 pk.k().into(),
291 source_xe,
292 SIGMA,
293 SIGMA_BOUND,
294 );
295
296 if let Some((pt, col)) = pt
298 && col == i
299 {
300 module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0);
301 }
302
303 module.vec_znx_big_normalize(base2k, &mut self.data, i, base2k, &ci_big, 0, scratch_2);
305 });
306 }
307}
308
309#[allow(clippy::too_many_arguments)]
310pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk: DataRef, B: Backend>(
311 module: &Module<B>,
312 base2k: usize,
313 k: usize,
314 ct: &mut VecZnx<DataCt>,
315 cols: usize,
316 compressed: bool,
317 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
318 sk: &GLWESecretPrepared<DataSk, B>,
319 source_xa: &mut Source,
320 source_xe: &mut Source,
321 sigma: f64,
322 scratch: &mut Scratch<B>,
323) where
324 Module<B>: VecZnxDftAllocBytes
325 + VecZnxBigNormalize<B>
326 + VecZnxDftApply<B>
327 + SvpApplyDftToDftInplace<B>
328 + VecZnxIdftApplyConsume<B>
329 + VecZnxNormalizeTmpBytes
330 + VecZnxFillUniform
331 + VecZnxSubInplace
332 + VecZnxAddInplace
333 + VecZnxNormalizeInplace<B>
334 + VecZnxAddNormal
335 + VecZnxNormalize<B>
336 + VecZnxSub,
337 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
338{
339 #[cfg(debug_assertions)]
340 {
341 if compressed {
342 assert_eq!(
343 ct.cols(),
344 1,
345 "invalid ciphertext: compressed tag=true but #cols={} != 1",
346 ct.cols()
347 )
348 }
349 }
350
351 let size: usize = ct.size();
352
353 let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size);
354 c0.zero();
355
356 {
357 let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size);
358
359 (1..cols).for_each(|i| {
362 let col_ct: usize = if compressed { 0 } else { i };
363
364 module.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa);
366
367 let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
368
369 if let Some((pt, col)) = pt {
373 if i == col {
374 module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
375 module.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3);
376 module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
377 } else {
378 module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
379 }
380 } else {
381 module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
382 }
383
384 module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
385 let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft);
386
387 module.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3);
389
390 module.vec_znx_sub_inplace(&mut c0, 0, &ci, 0);
392 });
393 }
394
395 module.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
397
398 if let Some((pt, col)) = pt
400 && col == 0
401 {
402 module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0);
403 }
404
405 module.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1);
407}