1use crate::{
2 api::{
3 VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAccumulate, VmpApplyDftToDftAccumulateTmpBytes,
4 VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes, VmpZero,
5 },
6 layouts::{
7 Backend, MatZnxBackendRef, Module, ScratchArena, VecZnxBackendRef, VecZnxDftBackendMut, VecZnxDftBackendRef,
8 VecZnxDftToBackendMut, VmpPMatBackendMut, VmpPMatBackendRef, VmpPMatOwned,
9 },
10 oep::HalVmpImpl,
11};
12
13macro_rules! impl_vmp_delegate {
14 ($trait:ty, $($body:item)+) => {
15 impl<B> $trait for Module<B>
16 where
17 B: Backend + HalVmpImpl<B>,
18 {
19 $($body)+
20 }
21 };
22}
23
24impl<B: Backend> VmpPMatAlloc<B> for Module<B> {
25 fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
26 VmpPMatOwned::alloc(self.n(), rows, cols_in, cols_out, size)
27 }
28}
29
30impl<B: Backend> VmpPMatBytesOf for Module<B> {
31 fn bytes_of_vmp_pmat(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
32 B::bytes_of_vmp_pmat(self.n(), rows, cols_in, cols_out, size)
33 }
34}
35
36impl_vmp_delegate!(
37 VmpPrepareTmpBytes,
38 fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
39 B::vmp_prepare_tmp_bytes(self, rows, cols_in, cols_out, size)
40 }
41);
42
43impl_vmp_delegate!(
44 VmpPrepare<B>,
45 fn vmp_prepare(&self, res: &mut VmpPMatBackendMut<'_, B>, a: &MatZnxBackendRef<'_, B>, scratch: &mut ScratchArena<'_, B>) {
46 B::vmp_prepare(self, res, a, scratch);
47 }
48);
49
50impl_vmp_delegate!(
51 VmpApplyDftTmpBytes,
52 fn vmp_apply_dft_tmp_bytes(
53 &self,
54 res_size: usize,
55 a_size: usize,
56 b_rows: usize,
57 b_cols_in: usize,
58 b_cols_out: usize,
59 b_size: usize,
60 ) -> usize {
61 B::vmp_apply_dft_tmp_bytes(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
62 }
63);
64
65impl_vmp_delegate!(
66 VmpApplyDft<B>,
67 fn vmp_apply_dft<R>(
68 &self,
69 res: &mut R,
70 a: &VecZnxBackendRef<'_, B>,
71 b: &VmpPMatBackendRef<'_, B>,
72 scratch: &mut ScratchArena<'_, B>,
73 ) where
74 R: VecZnxDftToBackendMut<B>,
75 {
76 B::vmp_apply_dft(self, res, a, b, scratch)
77 }
78);
79
80impl_vmp_delegate!(
81 VmpApplyDftToDftTmpBytes,
82 fn vmp_apply_dft_to_dft_tmp_bytes(
83 &self,
84 res_size: usize,
85 a_size: usize,
86 b_rows: usize,
87 b_cols_in: usize,
88 b_cols_out: usize,
89 b_size: usize,
90 ) -> usize {
91 B::vmp_apply_dft_to_dft_tmp_bytes(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
92 }
93);
94
95impl_vmp_delegate!(
96 VmpApplyDftToDft<B>,
97 fn vmp_apply_dft_to_dft(
98 &self,
99 res: &mut VecZnxDftBackendMut<'_, B>,
100 a: &VecZnxDftBackendRef<'_, B>,
101 b: &VmpPMatBackendRef<'_, B>,
102 limb_offset: usize,
103 scratch: &mut ScratchArena<'_, B>,
104 ) {
105 B::vmp_apply_dft_to_dft(self, res, a, b, limb_offset, scratch)
106 }
107);
108
109impl_vmp_delegate!(
110 VmpApplyDftToDftAccumulateTmpBytes,
111 fn vmp_apply_dft_to_dft_accumulate_tmp_bytes(
112 &self,
113 res_size: usize,
114 a_size: usize,
115 b_rows: usize,
116 b_cols_in: usize,
117 b_cols_out: usize,
118 b_size: usize,
119 ) -> usize {
120 B::vmp_apply_dft_to_dft_accumulate_tmp_bytes(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
121 }
122);
123
124impl_vmp_delegate!(
125 VmpApplyDftToDftAccumulate<B>,
126 fn vmp_apply_dft_to_dft_accumulate(
127 &self,
128 res: &mut VecZnxDftBackendMut<'_, B>,
129 a: &VecZnxDftBackendRef<'_, B>,
130 b: &VmpPMatBackendRef<'_, B>,
131 limb_offset: usize,
132 scratch: &mut ScratchArena<'_, B>,
133 ) {
134 B::vmp_apply_dft_to_dft_accumulate(self, res, a, b, limb_offset, scratch);
135 }
136);
137
138impl_vmp_delegate!(
139 VmpZero<B>,
140 fn vmp_zero(&self, res: &mut VmpPMatBackendMut<'_, B>) {
141 B::vmp_zero(self, res);
142 }
143);