1use poulpy_hal::{
2 api::{ModuleN, ScratchArenaTakeBasic, SvpPPolBytesOf, VmpPMatBytesOf},
3 layouts::{Backend, ScratchArena},
4};
5
6use crate::{
7 dist::Distribution,
8 layouts::{
9 Degree, GGLWE, GGLWEInfos, GGLWEPreparedViewMut, GGLWEViewMut, GGSW, GGSWInfos, GGSWPreparedViewMut, GGSWViewMut, GLWE,
10 GLWEInfos, GLWEPlaintext, GLWEPlaintextViewMut, GLWESecret, GLWESecretPreparedViewMut, GLWESecretTensor,
11 GLWESecretTensorViewMut, GLWESecretViewMut, GLWETensor, GLWETensorViewMut, GLWEViewMut, LWE, LWEInfos, LWEPlaintext,
12 LWEPlaintextViewMut, LWEViewMut, Rank,
13 prepared::{GGLWEPrepared, GGSWPrepared, GLWESecretPrepared},
14 },
15};
16
17pub trait ScratchArenaTakeCore<'a, B: Backend>: ScratchArenaTakeBasic<'a, B> + Sized {
21 fn take_lwe_scratch<A>(self, infos: &A) -> (LWEViewMut<'a, B>, Self)
23 where
24 B: 'a,
25 A: LWEInfos,
26 {
27 let (body, scratch_1) = self.take_vec_znx_scratch(1, 1, infos.size());
28 let (mask, scratch) = scratch_1.take_vec_znx_scratch(infos.n().into(), 1, infos.size());
29 (
30 LWEViewMut::from_inner(LWE {
31 base2k: infos.base2k(),
32 body: body.into_inner(),
33 mask: mask.into_inner(),
34 }),
35 scratch,
36 )
37 }
38
39 fn take_lwe_plaintext_scratch<A>(self, infos: &A) -> (LWEPlaintextViewMut<'a, B>, Self)
41 where
42 B: 'a,
43 A: LWEInfos,
44 {
45 let (data, scratch) = self.take_vec_znx_scratch(1, 1, infos.size());
46 (
47 LWEPlaintextViewMut::from_inner(LWEPlaintext {
48 base2k: infos.base2k(),
49 data: data.into_inner(),
50 }),
51 scratch,
52 )
53 }
54
55 fn take_glwe_scratch<A>(self, infos: &A) -> (GLWEViewMut<'a, B>, Self)
57 where
58 B: 'a,
59 A: GLWEInfos,
60 {
61 let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), (infos.rank() + 1).into(), infos.size());
62 (
63 GLWEViewMut::from_inner(GLWE {
64 base2k: infos.base2k(),
65 data: data.into_inner(),
66 }),
67 scratch,
68 )
69 }
70
71 fn take_glwe_slice_scratch<A>(self, size: usize, infos: &A) -> (Vec<GLWEViewMut<'a, B>>, Self)
73 where
74 B: 'a,
75 A: GLWEInfos,
76 {
77 let mut scratch: Self = self;
78 let mut cts: Vec<GLWEViewMut<'a, B>> = Vec::with_capacity(size);
79 for _ in 0..size {
80 let (ct, new_scratch) = scratch.take_glwe_scratch(infos);
81 scratch = new_scratch;
82 cts.push(ct);
83 }
84 (cts, scratch)
85 }
86
87 fn take_glwe_tensor_scratch<A>(self, infos: &A) -> (GLWETensorViewMut<'a, B>, Self)
89 where
90 B: 'a,
91 A: GLWEInfos,
92 {
93 let cols: usize = infos.rank().as_usize() + 1;
94 let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
95 let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), pairs, infos.size());
96 (
97 GLWETensorViewMut::from_inner(GLWETensor {
98 base2k: infos.base2k(),
99 rank: infos.rank(),
100 data: data.into_inner(),
101 }),
102 scratch,
103 )
104 }
105
106 fn take_glwe_plaintext_scratch<A>(self, infos: &A) -> (GLWEPlaintextViewMut<'a, B>, Self)
108 where
109 B: 'a,
110 A: GLWEInfos,
111 {
112 let (data, scratch) = self.take_vec_znx_scratch(infos.n().into(), 1, infos.size());
113 (
114 GLWEPlaintextViewMut::from_inner(GLWEPlaintext {
115 base2k: infos.base2k(),
116 data: data.into_inner(),
117 }),
118 scratch,
119 )
120 }
121
122 fn take_glwe_secret_prepared_scratch<M>(self, module: &M, rank: Rank) -> (GLWESecretPreparedViewMut<'a, B>, Self)
124 where
125 B: 'a,
126 M: ModuleN + SvpPPolBytesOf,
127 {
128 let (data, scratch) = self.take_svp_ppol_scratch(module, rank.into());
129 (
130 GLWESecretPreparedViewMut::from_inner(GLWESecretPrepared {
131 data: data.into_inner(),
132 dist: Distribution::NONE,
133 }),
134 scratch,
135 )
136 }
137
138 fn take_glwe_secret_scratch(self, n: Degree, rank: Rank) -> (GLWESecretViewMut<'a, B>, Self)
140 where
141 B: 'a,
142 {
143 let (data, scratch) = self.take_scalar_znx_scratch(n.into(), rank.into());
144 (
145 GLWESecretViewMut::from_inner(GLWESecret {
146 data: data.into_inner(),
147 dist: Distribution::NONE,
148 }),
149 scratch,
150 )
151 }
152
153 fn take_glwe_secret_tensor_scratch(self, n: Degree, rank: Rank) -> (GLWESecretTensorViewMut<'a, B>, Self)
155 where
156 B: 'a,
157 {
158 let (data, scratch) = self.take_scalar_znx_scratch(n.into(), GLWESecretTensor::pairs(rank.into()));
159 (
160 GLWESecretTensorViewMut::from_inner(GLWESecretTensor {
161 data: data.into_inner(),
162 rank,
163 dist: Distribution::NONE,
164 }),
165 scratch,
166 )
167 }
168
169 fn take_gglwe_scratch<A>(self, infos: &A) -> (GGLWEViewMut<'a, B>, Self)
171 where
172 B: 'a,
173 A: GGLWEInfos,
174 {
175 let (data, scratch) = self.take_mat_znx_scratch(
176 infos.n().into(),
177 infos.dnum().0.div_ceil(infos.dsize().0) as usize,
178 infos.rank_in().into(),
179 (infos.rank_out() + 1).into(),
180 infos.size(),
181 );
182 (
183 GGLWEViewMut::from_inner(GGLWE {
184 base2k: infos.base2k(),
185 dsize: infos.dsize(),
186 data: data.into_inner(),
187 }),
188 scratch,
189 )
190 }
191
192 fn take_gglwe_prepared_scratch<A, M>(self, module: &M, infos: &A) -> (GGLWEPreparedViewMut<'a, B>, Self)
194 where
195 B: 'a,
196 A: GGLWEInfos,
197 M: ModuleN + VmpPMatBytesOf,
198 {
199 assert_eq!(module.n() as u32, infos.n());
200 let (data, scratch) = self.take_vmp_pmat_scratch(
201 module,
202 infos.dnum().into(),
203 infos.rank_in().into(),
204 (infos.rank_out() + 1).into(),
205 infos.size(),
206 );
207 (
208 GGLWEPreparedViewMut::from_inner(GGLWEPrepared {
209 base2k: infos.base2k(),
210 dsize: infos.dsize(),
211 data: data.into_inner(),
212 }),
213 scratch,
214 )
215 }
216
217 fn take_ggsw_scratch<A>(self, infos: &A) -> (GGSWViewMut<'a, B>, Self)
219 where
220 B: 'a,
221 A: GGSWInfos,
222 {
223 let (data, scratch) = self.take_mat_znx_scratch(
224 infos.n().into(),
225 infos.dnum().into(),
226 (infos.rank() + 1).into(),
227 (infos.rank() + 1).into(),
228 infos.size(),
229 );
230 (
231 GGSWViewMut::from_inner(GGSW {
232 base2k: infos.base2k(),
233 dsize: infos.dsize(),
234 data: data.into_inner(),
235 }),
236 scratch,
237 )
238 }
239
240 fn take_ggsw_prepared_scratch<A, M>(self, module: &M, infos: &A) -> (GGSWPreparedViewMut<'a, B>, Self)
242 where
243 B: 'a,
244 A: GGSWInfos,
245 M: ModuleN + VmpPMatBytesOf,
246 {
247 assert_eq!(module.n() as u32, infos.n());
248 let (data, scratch) = self.take_vmp_pmat_scratch(
249 module,
250 infos.dnum().into(),
251 (infos.rank() + 1).into(),
252 (infos.rank() + 1).into(),
253 infos.size(),
254 );
255 (
256 GGSWPreparedViewMut::from_inner(GGSWPrepared {
257 base2k: infos.base2k(),
258 dsize: infos.dsize(),
259 data: data.into_inner(),
260 }),
261 scratch,
262 )
263 }
264}
265
266impl<'a, B: Backend> ScratchArenaTakeCore<'a, B> for ScratchArena<'a, B> {}