poulpy_hal/api/
scratch.rs

1use crate::{
2    api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3    layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
4};
5
6/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
7pub trait ScratchOwnedAlloc<B: Backend> {
8    fn alloc(size: usize) -> Self;
9}
10
11/// Borrows a slice of bytes into a [Scratch].
12pub trait ScratchOwnedBorrow<B: Backend> {
13    fn borrow(&mut self) -> &mut Scratch<B>;
14}
15
16/// Wrap an array of mutable borrowed bytes into a [Scratch].
17pub trait ScratchFromBytes<B: Backend> {
18    fn from_bytes(data: &mut [u8]) -> &mut Scratch<B>;
19}
20
21/// Returns how many bytes left can be taken from the scratch.
22pub trait ScratchAvailable {
23    fn available(&self) -> usize;
24}
25
26/// Takes a slice of bytes from a [Scratch] and return a new [Scratch] minus the taken array of bytes.
27pub trait TakeSlice {
28    fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
29}
30
31impl<BE: Backend> Scratch<BE>
32where
33    Self: TakeSlice + ScratchAvailable + ScratchFromBytes<BE>,
34{
35    pub fn split_at_mut(&mut self, len: usize) -> (&mut Scratch<BE>, &mut Self) {
36        let (take_slice, rem_slice) = self.take_slice(len);
37        (Self::from_bytes(take_slice), rem_slice)
38    }
39
40    pub fn split_mut(&mut self, n: usize, len: usize) -> (Vec<&mut Scratch<BE>>, &mut Self) {
41        assert!(self.available() >= n * len);
42        let mut scratches: Vec<&mut Scratch<BE>> = Vec::with_capacity(n);
43        let mut scratch: &mut Scratch<BE> = self;
44        for _ in 0..n {
45            let (tmp, scratch_new) = scratch.split_at_mut(len);
46            scratch = scratch_new;
47            scratches.push(tmp);
48        }
49        (scratches, scratch)
50    }
51}
52
53impl<B: Backend> ScratchTakeBasic for Scratch<B> where Self: TakeSlice + ScratchFromBytes<B> {}
54
55pub trait ScratchTakeBasic
56where
57    Self: TakeSlice,
58{
59    fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
60        let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
61        (ScalarZnx::from_data(take_slice, n, cols), rem_slice)
62    }
63
64    fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
65    where
66        M: SvpPPolBytesOf + ModuleN,
67    {
68        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols));
69        (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
70    }
71
72    fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
73        let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
74        (VecZnx::from_data(take_slice, n, cols, size), rem_slice)
75    }
76
77    fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
78    where
79        M: VecZnxBigBytesOf + ModuleN,
80    {
81        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
82        (
83            VecZnxBig::from_data(take_slice, module.n(), cols, size),
84            rem_slice,
85        )
86    }
87
88    fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
89    where
90        M: VecZnxDftBytesOf + ModuleN,
91    {
92        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
93
94        (
95            VecZnxDft::from_data(take_slice, module.n(), cols, size),
96            rem_slice,
97        )
98    }
99
100    fn take_vec_znx_dft_slice<M, B: Backend>(
101        &mut self,
102        module: &M,
103        len: usize,
104        cols: usize,
105        size: usize,
106    ) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self)
107    where
108        M: VecZnxDftBytesOf + ModuleN,
109    {
110        let mut scratch: &mut Self = self;
111        let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
112        for _ in 0..len {
113            let (znx, new_scratch) = scratch.take_vec_znx_dft(module, cols, size);
114            scratch = new_scratch;
115            slice.push(znx);
116        }
117        (slice, scratch)
118    }
119
120    fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
121        let mut scratch: &mut Self = self;
122        let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
123        for _ in 0..len {
124            let (znx, new_scratch) = scratch.take_vec_znx(n, cols, size);
125            scratch = new_scratch;
126            slice.push(znx);
127        }
128        (slice, scratch)
129    }
130
131    fn take_vmp_pmat<M, B: Backend>(
132        &mut self,
133        module: &M,
134        rows: usize,
135        cols_in: usize,
136        cols_out: usize,
137        size: usize,
138    ) -> (VmpPMat<&mut [u8], B>, &mut Self)
139    where
140        M: VmpPMatBytesOf + ModuleN,
141    {
142        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
143        (
144            VmpPMat::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
145            rem_slice,
146        )
147    }
148
149    fn take_mat_znx(
150        &mut self,
151        n: usize,
152        rows: usize,
153        cols_in: usize,
154        cols_out: usize,
155        size: usize,
156    ) -> (MatZnx<&mut [u8]>, &mut Self) {
157        let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
158        (
159            MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
160            rem_slice,
161        )
162    }
163}