1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
5 VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
8};
9
10use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWESwitchingKeyPrepared};
11
12impl GLWECiphertext<Vec<u8>> {
13 #[allow(clippy::too_many_arguments)]
14 pub fn keyswitch_scratch_space<B: Backend>(
15 module: &Module<B>,
16 basek: usize,
17 k_out: usize,
18 k_in: usize,
19 k_ksk: usize,
20 digits: usize,
21 rank_in: usize,
22 rank_out: usize,
23 ) -> usize
24 where
25 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
26 {
27 let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
28 let out_size: usize = k_out.div_ceil(basek);
29 let ksk_size: usize = k_ksk.div_ceil(basek);
30 let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size);
32 let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size)
33 + module.vec_znx_dft_alloc_bytes(rank_in, in_size);
34 let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
35 res_dft + ((ai_dft + vmp) | normalize)
36 }
37
38 pub fn keyswitch_inplace_scratch_space<B: Backend>(
39 module: &Module<B>,
40 basek: usize,
41 k_out: usize,
42 k_ksk: usize,
43 digits: usize,
44 rank: usize,
45 ) -> usize
46 where
47 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
48 {
49 Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank)
50 }
51}
52
53impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
54 #[allow(dead_code)]
55 pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
56 &self,
57 module: &Module<B>,
58 lhs: &GLWECiphertext<DataLhs>,
59 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
60 scratch: &Scratch<B>,
61 ) where
62 DataLhs: DataRef,
63 DataRhs: DataRef,
64 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
65 Scratch<B>: ScratchAvailable,
66 {
67 let basek: usize = self.basek();
68 assert_eq!(
69 lhs.rank(),
70 rhs.rank_in(),
71 "lhs.rank(): {} != rhs.rank_in(): {}",
72 lhs.rank(),
73 rhs.rank_in()
74 );
75 assert_eq!(
76 self.rank(),
77 rhs.rank_out(),
78 "self.rank(): {} != rhs.rank_out(): {}",
79 self.rank(),
80 rhs.rank_out()
81 );
82 assert_eq!(self.basek(), basek);
83 assert_eq!(lhs.basek(), basek);
84 assert_eq!(rhs.n(), self.n());
85 assert_eq!(lhs.n(), self.n());
86 assert!(
87 scratch.available()
88 >= GLWECiphertext::keyswitch_scratch_space(
89 module,
90 self.basek(),
91 self.k(),
92 lhs.k(),
93 rhs.k(),
94 rhs.digits(),
95 rhs.rank_in(),
96 rhs.rank_out(),
97 ),
98 "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
99 module,
100 self.basek(),
101 self.k(),
102 lhs.k(),
103 rhs.k(),
104 rhs.digits(),
105 rhs.rank_in(),
106 rhs.rank_out(),
107 )={}",
108 scratch.available(),
109 GLWECiphertext::keyswitch_scratch_space(
110 module,
111 self.basek(),
112 self.k(),
113 lhs.k(),
114 rhs.k(),
115 rhs.digits(),
116 rhs.rank_in(),
117 rhs.rank_out(),
118 )
119 );
120 }
121
122 #[allow(dead_code)]
123 pub(crate) fn assert_keyswitch_inplace<B: Backend, DataRhs>(
124 &self,
125 module: &Module<B>,
126 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
127 scratch: &Scratch<B>,
128 ) where
129 DataRhs: DataRef,
130 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
131 Scratch<B>: ScratchAvailable,
132 {
133 let basek: usize = self.basek();
134 assert_eq!(
135 self.rank(),
136 rhs.rank_out(),
137 "self.rank(): {} != rhs.rank_out(): {}",
138 self.rank(),
139 rhs.rank_out()
140 );
141 assert_eq!(self.basek(), basek);
142 assert_eq!(rhs.n(), self.n());
143 assert!(
144 scratch.available()
145 >= GLWECiphertext::keyswitch_scratch_space(
146 module,
147 self.basek(),
148 self.k(),
149 self.k(),
150 rhs.k(),
151 rhs.digits(),
152 rhs.rank_in(),
153 rhs.rank_out(),
154 ),
155 "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
156 module,
157 self.basek(),
158 self.k(),
159 self.k(),
160 rhs.k(),
161 rhs.digits(),
162 rhs.rank_in(),
163 rhs.rank_out(),
164 )={}",
165 scratch.available(),
166 GLWECiphertext::keyswitch_scratch_space(
167 module,
168 self.basek(),
169 self.k(),
170 self.k(),
171 rhs.k(),
172 rhs.digits(),
173 rhs.rank_in(),
174 rhs.rank_out(),
175 )
176 );
177 }
178}
179
180impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
181 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
182 &mut self,
183 module: &Module<B>,
184 lhs: &GLWECiphertext<DataLhs>,
185 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
186 scratch: &mut Scratch<B>,
187 ) where
188 Module<B>: VecZnxDftAllocBytes
189 + VmpApplyDftToDftTmpBytes
190 + VecZnxBigNormalizeTmpBytes
191 + VmpApplyDftToDft<B>
192 + VmpApplyDftToDftAdd<B>
193 + VecZnxDftApply<B>
194 + VecZnxIdftApplyConsume<B>
195 + VecZnxBigAddSmallInplace<B>
196 + VecZnxBigNormalize<B>,
197 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
198 {
199 #[cfg(debug_assertions)]
200 {
201 self.assert_keyswitch(module, lhs, rhs, scratch);
202 }
203 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch_1);
205 (0..self.cols()).for_each(|i| {
206 module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
207 })
208 }
209
210 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
211 &mut self,
212 module: &Module<B>,
213 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
214 scratch: &mut Scratch<B>,
215 ) where
216 Module<B>: VecZnxDftAllocBytes
217 + VmpApplyDftToDftTmpBytes
218 + VecZnxBigNormalizeTmpBytes
219 + VmpApplyDftToDftTmpBytes
220 + VmpApplyDftToDft<B>
221 + VmpApplyDftToDftAdd<B>
222 + VecZnxDftApply<B>
223 + VecZnxIdftApplyConsume<B>
224 + VecZnxBigAddSmallInplace<B>
225 + VecZnxBigNormalize<B>,
226 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
227 {
228 #[cfg(debug_assertions)]
229 {
230 self.assert_keyswitch_inplace(module, rhs, scratch);
231 }
232 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1);
234 (0..self.cols()).for_each(|i| {
235 module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
236 })
237 }
238}
239
240impl<D: DataRef> GLWECiphertext<D> {
241 pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
242 &self,
243 module: &Module<B>,
244 res_dft: VecZnxDft<DataRes, B>,
245 rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
246 scratch: &mut Scratch<B>,
247 ) -> VecZnxBig<DataRes, B>
248 where
249 DataRes: DataMut,
250 DataKey: DataRef,
251 Module<B>: VecZnxDftAllocBytes
252 + VmpApplyDftToDftTmpBytes
253 + VecZnxBigNormalizeTmpBytes
254 + VmpApplyDftToDftTmpBytes
255 + VmpApplyDftToDft<B>
256 + VmpApplyDftToDftAdd<B>
257 + VecZnxDftApply<B>
258 + VecZnxIdftApplyConsume<B>
259 + VecZnxBigAddSmallInplace<B>
260 + VecZnxBigNormalize<B>,
261 Scratch<B>: TakeVecZnxDft<B>,
262 {
263 if rhs.digits() == 1 {
264 return keyswitch_vmp_one_digit(module, res_dft, &self.data, &rhs.key.data, scratch);
265 }
266
267 keyswitch_vmp_multiple_digits(
268 module,
269 res_dft,
270 &self.data,
271 &rhs.key.data,
272 rhs.digits(),
273 scratch,
274 )
275 }
276}
277
278fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
279 module: &Module<B>,
280 mut res_dft: VecZnxDft<DataRes, B>,
281 a: &VecZnx<DataIn>,
282 mat: &VmpPMat<DataVmp, B>,
283 scratch: &mut Scratch<B>,
284) -> VecZnxBig<DataRes, B>
285where
286 DataRes: DataMut,
287 DataIn: DataRef,
288 DataVmp: DataRef,
289 Module<B>:
290 VecZnxDftAllocBytes + VecZnxDftApply<B> + VmpApplyDftToDft<B> + VecZnxIdftApplyConsume<B> + VecZnxBigAddSmallInplace<B>,
291 Scratch<B>: TakeVecZnxDft<B>,
292{
293 let cols: usize = a.cols();
294 let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
295 (0..cols - 1).for_each(|col_i| {
296 module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1);
297 });
298 module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
299 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
300 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
301 res_big
302}
303
304fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
305 module: &Module<B>,
306 mut res_dft: VecZnxDft<DataRes, B>,
307 a: &VecZnx<DataIn>,
308 mat: &VmpPMat<DataVmp, B>,
309 digits: usize,
310 scratch: &mut Scratch<B>,
311) -> VecZnxBig<DataRes, B>
312where
313 DataRes: DataMut,
314 DataIn: DataRef,
315 DataVmp: DataRef,
316 Module<B>: VecZnxDftAllocBytes
317 + VecZnxDftApply<B>
318 + VmpApplyDftToDft<B>
319 + VmpApplyDftToDftAdd<B>
320 + VecZnxIdftApplyConsume<B>
321 + VecZnxBigAddSmallInplace<B>,
322 Scratch<B>: TakeVecZnxDft<B>,
323{
324 let cols: usize = a.cols();
325 let size: usize = a.size();
326 let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
327
328 ai_dft.data_mut().fill(0);
329
330 (0..digits).for_each(|di| {
331 ai_dft.set_size((size + di) / digits);
332
333 res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
341
342 (0..cols - 1).for_each(|col_i| {
343 module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
344 });
345
346 if di == 0 {
347 module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
348 } else {
349 module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1);
350 }
351 });
352
353 res_dft.set_size(res_dft.max_size());
354 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
355 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
356 res_big
357}