1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
//! Matrix multiplication operations for CUDA runtime
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::BinaryOps;
use crate::ops::{
MatmulOps, ShapeOps, matmul_bias_output_shape, matmul_output_shape, validate_matmul_bias_dtypes,
};
use crate::runtime::cuda::ops::helpers::{
matmul_batched_native, matmul_bias_batched_native, matmul_bias_native, matmul_native,
};
use crate::runtime::cuda::{CudaClient, CudaRuntime};
use crate::runtime::fallback::{matmul_fallback, validate_binary_dtypes};
use crate::tensor::Tensor;
impl MatmulOps<CudaRuntime> for CudaClient {
fn matmul(
&self,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
let dtype = validate_binary_dtypes(a, b)?;
let a_shape = a.shape();
let b_shape = b.shape();
let m = if a_shape.len() >= 2 {
a_shape[a_shape.len() - 2]
} else {
1
};
let k = a_shape[a_shape.len() - 1];
let n = b_shape[b_shape.len() - 1];
let k_b = if b_shape.len() >= 2 {
b_shape[b_shape.len() - 2]
} else {
b_shape[b_shape.len() - 1]
};
if k != k_b {
return Err(Error::ShapeMismatch {
expected: a_shape.to_vec(),
got: b_shape.to_vec(),
});
}
let out_shape = matmul_output_shape(a_shape, b_shape).ok_or(Error::ShapeMismatch {
expected: a_shape.to_vec(),
got: b_shape.to_vec(),
})?;
let batch_size: usize = out_shape
.iter()
.take(out_shape.len().saturating_sub(2))
.product();
let batch_size = batch_size.max(1);
// Native tiled CUDA kernel
match dtype {
DType::F32 | DType::F64 | DType::F16 | DType::BF16 => {
if batch_size > 1 {
matmul_batched_native(self, a, b, dtype, batch_size, m, k, n)
} else {
// Pad unaligned F16/BF16 (m>16) up to 16-multiples so the WMMA
// tensor-core kernel fires. Critical for the varlen-embedding path:
// M = total_tokens is rarely a multiple of 16, so without this F16
// dropped to the ~150x-slower generic kernel (57 vs 8500 GFLOP/s).
// Zero-padding is exact (extra K contributes 0; extra M rows / N
// cols are sliced off); the WMMA kernel only ever sees aligned dims.
let pad_for_wmma = matches!(dtype, DType::F16 | DType::BF16)
&& m > 16
&& (!m.is_multiple_of(16)
|| !k.is_multiple_of(16)
|| !n.is_multiple_of(16));
if pad_for_wmma {
let m_pad = m.next_multiple_of(16);
let k_pad = k.next_multiple_of(16);
let n_pad = n.next_multiple_of(16);
// pad(t, [last_before, last_after, 2nd_last_before, 2nd_last_after])
// — only the last two dims (M=2nd-last of A, K=last of A; N=last
// of B, K=2nd-last of B) are padded; any leading batch dims are
// untouched.
let a_pad = self.pad(a, &[0, k_pad - k, 0, m_pad - m], 0.0)?;
let b_pad = self.pad(b, &[0, n_pad - n, 0, k_pad - k], 0.0)?;
let out_pad =
matmul_native(self, &a_pad, &b_pad, dtype, m_pad, k_pad, n_pad)?;
// Slice the M (2nd-last) and N (last) dims back via negative
// indexing — NOT dims 0/1, since the output may carry leading
// batch dims (e.g. a 3D [1, m, n] from the padded encoder forward,
// where narrowing dim 0 — the size-1 batch — gave a [0, …] tensor).
out_pad.narrow(-2, 0, m)?.narrow(-1, 0, n)?.contiguous()
} else {
matmul_native(self, a, b, dtype, m, k, n)
}
}
}
_ => matmul_fallback(a, b, &out_shape, &self.device, "matmul"),
}
}
fn matmul_bias(
&self,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
bias: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
// Validate dtypes using unified helper (ensures consistent error handling across backends)
let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?;
// Validate bias is 1D
if bias.shape().len() != 1 {
return Err(Error::InvalidArgument {
arg: "bias",
reason: format!("bias must be 1D tensor, got shape {:?}", bias.shape()),
});
}
let a_shape = a.shape();
let b_shape = b.shape();
let bias_shape = bias.shape();
let m = if a_shape.len() >= 2 {
a_shape[a_shape.len() - 2]
} else {
1
};
let k = a_shape[a_shape.len() - 1];
let n = b_shape[b_shape.len() - 1];
// Validate inner dimensions
let k_b = if b_shape.len() >= 2 {
b_shape[b_shape.len() - 2]
} else {
b_shape[b_shape.len() - 1]
};
if k != k_b {
return Err(Error::ShapeMismatch {
expected: a_shape.to_vec(),
got: b_shape.to_vec(),
});
}
// Validate bias length matches N
if bias_shape[0] != n {
return Err(Error::InvalidArgument {
arg: "bias",
reason: format!(
"bias length {} must match output columns {}",
bias_shape[0], n
),
});
}
let out_shape =
matmul_bias_output_shape(a_shape, b_shape, bias_shape).ok_or(Error::ShapeMismatch {
expected: a_shape.to_vec(),
got: b_shape.to_vec(),
})?;
let batch_size: usize = out_shape
.iter()
.take(out_shape.len().saturating_sub(2))
.product();
let batch_size = batch_size.max(1);
// Native tiled CUDA kernel with fused bias
match dtype {
DType::F32 | DType::F64 | DType::F16 | DType::BF16 => {
if batch_size > 1 {
matmul_bias_batched_native(self, a, b, bias, dtype, batch_size, m, k, n)
} else {
// Pad unaligned F16/BF16 (m>16) up to 16-multiples so WMMA fires
// (see matmul() for rationale). bias is [n] → pad to [n_pad].
let pad_for_wmma = matches!(dtype, DType::F16 | DType::BF16)
&& m > 16
&& (!m.is_multiple_of(16)
|| !k.is_multiple_of(16)
|| !n.is_multiple_of(16));
if pad_for_wmma {
let m_pad = m.next_multiple_of(16);
let k_pad = k.next_multiple_of(16);
let n_pad = n.next_multiple_of(16);
let a_pad = self.pad(a, &[0, k_pad - k, 0, m_pad - m], 0.0)?;
let b_pad = self.pad(b, &[0, n_pad - n, 0, k_pad - k], 0.0)?;
let bias_pad = self.pad(bias, &[0, n_pad - n], 0.0)?;
let out_pad = matmul_bias_native(
self, &a_pad, &b_pad, &bias_pad, dtype, m_pad, k_pad, n_pad,
)?;
// Slice M (2nd-last) and N (last) via negative indexing — see matmul().
out_pad.narrow(-2, 0, m)?.narrow(-1, 0, n)?.contiguous()
} else {
matmul_bias_native(self, a, b, bias, dtype, m, k, n)
}
}
}
_ => {
// FP8 and other dtypes: fall back to matmul + add
let mm = self.matmul(a, b)?;
self.add(&mm, &bias.reshape(&[1, n])?)
}
}
}
}