poulpy_hal/delegates/
vmp_pmat.rs

1use crate::{
2    api::{
3        VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes,
4        VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
5    },
6    layouts::{
7        Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut,
8        VmpPMatToRef,
9    },
10    oep::{
11        VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl,
12        VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl,
13        VmpPrepareTmpBytesImpl,
14    },
15};
16
17impl<B> VmpPMatAlloc<B> for Module<B>
18where
19    B: Backend + VmpPMatAllocImpl<B>,
20{
21    fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
22        B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
23    }
24}
25
26impl<B> VmpPMatBytesOf for Module<B>
27where
28    B: Backend + VmpPMatAllocBytesImpl<B>,
29{
30    fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
31        B::vmp_pmat_bytes_of_impl(self.n(), rows, cols_in, cols_out, size)
32    }
33}
34
35impl<B> VmpPMatFromBytes<B> for Module<B>
36where
37    B: Backend + VmpPMatFromBytesImpl<B>,
38{
39    fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
40        B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
41    }
42}
43
44impl<B> VmpPrepareTmpBytes for Module<B>
45where
46    B: Backend + VmpPrepareTmpBytesImpl<B>,
47{
48    fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
49        B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
50    }
51}
52
53impl<B> VmpPrepare<B> for Module<B>
54where
55    B: Backend + VmpPrepareImpl<B>,
56{
57    fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
58    where
59        R: VmpPMatToMut<B>,
60        A: MatZnxToRef,
61    {
62        B::vmp_prepare_impl(self, res, a, scratch)
63    }
64}
65
66impl<B> VmpApplyDftTmpBytes for Module<B>
67where
68    B: Backend + VmpApplyDftTmpBytesImpl<B>,
69{
70    fn vmp_apply_dft_tmp_bytes(
71        &self,
72        res_size: usize,
73        a_size: usize,
74        b_rows: usize,
75        b_cols_in: usize,
76        b_cols_out: usize,
77        b_size: usize,
78    ) -> usize {
79        B::vmp_apply_dft_tmp_bytes_impl(
80            self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
81        )
82    }
83}
84
85impl<B> VmpApplyDft<B> for Module<B>
86where
87    B: Backend + VmpApplyDftImpl<B>,
88{
89    fn vmp_apply_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
90    where
91        R: VecZnxDftToMut<B>,
92        A: VecZnxToRef,
93        C: VmpPMatToRef<B>,
94    {
95        B::vmp_apply_dft_impl(self, res, a, b, scratch);
96    }
97}
98
99impl<B> VmpApplyDftToDftTmpBytes for Module<B>
100where
101    B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,
102{
103    fn vmp_apply_dft_to_dft_tmp_bytes(
104        &self,
105        res_size: usize,
106        a_size: usize,
107        b_rows: usize,
108        b_cols_in: usize,
109        b_cols_out: usize,
110        b_size: usize,
111    ) -> usize {
112        B::vmp_apply_dft_to_dft_tmp_bytes_impl(
113            self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
114        )
115    }
116}
117
118impl<B> VmpApplyDftToDft<B> for Module<B>
119where
120    B: Backend + VmpApplyDftToDftImpl<B>,
121{
122    fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
123    where
124        R: VecZnxDftToMut<B>,
125        A: VecZnxDftToRef<B>,
126        C: VmpPMatToRef<B>,
127    {
128        B::vmp_apply_dft_to_dft_impl(self, res, a, b, scratch);
129    }
130}
131
132impl<B> VmpApplyDftToDftAddTmpBytes for Module<B>
133where
134    B: Backend + VmpApplyDftToDftAddTmpBytesImpl<B>,
135{
136    fn vmp_apply_dft_to_dft_add_tmp_bytes(
137        &self,
138        res_size: usize,
139        a_size: usize,
140        b_rows: usize,
141        b_cols_in: usize,
142        b_cols_out: usize,
143        b_size: usize,
144    ) -> usize {
145        B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(
146            self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
147        )
148    }
149}
150
151impl<B> VmpApplyDftToDftAdd<B> for Module<B>
152where
153    B: Backend + VmpApplyDftToDftAddImpl<B>,
154{
155    fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
156    where
157        R: VecZnxDftToMut<B>,
158        A: VecZnxDftToRef<B>,
159        C: VmpPMatToRef<B>,
160    {
161        B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch);
162    }
163}