1use poulpy_hal::{
2 api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3 layouts::{Backend, Scratch},
4};
5
6use crate::{
7 dist::Distribution,
8 layouts::{
9 Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext,
10 GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESwitchingKey, GLWETensorKey, Rank,
11 prepared::{
12 GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
13 GLWESwitchingKeyPrepared, GLWETensorKeyPrepared,
14 },
15 },
16};
17
18pub trait ScratchTakeCore<B: Backend>
19where
20 Self: ScratchTakeBasic + ScratchAvailable,
21{
22 fn take_glwe<A>(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
23 where
24 A: GLWEInfos,
25 {
26 let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size());
27 (
28 GLWE {
29 k: infos.k(),
30 base2k: infos.base2k(),
31 data,
32 },
33 scratch,
34 )
35 }
36
37 fn take_glwe_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GLWE<&mut [u8]>>, &mut Self)
38 where
39 A: GLWEInfos,
40 {
41 let mut scratch: &mut Self = self;
42 let mut cts: Vec<GLWE<&mut [u8]>> = Vec::with_capacity(size);
43 for _ in 0..size {
44 let (ct, new_scratch) = scratch.take_glwe(infos);
45 scratch = new_scratch;
46 cts.push(ct);
47 }
48 (cts, scratch)
49 }
50
51 fn take_glwe_plaintext<A>(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self)
52 where
53 A: GLWEInfos,
54 {
55 let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
56 (
57 GLWEPlaintext {
58 k: infos.k(),
59 base2k: infos.base2k(),
60 data,
61 },
62 scratch,
63 )
64 }
65
66 fn take_gglwe<A>(&mut self, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self)
67 where
68 A: GGLWEInfos,
69 {
70 let (data, scratch) = self.take_mat_znx(
71 infos.n().into(),
72 infos.dnum().0.div_ceil(infos.dsize().0) as usize,
73 infos.rank_in().into(),
74 (infos.rank_out() + 1).into(),
75 infos.size(),
76 );
77 (
78 GGLWE {
79 k: infos.k(),
80 base2k: infos.base2k(),
81 dsize: infos.dsize(),
82 data,
83 },
84 scratch,
85 )
86 }
87
88 fn take_gglwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self)
89 where
90 A: GGLWEInfos,
91 M: ModuleN + VmpPMatBytesOf,
92 {
93 assert_eq!(module.n() as u32, infos.n());
94 let (data, scratch) = self.take_vmp_pmat(
95 module,
96 infos.dnum().into(),
97 infos.rank_in().into(),
98 (infos.rank_out() + 1).into(),
99 infos.size(),
100 );
101 (
102 GGLWEPrepared {
103 k: infos.k(),
104 base2k: infos.base2k(),
105 dsize: infos.dsize(),
106 data,
107 },
108 scratch,
109 )
110 }
111
112 fn take_ggsw<A>(&mut self, infos: &A) -> (GGSW<&mut [u8]>, &mut Self)
113 where
114 A: GGSWInfos,
115 {
116 let (data, scratch) = self.take_mat_znx(
117 infos.n().into(),
118 infos.dnum().into(),
119 (infos.rank() + 1).into(),
120 (infos.rank() + 1).into(),
121 infos.size(),
122 );
123 (
124 GGSW {
125 k: infos.k(),
126 base2k: infos.base2k(),
127 dsize: infos.dsize(),
128 data,
129 },
130 scratch,
131 )
132 }
133
134 fn take_ggsw_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self)
135 where
136 A: GGSWInfos,
137 M: ModuleN + VmpPMatBytesOf,
138 {
139 assert_eq!(module.n() as u32, infos.n());
140 let (data, scratch) = self.take_vmp_pmat(
141 module,
142 infos.dnum().into(),
143 (infos.rank() + 1).into(),
144 (infos.rank() + 1).into(),
145 infos.size(),
146 );
147 (
148 GGSWPrepared {
149 k: infos.k(),
150 base2k: infos.base2k(),
151 dsize: infos.dsize(),
152 data,
153 },
154 scratch,
155 )
156 }
157
158 fn take_ggsw_prepared_slice<A, M>(
159 &mut self,
160 module: &M,
161 size: usize,
162 infos: &A,
163 ) -> (Vec<GGSWPrepared<&mut [u8], B>>, &mut Self)
164 where
165 A: GGSWInfos,
166 M: ModuleN + VmpPMatBytesOf,
167 {
168 let mut scratch: &mut Self = self;
169 let mut cts: Vec<GGSWPrepared<&mut [u8], B>> = Vec::with_capacity(size);
170 for _ in 0..size {
171 let (ct, new_scratch) = scratch.take_ggsw_prepared(module, infos);
172 scratch = new_scratch;
173 cts.push(ct)
174 }
175 (cts, scratch)
176 }
177
178 fn take_glwe_public_key<A>(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self)
179 where
180 A: GLWEInfos,
181 {
182 let (data, scratch) = self.take_glwe(infos);
183 (
184 GLWEPublicKey {
185 dist: Distribution::NONE,
186 key: data,
187 },
188 scratch,
189 )
190 }
191
192 fn take_glwe_public_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self)
193 where
194 A: GLWEInfos,
195 M: ModuleN + VecZnxDftBytesOf,
196 {
197 let (data, scratch) = self.take_glwe_prepared(module, infos);
198 (
199 GLWEPublicKeyPrepared {
200 dist: Distribution::NONE,
201 key: data,
202 },
203 scratch,
204 )
205 }
206
207 fn take_glwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPrepared<&mut [u8], B>, &mut Self)
208 where
209 A: GLWEInfos,
210 M: ModuleN + VecZnxDftBytesOf,
211 {
212 assert_eq!(module.n() as u32, infos.n());
213 let (data, scratch) = self.take_vec_znx_dft(module, (infos.rank() + 1).into(), infos.size());
214 (
215 GLWEPrepared {
216 k: infos.k(),
217 base2k: infos.base2k(),
218 data,
219 },
220 scratch,
221 )
222 }
223
224 fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) {
225 let (data, scratch) = self.take_scalar_znx(n.into(), rank.into());
226 (
227 GLWESecret {
228 data,
229 dist: Distribution::NONE,
230 },
231 scratch,
232 )
233 }
234
235 fn take_glwe_secret_prepared<M>(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self)
236 where
237 M: ModuleN + SvpPPolBytesOf,
238 {
239 let (data, scratch) = self.take_svp_ppol(module, rank.into());
240 (
241 GLWESecretPrepared {
242 data,
243 dist: Distribution::NONE,
244 },
245 scratch,
246 )
247 }
248
249 fn take_glwe_switching_key<A>(&mut self, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self)
250 where
251 A: GGLWEInfos,
252 {
253 let (data, scratch) = self.take_gglwe(infos);
254 (
255 GLWESwitchingKey {
256 key: data,
257 input_degree: Degree(0),
258 output_degree: Degree(0),
259 },
260 scratch,
261 )
262 }
263
264 fn take_glwe_switching_key_prepared<A, M>(
265 &mut self,
266 module: &M,
267 infos: &A,
268 ) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self)
269 where
270 A: GGLWEInfos,
271 M: ModuleN + VmpPMatBytesOf,
272 {
273 assert_eq!(module.n() as u32, infos.n());
274 let (data, scratch) = self.take_gglwe_prepared(module, infos);
275 (
276 GLWESwitchingKeyPrepared {
277 key: data,
278 input_degree: Degree(0),
279 output_degree: Degree(0),
280 },
281 scratch,
282 )
283 }
284
285 fn take_glwe_automorphism_key<A>(&mut self, infos: &A) -> (GLWEAutomorphismKey<&mut [u8]>, &mut Self)
286 where
287 A: GGLWEInfos,
288 {
289 let (data, scratch) = self.take_gglwe(infos);
290 (GLWEAutomorphismKey { key: data, p: 0 }, scratch)
291 }
292
293 fn take_glwe_automorphism_key_prepared<A, M>(
294 &mut self,
295 module: &M,
296 infos: &A,
297 ) -> (GLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self)
298 where
299 A: GGLWEInfos,
300 M: ModuleN + VmpPMatBytesOf,
301 {
302 assert_eq!(module.n() as u32, infos.n());
303 let (data, scratch) = self.take_gglwe_prepared(module, infos);
304 (GLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch)
305 }
306
307 fn take_glwe_tensor_key<A, M>(&mut self, infos: &A) -> (GLWETensorKey<&mut [u8]>, &mut Self)
308 where
309 A: GGLWEInfos,
310 {
311 assert_eq!(
312 infos.rank_in(),
313 infos.rank_out(),
314 "rank_in != rank_out is not supported for GLWETensorKey"
315 );
316 let mut keys: Vec<GGLWE<&mut [u8]>> = Vec::new();
317 let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
318
319 let mut scratch: &mut Self = self;
320
321 let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
322 ksk_infos.rank_in = Rank(1);
323
324 if pairs != 0 {
325 let (gglwe, s) = scratch.take_gglwe(&ksk_infos);
326 scratch = s;
327 keys.push(gglwe);
328 }
329 for _ in 1..pairs {
330 let (gglwe, s) = scratch.take_gglwe(&ksk_infos);
331 scratch = s;
332 keys.push(gglwe);
333 }
334 (GLWETensorKey { keys }, scratch)
335 }
336
337 fn take_glwe_tensor_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self)
338 where
339 A: GGLWEInfos,
340 M: ModuleN + VmpPMatBytesOf,
341 {
342 assert_eq!(module.n() as u32, infos.n());
343 assert_eq!(
344 infos.rank_in(),
345 infos.rank_out(),
346 "rank_in != rank_out is not supported for GGLWETensorKeyPrepared"
347 );
348
349 let mut keys: Vec<GGLWEPrepared<&mut [u8], B>> = Vec::new();
350 let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
351
352 let mut scratch: &mut Self = self;
353
354 let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
355 ksk_infos.rank_in = Rank(1);
356
357 if pairs != 0 {
358 let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos);
359 scratch = s;
360 keys.push(gglwe);
361 }
362 for _ in 1..pairs {
363 let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos);
364 scratch = s;
365 keys.push(gglwe);
366 }
367 (GLWETensorKeyPrepared { keys }, scratch)
368 }
369}
370
371impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic + ScratchAvailable {}