use crate::tensor::Tensor;
use crate::buffer::Buffer;
use crate::shape::Shape;
use crate::dtypes::DType;
use crate::device::Device;
use crate::errors::{EtensorError, EtensorResult};
pub fn matmul_forward(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
if a.shape.rank() != 2 || b.shape.rank() != 2 {
return Err(EtensorError::ShapeMismatch {
expected: vec![2, 2], got: vec![a.shape.rank(), b.shape.rank()],
});
}
let m = a.shape.dims[0];
let k_a = a.shape.dims[1];
let k_b = b.shape.dims[0];
let n = b.shape.dims[1];
if k_a != k_b {
return Err(EtensorError::ShapeMismatch {
expected: vec![m, k_a],
got: vec![k_a, n],
});
}
let slice_a = a.data.as_f32_slice()?;
let slice_b = b.data.as_f32_slice()?;
let mut out_vec = vec![0.0_f32; m * n];
let stride_a0 = a.shape.strides[0] as isize; let stride_a1 = a.shape.strides[1] as isize; let stride_b0 = b.shape.strides[0] as isize;
let stride_b1 = b.shape.strides[1] as isize;
unsafe {
matrixmultiply::sgemm(
m, k_a, n, 1.0, slice_a.as_ptr(),
stride_a0, stride_a1, slice_b.as_ptr(),
stride_b0, stride_b1, 0.0, out_vec.as_mut_ptr(),
n as isize, 1, );
}
let out_shape = Shape::new(vec![m, n]);
Ok(Tensor::new(
Buffer::from_f32_vec(out_vec),
out_shape,
Device::Cpu,
DType::F32,
false, ))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_matrix(dims: Vec<usize>, data: Vec<f32>) -> Tensor {
Tensor::new(
Buffer::from_f32_vec(data),
Shape::new(dims),
Device::Cpu,
DType::F32,
false,
)
}
#[test]
fn test_standard_matmul() {
let a = make_matrix(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
let c = matmul_forward(&a, &b).unwrap();
let slice = c.data.as_f32_slice().unwrap();
assert_eq!(c.shape.dims, vec![2, 2]);
assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
}
#[test]
fn test_strided_zero_copy_matmul() {
let a_orig = make_matrix(vec![3, 2], vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
let a_t = a_orig.transpose();
let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
let c = matmul_forward(&a_t, &b).unwrap();
let slice = c.data.as_f32_slice().unwrap();
assert_eq!(c.shape.dims, vec![2, 2]);
assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
}
#[test]
fn test_matmul_shape_rejection() {
let a = make_matrix(vec![2, 3], vec![0.0; 6]);
let b = make_matrix(vec![4, 2], vec![0.0; 8]);
let result = matmul_forward(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_large_matmul_correctness() {
let n = 64;
let mut identity = vec![0.0_f32; n * n];
for i in 0..n {
identity[i * n + i] = 1.0;
}
let ones = vec![1.0_f32; n * n];
let a = make_matrix(vec![n, n], identity);
let b = make_matrix(vec![n, n], ones);
let c = matmul_forward(&a, &b).unwrap();
let slice = c.data.as_f32_slice().unwrap();
for &val in slice {
assert_eq!(val, 1.0);
}
}
}