poulpy_core/keyswitching/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
5        VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
8};
9
10use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWESwitchingKeyPrepared};
11
12impl GLWECiphertext<Vec<u8>> {
13    #[allow(clippy::too_many_arguments)]
14    pub fn keyswitch_scratch_space<B: Backend>(
15        module: &Module<B>,
16        basek: usize,
17        k_out: usize,
18        k_in: usize,
19        k_ksk: usize,
20        digits: usize,
21        rank_in: usize,
22        rank_out: usize,
23    ) -> usize
24    where
25        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
26    {
27        let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
28        let out_size: usize = k_out.div_ceil(basek);
29        let ksk_size: usize = k_ksk.div_ceil(basek);
30        let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); // TODO OPTIMIZE
31        let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size);
32        let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size)
33            + module.vec_znx_dft_alloc_bytes(rank_in, in_size);
34        let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
35        res_dft + ((ai_dft + vmp) | normalize)
36    }
37
38    pub fn keyswitch_inplace_scratch_space<B: Backend>(
39        module: &Module<B>,
40        basek: usize,
41        k_out: usize,
42        k_ksk: usize,
43        digits: usize,
44        rank: usize,
45    ) -> usize
46    where
47        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
48    {
49        Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank)
50    }
51}
52
53impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
54    #[allow(dead_code)]
55    pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
56        &self,
57        module: &Module<B>,
58        lhs: &GLWECiphertext<DataLhs>,
59        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
60        scratch: &Scratch<B>,
61    ) where
62        DataLhs: DataRef,
63        DataRhs: DataRef,
64        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
65        Scratch<B>: ScratchAvailable,
66    {
67        let basek: usize = self.basek();
68        assert_eq!(
69            lhs.rank(),
70            rhs.rank_in(),
71            "lhs.rank(): {} != rhs.rank_in(): {}",
72            lhs.rank(),
73            rhs.rank_in()
74        );
75        assert_eq!(
76            self.rank(),
77            rhs.rank_out(),
78            "self.rank(): {} != rhs.rank_out(): {}",
79            self.rank(),
80            rhs.rank_out()
81        );
82        assert_eq!(self.basek(), basek);
83        assert_eq!(lhs.basek(), basek);
84        assert_eq!(rhs.n(), self.n());
85        assert_eq!(lhs.n(), self.n());
86        assert!(
87            scratch.available()
88                >= GLWECiphertext::keyswitch_scratch_space(
89                    module,
90                    self.basek(),
91                    self.k(),
92                    lhs.k(),
93                    rhs.k(),
94                    rhs.digits(),
95                    rhs.rank_in(),
96                    rhs.rank_out(),
97                ),
98            "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
99                    module,
100                    self.basek(),
101                    self.k(),
102                    lhs.k(),
103                    rhs.k(),
104                    rhs.digits(),
105                    rhs.rank_in(),
106                    rhs.rank_out(),
107                )={}",
108            scratch.available(),
109            GLWECiphertext::keyswitch_scratch_space(
110                module,
111                self.basek(),
112                self.k(),
113                lhs.k(),
114                rhs.k(),
115                rhs.digits(),
116                rhs.rank_in(),
117                rhs.rank_out(),
118            )
119        );
120    }
121
122    #[allow(dead_code)]
123    pub(crate) fn assert_keyswitch_inplace<B: Backend, DataRhs>(
124        &self,
125        module: &Module<B>,
126        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
127        scratch: &Scratch<B>,
128    ) where
129        DataRhs: DataRef,
130        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
131        Scratch<B>: ScratchAvailable,
132    {
133        let basek: usize = self.basek();
134        assert_eq!(
135            self.rank(),
136            rhs.rank_out(),
137            "self.rank(): {} != rhs.rank_out(): {}",
138            self.rank(),
139            rhs.rank_out()
140        );
141        assert_eq!(self.basek(), basek);
142        assert_eq!(rhs.n(), self.n());
143        assert!(
144            scratch.available()
145                >= GLWECiphertext::keyswitch_scratch_space(
146                    module,
147                    self.basek(),
148                    self.k(),
149                    self.k(),
150                    rhs.k(),
151                    rhs.digits(),
152                    rhs.rank_in(),
153                    rhs.rank_out(),
154                ),
155            "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
156                    module,
157                    self.basek(),
158                    self.k(),
159                    self.k(),
160                    rhs.k(),
161                    rhs.digits(),
162                    rhs.rank_in(),
163                    rhs.rank_out(),
164                )={}",
165            scratch.available(),
166            GLWECiphertext::keyswitch_scratch_space(
167                module,
168                self.basek(),
169                self.k(),
170                self.k(),
171                rhs.k(),
172                rhs.digits(),
173                rhs.rank_in(),
174                rhs.rank_out(),
175            )
176        );
177    }
178}
179
180impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
181    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
182        &mut self,
183        module: &Module<B>,
184        lhs: &GLWECiphertext<DataLhs>,
185        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
186        scratch: &mut Scratch<B>,
187    ) where
188        Module<B>: VecZnxDftAllocBytes
189            + VmpApplyDftToDftTmpBytes
190            + VecZnxBigNormalizeTmpBytes
191            + VmpApplyDftToDft<B>
192            + VmpApplyDftToDftAdd<B>
193            + VecZnxDftApply<B>
194            + VecZnxIdftApplyConsume<B>
195            + VecZnxBigAddSmallInplace<B>
196            + VecZnxBigNormalize<B>,
197        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
198    {
199        #[cfg(debug_assertions)]
200        {
201            self.assert_keyswitch(module, lhs, rhs, scratch);
202        }
203        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
204        let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch_1);
205        (0..self.cols()).for_each(|i| {
206            module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
207        })
208    }
209
210    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
211        &mut self,
212        module: &Module<B>,
213        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
214        scratch: &mut Scratch<B>,
215    ) where
216        Module<B>: VecZnxDftAllocBytes
217            + VmpApplyDftToDftTmpBytes
218            + VecZnxBigNormalizeTmpBytes
219            + VmpApplyDftToDftTmpBytes
220            + VmpApplyDftToDft<B>
221            + VmpApplyDftToDftAdd<B>
222            + VecZnxDftApply<B>
223            + VecZnxIdftApplyConsume<B>
224            + VecZnxBigAddSmallInplace<B>
225            + VecZnxBigNormalize<B>,
226        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
227    {
228        #[cfg(debug_assertions)]
229        {
230            self.assert_keyswitch_inplace(module, rhs, scratch);
231        }
232        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
233        let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1);
234        (0..self.cols()).for_each(|i| {
235            module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
236        })
237    }
238}
239
240impl<D: DataRef> GLWECiphertext<D> {
241    pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
242        &self,
243        module: &Module<B>,
244        res_dft: VecZnxDft<DataRes, B>,
245        rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
246        scratch: &mut Scratch<B>,
247    ) -> VecZnxBig<DataRes, B>
248    where
249        DataRes: DataMut,
250        DataKey: DataRef,
251        Module<B>: VecZnxDftAllocBytes
252            + VmpApplyDftToDftTmpBytes
253            + VecZnxBigNormalizeTmpBytes
254            + VmpApplyDftToDftTmpBytes
255            + VmpApplyDftToDft<B>
256            + VmpApplyDftToDftAdd<B>
257            + VecZnxDftApply<B>
258            + VecZnxIdftApplyConsume<B>
259            + VecZnxBigAddSmallInplace<B>
260            + VecZnxBigNormalize<B>,
261        Scratch<B>: TakeVecZnxDft<B>,
262    {
263        if rhs.digits() == 1 {
264            return keyswitch_vmp_one_digit(module, res_dft, &self.data, &rhs.key.data, scratch);
265        }
266
267        keyswitch_vmp_multiple_digits(
268            module,
269            res_dft,
270            &self.data,
271            &rhs.key.data,
272            rhs.digits(),
273            scratch,
274        )
275    }
276}
277
278fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
279    module: &Module<B>,
280    mut res_dft: VecZnxDft<DataRes, B>,
281    a: &VecZnx<DataIn>,
282    mat: &VmpPMat<DataVmp, B>,
283    scratch: &mut Scratch<B>,
284) -> VecZnxBig<DataRes, B>
285where
286    DataRes: DataMut,
287    DataIn: DataRef,
288    DataVmp: DataRef,
289    Module<B>:
290        VecZnxDftAllocBytes + VecZnxDftApply<B> + VmpApplyDftToDft<B> + VecZnxIdftApplyConsume<B> + VecZnxBigAddSmallInplace<B>,
291    Scratch<B>: TakeVecZnxDft<B>,
292{
293    let cols: usize = a.cols();
294    let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
295    (0..cols - 1).for_each(|col_i| {
296        module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1);
297    });
298    module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
299    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
300    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
301    res_big
302}
303
304fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
305    module: &Module<B>,
306    mut res_dft: VecZnxDft<DataRes, B>,
307    a: &VecZnx<DataIn>,
308    mat: &VmpPMat<DataVmp, B>,
309    digits: usize,
310    scratch: &mut Scratch<B>,
311) -> VecZnxBig<DataRes, B>
312where
313    DataRes: DataMut,
314    DataIn: DataRef,
315    DataVmp: DataRef,
316    Module<B>: VecZnxDftAllocBytes
317        + VecZnxDftApply<B>
318        + VmpApplyDftToDft<B>
319        + VmpApplyDftToDftAdd<B>
320        + VecZnxIdftApplyConsume<B>
321        + VecZnxBigAddSmallInplace<B>,
322    Scratch<B>: TakeVecZnxDft<B>,
323{
324    let cols: usize = a.cols();
325    let size: usize = a.size();
326    let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
327
328    ai_dft.data_mut().fill(0);
329
330    (0..digits).for_each(|di| {
331        ai_dft.set_size((size + di) / digits);
332
333        // Small optimization for digits > 2
334        // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
335        // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
336        // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
337        // It is possible to further ignore the last digits-1 limbs, but this introduce
338        // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
339        // noise is kept with respect to the ideal functionality.
340        res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
341
342        (0..cols - 1).for_each(|col_i| {
343            module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
344        });
345
346        if di == 0 {
347            module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
348        } else {
349            module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1);
350        }
351    });
352
353    res_dft.set_size(res_dft.max_size());
354    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
355    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
356    res_big
357}