1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
4 VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply,
5 VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
6 VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos},
9};
10
11use crate::{
12 layouts::{
13 GGLWECiphertext, GGLWELayoutInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos,
14 prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared},
15 },
16 operations::GLWEOperations,
17};
18
19impl GGSWCiphertext<Vec<u8>> {
20 pub(crate) fn expand_row_scratch_space<B: Backend, OUT, TSK>(module: &Module<B>, out_infos: &OUT, tsk_infos: &TSK) -> usize
21 where
22 OUT: GGSWInfos,
23 TSK: GGLWELayoutInfos,
24 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
25 {
26 let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize;
27 let size_in: usize = out_infos
28 .k()
29 .div_ceil(tsk_infos.base2k())
30 .div_ceil(tsk_infos.digits().into()) as usize;
31
32 let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes((tsk_infos.rank_out() + 1).into(), tsk_size);
33 let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, size_in);
34 let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
35 tsk_size,
36 size_in,
37 size_in,
38 (tsk_infos.rank_in()).into(), (tsk_infos.rank_out()).into(), tsk_size,
41 );
42 let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size);
43 let norm: usize = module.vec_znx_normalize_tmp_bytes();
44
45 tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm))
46 }
47
48 #[allow(clippy::too_many_arguments)]
49 pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY, TSK>(
50 module: &Module<B>,
51 out_infos: &OUT,
52 in_infos: &IN,
53 apply_infos: &KEY,
54 tsk_infos: &TSK,
55 ) -> usize
56 where
57 OUT: GGSWInfos,
58 IN: GGSWInfos,
59 KEY: GGLWELayoutInfos,
60 TSK: GGLWELayoutInfos,
61 Module<B>: VecZnxDftAllocBytes
62 + VmpApplyDftToDftTmpBytes
63 + VecZnxBigAllocBytes
64 + VecZnxNormalizeTmpBytes
65 + VecZnxBigNormalizeTmpBytes,
66 {
67 #[cfg(debug_assertions)]
68 {
69 assert_eq!(apply_infos.rank_in(), apply_infos.rank_out());
70 assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out());
71 assert_eq!(apply_infos.rank_in(), tsk_infos.rank_in());
72 }
73
74 let rank: usize = apply_infos.rank_out().into();
75
76 let size_out: usize = out_infos.k().div_ceil(out_infos.base2k()) as usize;
77 let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, size_out);
78 let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out);
79 let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, apply_infos);
80 let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos);
81 let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out);
82
83 if in_infos.base2k() == tsk_infos.base2k() {
84 res_znx + ci_dft + (ks | expand_rows | res_dft)
85 } else {
86 let a_conv: usize = VecZnx::alloc_bytes(
87 module.n(),
88 1,
89 out_infos.k().div_ceil(tsk_infos.base2k()) as usize,
90 ) + module.vec_znx_normalize_tmp_bytes();
91 res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft)
92 }
93 }
94
95 #[allow(clippy::too_many_arguments)]
96 pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY, TSK>(
97 module: &Module<B>,
98 out_infos: &OUT,
99 apply_infos: &KEY,
100 tsk_infos: &TSK,
101 ) -> usize
102 where
103 OUT: GGSWInfos,
104 KEY: GGLWELayoutInfos,
105 TSK: GGLWELayoutInfos,
106 Module<B>: VecZnxDftAllocBytes
107 + VmpApplyDftToDftTmpBytes
108 + VecZnxBigAllocBytes
109 + VecZnxNormalizeTmpBytes
110 + VecZnxBigNormalizeTmpBytes,
111 {
112 GGSWCiphertext::keyswitch_scratch_space(module, out_infos, out_infos, apply_infos, tsk_infos)
113 }
114}
115
116impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
117 pub fn from_gglwe<DataA, DataTsk, B: Backend>(
118 &mut self,
119 module: &Module<B>,
120 a: &GGLWECiphertext<DataA>,
121 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
122 scratch: &mut Scratch<B>,
123 ) where
124 DataA: DataRef,
125 DataTsk: DataRef,
126 Module<B>: VecZnxCopy
127 + VecZnxDftAllocBytes
128 + VmpApplyDftToDftTmpBytes
129 + VecZnxBigAllocBytes
130 + VecZnxNormalizeTmpBytes
131 + VecZnxDftApply<B>
132 + VecZnxDftCopy<B>
133 + VmpApplyDftToDft<B>
134 + VmpApplyDftToDftAdd<B>
135 + VecZnxDftAddInplace<B>
136 + VecZnxBigNormalize<B>
137 + VecZnxIdftApplyTmpA<B>
138 + VecZnxNormalize<B>,
139 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
140 {
141 #[cfg(debug_assertions)]
142 {
143 use crate::layouts::{GLWEInfos, LWEInfos};
144
145 assert_eq!(self.rank(), a.rank_out());
146 assert_eq!(self.rows(), a.rows());
147 assert_eq!(self.n(), module.n() as u32);
148 assert_eq!(a.n(), module.n() as u32);
149 assert_eq!(tsk.n(), module.n() as u32);
150 }
151 (0..self.rows().into()).for_each(|row_i| {
152 self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0));
153 });
154 self.expand_row(module, tsk, scratch);
155 }
156
157 pub fn keyswitch<DataLhs: DataRef, DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
158 &mut self,
159 module: &Module<B>,
160 lhs: &GGSWCiphertext<DataLhs>,
161 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
162 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
163 scratch: &mut Scratch<B>,
164 ) where
165 Module<B>: VecZnxDftAllocBytes
166 + VmpApplyDftToDftTmpBytes
167 + VecZnxBigNormalizeTmpBytes
168 + VmpApplyDftToDft<B>
169 + VmpApplyDftToDftAdd<B>
170 + VecZnxDftApply<B>
171 + VecZnxIdftApplyConsume<B>
172 + VecZnxBigAddSmallInplace<B>
173 + VecZnxBigNormalize<B>
174 + VecZnxDftAllocBytes
175 + VecZnxBigAllocBytes
176 + VecZnxNormalizeTmpBytes
177 + VecZnxDftCopy<B>
178 + VecZnxDftAddInplace<B>
179 + VecZnxIdftApplyTmpA<B>
180 + VecZnxNormalize<B>,
181 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
182 {
183 (0..lhs.rows().into()).for_each(|row_i| {
184 self.at_mut(row_i, 0)
187 .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
188 });
189 self.expand_row(module, tsk, scratch);
190 }
191
192 pub fn keyswitch_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
193 &mut self,
194 module: &Module<B>,
195 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
196 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
197 scratch: &mut Scratch<B>,
198 ) where
199 Module<B>: VecZnxDftAllocBytes
200 + VmpApplyDftToDftTmpBytes
201 + VecZnxBigNormalizeTmpBytes
202 + VmpApplyDftToDft<B>
203 + VmpApplyDftToDftAdd<B>
204 + VecZnxDftApply<B>
205 + VecZnxIdftApplyConsume<B>
206 + VecZnxBigAddSmallInplace<B>
207 + VecZnxBigNormalize<B>
208 + VecZnxDftAllocBytes
209 + VecZnxBigAllocBytes
210 + VecZnxNormalizeTmpBytes
211 + VecZnxDftCopy<B>
212 + VecZnxDftAddInplace<B>
213 + VecZnxIdftApplyTmpA<B>
214 + VecZnxNormalize<B>,
215 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
216 {
217 (0..self.rows().into()).for_each(|row_i| {
218 self.at_mut(row_i, 0)
221 .keyswitch_inplace(module, ksk, scratch);
222 });
223 self.expand_row(module, tsk, scratch);
224 }
225
226 pub fn expand_row<DataTsk: DataRef, B: Backend>(
227 &mut self,
228 module: &Module<B>,
229 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
230 scratch: &mut Scratch<B>,
231 ) where
232 Module<B>: VecZnxDftAllocBytes
233 + VmpApplyDftToDftTmpBytes
234 + VecZnxBigAllocBytes
235 + VecZnxNormalizeTmpBytes
236 + VecZnxDftApply<B>
237 + VecZnxDftCopy<B>
238 + VmpApplyDftToDft<B>
239 + VmpApplyDftToDftAdd<B>
240 + VecZnxDftAddInplace<B>
241 + VecZnxBigNormalize<B>
242 + VecZnxIdftApplyTmpA<B>
243 + VecZnxNormalize<B>,
244 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx,
245 {
246 let basek_in: usize = self.base2k().into();
247 let basek_tsk: usize = tsk.base2k().into();
248
249 assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self, tsk));
250
251 let n: usize = self.n().into();
252 let rank: usize = self.rank().into();
253 let cols: usize = rank + 1;
254
255 let a_size: usize = (self.size() * basek_in).div_ceil(basek_tsk);
256
257 for row_i in 0..self.rows().into() {
259 let a = &self.at(row_i, 0).data;
260
261 let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, a_size);
263
264 if basek_in == basek_tsk {
265 for i in 0..cols {
266 module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i);
267 }
268 } else {
269 let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(n, 1, a_size);
270 for i in 0..cols {
271 module.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2);
272 module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0);
273 }
274 }
275
276 for col_j in 1..cols {
277 let digits: usize = tsk.digits().into();
298
299 let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size());
300 let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
301
302 {
303 for col_i in 1..cols {
315 let pmat: &VmpPMat<DataTsk, B> = &tsk.at(col_i - 1, col_j - 1).key.data; for di in 0..digits {
319 tmp_a.set_size((ci_dft.size() + di) / digits);
320
321 tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
329
330 module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
331 if di == 0 && col_i == 1 {
332 module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3);
333 } else {
334 module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3);
335 }
336 }
337 }
338 }
339
340 module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
350 let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size());
351 for i in 0..cols {
352 module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i);
353 module.vec_znx_big_normalize(
354 basek_in,
355 &mut self.at_mut(row_i, col_j).data,
356 i,
357 basek_tsk,
358 &tmp_idft,
359 0,
360 scratch_3,
361 );
362 }
363 }
364 }
365 }
366}