Skip to main content

etensor_core/backends/cpu/
fusion.rs

1//! High-performance CPU Operator Fusion kernels.
2//! 
3//! Operator fusion combines multiple mathematical operations into a single loop,
4//! drastically reducing memory allocations and RAM bandwidth bottlenecks.
5//! Uses matrixmultiply for BLIS-style GEMM in the fused linear kernel.
6
7use crate::tensor::Tensor;
8use crate::buffer::Buffer;
9use crate::shape::Shape;
10use crate::dtypes::DType;
11use crate::device::Device;
12use crate::errors::{EtensorError, EtensorResult};
13
14/// Fused Element-wise Addition and ReLU: f(A, B) = max(0, A + B)
15/// 
16/// Bandwidth-bound: read 2 values, add, max, write 1 value per element.
17/// Single-threaded SIMD loop already saturates the memory bus.
18pub fn add_relu_forward(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
19    if a.shape.dims != b.shape.dims {
20        return Err(EtensorError::ShapeMismatch {
21            expected: a.shape.dims.clone(),
22            got: b.shape.dims.clone(),
23        });
24    }
25
26    let slice_a = a.data.as_f32_slice()?;
27    let slice_b = b.data.as_f32_slice()?;
28
29    let out_vec: Vec<f32> = slice_a.iter().zip(slice_b).map(|(x, y)| (x + y).max(0.0)).collect();
30
31    Ok(Tensor::new(
32        Buffer::from_f32_vec(out_vec),
33        a.shape.clone(),
34        Device::Cpu,
35        a.dtype,
36        false, // Gradients are exclusively managed by the Dispatcher.
37    ))
38}
39
40/// Fused Linear Layer (MatMul + Bias): y = X @ W + b
41/// Uses BLIS-style cache-tiled GEMM for the matmul portion, then adds bias in a single pass.
42pub fn linear_forward(x: &Tensor, w: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
43    if x.shape.rank() != 2 || w.shape.rank() != 2 || b.shape.rank() != 1 {
44        return Err(EtensorError::InternalError(
45            "Fused Linear requires 2D Input, 2D Weight, and 1D Bias.".to_string(),
46        ));
47    }
48
49    let m = x.shape.dims[0];
50    let k_x = x.shape.dims[1];
51    let k_w = w.shape.dims[0];
52    let n = w.shape.dims[1];
53
54    if k_x != k_w {
55        return Err(EtensorError::ShapeMismatch {
56            expected: vec![m, k_x],
57            got: vec![k_w, n],
58        });
59    }
60
61    // Bias must perfectly match the output feature dimension (N)
62    if b.shape.dims[0] != n {
63        return Err(EtensorError::ShapeMismatch {
64            expected: vec![n],
65            got: b.shape.dims.clone(),
66        });
67    }
68
69    let slice_x = x.data.as_f32_slice()?;
70    let slice_w = w.data.as_f32_slice()?;
71    let slice_b = b.data.as_f32_slice()?;
72
73    // Pre-fill output with the bias vector (replicated across rows)
74    // This lets us use beta=1.0 in sgemm to fuse the bias addition!
75    let mut out_vec = Vec::with_capacity(m * n);
76    for _ in 0..m {
77        out_vec.extend_from_slice(slice_b);
78    }
79
80    let stride_x0 = x.shape.strides[0] as isize;
81    let stride_x1 = x.shape.strides[1] as isize;
82    let stride_w0 = w.shape.strides[0] as isize;
83    let stride_w1 = w.shape.strides[1] as isize;
84
85    // Fused GEMM + Bias: C = 1.0 * X @ W + 1.0 * C (where C is pre-filled with bias)
86    unsafe {
87        matrixmultiply::sgemm(
88            m, k_x, n,
89            1.0,                        // alpha
90            slice_x.as_ptr(),
91            stride_x0, stride_x1,       // X strides
92            slice_w.as_ptr(),
93            stride_w0, stride_w1,       // W strides
94            1.0,                        // beta = 1.0 means: C = matmul + existing C (bias!)
95            out_vec.as_mut_ptr(),
96            n as isize, 1,              // C strides (contiguous row-major)
97        );
98    }
99
100    Ok(Tensor::new(
101        Buffer::from_f32_vec(out_vec),
102        Shape::new(vec![m, n]),
103        Device::Cpu,
104        DType::F32,
105        false,
106    ))
107}
108
109// =====================================================================
110// UNIT TESTS
111// =====================================================================
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    fn make_test_tensor(data: Vec<f32>, dims: Vec<usize>) -> Tensor {
117        Tensor::new(
118            Buffer::from_f32_vec(data),
119            Shape::new(dims),
120            Device::Cpu,
121            DType::F32,
122            false,
123        )
124    }
125
126    #[test]
127    fn test_cpu_fusion_add_relu() {
128        let a = make_test_tensor(vec![-2.0, 1.0, 3.0], vec![3]);
129        let b = make_test_tensor(vec![1.0, -5.0, 2.0], vec![3]);
130        
131        // A + B = [-1.0, -4.0, 5.0]
132        // ReLU(A + B) = [0.0, 0.0, 5.0]
133        let c = add_relu_forward(&a, &b).unwrap();
134        let slice = c.data.as_f32_slice().unwrap();
135
136        assert_eq!(slice, &[0.0, 0.0, 5.0]);
137    }
138
139    #[test]
140    fn test_cpu_fusion_linear() {
141        // X: 2x2 Matrix
142        // [1.0, 2.0]
143        // [3.0, 4.0]
144        let x = make_test_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
145
146        // W: 2x2 Matrix
147        // [2.0, 0.0]
148        // [0.0, 2.0]
149        let w = make_test_tensor(vec![2.0, 0.0, 0.0, 2.0], vec![2, 2]);
150
151        // B: Vector of size 2
152        // [10.0, 20.0]
153        let b = make_test_tensor(vec![10.0, 20.0], vec![2]);
154
155        // X @ W = 
156        // [2.0,  4.0]
157        // [6.0,  8.0]
158        //
159        // X @ W + B = 
160        // [12.0, 24.0]
161        // [16.0, 28.0]
162        
163        let y = linear_forward(&x, &w, &b).unwrap();
164        let slice = y.data.as_f32_slice().unwrap();
165
166        assert_eq!(y.shape.dims, vec![2, 2]);
167        assert_eq!(slice, &[12.0, 24.0, 16.0, 28.0]);
168    }
169}