1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//! Matrix multiplication operations trait.
use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
/// Matrix multiplication operations
pub trait MatmulOps<R: Runtime> {
/// Matrix multiplication: a @ b
///
/// Supports batched matmul for tensors with more than 2 dimensions.
fn matmul(&self, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>;
/// Fused matrix multiplication with bias addition: C = A @ B + bias
///
/// This is a fused operation that combines matrix multiplication and bias addition
/// into a single kernel, avoiding an extra memory round-trip compared to separate
/// `matmul` followed by `add`. This is the core operation for neural network linear
/// layers: `output = input @ weight.T + bias`.
///
/// # Algorithm (Epilogue Fusion)
///
/// The bias addition is fused into the GEMM epilogue:
/// ```text
/// 1. Load tiles of A and B into shared memory
/// 2. Compute partial products, accumulate in registers
/// 3. Repeat for all K tiles
/// 4. EPILOGUE: For each output element C[i][j]:
/// C[i][j] = accumulated_value[i][j] + bias[j]
/// 5. Write final result to global memory
/// ```
///
/// This saves one global memory read/write cycle vs the naive:
/// ```text
/// temp = A @ B // Write temp to global memory
/// C = temp + bias // Read temp, write C
/// ```
///
/// # Arguments
///
/// * `a` - Input tensor of shape `` `[..., M, K]` ``
/// * `b` - Weight tensor of shape `` `[..., K, N]` ``
/// * `bias` - Bias tensor of shape `` `[N]` `` (1D, broadcast across all M rows)
///
/// # Returns
///
/// Output tensor of shape `` `[..., M, N]` `` where `` `C[..., i, j] = sum_k(A[..., i, k] * B[..., k, j]) + bias[j]` ``
///
/// # Errors
///
/// Returns `Error::ShapeMismatch` if:
/// - Inner dimensions don't match (A's last dim != B's second-to-last dim)
/// - Bias shape doesn't match output columns (bias.len() != N)
/// - Bias is not 1D
///
/// Returns `Error::DTypeMismatch` if A, B, and bias don't have the same dtype.
///
/// # Examples
///
/// ```
/// # use numr::prelude::*;
/// # let device = CpuDevice::new();
/// # let client = CpuRuntime::default_client(&device);
/// use numr::ops::MatmulOps;
///
/// // Linear layer: output = input @ weight.T + bias
/// let batch = 2; let seq_len = 5; let hidden = 3; let out_features = 4;
/// let input = client.randn(&[batch, seq_len, hidden], DType::F32)?;
/// let weight = client.randn(&[out_features, hidden], DType::F32)?;
/// let bias = client.randn(&[out_features], DType::F32)?;
///
/// // Using fused operation (faster):
/// let output = client.matmul_bias(&input, &weight.transpose(-1, -2)?, &bias)?;
///
/// // Equivalent to (slower - extra memory round-trip):
/// let temp = client.matmul(&input, &weight.transpose(-1, -2)?)?;
/// let output = client.add(&temp, &bias.unsqueeze(0)?.unsqueeze(0)?)?;
/// # Ok::<(), numr::error::Error>(())
/// ```
///
/// # Performance
///
/// Fusing bias into the GEMM epilogue provides:
/// - ~2x memory bandwidth reduction for the bias addition
/// - Better cache utilization (output stays in registers)
/// - Reduced kernel launch overhead (one kernel instead of two)
///
/// For large matrices where GEMM is compute-bound, the speedup is modest.
/// For smaller matrices (typical in LLM inference), the speedup is more significant.
///
/// # Backend Support
///
/// | Backend | Supported DTypes | Tensor Dims | Notes |
/// |---------|------------------|-------------|-------|
/// | CPU | All dtypes | 2D, 3D+ | Full support via generic kernels |
/// | CUDA | F32, F64, F16, BF16 | 2D, 3D+ | Returns `UnsupportedDType` for integers |
/// | WebGPU | F32, I32, U32, F16 | 2D, 3D only | Returns error for >3D tensors |
///
/// Integer dtypes (I32, I64, U32, U64) are only supported on CPU.
/// CUDA returns `Error::UnsupportedDType` for integer matmul_bias operations.
/// WebGPU is limited to 3D workgroup dispatches and returns an error for >3D tensors.
fn matmul_bias(&self, a: &Tensor<R>, b: &Tensor<R>, bias: &Tensor<R>) -> Result<Tensor<R>>;
}