1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
4 VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromVecZnx,
5 VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
6 ZnxInfos,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat},
9};
10
11use crate::{
12 layouts::{
13 GGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos,
14 prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared},
15 },
16 operations::GLWEOperations,
17};
18
19impl GGSWCiphertext<Vec<u8>> {
20 pub(crate) fn expand_row_scratch_space<B: Backend>(
21 module: &Module<B>,
22 n: usize,
23 basek: usize,
24 self_k: usize,
25 k_tsk: usize,
26 digits: usize,
27 rank: usize,
28 ) -> usize
29 where
30 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes,
31 {
32 let tsk_size: usize = k_tsk.div_ceil(basek);
33 let self_size_out: usize = self_k.div_ceil(basek);
34 let self_size_in: usize = self_size_out.div_ceil(digits);
35
36 let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, tsk_size);
37 let tmp_a: usize = module.vec_znx_dft_alloc_bytes(n, 1, self_size_in);
38 let vmp: usize = module.vmp_apply_tmp_bytes(
39 n,
40 self_size_out,
41 self_size_in,
42 self_size_in,
43 rank,
44 rank,
45 tsk_size,
46 );
47 let tmp_idft: usize = module.vec_znx_big_alloc_bytes(n, 1, tsk_size);
48 let norm: usize = module.vec_znx_normalize_tmp_bytes(n);
49 tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm))
50 }
51
52 #[allow(clippy::too_many_arguments)]
53 pub fn keyswitch_scratch_space<B: Backend>(
54 module: &Module<B>,
55 n: usize,
56 basek: usize,
57 k_out: usize,
58 k_in: usize,
59 k_ksk: usize,
60 digits_ksk: usize,
61 k_tsk: usize,
62 digits_tsk: usize,
63 rank: usize,
64 ) -> usize
65 where
66 Module<B>:
67 VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
68 {
69 let out_size: usize = k_out.div_ceil(basek);
70 let res_znx: usize = VecZnx::alloc_bytes(n, rank + 1, out_size);
71 let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size);
72 let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank);
73 let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank);
74 let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size);
75 res_znx + ci_dft + (ks | expand_rows | res_dft)
76 }
77
78 #[allow(clippy::too_many_arguments)]
79 pub fn keyswitch_inplace_scratch_space<B: Backend>(
80 module: &Module<B>,
81 n: usize,
82 basek: usize,
83 k_out: usize,
84 k_ksk: usize,
85 digits_ksk: usize,
86 k_tsk: usize,
87 digits_tsk: usize,
88 rank: usize,
89 ) -> usize
90 where
91 Module<B>:
92 VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
93 {
94 GGSWCiphertext::keyswitch_scratch_space(
95 module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
96 )
97 }
98}
99
100impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
101 pub fn from_gglwe<DataA, DataTsk, B: Backend>(
102 &mut self,
103 module: &Module<B>,
104 a: &GGLWECiphertext<DataA>,
105 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
106 scratch: &mut Scratch<B>,
107 ) where
108 DataA: DataRef,
109 DataTsk: DataRef,
110 Module<B>: VecZnxCopy
111 + VecZnxDftAllocBytes
112 + VmpApplyTmpBytes
113 + VecZnxBigAllocBytes
114 + VecZnxNormalizeTmpBytes
115 + VecZnxDftFromVecZnx<B>
116 + VecZnxDftCopy<B>
117 + VmpApply<B>
118 + VmpApplyAdd<B>
119 + VecZnxDftAddInplace<B>
120 + VecZnxBigNormalize<B>
121 + VecZnxDftToVecZnxBigTmpA<B>,
122 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
123 {
124 #[cfg(debug_assertions)]
125 {
126 assert_eq!(self.rank(), a.rank());
127 assert_eq!(self.rows(), a.rows());
128 assert_eq!(self.n(), module.n());
129 assert_eq!(a.n(), module.n());
130 assert_eq!(tsk.n(), module.n());
131 }
132 (0..self.rows()).for_each(|row_i| {
133 self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0));
134 });
135 self.expand_row(module, tsk, scratch);
136 }
137
138 pub fn keyswitch<DataLhs: DataRef, DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
139 &mut self,
140 module: &Module<B>,
141 lhs: &GGSWCiphertext<DataLhs>,
142 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
143 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
144 scratch: &mut Scratch<B>,
145 ) where
146 Module<B>: VecZnxDftAllocBytes
147 + VmpApplyTmpBytes
148 + VecZnxBigNormalizeTmpBytes
149 + VmpApply<B>
150 + VmpApplyAdd<B>
151 + VecZnxDftFromVecZnx<B>
152 + VecZnxDftToVecZnxBigConsume<B>
153 + VecZnxBigAddSmallInplace<B>
154 + VecZnxBigNormalize<B>
155 + VecZnxDftAllocBytes
156 + VecZnxBigAllocBytes
157 + VecZnxNormalizeTmpBytes
158 + VecZnxDftCopy<B>
159 + VecZnxDftAddInplace<B>
160 + VecZnxDftToVecZnxBigTmpA<B>,
161 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
162 {
163 self.keyswitch_internal(module, lhs, ksk, scratch);
164 self.expand_row(module, tsk, scratch);
165 }
166
167 pub fn keyswitch_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
168 &mut self,
169 module: &Module<B>,
170 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
171 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
172 scratch: &mut Scratch<B>,
173 ) where
174 Module<B>: VecZnxDftAllocBytes
175 + VmpApplyTmpBytes
176 + VecZnxBigNormalizeTmpBytes
177 + VmpApply<B>
178 + VmpApplyAdd<B>
179 + VecZnxDftFromVecZnx<B>
180 + VecZnxDftToVecZnxBigConsume<B>
181 + VecZnxBigAddSmallInplace<B>
182 + VecZnxBigNormalize<B>
183 + VecZnxDftAllocBytes
184 + VecZnxBigAllocBytes
185 + VecZnxNormalizeTmpBytes
186 + VecZnxDftCopy<B>
187 + VecZnxDftAddInplace<B>
188 + VecZnxDftToVecZnxBigTmpA<B>,
189 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
190 {
191 unsafe {
192 let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
193 self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
194 }
195 }
196
197 pub fn expand_row<DataTsk: DataRef, B: Backend>(
198 &mut self,
199 module: &Module<B>,
200 tsk: &GGLWETensorKeyPrepared<DataTsk, B>,
201 scratch: &mut Scratch<B>,
202 ) where
203 Module<B>: VecZnxDftAllocBytes
204 + VmpApplyTmpBytes
205 + VecZnxBigAllocBytes
206 + VecZnxNormalizeTmpBytes
207 + VecZnxDftFromVecZnx<B>
208 + VecZnxDftCopy<B>
209 + VmpApply<B>
210 + VmpApplyAdd<B>
211 + VecZnxDftAddInplace<B>
212 + VecZnxBigNormalize<B>
213 + VecZnxDftToVecZnxBigTmpA<B>,
214 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
215 {
216 assert!(
217 scratch.available()
218 >= GGSWCiphertext::expand_row_scratch_space(
219 module,
220 self.n(),
221 self.basek(),
222 self.k(),
223 tsk.k(),
224 tsk.digits(),
225 tsk.rank()
226 )
227 );
228
229 let n: usize = self.n();
230 let rank: usize = self.rank();
231 let cols: usize = rank + 1;
232
233 (0..self.rows()).for_each(|row_i| {
235 let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size());
237 (0..cols).for_each(|i| {
238 module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
239 });
240
241 (1..cols).for_each(|col_j| {
242 let digits: usize = tsk.digits();
263
264 let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size());
265 let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
266
267 {
268 (1..cols).for_each(|col_i| {
280 let pmat: &VmpPMat<DataTsk, B> = &tsk.at(col_i - 1, col_j - 1).key.data; (0..digits).for_each(|di| {
284 tmp_a.set_size((ci_dft.size() + di) / digits);
285
286 tmp_dft_i.set_size(tsk.size() - ((digits - di) as isize - 2).max(0) as usize);
294
295 module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
296 if di == 0 && col_i == 1 {
297 module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch3);
298 } else {
299 module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3);
300 }
301 });
302 });
303 }
304
305 module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
315 let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size());
316 (0..cols).for_each(|i| {
317 module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
318 module.vec_znx_big_normalize(
319 self.basek(),
320 &mut self.at_mut(row_i, col_j).data,
321 i,
322 &tmp_idft,
323 0,
324 scratch3,
325 );
326 });
327 })
328 })
329 }
330
331 fn keyswitch_internal<DataLhs: DataRef, DataKsk: DataRef, B: Backend>(
332 &mut self,
333 module: &Module<B>,
334 lhs: &GGSWCiphertext<DataLhs>,
335 ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
336 scratch: &mut Scratch<B>,
337 ) where
338 Module<B>: VecZnxDftAllocBytes
339 + VmpApplyTmpBytes
340 + VecZnxBigNormalizeTmpBytes
341 + VmpApply<B>
342 + VmpApplyAdd<B>
343 + VecZnxDftFromVecZnx<B>
344 + VecZnxDftToVecZnxBigConsume<B>
345 + VecZnxBigAddSmallInplace<B>
346 + VecZnxBigNormalize<B>,
347 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
348 {
349 (0..lhs.rows()).for_each(|row_i| {
351 self.at_mut(row_i, 0)
354 .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
355 })
356 }
357}