1use std::collections::HashMap;
2
3use poulpy_hal::{
4 api::{
5 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
6 VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
7 VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxNegateInplace, VecZnxNormalizeInplace,
8 VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VmpApply, VmpApplyAdd,
9 VmpApplyTmpBytes,
10 },
11 layouts::{Backend, DataMut, DataRef, Module, Scratch},
12};
13
14use crate::{
15 GLWEOperations, TakeGLWECt,
16 layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared},
17};
18
19pub struct GLWEPacker {
24 accumulators: Vec<Accumulator>,
25 log_batch: usize,
26 counter: usize,
27}
28
29struct Accumulator {
32 data: GLWECiphertext<Vec<u8>>,
33 value: bool, control: bool, }
36
37impl Accumulator {
38 pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self {
47 Self {
48 data: GLWECiphertext::alloc(n, basek, k, rank),
49 value: false,
50 control: false,
51 }
52 }
53}
54
55impl GLWEPacker {
56 pub fn new(n: usize, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self {
70 let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
71 let log_n: usize = (usize::BITS - (n - 1).leading_zeros()) as _;
72 (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(n, basek, k, rank)));
73 Self {
74 accumulators,
75 log_batch,
76 counter: 0,
77 }
78 }
79
80 fn reset(&mut self) {
82 for i in 0..self.accumulators.len() {
83 self.accumulators[i].value = false;
84 self.accumulators[i].control = false;
85 }
86 self.counter = 0;
87 }
88
89 pub fn scratch_space<B: Backend>(
91 module: &Module<B>,
92 n: usize,
93 basek: usize,
94 ct_k: usize,
95 k_ksk: usize,
96 digits: usize,
97 rank: usize,
98 ) -> usize
99 where
100 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
101 {
102 pack_core_scratch_space(module, n, basek, ct_k, k_ksk, digits, rank)
103 }
104
105 pub fn galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
106 GLWECiphertext::trace_galois_elements(module)
107 }
108
109 pub fn add<DataA: DataRef, DataAK: DataRef, B: Backend>(
119 &mut self,
120 module: &Module<B>,
121 a: Option<&GLWECiphertext<DataA>>,
122 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
123 scratch: &mut Scratch<B>,
124 ) where
125 Module<B>: VecZnxDftAllocBytes
126 + VmpApplyTmpBytes
127 + VecZnxBigNormalizeTmpBytes
128 + VmpApply<B>
129 + VmpApplyAdd<B>
130 + VecZnxDftFromVecZnx<B>
131 + VecZnxDftToVecZnxBigConsume<B>
132 + VecZnxBigAddSmallInplace<B>
133 + VecZnxBigNormalize<B>
134 + VecZnxCopy
135 + VecZnxRotateInplace
136 + VecZnxSub
137 + VecZnxNegateInplace
138 + VecZnxRshInplace
139 + VecZnxAddInplace
140 + VecZnxNormalizeInplace<B>
141 + VecZnxSubABInplace
142 + VecZnxRotate
143 + VecZnxAutomorphismInplace
144 + VecZnxBigSubSmallBInplace<B>
145 + VecZnxBigAutomorphismInplace<B>,
146 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
147 {
148 assert!(
149 self.counter < self.accumulators[0].data.n(),
150 "Packing limit of {} reached",
151 self.accumulators[0].data.n() >> self.log_batch
152 );
153
154 pack_core(
155 module,
156 a,
157 &mut self.accumulators,
158 self.log_batch,
159 auto_keys,
160 scratch,
161 );
162 self.counter += 1 << self.log_batch;
163 }
164
165 pub fn flush<Data: DataMut, B: Backend>(&mut self, module: &Module<B>, res: &mut GLWECiphertext<Data>)
167 where
168 Module<B>: VecZnxCopy,
169 {
170 assert!(self.counter == self.accumulators[0].data.n());
171 res.copy(
173 module,
174 &self.accumulators[module.log_n() - self.log_batch - 1].data,
175 );
176
177 self.reset();
178 }
179}
180
181fn pack_core_scratch_space<B: Backend>(
182 module: &Module<B>,
183 n: usize,
184 basek: usize,
185 ct_k: usize,
186 k_ksk: usize,
187 digits: usize,
188 rank: usize,
189) -> usize
190where
191 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
192{
193 combine_scratch_space(module, n, basek, ct_k, k_ksk, digits, rank)
194}
195
196fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
197 module: &Module<B>,
198 a: Option<&GLWECiphertext<D>>,
199 accumulators: &mut [Accumulator],
200 i: usize,
201 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
202 scratch: &mut Scratch<B>,
203) where
204 Module<B>: VecZnxDftAllocBytes
205 + VmpApplyTmpBytes
206 + VecZnxBigNormalizeTmpBytes
207 + VmpApply<B>
208 + VmpApplyAdd<B>
209 + VecZnxDftFromVecZnx<B>
210 + VecZnxDftToVecZnxBigConsume<B>
211 + VecZnxBigAddSmallInplace<B>
212 + VecZnxBigNormalize<B>
213 + VecZnxCopy
214 + VecZnxRotateInplace
215 + VecZnxSub
216 + VecZnxNegateInplace
217 + VecZnxRshInplace
218 + VecZnxAddInplace
219 + VecZnxNormalizeInplace<B>
220 + VecZnxSubABInplace
221 + VecZnxRotate
222 + VecZnxAutomorphismInplace
223 + VecZnxBigSubSmallBInplace<B>
224 + VecZnxBigAutomorphismInplace<B>,
225 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
226{
227 let log_n: usize = module.log_n();
228
229 if i == log_n {
230 return;
231 }
232
233 let (acc_prev, acc_next) = accumulators.split_at_mut(1);
235
236 if !acc_prev[0].control {
238 let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; if let Some(a_ref) = a {
242 acc_mut_ref.data.copy(module, a_ref);
243 acc_mut_ref.value = true
244 } else {
245 acc_mut_ref.value = false
246 }
247 acc_mut_ref.control = true; } else {
249 combine(module, &mut acc_prev[0], a, i, auto_keys, scratch);
251 acc_prev[0].control = false;
252
253 if acc_prev[0].value {
255 pack_core(
256 module,
257 Some(&acc_prev[0].data),
258 acc_next,
259 i + 1,
260 auto_keys,
261 scratch,
262 );
263 } else {
264 pack_core(
265 module,
266 None::<&GLWECiphertext<Vec<u8>>>,
267 acc_next,
268 i + 1,
269 auto_keys,
270 scratch,
271 );
272 }
273 }
274}
275
276fn combine_scratch_space<B: Backend>(
277 module: &Module<B>,
278 n: usize,
279 basek: usize,
280 ct_k: usize,
281 k_ksk: usize,
282 digits: usize,
283 rank: usize,
284) -> usize
285where
286 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
287{
288 GLWECiphertext::bytes_of(n, basek, ct_k, rank)
289 + (GLWECiphertext::rsh_scratch_space(n)
290 | GLWECiphertext::automorphism_scratch_space(module, n, basek, ct_k, ct_k, k_ksk, digits, rank))
291}
292
293fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
295 module: &Module<B>,
296 acc: &mut Accumulator,
297 b: Option<&GLWECiphertext<D>>,
298 i: usize,
299 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
300 scratch: &mut Scratch<B>,
301) where
302 Module<B>: VecZnxDftAllocBytes
303 + VmpApplyTmpBytes
304 + VecZnxBigNormalizeTmpBytes
305 + VmpApply<B>
306 + VmpApplyAdd<B>
307 + VecZnxDftFromVecZnx<B>
308 + VecZnxDftToVecZnxBigConsume<B>
309 + VecZnxBigAddSmallInplace<B>
310 + VecZnxBigNormalize<B>
311 + VecZnxCopy
312 + VecZnxRotateInplace
313 + VecZnxSub
314 + VecZnxNegateInplace
315 + VecZnxRshInplace
316 + VecZnxAddInplace
317 + VecZnxNormalizeInplace<B>
318 + VecZnxSubABInplace
319 + VecZnxRotate
320 + VecZnxAutomorphismInplace
321 + VecZnxBigSubSmallBInplace<B>
322 + VecZnxBigAutomorphismInplace<B>,
323 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
324{
325 let n: usize = acc.data.n();
326 let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _;
327 let a: &mut GLWECiphertext<Vec<u8>> = &mut acc.data;
328 let basek: usize = a.basek();
329 let k: usize = a.k();
330 let rank: usize = a.rank();
331
332 let gal_el: i64 = if i == 0 {
333 -1
334 } else {
335 module.galois_element(1 << (i - 1))
336 };
337
338 let t: i64 = 1 << (log_n - i - 1);
339
340 if acc.value {
351 if let Some(b) = b {
352 let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
353
354 a.rotate_inplace(module, -t);
356
357 tmp_b.sub(module, a, b);
359 tmp_b.rsh(module, 1);
360
361 a.add_inplace(module, b);
363 a.rsh(module, 1);
364
365 tmp_b.normalize_inplace(module, scratch_1);
366
367 if let Some(key) = auto_keys.get(&gal_el) {
369 tmp_b.automorphism_inplace(module, key, scratch_1);
370 } else {
371 panic!("auto_key[{}] not found", gal_el);
372 }
373
374 a.sub_inplace_ab(module, &tmp_b);
376 a.normalize_inplace(module, scratch_1);
377
378 a.rotate_inplace(module, t);
382 } else {
383 a.rsh(module, 1);
384 if let Some(key) = auto_keys.get(&gal_el) {
386 a.automorphism_add_inplace(module, key, scratch);
387 } else {
388 panic!("auto_key[{}] not found", gal_el);
389 }
390 }
391 } else if let Some(b) = b {
392 let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
393 tmp_b.rotate(module, 1 << (log_n - i - 1), b);
394 tmp_b.rsh(module, 1);
395
396 if let Some(key) = auto_keys.get(&gal_el) {
398 a.automorphism_sub_ba(module, &tmp_b, key, scratch_1);
399 } else {
400 panic!("auto_key[{}] not found", gal_el);
401 }
402
403 acc.value = true;
404 }
405}