poulpy_core/keyswitching/
glwe.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5        VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
8};
9
10use crate::{
11    ScratchTakeCore,
12    layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos},
13};
14
15impl GLWE<Vec<u8>> {
16    pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
17    where
18        R: GLWEInfos,
19        A: GLWEInfos,
20        B: GGLWEInfos,
21        M: GLWEKeyswitch<BE>,
22    {
23        module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
24    }
25}
26
27impl<D: DataMut> GLWE<D> {
28    pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
29    where
30        A: GLWEToRef,
31        B: GGLWEPreparedToRef<BE>,
32        M: GLWEKeyswitch<BE>,
33        Scratch<BE>: ScratchTakeCore<BE>,
34    {
35        module.glwe_keyswitch(self, a, b, scratch);
36    }
37
38    pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
39    where
40        A: GGLWEPreparedToRef<BE>,
41        M: GLWEKeyswitch<BE>,
42        Scratch<BE>: ScratchTakeCore<BE>,
43    {
44        module.glwe_keyswitch_inplace(self, a, scratch);
45    }
46}
47
48impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE> where
49    Self: Sized
50        + ModuleN
51        + VecZnxDftBytesOf
52        + VmpApplyDftToDftTmpBytes
53        + VecZnxBigNormalizeTmpBytes
54        + VecZnxNormalizeTmpBytes
55        + VecZnxDftBytesOf
56        + VmpApplyDftToDftTmpBytes
57        + VecZnxBigNormalizeTmpBytes
58        + VmpApplyDftToDft<BE>
59        + VmpApplyDftToDftAdd<BE>
60        + VecZnxDftApply<BE>
61        + VecZnxIdftApplyConsume<BE>
62        + VecZnxBigAddSmallInplace<BE>
63        + VecZnxBigNormalize<BE>
64        + VecZnxNormalize<BE>
65        + VecZnxNormalizeTmpBytes
66{
67}
68
69pub trait GLWEKeyswitch<BE: Backend>
70where
71    Self: Sized
72        + ModuleN
73        + VecZnxDftBytesOf
74        + VmpApplyDftToDftTmpBytes
75        + VecZnxBigNormalizeTmpBytes
76        + VecZnxNormalizeTmpBytes
77        + VecZnxDftBytesOf
78        + VmpApplyDftToDftTmpBytes
79        + VecZnxBigNormalizeTmpBytes
80        + VmpApplyDftToDft<BE>
81        + VmpApplyDftToDftAdd<BE>
82        + VecZnxDftApply<BE>
83        + VecZnxIdftApplyConsume<BE>
84        + VecZnxBigAddSmallInplace<BE>
85        + VecZnxBigNormalize<BE>
86        + VecZnxNormalize<BE>
87        + VecZnxNormalizeTmpBytes,
88{
89    fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
90    where
91        R: GLWEInfos,
92        A: GLWEInfos,
93        B: GGLWEInfos,
94    {
95        let in_size: usize = a_infos
96            .k()
97            .div_ceil(key_infos.base2k())
98            .div_ceil(key_infos.dsize().into()) as usize;
99        let out_size: usize = res_infos.size();
100        let ksk_size: usize = key_infos.size();
101        let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
102        let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
103        let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
104            out_size,
105            in_size,
106            in_size,
107            (key_infos.rank_in()).into(),
108            (key_infos.rank_out() + 1).into(),
109            ksk_size,
110        ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
111        let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes();
112        if a_infos.base2k() == key_infos.base2k() {
113            res_dft + ((ai_dft + vmp) | normalize_big)
114        } else if key_infos.dsize() == 1 {
115            // In this case, we only need one column, temporary, that we can drop once a_dft is computed.
116            let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes();
117            res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
118        } else {
119            // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion.
120            let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size);
121            res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
122        }
123    }
124
125    fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
126    where
127        R: GLWEToMut,
128        A: GLWEToRef,
129        K: GGLWEPreparedToRef<BE>,
130        Scratch<BE>: ScratchTakeCore<BE>,
131    {
132        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
133        let a: &GLWE<&[u8]> = &a.to_ref();
134        let b: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
135
136        assert_eq!(
137            a.rank(),
138            b.rank_in(),
139            "a.rank(): {} != b.rank_in(): {}",
140            a.rank(),
141            b.rank_in()
142        );
143        assert_eq!(
144            res.rank(),
145            b.rank_out(),
146            "res.rank(): {} != b.rank_out(): {}",
147            res.rank(),
148            b.rank_out()
149        );
150
151        assert_eq!(res.n(), self.n() as u32);
152        assert_eq!(a.n(), self.n() as u32);
153        assert_eq!(b.n(), self.n() as u32);
154
155        let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, b);
156
157        assert!(
158            scratch.available() >= scrach_needed,
159            "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
160            scratch.available(),
161        );
162
163        let basek_out: usize = res.base2k().into();
164        let base2k_out: usize = b.base2k().into();
165
166        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise
167        let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1);
168        (0..(res.rank() + 1).into()).for_each(|i| {
169            self.vec_znx_big_normalize(
170                basek_out,
171                &mut res.data,
172                i,
173                base2k_out,
174                &res_big,
175                i,
176                scratch_1,
177            );
178        })
179    }
180
181    fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
182    where
183        R: GLWEToMut,
184        K: GGLWEPreparedToRef<BE>,
185        Scratch<BE>: ScratchTakeCore<BE>,
186    {
187        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
188        let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
189
190        assert_eq!(
191            res.rank(),
192            a.rank_in(),
193            "res.rank(): {} != a.rank_in(): {}",
194            res.rank(),
195            a.rank_in()
196        );
197        assert_eq!(
198            res.rank(),
199            a.rank_out(),
200            "res.rank(): {} != b.rank_out(): {}",
201            res.rank(),
202            a.rank_out()
203        );
204
205        assert_eq!(res.n(), self.n() as u32);
206        assert_eq!(a.n(), self.n() as u32);
207
208        let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a);
209
210        assert!(
211            scratch.available() >= scrach_needed,
212            "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
213            scratch.available(),
214        );
215
216        let base2k_in: usize = res.base2k().into();
217        let base2k_out: usize = a.base2k().into();
218
219        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise
220        let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1);
221        (0..(res.rank() + 1).into()).for_each(|i| {
222            self.vec_znx_big_normalize(
223                base2k_in,
224                &mut res.data,
225                i,
226                base2k_out,
227                &res_big,
228                i,
229                scratch_1,
230            );
231        })
232    }
233}
234
235impl GLWE<Vec<u8>> {}
236
237impl<DataSelf: DataMut> GLWE<DataSelf> {}
238
239pub(crate) fn keyswitch_internal<BE: Backend, M, DR, A, K>(
240    module: &M,
241    mut res: VecZnxDft<DR, BE>,
242    a: &A,
243    key: &K,
244    scratch: &mut Scratch<BE>,
245) -> VecZnxBig<DR, BE>
246where
247    DR: DataMut,
248    A: GLWEToRef,
249    K: GGLWEPreparedToRef<BE>,
250    M: ModuleN
251        + VecZnxDftBytesOf
252        + VmpApplyDftToDftTmpBytes
253        + VecZnxBigNormalizeTmpBytes
254        + VmpApplyDftToDftTmpBytes
255        + VmpApplyDftToDft<BE>
256        + VmpApplyDftToDftAdd<BE>
257        + VecZnxDftApply<BE>
258        + VecZnxIdftApplyConsume<BE>
259        + VecZnxBigAddSmallInplace<BE>
260        + VecZnxBigNormalize<BE>
261        + VecZnxNormalize<BE>,
262    Scratch<BE>: ScratchTakeCore<BE>,
263{
264    let a: &GLWE<&[u8]> = &a.to_ref();
265    let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
266
267    let base2k_in: usize = a.base2k().into();
268    let base2k_out: usize = key.base2k().into();
269    let cols: usize = (a.rank() + 1).into();
270    let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
271    let pmat: &VmpPMat<&[u8], BE> = &key.data;
272
273    if key.dsize() == 1 {
274        let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());
275
276        if base2k_in == base2k_out {
277            (0..cols - 1).for_each(|col_i| {
278                module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1);
279            });
280        } else {
281            let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size);
282            (0..cols - 1).for_each(|col_i| {
283                module.vec_znx_normalize(
284                    base2k_out,
285                    &mut a_conv,
286                    0,
287                    base2k_in,
288                    a.data(),
289                    col_i + 1,
290                    scratch_2,
291                );
292                module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0);
293            });
294        }
295
296        module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
297    } else {
298        let dsize: usize = key.dsize().into();
299
300        let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize));
301        ai_dft.data_mut().fill(0);
302
303        if base2k_in == base2k_out {
304            for di in 0..dsize {
305                ai_dft.set_size((a_size + di) / dsize);
306
307                // Small optimization for dsize > 2
308                // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
309                // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
310                // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
311                // It is possible to further ignore the last dsize-1 limbs, but this introduce
312                // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
313                // noise is kept with respect to the ideal functionality.
314                res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
315
316                for j in 0..cols - 1 {
317                    module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1);
318                }
319
320                if di == 0 {
321                    module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
322                } else {
323                    module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1);
324                }
325            }
326        } else {
327            let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size);
328            for j in 0..cols - 1 {
329                module.vec_znx_normalize(
330                    base2k_out,
331                    &mut a_conv,
332                    j,
333                    base2k_in,
334                    a.data(),
335                    j + 1,
336                    scratch_2,
337                );
338            }
339
340            for di in 0..dsize {
341                ai_dft.set_size((a_size + di) / dsize);
342
343                // Small optimization for dsize > 2
344                // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
345                // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
346                // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
347                // It is possible to further ignore the last dsize-1 limbs, but this introduce
348                // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
349                // noise is kept with respect to the ideal functionality.
350                res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
351
352                for j in 0..cols - 1 {
353                    module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j);
354                }
355
356                if di == 0 {
357                    module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2);
358                } else {
359                    module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2);
360                }
361            }
362        }
363
364        res.set_size(res.max_size());
365    }
366
367    let mut res_big: VecZnxBig<DR, BE> = module.vec_znx_idft_apply_consume(res);
368    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
369    res_big
370}