custos_math/
matrix_multiply.rs

1#[rustfmt::skip]
2pub trait MatrixMultiply where Self: Sized {
3    fn gemm(m: usize, k: usize, n: usize,
4        a: &[Self], rsa: usize, csa: usize,
5        b: &[Self], rsb: usize, csb: usize,
6        c: &mut [Self], rsc: usize, csc: usize);
7}
8
9#[allow(unused)]
10mod implements {
11    use super::MatrixMultiply;
12
13    #[rustfmt::skip]
14    impl MatrixMultiply for f32 {
15        #[inline]
16        fn gemm(m: usize, k: usize, n: usize,
17            a: &[Self], rsa: usize, csa: usize,
18            b: &[Self], rsb: usize, csb: usize,
19            c: &mut [Self], rsc: usize, csc: usize) {
20
21                #[cfg(not(feature = "matrixmultiply"))]
22                unimplemented!("Activate the matrixmultiply feature");
23            
24                #[cfg(feature = "matrixmultiply")]
25                unsafe {
26                    matrixmultiply::sgemm(m, k, n, 1., a.as_ptr(), rsa as isize, csa as isize, b.as_ptr(), rsb as isize, csb as isize, 1., c.as_mut_ptr(), rsc as isize, csc as isize);
27                }
28                
29        }
30    }
31
32    #[rustfmt::skip]
33    impl MatrixMultiply for f64 {
34        #[inline]
35        fn gemm(m: usize, k: usize, n: usize,
36            a: &[Self], rsa: usize, csa: usize,
37            b: &[Self], rsb: usize, csb: usize,
38            c: &mut [Self], rsc: usize, csc: usize) 
39        {
40            #[cfg(not(feature = "matrixmultiply"))]
41            unimplemented!("Activate the matrixmultiply feature");
42            #[cfg(feature = "matrixmultiply")]
43            unsafe {
44                matrixmultiply::dgemm(m, k, n, 1., a.as_ptr(), rsa as isize, csa as isize, b.as_ptr(), rsb as isize, csb as isize, 1., c.as_mut_ptr(), rsc as isize, csc as isize);
45            }           
46        }
47    }
48}