opensrdk_linear_algebra/matrix/ge/
mm.rs

1use crate::matrix::MatrixError;
2use crate::{number::c64, Matrix};
3use blas::dgemm;
4use blas::zgemm;
5
6impl Matrix {
7    /// C = self
8    /// A = lhs
9    /// B = rhs
10    /// return alpha*op( A )*op( B ) + beta*C,
11    pub fn gemm(
12        self,
13        lhs: &Matrix,
14        rhs: &Matrix,
15        alpha: f64,
16        beta: f64,
17    ) -> Result<Matrix, MatrixError> {
18        if self.rows != lhs.rows || self.cols != rhs.cols || lhs.cols != rhs.rows {
19            return Err(MatrixError::DimensionMismatch);
20        }
21
22        let m = lhs.rows as i32;
23        let k = lhs.cols as i32;
24        let n = rhs.cols as i32;
25
26        let mut slf = self;
27
28        unsafe {
29            dgemm(
30                'N' as u8,
31                'N' as u8,
32                m,
33                n,
34                k,
35                alpha,
36                lhs.elems.as_slice(),
37                m,
38                rhs.elems.as_slice(),
39                k,
40                beta,
41                &mut slf.elems,
42                m,
43            );
44        }
45
46        Ok(slf)
47    }
48}
49
50impl Matrix<c64> {
51    pub fn gemm(
52        self,
53        lhs: &Matrix<c64>,
54        rhs: &Matrix<c64>,
55        alpha: c64,
56        beta: c64,
57    ) -> Result<Matrix<c64>, MatrixError> {
58        if self.rows != lhs.rows || self.cols != rhs.cols || lhs.cols != rhs.rows {
59            return Err(MatrixError::DimensionMismatch);
60        }
61
62        let m = lhs.rows as i32;
63        let k = lhs.cols as i32;
64        let n = rhs.cols as i32;
65
66        let mut slf = self;
67
68        unsafe {
69            zgemm(
70                'N' as u8,
71                'N' as u8,
72                m,
73                n,
74                k,
75                alpha,
76                rhs.elems.as_slice(),
77                m,
78                lhs.elems.as_slice(),
79                k,
80                beta,
81                &mut slf.elems,
82                m,
83            );
84        }
85
86        Ok(slf)
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use crate::*;
93    #[test]
94    fn it_works() {
95        let a = mat!(
96            1.0, 2.0;
97            3.0, 4.0
98        );
99        let b = mat!(
100            2.0, 1.0;
101            4.0, 3.0
102        );
103        let c = mat!(
104            1.0, 3.0;
105            5.0, 7.0
106        );
107        let alpha = 2.0;
108        let beta = 3.0;
109        let result = c.clone().gemm(&a, &b, alpha, beta).unwrap();
110        let result2 = alpha * a.dot(&b) + beta * c;
111        assert_eq!(result[(0, 0)], result2[(0, 0)]);
112    }
113}