1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// trueno#243: Manual CUDA graph construction for decode loop.
//
// Bypasses cuStreamBeginCapture (broken on driver 570.207 code 901)
// by building the graph explicitly via cuGraphAddKernelNode.
//
// Protocol:
// 1. First decode token: set graph_recording=true, run eager forward
// (kernels execute AND get recorded)
// 2. Build CudaGraph from recorded kernels
// 3. Subsequent tokens: replay graph (single cuGraphLaunch)
// 4. Before replay: update position_buf + seq_len_buf via async memcpy
#![allow(clippy::wildcard_imports)]
use super::super::*;
impl CudaExecutor {
/// Start recording kernel launches for manual graph construction.
/// Kernels still execute (eager) but are also recorded.
pub(crate) fn begin_graph_recording(&mut self) {
self.graph_recorded_kernels.clear();
self.graph_recording = true;
}
/// Stop recording and build a CudaGraph from recorded kernels.
///
/// Returns the number of kernels captured.
pub(crate) fn end_graph_recording(&mut self) -> Result<usize, GpuError> {
self.graph_recording = false;
let num_kernels = self.graph_recorded_kernels.len();
if num_kernels == 0 {
return Ok(0);
}
// Build graph from recorded kernels (linear dependency chain)
let mut graph = trueno_gpu::driver::CudaGraph::new()?;
let mut prev_node = None;
for record in &self.graph_recorded_kernels {
// Reconstruct arg pointers from stored u64 values
let mut arg_storage: Vec<u64> = record.arg_data.clone();
let mut arg_ptrs: Vec<*mut std::ffi::c_void> = arg_storage
.iter_mut()
.map(|v| std::ptr::from_mut(v) as *mut std::ffi::c_void)
.collect();
let deps: Vec<trueno_gpu::driver::sys::CUgraphNode> = match prev_node {
Some(node) => vec![node],
None => vec![],
};
let node = graph.add_kernel_node(
record.func.0,
(
record.config.grid.0,
record.config.grid.1,
record.config.grid.2,
),
(
record.config.block.0,
record.config.block.1,
record.config.block.2,
),
record.config.shared_mem,
&mut arg_ptrs,
&deps,
)?;
prev_node = Some(node);
}
// Instantiate
let graph_exec = graph.instantiate()?;
self.decode_graph = Some(graph_exec);
self.decode_token_count = 1;
// realizr#198 DEBUG: Log recorded arg pointers for first AND last kernel
// First kernel = RMSNorm (reads graph_input_buf)
// Last kernel = LM head GEMV or bias add (writes logits_buf)
let first_args = self.graph_recorded_kernels.first().map(|k| {
k.arg_data
.iter()
.map(|a| format!("{:#x}", a))
.collect::<Vec<_>>()
});
let last_args = self.graph_recorded_kernels.last().map(|k| {
k.arg_data
.iter()
.map(|a| format!("{:#x}", a))
.collect::<Vec<_>>()
});
// Also log logits_buf current pointer for comparison
let logits_ptr = self
.workspace
.logits_buf
.as_ref()
.map(|b| b.as_ptr())
.unwrap_or(0);
eprintln!(
"[trueno#243] ✓ Manual graph: {} kernels. first_args={:?}, last_args={:?}, current_logits_buf={:#x}",
num_kernels, first_args, last_args, logits_ptr
);
Ok(num_kernels)
}
/// Record a kernel launch for manual graph construction.
/// Called by kernel dispatch functions when graph_recording is true.
pub(crate) fn record_kernel_launch(
&mut self,
func: trueno_gpu::driver::sys::CUfunction,
config: &LaunchConfig,
args: &[u64],
) {
if self.graph_recording {
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config: config.clone(),
arg_data: args.to_vec(),
});
}
}
}