1use std::marker::PhantomData;
2
3use crate::layouts::{Backend, Data, DataView, DataViewMut, HostDataRef, ZnxInfos, ZnxView};
4
5#[repr(C)]
6#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
7pub struct CnvPVecShape {
8 n: usize,
9 size: usize,
10 cols: usize,
11}
12
13impl CnvPVecShape {
14 pub const fn new(n: usize, cols: usize, size: usize) -> Self {
15 Self { n, size, cols }
16 }
17
18 pub const fn n(self) -> usize {
19 self.n
20 }
21
22 pub const fn size(self) -> usize {
23 self.size
24 }
25
26 pub const fn cols(self) -> usize {
27 self.cols
28 }
29}
30
31pub struct CnvPVecR<D: Data, BE: Backend> {
38 data: D,
39 shape: CnvPVecShape,
40 _phantom: PhantomData<BE>,
41}
42
43impl<D: Data, BE: Backend> ZnxInfos for CnvPVecR<D, BE> {
44 fn cols(&self) -> usize {
45 self.shape.cols()
46 }
47
48 fn n(&self) -> usize {
49 self.shape.n()
50 }
51
52 fn rows(&self) -> usize {
53 1
54 }
55
56 fn size(&self) -> usize {
57 self.shape.size()
58 }
59}
60
61impl<D: Data, BE: Backend> DataView for CnvPVecR<D, BE> {
62 type D = D;
63 fn data(&self) -> &Self::D {
64 &self.data
65 }
66}
67
68impl<D: Data, B: Backend> DataViewMut for CnvPVecR<D, B> {
69 fn data_mut(&mut self) -> &mut Self::D {
70 &mut self.data
71 }
72}
73
74impl<D: HostDataRef, BE: Backend> ZnxView for CnvPVecR<D, BE> {
75 type Scalar = BE::ScalarPrep;
76}
77
78impl<D: Data, BE: Backend> CnvPVecR<D, BE> {
79 pub fn shape(&self) -> CnvPVecShape {
80 self.shape
81 }
82
83 pub fn n(&self) -> usize {
84 self.shape.n()
85 }
86
87 pub fn cols(&self) -> usize {
88 self.shape.cols()
89 }
90
91 pub fn size(&self) -> usize {
92 self.shape.size()
93 }
94}
95
96impl<B: Backend> CnvPVecR<B::OwnedBuf, B> {
97 pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
98 let data: B::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_cnv_pvec_right(n, cols, size));
99 Self {
100 data,
101 shape: CnvPVecShape::new(n, cols, size),
102 _phantom: PhantomData,
103 }
104 }
105
106 pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
107 let data: Vec<u8> = bytes.into();
108 assert!(data.len() == B::bytes_of_cnv_pvec_right(n, cols, size));
109 let data: B::OwnedBuf = B::from_host_bytes(&data);
110 Self {
111 data,
112 shape: CnvPVecShape::new(n, cols, size),
113 _phantom: PhantomData,
114 }
115 }
116}
117
118impl<D: Data, B: Backend> CnvPVecR<D, B> {
119 pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
120 Self {
121 data,
122 shape: CnvPVecShape::new(n, cols, size),
123 _phantom: PhantomData,
124 }
125 }
126}
127
128pub struct CnvPVecL<D: Data, BE: Backend> {
135 data: D,
136 shape: CnvPVecShape,
137 _phantom: PhantomData<BE>,
138}
139
140impl<D: Data, BE: Backend> ZnxInfos for CnvPVecL<D, BE> {
141 fn cols(&self) -> usize {
142 self.shape.cols()
143 }
144
145 fn n(&self) -> usize {
146 self.shape.n()
147 }
148
149 fn rows(&self) -> usize {
150 1
151 }
152
153 fn size(&self) -> usize {
154 self.shape.size()
155 }
156}
157
158impl<D: Data, BE: Backend> DataView for CnvPVecL<D, BE> {
159 type D = D;
160 fn data(&self) -> &Self::D {
161 &self.data
162 }
163}
164
165impl<D: Data, B: Backend> DataViewMut for CnvPVecL<D, B> {
166 fn data_mut(&mut self) -> &mut Self::D {
167 &mut self.data
168 }
169}
170
171impl<D: HostDataRef, BE: Backend> ZnxView for CnvPVecL<D, BE> {
172 type Scalar = BE::ScalarPrep;
173}
174
175impl<D: Data, BE: Backend> CnvPVecL<D, BE> {
176 pub fn shape(&self) -> CnvPVecShape {
177 self.shape
178 }
179
180 pub fn n(&self) -> usize {
181 self.shape.n()
182 }
183
184 pub fn cols(&self) -> usize {
185 self.shape.cols()
186 }
187
188 pub fn size(&self) -> usize {
189 self.shape.size()
190 }
191}
192
193impl<B: Backend> CnvPVecL<B::OwnedBuf, B> {
194 pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
195 let data: B::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_cnv_pvec_left(n, cols, size));
196 Self {
197 data,
198 shape: CnvPVecShape::new(n, cols, size),
199 _phantom: PhantomData,
200 }
201 }
202
203 pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
204 let data: Vec<u8> = bytes.into();
205 assert!(data.len() == B::bytes_of_cnv_pvec_left(n, cols, size));
206 let data: B::OwnedBuf = B::from_host_bytes(&data);
207 Self {
208 data,
209 shape: CnvPVecShape::new(n, cols, size),
210 _phantom: PhantomData,
211 }
212 }
213}
214
215impl<D: Data, B: Backend> CnvPVecL<D, B> {
216 pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
217 Self {
218 data,
219 shape: CnvPVecShape::new(n, cols, size),
220 _phantom: PhantomData,
221 }
222 }
223}
224
225pub type CnvPVecRBackendRef<'a, B> = CnvPVecR<<B as Backend>::BufRef<'a>, B>;
227pub type CnvPVecRBackendMut<'a, B> = CnvPVecR<<B as Backend>::BufMut<'a>, B>;
228pub type CnvPVecLBackendRef<'a, B> = CnvPVecL<<B as Backend>::BufRef<'a>, B>;
229pub type CnvPVecLBackendMut<'a, B> = CnvPVecL<<B as Backend>::BufMut<'a>, B>;
230
231pub trait CnvPVecRToBackendRef<BE: Backend> {
233 fn to_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE>;
234}
235
236impl<BE: Backend> CnvPVecRToBackendRef<BE> for CnvPVecR<BE::OwnedBuf, BE> {
237 fn to_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE> {
238 CnvPVecR {
239 data: BE::view(&self.data),
240 shape: self.shape,
241 _phantom: self._phantom,
242 }
243 }
244}
245
246pub trait CnvPVecRReborrowBackendRef<BE: Backend> {
248 fn reborrow_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE>;
249}
250
251impl<'b, BE: Backend + 'b> CnvPVecRReborrowBackendRef<BE> for CnvPVecR<BE::BufMut<'b>, BE> {
252 fn reborrow_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE> {
253 CnvPVecR {
254 data: BE::view_ref_mut(&self.data),
255 shape: self.shape,
256 _phantom: self._phantom,
257 }
258 }
259}
260
261pub trait CnvPVecRToBackendMut<BE: Backend> {
263 fn to_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE>;
264}
265
266impl<BE: Backend> CnvPVecRToBackendMut<BE> for CnvPVecR<BE::OwnedBuf, BE> {
267 fn to_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE> {
268 CnvPVecR {
269 data: BE::view_mut(&mut self.data),
270 shape: self.shape,
271 _phantom: self._phantom,
272 }
273 }
274}
275
276pub trait CnvPVecRReborrowBackendMut<BE: Backend> {
278 fn reborrow_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE>;
279}
280
281impl<'b, BE: Backend + 'b> CnvPVecRReborrowBackendMut<BE> for CnvPVecR<BE::BufMut<'b>, BE> {
282 fn reborrow_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE> {
283 CnvPVecR {
284 data: BE::view_mut_ref(&mut self.data),
285 shape: self.shape,
286 _phantom: self._phantom,
287 }
288 }
289}
290
291pub trait CnvPVecLToBackendRef<BE: Backend> {
293 fn to_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE>;
294}
295
296impl<BE: Backend> CnvPVecLToBackendRef<BE> for CnvPVecL<BE::OwnedBuf, BE> {
297 fn to_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE> {
298 CnvPVecL {
299 data: BE::view(&self.data),
300 shape: self.shape,
301 _phantom: self._phantom,
302 }
303 }
304}
305
306pub trait CnvPVecLReborrowBackendRef<BE: Backend> {
308 fn reborrow_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE>;
309}
310
311impl<'b, BE: Backend + 'b> CnvPVecLReborrowBackendRef<BE> for CnvPVecL<BE::BufMut<'b>, BE> {
312 fn reborrow_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE> {
313 CnvPVecL {
314 data: BE::view_ref_mut(&self.data),
315 shape: self.shape,
316 _phantom: self._phantom,
317 }
318 }
319}
320
321pub trait CnvPVecLToBackendMut<BE: Backend> {
323 fn to_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE>;
324}
325
326impl<BE: Backend> CnvPVecLToBackendMut<BE> for CnvPVecL<BE::OwnedBuf, BE> {
327 fn to_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE> {
328 CnvPVecL {
329 data: BE::view_mut(&mut self.data),
330 shape: self.shape,
331 _phantom: self._phantom,
332 }
333 }
334}
335
336pub trait CnvPVecLReborrowBackendMut<BE: Backend> {
338 fn reborrow_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE>;
339}
340
341impl<'b, BE: Backend + 'b> CnvPVecLReborrowBackendMut<BE> for CnvPVecL<BE::BufMut<'b>, BE> {
342 fn reborrow_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE> {
343 CnvPVecL {
344 data: BE::view_mut_ref(&mut self.data),
345 shape: self.shape,
346 _phantom: self._phantom,
347 }
348 }
349}