1use crate::{
2 api::{CnvPVecAlloc, CnvPVecBytesOf, Convolution},
3 layouts::{
4 Backend, CnvPVecL, CnvPVecLBackendMut, CnvPVecLBackendRef, CnvPVecR, CnvPVecRBackendMut, CnvPVecRBackendRef, Module,
5 ScratchArena, VecZnxBackendRef, VecZnxBigBackendMut, VecZnxDftBackendMut,
6 },
7 oep::HalConvolutionImpl,
8};
9
10macro_rules! impl_convolution_delegate {
11 ($trait:ty, $($body:item),+ $(,)?) => {
12 impl<BE: Backend> $trait for Module<BE>
13 where
14 BE: HalConvolutionImpl<BE>,
15 {
16 $($body)+
17 }
18 };
19}
20
21impl<BE: Backend> CnvPVecAlloc<BE> for Module<BE> {
22 fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<BE::OwnedBuf, BE> {
23 CnvPVecL::alloc(self.n(), cols, size)
24 }
25
26 fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<BE::OwnedBuf, BE> {
27 CnvPVecR::alloc(self.n(), cols, size)
28 }
29}
30
31impl<BE: Backend> CnvPVecBytesOf for Module<BE> {
32 fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize {
33 BE::bytes_of_cnv_pvec_left(self.n(), cols, size)
34 }
35
36 fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize {
37 BE::bytes_of_cnv_pvec_right(self.n(), cols, size)
38 }
39}
40
41impl_convolution_delegate!(
42 Convolution<BE>,
43 fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
44 <BE as HalConvolutionImpl<BE>>::cnv_prepare_left_tmp_bytes(self, res_size, a_size)
45 },
46 fn cnv_prepare_left(
47 &self,
48 res: &mut CnvPVecLBackendMut<'_, BE>,
49 a: &VecZnxBackendRef<'_, BE>,
50 mask: i64,
51 scratch: &mut ScratchArena<'_, BE>,
52 ) {
53 <BE as HalConvolutionImpl<BE>>::cnv_prepare_left(self, res, a, mask, scratch);
54 },
55 fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
56 <BE as HalConvolutionImpl<BE>>::cnv_prepare_right_tmp_bytes(self, res_size, a_size)
57 },
58 fn cnv_prepare_right(
59 &self,
60 res: &mut CnvPVecRBackendMut<'_, BE>,
61 a: &VecZnxBackendRef<'_, BE>,
62 mask: i64,
63 scratch: &mut ScratchArena<'_, BE>,
64 ) {
65 <BE as HalConvolutionImpl<BE>>::cnv_prepare_right(self, res, a, mask, scratch);
66 },
67 fn cnv_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize {
68 <BE as HalConvolutionImpl<BE>>::cnv_apply_dft_tmp_bytes(self, cnv_offset, res_size, a_size, b_size)
69 },
70 fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, cnv_offset: usize, a_size: usize, b_size: usize) -> usize {
71 <BE as HalConvolutionImpl<BE>>::cnv_by_const_apply_tmp_bytes(self, res_size, cnv_offset, a_size, b_size)
72 },
73 fn cnv_by_const_apply(
74 &self,
75 cnv_offset: usize,
76 res: &mut VecZnxBigBackendMut<'_, BE>,
77 res_col: usize,
78 a: &VecZnxBackendRef<'_, BE>,
79 a_col: usize,
80 b: &VecZnxBackendRef<'_, BE>,
81 b_col: usize,
82 b_coeff: usize,
83 scratch: &mut ScratchArena<'_, BE>,
84 ) {
85 <BE as HalConvolutionImpl<BE>>::cnv_by_const_apply(self, cnv_offset, res, res_col, a, a_col, b, b_col, b_coeff, scratch)
86 },
87 fn cnv_apply_dft(
88 &self,
89 cnv_offset: usize,
90 res: &mut VecZnxDftBackendMut<'_, BE>,
91 res_col: usize,
92 a: &CnvPVecLBackendRef<'_, BE>,
93 a_col: usize,
94 b: &CnvPVecRBackendRef<'_, BE>,
95 b_col: usize,
96 scratch: &mut ScratchArena<'_, BE>,
97 ) {
98 <BE as HalConvolutionImpl<BE>>::cnv_apply_dft(self, cnv_offset, res, res_col, a, a_col, b, b_col, scratch)
99 },
100 fn cnv_pairwise_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize {
101 <BE as HalConvolutionImpl<BE>>::cnv_pairwise_apply_dft_tmp_bytes(self, cnv_offset, res_size, a_size, b_size)
102 },
103 fn cnv_pairwise_apply_dft(
104 &self,
105 cnv_offset: usize,
106 res: &mut VecZnxDftBackendMut<'_, BE>,
107 res_col: usize,
108 a: &CnvPVecLBackendRef<'_, BE>,
109 b: &CnvPVecRBackendRef<'_, BE>,
110 i: usize,
111 j: usize,
112 scratch: &mut ScratchArena<'_, BE>,
113 ) {
114 <BE as HalConvolutionImpl<BE>>::cnv_pairwise_apply_dft(self, cnv_offset, res, res_col, a, b, i, j, scratch)
115 },
116 fn cnv_prepare_self_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
117 <BE as HalConvolutionImpl<BE>>::cnv_prepare_self_tmp_bytes(self, res_size, a_size)
118 },
119 fn cnv_prepare_self(
120 &self,
121 left: &mut CnvPVecLBackendMut<'_, BE>,
122 right: &mut CnvPVecRBackendMut<'_, BE>,
123 a: &VecZnxBackendRef<'_, BE>,
124 mask: i64,
125 scratch: &mut ScratchArena<'_, BE>,
126 ) {
127 <BE as HalConvolutionImpl<BE>>::cnv_prepare_self(self, left, right, a, mask, scratch)
128 }
129);