1use crate::{
2 api::{
3 VmpApply, VmpApplyAdd, VmpApplyAddTmpBytes, VmpApplyTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes,
4 VmpPrepare, VmpPrepareTmpBytes,
5 },
6 layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
7 oep::{
8 VmpApplyAddImpl, VmpApplyAddTmpBytesImpl, VmpApplyImpl, VmpApplyTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl,
9 VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
10 },
11};
12
13impl<B> VmpPMatAlloc<B> for Module<B>
14where
15 B: Backend + VmpPMatAllocImpl<B>,
16{
17 fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
18 B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
19 }
20}
21
22impl<B> VmpPMatAllocBytes for Module<B>
23where
24 B: Backend + VmpPMatAllocBytesImpl<B>,
25{
26 fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
27 B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
28 }
29}
30
31impl<B> VmpPMatFromBytes<B> for Module<B>
32where
33 B: Backend + VmpPMatFromBytesImpl<B>,
34{
35 fn vmp_pmat_from_bytes(
36 &self,
37 n: usize,
38 rows: usize,
39 cols_in: usize,
40 cols_out: usize,
41 size: usize,
42 bytes: Vec<u8>,
43 ) -> VmpPMatOwned<B> {
44 B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
45 }
46}
47
48impl<B> VmpPrepareTmpBytes for Module<B>
49where
50 B: Backend + VmpPrepareTmpBytesImpl<B>,
51{
52 fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
53 B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
54 }
55}
56
57impl<B> VmpPrepare<B> for Module<B>
58where
59 B: Backend + VmpPMatPrepareImpl<B>,
60{
61 fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
62 where
63 R: VmpPMatToMut<B>,
64 A: MatZnxToRef,
65 {
66 B::vmp_prepare_impl(self, res, a, scratch)
67 }
68}
69
70impl<B> VmpApplyTmpBytes for Module<B>
71where
72 B: Backend + VmpApplyTmpBytesImpl<B>,
73{
74 fn vmp_apply_tmp_bytes(
75 &self,
76 n: usize,
77 res_size: usize,
78 a_size: usize,
79 b_rows: usize,
80 b_cols_in: usize,
81 b_cols_out: usize,
82 b_size: usize,
83 ) -> usize {
84 B::vmp_apply_tmp_bytes_impl(
85 self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
86 )
87 }
88}
89
90impl<B> VmpApply<B> for Module<B>
91where
92 B: Backend + VmpApplyImpl<B>,
93{
94 fn vmp_apply<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
95 where
96 R: VecZnxDftToMut<B>,
97 A: VecZnxDftToRef<B>,
98 C: VmpPMatToRef<B>,
99 {
100 B::vmp_apply_impl(self, res, a, b, scratch);
101 }
102}
103
104impl<B> VmpApplyAddTmpBytes for Module<B>
105where
106 B: Backend + VmpApplyAddTmpBytesImpl<B>,
107{
108 fn vmp_apply_add_tmp_bytes(
109 &self,
110 n: usize,
111 res_size: usize,
112 a_size: usize,
113 b_rows: usize,
114 b_cols_in: usize,
115 b_cols_out: usize,
116 b_size: usize,
117 ) -> usize {
118 B::vmp_apply_add_tmp_bytes_impl(
119 self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
120 )
121 }
122}
123
124impl<B> VmpApplyAdd<B> for Module<B>
125where
126 B: Backend + VmpApplyAddImpl<B>,
127{
128 fn vmp_apply_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
129 where
130 R: VecZnxDftToMut<B>,
131 A: VecZnxDftToRef<B>,
132 C: VmpPMatToRef<B>,
133 {
134 B::vmp_apply_add_impl(self, res, a, b, scale, scratch);
135 }
136}