rust_blas/matrix/
ops.rs

1// Copyright 2014 Michael Yang. All rights reserved.
2// Use of this source code is governed by a MIT-style
3// license that can be found in the LICENSE file.
4
5//! Wrappers for matrix functions.
6
7use crate::attribute::{Diagonal, Side, Symmetry, Transpose};
8use crate::matrix::ll::*;
9use crate::matrix::Matrix;
10use crate::pointer::CPtr;
11use crate::scalar::Scalar;
12use num_complex::{Complex, Complex32, Complex64};
13
14pub trait Gemm: Sized {
15    fn gemm(
16        alpha: &Self,
17        at: Transpose,
18        a: &dyn Matrix<Self>,
19        bt: Transpose,
20        b: &dyn Matrix<Self>,
21        beta: &Self,
22        c: &mut dyn Matrix<Self>,
23    );
24}
25
26macro_rules! gemm_impl(($($t: ident), +) => (
27    $(
28        impl Gemm for $t {
29            fn gemm(alpha: &$t, at: Transpose, a: &dyn Matrix<$t>, bt: Transpose, b: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
30                unsafe {
31                    let (m, k)  = match at {
32                        Transpose::NoTrans => (a.rows(), a.cols()),
33                        _ => (a.cols(), a.rows()),
34                    };
35
36                    let n = match bt {
37                        Transpose::NoTrans => b.cols(),
38                        _ => b.rows(),
39                    };
40
41                    prefix!($t, gemm)(a.order(),
42                        at, bt,
43                        m, n, k,
44                        alpha.as_const(),
45                        a.as_ptr().as_c_ptr(), a.lead_dim(),
46                        b.as_ptr().as_c_ptr(), b.lead_dim(),
47                        beta.as_const(),
48                        c.as_mut_ptr().as_c_ptr(), c.lead_dim());
49                }
50            }
51        }
52    )+
53));
54
55gemm_impl!(f32, f64, Complex32, Complex64);
56
57#[cfg(test)]
58mod gemm_tests {
59    use crate::attribute::Transpose;
60    use crate::matrix::ops::Gemm;
61    use crate::matrix::tests::M;
62    use std::iter::repeat;
63
64    #[test]
65    fn real() {
66        let a = M(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
67        let b = M(2, 2, vec![-1.0, 3.0, 1.0, 1.0]);
68        let t = Transpose::NoTrans;
69
70        let mut c = M(2, 2, repeat(0.0).take(4).collect());
71        Gemm::gemm(&1f32, t, &a, t, &b, &0f32, &mut c);
72
73        assert_eq!(c.2, vec![1.0, 5.0, 1.0, 13.0]);
74    }
75
76    #[test]
77    fn transpose() {
78        let a = M(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
79        let b = M(2, 3, vec![-1.0, 3.0, 1.0, 1.0, 1.0, 1.0]);
80        let t = Transpose::Trans;
81
82        let mut c = M(2, 2, repeat(0.0).take(4).collect());
83        Gemm::gemm(&1f32, t, &a, t, &b, &0f32, &mut c);
84
85        assert_eq!(c.2, vec![13.0, 9.0, 16.0, 12.0]);
86    }
87}
88
89pub trait Symm: Sized {
90    fn symm(
91        side: Side,
92        symmetry: Symmetry,
93        alpha: &Self,
94        a: &dyn Matrix<Self>,
95        b: &dyn Matrix<Self>,
96        beta: &Self,
97        c: &mut dyn Matrix<Self>,
98    );
99}
100
101pub trait Hemm: Sized {
102    fn hemm(
103        side: Side,
104        symmetry: Symmetry,
105        alpha: &Self,
106        a: &dyn Matrix<Self>,
107        b: &dyn Matrix<Self>,
108        beta: &Self,
109        c: &mut dyn Matrix<Self>,
110    );
111}
112
113macro_rules! symm_impl(($trait_name: ident, $fn_name: ident, $($t: ident), +) => (
114    $(
115        impl $trait_name for $t {
116            fn $fn_name(side: Side, symmetry: Symmetry, alpha: &$t, a: &dyn Matrix<$t>, b: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
117                unsafe {
118                    prefix!($t, $fn_name)(a.order(),
119                        side, symmetry,
120                        a.rows(), b.cols(),
121                        alpha.as_const(),
122                        a.as_ptr().as_c_ptr(), a.lead_dim(),
123                        b.as_ptr().as_c_ptr(), b.lead_dim(),
124                        beta.as_const(),
125                        c.as_mut_ptr().as_c_ptr(), c.lead_dim());
126                }
127            }
128        }
129    )+
130));
131
132symm_impl!(Symm, symm, f32, f64, Complex32, Complex64);
133symm_impl!(Hemm, hemm, Complex32, Complex64);
134
135pub trait Trmm: Sized {
136    fn trmm(
137        side: Side,
138        symmetry: Symmetry,
139        trans: Transpose,
140        diag: Diagonal,
141        alpha: &Self,
142        a: &dyn Matrix<Self>,
143        b: &mut dyn Matrix<Self>,
144    );
145}
146
147pub trait Trsm: Sized {
148    fn trsm(
149        side: Side,
150        symmetry: Symmetry,
151        trans: Transpose,
152        diag: Diagonal,
153        alpha: &Self,
154        a: &dyn Matrix<Self>,
155        b: &mut dyn Matrix<Self>,
156    );
157}
158
159macro_rules! trmm_impl(($trait_name: ident, $fn_name: ident, $($t: ident), +) => (
160    $(
161        impl $trait_name for $t {
162            fn $fn_name(side: Side, symmetry: Symmetry, trans: Transpose, diag: Diagonal, alpha: &$t, a: &dyn Matrix<$t>, b: &mut dyn Matrix<$t>) {
163                unsafe {
164                    prefix!($t, $fn_name)(a.order(),
165                        side, symmetry, trans, diag,
166                        b.rows(), b.cols(),
167                        alpha.as_const(),
168                        a.as_ptr().as_c_ptr(), a.lead_dim(),
169                        b.as_mut_ptr().as_c_ptr(), b.lead_dim());
170                }
171            }
172        }
173    )+
174));
175
176trmm_impl!(Trmm, trmm, f32, f64, Complex32, Complex64);
177trmm_impl!(Trsm, trsm, Complex32, Complex64);
178
179pub trait Herk: Sized {
180    fn herk(
181        symmetry: Symmetry,
182        trans: Transpose,
183        alpha: &Self,
184        a: &dyn Matrix<Complex<Self>>,
185        beta: &Self,
186        c: &mut dyn Matrix<Complex<Self>>,
187    );
188}
189
190pub trait Her2k: Sized {
191    fn her2k(
192        symmetry: Symmetry,
193        trans: Transpose,
194        alpha: Complex<Self>,
195        a: &dyn Matrix<Complex<Self>>,
196        b: &dyn Matrix<Complex<Self>>,
197        beta: &Self,
198        c: &mut dyn Matrix<Complex<Self>>,
199    );
200}
201
202macro_rules! herk_impl(($($t: ident), +) => (
203    $(
204        impl Herk for $t {
205            fn herk(symmetry: Symmetry, trans: Transpose, alpha: &$t, a: &dyn Matrix<Complex<$t>>, beta: &$t, c: &mut dyn Matrix<Complex<$t>>) {
206                unsafe {
207                    prefix!(Complex<$t>, herk)(a.order(),
208                        symmetry, trans,
209                        a.rows(), a.cols(),
210                        *alpha,
211                        a.as_ptr().as_c_ptr(), a.lead_dim(),
212                        *beta,
213                        c.as_mut_ptr().as_c_ptr(), c.lead_dim());
214                }
215            }
216        }
217
218        impl Her2k for $t {
219            fn her2k(symmetry: Symmetry, trans: Transpose, alpha: Complex<$t>, a: &dyn Matrix<Complex<$t>>, b: &dyn Matrix<Complex<$t>>, beta: &$t, c: &mut dyn Matrix<Complex<$t>>) {
220                unsafe {
221                    prefix!(Complex<$t>, her2k)(a.order(),
222                        symmetry, trans,
223                        a.rows(), a.cols(),
224                        alpha.as_const(),
225                        a.as_ptr().as_c_ptr(), a.lead_dim(),
226                        b.as_ptr().as_c_ptr(), b.lead_dim(),
227                        *beta,
228                        c.as_mut_ptr().as_c_ptr(), c.lead_dim());
229                }
230            }
231        }
232    )+
233));
234
235herk_impl!(f32, f64);
236
237pub trait Syrk: Sized {
238    fn syrk(
239        symmetry: Symmetry,
240        trans: Transpose,
241        alpha: &Self,
242        a: &dyn Matrix<Self>,
243        beta: &Self,
244        c: &mut dyn Matrix<Self>,
245    );
246}
247
248pub trait Syr2k: Sized {
249    fn syr2k(
250        symmetry: Symmetry,
251        trans: Transpose,
252        alpha: &Self,
253        a: &dyn Matrix<Self>,
254        b: &dyn Matrix<Self>,
255        beta: &Self,
256        c: &mut dyn Matrix<Self>,
257    );
258}
259
260macro_rules! syrk_impl(($($t: ident), +) => (
261    $(
262        impl Syrk for $t {
263            fn syrk(symmetry: Symmetry, trans: Transpose, alpha: &$t, a: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
264                unsafe {
265                    prefix!($t, syrk)(a.order(),
266                        symmetry, trans,
267                        a.rows(), a.cols(),
268                        alpha.as_const(),
269                        a.as_ptr().as_c_ptr(), a.lead_dim(),
270                        beta.as_const(),
271                        c.as_mut_ptr().as_c_ptr(), c.lead_dim());
272                }
273            }
274        }
275
276        impl Syr2k for $t {
277            fn syr2k(symmetry: Symmetry, trans: Transpose, alpha: &$t, a: &dyn Matrix<$t>, b: &dyn Matrix<$t>, beta: &$t, c: &mut dyn Matrix<$t>) {
278                unsafe {
279                    prefix!($t, syr2k)(a.order(),
280                        symmetry, trans,
281                        a.rows(), a.cols(),
282                        alpha.as_const(),
283                        a.as_ptr().as_c_ptr(), a.lead_dim(),
284                        b.as_ptr().as_c_ptr(), b.lead_dim(),
285                        beta.as_const(),
286                        c.as_mut_ptr().as_c_ptr(), c.lead_dim());
287                }
288            }
289        }
290    )+
291));
292
293syrk_impl!(f32, f64, Complex32, Complex64);