Skip to main content

poulpy_hal/api/
scratch.rs

1//! Scratch memory allocation, borrowing, and arena-style sub-allocation.
2//!
3//! Provides traits for creating scratch buffers, borrowing them as
4//! backend-native [`ScratchArena`] values, and carving typed layout
5//! objects (e.g., [`VecZnx`], [`VecZnxDft`], [`VmpPMat`]) out of them.
6
7use crate::{
8    api::{CnvPVecBytesOf, ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
9    layouts::{
10        Backend, CnvPVecL, CnvPVecLViewMut, CnvPVecR, CnvPVecRViewMut, MatZnx, MatZnxViewMut, ScalarZnx, ScalarZnxViewMut,
11        ScratchArena, SvpPPol, SvpPPolViewMut, VecZnx, VecZnxBig, VecZnxBigViewMut, VecZnxDft, VecZnxDftViewMut, VecZnxViewMut,
12        VmpPMat, VmpPMatViewMut,
13    },
14};
15
16/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
17pub trait ScratchOwnedAlloc<B: Backend> {
18    fn alloc(size: usize) -> Self;
19}
20
21/// Borrows an owned scratch buffer as a backend-native arena.
22pub trait ScratchOwnedBorrow<B: Backend> {
23    fn borrow(&mut self) -> ScratchArena<'_, B>;
24}
25
26/// Returns how many bytes left can be taken from the scratch.
27pub trait ScratchAvailable {
28    fn available(&self) -> usize;
29}
30
31/// Host-visible borrowed scratch region for a backend.
32///
33/// Device backends should not implement this unless their borrowed mutable
34/// scratch region is directly accessible as a host byte slice.
35pub trait HostBufMut<'a>: Sized {
36    fn into_bytes(self) -> &'a mut [u8];
37}
38
39impl<'a> HostBufMut<'a> for &'a mut [u8] {
40    #[inline]
41    fn into_bytes(self) -> &'a mut [u8] {
42        self
43    }
44}
45
46/// Backend-native arena allocation of typed HAL layouts.
47///
48/// This is the additive, backend-owned scratch path introduced for
49/// incremental device-backend integration. It consumes a [`ScratchArena`]
50/// by value and returns the carved layout together with the remaining arena.
51pub trait ScratchArenaTakeBasic<'a, B: Backend>: Sized {
52    /// Takes a [`CnvPVecL`] from the scratch arena.
53    fn take_cnv_pvec_left_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecLViewMut<'a, B>, Self)
54    where
55        B: 'a,
56        M: ModuleN + CnvPVecBytesOf;
57
58    /// Takes a [`CnvPVecR`] from the scratch arena.
59    fn take_cnv_pvec_right_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecRViewMut<'a, B>, Self)
60    where
61        B: 'a,
62        M: ModuleN + CnvPVecBytesOf;
63
64    /// Takes a [`ScalarZnx`] from the scratch arena.
65    fn take_scalar_znx_scratch(self, n: usize, cols: usize) -> (ScalarZnxViewMut<'a, B>, Self)
66    where
67        B: 'a;
68
69    /// Takes a [`SvpPPol`] from the scratch arena.
70    fn take_svp_ppol_scratch<M>(self, module: &M, cols: usize) -> (SvpPPolViewMut<'a, B>, Self)
71    where
72        B: 'a,
73        M: SvpPPolBytesOf + ModuleN;
74
75    /// Takes a [`VecZnx`] from the scratch arena.
76    fn take_vec_znx_scratch(self, n: usize, cols: usize, size: usize) -> (VecZnxViewMut<'a, B>, Self)
77    where
78        B: 'a;
79
80    /// Takes a [`VecZnxBig`] from the scratch arena.
81    fn take_vec_znx_big_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
82    where
83        B: 'a,
84        M: VecZnxBigBytesOf + ModuleN;
85
86    fn take_vec_znx_big_scratch_n(self, n: usize, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
87    where
88        B: 'a;
89
90    /// Takes a [`VecZnxDft`] from the scratch arena.
91    fn take_vec_znx_dft_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxDftViewMut<'a, B>, Self)
92    where
93        B: 'a,
94        M: VecZnxDftBytesOf + ModuleN;
95
96    /// Takes `len` consecutive [`VecZnxDft`] objects from the scratch arena.
97    fn take_vec_znx_dft_slice_scratch<M>(
98        self,
99        module: &M,
100        len: usize,
101        cols: usize,
102        size: usize,
103    ) -> (Vec<VecZnxDftViewMut<'a, B>>, Self)
104    where
105        B: 'a,
106        M: VecZnxDftBytesOf + ModuleN,
107    {
108        let mut scratch: Self = self;
109        let mut slice: Vec<VecZnxDftViewMut<'a, B>> = Vec::with_capacity(len);
110        for _ in 0..len {
111            let (znx, rem) = scratch.take_vec_znx_dft_scratch(module, cols, size);
112            scratch = rem;
113            slice.push(znx);
114        }
115        (slice, scratch)
116    }
117
118    /// Takes `len` consecutive [`VecZnx`] objects from the scratch arena.
119    fn take_vec_znx_slice_scratch(self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnxViewMut<'a, B>>, Self)
120    where
121        B: 'a,
122    {
123        let mut scratch: Self = self;
124        let mut slice: Vec<VecZnxViewMut<'a, B>> = Vec::with_capacity(len);
125        for _ in 0..len {
126            let (znx, rem) = scratch.take_vec_znx_scratch(n, cols, size);
127            scratch = rem;
128            slice.push(znx);
129        }
130        (slice, scratch)
131    }
132
133    /// Takes a [`VmpPMat`] from the scratch arena.
134    fn take_vmp_pmat_scratch<M>(
135        self,
136        module: &M,
137        rows: usize,
138        cols_in: usize,
139        cols_out: usize,
140        size: usize,
141    ) -> (VmpPMatViewMut<'a, B>, Self)
142    where
143        B: 'a,
144        M: VmpPMatBytesOf + ModuleN;
145
146    /// Takes a [`MatZnx`] from the scratch arena.
147    fn take_mat_znx_scratch(
148        self,
149        n: usize,
150        rows: usize,
151        cols_in: usize,
152        cols_out: usize,
153        size: usize,
154    ) -> (MatZnxViewMut<'a, B>, Self)
155    where
156        B: 'a;
157}
158
159impl<'a, B: Backend> ScratchArenaTakeBasic<'a, B> for ScratchArena<'a, B> {
160    fn take_cnv_pvec_left_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecLViewMut<'a, B>, Self)
161    where
162        B: 'a,
163        M: ModuleN + CnvPVecBytesOf,
164    {
165        let (data, arena) = self.take_region(module.bytes_of_cnv_pvec_left(cols, size));
166        (
167            CnvPVecLViewMut::from_inner(CnvPVecL::from_data(data, module.n(), cols, size)),
168            arena,
169        )
170    }
171
172    fn take_cnv_pvec_right_scratch<M>(self, module: &M, cols: usize, size: usize) -> (CnvPVecRViewMut<'a, B>, Self)
173    where
174        B: 'a,
175        M: ModuleN + CnvPVecBytesOf,
176    {
177        let (data, arena) = self.take_region(module.bytes_of_cnv_pvec_right(cols, size));
178        (
179            CnvPVecRViewMut::from_inner(CnvPVecR::from_data(data, module.n(), cols, size)),
180            arena,
181        )
182    }
183
184    fn take_scalar_znx_scratch(self, n: usize, cols: usize) -> (ScalarZnxViewMut<'a, B>, Self)
185    where
186        B: 'a,
187    {
188        let (data, arena) = self.take_region(ScalarZnx::bytes_of(n, cols));
189        (ScalarZnxViewMut::from_inner(ScalarZnx::from_data(data, n, cols)), arena)
190    }
191
192    fn take_svp_ppol_scratch<M>(self, module: &M, cols: usize) -> (SvpPPolViewMut<'a, B>, Self)
193    where
194        B: 'a,
195        M: SvpPPolBytesOf + ModuleN,
196    {
197        let (data, arena) = self.take_region(module.bytes_of_svp_ppol(cols));
198        (SvpPPolViewMut::from_inner(SvpPPol::from_data(data, module.n(), cols)), arena)
199    }
200
201    fn take_vec_znx_scratch(self, n: usize, cols: usize, size: usize) -> (VecZnxViewMut<'a, B>, Self)
202    where
203        B: 'a,
204    {
205        let (data, arena) = self.take_region(VecZnx::bytes_of(n, cols, size));
206        (VecZnxViewMut::from_inner(VecZnx::from_data(data, n, cols, size)), arena)
207    }
208
209    fn take_vec_znx_big_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
210    where
211        B: 'a,
212        M: VecZnxBigBytesOf + ModuleN,
213    {
214        self.take_vec_znx_big_scratch_n(module.n(), cols, size)
215    }
216
217    fn take_vec_znx_big_scratch_n(self, n: usize, cols: usize, size: usize) -> (VecZnxBigViewMut<'a, B>, Self)
218    where
219        B: 'a,
220    {
221        let (data, arena) = self.take_region(B::bytes_of_vec_znx_big(n, cols, size));
222        (VecZnxBigViewMut::from_inner(VecZnxBig::from_data(data, n, cols, size)), arena)
223    }
224
225    fn take_vec_znx_dft_scratch<M>(self, module: &M, cols: usize, size: usize) -> (VecZnxDftViewMut<'a, B>, Self)
226    where
227        B: 'a,
228        M: VecZnxDftBytesOf + ModuleN,
229    {
230        let (data, arena) = self.take_region(module.bytes_of_vec_znx_dft(cols, size));
231        (
232            VecZnxDftViewMut::from_inner(VecZnxDft::from_data(data, module.n(), cols, size)),
233            arena,
234        )
235    }
236
237    fn take_vmp_pmat_scratch<M>(
238        self,
239        module: &M,
240        rows: usize,
241        cols_in: usize,
242        cols_out: usize,
243        size: usize,
244    ) -> (VmpPMatViewMut<'a, B>, Self)
245    where
246        B: 'a,
247        M: VmpPMatBytesOf + ModuleN,
248    {
249        let (data, arena) = self.take_region(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
250        (
251            VmpPMatViewMut::from_inner(VmpPMat::from_data(data, module.n(), rows, cols_in, cols_out, size)),
252            arena,
253        )
254    }
255
256    fn take_mat_znx_scratch(
257        self,
258        n: usize,
259        rows: usize,
260        cols_in: usize,
261        cols_out: usize,
262        size: usize,
263    ) -> (MatZnxViewMut<'a, B>, Self)
264    where
265        B: 'a,
266    {
267        let (data, arena) = self.take_region(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
268        (
269            MatZnxViewMut::from_inner(MatZnx::from_data(data, n, rows, cols_in, cols_out, size)),
270            arena,
271        )
272    }
273}