poulpy_hal/delegates/
vmp_pmat.rs

1use crate::{
2    api::{
3        VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc,
4        VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
5    },
6    layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
7    oep::{
8        VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
9        VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
10    },
11};
12
13impl<B> VmpPMatAlloc<B> for Module<B>
14where
15    B: Backend + VmpPMatAllocImpl<B>,
16{
17    fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
18        B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
19    }
20}
21
22impl<B> VmpPMatAllocBytes for Module<B>
23where
24    B: Backend + VmpPMatAllocBytesImpl<B>,
25{
26    fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
27        B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
28    }
29}
30
31impl<B> VmpPMatFromBytes<B> for Module<B>
32where
33    B: Backend + VmpPMatFromBytesImpl<B>,
34{
35    fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
36        B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
37    }
38}
39
40impl<B> VmpPrepareTmpBytes for Module<B>
41where
42    B: Backend + VmpPrepareTmpBytesImpl<B>,
43{
44    fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
45        B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
46    }
47}
48
49impl<B> VmpPrepare<B> for Module<B>
50where
51    B: Backend + VmpPMatPrepareImpl<B>,
52{
53    fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
54    where
55        R: VmpPMatToMut<B>,
56        A: MatZnxToRef,
57    {
58        B::vmp_prepare_impl(self, res, a, scratch)
59    }
60}
61
62impl<B> VmpApplyDftToDftTmpBytes for Module<B>
63where
64    B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,
65{
66    fn vmp_apply_dft_to_dft_tmp_bytes(
67        &self,
68        res_size: usize,
69        a_size: usize,
70        b_rows: usize,
71        b_cols_in: usize,
72        b_cols_out: usize,
73        b_size: usize,
74    ) -> usize {
75        B::vmp_apply_dft_to_dft_tmp_bytes_impl(
76            self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
77        )
78    }
79}
80
81impl<B> VmpApplyDftToDft<B> for Module<B>
82where
83    B: Backend + VmpApplyDftToDftImpl<B>,
84{
85    fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
86    where
87        R: VecZnxDftToMut<B>,
88        A: VecZnxDftToRef<B>,
89        C: VmpPMatToRef<B>,
90    {
91        B::vmp_apply_dft_to_dft_impl(self, res, a, b, scratch);
92    }
93}
94
95impl<B> VmpApplyDftToDftAddTmpBytes for Module<B>
96where
97    B: Backend + VmpApplyDftToDftAddTmpBytesImpl<B>,
98{
99    fn vmp_apply_dft_to_dft_add_tmp_bytes(
100        &self,
101        res_size: usize,
102        a_size: usize,
103        b_rows: usize,
104        b_cols_in: usize,
105        b_cols_out: usize,
106        b_size: usize,
107    ) -> usize {
108        B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(
109            self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
110        )
111    }
112}
113
114impl<B> VmpApplyDftToDftAdd<B> for Module<B>
115where
116    B: Backend + VmpApplyDftToDftAddImpl<B>,
117{
118    fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
119    where
120        R: VecZnxDftToMut<B>,
121        A: VecZnxDftToRef<B>,
122        C: VmpPMatToRef<B>,
123    {
124        B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch);
125    }
126}