Skip to main content

etensor_core/backends/cpu/
matmul.rs

1//! High-performance CPU Matrix Multiplication (GEMM) kernel.
2//! 
3//! This kernel uses the `matrixmultiply` crate which implements a BLIS-style 
4//! macro/microkernel approach with cache-oblivious tiling, SIMD vectorization 
5//! (AVX/FMA/SSE2/NEON), and optional multithreading.
6//!
7//! The kernel strictly respects memory strides. This means if a user transposes 
8//! a tensor (which is a zero-copy O(1) operation), this kernel correctly reads 
9//! the memory in transposed order without ever allocating a duplicate buffer.
10
11use crate::tensor::Tensor;
12use crate::buffer::Buffer;
13use crate::shape::Shape;
14use crate::dtypes::DType;
15use crate::device::Device;
16use crate::errors::{EtensorError, EtensorResult};
17
18/// Executes the physical forward pass for 2D Matrix Multiplication: C = A @ B
19/// 
20/// Uses BLIS-style cache-tiled GEMM via the `matrixmultiply` crate.
21/// This computes: C = alpha * A @ B + beta * C
22/// With alpha=1.0, beta=0.0 (standard matmul).
23pub fn matmul_forward(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
24    // 1. Gatekeeping: Ensure mathematical dimensions are valid for Dot Product
25    if a.shape.rank() != 2 || b.shape.rank() != 2 {
26        return Err(EtensorError::ShapeMismatch {
27            expected: vec![2, 2], // Dummy values to indicate 2D requirement
28            got: vec![a.shape.rank(), b.shape.rank()],
29        });
30    }
31
32    let m = a.shape.dims[0];
33    let k_a = a.shape.dims[1];
34    let k_b = b.shape.dims[0];
35    let n = b.shape.dims[1];
36
37    if k_a != k_b {
38        return Err(EtensorError::ShapeMismatch {
39            expected: vec![m, k_a],
40            got: vec![k_a, n],
41        });
42    }
43
44    // 2. Extract physical memory
45    let slice_a = a.data.as_f32_slice()?;
46    let slice_b = b.data.as_f32_slice()?;
47
48    // 3. Allocate the contiguous output buffer (Shape: [M, N])
49    let mut out_vec = vec![0.0_f32; m * n];
50
51    // 4. Extract strides for zero-copy transposed view support
52    let stride_a0 = a.shape.strides[0] as isize; // row stride
53    let stride_a1 = a.shape.strides[1] as isize; // column stride
54    let stride_b0 = b.shape.strides[0] as isize;
55    let stride_b1 = b.shape.strides[1] as isize;
56
57    // 5. BLIS-style cache-tiled GEMM
58    // C = 1.0 * A @ B + 0.0 * C
59    // The matrixmultiply crate natively supports arbitrary strides, which means
60    // transposed tensors (where strides are swapped) work without any data copying.
61    unsafe {
62        matrixmultiply::sgemm(
63            m,           // rows of A / rows of C
64            k_a,         // cols of A / rows of B
65            n,           // cols of B / cols of C
66            1.0,         // alpha
67            slice_a.as_ptr(),
68            stride_a0,   // row stride of A
69            stride_a1,   // col stride of A
70            slice_b.as_ptr(),
71            stride_b0,   // row stride of B
72            stride_b1,   // col stride of B
73            0.0,         // beta
74            out_vec.as_mut_ptr(),
75            n as isize,  // row stride of C (contiguous, row-major)
76            1,           // col stride of C
77        );
78    }
79
80    // 6. Construct the final output tensor
81    let out_shape = Shape::new(vec![m, n]);
82    
83    Ok(Tensor::new(
84        Buffer::from_f32_vec(out_vec),
85        out_shape,
86        Device::Cpu,
87        DType::F32,
88        false, // Gradients are handled exclusively by the Dispatcher.
89    ))
90}
91
92// =====================================================================
93// UNIT TESTS
94// =====================================================================
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    // Helper to generate purely physical test matrices
100    fn make_matrix(dims: Vec<usize>, data: Vec<f32>) -> Tensor {
101        Tensor::new(
102            Buffer::from_f32_vec(data),
103            Shape::new(dims),
104            Device::Cpu,
105            DType::F32,
106            false,
107        )
108    }
109
110    #[test]
111    fn test_standard_matmul() {
112        // A: 2x3 matrix
113        // [1.0, 2.0, 3.0]
114        // [4.0, 5.0, 6.0]
115        let a = make_matrix(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
116
117        // B: 3x2 matrix
118        // [7.0, 8.0]
119        // [9.0, 1.0]
120        // [2.0, 3.0]
121        let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
122
123        // C = A @ B (Expected Shape: 2x2)
124        // C[0, 0] = (1*7) + (2*9) + (3*2) = 7 + 18 + 6 = 31
125        // C[0, 1] = (1*8) + (2*1) + (3*3) = 8 + 2 + 9 = 19
126        // C[1, 0] = (4*7) + (5*9) + (6*2) = 28 + 45 + 12 = 85
127        // C[1, 1] = (4*8) + (5*1) + (6*3) = 32 + 5 + 18 = 55
128        
129        let c = matmul_forward(&a, &b).unwrap();
130        let slice = c.data.as_f32_slice().unwrap();
131
132        assert_eq!(c.shape.dims, vec![2, 2]);
133        assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
134    }
135
136    #[test]
137    fn test_strided_zero_copy_matmul() {
138        // The Ultimate Test: Can we multiply a Transposed matrix without moving memory?
139        
140        // A: 3x2 matrix
141        // [1.0, 4.0]
142        // [2.0, 5.0]
143        // [3.0, 6.0]
144        let a_orig = make_matrix(vec![3, 2], vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
145        
146        // Transpose A to become a 2x3 matrix logically! (Memory stays the exact same)
147        let a_t = a_orig.transpose();
148        
149        let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
150
151        // C = A^T @ B
152        // Because a_t is mathematically identical to A in the first test, 
153        // the result MUST be the exact same [31, 19, 85, 55].
154        let c = matmul_forward(&a_t, &b).unwrap();
155        let slice = c.data.as_f32_slice().unwrap();
156
157        assert_eq!(c.shape.dims, vec![2, 2]);
158        assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
159    }
160
161    #[test]
162    fn test_matmul_shape_rejection() {
163        // 2x3 @ 4x2 -> Inner dimensions (3 and 4) do not match!
164        let a = make_matrix(vec![2, 3], vec![0.0; 6]);
165        let b = make_matrix(vec![4, 2], vec![0.0; 8]);
166
167        let result = matmul_forward(&a, &b);
168        assert!(result.is_err());
169    }
170
171    #[test]
172    fn test_large_matmul_correctness() {
173        // 64x64 identity matrix @ 64x64 ones matrix = 64x64 matrix where each element = 64
174        let n = 64;
175        let mut identity = vec![0.0_f32; n * n];
176        for i in 0..n {
177            identity[i * n + i] = 1.0;
178        }
179        let ones = vec![1.0_f32; n * n];
180
181        let a = make_matrix(vec![n, n], identity);
182        let b = make_matrix(vec![n, n], ones);
183
184        let c = matmul_forward(&a, &b).unwrap();
185        let slice = c.data.as_f32_slice().unwrap();
186
187        // I @ ones = ones
188        for &val in slice {
189            assert_eq!(val, 1.0);
190        }
191    }
192}