1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
//! PMAT-291: Graph-based transformer layer decode.
//!
//! Alternative to `transformer_layer_batched()` that uses the tensor graph
//! executor. Dispatches 13 tensor-level operations instead of ~15 individual
//! kernel calls, providing a cleaner abstraction and foundation for CUDA
//! graph capture with fewer nodes.
//!
//! Activated by `GRAPH_DISPATCH=1` environment variable.
use trueno_gpu::driver::GpuBuffer;
use trueno_gpu::graph::execute_graph;
use trueno_gpu::GpuError;
use crate::cuda::executor::graph_builder::build_layer_graph;
use crate::cuda::executor::CudaExecutor;
use crate::cuda::types::ValidatedLayerWeights;
impl CudaExecutor {
/// PMAT-291: Graph-based batched transformer layer decode.
///
/// Equivalent to `transformer_layer_batched()` but routes through the
/// tensor graph executor. Each layer is expressed as a 14-node compute
/// graph (1 leaf + 13 ops) that dispatches the same kernels.
///
/// The graph path skips the fused DP4A QKV optimization (uses individual
/// projections) but gains a simpler dispatch loop suitable for CUDA graph
/// capture with ~14 nodes per layer instead of ~15+ individual kernel calls.
#[allow(clippy::too_many_arguments)]
pub fn transformer_layer_batched_graph(
&mut self,
input: &GpuBuffer<f32>,
layer_idx: usize,
layer_weights: &ValidatedLayerWeights,
m: u32,
positions: &[u32],
hidden_dim: u32,
intermediate_dim: u32,
epsilon: f32,
) -> Result<(), GpuError> {
self.validate_batched_workspace(m, positions.len())?;
let q_dim = (self.kv_num_heads * self.kv_head_dim) as u32;
let kv_dim = (self.kv_num_kv_heads * self.kv_head_dim) as u32;
// Extract workspace buffer pointers
let hidden_buf1_ptr = self
.workspace
.hidden_buf1
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: hidden_buf1".to_string()))?
.as_ptr();
let hidden_buf2_ptr = self
.workspace
.hidden_buf2
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: hidden_buf2".to_string()))?
.as_ptr();
let input_staging_ptr = self
.workspace
.input_staging
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: input_staging".to_string()))?
.as_ptr();
let q_buf_ptr = self
.workspace
.q_buf
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: q_buf".to_string()))?
.as_ptr();
let k_buf_ptr = self
.workspace
.k_buf
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: k_buf".to_string()))?
.as_ptr();
let v_buf_ptr = self
.workspace
.v_buf
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: v_buf".to_string()))?
.as_ptr();
let attn_out_ptr = self
.workspace
.attn_out_buf
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: attn_out".to_string()))?
.as_ptr();
let ffn_gate_ptr = self
.workspace
.ffn_gate_buf
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: ffn_gate".to_string()))?
.as_ptr();
let ffn_up_ptr = self
.workspace
.ffn_up_buf
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: ffn_up".to_string()))?
.as_ptr();
let ffn_act_ptr = self
.workspace
.ffn_act_buf
.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("PMAT-291: ffn_act".to_string()))?
.as_ptr();
// Build the layer graph
let (graph, _output_idx) = build_layer_graph(
layer_weights,
input.as_ptr(),
hidden_dim,
intermediate_dim,
q_dim,
kv_dim,
m,
epsilon,
layer_idx,
hidden_buf1_ptr,
hidden_buf2_ptr,
q_buf_ptr,
k_buf_ptr,
v_buf_ptr,
attn_out_ptr,
ffn_gate_ptr,
ffn_up_ptr,
ffn_act_ptr,
input_staging_ptr,
);
// Set positions side-channel for dispatch_attention
self.graph_dispatch_positions = positions.to_vec();
// Execute the graph
let _n_launches = execute_graph(&graph, self)?;
// Clear positions
self.graph_dispatch_positions.clear();
// Output is in hidden_buf2 (same as transformer_layer_batched)
Ok(())
}
/// Check if graph dispatch is enabled (default: ON, opt-out with GRAPH_DISPATCH=0).
/// Cached after first check to avoid repeated env var lookups.
pub(crate) fn use_graph_dispatch(&self) -> bool {
static ENABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*ENABLED.get_or_init(|| std::env::var("GRAPH_DISPATCH").as_deref() != Ok("0"))
}
}