poulpy_hal/delegates/
vmp_pmat.rs

1use crate::{
2    api::{
3        VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes,
4        VmpPrepare, VmpPrepareTmpBytes,
5    },
6    layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
7    oep::{
8        VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
9        VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
10    },
11};
12
13impl<B> VmpPMatAlloc<B> for Module<B>
14where
15    B: Backend + VmpPMatAllocImpl<B>,
16{
17    fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
18        B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
19    }
20}
21
22impl<B> VmpPMatAllocBytes for Module<B>
23where
24    B: Backend + VmpPMatAllocBytesImpl<B>,
25{
26    fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
27        B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
28    }
29}
30
31impl<B> VmpPMatFromBytes<B> for Module<B>
32where
33    B: Backend + VmpPMatFromBytesImpl<B>,
34{
35    fn vmp_pmat_from_bytes(
36        &self,
37        n: usize,
38        rows: usize,
39        cols_in: usize,
40        cols_out: usize,
41        size: usize,
42        bytes: Vec<u8>,
43    ) -> VmpPMatOwned<B> {
44        B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
45    }
46}
47
48impl<B> VmpPrepareTmpBytes for Module<B>
49where
50    B: Backend + VmpPrepareTmpBytesImpl<B>,
51{
52    fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
53        B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
54    }
55}
56
57impl<B> VmpPrepare<B> for Module<B>
58where
59    B: Backend + VmpPMatPrepareImpl<B>,
60{
61    fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
62    where
63        R: VmpPMatToMut<B>,
64        A: MatZnxToRef,
65    {
66        B::vmp_prepare_impl(self, res, a, scratch)
67    }
68}
69
70impl<B> VmpApplyTmpBytes for Module<B>
71where
72    B: Backend + VmpApplyTmpBytesImpl<B>,
73{
74    fn vmp_apply_tmp_bytes(
75        &self,
76        n: usize,
77        res_size: usize,
78        a_size: usize,
79        b_rows: usize,
80        b_cols_in: usize,
81        b_cols_out: usize,
82        b_size: usize,
83    ) -> usize {
84        B::vmp_apply_tmp_bytes_impl(
85            self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
86        )
87    }
88}
89
90impl<B> VmpApply<B> for Module<B>
91where
92    B: Backend + VmpApplyImpl<B>,
93{
94    fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
95    where
96        R: VecZnxDftToMut<B>,
97        A: VecZnxDftToRef<B>,
98        C: VmpPMatToRef<B>,
99    {
100        B::vmp_apply_impl(self, res, a, b, scratch);
101    }
102}
103
104impl<B> VmpApplyAddTmpBytes for Module<B>
105where
106    B: Backend + VmpApplyAddTmpBytesImpl<B>,
107{
108    fn vmp_apply_add_tmp_bytes(
109        &self,
110        n: usize,
111        res_size: usize,
112        a_size: usize,
113        b_rows: usize,
114        b_cols_in: usize,
115        b_cols_out: usize,
116        b_size: usize,
117    ) -> usize {
118        B::vmp_apply_add_tmp_bytes_impl(
119            self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
120        )
121    }
122}
123
124impl<B> VmpApplyAdd<B> for Module<B>
125where
126    B: Backend + VmpApplyAddImpl<B>,
127{
128    fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
129    where
130        R: VecZnxDftToMut<B>,
131        A: VecZnxDftToRef<B>,
132        C: VmpPMatToRef<B>,
133    {
134        B::vmp_apply_add_impl(self, res, a, b, scale, scratch);
135    }
136}