1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
use super::*;
use russell_openblas::*;
use std::convert::TryInto;
pub fn mat_mat_mul(c: &mut Matrix, alpha: f64, a: &Matrix, b: &Matrix) {
if a.nrow != c.nrow {
panic!("the number of rows of matrix [a] (={}) must be equal to the number of rows of matrix [c] (={})", a.nrow, c.nrow);
}
if b.nrow != a.ncol {
panic!("the number of rows of matrix [b] (={}) must be equal to the number of columns of matrix [a] (={})", b.nrow, a.ncol);
}
if b.ncol != c.ncol {
panic!("the number of columns of matrix [b] (={}) must be equal to the number of columns of matrix [c] (={})", b.ncol, c.ncol);
}
let m_i32: i32 = c.nrow.try_into().unwrap();
let n_i32: i32 = c.ncol.try_into().unwrap();
let k_i32: i32 = a.ncol.try_into().unwrap();
let lda_i32: i32 = a.nrow.try_into().unwrap();
let ldb_i32: i32 = b.nrow.try_into().unwrap();
dgemm(
false,
false,
m_i32,
n_i32,
k_i32,
alpha,
&a.data,
lda_i32,
&b.data,
ldb_i32,
0.0,
&mut c.data,
m_i32,
);
}
#[cfg(test)]
mod tests {
use super::*;
use russell_chk::*;
#[test]
fn mat_mat_mul_works() {
let a = Matrix::from(&[
&[1.0, 2.00, 3.0],
&[0.5, 0.75, 1.5],
]);
let b = Matrix::from(&[
&[0.1, 0.5, 0.5, 0.75],
&[0.2, 2.0, 2.0, 2.00],
&[0.3, 0.5, 0.5, 0.50],
]);
let mut c = Matrix::new(2, 4);
mat_mat_mul(&mut c, 2.0, &a, &b);
#[rustfmt::skip]
let correct =slice_to_colmajor(&[
&[2.80, 12.0, 12.0, 12.50],
&[1.30, 5.0, 5.0, 5.25],
]);
assert_vec_approx_eq!(c.data, correct, 1e-15);
}
}