etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! High-performance CPU Matrix Multiplication (GEMM) kernel.
//! 
//! This kernel uses the `matrixmultiply` crate which implements a BLIS-style 
//! macro/microkernel approach with cache-oblivious tiling, SIMD vectorization 
//! (AVX/FMA/SSE2/NEON), and optional multithreading.
//!
//! The kernel strictly respects memory strides. This means if a user transposes 
//! a tensor (which is a zero-copy O(1) operation), this kernel correctly reads 
//! the memory in transposed order without ever allocating a duplicate buffer.

use crate::tensor::Tensor;
use crate::buffer::Buffer;
use crate::shape::Shape;
use crate::dtypes::DType;
use crate::device::Device;
use crate::errors::{EtensorError, EtensorResult};

/// Executes the physical forward pass for 2D Matrix Multiplication: C = A @ B
/// 
/// Uses BLIS-style cache-tiled GEMM via the `matrixmultiply` crate.
/// This computes: C = alpha * A @ B + beta * C
/// With alpha=1.0, beta=0.0 (standard matmul).
pub fn matmul_forward(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
    // 1. Gatekeeping: Ensure mathematical dimensions are valid for Dot Product
    if a.shape.rank() != 2 || b.shape.rank() != 2 {
        return Err(EtensorError::ShapeMismatch {
            expected: vec![2, 2], // Dummy values to indicate 2D requirement
            got: vec![a.shape.rank(), b.shape.rank()],
        });
    }

    let m = a.shape.dims[0];
    let k_a = a.shape.dims[1];
    let k_b = b.shape.dims[0];
    let n = b.shape.dims[1];

    if k_a != k_b {
        return Err(EtensorError::ShapeMismatch {
            expected: vec![m, k_a],
            got: vec![k_a, n],
        });
    }

    // 2. Extract physical memory
    let slice_a = a.data.as_f32_slice()?;
    let slice_b = b.data.as_f32_slice()?;

    // 3. Allocate the contiguous output buffer (Shape: [M, N])
    let mut out_vec = vec![0.0_f32; m * n];

    // 4. Extract strides for zero-copy transposed view support
    let stride_a0 = a.shape.strides[0] as isize; // row stride
    let stride_a1 = a.shape.strides[1] as isize; // column stride
    let stride_b0 = b.shape.strides[0] as isize;
    let stride_b1 = b.shape.strides[1] as isize;

    // 5. BLIS-style cache-tiled GEMM
    // C = 1.0 * A @ B + 0.0 * C
    // The matrixmultiply crate natively supports arbitrary strides, which means
    // transposed tensors (where strides are swapped) work without any data copying.
    unsafe {
        matrixmultiply::sgemm(
            m,           // rows of A / rows of C
            k_a,         // cols of A / rows of B
            n,           // cols of B / cols of C
            1.0,         // alpha
            slice_a.as_ptr(),
            stride_a0,   // row stride of A
            stride_a1,   // col stride of A
            slice_b.as_ptr(),
            stride_b0,   // row stride of B
            stride_b1,   // col stride of B
            0.0,         // beta
            out_vec.as_mut_ptr(),
            n as isize,  // row stride of C (contiguous, row-major)
            1,           // col stride of C
        );
    }

    // 6. Construct the final output tensor
    let out_shape = Shape::new(vec![m, n]);
    
    Ok(Tensor::new(
        Buffer::from_f32_vec(out_vec),
        out_shape,
        Device::Cpu,
        DType::F32,
        false, // Gradients are handled exclusively by the Dispatcher.
    ))
}

// =====================================================================
// UNIT TESTS
// =====================================================================
#[cfg(test)]
mod tests {
    use super::*;

    // Helper to generate purely physical test matrices
    fn make_matrix(dims: Vec<usize>, data: Vec<f32>) -> Tensor {
        Tensor::new(
            Buffer::from_f32_vec(data),
            Shape::new(dims),
            Device::Cpu,
            DType::F32,
            false,
        )
    }

    #[test]
    fn test_standard_matmul() {
        // A: 2x3 matrix
        // [1.0, 2.0, 3.0]
        // [4.0, 5.0, 6.0]
        let a = make_matrix(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);

        // B: 3x2 matrix
        // [7.0, 8.0]
        // [9.0, 1.0]
        // [2.0, 3.0]
        let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);

        // C = A @ B (Expected Shape: 2x2)
        // C[0, 0] = (1*7) + (2*9) + (3*2) = 7 + 18 + 6 = 31
        // C[0, 1] = (1*8) + (2*1) + (3*3) = 8 + 2 + 9 = 19
        // C[1, 0] = (4*7) + (5*9) + (6*2) = 28 + 45 + 12 = 85
        // C[1, 1] = (4*8) + (5*1) + (6*3) = 32 + 5 + 18 = 55
        
        let c = matmul_forward(&a, &b).unwrap();
        let slice = c.data.as_f32_slice().unwrap();

        assert_eq!(c.shape.dims, vec![2, 2]);
        assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
    }

    #[test]
    fn test_strided_zero_copy_matmul() {
        // The Ultimate Test: Can we multiply a Transposed matrix without moving memory?
        
        // A: 3x2 matrix
        // [1.0, 4.0]
        // [2.0, 5.0]
        // [3.0, 6.0]
        let a_orig = make_matrix(vec![3, 2], vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
        
        // Transpose A to become a 2x3 matrix logically! (Memory stays the exact same)
        let a_t = a_orig.transpose();
        
        let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);

        // C = A^T @ B
        // Because a_t is mathematically identical to A in the first test, 
        // the result MUST be the exact same [31, 19, 85, 55].
        let c = matmul_forward(&a_t, &b).unwrap();
        let slice = c.data.as_f32_slice().unwrap();

        assert_eq!(c.shape.dims, vec![2, 2]);
        assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
    }

    #[test]
    fn test_matmul_shape_rejection() {
        // 2x3 @ 4x2 -> Inner dimensions (3 and 4) do not match!
        let a = make_matrix(vec![2, 3], vec![0.0; 6]);
        let b = make_matrix(vec![4, 2], vec![0.0; 8]);

        let result = matmul_forward(&a, &b);
        assert!(result.is_err());
    }

    #[test]
    fn test_large_matmul_correctness() {
        // 64x64 identity matrix @ 64x64 ones matrix = 64x64 matrix where each element = 64
        let n = 64;
        let mut identity = vec![0.0_f32; n * n];
        for i in 0..n {
            identity[i * n + i] = 1.0;
        }
        let ones = vec![1.0_f32; n * n];

        let a = make_matrix(vec![n, n], identity);
        let b = make_matrix(vec![n, n], ones);

        let c = matmul_forward(&a, &b).unwrap();
        let slice = c.data.as_f32_slice().unwrap();

        // I @ ones = ones
        for &val in slice {
            assert_eq!(val, 1.0);
        }
    }
}