poulpy_core/keyswitching/
ggsw.rs1use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, VecZnx};
2
3use crate::{
4 GGSWExpandRows, ScratchTakeCore,
5 keyswitching::GLWEKeyswitch,
6 layouts::{GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::GLWETensorKeyPreparedToRef},
7};
8
9impl GGSW<Vec<u8>> {
10 pub fn keyswitch_tmp_bytes<R, A, K, T, M, BE: Backend>(
11 module: &M,
12 res_infos: &R,
13 a_infos: &A,
14 key_infos: &K,
15 tsk_infos: &T,
16 ) -> usize
17 where
18 R: GGSWInfos,
19 A: GGSWInfos,
20 K: GGLWEInfos,
21 T: GGLWEInfos,
22 M: GGSWKeyswitch<BE>,
23 {
24 module.ggsw_keyswitch_tmp_bytes(res_infos, a_infos, key_infos, tsk_infos)
25 }
26}
27
28impl<D: DataMut> GGSW<D> {
29 pub fn keyswitch<M, A, K, T, BE: Backend>(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
30 where
31 A: GGSWToRef,
32 K: GGLWEPreparedToRef<BE>,
33 T: GLWETensorKeyPreparedToRef<BE>,
34 Scratch<BE>: ScratchTakeCore<BE>,
35 M: GGSWKeyswitch<BE>,
36 {
37 module.ggsw_keyswitch(self, a, key, tsk, scratch);
38 }
39
40 pub fn keyswitch_inplace<M, K, T, BE: Backend>(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
41 where
42 K: GGLWEPreparedToRef<BE>,
43 T: GLWETensorKeyPreparedToRef<BE>,
44 Scratch<BE>: ScratchTakeCore<BE>,
45 M: GGSWKeyswitch<BE>,
46 {
47 module.ggsw_keyswitch_inplace(self, key, tsk, scratch);
48 }
49}
50
51impl<BE: Backend> GGSWKeyswitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE> {}
52
53pub trait GGSWKeyswitch<BE: Backend>
54where
55 Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>,
56{
57 fn ggsw_keyswitch_tmp_bytes<R, A, K, T>(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize
58 where
59 R: GGSWInfos,
60 A: GGSWInfos,
61 K: GGLWEInfos,
62 T: GGLWEInfos,
63 {
64 assert_eq!(key_infos.rank_in(), key_infos.rank_out());
65 assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out());
66 assert_eq!(key_infos.rank_in(), tsk_infos.rank_in());
67
68 let rank: usize = key_infos.rank_out().into();
69
70 let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize;
71 let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out);
72 let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out);
73 let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos);
74 let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos);
75 let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out);
76
77 if a_infos.base2k() == tsk_infos.base2k() {
78 res_znx + ci_dft + (ks | expand_rows | res_dft)
79 } else {
80 let a_conv: usize = VecZnx::bytes_of(
81 self.n(),
82 1,
83 res_infos.k().div_ceil(tsk_infos.base2k()) as usize,
84 ) + self.vec_znx_normalize_tmp_bytes();
85 res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft)
86 }
87 }
88
89 fn ggsw_keyswitch<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
90 where
91 R: GGSWToMut,
92 A: GGSWToRef,
93 K: GGLWEPreparedToRef<BE>,
94 T: GLWETensorKeyPreparedToRef<BE>,
95 Scratch<BE>: ScratchTakeCore<BE>,
96 {
97 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
98 let a: &GGSW<&[u8]> = &a.to_ref();
99
100 assert!(res.dnum() <= a.dnum());
101 assert_eq!(res.dsize(), a.dsize());
102
103 for row in 0..a.dnum().into() {
104 self.glwe_keyswitch(&mut res.at_mut(row, 0), &a.at(row, 0), key, scratch);
107 }
108
109 self.ggsw_expand_row(res, tsk, scratch);
110 }
111
112 fn ggsw_keyswitch_inplace<R, K, T>(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
113 where
114 R: GGSWToMut,
115 K: GGLWEPreparedToRef<BE>,
116 T: GLWETensorKeyPreparedToRef<BE>,
117 Scratch<BE>: ScratchTakeCore<BE>,
118 {
119 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
120
121 for row in 0..res.dnum().into() {
122 self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch);
125 }
126
127 self.ggsw_expand_row(res, tsk, scratch);
128 }
129}