Skip to main content

poulpy_hal/layouts/
convolution.rs

1use std::marker::PhantomData;
2
3use crate::layouts::{Backend, Data, DataView, DataViewMut, HostDataRef, ZnxInfos, ZnxView};
4
5#[repr(C)]
6#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
7pub struct CnvPVecShape {
8    n: usize,
9    size: usize,
10    cols: usize,
11}
12
13impl CnvPVecShape {
14    pub const fn new(n: usize, cols: usize, size: usize) -> Self {
15        Self { n, size, cols }
16    }
17
18    pub const fn n(self) -> usize {
19        self.n
20    }
21
22    pub const fn size(self) -> usize {
23        self.size
24    }
25
26    pub const fn cols(self) -> usize {
27        self.cols
28    }
29}
30
31/// Prepared right operand for bivariate convolution.
32///
33/// Holds a polynomial vector in the backend's prepared representation,
34/// ready to be used as the right operand of
35/// [`Convolution::cnv_apply_dft`](crate::api::Convolution::cnv_apply_dft).
36/// Created via [`Convolution::cnv_prepare_right`](crate::api::Convolution::cnv_prepare_right).
37pub struct CnvPVecR<D: Data, BE: Backend> {
38    data: D,
39    shape: CnvPVecShape,
40    _phantom: PhantomData<BE>,
41}
42
43impl<D: Data, BE: Backend> ZnxInfos for CnvPVecR<D, BE> {
44    fn cols(&self) -> usize {
45        self.shape.cols()
46    }
47
48    fn n(&self) -> usize {
49        self.shape.n()
50    }
51
52    fn rows(&self) -> usize {
53        1
54    }
55
56    fn size(&self) -> usize {
57        self.shape.size()
58    }
59}
60
61impl<D: Data, BE: Backend> DataView for CnvPVecR<D, BE> {
62    type D = D;
63    fn data(&self) -> &Self::D {
64        &self.data
65    }
66}
67
68impl<D: Data, B: Backend> DataViewMut for CnvPVecR<D, B> {
69    fn data_mut(&mut self) -> &mut Self::D {
70        &mut self.data
71    }
72}
73
74impl<D: HostDataRef, BE: Backend> ZnxView for CnvPVecR<D, BE> {
75    type Scalar = BE::ScalarPrep;
76}
77
78impl<D: Data, BE: Backend> CnvPVecR<D, BE> {
79    pub fn shape(&self) -> CnvPVecShape {
80        self.shape
81    }
82
83    pub fn n(&self) -> usize {
84        self.shape.n()
85    }
86
87    pub fn cols(&self) -> usize {
88        self.shape.cols()
89    }
90
91    pub fn size(&self) -> usize {
92        self.shape.size()
93    }
94}
95
96impl<B: Backend> CnvPVecR<B::OwnedBuf, B> {
97    pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
98        let data: B::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_cnv_pvec_right(n, cols, size));
99        Self {
100            data,
101            shape: CnvPVecShape::new(n, cols, size),
102            _phantom: PhantomData,
103        }
104    }
105
106    pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
107        let data: Vec<u8> = bytes.into();
108        assert!(data.len() == B::bytes_of_cnv_pvec_right(n, cols, size));
109        let data: B::OwnedBuf = B::from_host_bytes(&data);
110        Self {
111            data,
112            shape: CnvPVecShape::new(n, cols, size),
113            _phantom: PhantomData,
114        }
115    }
116}
117
118impl<D: Data, B: Backend> CnvPVecR<D, B> {
119    pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
120        Self {
121            data,
122            shape: CnvPVecShape::new(n, cols, size),
123            _phantom: PhantomData,
124        }
125    }
126}
127
128/// Prepared left operand for bivariate convolution.
129///
130/// Holds a polynomial vector in the backend's prepared representation,
131/// ready to be used as the left operand of
132/// [`Convolution::cnv_apply_dft`](crate::api::Convolution::cnv_apply_dft).
133/// Created via [`Convolution::cnv_prepare_left`](crate::api::Convolution::cnv_prepare_left).
134pub struct CnvPVecL<D: Data, BE: Backend> {
135    data: D,
136    shape: CnvPVecShape,
137    _phantom: PhantomData<BE>,
138}
139
140impl<D: Data, BE: Backend> ZnxInfos for CnvPVecL<D, BE> {
141    fn cols(&self) -> usize {
142        self.shape.cols()
143    }
144
145    fn n(&self) -> usize {
146        self.shape.n()
147    }
148
149    fn rows(&self) -> usize {
150        1
151    }
152
153    fn size(&self) -> usize {
154        self.shape.size()
155    }
156}
157
158impl<D: Data, BE: Backend> DataView for CnvPVecL<D, BE> {
159    type D = D;
160    fn data(&self) -> &Self::D {
161        &self.data
162    }
163}
164
165impl<D: Data, B: Backend> DataViewMut for CnvPVecL<D, B> {
166    fn data_mut(&mut self) -> &mut Self::D {
167        &mut self.data
168    }
169}
170
171impl<D: HostDataRef, BE: Backend> ZnxView for CnvPVecL<D, BE> {
172    type Scalar = BE::ScalarPrep;
173}
174
175impl<D: Data, BE: Backend> CnvPVecL<D, BE> {
176    pub fn shape(&self) -> CnvPVecShape {
177        self.shape
178    }
179
180    pub fn n(&self) -> usize {
181        self.shape.n()
182    }
183
184    pub fn cols(&self) -> usize {
185        self.shape.cols()
186    }
187
188    pub fn size(&self) -> usize {
189        self.shape.size()
190    }
191}
192
193impl<B: Backend> CnvPVecL<B::OwnedBuf, B> {
194    pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
195        let data: B::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_cnv_pvec_left(n, cols, size));
196        Self {
197            data,
198            shape: CnvPVecShape::new(n, cols, size),
199            _phantom: PhantomData,
200        }
201    }
202
203    pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
204        let data: Vec<u8> = bytes.into();
205        assert!(data.len() == B::bytes_of_cnv_pvec_left(n, cols, size));
206        let data: B::OwnedBuf = B::from_host_bytes(&data);
207        Self {
208            data,
209            shape: CnvPVecShape::new(n, cols, size),
210            _phantom: PhantomData,
211        }
212    }
213}
214
215impl<D: Data, B: Backend> CnvPVecL<D, B> {
216    pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
217        Self {
218            data,
219            shape: CnvPVecShape::new(n, cols, size),
220            _phantom: PhantomData,
221        }
222    }
223}
224
225/// Borrow a `CnvPVecR` as a shared reference view.
226pub type CnvPVecRBackendRef<'a, B> = CnvPVecR<<B as Backend>::BufRef<'a>, B>;
227pub type CnvPVecRBackendMut<'a, B> = CnvPVecR<<B as Backend>::BufMut<'a>, B>;
228pub type CnvPVecLBackendRef<'a, B> = CnvPVecL<<B as Backend>::BufRef<'a>, B>;
229pub type CnvPVecLBackendMut<'a, B> = CnvPVecL<<B as Backend>::BufMut<'a>, B>;
230
231/// Borrow a backend-owned `CnvPVecR` using the backend's native view type.
232pub trait CnvPVecRToBackendRef<BE: Backend> {
233    fn to_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE>;
234}
235
236impl<BE: Backend> CnvPVecRToBackendRef<BE> for CnvPVecR<BE::OwnedBuf, BE> {
237    fn to_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE> {
238        CnvPVecR {
239            data: BE::view(&self.data),
240            shape: self.shape,
241            _phantom: self._phantom,
242        }
243    }
244}
245
246/// Reborrow an already backend-borrowed `CnvPVecR` as a shared backend-native view.
247pub trait CnvPVecRReborrowBackendRef<BE: Backend> {
248    fn reborrow_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE>;
249}
250
251impl<'b, BE: Backend + 'b> CnvPVecRReborrowBackendRef<BE> for CnvPVecR<BE::BufMut<'b>, BE> {
252    fn reborrow_backend_ref(&self) -> CnvPVecRBackendRef<'_, BE> {
253        CnvPVecR {
254            data: BE::view_ref_mut(&self.data),
255            shape: self.shape,
256            _phantom: self._phantom,
257        }
258    }
259}
260
261/// Mutably borrow a backend-owned `CnvPVecR` using the backend's native view type.
262pub trait CnvPVecRToBackendMut<BE: Backend> {
263    fn to_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE>;
264}
265
266impl<BE: Backend> CnvPVecRToBackendMut<BE> for CnvPVecR<BE::OwnedBuf, BE> {
267    fn to_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE> {
268        CnvPVecR {
269            data: BE::view_mut(&mut self.data),
270            shape: self.shape,
271            _phantom: self._phantom,
272        }
273    }
274}
275
276/// Reborrow an already backend-borrowed `CnvPVecR` as a mutable backend-native view.
277pub trait CnvPVecRReborrowBackendMut<BE: Backend> {
278    fn reborrow_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE>;
279}
280
281impl<'b, BE: Backend + 'b> CnvPVecRReborrowBackendMut<BE> for CnvPVecR<BE::BufMut<'b>, BE> {
282    fn reborrow_backend_mut(&mut self) -> CnvPVecRBackendMut<'_, BE> {
283        CnvPVecR {
284            data: BE::view_mut_ref(&mut self.data),
285            shape: self.shape,
286            _phantom: self._phantom,
287        }
288    }
289}
290
291/// Borrow a `CnvPVecL` as a shared reference view.
292pub trait CnvPVecLToBackendRef<BE: Backend> {
293    fn to_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE>;
294}
295
296impl<BE: Backend> CnvPVecLToBackendRef<BE> for CnvPVecL<BE::OwnedBuf, BE> {
297    fn to_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE> {
298        CnvPVecL {
299            data: BE::view(&self.data),
300            shape: self.shape,
301            _phantom: self._phantom,
302        }
303    }
304}
305
306/// Reborrow an already backend-borrowed `CnvPVecL` as a shared backend-native view.
307pub trait CnvPVecLReborrowBackendRef<BE: Backend> {
308    fn reborrow_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE>;
309}
310
311impl<'b, BE: Backend + 'b> CnvPVecLReborrowBackendRef<BE> for CnvPVecL<BE::BufMut<'b>, BE> {
312    fn reborrow_backend_ref(&self) -> CnvPVecLBackendRef<'_, BE> {
313        CnvPVecL {
314            data: BE::view_ref_mut(&self.data),
315            shape: self.shape,
316            _phantom: self._phantom,
317        }
318    }
319}
320
321/// Mutably borrow a backend-owned `CnvPVecL` using the backend's native view type.
322pub trait CnvPVecLToBackendMut<BE: Backend> {
323    fn to_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE>;
324}
325
326impl<BE: Backend> CnvPVecLToBackendMut<BE> for CnvPVecL<BE::OwnedBuf, BE> {
327    fn to_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE> {
328        CnvPVecL {
329            data: BE::view_mut(&mut self.data),
330            shape: self.shape,
331            _phantom: self._phantom,
332        }
333    }
334}
335
336/// Reborrow an already backend-borrowed `CnvPVecL` as a mutable backend-native view.
337pub trait CnvPVecLReborrowBackendMut<BE: Backend> {
338    fn reborrow_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE>;
339}
340
341impl<'b, BE: Backend + 'b> CnvPVecLReborrowBackendMut<BE> for CnvPVecL<BE::BufMut<'b>, BE> {
342    fn reborrow_backend_mut(&mut self) -> CnvPVecLBackendMut<'_, BE> {
343        CnvPVecL {
344            data: BE::view_mut_ref(&mut self.data),
345            shape: self.shape,
346            _phantom: self._phantom,
347        }
348    }
349}