Skip to main content

poulpy_hal/layouts/
vmp_pmat.rs

1use std::{
2    hash::{DefaultHasher, Hasher},
3    marker::PhantomData,
4};
5
6use crate::layouts::{Backend, Data, DataView, DataViewMut, DigestU64, HostDataRef, ZnxInfos, ZnxView};
7
8#[repr(C)]
9#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
10pub struct VmpPMatShape {
11    n: usize,
12    size: usize,
13    rows: usize,
14    cols_in: usize,
15    cols_out: usize,
16}
17
18impl VmpPMatShape {
19    pub const fn new(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
20        Self {
21            n,
22            size,
23            rows,
24            cols_in,
25            cols_out,
26        }
27    }
28
29    pub const fn n(self) -> usize {
30        self.n
31    }
32
33    pub const fn size(self) -> usize {
34        self.size
35    }
36
37    pub const fn rows(self) -> usize {
38        self.rows
39    }
40
41    pub const fn cols_in(self) -> usize {
42        self.cols_in
43    }
44
45    pub const fn cols_out(self) -> usize {
46        self.cols_out
47    }
48}
49
50/// Prepared (DFT-domain) polynomial matrix for vector-matrix products.
51///
52/// A `VmpPMat` stores a matrix of `rows * cols_in` entries, where each
53/// entry is a [`VecZnxDft`](crate::layouts::VecZnxDft) of `cols_out`
54/// columns and `size` limbs, all in the backend's prepared representation.
55///
56/// Used as the right operand in
57/// [`VmpApplyDftToDft`](crate::api::VmpApplyDftToDft). Create via
58/// [`VmpPrepare`](crate::api::VmpPrepare) from a coefficient-domain
59/// [`MatZnx`](crate::layouts::MatZnx).
60///
61/// Ring degree `n` is always a power of two, so each prepared polynomial's DFT
62/// coefficient count matches vector lane widths relative to buffer alignment.
63#[repr(C)]
64#[derive(PartialEq, Eq, Hash)]
65pub struct VmpPMat<D: Data, B: Backend> {
66    data: D,
67    shape: VmpPMatShape,
68    _phantom: PhantomData<B>,
69}
70
71impl<D: HostDataRef, B: Backend> DigestU64 for VmpPMat<D, B> {
72    fn digest_u64(&self) -> u64 {
73        let mut h: DefaultHasher = DefaultHasher::new();
74        h.write(self.data.as_ref());
75        h.write_usize(self.n());
76        h.write_usize(self.size());
77        h.write_usize(self.rows());
78        h.write_usize(self.cols_in());
79        h.write_usize(self.cols_out());
80        h.finish()
81    }
82}
83
84impl<D: HostDataRef, B: Backend> ZnxView for VmpPMat<D, B> {
85    type Scalar = B::ScalarPrep;
86}
87
88impl<D: Data, B: Backend> ZnxInfos for VmpPMat<D, B> {
89    fn cols(&self) -> usize {
90        self.shape.cols_in()
91    }
92
93    fn rows(&self) -> usize {
94        self.shape.rows()
95    }
96
97    fn n(&self) -> usize {
98        self.shape.n()
99    }
100
101    fn size(&self) -> usize {
102        self.shape.size()
103    }
104
105    fn poly_count(&self) -> usize {
106        self.rows() * self.cols_in() * self.size() * self.cols_out()
107    }
108}
109
110impl<D: Data, B: Backend> DataView for VmpPMat<D, B> {
111    type D = D;
112    fn data(&self) -> &Self::D {
113        &self.data
114    }
115}
116
117impl<D: Data, B: Backend> DataViewMut for VmpPMat<D, B> {
118    fn data_mut(&mut self) -> &mut Self::D {
119        &mut self.data
120    }
121}
122
123impl<D: Data, B: Backend> VmpPMat<D, B> {
124    pub fn shape(&self) -> VmpPMatShape {
125        self.shape
126    }
127
128    pub fn n(&self) -> usize {
129        self.shape.n()
130    }
131
132    pub fn rows(&self) -> usize {
133        self.shape.rows()
134    }
135
136    pub fn size(&self) -> usize {
137        self.shape.size()
138    }
139
140    /// Returns the number of input columns.
141    pub fn cols_in(&self) -> usize {
142        self.shape.cols_in()
143    }
144
145    /// Returns the number of output columns.
146    pub fn cols_out(&self) -> usize {
147        self.shape.cols_out()
148    }
149}
150
151impl<B: Backend> VmpPMat<<B as Backend>::OwnedBuf, B> {
152    pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
153        let data: <B as Backend>::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_vmp_pmat(n, rows, cols_in, cols_out, size));
154        Self {
155            data,
156            shape: VmpPMatShape::new(n, rows, cols_in, cols_out, size),
157            _phantom: PhantomData,
158        }
159    }
160}
161
162/// Owned `VmpPMat` backed by a backend-owned buffer.
163pub type VmpPMatOwned<B> = VmpPMat<<B as Backend>::OwnedBuf, B>;
164/// Immutably borrowed `VmpPMat`.
165pub type VmpPMatRef<'a, B> = VmpPMat<&'a [u8], B>;
166/// Shared backend-native borrow of a `VmpPMat`.
167pub type VmpPMatBackendRef<'a, B> = VmpPMat<<B as Backend>::BufRef<'a>, B>;
168/// Mutable backend-native borrow of a `VmpPMat`.
169pub type VmpPMatBackendMut<'a, B> = VmpPMat<<B as Backend>::BufMut<'a>, B>;
170
171/// Reborrow an immutable backend-native `VmpPMat` view as a shared backend-native view.
172pub fn vmp_pmat_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(pmat: &'a VmpPMat<B::BufRef<'b>, B>) -> VmpPMatBackendRef<'a, B> {
173    VmpPMat {
174        data: B::view_ref(&pmat.data),
175        shape: pmat.shape,
176        _phantom: PhantomData,
177    }
178}
179
180/// Reborrow a mutable backend-native `VmpPMat` view as a shared backend-native view.
181pub fn vmp_pmat_backend_ref_from_mut<'a, B: Backend>(pmat: &'a VmpPMatBackendMut<'a, B>) -> VmpPMatBackendRef<'a, B> {
182    VmpPMat {
183        data: B::view_ref_mut(&pmat.data),
184        shape: pmat.shape,
185        _phantom: PhantomData,
186    }
187}
188
189pub fn vmp_pmat_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(
190    pmat: &'a mut VmpPMatBackendMut<'b, B>,
191) -> VmpPMatBackendMut<'a, B> {
192    VmpPMat {
193        data: B::view_mut_ref(&mut pmat.data),
194        shape: pmat.shape,
195        _phantom: PhantomData,
196    }
197}
198
199/// Borrow a backend-owned `VmpPMat` using the backend's native view type.
200pub trait VmpPMatToBackendRef<B: Backend> {
201    fn to_backend_ref(&self) -> VmpPMatBackendRef<'_, B>;
202}
203
204impl<B: Backend> VmpPMatToBackendRef<B> for VmpPMat<B::OwnedBuf, B> {
205    fn to_backend_ref(&self) -> VmpPMatBackendRef<'_, B> {
206        VmpPMat {
207            data: B::view(&self.data),
208            shape: self.shape,
209            _phantom: std::marker::PhantomData,
210        }
211    }
212}
213
214impl<'b, B: Backend + 'b> VmpPMatToBackendRef<B> for &VmpPMat<B::BufRef<'b>, B> {
215    fn to_backend_ref(&self) -> VmpPMatBackendRef<'_, B> {
216        VmpPMat {
217            data: B::view_ref(&self.data),
218            shape: self.shape,
219            _phantom: PhantomData,
220        }
221    }
222}
223
224/// Reborrow an already backend-borrowed `VmpPMat` as a shared backend-native view.
225pub trait VmpPMatReborrowBackendRef<B: Backend> {
226    fn reborrow_backend_ref(&self) -> VmpPMatBackendRef<'_, B>;
227}
228
229impl<'b, B: Backend + 'b> VmpPMatReborrowBackendRef<B> for VmpPMat<B::BufMut<'b>, B> {
230    fn reborrow_backend_ref(&self) -> VmpPMatBackendRef<'_, B> {
231        VmpPMat {
232            data: B::view_ref_mut(&self.data),
233            shape: self.shape,
234            _phantom: std::marker::PhantomData,
235        }
236    }
237}
238
239/// Mutably borrow a backend-owned `VmpPMat` using the backend's native view type.
240pub trait VmpPMatToBackendMut<B: Backend> {
241    fn to_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B>;
242}
243
244impl<B: Backend> VmpPMatToBackendMut<B> for VmpPMat<B::OwnedBuf, B> {
245    fn to_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B> {
246        VmpPMat {
247            data: B::view_mut(&mut self.data),
248            shape: self.shape,
249            _phantom: std::marker::PhantomData,
250        }
251    }
252}
253
254impl<'b, B: Backend + 'b> VmpPMatToBackendMut<B> for &mut VmpPMat<B::BufMut<'b>, B> {
255    fn to_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B> {
256        vmp_pmat_backend_mut_from_mut::<B>(self)
257    }
258}
259
260/// Reborrow an already backend-borrowed `VmpPMat` as a mutable backend-native view.
261pub trait VmpPMatReborrowBackendMut<B: Backend> {
262    fn reborrow_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B>;
263}
264
265impl<'b, B: Backend + 'b> VmpPMatReborrowBackendMut<B> for VmpPMat<B::BufMut<'b>, B> {
266    fn reborrow_backend_mut(&mut self) -> VmpPMatBackendMut<'_, B> {
267        vmp_pmat_backend_mut_from_mut::<B>(self)
268    }
269}
270
271impl<D: Data, B: Backend> VmpPMat<D, B> {
272    pub fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
273        Self {
274            data,
275            shape: VmpPMatShape::new(n, rows, cols_in, cols_out, size),
276            _phantom: PhantomData,
277        }
278    }
279}