1use std::{
2 hash::{DefaultHasher, Hasher},
3 marker::PhantomData,
4};
5
6use crate::layouts::{Backend, Data, DataView, DataViewMut, DigestU64, HostDataRef, ZnxInfos, ZnxView};
7
8#[repr(C)]
9#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
10pub struct VmpPMatShape {
11 n: usize,
12 size: usize,
13 rows: usize,
14 cols_in: usize,
15 cols_out: usize,
16}
17
18impl VmpPMatShape {
19 pub const fn new(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
20 Self {
21 n,
22 size,
23 rows,
24 cols_in,
25 cols_out,
26 }
27 }
28
29 pub const fn n(self) -> usize {
30 self.n
31 }
32
33 pub const fn size(self) -> usize {
34 self.size
35 }
36
37 pub const fn rows(self) -> usize {
38 self.rows
39 }
40
41 pub const fn cols_in(self) -> usize {
42 self.cols_in
43 }
44
45 pub const fn cols_out(self) -> usize {
46 self.cols_out
47 }
48}
49
50#[repr(C)]
64#[derive(PartialEq, Eq, Hash)]
65pub struct VmpPMat<D: Data, B: Backend> {
66 data: D,
67 shape: VmpPMatShape,
68 _phantom: PhantomData<B>,
69}
70
71impl<D: HostDataRef, B: Backend> DigestU64 for VmpPMat<D, B> {
72 fn digest_u64(&self) -> u64 {
73 let mut h: DefaultHasher = DefaultHasher::new();
74 h.write(self.data.as_ref());
75 h.write_usize(self.n());
76 h.write_usize(self.size());
77 h.write_usize(self.rows());
78 h.write_usize(self.cols_in());
79 h.write_usize(self.cols_out());
80 h.finish()
81 }
82}
83
84impl<D: HostDataRef, B: Backend> ZnxView for VmpPMat<D, B> {
85 type Scalar = B::ScalarPrep;
86}
87
88impl<D: Data, B: Backend> ZnxInfos for VmpPMat<D, B> {
89 fn cols(&self) -> usize {
90 self.shape.cols_in()
91 }
92
93 fn rows(&self) -> usize {
94 self.shape.rows()
95 }
96
97 fn n(&self) -> usize {
98 self.shape.n()
99 }
100
101 fn size(&self) -> usize {
102 self.shape.size()
103 }
104
105 fn poly_count(&self) -> usize {
106 self.rows() * self.cols_in() * self.size() * self.cols_out()
107 }
108}
109
110impl<D: Data, B: Backend> DataView for VmpPMat<D, B> {
111 type D = D;
112 fn data(&self) -> &Self::D {
113 &self.data
114 }
115}
116
117impl<D: Data, B: Backend> DataViewMut for VmpPMat<D, B> {
118 fn data_mut(&mut self) -> &mut Self::D {
119 &mut self.data
120 }
121}
122
123impl<D: Data, B: Backend> VmpPMat<D, B> {
124 pub fn shape(&self) -> VmpPMatShape {
125 self.shape
126 }
127
128 pub fn n(&self) -> usize {
129 self.shape.n()
130 }
131
132 pub fn rows(&self) -> usize {
133 self.shape.rows()
134 }
135
136 pub fn size(&self) -> usize {
137 self.shape.size()
138 }
139
140 pub fn cols_in(&self) -> usize {
142 self.shape.cols_in()
143 }
144
145 pub fn cols_out(&self) -> usize {
147 self.shape.cols_out()
148 }
149}
150
151impl<B: Backend> VmpPMat<<B as Backend>::OwnedBuf, B> {
152 pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
153 let data: <B as Backend>::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_vmp_pmat(n, rows, cols_in, cols_out, size));
154 Self {
155 data,
156 shape: VmpPMatShape::new(n, rows, cols_in, cols_out, size),
157 _phantom: PhantomData,
158 }
159 }
160}
161
162pub type VmpPMatOwned<B> = VmpPMat<<B as Backend>::OwnedBuf, B>;
164pub type VmpPMatRef<'a, B> = VmpPMat<&'a [u8], B>;
166pub type VmpPMatBackendRef<'a, B> = VmpPMat<<B as Backend>::BufRef<'a>, B>;
168pub type VmpPMatBackendMut<'a, B> = VmpPMat<<B as Backend>::BufMut<'a>, B>;
170
171pub fn vmp_pmat_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(pmat: &'a VmpPMat<B::BufRef<'b>, B>) -> VmpPMatBackendRef<'a, B> {
173 VmpPMat {
174 data: B::view_ref(&pmat.data),
175 shape: pmat.shape,
176 _phantom: PhantomData,
177 }
178}
179
180pub fn vmp_pmat_backend_ref_from_mut<'a, B: Backend>(pmat: &'a VmpPMatBackendMut<'a, B>) -> VmpPMatBackendRef<'a, B> {
182 VmpPMat {
183 data: B::view_ref_mut(&pmat.data),
184 shape: pmat.shape,
185 _phantom: PhantomData,
186 }
187}
188
189pub fn vmp_pmat_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(
190 pmat: &'a mut VmpPMatBackendMut<'b, B>,
191) -> VmpPMatBackendMut<'a, B> {
192 VmpPMat {
193 data: B::view_mut_ref(&mut pmat.data),
194 shape: pmat.shape,
195 _phantom: PhantomData,
196 }
197}
198
199pub trait VmpPMatToBackendRef<B: Backend> {
201 fn to_backend_ref(&self) -> VmpPMatBackendRef<'_, B>;
202}
203
204impl<B: Backend> VmpPMatToBackendRef<B> for VmpPMat<B::OwnedBuf, B> {
205 fn to_backend_ref(&self) -> VmpPMatBackendRef<'_, B> {
206 VmpPMat {
207 data: B::view(&self.data),
208 shape: self.shape,
209 _phantom: std::marker::PhantomData,
210 }
211 }
212}
213
214impl<'b, B: Backend + 'b> VmpPMatToBackendRef<B> for &VmpPMat<B::BufRef<'b>, B> {
215 fn to_backend_ref(&self) -> VmpPMatBackendRef<'_, B> {
216 VmpPMat {
217 data: B::view_ref(&self.data),
218 shape: self.shape,
219 _phantom: PhantomData,
220 }
221 }
222}
223
224pub trait VmpPMatReborrowBackendRef<B: Backend> {
226 fn reborrow_backend_ref(&self) -> VmpPMatBackendRef<'_, B>;
227}
228
229impl<'b, B: Backend + 'b> VmpPMatReborrowBackendRef<B> for VmpPMat<B::BufMut<'b>, B> {
230 fn reborrow_backend_ref(&self) -> VmpPMatBackendRef<'_, B> {
231 VmpPMat {
232 data: B::view_ref_mut(&self.data),
233 shape: self.shape,
234 _phantom: std::marker::PhantomData,
235 }
236 }
237}
238
239pub trait VmpPMatToBackendMut<B: Backend> {
241 fn to_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B>;
242}
243
244impl<B: Backend> VmpPMatToBackendMut<B> for VmpPMat<B::OwnedBuf, B> {
245 fn to_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B> {
246 VmpPMat {
247 data: B::view_mut(&mut self.data),
248 shape: self.shape,
249 _phantom: std::marker::PhantomData,
250 }
251 }
252}
253
254impl<'b, B: Backend + 'b> VmpPMatToBackendMut<B> for &mut VmpPMat<B::BufMut<'b>, B> {
255 fn to_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B> {
256 vmp_pmat_backend_mut_from_mut::<B>(self)
257 }
258}
259
260pub trait VmpPMatReborrowBackendMut<B: Backend> {
262 fn reborrow_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B>;
263}
264
265impl<'b, B: Backend + 'b> VmpPMatReborrowBackendMut<B> for VmpPMat<B::BufMut<'b>, B> {
266 fn reborrow_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B> {
267 vmp_pmat_backend_mut_from_mut::<B>(self)
268 }
269}
270
271impl<D: Data, B: Backend> VmpPMat<D, B> {
272 pub fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
273 Self {
274 data,
275 shape: VmpPMatShape::new(n, rows, cols_in, cols_out, size),
276 _phantom: PhantomData,
277 }
278 }
279}