poulpy_core/keyswitching/
glwe.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes,
5        VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos},
8};
9
10use crate::{
11    GLWENormalize, ScratchTakeCore,
12    layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos},
13};
14
15impl GLWE<Vec<u8>> {
16    pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
17    where
18        R: GLWEInfos,
19        A: GLWEInfos,
20        B: GGLWEInfos,
21        M: GLWEKeyswitch<BE>,
22    {
23        module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
24    }
25}
26
27impl<D: DataMut> GLWE<D> {
28    pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
29    where
30        A: GLWEToRef + GLWEInfos,
31        B: GGLWEPreparedToRef<BE> + GGLWEInfos,
32        M: GLWEKeyswitch<BE>,
33        Scratch<BE>: ScratchTakeCore<BE>,
34    {
35        module.glwe_keyswitch(self, a, b, scratch);
36    }
37
38    pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
39    where
40        A: GGLWEPreparedToRef<BE> + GGLWEInfos,
41        M: GLWEKeyswitch<BE>,
42        Scratch<BE>: ScratchTakeCore<BE>,
43    {
44        module.glwe_keyswitch_inplace(self, a, scratch);
45    }
46}
47
48impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE>
49where
50    Self: Sized + GLWEKeySwitchInternal<BE> + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize<BE> + GLWENormalize<BE>,
51    Scratch<BE>: ScratchTakeCore<BE>,
52{
53    fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
54    where
55        R: GLWEInfos,
56        A: GLWEInfos,
57        B: GGLWEInfos,
58    {
59        let cols: usize = res_infos.rank().as_usize() + 1;
60        let size: usize = if a_infos.base2k() != key_infos.base2k() {
61            let a_conv_infos = &GLWELayout {
62                n: a_infos.n(),
63                base2k: key_infos.base2k(),
64                k: a_infos.k(),
65                rank: a_infos.rank(),
66            };
67            self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_conv_infos, key_infos) + GLWE::bytes_of_from_infos(a_conv_infos)
68        } else {
69            self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos)
70        };
71
72        size.max(self.vec_znx_big_normalize_tmp_bytes()) + self.bytes_of_vec_znx_dft(cols, key_infos.size())
73    }
74
75    fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
76    where
77        R: GLWEToMut + GLWEInfos,
78        A: GLWEToRef + GLWEInfos,
79        K: GGLWEPreparedToRef<BE> + GGLWEInfos,
80    {
81        assert_eq!(
82            a.rank(),
83            key.rank_in(),
84            "a.rank(): {} != b.rank_in(): {}",
85            a.rank(),
86            key.rank_in()
87        );
88        assert_eq!(
89            res.rank(),
90            key.rank_out(),
91            "res.rank(): {} != b.rank_out(): {}",
92            res.rank(),
93            key.rank_out()
94        );
95
96        assert_eq!(res.n(), self.n() as u32);
97        assert_eq!(a.n(), self.n() as u32);
98        assert_eq!(key.n(), self.n() as u32);
99
100        let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, key);
101
102        assert!(
103            scratch.available() >= scrach_needed,
104            "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
105            scratch.available(),
106        );
107
108        let base2k_a: usize = a.base2k().into();
109        let base2k_key: usize = key.base2k().into();
110        let base2k_res: usize = res.base2k().into();
111
112        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
113
114        let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_key {
115            let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
116                n: a.n(),
117                base2k: key.base2k(),
118                k: a.k(),
119                rank: a.rank(),
120            });
121            self.glwe_normalize(&mut a_conv, a, scratch_2);
122            self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2)
123        } else {
124            self.glwe_keyswitch_internal(res_dft, a, key, scratch_1)
125        };
126
127        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
128        for i in 0..(res.rank() + 1).into() {
129            self.vec_znx_big_normalize(
130                base2k_res,
131                res.data_mut(),
132                i,
133                base2k_key,
134                &res_big,
135                i,
136                scratch_1,
137            );
138        }
139    }
140
141    fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
142    where
143        R: GLWEToMut + GLWEInfos,
144        K: GGLWEPreparedToRef<BE> + GGLWEInfos,
145    {
146        assert_eq!(
147            res.rank(),
148            key.rank_in(),
149            "res.rank(): {} != a.rank_in(): {}",
150            res.rank(),
151            key.rank_in()
152        );
153        assert_eq!(
154            res.rank(),
155            key.rank_out(),
156            "res.rank(): {} != b.rank_out(): {}",
157            res.rank(),
158            key.rank_out()
159        );
160
161        assert_eq!(res.n(), self.n() as u32);
162        assert_eq!(key.n(), self.n() as u32);
163
164        let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, key);
165
166        assert!(
167            scratch.available() >= scrach_needed,
168            "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
169            scratch.available(),
170        );
171
172        let base2k_res: usize = res.base2k().as_usize();
173        let base2k_key: usize = key.base2k().as_usize();
174
175        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
176
177        let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_key {
178            let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
179                n: res.n(),
180                base2k: key.base2k(),
181                k: res.k(),
182                rank: res.rank(),
183            });
184            self.glwe_normalize(&mut res_conv, res, scratch_2);
185
186            self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2)
187        } else {
188            self.glwe_keyswitch_internal(res_dft, res, key, scratch_1)
189        };
190
191        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
192        for i in 0..(res.rank() + 1).into() {
193            self.vec_znx_big_normalize(
194                base2k_res,
195                res.data_mut(),
196                i,
197                base2k_key,
198                &res_big,
199                i,
200                scratch_1,
201            );
202        }
203    }
204}
205
206pub trait GLWEKeyswitch<BE: Backend> {
207    fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
208    where
209        R: GLWEInfos,
210        A: GLWEInfos,
211        B: GGLWEInfos;
212
213    fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
214    where
215        R: GLWEToMut + GLWEInfos,
216        A: GLWEToRef + GLWEInfos,
217        K: GGLWEPreparedToRef<BE> + GGLWEInfos;
218
219    fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
220    where
221        R: GLWEToMut + GLWEInfos,
222        K: GGLWEPreparedToRef<BE> + GGLWEInfos;
223}
224
225impl<BE: Backend> GLWEKeySwitchInternal<BE> for Module<BE> where
226    Self: GGLWEProduct<BE>
227        + VecZnxDftApply<BE>
228        + VecZnxNormalize<BE>
229        + VecZnxIdftApplyConsume<BE>
230        + VecZnxBigAddSmallInplace<BE>
231        + VecZnxNormalizeTmpBytes
232{
233}
234
235pub(crate) trait GLWEKeySwitchInternal<BE: Backend>
236where
237    Self: GGLWEProduct<BE>
238        + VecZnxDftApply<BE>
239        + VecZnxNormalize<BE>
240        + VecZnxIdftApplyConsume<BE>
241        + VecZnxBigAddSmallInplace<BE>
242        + VecZnxNormalizeTmpBytes,
243{
244    fn glwe_keyswitch_internal_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
245    where
246        R: GLWEInfos,
247        A: GLWEInfos,
248        K: GGLWEInfos,
249    {
250        let cols: usize = (a_infos.rank() + 1).into();
251        let a_size: usize = a_infos.size();
252        self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols - 1, a_size)
253    }
254
255    fn glwe_keyswitch_internal<DR, A, K>(
256        &self,
257        mut res: VecZnxDft<DR, BE>,
258        a: &A,
259        key: &K,
260        scratch: &mut Scratch<BE>,
261    ) -> VecZnxBig<DR, BE>
262    where
263        DR: DataMut,
264        A: GLWEToRef,
265        K: GGLWEPreparedToRef<BE>,
266        Scratch<BE>: ScratchTakeCore<BE>,
267    {
268        let a: &GLWE<&[u8]> = &a.to_ref();
269        let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
270        assert_eq!(a.base2k(), key.base2k());
271        let cols: usize = (a.rank() + 1).into();
272        let a_size: usize = a.size();
273        let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size);
274        for col_i in 0..cols - 1 {
275            self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1);
276        }
277        self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1);
278        let mut res_big: VecZnxBig<DR, BE> = self.vec_znx_idft_apply_consume(res);
279        self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
280        res_big
281    }
282}
283
284impl<BE: Backend> GGLWEProduct<BE> for Module<BE> where
285    Self: Sized
286        + ModuleN
287        + VecZnxDftBytesOf
288        + VmpApplyDftToDftTmpBytes
289        + VmpApplyDftToDft<BE>
290        + VmpApplyDftToDftAdd<BE>
291        + VecZnxDftCopy<BE>
292{
293}
294
295pub(crate) trait GGLWEProduct<BE: Backend>
296where
297    Self: Sized
298        + ModuleN
299        + VecZnxDftBytesOf
300        + VmpApplyDftToDftTmpBytes
301        + VmpApplyDftToDft<BE>
302        + VmpApplyDftToDftAdd<BE>
303        + VecZnxDftCopy<BE>,
304{
305    fn gglwe_product_dft_tmp_bytes<K>(&self, res_size: usize, a_size: usize, key_infos: &K) -> usize
306    where
307        K: GGLWEInfos,
308    {
309        let dsize: usize = key_infos.dsize().as_usize();
310
311        if dsize == 1 {
312            self.vmp_apply_dft_to_dft_tmp_bytes(
313                res_size,
314                a_size,
315                key_infos.dnum().into(),
316                (key_infos.rank_in()).into(),
317                (key_infos.rank_out() + 1).into(),
318                key_infos.size(),
319            )
320        } else {
321            let dnum: usize = key_infos.dnum().into();
322            let a_size: usize = a_size.div_ceil(dsize).min(dnum);
323            let ai_dft: usize = self.bytes_of_vec_znx_dft(key_infos.rank_in().into(), a_size);
324
325            let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
326                res_size,
327                a_size,
328                dnum,
329                (key_infos.rank_in()).into(),
330                (key_infos.rank_out() + 1).into(),
331                key_infos.size(),
332            );
333
334            ai_dft + vmp
335        }
336    }
337
338    fn gglwe_product_dft<DR, A, K>(&self, res: &mut VecZnxDft<DR, BE>, a: &A, key: &K, scratch: &mut Scratch<BE>)
339    where
340        DR: DataMut,
341        A: VecZnxDftToRef<BE>,
342        K: GGLWEPreparedToRef<BE>,
343        Scratch<BE>: ScratchTakeCore<BE>,
344    {
345        let a: &VecZnxDft<&[u8], BE> = &a.to_ref();
346        let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
347
348        let cols: usize = a.cols();
349        let a_size: usize = a.size();
350        let pmat: &VmpPMat<&[u8], BE> = &key.data;
351
352        // If dsize == 1, then the digit decomposition is equal to Base2K and we can simply
353        // can the vmp API.
354        if key.dsize() == 1 {
355            self.vmp_apply_dft_to_dft(res, a, pmat, scratch);
356        // If dsize != 1, then the digit decomposition is k * Base2K with k > 1.
357        // As such we need to perform a bivariate polynomial convolution in (X, Y) / (X^{N}+1) with Y = 2^-K
358        // (instead of yn univariate one in X).
359        //
360        // Since the basis in Y is small (in practice degree 6-7 max), we perform it naiveley.
361        // To do so, we group the different limbs of ai_dft by their respective degree in Y
362        // which are multiples of the current digit.
363        // For example if dsize = 3, with ai_dft = [a0, a1, a2, a3, a4, a5, a6],
364        // we group them as [[a0, a3, a5], [a1, a4, a6], [a2, a5, 0]]
365        // and evaluate sum(a_di * pmat * 2^{di*Base2k})
366        } else {
367            let dsize: usize = key.dsize().into();
368            let dnum: usize = key.dnum().into();
369
370            // We bound ai_dft size by the number of rows of the matrix
371            let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize).min(dnum));
372            ai_dft.data_mut().fill(0);
373
374            for di in 0..dsize {
375                // Sets ai_dft size according to the current digit (if dsize does not divides a_size),
376                // bounded by the number of rows (digits) in the prepared matrix.
377                ai_dft.set_size(((a_size + di) / dsize).min(dnum));
378
379                // Small optimization for dsize > 2
380                // VMP produce some error e, and since we aggregate vmp * 2^{di * Base2k}, then
381                // we also aggregate ei * 2^{di * Base2k}, with the largest error being ei * 2^{(dsize-1) * Base2k}.
382                // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
383                // It is possible to further ignore the last dsize-1 limbs, but this introduce
384                // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
385                // noise is kept with respect to the ideal functionality.
386                res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
387
388                for j in 0..cols {
389                    self.vec_znx_dft_copy(dsize, dsize - di - 1, &mut ai_dft, j, a, j);
390                }
391
392                if di == 0 {
393                    // res = pmat * ai_dft
394                    self.vmp_apply_dft_to_dft(res, &ai_dft, pmat, scratch_1);
395                } else {
396                    // res = (pmat * ai_dft) * 2^{di * Base2k}
397                    self.vmp_apply_dft_to_dft_add(res, &ai_dft, pmat, di, scratch_1);
398                }
399            }
400
401            res.set_size(res.max_size());
402        }
403    }
404}