1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
impl CudaExecutor {
/// PAR-118: Graph-compatible Flash Decoding for single-sequence decode (M=1).
///
/// Uses pre-allocated per-layer K/V pointer buffers and static max_chunks grid
/// dimensions for CUDA graph compatibility. seq_len is read from either
/// `seq_len_buf` (graph mode) or passed directly.
///
/// Five-Whys: attention uses only 28 blocks (one per head) on RTX 4090 (128 SMs),
/// leaving 78% of SMs idle. Flash Decoding splits KV scan into max_chunks blocks
/// per head, giving 28 × max_chunks = 224 blocks for 8 chunks. This 8x increase
/// in block count directly improves SM occupancy and memory-level parallelism.
#[allow(clippy::too_many_arguments)]
pub fn flash_decoding_graphed(
&mut self,
layer_idx: usize,
q_gpu: &GpuBuffer<f32>,
out_gpu: &GpuBuffer<f32>,
_use_graph_mode: bool,
seq_len: u32,
) -> Result<(), GpuError> {
use trueno_gpu::kernels::{FlashDecodingChunkKernel, FlashDecodingReduceKernel, Kernel};
let num_heads = self.kv_num_heads;
let num_kv_heads = self.kv_num_kv_heads;
let head_dim = self.kv_head_dim;
let max_len = self.kv_cache_max_len;
let max_chunks = self.flash_decode_max_chunks;
// Get pre-allocated per-layer K/V pointer buffers
let k_ptrs_buf = self.flash_decode_k_ptrs.get(&layer_idx).ok_or_else(|| {
GpuError::InvalidLaunchConfig(format!(
"PAR-118: Flash Decoding K pointer buffer not found for layer {layer_idx}"
))
})?;
let v_ptrs_buf = self.flash_decode_v_ptrs.get(&layer_idx).ok_or_else(|| {
GpuError::InvalidLaunchConfig(format!(
"PAR-118: Flash Decoding V pointer buffer not found for layer {layer_idx}"
))
})?;
let partials_buf = self.flash_decode_partials.as_ref().ok_or_else(|| {
GpuError::InvalidLaunchConfig(
"PAR-118: Flash Decoding partials buffer not allocated".to_string(),
)
})?;
// Compile chunk kernel
let chunk_kernel = FlashDecodingChunkKernel::new(
max_len as u32,
head_dim as u32,
num_heads as u32,
num_kv_heads as u32,
1, // M=1 for single-sequence decode
);
let chunk_kernel_name = chunk_kernel.name();
let chunk_module_key = format!(
"flash_decode_chunk_{}_{}_{}_{}",
max_len, head_dim, num_heads, num_kv_heads
);
if !self.modules.contains_key(&chunk_module_key) {
let chunk_ptx = chunk_kernel.emit_ptx_for_target(&self.kernels.sm_target);
let module = self.compile_ptx(&chunk_ptx)?;
self.modules.insert(chunk_module_key.clone(), module);
}
// PAR-118: Static grid dimensions for CUDA graph compatibility.
// max_chunks is fixed based on max_seq_len, not actual seq_len.
// The chunk kernel checks actual seq_len and early-exits for empty chunks
// (stores sentinel max=-inf, sum=0 which the reduce kernel skips).
let chunk_config = LaunchConfig {
grid: (num_heads as u32, 1, max_chunks as u32),
block: (32, 1, 1),
shared_mem: 0,
};
let mut q_ptr = q_gpu.as_ptr();
let mut k_ptrs_ptr = k_ptrs_buf.as_ptr();
let mut v_ptrs_ptr = v_ptrs_buf.as_ptr();
let mut partials_ptr = partials_buf.as_ptr();
let mut max_chunks_val = max_chunks as u32;
// PAR-118: seq_len pointer selection depends on capture vs eager mode.
let mut seq_lens_ptr = if self.is_capturing {
// During graph capture: use seq_len_buf (already populated by prepare_capture_buffers).
// No copy_from_host — forbidden during capture. During graph replay, seq_len_buf
// is updated via async copy before graph launch, so captured kernels read fresh values.
self.seq_len_buf.as_ref().ok_or_else(|| {
GpuError::InvalidLaunchConfig(
"PAR-118: seq_len_buf not initialized for graph capture".to_string(),
)
})?.as_ptr()
} else {
// Eager path: use dedicated flash_decode_seq_lens_buf with explicit copy.
// CRITICAL: Do NOT use seq_len_buf here — may have stale values after graph
// capture failure (error 901 fallback).
let buf = self.flash_decode_seq_lens_buf.as_mut().ok_or_else(|| {
GpuError::InvalidLaunchConfig(
"PAR-118: flash_decode_seq_lens_buf not initialized".to_string(),
)
})?;
buf.copy_from_host(&[seq_len])?;
buf.as_ptr()
};
{
let chunk_module = self
.modules
.get_mut(&chunk_module_key)
.expect("module just compiled");
// SAFETY: All pointers from pre-allocated GPU buffers with stable addresses.
// seq_lens_ptr points to a long-lived buffer that outlives the kernel execution.
unsafe {
self.stream.launch_kernel(
chunk_module,
chunk_kernel_name,
&chunk_config,
&mut [
std::ptr::from_mut(&mut q_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_ptrs_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut v_ptrs_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut partials_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut seq_lens_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut max_chunks_val) as *mut std::ffi::c_void,
],
)?;
}
}
// trueno#243: Record chunk kernel for manual graph construction
if self.graph_recording {
let module = self.modules.get_mut(&chunk_module_key).expect("module exists");
let func = module.get_function(chunk_kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config: chunk_config,
arg_data: vec![q_ptr, k_ptrs_ptr, v_ptrs_ptr, partials_ptr, seq_lens_ptr, max_chunks_val as u64],
});
}
// PAR-118-FIX: No synchronize needed between chunk and reduce kernels.
// CUDA stream semantics guarantee all blocks from kernel A complete before
// kernel B starts on the same stream. The previous sync was a CPU-GPU
// round-trip (~10µs) × 28 layers = ~280µs/token of pure overhead.
// Reduce kernel: one block per head, reduces across chunks
let reduce_kernel = FlashDecodingReduceKernel::new(
head_dim as u32,
num_heads as u32,
1, // M=1
);
let reduce_kernel_name = reduce_kernel.name();
let reduce_module_key = format!("flash_decode_reduce_{}_{}", head_dim, num_heads);
if !self.modules.contains_key(&reduce_module_key) {
let reduce_ptx = reduce_kernel.emit_ptx_for_target(&self.kernels.sm_target);
let module = self.compile_ptx(&reduce_ptx)?;
self.modules.insert(reduce_module_key.clone(), module);
}
let reduce_config = LaunchConfig {
grid: (num_heads as u32, 1, 1),
block: (32, 1, 1),
shared_mem: 0,
};
let mut out_ptr = out_gpu.as_ptr();
let mut partials_ptr2 = partials_buf.as_ptr();
let mut max_chunks_val2 = max_chunks as u32;
// Reuse same seq_lens_ptr from chunk kernel (already set above, persistent buffer)
{
let reduce_module = self
.modules
.get_mut(&reduce_module_key)
.expect("module just compiled");
// SAFETY: All pointers from pre-allocated GPU buffers with stable addresses.
// seq_lens_ptr uses the same long-lived buffer as the chunk kernel.
unsafe {
self.stream.launch_kernel(
reduce_module,
reduce_kernel_name,
&reduce_config,
&mut [
std::ptr::from_mut(&mut partials_ptr2) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut out_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut seq_lens_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut max_chunks_val2) as *mut std::ffi::c_void,
],
)?;
}
}
// trueno#243: Record reduce kernel for manual graph construction
if self.graph_recording {
let module = self.modules.get_mut(&reduce_module_key).expect("module exists");
let func = module.get_function(reduce_kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config: reduce_config,
arg_data: vec![partials_ptr2, out_ptr, seq_lens_ptr, max_chunks_val2 as u64],
});
}
Ok(())
}
}