1use std::collections::HashMap;
2
3use poulpy_hal::{
4 api::{
5 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
6 VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy,
7 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize,
8 VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
9 VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
10 },
11 layouts::{Backend, DataMut, DataRef, Module, Scratch},
12};
13
14use crate::{
15 GLWEOperations, TakeGLWECt,
16 layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared},
17};
18
19pub struct GLWEPacker {
24 accumulators: Vec<Accumulator>,
25 log_batch: usize,
26 counter: usize,
27}
28
29struct Accumulator {
32 data: GLWECiphertext<Vec<u8>>,
33 value: bool, control: bool, }
36
37impl Accumulator {
38 pub fn alloc<A>(infos: &A) -> Self
47 where
48 A: GLWEInfos,
49 {
50 Self {
51 data: GLWECiphertext::alloc(infos),
52 value: false,
53 control: false,
54 }
55 }
56}
57
58impl GLWEPacker {
59 pub fn new<A>(infos: &A, log_batch: usize) -> Self
70 where
71 A: GLWEInfos,
72 {
73 let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
74 let log_n: usize = infos.n().log2();
75 (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos)));
76 Self {
77 accumulators,
78 log_batch,
79 counter: 0,
80 }
81 }
82
83 fn reset(&mut self) {
85 for i in 0..self.accumulators.len() {
86 self.accumulators[i].value = false;
87 self.accumulators[i].control = false;
88 }
89 self.counter = 0;
90 }
91
92 pub fn scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
94 where
95 OUT: GLWEInfos,
96 KEY: GGLWELayoutInfos,
97 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
98 {
99 pack_core_scratch_space(module, out_infos, key_infos)
100 }
101
102 pub fn galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
103 GLWECiphertext::trace_galois_elements(module)
104 }
105
106 pub fn add<DataA: DataRef, DataAK: DataRef, B: Backend>(
116 &mut self,
117 module: &Module<B>,
118 a: Option<&GLWECiphertext<DataA>>,
119 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
120 scratch: &mut Scratch<B>,
121 ) where
122 Module<B>: VecZnxDftAllocBytes
123 + VmpApplyDftToDftTmpBytes
124 + VecZnxBigNormalizeTmpBytes
125 + VmpApplyDftToDft<B>
126 + VmpApplyDftToDftAdd<B>
127 + VecZnxDftApply<B>
128 + VecZnxIdftApplyConsume<B>
129 + VecZnxBigAddSmallInplace<B>
130 + VecZnxBigNormalize<B>
131 + VecZnxCopy
132 + VecZnxRotateInplace<B>
133 + VecZnxSub
134 + VecZnxNegateInplace
135 + VecZnxRshInplace<B>
136 + VecZnxAddInplace
137 + VecZnxNormalizeInplace<B>
138 + VecZnxSubInplace
139 + VecZnxRotate
140 + VecZnxAutomorphismInplace<B>
141 + VecZnxBigSubSmallNegateInplace<B>
142 + VecZnxBigAutomorphismInplace<B>
143 + VecZnxNormalize<B>
144 + VecZnxNormalizeTmpBytes,
145 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
146 {
147 assert!(
148 (self.counter as u32) < self.accumulators[0].data.n(),
149 "Packing limit of {} reached",
150 self.accumulators[0].data.n().0 as usize >> self.log_batch
151 );
152
153 pack_core(
154 module,
155 a,
156 &mut self.accumulators,
157 self.log_batch,
158 auto_keys,
159 scratch,
160 );
161 self.counter += 1 << self.log_batch;
162 }
163
164 pub fn flush<Data: DataMut, B: Backend>(&mut self, module: &Module<B>, res: &mut GLWECiphertext<Data>)
166 where
167 Module<B>: VecZnxCopy,
168 {
169 assert!(self.counter as u32 == self.accumulators[0].data.n());
170 res.copy(
172 module,
173 &self.accumulators[module.log_n() - self.log_batch - 1].data,
174 );
175
176 self.reset();
177 }
178}
179
180fn pack_core_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
181where
182 OUT: GLWEInfos,
183 KEY: GGLWELayoutInfos,
184 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
185{
186 combine_scratch_space(module, out_infos, key_infos)
187}
188
189fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
190 module: &Module<B>,
191 a: Option<&GLWECiphertext<D>>,
192 accumulators: &mut [Accumulator],
193 i: usize,
194 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
195 scratch: &mut Scratch<B>,
196) where
197 Module<B>: VecZnxDftAllocBytes
198 + VmpApplyDftToDftTmpBytes
199 + VecZnxBigNormalizeTmpBytes
200 + VmpApplyDftToDft<B>
201 + VmpApplyDftToDftAdd<B>
202 + VecZnxDftApply<B>
203 + VecZnxIdftApplyConsume<B>
204 + VecZnxBigAddSmallInplace<B>
205 + VecZnxBigNormalize<B>
206 + VecZnxCopy
207 + VecZnxRotateInplace<B>
208 + VecZnxSub
209 + VecZnxNegateInplace
210 + VecZnxRshInplace<B>
211 + VecZnxAddInplace
212 + VecZnxNormalizeInplace<B>
213 + VecZnxSubInplace
214 + VecZnxRotate
215 + VecZnxAutomorphismInplace<B>
216 + VecZnxBigSubSmallNegateInplace<B>
217 + VecZnxBigAutomorphismInplace<B>
218 + VecZnxNormalize<B>
219 + VecZnxNormalizeTmpBytes,
220 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
221{
222 let log_n: usize = module.log_n();
223
224 if i == log_n {
225 return;
226 }
227
228 let (acc_prev, acc_next) = accumulators.split_at_mut(1);
230
231 if !acc_prev[0].control {
233 let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; if let Some(a_ref) = a {
237 acc_mut_ref.data.copy(module, a_ref);
238 acc_mut_ref.value = true
239 } else {
240 acc_mut_ref.value = false
241 }
242 acc_mut_ref.control = true; } else {
244 combine(module, &mut acc_prev[0], a, i, auto_keys, scratch);
246 acc_prev[0].control = false;
247
248 if acc_prev[0].value {
250 pack_core(
251 module,
252 Some(&acc_prev[0].data),
253 acc_next,
254 i + 1,
255 auto_keys,
256 scratch,
257 );
258 } else {
259 pack_core(
260 module,
261 None::<&GLWECiphertext<Vec<u8>>>,
262 acc_next,
263 i + 1,
264 auto_keys,
265 scratch,
266 );
267 }
268 }
269}
270
271fn combine_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
272where
273 OUT: GLWEInfos,
274 KEY: GGLWELayoutInfos,
275 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
276{
277 GLWECiphertext::alloc_bytes(out_infos)
278 + (GLWECiphertext::rsh_scratch_space(module.n())
279 | GLWECiphertext::automorphism_inplace_scratch_space(module, out_infos, key_infos))
280}
281
282fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
284 module: &Module<B>,
285 acc: &mut Accumulator,
286 b: Option<&GLWECiphertext<D>>,
287 i: usize,
288 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
289 scratch: &mut Scratch<B>,
290) where
291 Module<B>: VecZnxDftAllocBytes
292 + VmpApplyDftToDftTmpBytes
293 + VecZnxBigNormalizeTmpBytes
294 + VmpApplyDftToDft<B>
295 + VmpApplyDftToDftAdd<B>
296 + VecZnxDftApply<B>
297 + VecZnxIdftApplyConsume<B>
298 + VecZnxBigAddSmallInplace<B>
299 + VecZnxBigNormalize<B>
300 + VecZnxCopy
301 + VecZnxRotateInplace<B>
302 + VecZnxSub
303 + VecZnxNegateInplace
304 + VecZnxRshInplace<B>
305 + VecZnxAddInplace
306 + VecZnxNormalizeInplace<B>
307 + VecZnxSubInplace
308 + VecZnxRotate
309 + VecZnxAutomorphismInplace<B>
310 + VecZnxBigSubSmallNegateInplace<B>
311 + VecZnxBigAutomorphismInplace<B>
312 + VecZnxNormalize<B>
313 + VecZnxNormalizeTmpBytes,
314 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeGLWECt,
315{
316 let log_n: usize = acc.data.n().log2();
317 let a: &mut GLWECiphertext<Vec<u8>> = &mut acc.data;
318
319 let gal_el: i64 = if i == 0 {
320 -1
321 } else {
322 module.galois_element(1 << (i - 1))
323 };
324
325 let t: i64 = 1 << (log_n - i - 1);
326
327 if acc.value {
338 if let Some(b) = b {
339 let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
340
341 a.rotate_inplace(module, -t, scratch_1);
343
344 tmp_b.sub(module, a, b);
346 tmp_b.rsh(module, 1, scratch_1);
347
348 a.add_inplace(module, b);
350 a.rsh(module, 1, scratch_1);
351
352 tmp_b.normalize_inplace(module, scratch_1);
353
354 if let Some(key) = auto_keys.get(&gal_el) {
356 tmp_b.automorphism_inplace(module, key, scratch_1);
357 } else {
358 panic!("auto_key[{gal_el}] not found");
359 }
360
361 a.sub_inplace_ab(module, &tmp_b);
363 a.normalize_inplace(module, scratch_1);
364
365 a.rotate_inplace(module, t, scratch_1);
369 } else {
370 a.rsh(module, 1, scratch);
371 if let Some(key) = auto_keys.get(&gal_el) {
373 a.automorphism_add_inplace(module, key, scratch);
374 } else {
375 panic!("auto_key[{gal_el}] not found");
376 }
377 }
378 } else if let Some(b) = b {
379 let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
380 tmp_b.rotate(module, 1 << (log_n - i - 1), b);
381 tmp_b.rsh(module, 1, scratch_1);
382
383 if let Some(key) = auto_keys.get(&gal_el) {
385 a.automorphism_sub_negate(module, &tmp_b, key, scratch_1);
386 } else {
387 panic!("auto_key[{gal_el}] not found");
388 }
389
390 acc.value = true;
391 }
392}