Skip to main content

poulpy_hal/delegates/
convolution.rs

1use crate::{
2    api::{CnvPVecAlloc, CnvPVecBytesOf, Convolution},
3    layouts::{
4        Backend, CnvPVecL, CnvPVecLBackendMut, CnvPVecLBackendRef, CnvPVecR, CnvPVecRBackendMut, CnvPVecRBackendRef, Module,
5        ScratchArena, VecZnxBackendRef, VecZnxBigBackendMut, VecZnxDftBackendMut,
6    },
7    oep::HalConvolutionImpl,
8};
9
10macro_rules! impl_convolution_delegate {
11    ($trait:ty, $($body:item),+ $(,)?) => {
12        impl<BE: Backend> $trait for Module<BE>
13        where
14            BE: HalConvolutionImpl<BE>,
15        {
16            $($body)+
17        }
18    };
19}
20
21impl<BE: Backend> CnvPVecAlloc<BE> for Module<BE> {
22    fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<BE::OwnedBuf, BE> {
23        CnvPVecL::alloc(self.n(), cols, size)
24    }
25
26    fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<BE::OwnedBuf, BE> {
27        CnvPVecR::alloc(self.n(), cols, size)
28    }
29}
30
31impl<BE: Backend> CnvPVecBytesOf for Module<BE> {
32    fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize {
33        BE::bytes_of_cnv_pvec_left(self.n(), cols, size)
34    }
35
36    fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize {
37        BE::bytes_of_cnv_pvec_right(self.n(), cols, size)
38    }
39}
40
41impl_convolution_delegate!(
42    Convolution<BE>,
43    fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
44        <BE as HalConvolutionImpl<BE>>::cnv_prepare_left_tmp_bytes(self, res_size, a_size)
45    },
46    fn cnv_prepare_left(
47        &self,
48        res: &mut CnvPVecLBackendMut<'_, BE>,
49        a: &VecZnxBackendRef<'_, BE>,
50        mask: i64,
51        scratch: &mut ScratchArena<'_, BE>,
52    ) {
53        <BE as HalConvolutionImpl<BE>>::cnv_prepare_left(self, res, a, mask, scratch);
54    },
55    fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
56        <BE as HalConvolutionImpl<BE>>::cnv_prepare_right_tmp_bytes(self, res_size, a_size)
57    },
58    fn cnv_prepare_right(
59        &self,
60        res: &mut CnvPVecRBackendMut<'_, BE>,
61        a: &VecZnxBackendRef<'_, BE>,
62        mask: i64,
63        scratch: &mut ScratchArena<'_, BE>,
64    ) {
65        <BE as HalConvolutionImpl<BE>>::cnv_prepare_right(self, res, a, mask, scratch);
66    },
67    fn cnv_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize {
68        <BE as HalConvolutionImpl<BE>>::cnv_apply_dft_tmp_bytes(self, cnv_offset, res_size, a_size, b_size)
69    },
70    fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, cnv_offset: usize, a_size: usize, b_size: usize) -> usize {
71        <BE as HalConvolutionImpl<BE>>::cnv_by_const_apply_tmp_bytes(self, res_size, cnv_offset, a_size, b_size)
72    },
73    fn cnv_by_const_apply(
74        &self,
75        cnv_offset: usize,
76        res: &mut VecZnxBigBackendMut<'_, BE>,
77        res_col: usize,
78        a: &VecZnxBackendRef<'_, BE>,
79        a_col: usize,
80        b: &VecZnxBackendRef<'_, BE>,
81        b_col: usize,
82        b_coeff: usize,
83        scratch: &mut ScratchArena<'_, BE>,
84    ) {
85        <BE as HalConvolutionImpl<BE>>::cnv_by_const_apply(self, cnv_offset, res, res_col, a, a_col, b, b_col, b_coeff, scratch)
86    },
87    fn cnv_apply_dft(
88        &self,
89        cnv_offset: usize,
90        res: &mut VecZnxDftBackendMut<'_, BE>,
91        res_col: usize,
92        a: &CnvPVecLBackendRef<'_, BE>,
93        a_col: usize,
94        b: &CnvPVecRBackendRef<'_, BE>,
95        b_col: usize,
96        scratch: &mut ScratchArena<'_, BE>,
97    ) {
98        <BE as HalConvolutionImpl<BE>>::cnv_apply_dft(self, cnv_offset, res, res_col, a, a_col, b, b_col, scratch)
99    },
100    fn cnv_pairwise_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize {
101        <BE as HalConvolutionImpl<BE>>::cnv_pairwise_apply_dft_tmp_bytes(self, cnv_offset, res_size, a_size, b_size)
102    },
103    fn cnv_pairwise_apply_dft(
104        &self,
105        cnv_offset: usize,
106        res: &mut VecZnxDftBackendMut<'_, BE>,
107        res_col: usize,
108        a: &CnvPVecLBackendRef<'_, BE>,
109        b: &CnvPVecRBackendRef<'_, BE>,
110        i: usize,
111        j: usize,
112        scratch: &mut ScratchArena<'_, BE>,
113    ) {
114        <BE as HalConvolutionImpl<BE>>::cnv_pairwise_apply_dft(self, cnv_offset, res, res_col, a, b, i, j, scratch)
115    },
116    fn cnv_prepare_self_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
117        <BE as HalConvolutionImpl<BE>>::cnv_prepare_self_tmp_bytes(self, res_size, a_size)
118    },
119    fn cnv_prepare_self(
120        &self,
121        left: &mut CnvPVecLBackendMut<'_, BE>,
122        right: &mut CnvPVecRBackendMut<'_, BE>,
123        a: &VecZnxBackendRef<'_, BE>,
124        mask: i64,
125        scratch: &mut ScratchArena<'_, BE>,
126    ) {
127        <BE as HalConvolutionImpl<BE>>::cnv_prepare_self(self, left, right, a, mask, scratch)
128    }
129);