1use poulpy_hal::{
2 api::{
3 ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes,
5 VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos},
8};
9
10use crate::{
11 GLWENormalize, ScratchTakeCore,
12 layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos},
13};
14
15impl GLWE<Vec<u8>> {
16 pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
17 where
18 R: GLWEInfos,
19 A: GLWEInfos,
20 B: GGLWEInfos,
21 M: GLWEKeyswitch<BE>,
22 {
23 module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
24 }
25}
26
27impl<D: DataMut> GLWE<D> {
28 pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
29 where
30 A: GLWEToRef + GLWEInfos,
31 B: GGLWEPreparedToRef<BE> + GGLWEInfos,
32 M: GLWEKeyswitch<BE>,
33 Scratch<BE>: ScratchTakeCore<BE>,
34 {
35 module.glwe_keyswitch(self, a, b, scratch);
36 }
37
38 pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
39 where
40 A: GGLWEPreparedToRef<BE> + GGLWEInfos,
41 M: GLWEKeyswitch<BE>,
42 Scratch<BE>: ScratchTakeCore<BE>,
43 {
44 module.glwe_keyswitch_inplace(self, a, scratch);
45 }
46}
47
48impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE>
49where
50 Self: Sized + GLWEKeySwitchInternal<BE> + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize<BE> + GLWENormalize<BE>,
51 Scratch<BE>: ScratchTakeCore<BE>,
52{
53 fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
54 where
55 R: GLWEInfos,
56 A: GLWEInfos,
57 B: GGLWEInfos,
58 {
59 let cols: usize = res_infos.rank().as_usize() + 1;
60 let size: usize = if a_infos.base2k() != key_infos.base2k() {
61 let a_conv_infos = &GLWELayout {
62 n: a_infos.n(),
63 base2k: key_infos.base2k(),
64 k: a_infos.k(),
65 rank: a_infos.rank(),
66 };
67 self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_conv_infos, key_infos) + GLWE::bytes_of_from_infos(a_conv_infos)
68 } else {
69 self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos)
70 };
71
72 size.max(self.vec_znx_big_normalize_tmp_bytes()) + self.bytes_of_vec_znx_dft(cols, key_infos.size())
73 }
74
75 fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
76 where
77 R: GLWEToMut + GLWEInfos,
78 A: GLWEToRef + GLWEInfos,
79 K: GGLWEPreparedToRef<BE> + GGLWEInfos,
80 {
81 assert_eq!(
82 a.rank(),
83 key.rank_in(),
84 "a.rank(): {} != b.rank_in(): {}",
85 a.rank(),
86 key.rank_in()
87 );
88 assert_eq!(
89 res.rank(),
90 key.rank_out(),
91 "res.rank(): {} != b.rank_out(): {}",
92 res.rank(),
93 key.rank_out()
94 );
95
96 assert_eq!(res.n(), self.n() as u32);
97 assert_eq!(a.n(), self.n() as u32);
98 assert_eq!(key.n(), self.n() as u32);
99
100 let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, key);
101
102 assert!(
103 scratch.available() >= scrach_needed,
104 "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
105 scratch.available(),
106 );
107
108 let base2k_a: usize = a.base2k().into();
109 let base2k_key: usize = key.base2k().into();
110 let base2k_res: usize = res.base2k().into();
111
112 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_key {
115 let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
116 n: a.n(),
117 base2k: key.base2k(),
118 k: a.k(),
119 rank: a.rank(),
120 });
121 self.glwe_normalize(&mut a_conv, a, scratch_2);
122 self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2)
123 } else {
124 self.glwe_keyswitch_internal(res_dft, a, key, scratch_1)
125 };
126
127 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
128 for i in 0..(res.rank() + 1).into() {
129 self.vec_znx_big_normalize(
130 base2k_res,
131 res.data_mut(),
132 i,
133 base2k_key,
134 &res_big,
135 i,
136 scratch_1,
137 );
138 }
139 }
140
141 fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
142 where
143 R: GLWEToMut + GLWEInfos,
144 K: GGLWEPreparedToRef<BE> + GGLWEInfos,
145 {
146 assert_eq!(
147 res.rank(),
148 key.rank_in(),
149 "res.rank(): {} != a.rank_in(): {}",
150 res.rank(),
151 key.rank_in()
152 );
153 assert_eq!(
154 res.rank(),
155 key.rank_out(),
156 "res.rank(): {} != b.rank_out(): {}",
157 res.rank(),
158 key.rank_out()
159 );
160
161 assert_eq!(res.n(), self.n() as u32);
162 assert_eq!(key.n(), self.n() as u32);
163
164 let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, key);
165
166 assert!(
167 scratch.available() >= scrach_needed,
168 "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
169 scratch.available(),
170 );
171
172 let base2k_res: usize = res.base2k().as_usize();
173 let base2k_key: usize = key.base2k().as_usize();
174
175 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_key {
178 let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
179 n: res.n(),
180 base2k: key.base2k(),
181 k: res.k(),
182 rank: res.rank(),
183 });
184 self.glwe_normalize(&mut res_conv, res, scratch_2);
185
186 self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2)
187 } else {
188 self.glwe_keyswitch_internal(res_dft, res, key, scratch_1)
189 };
190
191 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
192 for i in 0..(res.rank() + 1).into() {
193 self.vec_znx_big_normalize(
194 base2k_res,
195 res.data_mut(),
196 i,
197 base2k_key,
198 &res_big,
199 i,
200 scratch_1,
201 );
202 }
203 }
204}
205
206pub trait GLWEKeyswitch<BE: Backend> {
207 fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
208 where
209 R: GLWEInfos,
210 A: GLWEInfos,
211 B: GGLWEInfos;
212
213 fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
214 where
215 R: GLWEToMut + GLWEInfos,
216 A: GLWEToRef + GLWEInfos,
217 K: GGLWEPreparedToRef<BE> + GGLWEInfos;
218
219 fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
220 where
221 R: GLWEToMut + GLWEInfos,
222 K: GGLWEPreparedToRef<BE> + GGLWEInfos;
223}
224
225impl<BE: Backend> GLWEKeySwitchInternal<BE> for Module<BE> where
226 Self: GGLWEProduct<BE>
227 + VecZnxDftApply<BE>
228 + VecZnxNormalize<BE>
229 + VecZnxIdftApplyConsume<BE>
230 + VecZnxBigAddSmallInplace<BE>
231 + VecZnxNormalizeTmpBytes
232{
233}
234
235pub(crate) trait GLWEKeySwitchInternal<BE: Backend>
236where
237 Self: GGLWEProduct<BE>
238 + VecZnxDftApply<BE>
239 + VecZnxNormalize<BE>
240 + VecZnxIdftApplyConsume<BE>
241 + VecZnxBigAddSmallInplace<BE>
242 + VecZnxNormalizeTmpBytes,
243{
244 fn glwe_keyswitch_internal_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
245 where
246 R: GLWEInfos,
247 A: GLWEInfos,
248 K: GGLWEInfos,
249 {
250 let cols: usize = (a_infos.rank() + 1).into();
251 let a_size: usize = a_infos.size();
252 self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols - 1, a_size)
253 }
254
255 fn glwe_keyswitch_internal<DR, A, K>(
256 &self,
257 mut res: VecZnxDft<DR, BE>,
258 a: &A,
259 key: &K,
260 scratch: &mut Scratch<BE>,
261 ) -> VecZnxBig<DR, BE>
262 where
263 DR: DataMut,
264 A: GLWEToRef,
265 K: GGLWEPreparedToRef<BE>,
266 Scratch<BE>: ScratchTakeCore<BE>,
267 {
268 let a: &GLWE<&[u8]> = &a.to_ref();
269 let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
270 assert_eq!(a.base2k(), key.base2k());
271 let cols: usize = (a.rank() + 1).into();
272 let a_size: usize = a.size();
273 let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size);
274 for col_i in 0..cols - 1 {
275 self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1);
276 }
277 self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1);
278 let mut res_big: VecZnxBig<DR, BE> = self.vec_znx_idft_apply_consume(res);
279 self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
280 res_big
281 }
282}
283
284impl<BE: Backend> GGLWEProduct<BE> for Module<BE> where
285 Self: Sized
286 + ModuleN
287 + VecZnxDftBytesOf
288 + VmpApplyDftToDftTmpBytes
289 + VmpApplyDftToDft<BE>
290 + VmpApplyDftToDftAdd<BE>
291 + VecZnxDftCopy<BE>
292{
293}
294
295pub(crate) trait GGLWEProduct<BE: Backend>
296where
297 Self: Sized
298 + ModuleN
299 + VecZnxDftBytesOf
300 + VmpApplyDftToDftTmpBytes
301 + VmpApplyDftToDft<BE>
302 + VmpApplyDftToDftAdd<BE>
303 + VecZnxDftCopy<BE>,
304{
305 fn gglwe_product_dft_tmp_bytes<K>(&self, res_size: usize, a_size: usize, key_infos: &K) -> usize
306 where
307 K: GGLWEInfos,
308 {
309 let dsize: usize = key_infos.dsize().as_usize();
310
311 if dsize == 1 {
312 self.vmp_apply_dft_to_dft_tmp_bytes(
313 res_size,
314 a_size,
315 key_infos.dnum().into(),
316 (key_infos.rank_in()).into(),
317 (key_infos.rank_out() + 1).into(),
318 key_infos.size(),
319 )
320 } else {
321 let dnum: usize = key_infos.dnum().into();
322 let a_size: usize = a_size.div_ceil(dsize).min(dnum);
323 let ai_dft: usize = self.bytes_of_vec_znx_dft(key_infos.rank_in().into(), a_size);
324
325 let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
326 res_size,
327 a_size,
328 dnum,
329 (key_infos.rank_in()).into(),
330 (key_infos.rank_out() + 1).into(),
331 key_infos.size(),
332 );
333
334 ai_dft + vmp
335 }
336 }
337
338 fn gglwe_product_dft<DR, A, K>(&self, res: &mut VecZnxDft<DR, BE>, a: &A, key: &K, scratch: &mut Scratch<BE>)
339 where
340 DR: DataMut,
341 A: VecZnxDftToRef<BE>,
342 K: GGLWEPreparedToRef<BE>,
343 Scratch<BE>: ScratchTakeCore<BE>,
344 {
345 let a: &VecZnxDft<&[u8], BE> = &a.to_ref();
346 let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
347
348 let cols: usize = a.cols();
349 let a_size: usize = a.size();
350 let pmat: &VmpPMat<&[u8], BE> = &key.data;
351
352 if key.dsize() == 1 {
355 self.vmp_apply_dft_to_dft(res, a, pmat, scratch);
356 } else {
367 let dsize: usize = key.dsize().into();
368 let dnum: usize = key.dnum().into();
369
370 let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize).min(dnum));
372 ai_dft.data_mut().fill(0);
373
374 for di in 0..dsize {
375 ai_dft.set_size(((a_size + di) / dsize).min(dnum));
378
379 res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
387
388 for j in 0..cols {
389 self.vec_znx_dft_copy(dsize, dsize - di - 1, &mut ai_dft, j, a, j);
390 }
391
392 if di == 0 {
393 self.vmp_apply_dft_to_dft(res, &ai_dft, pmat, scratch_1);
395 } else {
396 self.vmp_apply_dft_to_dft_add(res, &ai_dft, pmat, di, scratch_1);
398 }
399 }
400
401 res.set_size(res.max_size());
402 }
403 }
404}