1use std::collections::HashMap;
2
3use poulpy_hal::{
4 api::{
5 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
6 VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
7 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate,
8 VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd,
9 VmpApplyDftToDftTmpBytes,
10 },
11 layouts::{Backend, DataMut, DataRef, Module, Scratch},
12};
13
14use crate::{
15 GLWEOperations, TakeGLWECt,
16 layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared},
17};
18
19pub struct GLWEPacker {
24 accumulators: Vec<Accumulator>,
25 log_batch: usize,
26 counter: usize,
27}
28
29struct Accumulator {
32 data: GLWECiphertext<Vec<u8>>,
33 value: bool, control: bool, }
36
37impl Accumulator {
38 pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self {
47 Self {
48 data: GLWECiphertext::alloc(n, basek, k, rank),
49 value: false,
50 control: false,
51 }
52 }
53}
54
55impl GLWEPacker {
56 pub fn new(n: usize, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self {
70 let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
71 let log_n: usize = (usize::BITS - (n - 1).leading_zeros()) as _;
72 (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(n, basek, k, rank)));
73 Self {
74 accumulators,
75 log_batch,
76 counter: 0,
77 }
78 }
79
80 fn reset(&mut self) {
82 for i in 0..self.accumulators.len() {
83 self.accumulators[i].value = false;
84 self.accumulators[i].control = false;
85 }
86 self.counter = 0;
87 }
88
89 pub fn scratch_space<B: Backend>(
91 module: &Module<B>,
92 basek: usize,
93 ct_k: usize,
94 k_ksk: usize,
95 digits: usize,
96 rank: usize,
97 ) -> usize
98 where
99 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
100 {
101 pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank)
102 }
103
104 pub fn galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
105 GLWECiphertext::trace_galois_elements(module)
106 }
107
108 pub fn add<DataA: DataRef, DataAK: DataRef, B: Backend>(
118 &mut self,
119 module: &Module<B>,
120 a: Option<&GLWECiphertext<DataA>>,
121 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
122 scratch: &mut Scratch<B>,
123 ) where
124 Module<B>: VecZnxDftAllocBytes
125 + VmpApplyDftToDftTmpBytes
126 + VecZnxBigNormalizeTmpBytes
127 + VmpApplyDftToDft<B>
128 + VmpApplyDftToDftAdd<B>
129 + VecZnxDftApply<B>
130 + VecZnxIdftApplyConsume<B>
131 + VecZnxBigAddSmallInplace<B>
132 + VecZnxBigNormalize<B>
133 + VecZnxCopy
134 + VecZnxRotateInplace<B>
135 + VecZnxSub
136 + VecZnxNegateInplace
137 + VecZnxRshInplace<B>
138 + VecZnxAddInplace
139 + VecZnxNormalizeInplace<B>
140 + VecZnxSubABInplace
141 + VecZnxRotate
142 + VecZnxAutomorphismInplace<B>
143 + VecZnxBigSubSmallBInplace<B>
144 + VecZnxBigAutomorphismInplace<B>,
145 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
146 {
147 assert!(
148 self.counter < self.accumulators[0].data.n(),
149 "Packing limit of {} reached",
150 self.accumulators[0].data.n() >> self.log_batch
151 );
152
153 pack_core(
154 module,
155 a,
156 &mut self.accumulators,
157 self.log_batch,
158 auto_keys,
159 scratch,
160 );
161 self.counter += 1 << self.log_batch;
162 }
163
164 pub fn flush<Data: DataMut, B: Backend>(&mut self, module: &Module<B>, res: &mut GLWECiphertext<Data>)
166 where
167 Module<B>: VecZnxCopy,
168 {
169 assert!(self.counter == self.accumulators[0].data.n());
170 res.copy(
172 module,
173 &self.accumulators[module.log_n() - self.log_batch - 1].data,
174 );
175
176 self.reset();
177 }
178}
179
180fn pack_core_scratch_space<B: Backend>(
181 module: &Module<B>,
182 basek: usize,
183 ct_k: usize,
184 k_ksk: usize,
185 digits: usize,
186 rank: usize,
187) -> usize
188where
189 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
190{
191 combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank)
192}
193
194fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
195 module: &Module<B>,
196 a: Option<&GLWECiphertext<D>>,
197 accumulators: &mut [Accumulator],
198 i: usize,
199 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
200 scratch: &mut Scratch<B>,
201) where
202 Module<B>: VecZnxDftAllocBytes
203 + VmpApplyDftToDftTmpBytes
204 + VecZnxBigNormalizeTmpBytes
205 + VmpApplyDftToDft<B>
206 + VmpApplyDftToDftAdd<B>
207 + VecZnxDftApply<B>
208 + VecZnxIdftApplyConsume<B>
209 + VecZnxBigAddSmallInplace<B>
210 + VecZnxBigNormalize<B>
211 + VecZnxCopy
212 + VecZnxRotateInplace<B>
213 + VecZnxSub
214 + VecZnxNegateInplace
215 + VecZnxRshInplace<B>
216 + VecZnxAddInplace
217 + VecZnxNormalizeInplace<B>
218 + VecZnxSubABInplace
219 + VecZnxRotate
220 + VecZnxAutomorphismInplace<B>
221 + VecZnxBigSubSmallBInplace<B>
222 + VecZnxBigAutomorphismInplace<B>,
223 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
224{
225 let log_n: usize = module.log_n();
226
227 if i == log_n {
228 return;
229 }
230
231 let (acc_prev, acc_next) = accumulators.split_at_mut(1);
233
234 if !acc_prev[0].control {
236 let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; if let Some(a_ref) = a {
240 acc_mut_ref.data.copy(module, a_ref);
241 acc_mut_ref.value = true
242 } else {
243 acc_mut_ref.value = false
244 }
245 acc_mut_ref.control = true; } else {
247 combine(module, &mut acc_prev[0], a, i, auto_keys, scratch);
249 acc_prev[0].control = false;
250
251 if acc_prev[0].value {
253 pack_core(
254 module,
255 Some(&acc_prev[0].data),
256 acc_next,
257 i + 1,
258 auto_keys,
259 scratch,
260 );
261 } else {
262 pack_core(
263 module,
264 None::<&GLWECiphertext<Vec<u8>>>,
265 acc_next,
266 i + 1,
267 auto_keys,
268 scratch,
269 );
270 }
271 }
272}
273
274fn combine_scratch_space<B: Backend>(
275 module: &Module<B>,
276 basek: usize,
277 ct_k: usize,
278 k_ksk: usize,
279 digits: usize,
280 rank: usize,
281) -> usize
282where
283 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
284{
285 GLWECiphertext::bytes_of(module.n(), basek, ct_k, rank)
286 + (GLWECiphertext::rsh_scratch_space(module.n())
287 | GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank))
288}
289
290fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
292 module: &Module<B>,
293 acc: &mut Accumulator,
294 b: Option<&GLWECiphertext<D>>,
295 i: usize,
296 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
297 scratch: &mut Scratch<B>,
298) where
299 Module<B>: VecZnxDftAllocBytes
300 + VmpApplyDftToDftTmpBytes
301 + VecZnxBigNormalizeTmpBytes
302 + VmpApplyDftToDft<B>
303 + VmpApplyDftToDftAdd<B>
304 + VecZnxDftApply<B>
305 + VecZnxIdftApplyConsume<B>
306 + VecZnxBigAddSmallInplace<B>
307 + VecZnxBigNormalize<B>
308 + VecZnxCopy
309 + VecZnxRotateInplace<B>
310 + VecZnxSub
311 + VecZnxNegateInplace
312 + VecZnxRshInplace<B>
313 + VecZnxAddInplace
314 + VecZnxNormalizeInplace<B>
315 + VecZnxSubABInplace
316 + VecZnxRotate
317 + VecZnxAutomorphismInplace<B>
318 + VecZnxBigSubSmallBInplace<B>
319 + VecZnxBigAutomorphismInplace<B>,
320 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
321{
322 let n: usize = acc.data.n();
323 let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _;
324 let a: &mut GLWECiphertext<Vec<u8>> = &mut acc.data;
325 let basek: usize = a.basek();
326 let k: usize = a.k();
327 let rank: usize = a.rank();
328
329 let gal_el: i64 = if i == 0 {
330 -1
331 } else {
332 module.galois_element(1 << (i - 1))
333 };
334
335 let t: i64 = 1 << (log_n - i - 1);
336
337 if acc.value {
348 if let Some(b) = b {
349 let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
350
351 a.rotate_inplace(module, -t, scratch_1);
353
354 tmp_b.sub(module, a, b);
356 tmp_b.rsh(module, 1, scratch_1);
357
358 a.add_inplace(module, b);
360 a.rsh(module, 1, scratch_1);
361
362 tmp_b.normalize_inplace(module, scratch_1);
363
364 if let Some(key) = auto_keys.get(&gal_el) {
366 tmp_b.automorphism_inplace(module, key, scratch_1);
367 } else {
368 panic!("auto_key[{}] not found", gal_el);
369 }
370
371 a.sub_inplace_ab(module, &tmp_b);
373 a.normalize_inplace(module, scratch_1);
374
375 a.rotate_inplace(module, t, scratch_1);
379 } else {
380 a.rsh(module, 1, scratch);
381 if let Some(key) = auto_keys.get(&gal_el) {
383 a.automorphism_add_inplace(module, key, scratch);
384 } else {
385 panic!("auto_key[{}] not found", gal_el);
386 }
387 }
388 } else if let Some(b) = b {
389 let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
390 tmp_b.rotate(module, 1 << (log_n - i - 1), b);
391 tmp_b.rsh(module, 1, scratch_1);
392
393 if let Some(key) = auto_keys.get(&gal_el) {
395 a.automorphism_sub_ba(module, &tmp_b, key, scratch_1);
396 } else {
397 panic!("auto_key[{}] not found", gal_el);
398 }
399
400 acc.value = true;
401 }
402}