Skip to main content

poulpy_core/
scratch.rs

1use poulpy_hal::{
2    api::{ModuleN, ScratchArenaTakeBasic, SvpPPolBytesOf, VmpPMatBytesOf},
3    layouts::{Backend, ScratchArena},
4};
5
6use crate::{
7    dist::Distribution,
8    layouts::{
9        Degree, GGLWE, GGLWEInfos, GGLWEPreparedViewMut, GGLWEViewMut, GGSW, GGSWInfos, GGSWPreparedViewMut, GGSWViewMut, GLWE,
10        GLWEInfos, GLWEPlaintext, GLWEPlaintextViewMut, GLWESecret, GLWESecretPreparedViewMut, GLWESecretTensor,
11        GLWESecretTensorViewMut, GLWESecretViewMut, GLWETensor, GLWETensorViewMut, GLWEViewMut, LWE, LWEInfos, LWEPlaintext,
12        LWEPlaintextViewMut, LWEViewMut, Rank,
13        prepared::{GGLWEPrepared, GGSWPrepared, GLWESecretPrepared},
14    },
15};
16
17/// Backend-native arena allocation for core ciphertext/key layouts.
18///
19/// Returns backend-native borrows (`B::BufMut<'a>`) carved from a [`ScratchArena`].
20pub trait ScratchArenaTakeCore<'a, B: Backend>: ScratchArenaTakeBasic<'a, B> + Sized {
21    /// Allocates an [`LWE`] ciphertext from scratch space.
22    fn take_lwe_scratch<A>(self, infos: &A) -> (LWEViewMut<'a, B>, Self)
23    where
24        B: 'a,
25        A: LWEInfos,
26    {
27        let (body, scratch_1) = self.take_vec_znx_scratch(1, 1, infos.size());
28        let (mask, scratch) = scratch_1.take_vec_znx_scratch(infos.n().into(), 1, infos.size());
29        (
30            LWEViewMut::from_inner(LWE {
31                base2k: infos.base2k(),
32                body: body.into_inner(),
33                mask: mask.into_inner(),
34            }),
35            scratch,
36        )
37    }
38
39    /// Allocates an [`LWEPlaintext`] from scratch space.
40    fn take_lwe_plaintext_scratch<A>(self, infos: &A) -> (LWEPlaintextViewMut<'a, B>, Self)
41    where
42        B: 'a,
43        A: LWEInfos,
44    {
45        let (data, scratch) = self.take_vec_znx_scratch(1, 1, infos.size());
46        (
47            LWEPlaintextViewMut::from_inner(LWEPlaintext {
48                base2k: infos.base2k(),
49                data: data.into_inner(),
50            }),
51            scratch,
52        )
53    }
54
55    /// Allocates a [`GLWE`] ciphertext from scratch space.
56    fn take_glwe_scratch<A>(self, infos: &A) -> (GLWEViewMut<'a, B>, Self)
57    where
58        B: 'a,
59        A: GLWEInfos,
60    {
61        let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), (infos.rank() + 1).into(), infos.size());
62        (
63            GLWEViewMut::from_inner(GLWE {
64                base2k: infos.base2k(),
65                data: data.into_inner(),
66            }),
67            scratch,
68        )
69    }
70
71    /// Allocates a `Vec` of `size` [`GLWE`] ciphertexts from scratch space.
72    fn take_glwe_slice_scratch<A>(self, size: usize, infos: &A) -> (Vec<GLWEViewMut<'a, B>>, Self)
73    where
74        B: 'a,
75        A: GLWEInfos,
76    {
77        let mut scratch: Self = self;
78        let mut cts: Vec<GLWEViewMut<'a, B>> = Vec::with_capacity(size);
79        for _ in 0..size {
80            let (ct, new_scratch) = scratch.take_glwe_scratch(infos);
81            scratch = new_scratch;
82            cts.push(ct);
83        }
84        (cts, scratch)
85    }
86
87    /// Allocates a [`GLWETensor`] from scratch space.
88    fn take_glwe_tensor_scratch<A>(self, infos: &A) -> (GLWETensorViewMut<'a, B>, Self)
89    where
90        B: 'a,
91        A: GLWEInfos,
92    {
93        let cols: usize = infos.rank().as_usize() + 1;
94        let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
95        let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), pairs, infos.size());
96        (
97            GLWETensorViewMut::from_inner(GLWETensor {
98                base2k: infos.base2k(),
99                rank: infos.rank(),
100                data: data.into_inner(),
101            }),
102            scratch,
103        )
104    }
105
106    /// Allocates a [`GLWEPlaintext`] from scratch space.
107    fn take_glwe_plaintext_scratch<A>(self, infos: &A) -> (GLWEPlaintextViewMut<'a, B>, Self)
108    where
109        B: 'a,
110        A: GLWEInfos,
111    {
112        let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), 1, infos.size());
113        (
114            GLWEPlaintextViewMut::from_inner(GLWEPlaintext {
115                base2k: infos.base2k(),
116                data: data.into_inner(),
117            }),
118            scratch,
119        )
120    }
121
122    /// Allocates a [`GLWESecretPrepared`] (DFT-domain secret key) from scratch space.
123    fn take_glwe_secret_prepared_scratch<M>(self, module: &M, rank: Rank) -> (GLWESecretPreparedViewMut<'a, B>, Self)
124    where
125        B: 'a,
126        M: ModuleN + SvpPPolBytesOf,
127    {
128        let (data, scratch) = self.take_svp_ppol_scratch(module, rank.into());
129        (
130            GLWESecretPreparedViewMut::from_inner(GLWESecretPrepared {
131                data: data.into_inner(),
132                dist: Distribution::NONE,
133            }),
134            scratch,
135        )
136    }
137
138    /// Allocates a [`GLWESecret`] from scratch space.
139    fn take_glwe_secret_scratch(self, n: Degree, rank: Rank) -> (GLWESecretViewMut<'a, B>, Self)
140    where
141        B: 'a,
142    {
143        let (data, scratch) = self.take_scalar_znx_scratch(n.into(), rank.into());
144        (
145            GLWESecretViewMut::from_inner(GLWESecret {
146                data: data.into_inner(),
147                dist: Distribution::NONE,
148            }),
149            scratch,
150        )
151    }
152
153    /// Allocates a [`GLWESecretTensor`] from scratch space.
154    fn take_glwe_secret_tensor_scratch(self, n: Degree, rank: Rank) -> (GLWESecretTensorViewMut<'a, B>, Self)
155    where
156        B: 'a,
157    {
158        let (data, scratch) = self.take_scalar_znx_scratch(n.into(), GLWESecretTensor::pairs(rank.into()));
159        (
160            GLWESecretTensorViewMut::from_inner(GLWESecretTensor {
161                data: data.into_inner(),
162                rank,
163                dist: Distribution::NONE,
164            }),
165            scratch,
166        )
167    }
168
169    /// Allocates a [`GGLWE`] ciphertext from scratch space.
170    fn take_gglwe_scratch<A>(self, infos: &A) -> (GGLWEViewMut<'a, B>, Self)
171    where
172        B: 'a,
173        A: GGLWEInfos,
174    {
175        let (data, scratch) = self.take_mat_znx_scratch(
176            infos.n().into(),
177            infos.dnum().0.div_ceil(infos.dsize().0) as usize,
178            infos.rank_in().into(),
179            (infos.rank_out() + 1).into(),
180            infos.size(),
181        );
182        (
183            GGLWEViewMut::from_inner(GGLWE {
184                base2k: infos.base2k(),
185                dsize: infos.dsize(),
186                data: data.into_inner(),
187            }),
188            scratch,
189        )
190    }
191
192    /// Allocates a [`GGLWEPrepared`] (DFT-domain GGLWE) from scratch space.
193    fn take_gglwe_prepared_scratch<A, M>(self, module: &M, infos: &A) -> (GGLWEPreparedViewMut<'a, B>, Self)
194    where
195        B: 'a,
196        A: GGLWEInfos,
197        M: ModuleN + VmpPMatBytesOf,
198    {
199        assert_eq!(module.n() as u32, infos.n());
200        let (data, scratch) = self.take_vmp_pmat_scratch(
201            module,
202            infos.dnum().into(),
203            infos.rank_in().into(),
204            (infos.rank_out() + 1).into(),
205            infos.size(),
206        );
207        (
208            GGLWEPreparedViewMut::from_inner(GGLWEPrepared {
209                base2k: infos.base2k(),
210                dsize: infos.dsize(),
211                data: data.into_inner(),
212            }),
213            scratch,
214        )
215    }
216
217    /// Allocates a [`GGSW`] ciphertext from scratch space.
218    fn take_ggsw_scratch<A>(self, infos: &A) -> (GGSWViewMut<'a, B>, Self)
219    where
220        B: 'a,
221        A: GGSWInfos,
222    {
223        let (data, scratch) = self.take_mat_znx_scratch(
224            infos.n().into(),
225            infos.dnum().into(),
226            (infos.rank() + 1).into(),
227            (infos.rank() + 1).into(),
228            infos.size(),
229        );
230        (
231            GGSWViewMut::from_inner(GGSW {
232                base2k: infos.base2k(),
233                dsize: infos.dsize(),
234                data: data.into_inner(),
235            }),
236            scratch,
237        )
238    }
239
240    /// Allocates a [`GGSWPrepared`] (DFT-domain GGSW) from scratch space.
241    fn take_ggsw_prepared_scratch<A, M>(self, module: &M, infos: &A) -> (GGSWPreparedViewMut<'a, B>, Self)
242    where
243        B: 'a,
244        A: GGSWInfos,
245        M: ModuleN + VmpPMatBytesOf,
246    {
247        assert_eq!(module.n() as u32, infos.n());
248        let (data, scratch) = self.take_vmp_pmat_scratch(
249            module,
250            infos.dnum().into(),
251            (infos.rank() + 1).into(),
252            (infos.rank() + 1).into(),
253            infos.size(),
254        );
255        (
256            GGSWPreparedViewMut::from_inner(GGSWPrepared {
257                base2k: infos.base2k(),
258                dsize: infos.dsize(),
259                data: data.into_inner(),
260            }),
261            scratch,
262        )
263    }
264}
265
266impl<'a, B: Backend> ScratchArenaTakeCore<'a, B> for ScratchArena<'a, B> {}