1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace,
4 VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace,
5 VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
6 VmpApplyDftToDftTmpBytes,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos},
9};
10
11use crate::{
12 layouts::{
13 GGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos,
14 prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared},
15 },
16 operations::GLWEOperations,
17};
18
19impl GGSWCiphertext<Vec<u8>> {
20 pub(crate) fn expand_row_scratch_space<B: Backend>(
21 module: &Module<B>,
22 basek: usize,
23 self_k: usize,
24 k_tsk: usize,
25 digits: usize,
26 rank: usize,
27 ) -> usize
28 where
29 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
30 {
31 let tsk_size: usize = k_tsk.div_ceil(basek);
32 let self_size_out: usize = self_k.div_ceil(basek);
33 let self_size_in: usize = self_size_out.div_ceil(digits);
34
35 let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(rank + 1, tsk_size);
36 let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, self_size_in);
37 let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(
38 self_size_out,
39 self_size_in,
40 self_size_in,
41 rank,
42 rank,
43 tsk_size,
44 );
45 let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size);
46 let norm: usize = module.vec_znx_normalize_tmp_bytes();
47 tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm))
48 }
49
50 #[allow(clippy::too_many_arguments)]
51 pub fn keyswitch_scratch_space<B: Backend>(
52 module: &Module<B>,
53 basek: usize,
54 k_out: usize,
55 k_in: usize,
56 k_ksk: usize,
57 digits_ksk: usize,
58 k_tsk: usize,
59 digits_tsk: usize,
60 rank: usize,
61 ) -> usize
62 where
63 Module<B>: VecZnxDftAllocBytes
64 + VmpApplyDftToDftTmpBytes
65 + VecZnxBigAllocBytes
66 + VecZnxNormalizeTmpBytes
67 + VecZnxBigNormalizeTmpBytes,
68 {
69 let out_size: usize = k_out.div_ceil(basek);
70 let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, out_size);
71 let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size);
72 let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank);
73 let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank);
74 let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size);
75 res_znx + ci_dft + (ks | expand_rows | res_dft)
76 }
77
78 #[allow(clippy::too_many_arguments)]
79 pub fn keyswitch_inplace_scratch_space<B: Backend>(
80 module: &Module<B>,
81 basek: usize,
82 k_out: usize,
83 k_ksk: usize,
84 digits_ksk: usize,
85 k_tsk: usize,
86 digits_tsk: usize,
87 rank: usize,
88 ) -> usize
89 where
90 Module<B>: VecZnxDftAllocBytes
91 + VmpApplyDftToDftTmpBytes
92 + VecZnxBigAllocBytes
93 + VecZnxNormalizeTmpBytes
94 + VecZnxBigNormalizeTmpBytes,
95 {
96 GGSWCiphertext::keyswitch_scratch_space(
97 module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
98 )
99 }
100}
101
102impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
103 pub fn from_gglwe<DataA, DataTsk, B: Backend>(
104 &mut self,
105 module: &Module<B>,
106 a: &GGLWECiphertext<DataA>,
107 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
108 scratch: &mut Scratch<B>,
109 ) where
110 DataA: DataRef,
111 DataTsk: DataRef,
112 Module<B>: VecZnxCopy
113 + VecZnxDftAllocBytes
114 + VmpApplyDftToDftTmpBytes
115 + VecZnxBigAllocBytes
116 + VecZnxNormalizeTmpBytes
117 + DFT<B>
118 + VecZnxDftCopy<B>
119 + VmpApplyDftToDft<B>
120 + VmpApplyDftToDftAdd<B>
121 + VecZnxDftAddInplace<B>
122 + VecZnxBigNormalize<B>
123 + IDFTTmpA<B>,
124 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
125 {
126 #[cfg(debug_assertions)]
127 {
128 assert_eq!(self.rank(), a.rank());
129 assert_eq!(self.rows(), a.rows());
130 assert_eq!(self.n(), module.n());
131 assert_eq!(a.n(), module.n());
132 assert_eq!(tsk.n(), module.n());
133 }
134 (0..self.rows()).for_each(|row_i| {
135 self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0));
136 });
137 self.expand_row(module, tsk, scratch);
138 }
139
140 pub fn keyswitch<DataLhs: DataRef, DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
141 &mut self,
142 module: &Module<B>,
143 lhs: &GGSWCiphertext<DataLhs>,
144 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
145 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
146 scratch: &mut Scratch<B>,
147 ) where
148 Module<B>: VecZnxDftAllocBytes
149 + VmpApplyDftToDftTmpBytes
150 + VecZnxBigNormalizeTmpBytes
151 + VmpApplyDftToDft<B>
152 + VmpApplyDftToDftAdd<B>
153 + DFT<B>
154 + IDFTConsume<B>
155 + VecZnxBigAddSmallInplace<B>
156 + VecZnxBigNormalize<B>
157 + VecZnxDftAllocBytes
158 + VecZnxBigAllocBytes
159 + VecZnxNormalizeTmpBytes
160 + VecZnxDftCopy<B>
161 + VecZnxDftAddInplace<B>
162 + IDFTTmpA<B>,
163 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
164 {
165 self.keyswitch_internal(module, lhs, ksk, scratch);
166 self.expand_row(module, tsk, scratch);
167 }
168
169 pub fn keyswitch_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
170 &mut self,
171 module: &Module<B>,
172 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
173 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
174 scratch: &mut Scratch<B>,
175 ) where
176 Module<B>: VecZnxDftAllocBytes
177 + VmpApplyDftToDftTmpBytes
178 + VecZnxBigNormalizeTmpBytes
179 + VmpApplyDftToDft<B>
180 + VmpApplyDftToDftAdd<B>
181 + DFT<B>
182 + IDFTConsume<B>
183 + VecZnxBigAddSmallInplace<B>
184 + VecZnxBigNormalize<B>
185 + VecZnxDftAllocBytes
186 + VecZnxBigAllocBytes
187 + VecZnxNormalizeTmpBytes
188 + VecZnxDftCopy<B>
189 + VecZnxDftAddInplace<B>
190 + IDFTTmpA<B>,
191 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
192 {
193 unsafe {
194 let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
195 self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
196 }
197 }
198
199 pub fn expand_row<DataTsk: DataRef, B: Backend>(
200 &mut self,
201 module: &Module<B>,
202 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
203 scratch: &mut Scratch<B>,
204 ) where
205 Module<B>: VecZnxDftAllocBytes
206 + VmpApplyDftToDftTmpBytes
207 + VecZnxBigAllocBytes
208 + VecZnxNormalizeTmpBytes
209 + DFT<B>
210 + VecZnxDftCopy<B>
211 + VmpApplyDftToDft<B>
212 + VmpApplyDftToDftAdd<B>
213 + VecZnxDftAddInplace<B>
214 + VecZnxBigNormalize<B>
215 + IDFTTmpA<B>,
216 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
217 {
218 assert!(
219 scratch.available()
220 >= GGSWCiphertext::expand_row_scratch_space(
221 module,
222 self.basek(),
223 self.k(),
224 tsk.k(),
225 tsk.digits(),
226 tsk.rank()
227 )
228 );
229
230 let n: usize = self.n();
231 let rank: usize = self.rank();
232 let cols: usize = rank + 1;
233
234 (0..self.rows()).for_each(|row_i| {
236 let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size());
238 (0..cols).for_each(|i| {
239 module.dft(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
240 });
241
242 (1..cols).for_each(|col_j| {
243 let digits: usize = tsk.digits();
264
265 let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size());
266 let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
267
268 {
269 (1..cols).for_each(|col_i| {
281 let pmat: &VmpPMat<DataTsk, B> = &tsk.at(col_i - 1, col_j - 1).key.data; (0..digits).for_each(|di| {
285 tmp_a.set_size((ci_dft.size() + di) / digits);
286
287 tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
295
296 module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
297 if di == 0 && col_i == 1 {
298 module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch3);
299 } else {
300 module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3);
301 }
302 });
303 });
304 }
305
306 module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
316 let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size());
317 (0..cols).for_each(|i| {
318 module.idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
319 module.vec_znx_big_normalize(
320 self.basek(),
321 &mut self.at_mut(row_i, col_j).data,
322 i,
323 &tmp_idft,
324 0,
325 scratch3,
326 );
327 });
328 })
329 })
330 }
331
332 fn keyswitch_internal<DataLhs: DataRef, DataKsk: DataRef, B: Backend>(
333 &mut self,
334 module: &Module<B>,
335 lhs: &GGSWCiphertext<DataLhs>,
336 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
337 scratch: &mut Scratch<B>,
338 ) where
339 Module<B>: VecZnxDftAllocBytes
340 + VmpApplyDftToDftTmpBytes
341 + VecZnxBigNormalizeTmpBytes
342 + VmpApplyDftToDft<B>
343 + VmpApplyDftToDftAdd<B>
344 + DFT<B>
345 + IDFTConsume<B>
346 + VecZnxBigAddSmallInplace<B>
347 + VecZnxBigNormalize<B>,
348 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
349 {
350 (0..lhs.rows()).for_each(|row_i| {
352 self.at_mut(row_i, 0)
355 .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
356 })
357 }
358}