poulpy_core/keyswitching/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5        VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
8};
9
10use crate::layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWESwitchingKeyPrepared};
11
12impl GLWECiphertext<Vec<u8>> {
13    pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
14        module: &Module<B>,
15        out_infos: &OUT,
16        in_infos: &IN,
17        key_apply: &KEY,
18    ) -> usize
19    where
20        OUT: GLWEInfos,
21        IN: GLWEInfos,
22        KEY: GGLWELayoutInfos,
23        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
24    {
25        let in_size: usize = in_infos
26            .k()
27            .div_ceil(key_apply.base2k())
28            .div_ceil(key_apply.digits().into()) as usize;
29        let out_size: usize = out_infos.size();
30        let ksk_size: usize = key_apply.size();
31        let res_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
32        let ai_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size);
33        let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
34            out_size,
35            in_size,
36            in_size,
37            (key_apply.rank_in()).into(),
38            (key_apply.rank_out() + 1).into(),
39            ksk_size,
40        ) + module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size);
41        let normalize_big: usize = module.vec_znx_big_normalize_tmp_bytes();
42        if in_infos.base2k() == key_apply.base2k() {
43            res_dft + ((ai_dft + vmp) | normalize_big)
44        } else if key_apply.digits() == 1 {
45            // In this case, we only need one column, temporary, that we can drop once a_dft is computed.
46            let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), 1, in_size) + module.vec_znx_normalize_tmp_bytes();
47            res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
48        } else {
49            // Since we stride over a to get a_dft when digits > 1, we need to store the full columns of a with in the base conversion.
50            let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (key_apply.rank_in()).into(), in_size);
51            res_dft + ((ai_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
52        }
53    }
54
55    pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_apply: &KEY) -> usize
56    where
57        OUT: GLWEInfos,
58        KEY: GGLWELayoutInfos,
59        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
60    {
61        Self::keyswitch_scratch_space(module, out_infos, out_infos, key_apply)
62    }
63}
64
65impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
66    #[allow(dead_code)]
67    pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
68        &self,
69        module: &Module<B>,
70        lhs: &GLWECiphertext<DataLhs>,
71        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
72        scratch: &Scratch<B>,
73    ) where
74        DataLhs: DataRef,
75        DataRhs: DataRef,
76        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
77        Scratch<B>: ScratchAvailable,
78    {
79        assert_eq!(
80            lhs.rank(),
81            rhs.rank_in(),
82            "lhs.rank(): {} != rhs.rank_in(): {}",
83            lhs.rank(),
84            rhs.rank_in()
85        );
86        assert_eq!(
87            self.rank(),
88            rhs.rank_out(),
89            "self.rank(): {} != rhs.rank_out(): {}",
90            self.rank(),
91            rhs.rank_out()
92        );
93        assert_eq!(rhs.n(), self.n());
94        assert_eq!(lhs.n(), self.n());
95
96        let scrach_needed: usize = GLWECiphertext::keyswitch_scratch_space(module, self, lhs, rhs);
97
98        assert!(
99            scratch.available() >= scrach_needed,
100            "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
101                    module,
102                    self.base2k(),
103                    self.k(),
104                    lhs.base2k(),
105                    lhs.k(),
106                    rhs.base2k(),
107                    rhs.k(),
108                    rhs.digits(),
109                    rhs.rank_in(),
110                    rhs.rank_out(),
111                )={scrach_needed}",
112            scratch.available(),
113        );
114    }
115
116    #[allow(dead_code)]
117    pub(crate) fn assert_keyswitch_inplace<B: Backend, DataRhs>(
118        &self,
119        module: &Module<B>,
120        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
121        scratch: &Scratch<B>,
122    ) where
123        DataRhs: DataRef,
124        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
125        Scratch<B>: ScratchAvailable,
126    {
127        assert_eq!(
128            self.rank(),
129            rhs.rank_out(),
130            "self.rank(): {} != rhs.rank_out(): {}",
131            self.rank(),
132            rhs.rank_out()
133        );
134
135        assert_eq!(rhs.n(), self.n());
136
137        let scrach_needed: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, self, rhs);
138
139        assert!(
140            scratch.available() >= scrach_needed,
141            "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space()={scrach_needed}",
142            scratch.available(),
143        );
144    }
145}
146
147impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
148    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
149        &mut self,
150        module: &Module<B>,
151        glwe_in: &GLWECiphertext<DataLhs>,
152        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
153        scratch: &mut Scratch<B>,
154    ) where
155        Module<B>: VecZnxDftAllocBytes
156            + VmpApplyDftToDftTmpBytes
157            + VecZnxBigNormalizeTmpBytes
158            + VmpApplyDftToDft<B>
159            + VmpApplyDftToDftAdd<B>
160            + VecZnxDftApply<B>
161            + VecZnxIdftApplyConsume<B>
162            + VecZnxBigAddSmallInplace<B>
163            + VecZnxBigNormalize<B>
164            + VecZnxNormalize<B>
165            + VecZnxNormalizeTmpBytes,
166        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
167    {
168        #[cfg(debug_assertions)]
169        {
170            self.assert_keyswitch(module, glwe_in, rhs, scratch);
171        }
172
173        let basek_out: usize = self.base2k().into();
174        let basek_ksk: usize = rhs.base2k().into();
175
176        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise
177        let res_big: VecZnxBig<_, B> = glwe_in.keyswitch_internal(module, res_dft, rhs, scratch_1);
178        (0..(self.rank() + 1).into()).for_each(|i| {
179            module.vec_znx_big_normalize(
180                basek_out,
181                &mut self.data,
182                i,
183                basek_ksk,
184                &res_big,
185                i,
186                scratch_1,
187            );
188        })
189    }
190
191    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
192        &mut self,
193        module: &Module<B>,
194        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
195        scratch: &mut Scratch<B>,
196    ) where
197        Module<B>: VecZnxDftAllocBytes
198            + VmpApplyDftToDftTmpBytes
199            + VecZnxBigNormalizeTmpBytes
200            + VmpApplyDftToDftTmpBytes
201            + VmpApplyDftToDft<B>
202            + VmpApplyDftToDftAdd<B>
203            + VecZnxDftApply<B>
204            + VecZnxIdftApplyConsume<B>
205            + VecZnxBigAddSmallInplace<B>
206            + VecZnxBigNormalize<B>
207            + VecZnxNormalize<B>
208            + VecZnxNormalizeTmpBytes,
209        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
210    {
211        #[cfg(debug_assertions)]
212        {
213            self.assert_keyswitch_inplace(module, rhs, scratch);
214        }
215
216        let basek_in: usize = self.base2k().into();
217        let basek_ksk: usize = rhs.base2k().into();
218
219        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise
220        let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1);
221        (0..(self.rank() + 1).into()).for_each(|i| {
222            module.vec_znx_big_normalize(
223                basek_in,
224                &mut self.data,
225                i,
226                basek_ksk,
227                &res_big,
228                i,
229                scratch_1,
230            );
231        })
232    }
233}
234
235impl<D: DataRef> GLWECiphertext<D> {
236    pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
237        &self,
238        module: &Module<B>,
239        res_dft: VecZnxDft<DataRes, B>,
240        rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
241        scratch: &mut Scratch<B>,
242    ) -> VecZnxBig<DataRes, B>
243    where
244        DataRes: DataMut,
245        DataKey: DataRef,
246        Module<B>: VecZnxDftAllocBytes
247            + VmpApplyDftToDftTmpBytes
248            + VecZnxBigNormalizeTmpBytes
249            + VmpApplyDftToDftTmpBytes
250            + VmpApplyDftToDft<B>
251            + VmpApplyDftToDftAdd<B>
252            + VecZnxDftApply<B>
253            + VecZnxIdftApplyConsume<B>
254            + VecZnxBigAddSmallInplace<B>
255            + VecZnxBigNormalize<B>
256            + VecZnxNormalize<B>,
257        Scratch<B>: TakeVecZnxDft<B> + TakeVecZnx,
258    {
259        if rhs.digits() == 1 {
260            return keyswitch_vmp_one_digit(
261                module,
262                self.base2k().into(),
263                rhs.base2k().into(),
264                res_dft,
265                &self.data,
266                &rhs.key.data,
267                scratch,
268            );
269        }
270
271        keyswitch_vmp_multiple_digits(
272            module,
273            self.base2k().into(),
274            rhs.base2k().into(),
275            res_dft,
276            &self.data,
277            &rhs.key.data,
278            rhs.digits().into(),
279            scratch,
280        )
281    }
282}
283
284fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
285    module: &Module<B>,
286    basek_in: usize,
287    basek_ksk: usize,
288    mut res_dft: VecZnxDft<DataRes, B>,
289    a: &VecZnx<DataIn>,
290    mat: &VmpPMat<DataVmp, B>,
291    scratch: &mut Scratch<B>,
292) -> VecZnxBig<DataRes, B>
293where
294    DataRes: DataMut,
295    DataIn: DataRef,
296    DataVmp: DataRef,
297    Module<B>: VecZnxDftAllocBytes
298        + VecZnxDftApply<B>
299        + VmpApplyDftToDft<B>
300        + VecZnxIdftApplyConsume<B>
301        + VecZnxBigAddSmallInplace<B>
302        + VecZnxNormalize<B>,
303    Scratch<B>: TakeVecZnxDft<B> + TakeVecZnx,
304{
305    let cols: usize = a.cols();
306
307    let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk);
308    let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
309
310    if basek_in == basek_ksk {
311        (0..cols - 1).for_each(|col_i| {
312            module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1);
313        });
314    } else {
315        let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), 1, a_size);
316        (0..cols - 1).for_each(|col_i| {
317            module.vec_znx_normalize(basek_ksk, &mut a_conv, 0, basek_in, a, col_i + 1, scratch_2);
318            module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0);
319        });
320    }
321
322    module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
323    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
324    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
325    res_big
326}
327
328#[allow(clippy::too_many_arguments)]
329fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
330    module: &Module<B>,
331    basek_in: usize,
332    basek_ksk: usize,
333    mut res_dft: VecZnxDft<DataRes, B>,
334    a: &VecZnx<DataIn>,
335    mat: &VmpPMat<DataVmp, B>,
336    digits: usize,
337    scratch: &mut Scratch<B>,
338) -> VecZnxBig<DataRes, B>
339where
340    DataRes: DataMut,
341    DataIn: DataRef,
342    DataVmp: DataRef,
343    Module<B>: VecZnxDftAllocBytes
344        + VecZnxDftApply<B>
345        + VmpApplyDftToDft<B>
346        + VmpApplyDftToDftAdd<B>
347        + VecZnxIdftApplyConsume<B>
348        + VecZnxBigAddSmallInplace<B>
349        + VecZnxNormalize<B>,
350    Scratch<B>: TakeVecZnxDft<B> + TakeVecZnx,
351{
352    let cols: usize = a.cols();
353    let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk);
354    let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(digits));
355    ai_dft.data_mut().fill(0);
356
357    if basek_in == basek_ksk {
358        for di in 0..digits {
359            ai_dft.set_size((a_size + di) / digits);
360
361            // Small optimization for digits > 2
362            // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
363            // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
364            // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
365            // It is possible to further ignore the last digits-1 limbs, but this introduce
366            // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
367            // noise is kept with respect to the ideal functionality.
368            res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
369
370            for j in 0..cols - 1 {
371                module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, a, j + 1);
372            }
373
374            if di == 0 {
375                module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
376            } else {
377                module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1);
378            }
379        }
380    } else {
381        let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), cols - 1, a_size);
382        for j in 0..cols - 1 {
383            module.vec_znx_normalize(basek_ksk, &mut a_conv, j, basek_in, a, j + 1, scratch_2);
384        }
385
386        for di in 0..digits {
387            ai_dft.set_size((a_size + di) / digits);
388
389            // Small optimization for digits > 2
390            // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
391            // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
392            // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
393            // It is possible to further ignore the last digits-1 limbs, but this introduce
394            // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
395            // noise is kept with respect to the ideal functionality.
396            res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
397
398            for j in 0..cols - 1 {
399                module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, &a_conv, j);
400            }
401
402            if di == 0 {
403                module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_2);
404            } else {
405                module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_2);
406            }
407        }
408    }
409
410    res_dft.set_size(res_dft.max_size());
411    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
412    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
413    res_big
414}