poulpy_hal/api/
scratch.rs

1use crate::{
2    api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3    layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
4};
5
6/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
7pub trait ScratchOwnedAlloc<B: Backend> {
8    fn alloc(size: usize) -> Self;
9}
10
11/// Borrows a slice of bytes into a [Scratch].
12pub trait ScratchOwnedBorrow<B: Backend> {
13    fn borrow(&mut self) -> &mut Scratch<B>;
14}
15
16/// Wrap an array of mutable borrowed bytes into a [Scratch].
17pub trait ScratchFromBytes<B: Backend> {
18    fn from_bytes(data: &mut [u8]) -> &mut Scratch<B>;
19}
20
21/// Returns how many bytes left can be taken from the scratch.
22pub trait ScratchAvailable {
23    fn available(&self) -> usize;
24}
25
26/// Takes a slice of bytes from a [Scratch] and return a new [Scratch] minus the taken array of bytes.
27pub trait TakeSlice {
28    fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
29}
30
31impl<B: Backend> ScratchTakeBasic for Scratch<B> where Self: TakeSlice {}
32
33pub trait ScratchTakeBasic
34where
35    Self: TakeSlice,
36{
37    fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
38        let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
39        (ScalarZnx::from_data(take_slice, n, cols), rem_slice)
40    }
41
42    fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
43    where
44        M: SvpPPolBytesOf + ModuleN,
45    {
46        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols));
47        (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
48    }
49
50    fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
51        let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
52        (VecZnx::from_data(take_slice, n, cols, size), rem_slice)
53    }
54
55    fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
56    where
57        M: VecZnxBigBytesOf + ModuleN,
58    {
59        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
60        (
61            VecZnxBig::from_data(take_slice, module.n(), cols, size),
62            rem_slice,
63        )
64    }
65
66    fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
67    where
68        M: VecZnxDftBytesOf + ModuleN,
69    {
70        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
71
72        (
73            VecZnxDft::from_data(take_slice, module.n(), cols, size),
74            rem_slice,
75        )
76    }
77
78    fn take_vec_znx_dft_slice<M, B: Backend>(
79        &mut self,
80        module: &M,
81        len: usize,
82        cols: usize,
83        size: usize,
84    ) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self)
85    where
86        M: VecZnxDftBytesOf + ModuleN,
87    {
88        let mut scratch: &mut Self = self;
89        let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
90        for _ in 0..len {
91            let (znx, new_scratch) = scratch.take_vec_znx_dft(module, cols, size);
92            scratch = new_scratch;
93            slice.push(znx);
94        }
95        (slice, scratch)
96    }
97
98    fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
99        let mut scratch: &mut Self = self;
100        let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
101        for _ in 0..len {
102            let (znx, new_scratch) = scratch.take_vec_znx(n, cols, size);
103            scratch = new_scratch;
104            slice.push(znx);
105        }
106        (slice, scratch)
107    }
108
109    fn take_vmp_pmat<M, B: Backend>(
110        &mut self,
111        module: &M,
112        rows: usize,
113        cols_in: usize,
114        cols_out: usize,
115        size: usize,
116    ) -> (VmpPMat<&mut [u8], B>, &mut Self)
117    where
118        M: VmpPMatBytesOf + ModuleN,
119    {
120        let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
121        (
122            VmpPMat::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
123            rem_slice,
124        )
125    }
126
127    fn take_mat_znx(
128        &mut self,
129        n: usize,
130        rows: usize,
131        cols_in: usize,
132        cols_out: usize,
133        size: usize,
134    ) -> (MatZnx<&mut [u8]>, &mut Self) {
135        let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
136        (
137            MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
138            rem_slice,
139        )
140    }
141}