1use crate::attribute::{Diagonal, Side, Symmetry, Transpose};
8use crate::matrix::ll::*;
9use crate::matrix::Matrix;
10use crate::pointer::CPtr;
11use crate::scalar::Scalar;
12use num_complex::{Complex, Complex32, Complex64};
13
14pub trait Gemm: Sized {
15 fn gemm(
16 alpha: &Self,
17 at: Transpose,
18 a: &dyn Matrix<Self>,
19 bt: Transpose,
20 b: &dyn Matrix<Self>,
21 beta: &Self,
22 c: &mut dyn Matrix<Self>,
23 );
24}
25
26macro_rules! gemm_impl(($($t: ident), +) => (
27 $(
28 impl Gemm for $t {
29 fn gemm(alpha: &$t, at: Transpose, a: &dyn Matrix<$t>, bt: Transpose, b: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
30 unsafe {
31 let (m, k) = match at {
32 Transpose::NoTrans => (a.rows(), a.cols()),
33 _ => (a.cols(), a.rows()),
34 };
35
36 let n = match bt {
37 Transpose::NoTrans => b.cols(),
38 _ => b.rows(),
39 };
40
41 prefix!($t, gemm)(a.order(),
42 at, bt,
43 m, n, k,
44 alpha.as_const(),
45 a.as_ptr().as_c_ptr(), a.lead_dim(),
46 b.as_ptr().as_c_ptr(), b.lead_dim(),
47 beta.as_const(),
48 c.as_mut_ptr().as_c_ptr(), c.lead_dim());
49 }
50 }
51 }
52 )+
53));
54
55gemm_impl!(f32, f64, Complex32, Complex64);
56
57#[cfg(test)]
58mod gemm_tests {
59 use crate::attribute::Transpose;
60 use crate::matrix::ops::Gemm;
61 use crate::matrix::tests::M;
62 use std::iter::repeat;
63
64 #[test]
65 fn real() {
66 let a = M(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
67 let b = M(2, 2, vec![-1.0, 3.0, 1.0, 1.0]);
68 let t = Transpose::NoTrans;
69
70 let mut c = M(2, 2, repeat(0.0).take(4).collect());
71 Gemm::gemm(&1f32, t, &a, t, &b, &0f32, &mut c);
72
73 assert_eq!(c.2, vec![1.0, 5.0, 1.0, 13.0]);
74 }
75
76 #[test]
77 fn transpose() {
78 let a = M(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
79 let b = M(2, 3, vec![-1.0, 3.0, 1.0, 1.0, 1.0, 1.0]);
80 let t = Transpose::Trans;
81
82 let mut c = M(2, 2, repeat(0.0).take(4).collect());
83 Gemm::gemm(&1f32, t, &a, t, &b, &0f32, &mut c);
84
85 assert_eq!(c.2, vec![13.0, 9.0, 16.0, 12.0]);
86 }
87}
88
89pub trait Symm: Sized {
90 fn symm(
91 side: Side,
92 symmetry: Symmetry,
93 alpha: &Self,
94 a: &dyn Matrix<Self>,
95 b: &dyn Matrix<Self>,
96 beta: &Self,
97 c: &mut dyn Matrix<Self>,
98 );
99}
100
101pub trait Hemm: Sized {
102 fn hemm(
103 side: Side,
104 symmetry: Symmetry,
105 alpha: &Self,
106 a: &dyn Matrix<Self>,
107 b: &dyn Matrix<Self>,
108 beta: &Self,
109 c: &mut dyn Matrix<Self>,
110 );
111}
112
113macro_rules! symm_impl(($trait_name: ident, $fn_name: ident, $($t: ident), +) => (
114 $(
115 impl $trait_name for $t {
116 fn $fn_name(side: Side, symmetry: Symmetry, alpha: &$t, a: &dyn Matrix<$t>, b: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
117 unsafe {
118 prefix!($t, $fn_name)(a.order(),
119 side, symmetry,
120 a.rows(), b.cols(),
121 alpha.as_const(),
122 a.as_ptr().as_c_ptr(), a.lead_dim(),
123 b.as_ptr().as_c_ptr(), b.lead_dim(),
124 beta.as_const(),
125 c.as_mut_ptr().as_c_ptr(), c.lead_dim());
126 }
127 }
128 }
129 )+
130));
131
132symm_impl!(Symm, symm, f32, f64, Complex32, Complex64);
133symm_impl!(Hemm, hemm, Complex32, Complex64);
134
135pub trait Trmm: Sized {
136 fn trmm(
137 side: Side,
138 symmetry: Symmetry,
139 trans: Transpose,
140 diag: Diagonal,
141 alpha: &Self,
142 a: &dyn Matrix<Self>,
143 b: &mut dyn Matrix<Self>,
144 );
145}
146
147pub trait Trsm: Sized {
148 fn trsm(
149 side: Side,
150 symmetry: Symmetry,
151 trans: Transpose,
152 diag: Diagonal,
153 alpha: &Self,
154 a: &dyn Matrix<Self>,
155 b: &mut dyn Matrix<Self>,
156 );
157}
158
159macro_rules! trmm_impl(($trait_name: ident, $fn_name: ident, $($t: ident), +) => (
160 $(
161 impl $trait_name for $t {
162 fn $fn_name(side: Side, symmetry: Symmetry, trans: Transpose, diag: Diagonal, alpha: &$t, a: &dyn Matrix<$t>, b: &mut dyn Matrix<$t>) {
163 unsafe {
164 prefix!($t, $fn_name)(a.order(),
165 side, symmetry, trans, diag,
166 b.rows(), b.cols(),
167 alpha.as_const(),
168 a.as_ptr().as_c_ptr(), a.lead_dim(),
169 b.as_mut_ptr().as_c_ptr(), b.lead_dim());
170 }
171 }
172 }
173 )+
174));
175
176trmm_impl!(Trmm, trmm, f32, f64, Complex32, Complex64);
177trmm_impl!(Trsm, trsm, Complex32, Complex64);
178
179pub trait Herk: Sized {
180 fn herk(
181 symmetry: Symmetry,
182 trans: Transpose,
183 alpha: &Self,
184 a: &dyn Matrix<Complex<Self>>,
185 beta: &Self,
186 c: &mut dyn Matrix<Complex<Self>>,
187 );
188}
189
190pub trait Her2k: Sized {
191 fn her2k(
192 symmetry: Symmetry,
193 trans: Transpose,
194 alpha: Complex<Self>,
195 a: &dyn Matrix<Complex<Self>>,
196 b: &dyn Matrix<Complex<Self>>,
197 beta: &Self,
198 c: &mut dyn Matrix<Complex<Self>>,
199 );
200}
201
202macro_rules! herk_impl(($($t: ident), +) => (
203 $(
204 impl Herk for $t {
205 fn herk(symmetry: Symmetry, trans: Transpose, alpha: &$t, a: &dyn Matrix<Complex<$t>>, beta: &$t, c: &mut dyn Matrix<Complex<$t>>) {
206 unsafe {
207 prefix!(Complex<$t>, herk)(a.order(),
208 symmetry, trans,
209 a.rows(), a.cols(),
210 *alpha,
211 a.as_ptr().as_c_ptr(), a.lead_dim(),
212 *beta,
213 c.as_mut_ptr().as_c_ptr(), c.lead_dim());
214 }
215 }
216 }
217
218 impl Her2k for $t {
219 fn her2k(symmetry: Symmetry, trans: Transpose, alpha: Complex<$t>, a: &dyn Matrix<Complex<$t>>, b: &dyn Matrix<Complex<$t>>, beta: &$t, c: &mut dyn Matrix<Complex<$t>>) {
220 unsafe {
221 prefix!(Complex<$t>, her2k)(a.order(),
222 symmetry, trans,
223 a.rows(), a.cols(),
224 alpha.as_const(),
225 a.as_ptr().as_c_ptr(), a.lead_dim(),
226 b.as_ptr().as_c_ptr(), b.lead_dim(),
227 *beta,
228 c.as_mut_ptr().as_c_ptr(), c.lead_dim());
229 }
230 }
231 }
232 )+
233));
234
235herk_impl!(f32, f64);
236
237pub trait Syrk: Sized {
238 fn syrk(
239 symmetry: Symmetry,
240 trans: Transpose,
241 alpha: &Self,
242 a: &dyn Matrix<Self>,
243 beta: &Self,
244 c: &mut dyn Matrix<Self>,
245 );
246}
247
248pub trait Syr2k: Sized {
249 fn syr2k(
250 symmetry: Symmetry,
251 trans: Transpose,
252 alpha: &Self,
253 a: &dyn Matrix<Self>,
254 b: &dyn Matrix<Self>,
255 beta: &Self,
256 c: &mut dyn Matrix<Self>,
257 );
258}
259
260macro_rules! syrk_impl(($($t: ident), +) => (
261 $(
262 impl Syrk for $t {
263 fn syrk(symmetry: Symmetry, trans: Transpose, alpha: &$t, a: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
264 unsafe {
265 prefix!($t, syrk)(a.order(),
266 symmetry, trans,
267 a.rows(), a.cols(),
268 alpha.as_const(),
269 a.as_ptr().as_c_ptr(), a.lead_dim(),
270 beta.as_const(),
271 c.as_mut_ptr().as_c_ptr(), c.lead_dim());
272 }
273 }
274 }
275
276 impl Syr2k for $t {
277 fn syr2k(symmetry: Symmetry, trans: Transpose, alpha: &$t, a: &dyn Matrix<$t>, b: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
278 unsafe {
279 prefix!($t, syr2k)(a.order(),
280 symmetry, trans,
281 a.rows(), a.cols(),
282 alpha.as_const(),
283 a.as_ptr().as_c_ptr(), a.lead_dim(),
284 b.as_ptr().as_c_ptr(), b.lead_dim(),
285 beta.as_const(),
286 c.as_mut_ptr().as_c_ptr(), c.lead_dim());
287 }
288 }
289 }
290 )+
291));
292
293syrk_impl!(f32, f64, Complex32, Complex64);