custos_math/
matrix_multiply.rs1#[rustfmt::skip]
2pub trait MatrixMultiply where Self: Sized {
3 fn gemm(m: usize, k: usize, n: usize,
4 a: &[Self], rsa: usize, csa: usize,
5 b: &[Self], rsb: usize, csb: usize,
6 c: &mut [Self], rsc: usize, csc: usize);
7}
8
9#[allow(unused)]
10mod implements {
11 use super::MatrixMultiply;
12
13 #[rustfmt::skip]
14 impl MatrixMultiply for f32 {
15 #[inline]
16 fn gemm(m: usize, k: usize, n: usize,
17 a: &[Self], rsa: usize, csa: usize,
18 b: &[Self], rsb: usize, csb: usize,
19 c: &mut [Self], rsc: usize, csc: usize) {
20
21 #[cfg(not(feature = "matrixmultiply"))]
22 unimplemented!("Activate the matrixmultiply feature");
23
24 #[cfg(feature = "matrixmultiply")]
25 unsafe {
26 matrixmultiply::sgemm(m, k, n, 1., a.as_ptr(), rsa as isize, csa as isize, b.as_ptr(), rsb as isize, csb as isize, 1., c.as_mut_ptr(), rsc as isize, csc as isize);
27 }
28
29 }
30 }
31
32 #[rustfmt::skip]
33 impl MatrixMultiply for f64 {
34 #[inline]
35 fn gemm(m: usize, k: usize, n: usize,
36 a: &[Self], rsa: usize, csa: usize,
37 b: &[Self], rsb: usize, csb: usize,
38 c: &mut [Self], rsc: usize, csc: usize)
39 {
40 #[cfg(not(feature = "matrixmultiply"))]
41 unimplemented!("Activate the matrixmultiply feature");
42 #[cfg(feature = "matrixmultiply")]
43 unsafe {
44 matrixmultiply::dgemm(m, k, n, 1., a.as_ptr(), rsa as isize, csa as isize, b.as_ptr(), rsb as isize, csb as isize, 1., c.as_mut_ptr(), rsc as isize, csc as isize);
45 }
46 }
47 }
48}