etensor_core/backends/cpu/matmul.rs
1//! High-performance CPU Matrix Multiplication (GEMM) kernel.
2//!
3//! This kernel uses the `matrixmultiply` crate which implements a BLIS-style
4//! macro/microkernel approach with cache-oblivious tiling, SIMD vectorization
5//! (AVX/FMA/SSE2/NEON), and optional multithreading.
6//!
7//! The kernel strictly respects memory strides. This means if a user transposes
8//! a tensor (which is a zero-copy O(1) operation), this kernel correctly reads
9//! the memory in transposed order without ever allocating a duplicate buffer.
10
11use crate::tensor::Tensor;
12use crate::buffer::Buffer;
13use crate::shape::Shape;
14use crate::dtypes::DType;
15use crate::device::Device;
16use crate::errors::{EtensorError, EtensorResult};
17
18/// Executes the physical forward pass for 2D Matrix Multiplication: C = A @ B
19///
20/// Uses BLIS-style cache-tiled GEMM via the `matrixmultiply` crate.
21/// This computes: C = alpha * A @ B + beta * C
22/// With alpha=1.0, beta=0.0 (standard matmul).
23pub fn matmul_forward(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
24 // 1. Gatekeeping: Ensure mathematical dimensions are valid for Dot Product
25 if a.shape.rank() != 2 || b.shape.rank() != 2 {
26 return Err(EtensorError::ShapeMismatch {
27 expected: vec![2, 2], // Dummy values to indicate 2D requirement
28 got: vec![a.shape.rank(), b.shape.rank()],
29 });
30 }
31
32 let m = a.shape.dims[0];
33 let k_a = a.shape.dims[1];
34 let k_b = b.shape.dims[0];
35 let n = b.shape.dims[1];
36
37 if k_a != k_b {
38 return Err(EtensorError::ShapeMismatch {
39 expected: vec![m, k_a],
40 got: vec![k_a, n],
41 });
42 }
43
44 // 2. Extract physical memory
45 let slice_a = a.data.as_f32_slice()?;
46 let slice_b = b.data.as_f32_slice()?;
47
48 // 3. Allocate the contiguous output buffer (Shape: [M, N])
49 let mut out_vec = vec![0.0_f32; m * n];
50
51 // 4. Extract strides for zero-copy transposed view support
52 let stride_a0 = a.shape.strides[0] as isize; // row stride
53 let stride_a1 = a.shape.strides[1] as isize; // column stride
54 let stride_b0 = b.shape.strides[0] as isize;
55 let stride_b1 = b.shape.strides[1] as isize;
56
57 // 5. BLIS-style cache-tiled GEMM
58 // C = 1.0 * A @ B + 0.0 * C
59 // The matrixmultiply crate natively supports arbitrary strides, which means
60 // transposed tensors (where strides are swapped) work without any data copying.
61 unsafe {
62 matrixmultiply::sgemm(
63 m, // rows of A / rows of C
64 k_a, // cols of A / rows of B
65 n, // cols of B / cols of C
66 1.0, // alpha
67 slice_a.as_ptr(),
68 stride_a0, // row stride of A
69 stride_a1, // col stride of A
70 slice_b.as_ptr(),
71 stride_b0, // row stride of B
72 stride_b1, // col stride of B
73 0.0, // beta
74 out_vec.as_mut_ptr(),
75 n as isize, // row stride of C (contiguous, row-major)
76 1, // col stride of C
77 );
78 }
79
80 // 6. Construct the final output tensor
81 let out_shape = Shape::new(vec![m, n]);
82
83 Ok(Tensor::new(
84 Buffer::from_f32_vec(out_vec),
85 out_shape,
86 Device::Cpu,
87 DType::F32,
88 false, // Gradients are handled exclusively by the Dispatcher.
89 ))
90}
91
92// =====================================================================
93// UNIT TESTS
94// =====================================================================
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 // Helper to generate purely physical test matrices
100 fn make_matrix(dims: Vec<usize>, data: Vec<f32>) -> Tensor {
101 Tensor::new(
102 Buffer::from_f32_vec(data),
103 Shape::new(dims),
104 Device::Cpu,
105 DType::F32,
106 false,
107 )
108 }
109
110 #[test]
111 fn test_standard_matmul() {
112 // A: 2x3 matrix
113 // [1.0, 2.0, 3.0]
114 // [4.0, 5.0, 6.0]
115 let a = make_matrix(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
116
117 // B: 3x2 matrix
118 // [7.0, 8.0]
119 // [9.0, 1.0]
120 // [2.0, 3.0]
121 let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
122
123 // C = A @ B (Expected Shape: 2x2)
124 // C[0, 0] = (1*7) + (2*9) + (3*2) = 7 + 18 + 6 = 31
125 // C[0, 1] = (1*8) + (2*1) + (3*3) = 8 + 2 + 9 = 19
126 // C[1, 0] = (4*7) + (5*9) + (6*2) = 28 + 45 + 12 = 85
127 // C[1, 1] = (4*8) + (5*1) + (6*3) = 32 + 5 + 18 = 55
128
129 let c = matmul_forward(&a, &b).unwrap();
130 let slice = c.data.as_f32_slice().unwrap();
131
132 assert_eq!(c.shape.dims, vec![2, 2]);
133 assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
134 }
135
136 #[test]
137 fn test_strided_zero_copy_matmul() {
138 // The Ultimate Test: Can we multiply a Transposed matrix without moving memory?
139
140 // A: 3x2 matrix
141 // [1.0, 4.0]
142 // [2.0, 5.0]
143 // [3.0, 6.0]
144 let a_orig = make_matrix(vec![3, 2], vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
145
146 // Transpose A to become a 2x3 matrix logically! (Memory stays the exact same)
147 let a_t = a_orig.transpose();
148
149 let b = make_matrix(vec![3, 2], vec![7.0, 8.0, 9.0, 1.0, 2.0, 3.0]);
150
151 // C = A^T @ B
152 // Because a_t is mathematically identical to A in the first test,
153 // the result MUST be the exact same [31, 19, 85, 55].
154 let c = matmul_forward(&a_t, &b).unwrap();
155 let slice = c.data.as_f32_slice().unwrap();
156
157 assert_eq!(c.shape.dims, vec![2, 2]);
158 assert_eq!(slice, &[31.0, 19.0, 85.0, 55.0]);
159 }
160
161 #[test]
162 fn test_matmul_shape_rejection() {
163 // 2x3 @ 4x2 -> Inner dimensions (3 and 4) do not match!
164 let a = make_matrix(vec![2, 3], vec![0.0; 6]);
165 let b = make_matrix(vec![4, 2], vec![0.0; 8]);
166
167 let result = matmul_forward(&a, &b);
168 assert!(result.is_err());
169 }
170
171 #[test]
172 fn test_large_matmul_correctness() {
173 // 64x64 identity matrix @ 64x64 ones matrix = 64x64 matrix where each element = 64
174 let n = 64;
175 let mut identity = vec![0.0_f32; n * n];
176 for i in 0..n {
177 identity[i * n + i] = 1.0;
178 }
179 let ones = vec![1.0_f32; n * n];
180
181 let a = make_matrix(vec![n, n], identity);
182 let b = make_matrix(vec![n, n], ones);
183
184 let c = matmul_forward(&a, &b).unwrap();
185 let slice = c.data.as_f32_slice().unwrap();
186
187 // I @ ones = ones
188 for &val in slice {
189 assert_eq!(val, 1.0);
190 }
191 }
192}