#[macro_export]
macro_rules! impl_matmul_with_out {
($view_type:ident, $lock:ident, $into_handle:expr) => {
fn matmul_with_out(&self, other: &Self, out: &mut Self) -> Result<(), String> {
let shape_self = self.shape();
let shape_other = other.shape();
let shape_out = out.shape();
if shape_self.len() != 2 || shape_other.len() != 2 || shape_out.len() != 2 {
return Err("matmul only supports 2D tensors".into());
}
let (m, k1) = (shape_self[0], shape_self[1]);
let (k2, n) = (shape_other[0], shape_other[1]);
if k1 != k2 {
return Err("Inner dimensions must match".into());
}
if shape_out != &[m, n] {
return Err("Output shape must be [M, N]".into());
}
let a_cell = $lock(&self.handle);
let a_t = a_cell.borrow();
let b_cell = $lock(&other.handle);
let b_t = b_cell.borrow();
let c_cell = $lock(&out.handle);
let mut c_t = c_cell.borrow_mut();
if a_t.dtype() != b_t.dtype() || a_t.dtype() != c_t.dtype() {
return Err("Dtype mismatch".into());
}
if a_t.dtype() != $crate::DTYPE_FLOAT32 {
return Err("matmul only supports f32 for now".into());
}
let a_strides = self.strides();
let b_strides = other.strides();
let c_strides = out.strides();
let a_stride_row = a_strides[0];
let a_stride_col = a_strides[1];
let b_stride_row = b_strides[0];
let b_stride_col = b_strides[1];
let c_stride_row = c_strides[0];
let c_stride_col = c_strides[1];
let a_ptr = a_t.data_ptr(None);
let b_ptr = b_t.data_ptr(None);
let c_ptr = c_t.data_mut_ptr(None);
match a_t.device() {
$crate::device::Device::Cpu => unsafe {
$crate::kernel::cpu_matmul_strided_f32(
a_ptr as *const f32,
a_stride_row,
a_stride_col,
b_ptr as *const f32,
b_stride_row,
b_stride_col,
c_ptr as *mut f32,
c_stride_row,
c_stride_col,
m as i32,
n as i32,
k1 as i32,
);
},
$crate::device::Device::Cuda(_) => {
let stream = cuda::get_stream().map_err(|e| e.to_string())?;
let stream_ptr = stream.as_ptr();
unsafe {
let err = $crate::kernel::gpu_matmul_strided_f32(
a_ptr as *const f32,
a_stride_row,
a_stride_col,
b_ptr as *const f32,
b_stride_row,
b_stride_col,
c_ptr as *mut f32,
c_stride_row,
c_stride_col,
m as i32,
n as i32,
k1 as i32,
stream_ptr,
);
if err != 0 {
return Err(format!("GPU matmul failed with error {}", err));
}
}
}
}
Ok(())
}
};
}
#[macro_export]
macro_rules! impl_matmul {
($view_type:ident, $lock:ident, $into_handle:expr) => {
fn matmul(&self, other: &Self) -> Result<Self, String> {
let m = self.shape()[0];
let n = other.shape()[1];
let out_tensor = $crate::tensor::Tensor::new_contiguous(vec![m, n], self.dtype())?;
let mut out_view = Self::new($into_handle(out_tensor));
self.matmul_with_out(other, &mut out_view)?;
Ok(out_view)
}
};
}
#[cfg(test)]
mod tests {
use crate::view::rc_view_to_vec_f32;
use crate::*;
#[test]
fn test_matmul() {
let a = Tensor::new_cpu_from_f32(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let b = Tensor::new_cpu_from_f32(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
let a_view = a.into_rc().as_view();
let b_view = b.into_rc().as_view();
let c_view = a_view.matmul(&b_view).unwrap();
let expected = vec![19.0, 22.0, 43.0, 50.0];
assert_eq!(rc_view_to_vec_f32(&c_view), expected);
}
}