1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol,
4 TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace,
5 VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero},
9 source::Source,
10};
11
12use crate::{
13 dist::Distribution,
14 encryption::{SIGMA, SIGMA_BOUND},
15 layouts::{
16 GLWECiphertext, GLWEPlaintext, Infos,
17 prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared},
18 },
19};
20
21impl GLWECiphertext<Vec<u8>> {
22 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
23 where
24 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
25 {
26 let size: usize = k.div_ceil(basek);
27 module.vec_znx_normalize_tmp_bytes()
28 + 2 * VecZnx::alloc_bytes(module.n(), 1, size)
29 + module.vec_znx_dft_alloc_bytes(1, size)
30 }
31 pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
32 where
33 Module<B>: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
34 {
35 let size: usize = k.div_ceil(basek);
36 ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size))
37 | ScalarZnx::alloc_bytes(module.n(), 1))
38 + module.svp_ppol_alloc_bytes(1)
39 + module.vec_znx_normalize_tmp_bytes()
40 }
41}
42
43impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
44 #[allow(clippy::too_many_arguments)]
45 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
46 &mut self,
47 module: &Module<B>,
48 pt: &GLWEPlaintext<DataPt>,
49 sk: &GLWESecretPrepared<DataSk, B>,
50 source_xa: &mut Source,
51 source_xe: &mut Source,
52 scratch: &mut Scratch<B>,
53 ) where
54 Module<B>: VecZnxDftAllocBytes
55 + VecZnxBigNormalize<B>
56 + VecZnxDftApply<B>
57 + SvpApplyDftToDftInplace<B>
58 + VecZnxIdftApplyConsume<B>
59 + VecZnxNormalizeTmpBytes
60 + VecZnxFillUniform
61 + VecZnxSubABInplace
62 + VecZnxAddInplace
63 + VecZnxNormalizeInplace<B>
64 + VecZnxAddNormal
65 + VecZnxNormalize<B>
66 + VecZnxSub,
67 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
68 {
69 #[cfg(debug_assertions)]
70 {
71 assert_eq!(self.rank(), sk.rank());
72 assert_eq!(sk.n(), self.n());
73 assert_eq!(pt.n(), self.n());
74 assert!(
75 scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
76 "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
77 scratch.available(),
78 GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
79 )
80 }
81
82 self.encrypt_sk_internal(module, Some((pt, 0)), sk, source_xa, source_xe, scratch);
83 }
84
85 pub fn encrypt_zero_sk<DataSk: DataRef, B: Backend>(
86 &mut self,
87 module: &Module<B>,
88 sk: &GLWESecretPrepared<DataSk, B>,
89 source_xa: &mut Source,
90 source_xe: &mut Source,
91 scratch: &mut Scratch<B>,
92 ) where
93 Module<B>: VecZnxDftAllocBytes
94 + VecZnxBigNormalize<B>
95 + VecZnxDftApply<B>
96 + SvpApplyDftToDftInplace<B>
97 + VecZnxIdftApplyConsume<B>
98 + VecZnxNormalizeTmpBytes
99 + VecZnxFillUniform
100 + VecZnxSubABInplace
101 + VecZnxAddInplace
102 + VecZnxNormalizeInplace<B>
103 + VecZnxAddNormal
104 + VecZnxNormalize<B>
105 + VecZnxSub,
106 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
107 {
108 #[cfg(debug_assertions)]
109 {
110 assert_eq!(self.rank(), sk.rank());
111 assert_eq!(sk.n(), self.n());
112 assert!(
113 scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
114 "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
115 scratch.available(),
116 GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
117 )
118 }
119 self.encrypt_sk_internal(
120 module,
121 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
122 sk,
123 source_xa,
124 source_xe,
125 scratch,
126 );
127 }
128
129 #[allow(clippy::too_many_arguments)]
130 pub(crate) fn encrypt_sk_internal<DataPt: DataRef, DataSk: DataRef, B: Backend>(
131 &mut self,
132 module: &Module<B>,
133 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
134 sk: &GLWESecretPrepared<DataSk, B>,
135 source_xa: &mut Source,
136 source_xe: &mut Source,
137 scratch: &mut Scratch<B>,
138 ) where
139 Module<B>: VecZnxDftAllocBytes
140 + VecZnxBigNormalize<B>
141 + VecZnxDftApply<B>
142 + SvpApplyDftToDftInplace<B>
143 + VecZnxIdftApplyConsume<B>
144 + VecZnxNormalizeTmpBytes
145 + VecZnxFillUniform
146 + VecZnxSubABInplace
147 + VecZnxAddInplace
148 + VecZnxNormalizeInplace<B>
149 + VecZnxAddNormal
150 + VecZnxNormalize<B>
151 + VecZnxSub,
152 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
153 {
154 let cols: usize = self.rank() + 1;
155 glwe_encrypt_sk_internal(
156 module,
157 self.basek(),
158 self.k(),
159 &mut self.data,
160 cols,
161 false,
162 pt,
163 sk,
164 source_xa,
165 source_xe,
166 SIGMA,
167 scratch,
168 );
169 }
170
171 #[allow(clippy::too_many_arguments)]
172 pub fn encrypt_pk<DataPt: DataRef, DataPk: DataRef, B: Backend>(
173 &mut self,
174 module: &Module<B>,
175 pt: &GLWEPlaintext<DataPt>,
176 pk: &GLWEPublicKeyPrepared<DataPk, B>,
177 source_xu: &mut Source,
178 source_xe: &mut Source,
179 scratch: &mut Scratch<B>,
180 ) where
181 Module<B>: SvpPrepare<B>
182 + SvpApplyDftToDft<B>
183 + VecZnxIdftApplyConsume<B>
184 + VecZnxBigAddNormal<B>
185 + VecZnxBigAddSmallInplace<B>
186 + VecZnxBigNormalize<B>,
187 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
188 {
189 self.encrypt_pk_internal::<DataPt, DataPk, B>(module, Some((pt, 0)), pk, source_xu, source_xe, scratch);
190 }
191
192 pub fn encrypt_zero_pk<DataPk: DataRef, B: Backend>(
193 &mut self,
194 module: &Module<B>,
195 pk: &GLWEPublicKeyPrepared<DataPk, B>,
196 source_xu: &mut Source,
197 source_xe: &mut Source,
198 scratch: &mut Scratch<B>,
199 ) where
200 Module<B>: SvpPrepare<B>
201 + SvpApplyDftToDft<B>
202 + VecZnxIdftApplyConsume<B>
203 + VecZnxBigAddNormal<B>
204 + VecZnxBigAddSmallInplace<B>
205 + VecZnxBigNormalize<B>,
206 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
207 {
208 self.encrypt_pk_internal::<Vec<u8>, DataPk, B>(
209 module,
210 None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
211 pk,
212 source_xu,
213 source_xe,
214 scratch,
215 );
216 }
217
218 #[allow(clippy::too_many_arguments)]
219 pub(crate) fn encrypt_pk_internal<DataPt: DataRef, DataPk: DataRef, B: Backend>(
220 &mut self,
221 module: &Module<B>,
222 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
223 pk: &GLWEPublicKeyPrepared<DataPk, B>,
224 source_xu: &mut Source,
225 source_xe: &mut Source,
226 scratch: &mut Scratch<B>,
227 ) where
228 Module<B>: SvpPrepare<B>
229 + SvpApplyDftToDft<B>
230 + VecZnxIdftApplyConsume<B>
231 + VecZnxBigAddNormal<B>
232 + VecZnxBigAddSmallInplace<B>
233 + VecZnxBigNormalize<B>,
234 Scratch<B>: TakeSvpPPol<B> + TakeScalarZnx + TakeVecZnxDft<B>,
235 {
236 #[cfg(debug_assertions)]
237 {
238 assert_eq!(self.basek(), pk.basek());
239 assert_eq!(self.n(), pk.n());
240 assert_eq!(self.rank(), pk.rank());
241 if let Some((pt, _)) = pt {
242 assert_eq!(pt.basek(), pk.basek());
243 assert_eq!(pt.n(), pk.n());
244 }
245 }
246
247 let basek: usize = pk.basek();
248 let size_pk: usize = pk.size();
249 let cols: usize = self.rank() + 1;
250
251 let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n(), 1);
253
254 {
255 let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1);
256 match pk.dist {
257 Distribution::NONE => panic!(
258 "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
259 Self::generate"
260 ),
261 Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
262 Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
263 Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu),
264 Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu),
265 Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu),
266 Distribution::ZERO => {}
267 }
268
269 module.svp_prepare(&mut u_dft, 0, &u, 0);
270 }
271
272 (0..cols).for_each(|i| {
274 let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk);
275 module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
277
278 let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft);
280
281 module.vec_znx_big_add_normal(basek, &mut ci_big, 0, pk.k(), source_xe, SIGMA, SIGMA_BOUND);
283
284 if let Some((pt, col)) = pt
286 && col == i
287 {
288 module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0);
289 }
290
291 module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2);
293 });
294 }
295}
296
297#[allow(clippy::too_many_arguments)]
298pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk: DataRef, B: Backend>(
299 module: &Module<B>,
300 basek: usize,
301 k: usize,
302 ct: &mut VecZnx<DataCt>,
303 cols: usize,
304 compressed: bool,
305 pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
306 sk: &GLWESecretPrepared<DataSk, B>,
307 source_xa: &mut Source,
308 source_xe: &mut Source,
309 sigma: f64,
310 scratch: &mut Scratch<B>,
311) where
312 Module<B>: VecZnxDftAllocBytes
313 + VecZnxBigNormalize<B>
314 + VecZnxDftApply<B>
315 + SvpApplyDftToDftInplace<B>
316 + VecZnxIdftApplyConsume<B>
317 + VecZnxNormalizeTmpBytes
318 + VecZnxFillUniform
319 + VecZnxSubABInplace
320 + VecZnxAddInplace
321 + VecZnxNormalizeInplace<B>
322 + VecZnxAddNormal
323 + VecZnxNormalize<B>
324 + VecZnxSub,
325 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
326{
327 #[cfg(debug_assertions)]
328 {
329 if compressed {
330 assert_eq!(
331 ct.cols(),
332 1,
333 "invalid ciphertext: compressed tag=true but #cols={} != 1",
334 ct.cols()
335 )
336 }
337 }
338
339 let size: usize = ct.size();
340
341 let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size);
342 c0.zero();
343
344 {
345 let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size);
346
347 (1..cols).for_each(|i| {
350 let col_ct: usize = if compressed { 0 } else { i };
351
352 module.vec_znx_fill_uniform(basek, ct, col_ct, source_xa);
354
355 let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
356
357 if let Some((pt, col)) = pt {
361 if i == col {
362 module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
363 module.vec_znx_normalize_inplace(basek, &mut ci, 0, scratch_3);
364 module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
365 } else {
366 module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
367 }
368 } else {
369 module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
370 }
371
372 module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
373 let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft);
374
375 module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3);
377
378 module.vec_znx_sub_ab_inplace(&mut c0, 0, &ci, 0);
380 });
381 }
382
383 module.vec_znx_add_normal(basek, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND);
385
386 if let Some((pt, col)) = pt
388 && col == 0
389 {
390 module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0);
391 }
392
393 module.vec_znx_normalize(basek, ct, 0, &c0, 0, scratch_1);
395}