1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
4 VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5 },
6 layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
7};
8
9use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWESwitchingKeyPrepared};
10
11impl GLWECiphertext<Vec<u8>> {
12 #[allow(clippy::too_many_arguments)]
13 pub fn keyswitch_scratch_space<B: Backend>(
14 module: &Module<B>,
15 basek: usize,
16 k_out: usize,
17 k_in: usize,
18 k_ksk: usize,
19 digits: usize,
20 rank_in: usize,
21 rank_out: usize,
22 ) -> usize
23 where
24 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
25 {
26 let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
27 let out_size: usize = k_out.div_ceil(basek);
28 let ksk_size: usize = k_ksk.div_ceil(basek);
29 let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size);
31 let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size)
32 + module.vec_znx_dft_alloc_bytes(rank_in, in_size);
33 let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
34 res_dft + ((ai_dft + vmp) | normalize)
35 }
36
37 pub fn keyswitch_inplace_scratch_space<B: Backend>(
38 module: &Module<B>,
39 basek: usize,
40 k_out: usize,
41 k_ksk: usize,
42 digits: usize,
43 rank: usize,
44 ) -> usize
45 where
46 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
47 {
48 Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank)
49 }
50}
51
52impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
53 #[allow(dead_code)]
54 pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
55 &self,
56 module: &Module<B>,
57 lhs: &GLWECiphertext<DataLhs>,
58 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
59 scratch: &Scratch<B>,
60 ) where
61 DataLhs: DataRef,
62 DataRhs: DataRef,
63 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
64 Scratch<B>: ScratchAvailable,
65 {
66 let basek: usize = self.basek();
67 assert_eq!(
68 lhs.rank(),
69 rhs.rank_in(),
70 "lhs.rank(): {} != rhs.rank_in(): {}",
71 lhs.rank(),
72 rhs.rank_in()
73 );
74 assert_eq!(
75 self.rank(),
76 rhs.rank_out(),
77 "self.rank(): {} != rhs.rank_out(): {}",
78 self.rank(),
79 rhs.rank_out()
80 );
81 assert_eq!(self.basek(), basek);
82 assert_eq!(lhs.basek(), basek);
83 assert_eq!(rhs.n(), self.n());
84 assert_eq!(lhs.n(), self.n());
85 assert!(
86 scratch.available()
87 >= GLWECiphertext::keyswitch_scratch_space(
88 module,
89 self.basek(),
90 self.k(),
91 lhs.k(),
92 rhs.k(),
93 rhs.digits(),
94 rhs.rank_in(),
95 rhs.rank_out(),
96 ),
97 "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
98 module,
99 self.basek(),
100 self.k(),
101 lhs.k(),
102 rhs.k(),
103 rhs.digits(),
104 rhs.rank_in(),
105 rhs.rank_out(),
106 )={}",
107 scratch.available(),
108 GLWECiphertext::keyswitch_scratch_space(
109 module,
110 self.basek(),
111 self.k(),
112 lhs.k(),
113 rhs.k(),
114 rhs.digits(),
115 rhs.rank_in(),
116 rhs.rank_out(),
117 )
118 );
119 }
120}
121
122impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
123 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
124 &mut self,
125 module: &Module<B>,
126 lhs: &GLWECiphertext<DataLhs>,
127 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
128 scratch: &mut Scratch<B>,
129 ) where
130 Module<B>: VecZnxDftAllocBytes
131 + VmpApplyDftToDftTmpBytes
132 + VecZnxBigNormalizeTmpBytes
133 + VmpApplyDftToDftTmpBytes
134 + VmpApplyDftToDft<B>
135 + VmpApplyDftToDftAdd<B>
136 + DFT<B>
137 + IDFTConsume<B>
138 + VecZnxBigAddSmallInplace<B>
139 + VecZnxBigNormalize<B>,
140 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
141 {
142 #[cfg(debug_assertions)]
143 {
144 self.assert_keyswitch(module, lhs, rhs, scratch);
145 }
146 let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1);
148 (0..self.cols()).for_each(|i| {
149 module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
150 })
151 }
152
153 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
154 &mut self,
155 module: &Module<B>,
156 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
157 scratch: &mut Scratch<B>,
158 ) where
159 Module<B>: VecZnxDftAllocBytes
160 + VmpApplyDftToDftTmpBytes
161 + VecZnxBigNormalizeTmpBytes
162 + VmpApplyDftToDftTmpBytes
163 + VmpApplyDftToDft<B>
164 + VmpApplyDftToDftAdd<B>
165 + DFT<B>
166 + IDFTConsume<B>
167 + VecZnxBigAddSmallInplace<B>
168 + VecZnxBigNormalize<B>,
169 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
170 {
171 unsafe {
172 let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
173 self.keyswitch(module, &*self_ptr, rhs, scratch);
174 }
175 }
176}
177
178impl<D: DataRef> GLWECiphertext<D> {
179 pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
180 &self,
181 module: &Module<B>,
182 res_dft: VecZnxDft<DataRes, B>,
183 rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
184 scratch: &mut Scratch<B>,
185 ) -> VecZnxBig<DataRes, B>
186 where
187 DataRes: DataMut,
188 DataKey: DataRef,
189 Module<B>: VecZnxDftAllocBytes
190 + VmpApplyDftToDftTmpBytes
191 + VecZnxBigNormalizeTmpBytes
192 + VmpApplyDftToDftTmpBytes
193 + VmpApplyDftToDft<B>
194 + VmpApplyDftToDftAdd<B>
195 + DFT<B>
196 + IDFTConsume<B>
197 + VecZnxBigAddSmallInplace<B>
198 + VecZnxBigNormalize<B>,
199 Scratch<B>: TakeVecZnxDft<B>,
200 {
201 if rhs.digits() == 1 {
202 return keyswitch_vmp_one_digit(module, res_dft, &self.data, &rhs.key.data, scratch);
203 }
204
205 keyswitch_vmp_multiple_digits(
206 module,
207 res_dft,
208 &self.data,
209 &rhs.key.data,
210 rhs.digits(),
211 scratch,
212 )
213 }
214}
215
216fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
217 module: &Module<B>,
218 mut res_dft: VecZnxDft<DataRes, B>,
219 a: &VecZnx<DataIn>,
220 mat: &VmpPMat<DataVmp, B>,
221 scratch: &mut Scratch<B>,
222) -> VecZnxBig<DataRes, B>
223where
224 DataRes: DataMut,
225 DataIn: DataRef,
226 DataVmp: DataRef,
227 Module<B>: VecZnxDftAllocBytes + DFT<B> + VmpApplyDftToDft<B> + IDFTConsume<B> + VecZnxBigAddSmallInplace<B>,
228 Scratch<B>: TakeVecZnxDft<B>,
229{
230 let cols: usize = a.cols();
231 let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
232 (0..cols - 1).for_each(|col_i| {
233 module.dft(1, 0, &mut ai_dft, col_i, a, col_i + 1);
234 });
235 module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
236 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
237 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
238 res_big
239}
240
241fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
242 module: &Module<B>,
243 mut res_dft: VecZnxDft<DataRes, B>,
244 a: &VecZnx<DataIn>,
245 mat: &VmpPMat<DataVmp, B>,
246 digits: usize,
247 scratch: &mut Scratch<B>,
248) -> VecZnxBig<DataRes, B>
249where
250 DataRes: DataMut,
251 DataIn: DataRef,
252 DataVmp: DataRef,
253 Module<B>: VecZnxDftAllocBytes
254 + DFT<B>
255 + VmpApplyDftToDft<B>
256 + VmpApplyDftToDftAdd<B>
257 + IDFTConsume<B>
258 + VecZnxBigAddSmallInplace<B>,
259 Scratch<B>: TakeVecZnxDft<B>,
260{
261 let cols: usize = a.cols();
262 let size: usize = a.size();
263 let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
264
265 ai_dft.data_mut().fill(0);
266
267 (0..digits).for_each(|di| {
268 ai_dft.set_size((size + di) / digits);
269
270 res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
278
279 (0..cols - 1).for_each(|col_i| {
280 module.dft(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
281 });
282
283 if di == 0 {
284 module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
285 } else {
286 module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch1);
287 }
288 });
289
290 res_dft.set_size(res_dft.max_size());
291 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
292 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
293 res_big
294}