poulpy_core/
scratch.rs

1use poulpy_hal::{
2    api::{ModuleN, ScratchAvailable, ScratchTakeBasic, SvpPPolBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3    layouts::{Backend, Scratch},
4};
5
6use crate::{
7    dist::Distribution,
8    layouts::{
9        Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext,
10        GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESwitchingKey, GLWETensorKey, Rank,
11        prepared::{
12            GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
13            GLWESwitchingKeyPrepared, GLWETensorKeyPrepared,
14        },
15    },
16};
17
18pub trait ScratchTakeCore<B: Backend>
19where
20    Self: ScratchTakeBasic + ScratchAvailable,
21{
22    fn take_glwe<A>(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self)
23    where
24        A: GLWEInfos,
25    {
26        let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size());
27        (
28            GLWE {
29                k: infos.k(),
30                base2k: infos.base2k(),
31                data,
32            },
33            scratch,
34        )
35    }
36
37    fn take_glwe_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GLWE<&mut [u8]>>, &mut Self)
38    where
39        A: GLWEInfos,
40    {
41        let mut scratch: &mut Self = self;
42        let mut cts: Vec<GLWE<&mut [u8]>> = Vec::with_capacity(size);
43        for _ in 0..size {
44            let (ct, new_scratch) = scratch.take_glwe(infos);
45            scratch = new_scratch;
46            cts.push(ct);
47        }
48        (cts, scratch)
49    }
50
51    fn take_glwe_plaintext<A>(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self)
52    where
53        A: GLWEInfos,
54    {
55        let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
56        (
57            GLWEPlaintext {
58                k: infos.k(),
59                base2k: infos.base2k(),
60                data,
61            },
62            scratch,
63        )
64    }
65
66    fn take_gglwe<A>(&mut self, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self)
67    where
68        A: GGLWEInfos,
69    {
70        let (data, scratch) = self.take_mat_znx(
71            infos.n().into(),
72            infos.dnum().0.div_ceil(infos.dsize().0) as usize,
73            infos.rank_in().into(),
74            (infos.rank_out() + 1).into(),
75            infos.size(),
76        );
77        (
78            GGLWE {
79                k: infos.k(),
80                base2k: infos.base2k(),
81                dsize: infos.dsize(),
82                data,
83            },
84            scratch,
85        )
86    }
87
88    fn take_gglwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGLWEPrepared<&mut [u8], B>, &mut Self)
89    where
90        A: GGLWEInfos,
91        M: ModuleN + VmpPMatBytesOf,
92    {
93        assert_eq!(module.n() as u32, infos.n());
94        let (data, scratch) = self.take_vmp_pmat(
95            module,
96            infos.dnum().into(),
97            infos.rank_in().into(),
98            (infos.rank_out() + 1).into(),
99            infos.size(),
100        );
101        (
102            GGLWEPrepared {
103                k: infos.k(),
104                base2k: infos.base2k(),
105                dsize: infos.dsize(),
106                data,
107            },
108            scratch,
109        )
110    }
111
112    fn take_ggsw<A>(&mut self, infos: &A) -> (GGSW<&mut [u8]>, &mut Self)
113    where
114        A: GGSWInfos,
115    {
116        let (data, scratch) = self.take_mat_znx(
117            infos.n().into(),
118            infos.dnum().into(),
119            (infos.rank() + 1).into(),
120            (infos.rank() + 1).into(),
121            infos.size(),
122        );
123        (
124            GGSW {
125                k: infos.k(),
126                base2k: infos.base2k(),
127                dsize: infos.dsize(),
128                data,
129            },
130            scratch,
131        )
132    }
133
134    fn take_ggsw_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GGSWPrepared<&mut [u8], B>, &mut Self)
135    where
136        A: GGSWInfos,
137        M: ModuleN + VmpPMatBytesOf,
138    {
139        assert_eq!(module.n() as u32, infos.n());
140        let (data, scratch) = self.take_vmp_pmat(
141            module,
142            infos.dnum().into(),
143            (infos.rank() + 1).into(),
144            (infos.rank() + 1).into(),
145            infos.size(),
146        );
147        (
148            GGSWPrepared {
149                k: infos.k(),
150                base2k: infos.base2k(),
151                dsize: infos.dsize(),
152                data,
153            },
154            scratch,
155        )
156    }
157
158    fn take_ggsw_prepared_slice<A, M>(
159        &mut self,
160        module: &M,
161        size: usize,
162        infos: &A,
163    ) -> (Vec<GGSWPrepared<&mut [u8], B>>, &mut Self)
164    where
165        A: GGSWInfos,
166        M: ModuleN + VmpPMatBytesOf,
167    {
168        let mut scratch: &mut Self = self;
169        let mut cts: Vec<GGSWPrepared<&mut [u8], B>> = Vec::with_capacity(size);
170        for _ in 0..size {
171            let (ct, new_scratch) = scratch.take_ggsw_prepared(module, infos);
172            scratch = new_scratch;
173            cts.push(ct)
174        }
175        (cts, scratch)
176    }
177
178    fn take_glwe_public_key<A>(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self)
179    where
180        A: GLWEInfos,
181    {
182        let (data, scratch) = self.take_glwe(infos);
183        (
184            GLWEPublicKey {
185                dist: Distribution::NONE,
186                key: data,
187            },
188            scratch,
189        )
190    }
191
192    fn take_glwe_public_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self)
193    where
194        A: GLWEInfos,
195        M: ModuleN + VecZnxDftBytesOf,
196    {
197        let (data, scratch) = self.take_glwe_prepared(module, infos);
198        (
199            GLWEPublicKeyPrepared {
200                dist: Distribution::NONE,
201                key: data,
202            },
203            scratch,
204        )
205    }
206
207    fn take_glwe_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWEPrepared<&mut [u8], B>, &mut Self)
208    where
209        A: GLWEInfos,
210        M: ModuleN + VecZnxDftBytesOf,
211    {
212        assert_eq!(module.n() as u32, infos.n());
213        let (data, scratch) = self.take_vec_znx_dft(module, (infos.rank() + 1).into(), infos.size());
214        (
215            GLWEPrepared {
216                k: infos.k(),
217                base2k: infos.base2k(),
218                data,
219            },
220            scratch,
221        )
222    }
223
224    fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) {
225        let (data, scratch) = self.take_scalar_znx(n.into(), rank.into());
226        (
227            GLWESecret {
228                data,
229                dist: Distribution::NONE,
230            },
231            scratch,
232        )
233    }
234
235    fn take_glwe_secret_prepared<M>(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self)
236    where
237        M: ModuleN + SvpPPolBytesOf,
238    {
239        let (data, scratch) = self.take_svp_ppol(module, rank.into());
240        (
241            GLWESecretPrepared {
242                data,
243                dist: Distribution::NONE,
244            },
245            scratch,
246        )
247    }
248
249    fn take_glwe_switching_key<A>(&mut self, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self)
250    where
251        A: GGLWEInfos,
252    {
253        let (data, scratch) = self.take_gglwe(infos);
254        (
255            GLWESwitchingKey {
256                key: data,
257                input_degree: Degree(0),
258                output_degree: Degree(0),
259            },
260            scratch,
261        )
262    }
263
264    fn take_glwe_switching_key_prepared<A, M>(
265        &mut self,
266        module: &M,
267        infos: &A,
268    ) -> (GLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self)
269    where
270        A: GGLWEInfos,
271        M: ModuleN + VmpPMatBytesOf,
272    {
273        assert_eq!(module.n() as u32, infos.n());
274        let (data, scratch) = self.take_gglwe_prepared(module, infos);
275        (
276            GLWESwitchingKeyPrepared {
277                key: data,
278                input_degree: Degree(0),
279                output_degree: Degree(0),
280            },
281            scratch,
282        )
283    }
284
285    fn take_glwe_automorphism_key<A>(&mut self, infos: &A) -> (GLWEAutomorphismKey<&mut [u8]>, &mut Self)
286    where
287        A: GGLWEInfos,
288    {
289        let (data, scratch) = self.take_gglwe(infos);
290        (GLWEAutomorphismKey { key: data, p: 0 }, scratch)
291    }
292
293    fn take_glwe_automorphism_key_prepared<A, M>(
294        &mut self,
295        module: &M,
296        infos: &A,
297    ) -> (GLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self)
298    where
299        A: GGLWEInfos,
300        M: ModuleN + VmpPMatBytesOf,
301    {
302        assert_eq!(module.n() as u32, infos.n());
303        let (data, scratch) = self.take_gglwe_prepared(module, infos);
304        (GLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch)
305    }
306
307    fn take_glwe_tensor_key<A, M>(&mut self, infos: &A) -> (GLWETensorKey<&mut [u8]>, &mut Self)
308    where
309        A: GGLWEInfos,
310    {
311        assert_eq!(
312            infos.rank_in(),
313            infos.rank_out(),
314            "rank_in != rank_out is not supported for GLWETensorKey"
315        );
316        let mut keys: Vec<GGLWE<&mut [u8]>> = Vec::new();
317        let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
318
319        let mut scratch: &mut Self = self;
320
321        let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
322        ksk_infos.rank_in = Rank(1);
323
324        if pairs != 0 {
325            let (gglwe, s) = scratch.take_gglwe(&ksk_infos);
326            scratch = s;
327            keys.push(gglwe);
328        }
329        for _ in 1..pairs {
330            let (gglwe, s) = scratch.take_gglwe(&ksk_infos);
331            scratch = s;
332            keys.push(gglwe);
333        }
334        (GLWETensorKey { keys }, scratch)
335    }
336
337    fn take_glwe_tensor_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self)
338    where
339        A: GGLWEInfos,
340        M: ModuleN + VmpPMatBytesOf,
341    {
342        assert_eq!(module.n() as u32, infos.n());
343        assert_eq!(
344            infos.rank_in(),
345            infos.rank_out(),
346            "rank_in != rank_out is not supported for GGLWETensorKeyPrepared"
347        );
348
349        let mut keys: Vec<GGLWEPrepared<&mut [u8], B>> = Vec::new();
350        let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
351
352        let mut scratch: &mut Self = self;
353
354        let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
355        ksk_infos.rank_in = Rank(1);
356
357        if pairs != 0 {
358            let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos);
359            scratch = s;
360            keys.push(gglwe);
361        }
362        for _ in 1..pairs {
363            let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos);
364            scratch = s;
365            keys.push(gglwe);
366        }
367        (GLWETensorKeyPrepared { keys }, scratch)
368    }
369}
370
371impl<B: Backend> ScratchTakeCore<B> for Scratch<B> where Self: ScratchTakeBasic + ScratchAvailable {}