1use crate::{
8 api::{CnvPVecBytesOf, ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
9 layouts::{
10 Backend, CnvPVecL, CnvPVecLViewMut, CnvPVecR, CnvPVecRViewMut, MatZnx, MatZnxViewMut, ScalarZnx, ScalarZnxViewMut,
11 ScratchArena, SvpPPol, SvpPPolViewMut, VecZnx, VecZnxBig, VecZnxBigViewMut, VecZnxDft, VecZnxDftViewMut, VecZnxViewMut,
12 VmpPMat, VmpPMatViewMut,
13 },
14};
15
16pub trait ScratchOwnedAlloc<B: Backend> {
18 fn alloc(size: usize) -> Self;
19}
20
21pub trait ScratchOwnedBorrow<B: Backend> {
23 fn borrow(&mut self) -> ScratchArena<'_, B>;
24}
25
26pub trait ScratchAvailable {
28 fn available(&self) -> usize;
29}
30
31pub trait HostBufMut<'a>: Sized {
36 fn into_bytes(self) -> &'a mut [u8];
37}
38
39impl<'a> HostBufMut<'a> for &'a mut [u8] {
40 #[inline]
41 fn into_bytes(self) -> &'a mut [u8] {
42 self
43 }
44}
45
46pub trait ScratchArenaTakeBasic<'a, B: Backend>: Sized {
52 fn take_cnv_pvec_left_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecLViewMut<'a, B>, Self)
54 where
55 B: 'a,
56 M: ModuleN + CnvPVecBytesOf;
57
58 fn take_cnv_pvec_right_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecRViewMut<'a, B>, Self)
60 where
61 B: 'a,
62 M: ModuleN + CnvPVecBytesOf;
63
64 fn take_scalar_znx_scratch(self, n: usize, cols: usize) -> (ScalarZnxViewMut<'a, B>, Self)
66 where
67 B: 'a;
68
69 fn take_svp_ppol_scratch<M>(self, module: &M, cols: usize) -> (SvpPPolViewMut<'a, B>, Self)
71 where
72 B: 'a,
73 M: SvpPPolBytesOf + ModuleN;
74
75 fn take_vec_znx_scratch(self, n: usize, cols: usize, size: usize) -> (VecZnxViewMut<'a, B>, Self)
77 where
78 B: 'a;
79
80 fn take_vec_znx_big_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
82 where
83 B: 'a,
84 M: VecZnxBigBytesOf + ModuleN;
85
86 fn take_vec_znx_big_scratch_n(self, n: usize, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
87 where
88 B: 'a;
89
90 fn take_vec_znx_dft_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxDftViewMut<'a, B>, Self)
92 where
93 B: 'a,
94 M: VecZnxDftBytesOf + ModuleN;
95
96 fn take_vec_znx_dft_slice_scratch<M>(
98 self,
99 module: &M,
100 len: usize,
101 cols: usize,
102 size: usize,
103 ) -> (Vec<VecZnxDftViewMut<'a, B>>, Self)
104 where
105 B: 'a,
106 M: VecZnxDftBytesOf + ModuleN,
107 {
108 let mut scratch: Self = self;
109 let mut slice: Vec<VecZnxDftViewMut<'a, B>> = Vec::with_capacity(len);
110 for _ in 0..len {
111 let (znx, rem) = scratch.take_vec_znx_dft_scratch(module, cols, size);
112 scratch = rem;
113 slice.push(znx);
114 }
115 (slice, scratch)
116 }
117
118 fn take_vec_znx_slice_scratch(self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnxViewMut<'a, B>>, Self)
120 where
121 B: 'a,
122 {
123 let mut scratch: Self = self;
124 let mut slice: Vec<VecZnxViewMut<'a, B>> = Vec::with_capacity(len);
125 for _ in 0..len {
126 let (znx, rem) = scratch.take_vec_znx_scratch(n, cols, size);
127 scratch = rem;
128 slice.push(znx);
129 }
130 (slice, scratch)
131 }
132
133 fn take_vmp_pmat_scratch<M>(
135 self,
136 module: &M,
137 rows: usize,
138 cols_in: usize,
139 cols_out: usize,
140 size: usize,
141 ) -> (VmpPMatViewMut<'a, B>, Self)
142 where
143 B: 'a,
144 M: VmpPMatBytesOf + ModuleN;
145
146 fn take_mat_znx_scratch(
148 self,
149 n: usize,
150 rows: usize,
151 cols_in: usize,
152 cols_out: usize,
153 size: usize,
154 ) -> (MatZnxViewMut<'a, B>, Self)
155 where
156 B: 'a;
157}
158
159impl<'a, B: Backend> ScratchArenaTakeBasic<'a, B> for ScratchArena<'a, B> {
160 fn take_cnv_pvec_left_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecLViewMut<'a, B>, Self)
161 where
162 B: 'a,
163 M: ModuleN + CnvPVecBytesOf,
164 {
165 let (data, arena) = self.take_region(module.bytes_of_cnv_pvec_left(cols, size));
166 (
167 CnvPVecLViewMut::from_inner(CnvPVecL::from_data(data, module.n(), cols, size)),
168 arena,
169 )
170 }
171
172 fn take_cnv_pvec_right_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecRViewMut<'a, B>, Self)
173 where
174 B: 'a,
175 M: ModuleN + CnvPVecBytesOf,
176 {
177 let (data, arena) = self.take_region(module.bytes_of_cnv_pvec_right(cols, size));
178 (
179 CnvPVecRViewMut::from_inner(CnvPVecR::from_data(data, module.n(), cols, size)),
180 arena,
181 )
182 }
183
184 fn take_scalar_znx_scratch(self, n: usize, cols: usize) -> (ScalarZnxViewMut<'a, B>, Self)
185 where
186 B: 'a,
187 {
188 let (data, arena) = self.take_region(ScalarZnx::bytes_of(n, cols));
189 (ScalarZnxViewMut::from_inner(ScalarZnx::from_data(data, n, cols)), arena)
190 }
191
192 fn take_svp_ppol_scratch<M>(self, module: &M, cols: usize) -> (SvpPPolViewMut<'a, B>, Self)
193 where
194 B: 'a,
195 M: SvpPPolBytesOf + ModuleN,
196 {
197 let (data, arena) = self.take_region(module.bytes_of_svp_ppol(cols));
198 (SvpPPolViewMut::from_inner(SvpPPol::from_data(data, module.n(), cols)), arena)
199 }
200
201 fn take_vec_znx_scratch(self, n: usize, cols: usize, size: usize) -> (VecZnxViewMut<'a, B>, Self)
202 where
203 B: 'a,
204 {
205 let (data, arena) = self.take_region(VecZnx::bytes_of(n, cols, size));
206 (VecZnxViewMut::from_inner(VecZnx::from_data(data, n, cols, size)), arena)
207 }
208
209 fn take_vec_znx_big_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
210 where
211 B: 'a,
212 M: VecZnxBigBytesOf + ModuleN,
213 {
214 self.take_vec_znx_big_scratch_n(module.n(), cols, size)
215 }
216
217 fn take_vec_znx_big_scratch_n(self, n: usize, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
218 where
219 B: 'a,
220 {
221 let (data, arena) = self.take_region(B::bytes_of_vec_znx_big(n, cols, size));
222 (VecZnxBigViewMut::from_inner(VecZnxBig::from_data(data, n, cols, size)), arena)
223 }
224
225 fn take_vec_znx_dft_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxDftViewMut<'a, B>, Self)
226 where
227 B: 'a,
228 M: VecZnxDftBytesOf + ModuleN,
229 {
230 let (data, arena) = self.take_region(module.bytes_of_vec_znx_dft(cols, size));
231 (
232 VecZnxDftViewMut::from_inner(VecZnxDft::from_data(data, module.n(), cols, size)),
233 arena,
234 )
235 }
236
237 fn take_vmp_pmat_scratch<M>(
238 self,
239 module: &M,
240 rows: usize,
241 cols_in: usize,
242 cols_out: usize,
243 size: usize,
244 ) -> (VmpPMatViewMut<'a, B>, Self)
245 where
246 B: 'a,
247 M: VmpPMatBytesOf + ModuleN,
248 {
249 let (data, arena) = self.take_region(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
250 (
251 VmpPMatViewMut::from_inner(VmpPMat::from_data(data, module.n(), rows, cols_in, cols_out, size)),
252 arena,
253 )
254 }
255
256 fn take_mat_znx_scratch(
257 self,
258 n: usize,
259 rows: usize,
260 cols_in: usize,
261 cols_out: usize,
262 size: usize,
263 ) -> (MatZnxViewMut<'a, B>, Self)
264 where
265 B: 'a,
266 {
267 let (data, arena) = self.take_region(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
268 (
269 MatZnxViewMut::from_inner(MatZnx::from_data(data, n, rows, cols_in, cols_out, size)),
270 arena,
271 )
272 }
273}