poulpy_core/
scratch.rs

1use poulpy_hal::{
2    api::{ModuleN, ScratchAvailable, ScratchFromBytes, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3    layouts::{Backend, Scratch},
4};
5
6use crate::{
7    dist::Distribution,
8    layouts::{
9        Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext,
10        GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, LWE, LWEInfos, Rank,
11        prepared::{
12            GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
13            GLWESwitchingKeyPrepared, GLWETensorKeyPrepared,
14        },
15    },
16};
17
18pub trait ScratchTakeCore<B: Backend>
19where
20    Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes<B>,
21{
22    fn take_lwe<A>(&mut self, infos: &A) -> (LWE<&mut [u8]>, &mut Self)
23    where
24        A: LWEInfos,
25    {
26        let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
27        (
28            LWE {
29                k: infos.k(),
30                base2k: infos.base2k(),
31                data,
32            },
33            scratch,
34        )
35    }
36
37    fn take_glwe<A>(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
38    where
39        A: GLWEInfos,
40    {
41        let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size());
42        (
43            GLWE {
44                k: infos.k(),
45                base2k: infos.base2k(),
46                data,
47            },
48            scratch,
49        )
50    }
51
52    fn take_glwe_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GLWE<&mut [u8]>>, &mut Self)
53    where
54        A: GLWEInfos,
55    {
56        let mut scratch: &mut Self = self;
57        let mut cts: Vec<GLWE<&mut [u8]>> = Vec::with_capacity(size);
58        for _ in 0..size {
59            let (ct, new_scratch) = scratch.take_glwe(infos);
60            scratch = new_scratch;
61            cts.push(ct);
62        }
63        (cts, scratch)
64    }
65
66    fn take_glwe_plaintext<A>(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self)
67    where
68        A: GLWEInfos,
69    {
70        let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
71        (
72            GLWEPlaintext {
73                k: infos.k(),
74                base2k: infos.base2k(),
75                data,
76            },
77            scratch,
78        )
79    }
80
81    fn take_gglwe<A>(&mut self, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self)
82    where
83        A: GGLWEInfos,
84    {
85        let (data, scratch) = self.take_mat_znx(
86            infos.n().into(),
87            infos.dnum().0.div_ceil(infos.dsize().0) as usize,
88            infos.rank_in().into(),
89            (infos.rank_out() + 1).into(),
90            infos.size(),
91        );
92        (
93            GGLWE {
94                k: infos.k(),
95                base2k: infos.base2k(),
96                dsize: infos.dsize(),
97                data,
98            },
99            scratch,
100        )
101    }
102
103    fn take_gglwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self)
104    where
105        A: GGLWEInfos,
106        M: ModuleN + VmpPMatBytesOf,
107    {
108        assert_eq!(module.n() as u32, infos.n());
109        let (data, scratch) = self.take_vmp_pmat(
110            module,
111            infos.dnum().into(),
112            infos.rank_in().into(),
113            (infos.rank_out() + 1).into(),
114            infos.size(),
115        );
116        (
117            GGLWEPrepared {
118                k: infos.k(),
119                base2k: infos.base2k(),
120                dsize: infos.dsize(),
121                data,
122            },
123            scratch,
124        )
125    }
126
127    fn take_ggsw<A>(&mut self, infos: &A) -> (GGSW<&mut [u8]>, &mut Self)
128    where
129        A: GGSWInfos,
130    {
131        let (data, scratch) = self.take_mat_znx(
132            infos.n().into(),
133            infos.dnum().into(),
134            (infos.rank() + 1).into(),
135            (infos.rank() + 1).into(),
136            infos.size(),
137        );
138        (
139            GGSW {
140                k: infos.k(),
141                base2k: infos.base2k(),
142                dsize: infos.dsize(),
143                data,
144            },
145            scratch,
146        )
147    }
148
149    fn take_ggsw_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self)
150    where
151        A: GGSWInfos,
152        M: ModuleN + VmpPMatBytesOf,
153    {
154        assert_eq!(module.n() as u32, infos.n());
155        let (data, scratch) = self.take_vmp_pmat(
156            module,
157            infos.dnum().into(),
158            (infos.rank() + 1).into(),
159            (infos.rank() + 1).into(),
160            infos.size(),
161        );
162        (
163            GGSWPrepared {
164                k: infos.k(),
165                base2k: infos.base2k(),
166                dsize: infos.dsize(),
167                data,
168            },
169            scratch,
170        )
171    }
172
173    fn take_ggsw_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GGSW<&mut [u8]>>, &mut Self)
174    where
175        A: GGSWInfos,
176    {
177        let mut scratch: &mut Self = self;
178        let mut cts: Vec<GGSW<&mut [u8]>> = Vec::with_capacity(size);
179        for _ in 0..size {
180            let (ct, new_scratch) = scratch.take_ggsw(infos);
181            scratch = new_scratch;
182            cts.push(ct)
183        }
184        (cts, scratch)
185    }
186
187    fn take_ggsw_prepared_slice<A, M>(
188        &mut self,
189        module: &M,
190        size: usize,
191        infos: &A,
192    ) -> (Vec<GGSWPrepared<&mut [u8], B>>, &mut Self)
193    where
194        A: GGSWInfos,
195        M: ModuleN + VmpPMatBytesOf,
196    {
197        let mut scratch: &mut Self = self;
198        let mut cts: Vec<GGSWPrepared<&mut [u8], B>> = Vec::with_capacity(size);
199        for _ in 0..size {
200            let (ct, new_scratch) = scratch.take_ggsw_prepared(module, infos);
201            scratch = new_scratch;
202            cts.push(ct)
203        }
204        (cts, scratch)
205    }
206
207    fn take_glwe_public_key<A>(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self)
208    where
209        A: GLWEInfos,
210    {
211        let (data, scratch) = self.take_glwe(infos);
212        (
213            GLWEPublicKey {
214                dist: Distribution::NONE,
215                key: data,
216            },
217            scratch,
218        )
219    }
220
221    fn take_glwe_public_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self)
222    where
223        A: GLWEInfos,
224        M: ModuleN + VecZnxDftBytesOf,
225    {
226        let (data, scratch) = self.take_glwe_prepared(module, infos);
227        (
228            GLWEPublicKeyPrepared {
229                dist: Distribution::NONE,
230                key: data,
231            },
232            scratch,
233        )
234    }
235
236    fn take_glwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPrepared<&mut [u8], B>, &mut Self)
237    where
238        A: GLWEInfos,
239        M: ModuleN + VecZnxDftBytesOf,
240    {
241        assert_eq!(module.n() as u32, infos.n());
242        let (data, scratch) = self.take_vec_znx_dft(module, (infos.rank() + 1).into(), infos.size());
243        (
244            GLWEPrepared {
245                k: infos.k(),
246                base2k: infos.base2k(),
247                data,
248            },
249            scratch,
250        )
251    }
252
253    fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) {
254        let (data, scratch) = self.take_scalar_znx(n.into(), rank.into());
255        (
256            GLWESecret {
257                data,
258                dist: Distribution::NONE,
259            },
260            scratch,
261        )
262    }
263
264    fn take_glwe_secret_tensor(&mut self, n: Degree, rank: Rank) -> (GLWESecretTensor<&mut [u8]>, &mut Self) {
265        let (data, scratch) = self.take_scalar_znx(n.into(), GLWESecretTensor::pairs(rank.into()));
266        (
267            GLWESecretTensor {
268                data,
269                rank,
270                dist: Distribution::NONE,
271            },
272            scratch,
273        )
274    }
275
276    fn take_glwe_secret_prepared<M>(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self)
277    where
278        M: ModuleN + SvpPPolBytesOf,
279    {
280        let (data, scratch) = self.take_svp_ppol(module, rank.into());
281        (
282            GLWESecretPrepared {
283                data,
284                dist: Distribution::NONE,
285            },
286            scratch,
287        )
288    }
289
290    fn take_glwe_switching_key<A>(&mut self, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self)
291    where
292        A: GGLWEInfos,
293    {
294        let (data, scratch) = self.take_gglwe(infos);
295        (
296            GLWESwitchingKey {
297                key: data,
298                input_degree: Degree(0),
299                output_degree: Degree(0),
300            },
301            scratch,
302        )
303    }
304
305    fn take_glwe_switching_key_prepared<A, M>(
306        &mut self,
307        module: &M,
308        infos: &A,
309    ) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self)
310    where
311        A: GGLWEInfos,
312        M: ModuleN + VmpPMatBytesOf,
313    {
314        assert_eq!(module.n() as u32, infos.n());
315        let (data, scratch) = self.take_gglwe_prepared(module, infos);
316        (
317            GLWESwitchingKeyPrepared {
318                key: data,
319                input_degree: Degree(0),
320                output_degree: Degree(0),
321            },
322            scratch,
323        )
324    }
325
326    fn take_glwe_automorphism_key<A>(&mut self, infos: &A) -> (GLWEAutomorphismKey<&mut [u8]>, &mut Self)
327    where
328        A: GGLWEInfos,
329    {
330        let (data, scratch) = self.take_gglwe(infos);
331        (GLWEAutomorphismKey { key: data, p: 0 }, scratch)
332    }
333
334    fn take_glwe_automorphism_key_prepared<A, M>(
335        &mut self,
336        module: &M,
337        infos: &A,
338    ) -> (GLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self)
339    where
340        A: GGLWEInfos,
341        M: ModuleN + VmpPMatBytesOf,
342    {
343        assert_eq!(module.n() as u32, infos.n());
344        let (data, scratch) = self.take_gglwe_prepared(module, infos);
345        (GLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch)
346    }
347
348    fn take_glwe_tensor_key<A, M>(&mut self, infos: &A) -> (GLWETensorKey<&mut [u8]>, &mut Self)
349    where
350        A: GGLWEInfos,
351    {
352        assert_eq!(
353            infos.rank_in(),
354            infos.rank_out(),
355            "rank_in != rank_out is not supported for GLWETensorKey"
356        );
357
358        let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1);
359        let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
360        ksk_infos.rank_in = Rank(pairs);
361        let (data, scratch) = self.take_gglwe(&ksk_infos);
362        (GLWETensorKey(data), scratch)
363    }
364
365    fn take_glwe_tensor_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self)
366    where
367        A: GGLWEInfos,
368        M: ModuleN + VmpPMatBytesOf,
369    {
370        assert_eq!(module.n() as u32, infos.n());
371        assert_eq!(
372            infos.rank_in(),
373            infos.rank_out(),
374            "rank_in != rank_out is not supported for GGLWETensorKeyPrepared"
375        );
376
377        let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1);
378        let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
379        ksk_infos.rank_in = Rank(pairs);
380        let (data, scratch) = self.take_gglwe_prepared(module, &ksk_infos);
381        (GLWETensorKeyPrepared(data), scratch)
382    }
383}
384
385impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic + ScratchAvailable + ScratchFromBytes<B> {}