1use poulpy_hal::{
2 api::{
3 DataViewMut, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxInfos,
5 },
6 layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
7};
8
9use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWESwitchingKeyPrepared};
10
11impl GLWECiphertext<Vec<u8>> {
12 #[allow(clippy::too_many_arguments)]
13 pub fn keyswitch_scratch_space<B: Backend>(
14 module: &Module<B>,
15 n: usize,
16 basek: usize,
17 k_out: usize,
18 k_in: usize,
19 k_ksk: usize,
20 digits: usize,
21 rank_in: usize,
22 rank_out: usize,
23 ) -> usize
24 where
25 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
26 {
27 let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
28 let out_size: usize = k_out.div_ceil(basek);
29 let ksk_size: usize = k_ksk.div_ceil(basek);
30 let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_out + 1, ksk_size); let ai_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_in, in_size);
32 let vmp: usize = module.vmp_apply_tmp_bytes(
33 n,
34 out_size,
35 in_size,
36 in_size,
37 rank_in,
38 rank_out + 1,
39 ksk_size,
40 ) + module.vec_znx_dft_alloc_bytes(n, rank_in, in_size);
41 let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(n);
42 res_dft + ((ai_dft + vmp) | normalize)
43 }
44
45 pub fn keyswitch_inplace_scratch_space<B: Backend>(
46 module: &Module<B>,
47 n: usize,
48 basek: usize,
49 k_out: usize,
50 k_ksk: usize,
51 digits: usize,
52 rank: usize,
53 ) -> usize
54 where
55 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
56 {
57 Self::keyswitch_scratch_space(module, n, basek, k_out, k_out, k_ksk, digits, rank, rank)
58 }
59}
60
61impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
62 #[allow(dead_code)]
63 pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
64 &self,
65 module: &Module<B>,
66 lhs: &GLWECiphertext<DataLhs>,
67 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
68 scratch: &Scratch<B>,
69 ) where
70 DataLhs: DataRef,
71 DataRhs: DataRef,
72 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
73 Scratch<B>: ScratchAvailable,
74 {
75 let basek: usize = self.basek();
76 assert_eq!(
77 lhs.rank(),
78 rhs.rank_in(),
79 "lhs.rank(): {} != rhs.rank_in(): {}",
80 lhs.rank(),
81 rhs.rank_in()
82 );
83 assert_eq!(
84 self.rank(),
85 rhs.rank_out(),
86 "self.rank(): {} != rhs.rank_out(): {}",
87 self.rank(),
88 rhs.rank_out()
89 );
90 assert_eq!(self.basek(), basek);
91 assert_eq!(lhs.basek(), basek);
92 assert_eq!(rhs.n(), self.n());
93 assert_eq!(lhs.n(), self.n());
94 assert!(
95 scratch.available()
96 >= GLWECiphertext::keyswitch_scratch_space(
97 module,
98 self.n(),
99 self.basek(),
100 self.k(),
101 lhs.k(),
102 rhs.k(),
103 rhs.digits(),
104 rhs.rank_in(),
105 rhs.rank_out(),
106 ),
107 "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
108 module,
109 self.basek(),
110 self.k(),
111 lhs.k(),
112 rhs.k(),
113 rhs.digits(),
114 rhs.rank_in(),
115 rhs.rank_out(),
116 )={}",
117 scratch.available(),
118 GLWECiphertext::keyswitch_scratch_space(
119 module,
120 self.n(),
121 self.basek(),
122 self.k(),
123 lhs.k(),
124 rhs.k(),
125 rhs.digits(),
126 rhs.rank_in(),
127 rhs.rank_out(),
128 )
129 );
130 }
131}
132
133impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
134 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
135 &mut self,
136 module: &Module<B>,
137 lhs: &GLWECiphertext<DataLhs>,
138 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
139 scratch: &mut Scratch<B>,
140 ) where
141 Module<B>: VecZnxDftAllocBytes
142 + VmpApplyTmpBytes
143 + VecZnxBigNormalizeTmpBytes
144 + VmpApplyTmpBytes
145 + VmpApply<B>
146 + VmpApplyAdd<B>
147 + VecZnxDftFromVecZnx<B>
148 + VecZnxDftToVecZnxBigConsume<B>
149 + VecZnxBigAddSmallInplace<B>
150 + VecZnxBigNormalize<B>,
151 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
152 {
153 #[cfg(debug_assertions)]
154 {
155 self.assert_keyswitch(module, lhs, rhs, scratch);
156 }
157 let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1);
159 (0..self.cols()).for_each(|i| {
160 module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
161 })
162 }
163
164 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
165 &mut self,
166 module: &Module<B>,
167 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
168 scratch: &mut Scratch<B>,
169 ) where
170 Module<B>: VecZnxDftAllocBytes
171 + VmpApplyTmpBytes
172 + VecZnxBigNormalizeTmpBytes
173 + VmpApplyTmpBytes
174 + VmpApply<B>
175 + VmpApplyAdd<B>
176 + VecZnxDftFromVecZnx<B>
177 + VecZnxDftToVecZnxBigConsume<B>
178 + VecZnxBigAddSmallInplace<B>
179 + VecZnxBigNormalize<B>,
180 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
181 {
182 unsafe {
183 let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
184 self.keyswitch(module, &*self_ptr, rhs, scratch);
185 }
186 }
187}
188
189impl<D: DataRef> GLWECiphertext<D> {
190 pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
191 &self,
192 module: &Module<B>,
193 res_dft: VecZnxDft<DataRes, B>,
194 rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
195 scratch: &mut Scratch<B>,
196 ) -> VecZnxBig<DataRes, B>
197 where
198 DataRes: DataMut,
199 DataKey: DataRef,
200 Module<B>: VecZnxDftAllocBytes
201 + VmpApplyTmpBytes
202 + VecZnxBigNormalizeTmpBytes
203 + VmpApplyTmpBytes
204 + VmpApply<B>
205 + VmpApplyAdd<B>
206 + VecZnxDftFromVecZnx<B>
207 + VecZnxDftToVecZnxBigConsume<B>
208 + VecZnxBigAddSmallInplace<B>
209 + VecZnxBigNormalize<B>,
210 Scratch<B>: TakeVecZnxDft<B>,
211 {
212 if rhs.digits() == 1 {
213 return keyswitch_vmp_one_digit(module, res_dft, &self.data, &rhs.key.data, scratch);
214 }
215
216 keyswitch_vmp_multiple_digits(
217 module,
218 res_dft,
219 &self.data,
220 &rhs.key.data,
221 rhs.digits(),
222 scratch,
223 )
224 }
225}
226
227fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
228 module: &Module<B>,
229 mut res_dft: VecZnxDft<DataRes, B>,
230 a: &VecZnx<DataIn>,
231 mat: &VmpPMat<DataVmp, B>,
232 scratch: &mut Scratch<B>,
233) -> VecZnxBig<DataRes, B>
234where
235 DataRes: DataMut,
236 DataIn: DataRef,
237 DataVmp: DataRef,
238 Module<B>:
239 VecZnxDftAllocBytes + VecZnxDftFromVecZnx<B> + VmpApply<B> + VecZnxDftToVecZnxBigConsume<B> + VecZnxBigAddSmallInplace<B>,
240 Scratch<B>: TakeVecZnxDft<B>,
241{
242 let cols: usize = a.cols();
243 let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
244 (0..cols - 1).for_each(|col_i| {
245 module.vec_znx_dft_from_vec_znx(1, 0, &mut ai_dft, col_i, a, col_i + 1);
246 });
247 module.vmp_apply(&mut res_dft, &ai_dft, mat, scratch1);
248 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_dft_to_vec_znx_big_consume(res_dft);
249 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
250 res_big
251}
252
253fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
254 module: &Module<B>,
255 mut res_dft: VecZnxDft<DataRes, B>,
256 a: &VecZnx<DataIn>,
257 mat: &VmpPMat<DataVmp, B>,
258 digits: usize,
259 scratch: &mut Scratch<B>,
260) -> VecZnxBig<DataRes, B>
261where
262 DataRes: DataMut,
263 DataIn: DataRef,
264 DataVmp: DataRef,
265 Module<B>: VecZnxDftAllocBytes
266 + VecZnxDftFromVecZnx<B>
267 + VmpApply<B>
268 + VmpApplyAdd<B>
269 + VecZnxDftToVecZnxBigConsume<B>
270 + VecZnxBigAddSmallInplace<B>,
271 Scratch<B>: TakeVecZnxDft<B>,
272{
273 let cols: usize = a.cols();
274 let size: usize = a.size();
275 let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
276
277 ai_dft.data_mut().fill(0);
278
279 (0..digits).for_each(|di| {
280 ai_dft.set_size((size + di) / digits);
281
282 res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
290
291 (0..cols - 1).for_each(|col_i| {
292 module.vec_znx_dft_from_vec_znx(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
293 });
294
295 if di == 0 {
296 module.vmp_apply(&mut res_dft, &ai_dft, mat, scratch1);
297 } else {
298 module.vmp_apply_add(&mut res_dft, &ai_dft, mat, di, scratch1);
299 }
300 });
301
302 res_dft.set_size(res_dft.max_size());
303 let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_dft_to_vec_znx_big_consume(res_dft);
304 module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
305 res_big
306}