1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
//! Auto-detected GPU kernel configuration.
//!
//! Replaces per-machine env var tuning (`DP4A_Q4K`, `HW_DP4A_Q4K`, `MWV_Q6K`, etc.)
//! with automatic detection based on `compute_capability()`.
//!
//! Env vars still work as overrides for experimentation, but the defaults are
//! now correct for each GPU — no forjar config drift.
use trueno_gpu::driver::CudaContext;
/// Kernel variant for Q4K GEMV dispatch.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Q4kVariant {
/// Legacy single-warp (32 threads), no DP4A. Fallback for sm < 7.5.
Legacy,
/// Wide: 128 threads per output row.
Wide,
/// Vectorized: 32 threads with vectorized loads.
Vectorized,
/// Multi-warp DP4A: 32 threads/super-block with shfl broadcast.
MwvDp4a,
/// Half-warp DP4A: 16 threads/super-block, direct scale loads. Best on sm_75+.
HwDp4a,
/// Multi-warp vectorized (no DP4A).
Mwv,
}
/// Kernel variant for Q6K GEMV dispatch.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Q6kVariant {
/// Original single-warp Q6K (fallback).
Legacy,
/// Multi-warp vectorized Q6K (GH-118).
Mwv,
/// DP4A Q6K with Q8 pre-quantization.
Dp4a,
/// Half-warp DP4A Q6K: 16 threads/SB, direct scale loads (PMAT-030).
HwDp4a,
}
/// Auto-detected GPU profile for kernel dispatch.
///
/// Computed once at executor init from `compute_capability()`.
/// All kernel dispatch reads from this instead of env vars.
#[derive(Debug, Clone)]
pub struct GpuProfile {
/// Q4K GEMV kernel variant (auto-detected: HwDp4a on sm_75+).
pub q4k: Q4kVariant,
/// Q6K GEMV kernel variant (auto-detected: Dp4a on sm_75+).
pub q6k: Q6kVariant,
/// Multi-warp GEMV warp count (default: 3, override: MWV_WARPS env).
pub mwv_warps: u32,
/// Whether batched prefill is enabled (default: true, override: BATCHED_PREFILL=0).
pub batched_prefill: bool,
/// Whether to use cuBLAS HGEMM for decode (M=1) on high-BW GPUs.
/// Auto-detected: true on sm_75+ with >=32 SMs (desktop/server class).
/// Override: HGEMM_DECODE=1/0 or CUBLAS_GEMM_THRESHOLD=1.
pub hgemm_decode: bool,
/// Whether to use fused gate+up+SwiGLU kernel (PMAT-034).
/// Saves 11% instructions + eliminates SwiGLU kernel + 4 buffer passes.
/// Auto-detected: true when q4k=HwDp4a. Override: FUSED_GATE_UP=0/1.
pub fused_gate_up: bool,
/// PMAT-067: Use FP8 E4M3 weights for prefill GEMM (1 B/elem vs FP16's 2 B/elem).
/// Auto-detected: true on sm_89+ (Ada Lovelace FP8 tensor cores).
/// Override: FP8_PREFILL=0 to disable, FP8_PREFILL=1 to force.
/// cuBLASLt FP8 GEMM halves weight bandwidth — TTFT improvement ~1.25x.
pub fp8_prefill: bool,
/// PMAT-090: Use FP8 cuBLASLt GEMM for batched decode (M>=2).
/// DP4A Q4K GEMV is compute-bound at M>1 (DP4A ceiling = 306 tok/s at M=4).
/// FP8 (1 B/elem) reads 1.78× more than Q4K (0.5625) but tensor cores keep
/// it memory-bound — expected: ITL ~15→~8ms, aggregate ~257→~380 tok/s.
/// Auto-detected: true on sm_89+ when fp8_prefill is enabled.
/// Override: FP8_DECODE=0 to disable, FP8_DECODE=1 to force.
pub fp8_decode: bool,
/// PMAT-091: Use column-interleaved Q4K WMMA GEMM for batched decode (M>=2).
/// W4A16: INT4 storage (Q4K, 0.5625 B/elem) + FP16 tensor core compute.
/// Interleaved layout fixes 864-byte cross-column stride → perfect 128B coalescing.
/// At 70% WMMA efficiency: est. +34% c=4 aggregate over FP8.
/// Override: W4A16_INTERLEAVED=0 to disable, W4A16_INTERLEAVED=1 to force.
pub w4a16_interleaved: bool,
/// SM version for logging (e.g., "sm_89").
pub sm_target: String,
/// Numeric compute capability (major*10 + minor, e.g. 89 for sm_89).
/// Used for numeric comparisons instead of string lexicographic (avoids sm_100 bug).
pub cc: u32,
}
impl GpuProfile {
/// Detect optimal kernel configuration from GPU hardware.
///
/// Priority: env var override > auto-detect from compute capability.
/// This means `HW_DP4A_Q4K=1` still works for experimentation,
/// but production deployments need zero env vars.
pub fn detect(context: &CudaContext) -> Self {
let (major, minor) = context.compute_capability().unwrap_or((7, 0));
// GH-480: PTX source `.target` must use a version that PTX 8.0 supports (max sm_90).
// The CUDA JIT compiler (`CU_JIT_TARGET` in module.rs) receives the REAL
// compute capability (e.g. 121 for Blackwell) so it compiles natively.
// PTX `.target` = minimum ISA needed; JIT target = actual device.
let (ptx_major, ptx_minor) = if major > 9 || (major == 9 && minor > 0) {
(9, 0) // sm_90 is max target PTX 8.0 supports
} else {
(major, minor)
};
let sm_target = format!("sm_{ptx_major}{ptx_minor}");
let has_dp4a = major > 7 || (major == 7 && minor >= 5); // sm_75+ (Turing)
let num_sms = context.multiprocessor_count().unwrap_or(8) as u32;
let q4k = Self::detect_q4k(has_dp4a);
let q6k = Self::detect_q6k(has_dp4a);
let mwv_warps = Self::detect_mwv_warps();
let batched_prefill = Self::detect_batched_prefill();
let hgemm_decode = Self::detect_hgemm_decode(has_dp4a, num_sms);
let fused_gate_up = Self::detect_fused_gate_up(&q4k);
let cc = major as u32 * 10 + minor as u32;
let fp8_prefill = Self::detect_fp8_prefill(cc);
let fp8_decode = Self::detect_fp8_decode(fp8_prefill, cc);
let w4a16_interleaved = Self::detect_w4a16_interleaved(cc);
let profile = Self {
q4k,
q6k,
mwv_warps,
batched_prefill,
hgemm_decode,
fused_gate_up,
fp8_prefill,
fp8_decode,
w4a16_interleaved,
sm_target,
cc,
};
eprintln!(
"[GpuProfile] {}: q4k={:?}, q6k={:?}, warps={}, batched_prefill={}, hgemm_decode={}, fused_gate_up={}, fp8_prefill={}, fp8_decode={}, w4a16_il={}, sms={}",
profile.sm_target, profile.q4k, profile.q6k, profile.mwv_warps,
profile.batched_prefill, profile.hgemm_decode, profile.fused_gate_up,
profile.fp8_prefill, profile.fp8_decode, profile.w4a16_interleaved, num_sms,
);
profile
}
/// Q4K variant: env var override, else HwDp4a on sm_75+, else Mwv.
fn detect_q4k(has_dp4a: bool) -> Q4kVariant {
// Env var overrides (for experimentation only)
if std::env::var("WIDE_Q4K_DISABLE").is_ok() {
return Q4kVariant::Legacy;
}
if std::env::var("WIDE_Q4K").is_ok() {
return Q4kVariant::Wide;
}
if std::env::var("VECTORIZED_Q4K").is_ok() {
return Q4kVariant::Vectorized;
}
if std::env::var("HW_DP4A_Q4K").is_ok() {
return Q4kVariant::HwDp4a;
}
if std::env::var("DP4A_Q4K").is_ok() {
return Q4kVariant::MwvDp4a;
}
// PMAT-096: Force FP32 MWV variant (no Q8 quantization overhead)
if std::env::var("MWV_Q4K").is_ok() {
return Q4kVariant::Mwv;
}
// Auto-detect: HW DP4A on any GPU with DP4A support (sm_75+)
if has_dp4a {
Q4kVariant::HwDp4a
} else {
Q4kVariant::Mwv
}
}
/// Q6K variant: env var override, else HwDp4a on sm_75+, else Mwv.
fn detect_q6k(has_dp4a: bool) -> Q6kVariant {
if std::env::var("HW_DP4A_Q6K").is_ok() {
return Q6kVariant::HwDp4a;
}
if std::env::var("DP4A_Q6K").is_ok() {
return Q6kVariant::Dp4a;
}
if std::env::var("MWV_Q6K").is_ok() {
return Q6kVariant::Mwv;
}
if has_dp4a {
Q6kVariant::HwDp4a
} else {
Q6kVariant::Mwv
}
}
/// MWV warp count: env var override, else 3.
/// PMAT-089: 4 warps FALSIFIED (-2% decode due to register pressure). 3 is optimal.
fn detect_mwv_warps() -> u32 {
std::env::var("MWV_WARPS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(3)
}
/// Batched prefill: env var override, else always on.
fn detect_batched_prefill() -> bool {
// BATCHED_PREFILL=0 disables; any other value or absent = enabled
std::env::var("BATCHED_PREFILL")
.map(|v| v != "0")
.unwrap_or(true)
}
/// Fused gate+up+SwiGLU: enabled when HW DP4A Q4K is active.
/// Saves 11% instructions + eliminates SwiGLU kernel + 4 intermediate buffer passes.
fn detect_fused_gate_up(q4k: &Q4kVariant) -> bool {
if let Ok(v) = std::env::var("FUSED_GATE_UP") {
return v != "0";
}
// Auto-enable when using HW DP4A Q4K (the fused kernel is HW DP4A only)
*q4k == Q4kVariant::HwDp4a
}
/// PMAT-053b: FP8 prefill — default ON for sm_89+ (Ada/Hopper), OFF for Blackwell.
///
/// FP8 E4M3 weights are 1 B/elem vs FP16's 2 B/elem — halves weight bandwidth.
/// Per-tensor absmax scaling recovers dynamic range (TTFT 46.4→35.5ms, 1.31x).
/// Override: FP8_PREFILL=0 to disable, FP8_PREFILL=1 to force on older GPUs.
///
/// GH-542: Blackwell (sm_100+, cc >= 100) FP8 warmup crashes context.
/// But the FP8 cuBLASLt GEMM itself works on sm_121 (PMAT-410 verified).
/// Enable FP8 prefill on all cc >= 89; warmup_fp8_cache separately guards
/// against the warmup crash (cc < 100 check in attention.rs).
fn detect_fp8_prefill(cc: u32) -> bool {
match std::env::var("FP8_PREFILL").as_deref() {
Ok("0") => false,
Ok("1") => true,
_ => cc >= 89,
}
}
/// PMAT-090: FP8 batched decode — cuBLASLt FP8 GEMM replaces DP4A Q4K GEMV at M>=2.
///
/// DP4A GEMV is compute-bound at M>1: 4 independent DP4A accumulation chains
/// saturate INT32 units. DP4A ceiling = 306 tok/s at M=4 (theoretical).
/// FP8 cuBLASLt reads 1.78× more BW (1 B/elem vs Q4K 0.5625) but stays
/// memory-bound via tensor cores. Expected: ~1.5× c=4 aggregate improvement.
///
/// PMAT-410: FP8 decode follows fp8_prefill. No separate cc guard needed.
fn detect_fp8_decode(fp8_prefill: bool, _cc: u32) -> bool {
match std::env::var("FP8_DECODE").as_deref() {
Ok("0") => false,
Ok("1") => true,
_ => fp8_prefill,
}
}
/// PMAT-091: W4A16 interleaved WMMA for batched decode.
/// Requires sm_70+ for WMMA tensor cores. Default OFF (experimental).
fn detect_w4a16_interleaved(cc: u32) -> bool {
match std::env::var("W4A16_INTERLEAVED").as_deref() {
Ok("0") => false,
Ok("1") => cc >= 70,
_ => false, // Default OFF until benchmarked
}
}
/// HGEMM decode: use cuBLAS HGEMM (cached FP16 weights) for M=1 decode.
///
/// PMAT-037 RESULT: cuBLAS HGEMM for M=1 is SLOWER than Q4K GEMV on both
/// 4090 (109 vs 193 tok/s) and Jetson Orin. FP16 reads 3.56x more data
/// and cuBLAS launch overhead dominates at M=1. Disabled by default.
fn detect_hgemm_decode(_has_dp4a: bool, _num_sms: u32) -> bool {
// Env var override (for experimentation)
if let Ok(v) = std::env::var("HGEMM_DECODE") {
return v == "1";
}
// PMAT-037 RESULT: cuBLAS HGEMM for M=1 is SLOWER than Q4K GEMV (109 vs 200 tok/s).
// FP16 reads 3.56x more data, and cuBLAS overhead dominates at M=1.
// Keep disabled by default — only useful for M>=4 prefill (batched path).
false
}
}