poulpy_core/keyswitching/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
4        VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply,
5        VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
6        VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos},
9};
10
11use crate::{
12    layouts::{
13        GGLWECiphertext, GGLWELayoutInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos,
14        prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared},
15    },
16    operations::GLWEOperations,
17};
18
19impl GGSWCiphertext<Vec<u8>> {
20    pub(crate) fn expand_row_scratch_space<B: Backend, OUT, TSK>(module: &Module<B>, out_infos: &OUT, tsk_infos: &TSK) -> usize
21    where
22        OUT: GGSWInfos,
23        TSK: GGLWELayoutInfos,
24        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
25    {
26        let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize;
27        let size_in: usize = out_infos
28            .k()
29            .div_ceil(tsk_infos.base2k())
30            .div_ceil(tsk_infos.digits().into()) as usize;
31
32        let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes((tsk_infos.rank_out() + 1).into(), tsk_size);
33        let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, size_in);
34        let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
35            tsk_size,
36            size_in,
37            size_in,
38            (tsk_infos.rank_in()).into(),  // Verify if rank+1
39            (tsk_infos.rank_out()).into(), // Verify if rank+1
40            tsk_size,
41        );
42        let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size);
43        let norm: usize = module.vec_znx_normalize_tmp_bytes();
44
45        tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm))
46    }
47
48    #[allow(clippy::too_many_arguments)]
49    pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY, TSK>(
50        module: &Module<B>,
51        out_infos: &OUT,
52        in_infos: &IN,
53        apply_infos: &KEY,
54        tsk_infos: &TSK,
55    ) -> usize
56    where
57        OUT: GGSWInfos,
58        IN: GGSWInfos,
59        KEY: GGLWELayoutInfos,
60        TSK: GGLWELayoutInfos,
61        Module<B>: VecZnxDftAllocBytes
62            + VmpApplyDftToDftTmpBytes
63            + VecZnxBigAllocBytes
64            + VecZnxNormalizeTmpBytes
65            + VecZnxBigNormalizeTmpBytes,
66    {
67        #[cfg(debug_assertions)]
68        {
69            assert_eq!(apply_infos.rank_in(), apply_infos.rank_out());
70            assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out());
71            assert_eq!(apply_infos.rank_in(), tsk_infos.rank_in());
72        }
73
74        let rank: usize = apply_infos.rank_out().into();
75
76        let size_out: usize = out_infos.k().div_ceil(out_infos.base2k()) as usize;
77        let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, size_out);
78        let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out);
79        let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, apply_infos);
80        let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos);
81        let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out);
82
83        if in_infos.base2k() == tsk_infos.base2k() {
84            res_znx + ci_dft + (ks | expand_rows | res_dft)
85        } else {
86            let a_conv: usize = VecZnx::alloc_bytes(
87                module.n(),
88                1,
89                out_infos.k().div_ceil(tsk_infos.base2k()) as usize,
90            ) + module.vec_znx_normalize_tmp_bytes();
91            res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft)
92        }
93    }
94
95    #[allow(clippy::too_many_arguments)]
96    pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY, TSK>(
97        module: &Module<B>,
98        out_infos: &OUT,
99        apply_infos: &KEY,
100        tsk_infos: &TSK,
101    ) -> usize
102    where
103        OUT: GGSWInfos,
104        KEY: GGLWELayoutInfos,
105        TSK: GGLWELayoutInfos,
106        Module<B>: VecZnxDftAllocBytes
107            + VmpApplyDftToDftTmpBytes
108            + VecZnxBigAllocBytes
109            + VecZnxNormalizeTmpBytes
110            + VecZnxBigNormalizeTmpBytes,
111    {
112        GGSWCiphertext::keyswitch_scratch_space(module, out_infos, out_infos, apply_infos, tsk_infos)
113    }
114}
115
116impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
117    pub fn from_gglwe<DataA, DataTsk, B: Backend>(
118        &mut self,
119        module: &Module<B>,
120        a: &GGLWECiphertext<DataA>,
121        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
122        scratch: &mut Scratch<B>,
123    ) where
124        DataA: DataRef,
125        DataTsk: DataRef,
126        Module<B>: VecZnxCopy
127            + VecZnxDftAllocBytes
128            + VmpApplyDftToDftTmpBytes
129            + VecZnxBigAllocBytes
130            + VecZnxNormalizeTmpBytes
131            + VecZnxDftApply<B>
132            + VecZnxDftCopy<B>
133            + VmpApplyDftToDft<B>
134            + VmpApplyDftToDftAdd<B>
135            + VecZnxDftAddInplace<B>
136            + VecZnxBigNormalize<B>
137            + VecZnxIdftApplyTmpA<B>
138            + VecZnxNormalize<B>,
139        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
140    {
141        #[cfg(debug_assertions)]
142        {
143            use crate::layouts::{GLWEInfos, LWEInfos};
144
145            assert_eq!(self.rank(), a.rank_out());
146            assert_eq!(self.rows(), a.rows());
147            assert_eq!(self.n(), module.n() as u32);
148            assert_eq!(a.n(), module.n() as u32);
149            assert_eq!(tsk.n(), module.n() as u32);
150        }
151        (0..self.rows().into()).for_each(|row_i| {
152            self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0));
153        });
154        self.expand_row(module, tsk, scratch);
155    }
156
157    pub fn keyswitch<DataLhs: DataRef, DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
158        &mut self,
159        module: &Module<B>,
160        lhs: &GGSWCiphertext<DataLhs>,
161        ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
162        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
163        scratch: &mut Scratch<B>,
164    ) where
165        Module<B>: VecZnxDftAllocBytes
166            + VmpApplyDftToDftTmpBytes
167            + VecZnxBigNormalizeTmpBytes
168            + VmpApplyDftToDft<B>
169            + VmpApplyDftToDftAdd<B>
170            + VecZnxDftApply<B>
171            + VecZnxIdftApplyConsume<B>
172            + VecZnxBigAddSmallInplace<B>
173            + VecZnxBigNormalize<B>
174            + VecZnxDftAllocBytes
175            + VecZnxBigAllocBytes
176            + VecZnxNormalizeTmpBytes
177            + VecZnxDftCopy<B>
178            + VecZnxDftAddInplace<B>
179            + VecZnxIdftApplyTmpA<B>
180            + VecZnxNormalize<B>,
181        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
182    {
183        (0..lhs.rows().into()).for_each(|row_i| {
184            // Key-switch column 0, i.e.
185            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
186            self.at_mut(row_i, 0)
187                .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
188        });
189        self.expand_row(module, tsk, scratch);
190    }
191
192    pub fn keyswitch_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
193        &mut self,
194        module: &Module<B>,
195        ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
196        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
197        scratch: &mut Scratch<B>,
198    ) where
199        Module<B>: VecZnxDftAllocBytes
200            + VmpApplyDftToDftTmpBytes
201            + VecZnxBigNormalizeTmpBytes
202            + VmpApplyDftToDft<B>
203            + VmpApplyDftToDftAdd<B>
204            + VecZnxDftApply<B>
205            + VecZnxIdftApplyConsume<B>
206            + VecZnxBigAddSmallInplace<B>
207            + VecZnxBigNormalize<B>
208            + VecZnxDftAllocBytes
209            + VecZnxBigAllocBytes
210            + VecZnxNormalizeTmpBytes
211            + VecZnxDftCopy<B>
212            + VecZnxDftAddInplace<B>
213            + VecZnxIdftApplyTmpA<B>
214            + VecZnxNormalize<B>,
215        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
216    {
217        (0..self.rows().into()).for_each(|row_i| {
218            // Key-switch column 0, i.e.
219            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
220            self.at_mut(row_i, 0)
221                .keyswitch_inplace(module, ksk, scratch);
222        });
223        self.expand_row(module, tsk, scratch);
224    }
225
226    pub fn expand_row<DataTsk: DataRef, B: Backend>(
227        &mut self,
228        module: &Module<B>,
229        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
230        scratch: &mut Scratch<B>,
231    ) where
232        Module<B>: VecZnxDftAllocBytes
233            + VmpApplyDftToDftTmpBytes
234            + VecZnxBigAllocBytes
235            + VecZnxNormalizeTmpBytes
236            + VecZnxDftApply<B>
237            + VecZnxDftCopy<B>
238            + VmpApplyDftToDft<B>
239            + VmpApplyDftToDftAdd<B>
240            + VecZnxDftAddInplace<B>
241            + VecZnxBigNormalize<B>
242            + VecZnxIdftApplyTmpA<B>
243            + VecZnxNormalize<B>,
244        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
245    {
246        let basek_in: usize = self.base2k().into();
247        let basek_tsk: usize = tsk.base2k().into();
248
249        assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self, tsk));
250
251        let n: usize = self.n().into();
252        let rank: usize = self.rank().into();
253        let cols: usize = rank + 1;
254
255        let a_size: usize = (self.size() * basek_in).div_ceil(basek_tsk);
256
257        // Keyswitch the j-th row of the col 0
258        for row_i in 0..self.rows().into() {
259            let a = &self.at(row_i, 0).data;
260
261            // Pre-compute DFT of (a0, a1, a2)
262            let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, a_size);
263
264            if basek_in == basek_tsk {
265                for i in 0..cols {
266                    module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i);
267                }
268            } else {
269                let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(n, 1, a_size);
270                for i in 0..cols {
271                    module.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2);
272                    module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0);
273                }
274            }
275
276            for col_j in 1..cols {
277                // Example for rank 3:
278                //
279                // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
280                // actually composed of that many rows and we focus on a specific row here
281                // implicitely given ci_dft.
282                //
283                // # Input
284                //
285                // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0    , a1    , a2    )
286                // col 1: (0, 0, 0, 0)
287                // col 2: (0, 0, 0, 0)
288                // col 3: (0, 0, 0, 0)
289                //
290                // # Output
291                //
292                // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0       , a1       , a2       )
293                // col 1: (-(b0s0 + b1s1 + b2s2)       , b0 + M[i], b1       , b2       )
294                // col 2: (-(c0s0 + c1s1 + c2s2)       , c0       , c1 + M[i], c2       )
295                // col 3: (-(d0s0 + d1s1 + d2s2)       , d0       , d1       , d2 + M[i])
296
297                let digits: usize = tsk.digits().into();
298
299                let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size());
300                let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
301
302                {
303                    // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
304                    //
305                    // # Example for col=1
306                    //
307                    // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2)
308                    // +
309                    // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2)
310                    // +
311                    // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2)
312                    // =
313                    // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
314                    for col_i in 1..cols {
315                        let pmat: &VmpPMat<DataTsk, B> = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j])
316
317                        // Extracts a[i] and multipies with Enc(s[i]s[j])
318                        for di in 0..digits {
319                            tmp_a.set_size((ci_dft.size() + di) / digits);
320
321                            // Small optimization for digits > 2
322                            // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
323                            // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
324                            // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
325                            // It is possible to further ignore the last digits-1 limbs, but this introduce
326                            // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
327                            // noise is kept with respect to the ideal functionality.
328                            tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
329
330                            module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
331                            if di == 0 && col_i == 1 {
332                                module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3);
333                            } else {
334                                module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3);
335                            }
336                        }
337                    }
338                }
339
340                // Adds -(sum a[i] * s[i]) + m)  on the i-th column of tmp_idft_i
341                //
342                // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2)
343                // +
344                // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0)
345                // =
346                // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2)
347                // =
348                // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
349                module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
350                let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size());
351                for i in 0..cols {
352                    module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i);
353                    module.vec_znx_big_normalize(
354                        basek_in,
355                        &mut self.at_mut(row_i, col_j).data,
356                        i,
357                        basek_tsk,
358                        &tmp_idft,
359                        0,
360                        scratch_3,
361                    );
362                }
363            }
364        }
365    }
366}