1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5 VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
8};
9
10use crate::layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWESwitchingKeyPrepared};
11
12impl GLWECiphertext<Vec<u8>> {
13 pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
14 module: &Module<B>,
15 out_infos: &OUT,
16 in_infos: &IN,
17 key_apply: &KEY,
18 ) -> usize
19 where
20 OUT: GLWEInfos,
21 IN: GLWEInfos,
22 KEY: GGLWELayoutInfos,
23 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
24 {
25 let in_size: usize = in_infos
26 .k()
27 .div_ceil(key_apply.base2k())
28 .div_ceil(key_apply.digits().into()) as usize;
29 let out_size: usize = out_infos.size();
30 let ksk_size: usize = key_apply.size();
31 let res_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_out() + 1).into(), ksk_size); let ai_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size);
33 let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
34 out_size,
35 in_size,
36 in_size,
37 (key_apply.rank_in()).into(),
38 (key_apply.rank_out() + 1).into(),
39 ksk_size,
40 ) + module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size);
41 let normalize_big: usize = module.vec_znx_big_normalize_tmp_bytes();
42 if in_infos.base2k() == key_apply.base2k() {
43 res_dft + ((ai_dft + vmp) | normalize_big)
44 } else if key_apply.digits() == 1 {
45 let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), 1, in_size) + module.vec_znx_normalize_tmp_bytes();
47 res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
48 } else {
49 let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (key_apply.rank_in()).into(), in_size);
51 res_dft + ((ai_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
52 }
53 }
54
55 pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_apply: &KEY) -> usize
56 where
57 OUT: GLWEInfos,
58 KEY: GGLWELayoutInfos,
59 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
60 {
61 Self::keyswitch_scratch_space(module, out_infos, out_infos, key_apply)
62 }
63}
64
65impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
66 #[allow(dead_code)]
67 pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
68 &self,
69 module: &Module<B>,
70 lhs: &GLWECiphertext<DataLhs>,
71 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
72 scratch: &Scratch<B>,
73 ) where
74 DataLhs: DataRef,
75 DataRhs: DataRef,
76 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
77 Scratch<B>: ScratchAvailable,
78 {
79 assert_eq!(
80 lhs.rank(),
81 rhs.rank_in(),
82 "lhs.rank(): {} != rhs.rank_in(): {}",
83 lhs.rank(),
84 rhs.rank_in()
85 );
86 assert_eq!(
87 self.rank(),
88 rhs.rank_out(),
89 "self.rank(): {} != rhs.rank_out(): {}",
90 self.rank(),
91 rhs.rank_out()
92 );
93 assert_eq!(rhs.n(), self.n());
94 assert_eq!(lhs.n(), self.n());
95
96 let scrach_needed: usize = GLWECiphertext::keyswitch_scratch_space(module, self, lhs, rhs);
97
98 assert!(
99 scratch.available() >= scrach_needed,
100 "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
101 module,
102 self.base2k(),
103 self.k(),
104 lhs.base2k(),
105 lhs.k(),
106 rhs.base2k(),
107 rhs.k(),
108 rhs.digits(),
109 rhs.rank_in(),
110 rhs.rank_out(),
111 )={scrach_needed}",
112 scratch.available(),
113 );
114 }
115
116 #[allow(dead_code)]
117 pub(crate) fn assert_keyswitch_inplace<B: Backend, DataRhs>(
118 &self,
119 module: &Module<B>,
120 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
121 scratch: &Scratch<B>,
122 ) where
123 DataRhs: DataRef,
124 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
125 Scratch<B>: ScratchAvailable,
126 {
127 assert_eq!(
128 self.rank(),
129 rhs.rank_out(),
130 "self.rank(): {} != rhs.rank_out(): {}",
131 self.rank(),
132 rhs.rank_out()
133 );
134
135 assert_eq!(rhs.n(), self.n());
136
137 let scrach_needed: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, self, rhs);
138
139 assert!(
140 scratch.available() >= scrach_needed,
141 "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space()={scrach_needed}",
142 scratch.available(),
143 );
144 }
145}
146
147impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
148 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
149 &mut self,
150 module: &Module<B>,
151 glwe_in: &GLWECiphertext<DataLhs>,
152 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
153 scratch: &mut Scratch<B>,
154 ) where
155 Module<B>: VecZnxDftAllocBytes
156 + VmpApplyDftToDftTmpBytes
157 + VecZnxBigNormalizeTmpBytes
158 + VmpApplyDftToDft<B>
159 + VmpApplyDftToDftAdd<B>
160 + VecZnxDftApply<B>
161 + VecZnxIdftApplyConsume<B>
162 + VecZnxBigAddSmallInplace<B>
163 + VecZnxBigNormalize<B>
164 + VecZnxNormalize<B>
165 + VecZnxNormalizeTmpBytes,
166 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
167 {
168 #[cfg(debug_assertions)]
169 {
170 self.assert_keyswitch(module, glwe_in, rhs, scratch);
171 }
172
173 let basek_out: usize = self.base2k().into();
174 let basek_ksk: usize = rhs.base2k().into();
175
176 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); let res_big: VecZnxBig<_, B> = glwe_in.keyswitch_internal(module, res_dft, rhs, scratch_1);
178 (0..(self.rank() + 1).into()).for_each(|i| {
179 module.vec_znx_big_normalize(
180 basek_out,
181 &mut self.data,
182 i,
183 basek_ksk,
184 &res_big,
185 i,
186 scratch_1,
187 );
188 })
189 }
190
191 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
192 &mut self,
193 module: &Module<B>,
194 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
195 scratch: &mut Scratch<B>,
196 ) where
197 Module<B>: VecZnxDftAllocBytes
198 + VmpApplyDftToDftTmpBytes
199 + VecZnxBigNormalizeTmpBytes
200 + VmpApplyDftToDftTmpBytes
201 + VmpApplyDftToDft<B>
202 + VmpApplyDftToDftAdd<B>
203 + VecZnxDftApply<B>
204 + VecZnxIdftApplyConsume<B>
205 + VecZnxBigAddSmallInplace<B>
206 + VecZnxBigNormalize<B>
207 + VecZnxNormalize<B>
208 + VecZnxNormalizeTmpBytes,
209 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
210 {
211 #[cfg(debug_assertions)]
212 {
213 self.assert_keyswitch_inplace(module, rhs, scratch);
214 }
215
216 let basek_in: usize = self.base2k().into();
217 let basek_ksk: usize = rhs.base2k().into();
218
219 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1);
221 (0..(self.rank() + 1).into()).for_each(|i| {
222 module.vec_znx_big_normalize(
223 basek_in,
224 &mut self.data,
225 i,
226 basek_ksk,
227 &res_big,
228 i,
229 scratch_1,
230 );
231 })
232 }
233}
234
235impl<D: DataRef> GLWECiphertext<D> {
236 pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
237 &self,
238 module: &Module<B>,
239 res_dft: VecZnxDft<DataRes, B>,
240 rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
241 scratch: &mut Scratch<B>,
242 ) -> VecZnxBig<DataRes, B>
243 where
244 DataRes: DataMut,
245 DataKey: DataRef,
246 Module<B>: VecZnxDftAllocBytes
247 + VmpApplyDftToDftTmpBytes
248 + VecZnxBigNormalizeTmpBytes
249 + VmpApplyDftToDftTmpBytes
250 + VmpApplyDftToDft<B>
251 + VmpApplyDftToDftAdd<B>
252 + VecZnxDftApply<B>
253 + VecZnxIdftApplyConsume<B>
254 + VecZnxBigAddSmallInplace<B>
255 + VecZnxBigNormalize<B>
256 + VecZnxNormalize<B>,
257 Scratch<B>: TakeVecZnxDft<B> + TakeVecZnx,
258 {
259 if rhs.digits() == 1 {
260 return keyswitch_vmp_one_digit(
261 module,
262 self.base2k().into(),
263 rhs.base2k().into(),
264 res_dft,
265 &self.data,
266 &rhs.key.data,
267 scratch,
268 );
269 }
270
271 keyswitch_vmp_multiple_digits(
272 module,
273 self.base2k().into(),
274 rhs.base2k().into(),
275 res_dft,
276 &self.data,
277 &rhs.key.data,
278 rhs.digits().into(),
279 scratch,
280 )
281 }
282}
283
284fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
285 module: &Module<B>,
286 basek_in: usize,
287 basek_ksk: usize,
288 mut res_dft: VecZnxDft<DataRes, B>,
289 a: &VecZnx<DataIn>,
290 mat: &VmpPMat<DataVmp, B>,
291 scratch: &mut Scratch<B>,
292) -> VecZnxBig<DataRes, B>
293where
294 DataRes: DataMut,
295 DataIn: DataRef,
296 DataVmp: DataRef,
297 Module<B>: VecZnxDftAllocBytes
298 + VecZnxDftApply<B>
299 + VmpApplyDftToDft<B>
300 + VecZnxIdftApplyConsume<B>
301 + VecZnxBigAddSmallInplace<B>
302 + VecZnxNormalize<B>,
303 Scratch<B>: TakeVecZnxDft<B> + TakeVecZnx,
304{
305 let cols: usize = a.cols();
306
307 let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk);
308 let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
309
310 if basek_in == basek_ksk {
311 (0..cols - 1).for_each(|col_i| {
312 module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1);
313 });
314 } else {
315 let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), 1, a_size);
316 (0..cols - 1).for_each(|col_i| {
317 module.vec_znx_normalize(basek_ksk, &mut a_conv, 0, basek_in, a, col_i + 1, scratch_2);
318 module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0);
319 });
320 }
321
322 module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
323 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
324 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
325 res_big
326}
327
328#[allow(clippy::too_many_arguments)]
329fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
330 module: &Module<B>,
331 basek_in: usize,
332 basek_ksk: usize,
333 mut res_dft: VecZnxDft<DataRes, B>,
334 a: &VecZnx<DataIn>,
335 mat: &VmpPMat<DataVmp, B>,
336 digits: usize,
337 scratch: &mut Scratch<B>,
338) -> VecZnxBig<DataRes, B>
339where
340 DataRes: DataMut,
341 DataIn: DataRef,
342 DataVmp: DataRef,
343 Module<B>: VecZnxDftAllocBytes
344 + VecZnxDftApply<B>
345 + VmpApplyDftToDft<B>
346 + VmpApplyDftToDftAdd<B>
347 + VecZnxIdftApplyConsume<B>
348 + VecZnxBigAddSmallInplace<B>
349 + VecZnxNormalize<B>,
350 Scratch<B>: TakeVecZnxDft<B> + TakeVecZnx,
351{
352 let cols: usize = a.cols();
353 let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk);
354 let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(digits));
355 ai_dft.data_mut().fill(0);
356
357 if basek_in == basek_ksk {
358 for di in 0..digits {
359 ai_dft.set_size((a_size + di) / digits);
360
361 res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
369
370 for j in 0..cols - 1 {
371 module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, a, j + 1);
372 }
373
374 if di == 0 {
375 module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
376 } else {
377 module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1);
378 }
379 }
380 } else {
381 let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), cols - 1, a_size);
382 for j in 0..cols - 1 {
383 module.vec_znx_normalize(basek_ksk, &mut a_conv, j, basek_in, a, j + 1, scratch_2);
384 }
385
386 for di in 0..digits {
387 ai_dft.set_size((a_size + di) / digits);
388
389 res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
397
398 for j in 0..cols - 1 {
399 module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, &a_conv, j);
400 }
401
402 if di == 0 {
403 module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_2);
404 } else {
405 module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_2);
406 }
407 }
408 }
409
410 res_dft.set_size(res_dft.max_size());
411 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
412 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
413 res_big
414}