Skip to main content

poulpy_hal/delegates/
vmp_pmat.rs

1use crate::{
2    api::{
3        VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAccumulate, VmpApplyDftToDftAccumulateTmpBytes,
4        VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes, VmpZero,
5    },
6    layouts::{
7        Backend, MatZnxBackendRef, Module, ScratchArena, VecZnxBackendRef, VecZnxDftBackendMut, VecZnxDftBackendRef,
8        VecZnxDftToBackendMut, VmpPMatBackendMut, VmpPMatBackendRef, VmpPMatOwned,
9    },
10    oep::HalVmpImpl,
11};
12
13macro_rules! impl_vmp_delegate {
14    ($trait:ty, $($body:item)+) => {
15        impl<B> $trait for Module<B>
16        where
17            B: Backend + HalVmpImpl<B>,
18        {
19            $($body)+
20        }
21    };
22}
23
24impl<B: Backend> VmpPMatAlloc<B> for Module<B> {
25    fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
26        VmpPMatOwned::alloc(self.n(), rows, cols_in, cols_out, size)
27    }
28}
29
30impl<B: Backend> VmpPMatBytesOf for Module<B> {
31    fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
32        B::bytes_of_vmp_pmat(self.n(), rows, cols_in, cols_out, size)
33    }
34}
35
36impl_vmp_delegate!(
37    VmpPrepareTmpBytes,
38    fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
39        B::vmp_prepare_tmp_bytes(self, rows, cols_in, cols_out, size)
40    }
41);
42
43impl_vmp_delegate!(
44    VmpPrepare<B>,
45    fn vmp_prepare(&self, res: &mut VmpPMatBackendMut<'_, B>, a: &MatZnxBackendRef<'_, B>, scratch: &mut ScratchArena<'_, B>) {
46        B::vmp_prepare(self, res, a, scratch);
47    }
48);
49
50impl_vmp_delegate!(
51    VmpApplyDftTmpBytes,
52    fn vmp_apply_dft_tmp_bytes(
53        &self,
54        res_size: usize,
55        a_size: usize,
56        b_rows: usize,
57        b_cols_in: usize,
58        b_cols_out: usize,
59        b_size: usize,
60    ) -> usize {
61        B::vmp_apply_dft_tmp_bytes(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
62    }
63);
64
65impl_vmp_delegate!(
66    VmpApplyDft<B>,
67    fn vmp_apply_dft<R>(
68        &self,
69        res: &mut R,
70        a: &VecZnxBackendRef<'_, B>,
71        b: &VmpPMatBackendRef<'_, B>,
72        scratch: &mut ScratchArena<'_, B>,
73    ) where
74        R: VecZnxDftToBackendMut<B>,
75    {
76        B::vmp_apply_dft(self, res, a, b, scratch)
77    }
78);
79
80impl_vmp_delegate!(
81    VmpApplyDftToDftTmpBytes,
82    fn vmp_apply_dft_to_dft_tmp_bytes(
83        &self,
84        res_size: usize,
85        a_size: usize,
86        b_rows: usize,
87        b_cols_in: usize,
88        b_cols_out: usize,
89        b_size: usize,
90    ) -> usize {
91        B::vmp_apply_dft_to_dft_tmp_bytes(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
92    }
93);
94
95impl_vmp_delegate!(
96    VmpApplyDftToDft<B>,
97    fn vmp_apply_dft_to_dft(
98        &self,
99        res: &mut VecZnxDftBackendMut<'_, B>,
100        a: &VecZnxDftBackendRef<'_, B>,
101        b: &VmpPMatBackendRef<'_, B>,
102        limb_offset: usize,
103        scratch: &mut ScratchArena<'_, B>,
104    ) {
105        B::vmp_apply_dft_to_dft(self, res, a, b, limb_offset, scratch)
106    }
107);
108
109impl_vmp_delegate!(
110    VmpApplyDftToDftAccumulateTmpBytes,
111    fn vmp_apply_dft_to_dft_accumulate_tmp_bytes(
112        &self,
113        res_size: usize,
114        a_size: usize,
115        b_rows: usize,
116        b_cols_in: usize,
117        b_cols_out: usize,
118        b_size: usize,
119    ) -> usize {
120        B::vmp_apply_dft_to_dft_accumulate_tmp_bytes(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
121    }
122);
123
124impl_vmp_delegate!(
125    VmpApplyDftToDftAccumulate<B>,
126    fn vmp_apply_dft_to_dft_accumulate(
127        &self,
128        res: &mut VecZnxDftBackendMut<'_, B>,
129        a: &VecZnxDftBackendRef<'_, B>,
130        b: &VmpPMatBackendRef<'_, B>,
131        limb_offset: usize,
132        scratch: &mut ScratchArena<'_, B>,
133    ) {
134        B::vmp_apply_dft_to_dft_accumulate(self, res, a, b, limb_offset, scratch);
135    }
136);
137
138impl_vmp_delegate!(
139    VmpZero<B>,
140    fn vmp_zero(&self, res: &mut VmpPMatBackendMut<'_, B>) {
141        B::vmp_zero(self, res);
142    }
143);