ferrite/tensor/device/cpu/kernels/
blas.rs1use crate::*;
2
3const CBLAS_ROW_MAJOR: u8 = 101;
5const CBLAS_COL_MAJOR: u8 = 102;
6
7const CBLAS_NO_TRANS: u8 = 111;
9const CBLAS_TRANS: u8 = 112;
10const CBLAS_CONJ_TRANS: u8 = 113;
11
12#[link(name = "openblas")] extern "C" {
14 fn cblas_ddot(n: i32, x: *const f64, incx: i32, y: *const f64, incy: i32) -> f64;
15 fn cblas_dgemv(Layout: u8, trans: u8, m: i32, n: i32, alpha: f64, a: *const f64, lda: i32, x: *const f64, incx: i32, beta: f64, y: *mut f64, incy: i32);
16 fn cblas_sgemm(Layout: u8, transa: u8, transb: u8, m: i32, n: i32, k: i32, alpha: f32, a: *const f32, lda: i32, b: *const f32, ldb: i32, beta: f32, c: *mut f32, ldc: i32);
17}
18
19impl BlasOps for CpuStorage {
20 fn matmul(&self, other: &Self, transpose_self: bool, transpose_other: bool) -> Self {
21 if self.shape().len() != 2 { panic!("Can't Matmul on non-matrices"); }
22
23 if transpose_self && (self.shape()[0] != other.shape()[0]) {
25 panic!("Matrix dimensions do not match for multiplication.");
26 } else if transpose_other && (self.shape()[1] != other.shape()[1]) {
27 panic!("Matrix dimensions do not match for multiplication.");
28 } else if !transpose_other && !transpose_self && self.shape()[1] != other.shape()[0] {
29 panic!("Matrix dimensions do not match for multiplication.");
30 }
31
32 let layout = CBLAS_ROW_MAJOR;
33 let trans_a = if transpose_self { CBLAS_TRANS } else { CBLAS_NO_TRANS };
34 let trans_b = if transpose_other { CBLAS_TRANS } else { CBLAS_NO_TRANS };
35
36 let m = if !transpose_self { self.shape()[0] } else { self.shape()[1] };
38 let k = if !transpose_self { self.shape()[1] } else { self.shape()[0] };
39 let n = if !transpose_other { other.shape()[1] } else { other.shape()[0] };
40
41 let (a_data, lda) = self.make_contiguous();
43 let (b_data, ldb) = other.make_contiguous();
44
45 let mut c = vec![0.0; (m * n) as usize];
46 let ldc = n as i32;
47
48 unsafe {
49 cblas_sgemm(
50 layout, trans_a, trans_b,
51 m as i32, n as i32, k as i32, 1.0,
52 a_data.as_ptr(), lda,
53 b_data.as_ptr(), ldb, 0.0,
54 c.as_mut_ptr(), ldc
55 );
56 }
57
58 CpuStorage::new(c, vec![m as usize, n as usize])
59 }
60}