poulpy_core/keyswitching/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
4        VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromVecZnx,
5        VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
6        ZnxInfos,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat},
9};
10
11use crate::{
12    layouts::{
13        GGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos,
14        prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared},
15    },
16    operations::GLWEOperations,
17};
18
19impl GGSWCiphertext<Vec<u8>> {
20    pub(crate) fn expand_row_scratch_space<B: Backend>(
21        module: &Module<B>,
22        n: usize,
23        basek: usize,
24        self_k: usize,
25        k_tsk: usize,
26        digits: usize,
27        rank: usize,
28    ) -> usize
29    where
30        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
31    {
32        let tsk_size: usize = k_tsk.div_ceil(basek);
33        let self_size_out: usize = self_k.div_ceil(basek);
34        let self_size_in: usize = self_size_out.div_ceil(digits);
35
36        let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, tsk_size);
37        let tmp_a: usize = module.vec_znx_dft_alloc_bytes(n, 1, self_size_in);
38        let vmp: usize = module.vmp_apply_tmp_bytes(
39            n,
40            self_size_out,
41            self_size_in,
42            self_size_in,
43            rank,
44            rank,
45            tsk_size,
46        );
47        let tmp_idft: usize = module.vec_znx_big_alloc_bytes(n, 1, tsk_size);
48        let norm: usize = module.vec_znx_normalize_tmp_bytes(n);
49        tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm))
50    }
51
52    #[allow(clippy::too_many_arguments)]
53    pub fn keyswitch_scratch_space<B: Backend>(
54        module: &Module<B>,
55        n: usize,
56        basek: usize,
57        k_out: usize,
58        k_in: usize,
59        k_ksk: usize,
60        digits_ksk: usize,
61        k_tsk: usize,
62        digits_tsk: usize,
63        rank: usize,
64    ) -> usize
65    where
66        Module<B>:
67            VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
68    {
69        let out_size: usize = k_out.div_ceil(basek);
70        let res_znx: usize = VecZnx::alloc_bytes(n, rank + 1, out_size);
71        let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size);
72        let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank);
73        let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank);
74        let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size);
75        res_znx + ci_dft + (ks | expand_rows | res_dft)
76    }
77
78    #[allow(clippy::too_many_arguments)]
79    pub fn keyswitch_inplace_scratch_space<B: Backend>(
80        module: &Module<B>,
81        n: usize,
82        basek: usize,
83        k_out: usize,
84        k_ksk: usize,
85        digits_ksk: usize,
86        k_tsk: usize,
87        digits_tsk: usize,
88        rank: usize,
89    ) -> usize
90    where
91        Module<B>:
92            VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
93    {
94        GGSWCiphertext::keyswitch_scratch_space(
95            module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
96        )
97    }
98}
99
100impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
101    pub fn from_gglwe<DataA, DataTsk, B: Backend>(
102        &mut self,
103        module: &Module<B>,
104        a: &GGLWECiphertext<DataA>,
105        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
106        scratch: &mut Scratch<B>,
107    ) where
108        DataA: DataRef,
109        DataTsk: DataRef,
110        Module<B>: VecZnxCopy
111            + VecZnxDftAllocBytes
112            + VmpApplyTmpBytes
113            + VecZnxBigAllocBytes
114            + VecZnxNormalizeTmpBytes
115            + VecZnxDftFromVecZnx<B>
116            + VecZnxDftCopy<B>
117            + VmpApply<B>
118            + VmpApplyAdd<B>
119            + VecZnxDftAddInplace<B>
120            + VecZnxBigNormalize<B>
121            + VecZnxDftToVecZnxBigTmpA<B>,
122        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
123    {
124        #[cfg(debug_assertions)]
125        {
126            assert_eq!(self.rank(), a.rank());
127            assert_eq!(self.rows(), a.rows());
128            assert_eq!(self.n(), module.n());
129            assert_eq!(a.n(), module.n());
130            assert_eq!(tsk.n(), module.n());
131        }
132        (0..self.rows()).for_each(|row_i| {
133            self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0));
134        });
135        self.expand_row(module, tsk, scratch);
136    }
137
138    pub fn keyswitch<DataLhs: DataRef, DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
139        &mut self,
140        module: &Module<B>,
141        lhs: &GGSWCiphertext<DataLhs>,
142        ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
143        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
144        scratch: &mut Scratch<B>,
145    ) where
146        Module<B>: VecZnxDftAllocBytes
147            + VmpApplyTmpBytes
148            + VecZnxBigNormalizeTmpBytes
149            + VmpApply<B>
150            + VmpApplyAdd<B>
151            + VecZnxDftFromVecZnx<B>
152            + VecZnxDftToVecZnxBigConsume<B>
153            + VecZnxBigAddSmallInplace<B>
154            + VecZnxBigNormalize<B>
155            + VecZnxDftAllocBytes
156            + VecZnxBigAllocBytes
157            + VecZnxNormalizeTmpBytes
158            + VecZnxDftCopy<B>
159            + VecZnxDftAddInplace<B>
160            + VecZnxDftToVecZnxBigTmpA<B>,
161        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
162    {
163        self.keyswitch_internal(module, lhs, ksk, scratch);
164        self.expand_row(module, tsk, scratch);
165    }
166
167    pub fn keyswitch_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
168        &mut self,
169        module: &Module<B>,
170        ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
171        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
172        scratch: &mut Scratch<B>,
173    ) where
174        Module<B>: VecZnxDftAllocBytes
175            + VmpApplyTmpBytes
176            + VecZnxBigNormalizeTmpBytes
177            + VmpApply<B>
178            + VmpApplyAdd<B>
179            + VecZnxDftFromVecZnx<B>
180            + VecZnxDftToVecZnxBigConsume<B>
181            + VecZnxBigAddSmallInplace<B>
182            + VecZnxBigNormalize<B>
183            + VecZnxDftAllocBytes
184            + VecZnxBigAllocBytes
185            + VecZnxNormalizeTmpBytes
186            + VecZnxDftCopy<B>
187            + VecZnxDftAddInplace<B>
188            + VecZnxDftToVecZnxBigTmpA<B>,
189        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
190    {
191        unsafe {
192            let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
193            self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
194        }
195    }
196
197    pub fn expand_row<DataTsk: DataRef, B: Backend>(
198        &mut self,
199        module: &Module<B>,
200        tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
201        scratch: &mut Scratch<B>,
202    ) where
203        Module<B>: VecZnxDftAllocBytes
204            + VmpApplyTmpBytes
205            + VecZnxBigAllocBytes
206            + VecZnxNormalizeTmpBytes
207            + VecZnxDftFromVecZnx<B>
208            + VecZnxDftCopy<B>
209            + VmpApply<B>
210            + VmpApplyAdd<B>
211            + VecZnxDftAddInplace<B>
212            + VecZnxBigNormalize<B>
213            + VecZnxDftToVecZnxBigTmpA<B>,
214        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
215    {
216        assert!(
217            scratch.available()
218                >= GGSWCiphertext::expand_row_scratch_space(
219                    module,
220                    self.n(),
221                    self.basek(),
222                    self.k(),
223                    tsk.k(),
224                    tsk.digits(),
225                    tsk.rank()
226                )
227        );
228
229        let n: usize = self.n();
230        let rank: usize = self.rank();
231        let cols: usize = rank + 1;
232
233        // Keyswitch the j-th row of the col 0
234        (0..self.rows()).for_each(|row_i| {
235            // Pre-compute DFT of (a0, a1, a2)
236            let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size());
237            (0..cols).for_each(|i| {
238                module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
239            });
240
241            (1..cols).for_each(|col_j| {
242                // Example for rank 3:
243                //
244                // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
245                // actually composed of that many rows and we focus on a specific row here
246                // implicitely given ci_dft.
247                //
248                // # Input
249                //
250                // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0    , a1    , a2    )
251                // col 1: (0, 0, 0, 0)
252                // col 2: (0, 0, 0, 0)
253                // col 3: (0, 0, 0, 0)
254                //
255                // # Output
256                //
257                // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0       , a1       , a2       )
258                // col 1: (-(b0s0 + b1s1 + b2s2)       , b0 + M[i], b1       , b2       )
259                // col 2: (-(c0s0 + c1s1 + c2s2)       , c0       , c1 + M[i], c2       )
260                // col 3: (-(d0s0 + d1s1 + d2s2)       , d0       , d1       , d2 + M[i])
261
262                let digits: usize = tsk.digits();
263
264                let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size());
265                let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
266
267                {
268                    // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
269                    //
270                    // # Example for col=1
271                    //
272                    // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2)
273                    // +
274                    // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2)
275                    // +
276                    // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2)
277                    // =
278                    // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
279                    (1..cols).for_each(|col_i| {
280                        let pmat: &VmpPMat<DataTsk, B> = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j])
281
282                        // Extracts a[i] and multipies with Enc(s[i]s[j])
283                        (0..digits).for_each(|di| {
284                            tmp_a.set_size((ci_dft.size() + di) / digits);
285
286                            // Small optimization for digits > 2
287                            // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
288                            // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
289                            // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
290                            // It is possible to further ignore the last digits-1 limbs, but this introduce
291                            // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
292                            // noise is kept with respect to the ideal functionality.
293                            tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
294
295                            module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
296                            if di == 0 && col_i == 1 {
297                                module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch3);
298                            } else {
299                                module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3);
300                            }
301                        });
302                    });
303                }
304
305                // Adds -(sum a[i] * s[i]) + m)  on the i-th column of tmp_idft_i
306                //
307                // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2)
308                // +
309                // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0)
310                // =
311                // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2)
312                // =
313                // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
314                module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
315                let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size());
316                (0..cols).for_each(|i| {
317                    module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
318                    module.vec_znx_big_normalize(
319                        self.basek(),
320                        &mut self.at_mut(row_i, col_j).data,
321                        i,
322                        &tmp_idft,
323                        0,
324                        scratch3,
325                    );
326                });
327            })
328        })
329    }
330
331    fn keyswitch_internal<DataLhs: DataRef, DataKsk: DataRef, B: Backend>(
332        &mut self,
333        module: &Module<B>,
334        lhs: &GGSWCiphertext<DataLhs>,
335        ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
336        scratch: &mut Scratch<B>,
337    ) where
338        Module<B>: VecZnxDftAllocBytes
339            + VmpApplyTmpBytes
340            + VecZnxBigNormalizeTmpBytes
341            + VmpApply<B>
342            + VmpApplyAdd<B>
343            + VecZnxDftFromVecZnx<B>
344            + VecZnxDftToVecZnxBigConsume<B>
345            + VecZnxBigAddSmallInplace<B>
346            + VecZnxBigNormalize<B>,
347        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
348    {
349        // Keyswitch the j-th row of the col 0
350        (0..lhs.rows()).for_each(|row_i| {
351            // Key-switch column 0, i.e.
352            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
353            self.at_mut(row_i, 0)
354                .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
355        })
356    }
357}