1use crate::{
2 api::{
3 VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc,
4 VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
5 },
6 layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef},
7 oep::{
8 VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
9 VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
10 },
11};
12
13impl<B> VmpPMatAlloc<B> for Module<B>
14where
15 B: Backend + VmpPMatAllocImpl<B>,
16{
17 fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
18 B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
19 }
20}
21
22impl<B> VmpPMatAllocBytes for Module<B>
23where
24 B: Backend + VmpPMatAllocBytesImpl<B>,
25{
26 fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
27 B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
28 }
29}
30
31impl<B> VmpPMatFromBytes<B> for Module<B>
32where
33 B: Backend + VmpPMatFromBytesImpl<B>,
34{
35 fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
36 B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
37 }
38}
39
40impl<B> VmpPrepareTmpBytes for Module<B>
41where
42 B: Backend + VmpPrepareTmpBytesImpl<B>,
43{
44 fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
45 B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
46 }
47}
48
49impl<B> VmpPrepare<B> for Module<B>
50where
51 B: Backend + VmpPMatPrepareImpl<B>,
52{
53 fn vmp_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<B>)
54 where
55 R: VmpPMatToMut<B>,
56 A: MatZnxToRef,
57 {
58 B::vmp_prepare_impl(self, res, a, scratch)
59 }
60}
61
62impl<B> VmpApplyDftToDftTmpBytes for Module<B>
63where
64 B: Backend + VmpApplyDftToDftTmpBytesImpl<B>,
65{
66 fn vmp_apply_dft_to_dft_tmp_bytes(
67 &self,
68 res_size: usize,
69 a_size: usize,
70 b_rows: usize,
71 b_cols_in: usize,
72 b_cols_out: usize,
73 b_size: usize,
74 ) -> usize {
75 B::vmp_apply_dft_to_dft_tmp_bytes_impl(
76 self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
77 )
78 }
79}
80
81impl<B> VmpApplyDftToDft<B> for Module<B>
82where
83 B: Backend + VmpApplyDftToDftImpl<B>,
84{
85 fn vmp_apply_dft_to_dft<R, A, C>(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
86 where
87 R: VecZnxDftToMut<B>,
88 A: VecZnxDftToRef<B>,
89 C: VmpPMatToRef<B>,
90 {
91 B::vmp_apply_dft_to_dft_impl(self, res, a, b, scratch);
92 }
93}
94
95impl<B> VmpApplyDftToDftAddTmpBytes for Module<B>
96where
97 B: Backend + VmpApplyDftToDftAddTmpBytesImpl<B>,
98{
99 fn vmp_apply_dft_to_dft_add_tmp_bytes(
100 &self,
101 res_size: usize,
102 a_size: usize,
103 b_rows: usize,
104 b_cols_in: usize,
105 b_cols_out: usize,
106 b_size: usize,
107 ) -> usize {
108 B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(
109 self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
110 )
111 }
112}
113
114impl<B> VmpApplyDftToDftAdd<B> for Module<B>
115where
116 B: Backend + VmpApplyDftToDftAddImpl<B>,
117{
118 fn vmp_apply_dft_to_dft_add<R, A, C>(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch<B>)
119 where
120 R: VecZnxDftToMut<B>,
121 A: VecZnxDftToRef<B>,
122 C: VmpPMatToRef<B>,
123 {
124 B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch);
125 }
126}