ferrite/tensor/device/cpu/kernels/
blas.rs

1use crate::*;
2
3// CBLAS_LAYOUT
4const CBLAS_ROW_MAJOR: u8 = 101;
5const CBLAS_COL_MAJOR: u8 = 102;
6
7// CBLAS_TRANSPOSE
8const CBLAS_NO_TRANS: u8 = 111;
9const CBLAS_TRANS: u8 = 112;
10const CBLAS_CONJ_TRANS: u8 = 113;
11
12#[link(name = "openblas")] // Replace "openblas" with the library you installed if different
13extern "C" {
14  fn cblas_ddot(n: i32, x: *const f64, incx: i32, y: *const f64, incy: i32) -> f64;
15  fn cblas_dgemv(Layout: u8, trans: u8, m: i32, n: i32, alpha: f64, a: *const f64, lda: i32, x: *const f64, incx: i32, beta: f64, y: *mut f64, incy: i32);
16  fn cblas_sgemm(Layout: u8, transa: u8, transb: u8, m: i32, n: i32, k: i32, alpha: f32, a: *const f32, lda: i32, b: *const f32, ldb: i32, beta: f32, c: *mut f32, ldc: i32);
17}
18
19impl BlasOps for CpuStorage {
20  fn matmul(&self, other: &Self, transpose_self: bool, transpose_other: bool) -> Self {
21    if self.shape().len() != 2 { panic!("Can't Matmul on non-matrices"); }
22
23    // Check dimensions
24    if transpose_self && (self.shape()[0] != other.shape()[0]) { 
25      panic!("Matrix dimensions do not match for multiplication.");
26    } else if transpose_other && (self.shape()[1] != other.shape()[1]) { 
27      panic!("Matrix dimensions do not match for multiplication.");
28    } else if !transpose_other && !transpose_self && self.shape()[1] != other.shape()[0] { 
29      panic!("Matrix dimensions do not match for multiplication.");
30    }
31
32    let layout = CBLAS_ROW_MAJOR;
33    let trans_a = if transpose_self { CBLAS_TRANS } else { CBLAS_NO_TRANS };
34    let trans_b = if transpose_other { CBLAS_TRANS } else { CBLAS_NO_TRANS };
35    
36    // Get dimensions
37    let m = if !transpose_self { self.shape()[0] } else { self.shape()[1] };
38    let k = if !transpose_self { self.shape()[1] } else { self.shape()[0] };
39    let n = if !transpose_other { other.shape()[1] } else { other.shape()[0] };
40
41    // Get contiguous data
42    let (a_data, lda) = self.make_contiguous();
43    let (b_data, ldb) = other.make_contiguous();
44
45    let mut c = vec![0.0; (m * n) as usize];
46    let ldc = n as i32;
47
48    unsafe {
49      cblas_sgemm(
50        layout, trans_a, trans_b,
51        m as i32, n as i32, k as i32, 1.0,
52        a_data.as_ptr(), lda,
53        b_data.as_ptr(), ldb, 0.0,
54        c.as_mut_ptr(), ldc
55      );
56    }
57
58    CpuStorage::new(c, vec![m as usize, n as usize])
59  }
60}