poulpy_core/
glwe_packing.rs

1use std::collections::HashMap;
2
3use poulpy_hal::{
4    api::{
5        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
6        VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy,
7        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize,
8        VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
9        VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
10    },
11    layouts::{Backend, DataMut, DataRef, Module, Scratch},
12};
13
14use crate::{
15    GLWEOperations, TakeGLWECt,
16    layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared},
17};
18
19/// [GLWEPacker] enables only the fly GLWE packing
20/// with constant memory of Log(N) ciphertexts.
21/// Main difference with usual GLWE packing is that
22/// the output is bit-reversed.
23pub struct GLWEPacker {
24    accumulators: Vec<Accumulator>,
25    log_batch: usize,
26    counter: usize,
27}
28
29/// [Accumulator] stores intermediate packing result.
30/// There are Log(N) such accumulators in a [GLWEPacker].
31struct Accumulator {
32    data: GLWECiphertext<Vec<u8>>,
33    value: bool,   // Implicit flag for zero ciphertext
34    control: bool, // Can be combined with incoming value
35}
36
37impl Accumulator {
38    /// Allocates a new [Accumulator].
39    ///
40    /// #Arguments
41    ///
42    /// * `module`: static backend FFT tables.
43    /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
44    /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus.
45    /// * `rank`: rank of the GLWE ciphertext.
46    pub fn alloc<A>(infos: &A) -> Self
47    where
48        A: GLWEInfos,
49    {
50        Self {
51            data: GLWECiphertext::alloc(infos),
52            value: false,
53            control: false,
54        }
55    }
56}
57
58impl GLWEPacker {
59    /// Instantiates a new [GLWEPacker].
60    ///
61    /// # Arguments
62    ///
63    /// * `module`: static backend FFT tables.
64    /// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}.
65    ///   i.e. with `log_batch=0` only the constant coefficient is packed
66    ///   and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients
67    ///   which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts
68    ///   can be packed.
69    pub fn new<A>(infos: &A, log_batch: usize) -> Self
70    where
71        A: GLWEInfos,
72    {
73        let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
74        let log_n: usize = infos.n().log2();
75        (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos)));
76        Self {
77            accumulators,
78            log_batch,
79            counter: 0,
80        }
81    }
82
83    /// Implicit reset of the internal state (to be called before a new packing procedure).
84    fn reset(&mut self) {
85        for i in 0..self.accumulators.len() {
86            self.accumulators[i].value = false;
87            self.accumulators[i].control = false;
88        }
89        self.counter = 0;
90    }
91
92    /// Number of scratch space bytes required to call [Self::add].
93    pub fn scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
94    where
95        OUT: GLWEInfos,
96        KEY: GGLWELayoutInfos,
97        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
98    {
99        pack_core_scratch_space(module, out_infos, key_infos)
100    }
101
102    pub fn galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
103        GLWECiphertext::trace_galois_elements(module)
104    }
105
106    /// Adds a GLWE ciphertext to the [GLWEPacker].
107    /// #Arguments
108    ///
109    /// * `module`: static backend FFT tables.
110    /// * `res`: space to append fully packed ciphertext. Only when the number
111    ///   of packed ciphertexts reaches N/2^log_batch is a result written.
112    /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext.
113    /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s.
114    /// * `scratch`: scratch space of size at least [Self::scratch_space].
115    pub fn add<DataA: DataRef, DataAK: DataRef, B: Backend>(
116        &mut self,
117        module: &Module<B>,
118        a: Option<&GLWECiphertext<DataA>>,
119        auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
120        scratch: &mut Scratch<B>,
121    ) where
122        Module<B>: VecZnxDftAllocBytes
123            + VmpApplyDftToDftTmpBytes
124            + VecZnxBigNormalizeTmpBytes
125            + VmpApplyDftToDft<B>
126            + VmpApplyDftToDftAdd<B>
127            + VecZnxDftApply<B>
128            + VecZnxIdftApplyConsume<B>
129            + VecZnxBigAddSmallInplace<B>
130            + VecZnxBigNormalize<B>
131            + VecZnxCopy
132            + VecZnxRotateInplace<B>
133            + VecZnxSub
134            + VecZnxNegateInplace
135            + VecZnxRshInplace<B>
136            + VecZnxAddInplace
137            + VecZnxNormalizeInplace<B>
138            + VecZnxSubInplace
139            + VecZnxRotate
140            + VecZnxAutomorphismInplace<B>
141            + VecZnxBigSubSmallNegateInplace<B>
142            + VecZnxBigAutomorphismInplace<B>
143            + VecZnxNormalize<B>
144            + VecZnxNormalizeTmpBytes,
145        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
146    {
147        assert!(
148            (self.counter as u32) < self.accumulators[0].data.n(),
149            "Packing limit of {} reached",
150            self.accumulators[0].data.n().0 as usize >> self.log_batch
151        );
152
153        pack_core(
154            module,
155            a,
156            &mut self.accumulators,
157            self.log_batch,
158            auto_keys,
159            scratch,
160        );
161        self.counter += 1 << self.log_batch;
162    }
163
164    /// Flush result to`res`.
165    pub fn flush<Data: DataMut, B: Backend>(&mut self, module: &Module<B>, res: &mut GLWECiphertext<Data>)
166    where
167        Module<B>: VecZnxCopy,
168    {
169        assert!(self.counter as u32 == self.accumulators[0].data.n());
170        // Copy result GLWE into res GLWE
171        res.copy(
172            module,
173            &self.accumulators[module.log_n() - self.log_batch - 1].data,
174        );
175
176        self.reset();
177    }
178}
179
180fn pack_core_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
181where
182    OUT: GLWEInfos,
183    KEY: GGLWELayoutInfos,
184    Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
185{
186    combine_scratch_space(module, out_infos, key_infos)
187}
188
189fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
190    module: &Module<B>,
191    a: Option<&GLWECiphertext<D>>,
192    accumulators: &mut [Accumulator],
193    i: usize,
194    auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
195    scratch: &mut Scratch<B>,
196) where
197    Module<B>: VecZnxDftAllocBytes
198        + VmpApplyDftToDftTmpBytes
199        + VecZnxBigNormalizeTmpBytes
200        + VmpApplyDftToDft<B>
201        + VmpApplyDftToDftAdd<B>
202        + VecZnxDftApply<B>
203        + VecZnxIdftApplyConsume<B>
204        + VecZnxBigAddSmallInplace<B>
205        + VecZnxBigNormalize<B>
206        + VecZnxCopy
207        + VecZnxRotateInplace<B>
208        + VecZnxSub
209        + VecZnxNegateInplace
210        + VecZnxRshInplace<B>
211        + VecZnxAddInplace
212        + VecZnxNormalizeInplace<B>
213        + VecZnxSubInplace
214        + VecZnxRotate
215        + VecZnxAutomorphismInplace<B>
216        + VecZnxBigSubSmallNegateInplace<B>
217        + VecZnxBigAutomorphismInplace<B>
218        + VecZnxNormalize<B>
219        + VecZnxNormalizeTmpBytes,
220    Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
221{
222    let log_n: usize = module.log_n();
223
224    if i == log_n {
225        return;
226    }
227
228    // Isolate the first accumulator
229    let (acc_prev, acc_next) = accumulators.split_at_mut(1);
230
231    // Control = true accumlator is free to overide
232    if !acc_prev[0].control {
233        let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; // from split_at_mut
234
235        // No previous value -> copies and sets flags accordingly
236        if let Some(a_ref) = a {
237            acc_mut_ref.data.copy(module, a_ref);
238            acc_mut_ref.value = true
239        } else {
240            acc_mut_ref.value = false
241        }
242        acc_mut_ref.control = true; // Able to be combined on next call
243    } else {
244        // Compresses acc_prev <- combine(acc_prev, a).
245        combine(module, &mut acc_prev[0], a, i, auto_keys, scratch);
246        acc_prev[0].control = false;
247
248        // Propagates to next accumulator
249        if acc_prev[0].value {
250            pack_core(
251                module,
252                Some(&acc_prev[0].data),
253                acc_next,
254                i + 1,
255                auto_keys,
256                scratch,
257            );
258        } else {
259            pack_core(
260                module,
261                None::<&GLWECiphertext<Vec<u8>>>,
262                acc_next,
263                i + 1,
264                auto_keys,
265                scratch,
266            );
267        }
268    }
269}
270
271fn combine_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
272where
273    OUT: GLWEInfos,
274    KEY: GGLWELayoutInfos,
275    Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
276{
277    GLWECiphertext::alloc_bytes(out_infos)
278        + (GLWECiphertext::rsh_scratch_space(module.n())
279            | GLWECiphertext::automorphism_inplace_scratch_space(module, out_infos, key_infos))
280}
281
282/// [combine] merges two ciphertexts together.
283fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
284    module: &Module<B>,
285    acc: &mut Accumulator,
286    b: Option<&GLWECiphertext<D>>,
287    i: usize,
288    auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
289    scratch: &mut Scratch<B>,
290) where
291    Module<B>: VecZnxDftAllocBytes
292        + VmpApplyDftToDftTmpBytes
293        + VecZnxBigNormalizeTmpBytes
294        + VmpApplyDftToDft<B>
295        + VmpApplyDftToDftAdd<B>
296        + VecZnxDftApply<B>
297        + VecZnxIdftApplyConsume<B>
298        + VecZnxBigAddSmallInplace<B>
299        + VecZnxBigNormalize<B>
300        + VecZnxCopy
301        + VecZnxRotateInplace<B>
302        + VecZnxSub
303        + VecZnxNegateInplace
304        + VecZnxRshInplace<B>
305        + VecZnxAddInplace
306        + VecZnxNormalizeInplace<B>
307        + VecZnxSubInplace
308        + VecZnxRotate
309        + VecZnxAutomorphismInplace<B>
310        + VecZnxBigSubSmallNegateInplace<B>
311        + VecZnxBigAutomorphismInplace<B>
312        + VecZnxNormalize<B>
313        + VecZnxNormalizeTmpBytes,
314    Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeGLWECt,
315{
316    let log_n: usize = acc.data.n().log2();
317    let a: &mut GLWECiphertext<Vec<u8>> = &mut acc.data;
318
319    let gal_el: i64 = if i == 0 {
320        -1
321    } else {
322        module.galois_element(1 << (i - 1))
323    };
324
325    let t: i64 = 1 << (log_n - i - 1);
326
327    // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t))
328    // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g)
329    // where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)}
330    // Different cases for wether a and/or b are zero.
331    //
332    // Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption.
333    // Necessary so that the scaling of the plaintext remains constant.
334    // It however is ok to do so here because coefficients are eventually
335    // either mapped to garbage or twice their value which vanishes I(X)
336    // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q.
337    if acc.value {
338        if let Some(b) = b {
339            let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
340
341            // a = a * X^-t
342            a.rotate_inplace(module, -t, scratch_1);
343
344            // tmp_b = a * X^-t - b
345            tmp_b.sub(module, a, b);
346            tmp_b.rsh(module, 1, scratch_1);
347
348            // a = a * X^-t + b
349            a.add_inplace(module, b);
350            a.rsh(module, 1, scratch_1);
351
352            tmp_b.normalize_inplace(module, scratch_1);
353
354            // tmp_b = phi(a * X^-t - b)
355            if let Some(key) = auto_keys.get(&gal_el) {
356                tmp_b.automorphism_inplace(module, key, scratch_1);
357            } else {
358                panic!("auto_key[{gal_el}] not found");
359            }
360
361            // a = a * X^-t + b - phi(a * X^-t - b)
362            a.sub_inplace_ab(module, &tmp_b);
363            a.normalize_inplace(module, scratch_1);
364
365            // a = a + b * X^t - phi(a * X^-t - b) * X^t
366            //   = a + b * X^t - phi(a * X^-t - b) * - phi(X^t)
367            //   = a + b * X^t + phi(a - b * X^t)
368            a.rotate_inplace(module, t, scratch_1);
369        } else {
370            a.rsh(module, 1, scratch);
371            // a = a + phi(a)
372            if let Some(key) = auto_keys.get(&gal_el) {
373                a.automorphism_add_inplace(module, key, scratch);
374            } else {
375                panic!("auto_key[{gal_el}] not found");
376            }
377        }
378    } else if let Some(b) = b {
379        let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a);
380        tmp_b.rotate(module, 1 << (log_n - i - 1), b);
381        tmp_b.rsh(module, 1, scratch_1);
382
383        // a = (b* X^t - phi(b* X^t))
384        if let Some(key) = auto_keys.get(&gal_el) {
385            a.automorphism_sub_negate(module, &tmp_b, key, scratch_1);
386        } else {
387            panic!("auto_key[{gal_el}] not found");
388        }
389
390        acc.value = true;
391    }
392}