opensrdk_linear_algebra/matrix/ge/
mm.rs1use crate::matrix::MatrixError;
2use crate::{number::c64, Matrix};
3use blas::dgemm;
4use blas::zgemm;
5
6impl Matrix {
7 pub fn gemm(
12 self,
13 lhs: &Matrix,
14 rhs: &Matrix,
15 alpha: f64,
16 beta: f64,
17 ) -> Result<Matrix, MatrixError> {
18 if self.rows != lhs.rows || self.cols != rhs.cols || lhs.cols != rhs.rows {
19 return Err(MatrixError::DimensionMismatch);
20 }
21
22 let m = lhs.rows as i32;
23 let k = lhs.cols as i32;
24 let n = rhs.cols as i32;
25
26 let mut slf = self;
27
28 unsafe {
29 dgemm(
30 'N' as u8,
31 'N' as u8,
32 m,
33 n,
34 k,
35 alpha,
36 lhs.elems.as_slice(),
37 m,
38 rhs.elems.as_slice(),
39 k,
40 beta,
41 &mut slf.elems,
42 m,
43 );
44 }
45
46 Ok(slf)
47 }
48}
49
50impl Matrix<c64> {
51 pub fn gemm(
52 self,
53 lhs: &Matrix<c64>,
54 rhs: &Matrix<c64>,
55 alpha: c64,
56 beta: c64,
57 ) -> Result<Matrix<c64>, MatrixError> {
58 if self.rows != lhs.rows || self.cols != rhs.cols || lhs.cols != rhs.rows {
59 return Err(MatrixError::DimensionMismatch);
60 }
61
62 let m = lhs.rows as i32;
63 let k = lhs.cols as i32;
64 let n = rhs.cols as i32;
65
66 let mut slf = self;
67
68 unsafe {
69 zgemm(
70 'N' as u8,
71 'N' as u8,
72 m,
73 n,
74 k,
75 alpha,
76 rhs.elems.as_slice(),
77 m,
78 lhs.elems.as_slice(),
79 k,
80 beta,
81 &mut slf.elems,
82 m,
83 );
84 }
85
86 Ok(slf)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use crate::*;
93 #[test]
94 fn it_works() {
95 let a = mat!(
96 1.0, 2.0;
97 3.0, 4.0
98 );
99 let b = mat!(
100 2.0, 1.0;
101 4.0, 3.0
102 );
103 let c = mat!(
104 1.0, 3.0;
105 5.0, 7.0
106 );
107 let alpha = 2.0;
108 let beta = 3.0;
109 let result = c.clone().gemm(&a, &b, alpha, beta).unwrap();
110 let result2 = alpha * a.dot(&b) + beta * c;
111 assert_eq!(result[(0, 0)], result2[(0, 0)]);
112 }
113}