opensrdk_linear_algebra/matrix/ge/operations/
dot.rs1use crate::c64;
2use crate::Matrix;
3use blas::dgemm;
4use blas::zgemm;
5
6impl Matrix<f64> {
7    pub fn dot(&self, rhs: &Self) -> Self {
8        let lhs = self;
9        if lhs.cols != rhs.rows {
10            panic!("Dimension mismatch.")
11        }
12
13        let m = lhs.rows as i32;
14        let k = lhs.cols as i32;
15        let n = rhs.cols as i32;
16
17        let mut new_matrix = Matrix::new(lhs.rows, rhs.cols);
18
19        unsafe {
20            dgemm(
21                'N' as u8,
22                'N' as u8,
23                m,
24                n,
25                k,
26                1.0,
27                lhs.elems.as_slice(),
28                m,
29                rhs.elems.as_slice(),
30                k,
31                0.0,
32                &mut new_matrix.elems,
33                m,
34            );
35        }
36
37        new_matrix
38    }
39}
40
41impl Matrix<c64> {
42    pub fn dot(&self, rhs: &Self) -> Self {
43        let lhs = self;
44        if lhs.cols != rhs.rows {
45            panic!("Dimension mismatch.")
46        }
47
48        let m = lhs.rows as i32;
49        let k = lhs.cols as i32;
50        let n = rhs.cols as i32;
51
52        let mut new_matrix = Matrix::<c64>::new(lhs.rows, rhs.cols);
53
54        unsafe {
55            zgemm(
56                'N' as u8,
57                'N' as u8,
58                m,
59                n,
60                k,
61                blas::c64::new(1.0, 0.0),
62                &lhs.elems,
63                m,
64                &rhs.elems,
65                k,
66                blas::c64::new(0.0, 0.0),
67                &mut new_matrix.elems,
68                m,
69            );
70        }
71
72        new_matrix
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use crate::*;
79
80    #[test]
81    fn it_works() {
82        let a = mat!(
83            1.0, 2.0;
84            3.0, 4.0
85        )
86        .dot(&mat!(
87            5.0, 6.0;
88            7.0, 8.0
89        ));
90        assert_eq!(a[(0, 0)], 19.0);
91    }
92}