1use poulpy_hal::{
2 api::{ModuleN, ScratchAvailable, ScratchFromBytes, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3 layouts::{Backend, Scratch},
4};
5
6use crate::{
7 dist::Distribution,
8 layouts::{
9 Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext,
10 GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, Rank,
11 prepared::{
12 GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
13 GLWESwitchingKeyPrepared, GLWETensorKeyPrepared,
14 },
15 },
16};
17
18pub trait ScratchTakeCore<B: Backend>
19where
20 Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes<B>,
21{
22 fn take_lwe<A>(&mut self, infos: &A) -> (LWE<&mut [u8]>, &mut Self)
23 where
24 A: LWEInfos,
25 {
26 let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
27 (
28 LWE {
29 k: infos.k(),
30 base2k: infos.base2k(),
31 data,
32 },
33 scratch,
34 )
35 }
36
37 fn take_glwe<A>(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
38 where
39 A: GLWEInfos,
40 {
41 let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size());
42 (
43 GLWE {
44 k: infos.k(),
45 base2k: infos.base2k(),
46 data,
47 },
48 scratch,
49 )
50 }
51
52 fn take_glwe_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GLWE<&mut [u8]>>, &mut Self)
53 where
54 A: GLWEInfos,
55 {
56 let mut scratch: &mut Self = self;
57 let mut cts: Vec<GLWE<&mut [u8]>> = Vec::with_capacity(size);
58 for _ in 0..size {
59 let (ct, new_scratch) = scratch.take_glwe(infos);
60 scratch = new_scratch;
61 cts.push(ct);
62 }
63 (cts, scratch)
64 }
65
66 fn take_glwe_plaintext<A>(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self)
67 where
68 A: GLWEInfos,
69 {
70 let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
71 (
72 GLWEPlaintext {
73 k: infos.k(),
74 base2k: infos.base2k(),
75 data,
76 },
77 scratch,
78 )
79 }
80
81 fn take_gglwe<A>(&mut self, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self)
82 where
83 A: GGLWEInfos,
84 {
85 let (data, scratch) = self.take_mat_znx(
86 infos.n().into(),
87 infos.dnum().0.div_ceil(infos.dsize().0) as usize,
88 infos.rank_in().into(),
89 (infos.rank_out() + 1).into(),
90 infos.size(),
91 );
92 (
93 GGLWE {
94 k: infos.k(),
95 base2k: infos.base2k(),
96 dsize: infos.dsize(),
97 data,
98 },
99 scratch,
100 )
101 }
102
103 fn take_gglwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self)
104 where
105 A: GGLWEInfos,
106 M: ModuleN + VmpPMatBytesOf,
107 {
108 assert_eq!(module.n() as u32, infos.n());
109 let (data, scratch) = self.take_vmp_pmat(
110 module,
111 infos.dnum().into(),
112 infos.rank_in().into(),
113 (infos.rank_out() + 1).into(),
114 infos.size(),
115 );
116 (
117 GGLWEPrepared {
118 k: infos.k(),
119 base2k: infos.base2k(),
120 dsize: infos.dsize(),
121 data,
122 },
123 scratch,
124 )
125 }
126
127 fn take_ggsw<A>(&mut self, infos: &A) -> (GGSW<&mut [u8]>, &mut Self)
128 where
129 A: GGSWInfos,
130 {
131 let (data, scratch) = self.take_mat_znx(
132 infos.n().into(),
133 infos.dnum().into(),
134 (infos.rank() + 1).into(),
135 (infos.rank() + 1).into(),
136 infos.size(),
137 );
138 (
139 GGSW {
140 k: infos.k(),
141 base2k: infos.base2k(),
142 dsize: infos.dsize(),
143 data,
144 },
145 scratch,
146 )
147 }
148
149 fn take_ggsw_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self)
150 where
151 A: GGSWInfos,
152 M: ModuleN + VmpPMatBytesOf,
153 {
154 assert_eq!(module.n() as u32, infos.n());
155 let (data, scratch) = self.take_vmp_pmat(
156 module,
157 infos.dnum().into(),
158 (infos.rank() + 1).into(),
159 (infos.rank() + 1).into(),
160 infos.size(),
161 );
162 (
163 GGSWPrepared {
164 k: infos.k(),
165 base2k: infos.base2k(),
166 dsize: infos.dsize(),
167 data,
168 },
169 scratch,
170 )
171 }
172
173 fn take_ggsw_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GGSW<&mut [u8]>>, &mut Self)
174 where
175 A: GGSWInfos,
176 {
177 let mut scratch: &mut Self = self;
178 let mut cts: Vec<GGSW<&mut [u8]>> = Vec::with_capacity(size);
179 for _ in 0..size {
180 let (ct, new_scratch) = scratch.take_ggsw(infos);
181 scratch = new_scratch;
182 cts.push(ct)
183 }
184 (cts, scratch)
185 }
186
187 fn take_ggsw_prepared_slice<A, M>(
188 &mut self,
189 module: &M,
190 size: usize,
191 infos: &A,
192 ) -> (Vec<GGSWPrepared<&mut [u8], B>>, &mut Self)
193 where
194 A: GGSWInfos,
195 M: ModuleN + VmpPMatBytesOf,
196 {
197 let mut scratch: &mut Self = self;
198 let mut cts: Vec<GGSWPrepared<&mut [u8], B>> = Vec::with_capacity(size);
199 for _ in 0..size {
200 let (ct, new_scratch) = scratch.take_ggsw_prepared(module, infos);
201 scratch = new_scratch;
202 cts.push(ct)
203 }
204 (cts, scratch)
205 }
206
207 fn take_glwe_public_key<A>(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self)
208 where
209 A: GLWEInfos,
210 {
211 let (data, scratch) = self.take_glwe(infos);
212 (
213 GLWEPublicKey {
214 dist: Distribution::NONE,
215 key: data,
216 },
217 scratch,
218 )
219 }
220
221 fn take_glwe_public_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self)
222 where
223 A: GLWEInfos,
224 M: ModuleN + VecZnxDftBytesOf,
225 {
226 let (data, scratch) = self.take_glwe_prepared(module, infos);
227 (
228 GLWEPublicKeyPrepared {
229 dist: Distribution::NONE,
230 key: data,
231 },
232 scratch,
233 )
234 }
235
236 fn take_glwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPrepared<&mut [u8], B>, &mut Self)
237 where
238 A: GLWEInfos,
239 M: ModuleN + VecZnxDftBytesOf,
240 {
241 assert_eq!(module.n() as u32, infos.n());
242 let (data, scratch) = self.take_vec_znx_dft(module, (infos.rank() + 1).into(), infos.size());
243 (
244 GLWEPrepared {
245 k: infos.k(),
246 base2k: infos.base2k(),
247 data,
248 },
249 scratch,
250 )
251 }
252
253 fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) {
254 let (data, scratch) = self.take_scalar_znx(n.into(), rank.into());
255 (
256 GLWESecret {
257 data,
258 dist: Distribution::NONE,
259 },
260 scratch,
261 )
262 }
263
264 fn take_glwe_secret_tensor(&mut self, n: Degree, rank: Rank) -> (GLWESecretTensor<&mut [u8]>, &mut Self) {
265 let (data, scratch) = self.take_scalar_znx(n.into(), GLWESecretTensor::pairs(rank.into()));
266 (
267 GLWESecretTensor {
268 data,
269 rank,
270 dist: Distribution::NONE,
271 },
272 scratch,
273 )
274 }
275
276 fn take_glwe_secret_prepared<M>(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self)
277 where
278 M: ModuleN + SvpPPolBytesOf,
279 {
280 let (data, scratch) = self.take_svp_ppol(module, rank.into());
281 (
282 GLWESecretPrepared {
283 data,
284 dist: Distribution::NONE,
285 },
286 scratch,
287 )
288 }
289
290 fn take_glwe_switching_key<A>(&mut self, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self)
291 where
292 A: GGLWEInfos,
293 {
294 let (data, scratch) = self.take_gglwe(infos);
295 (
296 GLWESwitchingKey {
297 key: data,
298 input_degree: Degree(0),
299 output_degree: Degree(0),
300 },
301 scratch,
302 )
303 }
304
305 fn take_glwe_switching_key_prepared<A, M>(
306 &mut self,
307 module: &M,
308 infos: &A,
309 ) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self)
310 where
311 A: GGLWEInfos,
312 M: ModuleN + VmpPMatBytesOf,
313 {
314 assert_eq!(module.n() as u32, infos.n());
315 let (data, scratch) = self.take_gglwe_prepared(module, infos);
316 (
317 GLWESwitchingKeyPrepared {
318 key: data,
319 input_degree: Degree(0),
320 output_degree: Degree(0),
321 },
322 scratch,
323 )
324 }
325
326 fn take_glwe_automorphism_key<A>(&mut self, infos: &A) -> (GLWEAutomorphismKey<&mut [u8]>, &mut Self)
327 where
328 A: GGLWEInfos,
329 {
330 let (data, scratch) = self.take_gglwe(infos);
331 (GLWEAutomorphismKey { key: data, p: 0 }, scratch)
332 }
333
334 fn take_glwe_automorphism_key_prepared<A, M>(
335 &mut self,
336 module: &M,
337 infos: &A,
338 ) -> (GLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self)
339 where
340 A: GGLWEInfos,
341 M: ModuleN + VmpPMatBytesOf,
342 {
343 assert_eq!(module.n() as u32, infos.n());
344 let (data, scratch) = self.take_gglwe_prepared(module, infos);
345 (GLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch)
346 }
347
348 fn take_glwe_tensor_key<A, M>(&mut self, infos: &A) -> (GLWETensorKey<&mut [u8]>, &mut Self)
349 where
350 A: GGLWEInfos,
351 {
352 assert_eq!(
353 infos.rank_in(),
354 infos.rank_out(),
355 "rank_in != rank_out is not supported for GLWETensorKey"
356 );
357
358 let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1);
359 let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
360 ksk_infos.rank_in = Rank(pairs);
361 let (data, scratch) = self.take_gglwe(&ksk_infos);
362 (GLWETensorKey(data), scratch)
363 }
364
365 fn take_glwe_tensor_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self)
366 where
367 A: GGLWEInfos,
368 M: ModuleN + VmpPMatBytesOf,
369 {
370 assert_eq!(module.n() as u32, infos.n());
371 assert_eq!(
372 infos.rank_in(),
373 infos.rank_out(),
374 "rank_in != rank_out is not supported for GGLWETensorKeyPrepared"
375 );
376
377 let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1);
378 let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
379 ksk_infos.rank_in = Rank(pairs);
380 let (data, scratch) = self.take_gglwe_prepared(module, &ksk_infos);
381 (GLWETensorKeyPrepared(data), scratch)
382 }
383}
384
385impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes<B> {}