1use poulpy_hal::{
2 api::{
3 ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5 VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
8};
9
10use crate::{
11 ScratchTakeCore,
12 layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos},
13};
14
15impl GLWE<Vec<u8>> {
16 pub fn keyswitch_tmp_bytes<M, R, A, B, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
17 where
18 R: GLWEInfos,
19 A: GLWEInfos,
20 B: GGLWEInfos,
21 M: GLWEKeyswitch<BE>,
22 {
23 module.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
24 }
25}
26
27impl<D: DataMut> GLWE<D> {
28 pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
29 where
30 A: GLWEToRef,
31 B: GGLWEPreparedToRef<BE>,
32 M: GLWEKeyswitch<BE>,
33 Scratch<BE>: ScratchTakeCore<BE>,
34 {
35 module.glwe_keyswitch(self, a, b, scratch);
36 }
37
38 pub fn keyswitch_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
39 where
40 A: GGLWEPreparedToRef<BE>,
41 M: GLWEKeyswitch<BE>,
42 Scratch<BE>: ScratchTakeCore<BE>,
43 {
44 module.glwe_keyswitch_inplace(self, a, scratch);
45 }
46}
47
48impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE> where
49 Self: Sized
50 + ModuleN
51 + VecZnxDftBytesOf
52 + VmpApplyDftToDftTmpBytes
53 + VecZnxBigNormalizeTmpBytes
54 + VecZnxNormalizeTmpBytes
55 + VecZnxDftBytesOf
56 + VmpApplyDftToDftTmpBytes
57 + VecZnxBigNormalizeTmpBytes
58 + VmpApplyDftToDft<BE>
59 + VmpApplyDftToDftAdd<BE>
60 + VecZnxDftApply<BE>
61 + VecZnxIdftApplyConsume<BE>
62 + VecZnxBigAddSmallInplace<BE>
63 + VecZnxBigNormalize<BE>
64 + VecZnxNormalize<BE>
65 + VecZnxNormalizeTmpBytes
66{
67}
68
69pub trait GLWEKeyswitch<BE: Backend>
70where
71 Self: Sized
72 + ModuleN
73 + VecZnxDftBytesOf
74 + VmpApplyDftToDftTmpBytes
75 + VecZnxBigNormalizeTmpBytes
76 + VecZnxNormalizeTmpBytes
77 + VecZnxDftBytesOf
78 + VmpApplyDftToDftTmpBytes
79 + VecZnxBigNormalizeTmpBytes
80 + VmpApplyDftToDft<BE>
81 + VmpApplyDftToDftAdd<BE>
82 + VecZnxDftApply<BE>
83 + VecZnxIdftApplyConsume<BE>
84 + VecZnxBigAddSmallInplace<BE>
85 + VecZnxBigNormalize<BE>
86 + VecZnxNormalize<BE>
87 + VecZnxNormalizeTmpBytes,
88{
89 fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
90 where
91 R: GLWEInfos,
92 A: GLWEInfos,
93 B: GGLWEInfos,
94 {
95 let in_size: usize = a_infos
96 .k()
97 .div_ceil(key_infos.base2k())
98 .div_ceil(key_infos.dsize().into()) as usize;
99 let out_size: usize = res_infos.size();
100 let ksk_size: usize = key_infos.size();
101 let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
103 let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
104 out_size,
105 in_size,
106 in_size,
107 (key_infos.rank_in()).into(),
108 (key_infos.rank_out() + 1).into(),
109 ksk_size,
110 ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
111 let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes();
112 if a_infos.base2k() == key_infos.base2k() {
113 res_dft + ((ai_dft + vmp) | normalize_big)
114 } else if key_infos.dsize() == 1 {
115 let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes();
117 res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
118 } else {
119 let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size);
121 res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
122 }
123 }
124
125 fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
126 where
127 R: GLWEToMut,
128 A: GLWEToRef,
129 K: GGLWEPreparedToRef<BE>,
130 Scratch<BE>: ScratchTakeCore<BE>,
131 {
132 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
133 let a: &GLWE<&[u8]> = &a.to_ref();
134 let b: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
135
136 assert_eq!(
137 a.rank(),
138 b.rank_in(),
139 "a.rank(): {} != b.rank_in(): {}",
140 a.rank(),
141 b.rank_in()
142 );
143 assert_eq!(
144 res.rank(),
145 b.rank_out(),
146 "res.rank(): {} != b.rank_out(): {}",
147 res.rank(),
148 b.rank_out()
149 );
150
151 assert_eq!(res.n(), self.n() as u32);
152 assert_eq!(a.n(), self.n() as u32);
153 assert_eq!(b.n(), self.n() as u32);
154
155 let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, b);
156
157 assert!(
158 scratch.available() >= scrach_needed,
159 "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
160 scratch.available(),
161 );
162
163 let basek_out: usize = res.base2k().into();
164 let base2k_out: usize = b.base2k().into();
165
166 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1);
168 (0..(res.rank() + 1).into()).for_each(|i| {
169 self.vec_znx_big_normalize(
170 basek_out,
171 &mut res.data,
172 i,
173 base2k_out,
174 &res_big,
175 i,
176 scratch_1,
177 );
178 })
179 }
180
181 fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
182 where
183 R: GLWEToMut,
184 K: GGLWEPreparedToRef<BE>,
185 Scratch<BE>: ScratchTakeCore<BE>,
186 {
187 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
188 let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
189
190 assert_eq!(
191 res.rank(),
192 a.rank_in(),
193 "res.rank(): {} != a.rank_in(): {}",
194 res.rank(),
195 a.rank_in()
196 );
197 assert_eq!(
198 res.rank(),
199 a.rank_out(),
200 "res.rank(): {} != b.rank_out(): {}",
201 res.rank(),
202 a.rank_out()
203 );
204
205 assert_eq!(res.n(), self.n() as u32);
206 assert_eq!(a.n(), self.n() as u32);
207
208 let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a);
209
210 assert!(
211 scratch.available() >= scrach_needed,
212 "scratch.available()={} < glwe_keyswitch_tmp_bytes={scrach_needed}",
213 scratch.available(),
214 );
215
216 let base2k_in: usize = res.base2k().into();
217 let base2k_out: usize = a.base2k().into();
218
219 let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1);
221 (0..(res.rank() + 1).into()).for_each(|i| {
222 self.vec_znx_big_normalize(
223 base2k_in,
224 &mut res.data,
225 i,
226 base2k_out,
227 &res_big,
228 i,
229 scratch_1,
230 );
231 })
232 }
233}
234
235impl GLWE<Vec<u8>> {}
236
237impl<DataSelf: DataMut> GLWE<DataSelf> {}
238
239pub(crate) fn keyswitch_internal<BE: Backend, M, DR, A, K>(
240 module: &M,
241 mut res: VecZnxDft<DR, BE>,
242 a: &A,
243 key: &K,
244 scratch: &mut Scratch<BE>,
245) -> VecZnxBig<DR, BE>
246where
247 DR: DataMut,
248 A: GLWEToRef,
249 K: GGLWEPreparedToRef<BE>,
250 M: ModuleN
251 + VecZnxDftBytesOf
252 + VmpApplyDftToDftTmpBytes
253 + VecZnxBigNormalizeTmpBytes
254 + VmpApplyDftToDftTmpBytes
255 + VmpApplyDftToDft<BE>
256 + VmpApplyDftToDftAdd<BE>
257 + VecZnxDftApply<BE>
258 + VecZnxIdftApplyConsume<BE>
259 + VecZnxBigAddSmallInplace<BE>
260 + VecZnxBigNormalize<BE>
261 + VecZnxNormalize<BE>,
262 Scratch<BE>: ScratchTakeCore<BE>,
263{
264 let a: &GLWE<&[u8]> = &a.to_ref();
265 let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
266
267 let base2k_in: usize = a.base2k().into();
268 let base2k_out: usize = key.base2k().into();
269 let cols: usize = (a.rank() + 1).into();
270 let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
271 let pmat: &VmpPMat<&[u8], BE> = &key.data;
272
273 if key.dsize() == 1 {
274 let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());
275
276 if base2k_in == base2k_out {
277 (0..cols - 1).for_each(|col_i| {
278 module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1);
279 });
280 } else {
281 let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size);
282 (0..cols - 1).for_each(|col_i| {
283 module.vec_znx_normalize(
284 base2k_out,
285 &mut a_conv,
286 0,
287 base2k_in,
288 a.data(),
289 col_i + 1,
290 scratch_2,
291 );
292 module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0);
293 });
294 }
295
296 module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
297 } else {
298 let dsize: usize = key.dsize().into();
299
300 let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize));
301 ai_dft.data_mut().fill(0);
302
303 if base2k_in == base2k_out {
304 for di in 0..dsize {
305 ai_dft.set_size((a_size + di) / dsize);
306
307 res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
315
316 for j in 0..cols - 1 {
317 module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1);
318 }
319
320 if di == 0 {
321 module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1);
322 } else {
323 module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1);
324 }
325 }
326 } else {
327 let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size);
328 for j in 0..cols - 1 {
329 module.vec_znx_normalize(
330 base2k_out,
331 &mut a_conv,
332 j,
333 base2k_in,
334 a.data(),
335 j + 1,
336 scratch_2,
337 );
338 }
339
340 for di in 0..dsize {
341 ai_dft.set_size((a_size + di) / dsize);
342
343 res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
351
352 for j in 0..cols - 1 {
353 module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j);
354 }
355
356 if di == 0 {
357 module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2);
358 } else {
359 module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2);
360 }
361 }
362 }
363
364 res.set_size(res.max_size());
365 }
366
367 let mut res_big: VecZnxBig<DR, BE> = module.vec_znx_idft_apply_consume(res);
368 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
369 res_big
370}