1use crate::{
2 api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3 layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
4};
5
6pub trait ScratchOwnedAlloc<B: Backend> {
8 fn alloc(size: usize) -> Self;
9}
10
11pub trait ScratchOwnedBorrow<B: Backend> {
13 fn borrow(&mut self) -> &mut Scratch<B>;
14}
15
16pub trait ScratchFromBytes<B: Backend> {
18 fn from_bytes(data: &mut [u8]) -> &mut Scratch<B>;
19}
20
21pub trait ScratchAvailable {
23 fn available(&self) -> usize;
24}
25
26pub trait TakeSlice {
28 fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
29}
30
31impl<BE: Backend> Scratch<BE>
32where
33 Self: TakeSlice + ScratchAvailable + ScratchFromBytes<BE>,
34{
35 pub fn split_at_mut(&mut self, len: usize) -> (&mut Scratch<BE>, &mut Self) {
36 let (take_slice, rem_slice) = self.take_slice(len);
37 (Self::from_bytes(take_slice), rem_slice)
38 }
39
40 pub fn split_mut(&mut self, n: usize, len: usize) -> (Vec<&mut Scratch<BE>>, &mut Self) {
41 assert!(self.available() >= n * len);
42 let mut scratches: Vec<&mut Scratch<BE>> = Vec::with_capacity(n);
43 let mut scratch: &mut Scratch<BE> = self;
44 for _ in 0..n {
45 let (tmp, scratch_new) = scratch.split_at_mut(len);
46 scratch = scratch_new;
47 scratches.push(tmp);
48 }
49 (scratches, scratch)
50 }
51}
52
53impl<B: Backend> ScratchTakeBasic for Scratch<B> where Self: TakeSlice + ScratchFromBytes<B> {}
54
55pub trait ScratchTakeBasic
56where
57 Self: TakeSlice,
58{
59 fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
60 let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
61 (ScalarZnx::from_data(take_slice, n, cols), rem_slice)
62 }
63
64 fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
65 where
66 M: SvpPPolBytesOf + ModuleN,
67 {
68 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols));
69 (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
70 }
71
72 fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
73 let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
74 (VecZnx::from_data(take_slice, n, cols, size), rem_slice)
75 }
76
77 fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
78 where
79 M: VecZnxBigBytesOf + ModuleN,
80 {
81 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
82 (
83 VecZnxBig::from_data(take_slice, module.n(), cols, size),
84 rem_slice,
85 )
86 }
87
88 fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
89 where
90 M: VecZnxDftBytesOf + ModuleN,
91 {
92 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
93
94 (
95 VecZnxDft::from_data(take_slice, module.n(), cols, size),
96 rem_slice,
97 )
98 }
99
100 fn take_vec_znx_dft_slice<M, B: Backend>(
101 &mut self,
102 module: &M,
103 len: usize,
104 cols: usize,
105 size: usize,
106 ) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self)
107 where
108 M: VecZnxDftBytesOf + ModuleN,
109 {
110 let mut scratch: &mut Self = self;
111 let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
112 for _ in 0..len {
113 let (znx, new_scratch) = scratch.take_vec_znx_dft(module, cols, size);
114 scratch = new_scratch;
115 slice.push(znx);
116 }
117 (slice, scratch)
118 }
119
120 fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
121 let mut scratch: &mut Self = self;
122 let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
123 for _ in 0..len {
124 let (znx, new_scratch) = scratch.take_vec_znx(n, cols, size);
125 scratch = new_scratch;
126 slice.push(znx);
127 }
128 (slice, scratch)
129 }
130
131 fn take_vmp_pmat<M, B: Backend>(
132 &mut self,
133 module: &M,
134 rows: usize,
135 cols_in: usize,
136 cols_out: usize,
137 size: usize,
138 ) -> (VmpPMat<&mut [u8], B>, &mut Self)
139 where
140 M: VmpPMatBytesOf + ModuleN,
141 {
142 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
143 (
144 VmpPMat::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
145 rem_slice,
146 )
147 }
148
149 fn take_mat_znx(
150 &mut self,
151 n: usize,
152 rows: usize,
153 cols_in: usize,
154 cols_out: usize,
155 size: usize,
156 ) -> (MatZnx<&mut [u8]>, &mut Self) {
157 let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
158 (
159 MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
160 rem_slice,
161 )
162 }
163}