1use crate::{
2 api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
3 layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
4};
5
6pub trait ScratchOwnedAlloc<B: Backend> {
8 fn alloc(size: usize) -> Self;
9}
10
11pub trait ScratchOwnedBorrow<B: Backend> {
13 fn borrow(&mut self) -> &mut Scratch<B>;
14}
15
16pub trait ScratchFromBytes<B: Backend> {
18 fn from_bytes(data: &mut [u8]) -> &mut Scratch<B>;
19}
20
21pub trait ScratchAvailable {
23 fn available(&self) -> usize;
24}
25
26pub trait TakeSlice {
28 fn take_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self);
29}
30
31impl<B: Backend> ScratchTakeBasic for Scratch<B> where Self: TakeSlice {}
32
33pub trait ScratchTakeBasic
34where
35 Self: TakeSlice,
36{
37 fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
38 let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
39 (ScalarZnx::from_data(take_slice, n, cols), rem_slice)
40 }
41
42 fn take_svp_ppol<M, B: Backend>(&mut self, module: &M, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self)
43 where
44 M: SvpPPolBytesOf + ModuleN,
45 {
46 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_svp_ppol(cols));
47 (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
48 }
49
50 fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
51 let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
52 (VecZnx::from_data(take_slice, n, cols, size), rem_slice)
53 }
54
55 fn take_vec_znx_big<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self)
56 where
57 M: VecZnxBigBytesOf + ModuleN,
58 {
59 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
60 (
61 VecZnxBig::from_data(take_slice, module.n(), cols, size),
62 rem_slice,
63 )
64 }
65
66 fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
67 where
68 M: VecZnxDftBytesOf + ModuleN,
69 {
70 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
71
72 (
73 VecZnxDft::from_data(take_slice, module.n(), cols, size),
74 rem_slice,
75 )
76 }
77
78 fn take_vec_znx_dft_slice<M, B: Backend>(
79 &mut self,
80 module: &M,
81 len: usize,
82 cols: usize,
83 size: usize,
84 ) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self)
85 where
86 M: VecZnxDftBytesOf + ModuleN,
87 {
88 let mut scratch: &mut Self = self;
89 let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
90 for _ in 0..len {
91 let (znx, new_scratch) = scratch.take_vec_znx_dft(module, cols, size);
92 scratch = new_scratch;
93 slice.push(znx);
94 }
95 (slice, scratch)
96 }
97
98 fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
99 let mut scratch: &mut Self = self;
100 let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
101 for _ in 0..len {
102 let (znx, new_scratch) = scratch.take_vec_znx(n, cols, size);
103 scratch = new_scratch;
104 slice.push(znx);
105 }
106 (slice, scratch)
107 }
108
109 fn take_vmp_pmat<M, B: Backend>(
110 &mut self,
111 module: &M,
112 rows: usize,
113 cols_in: usize,
114 cols_out: usize,
115 size: usize,
116 ) -> (VmpPMat<&mut [u8], B>, &mut Self)
117 where
118 M: VmpPMatBytesOf + ModuleN,
119 {
120 let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vmp_pmat(rows, cols_in, cols_out, size));
121 (
122 VmpPMat::from_data(take_slice, module.n(), rows, cols_in, cols_out, size),
123 rem_slice,
124 )
125 }
126
127 fn take_mat_znx(
128 &mut self,
129 n: usize,
130 rows: usize,
131 cols_in: usize,
132 cols_out: usize,
133 size: usize,
134 ) -> (MatZnx<&mut [u8]>, &mut Self) {
135 let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
136 (
137 MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
138 rem_slice,
139 )
140 }
141}