1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
//! PMAT-291: Transformer layer graph builder for Qwen2.5 architecture.
//!
//! Builds a ComputeGraph representing one transformer layer's decode step.
//! The graph has 14 nodes (1 leaf + 13 ops) vs the current ~15 individual
//! kernel launches per layer. Combined with CUDA graph capture across 28
//! layers, total dispatches drop from ~430 to ~392 (14 x 28) then to 1
//! graph replay.
use trueno_gpu::graph::{ComputeGraph, OpParams, TensorOp};
use crate::cuda::types::ValidatedLayerWeights;
/// Build a compute graph for one transformer decoder layer (Qwen2.5 architecture).
///
/// The graph represents:
/// ```text
/// Phase 1 (Attention):
/// input -> rmsnorm -> Q_proj -> K_proj -> V_proj -> attention(+RoPE+KV) -> O_proj -> residual_1
/// Phase 2 (FFN):
/// residual_1 -> rmsnorm -> gate_proj -> up_proj -> swiglu(gate,up) -> down_proj -> residual_2
/// ```
///
/// Returns the graph and the index of the final output node.
#[allow(clippy::too_many_arguments)]
pub fn build_layer_graph(
layer_weights: &ValidatedLayerWeights,
input_ptr: u64,
hidden_dim: u32,
intermediate_dim: u32,
q_dim: u32,
kv_dim: u32,
m: u32,
epsilon: f32,
layer_idx: usize,
// Workspace buffer pointers (pre-allocated by init_batched_workspace)
hidden_buf1_ptr: u64,
hidden_buf2_ptr: u64,
q_buf_ptr: u64,
k_buf_ptr: u64,
v_buf_ptr: u64,
attn_out_ptr: u64,
ffn_gate_ptr: u64,
ffn_up_ptr: u64,
ffn_act_ptr: u64,
input_staging_ptr: u64,
) -> (ComputeGraph, usize) {
let mut g = ComputeGraph::new();
// ========== Leaf: input hidden state ==========
let input = g.add_leaf(input_ptr, [hidden_dim, 1, m, 0]);
// ========== Phase 1: Attention ==========
// 1. Pre-attention RMSNorm
let normed_attn = g.add_op(
TensorOp::RmsNorm,
hidden_buf1_ptr,
[hidden_dim, 1, m, 0],
vec![input],
OpParams {
gamma_ptr: layer_weights.attn_norm_ptr,
scalar: epsilon,
..Default::default()
},
);
// 2. Q projection (Q4K or Q6K GEMV/GEMM)
let q = g.add_op(
TensorOp::MulMat,
q_buf_ptr,
[q_dim, hidden_dim, m, 0],
vec![normed_attn],
OpParams {
weight_ptr: layer_weights.attn_q_ptr,
..Default::default()
},
);
// 3. K projection
let k = g.add_op(
TensorOp::MulMat,
k_buf_ptr,
[kv_dim, hidden_dim, m, 0],
vec![normed_attn],
OpParams {
weight_ptr: layer_weights.attn_k_ptr,
..Default::default()
},
);
// 4. V projection
let v = g.add_op(
TensorOp::MulMat,
v_buf_ptr,
[kv_dim, hidden_dim, m, 0],
vec![normed_attn],
OpParams {
weight_ptr: layer_weights.attn_v_ptr,
..Default::default()
},
);
// 5. Attention (compound: RoPE + KV scatter + incremental attention)
// The dispatcher handles all sub-operations internally via dispatch_attention.
// Positions are passed through CudaExecutor::graph_dispatch_positions side-channel.
let attn_out = g.add_op(
TensorOp::SoftMax,
attn_out_ptr,
[q_dim, 1, m, 0],
vec![q, k, v],
OpParams {
int_param: layer_idx as u32,
..Default::default()
},
);
// 6. Output projection
let o_proj = g.add_op(
TensorOp::MulMat,
hidden_buf1_ptr,
[hidden_dim, q_dim, m, 0],
vec![attn_out],
OpParams {
weight_ptr: layer_weights.attn_output_ptr,
..Default::default()
},
);
// 7. First residual: input + o_proj -> input_staging
let residual_1 = g.add_op(
TensorOp::Add,
input_staging_ptr,
[hidden_dim, 1, m, 0],
vec![input, o_proj],
OpParams::default(),
);
// ========== Phase 2: FFN ==========
// 8. Pre-FFN RMSNorm
let normed_ffn = g.add_op(
TensorOp::RmsNorm,
hidden_buf1_ptr,
[hidden_dim, 1, m, 0],
vec![residual_1],
OpParams {
gamma_ptr: layer_weights.ffn_norm_ptr,
scalar: epsilon,
..Default::default()
},
);
// 9. Gate projection
let gate = g.add_op(
TensorOp::MulMat,
ffn_gate_ptr,
[intermediate_dim, hidden_dim, m, 0],
vec![normed_ffn],
OpParams {
weight_ptr: layer_weights.ffn_gate_ptr,
..Default::default()
},
);
// 10. Up projection
let up = g.add_op(
TensorOp::MulMat,
ffn_up_ptr,
[intermediate_dim, hidden_dim, m, 0],
vec![normed_ffn],
OpParams {
weight_ptr: layer_weights.ffn_up_ptr,
..Default::default()
},
);
// 11. SwiGLU: gate * silu(up) — element-wise, dispatched as Mul(gate, up)
// The CudaExecutor dispatch_mul maps this to batched_swiglu_into.
let ffn_act = g.add_op(
TensorOp::Mul,
ffn_act_ptr,
[intermediate_dim, 1, m, 0],
vec![gate, up],
OpParams::default(),
);
// 12. Down projection
let down = g.add_op(
TensorOp::MulMat,
hidden_buf1_ptr,
[hidden_dim, intermediate_dim, m, 0],
vec![ffn_act],
OpParams {
weight_ptr: layer_weights.ffn_down_ptr,
..Default::default()
},
);
// 13. Second residual: input_staging + down -> hidden_buf2
let residual_2 = g.add_op(
TensorOp::Add,
hidden_buf2_ptr,
[hidden_dim, 1, m, 0],
vec![residual_1, down],
OpParams::default(),
);
(g, residual_2)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cuda::types::{IndexedLayerWeights, WeightQuantType};
fn mock_weights() -> ValidatedLayerWeights {
let raw = IndexedLayerWeights {
attn_q_ptr: 0x10000,
attn_q_len: 1024,
attn_q_qtype: WeightQuantType::Q4K,
attn_k_ptr: 0x20000,
attn_k_len: 512,
attn_k_qtype: WeightQuantType::Q4K,
attn_v_ptr: 0x30000,
attn_v_len: 512,
attn_v_qtype: WeightQuantType::Q6K,
attn_output_ptr: 0x40000,
attn_output_len: 1024,
attn_output_qtype: WeightQuantType::Q4K,
ffn_gate_ptr: 0x50000,
ffn_gate_len: 2048,
ffn_gate_qtype: WeightQuantType::Q4K,
ffn_up_ptr: 0x60000,
ffn_up_len: 2048,
ffn_up_qtype: WeightQuantType::Q4K,
ffn_down_ptr: 0x70000,
ffn_down_len: 2048,
ffn_down_qtype: WeightQuantType::Q4K,
attn_norm_ptr: 0x80000,
attn_norm_len: 256,
ffn_norm_ptr: 0x90000,
ffn_norm_len: 256,
attn_q_bias_ptr: 0,
attn_q_bias_len: 0,
attn_k_bias_ptr: 0,
attn_k_bias_len: 0,
attn_v_bias_ptr: 0,
attn_v_bias_len: 0,
attn_q_norm_ptr: 0,
attn_q_norm_len: 0,
attn_k_norm_ptr: 0,
attn_k_norm_len: 0,
};
ValidatedLayerWeights::new_unchecked(raw)
}
#[test]
fn test_layer_graph_node_count() {
let weights = mock_weights();
let (graph, output_idx) = build_layer_graph(
&weights, 0xA0000, 1536, 8960, 1536, 256, 4, 1e-6, 0, 0xB0000, 0xC0000, 0xD0000,
0xE0000, 0xF0000, 0x100000, 0x110000, 0x120000, 0x130000, 0x140000,
);
// 1 leaf + 13 ops = 14 nodes total
assert_eq!(graph.nodes.len(), 14);
assert_eq!(graph.n_leafs, 1);
assert_eq!(graph.n_ops(), 13);
assert_eq!(output_idx, 13);
}
#[test]
fn test_layer_graph_execution_count() {
use trueno_gpu::graph::execute_graph;
let weights = mock_weights();
let (graph, _) = build_layer_graph(
&weights, 0xA0000, 1536, 8960, 1536, 256, 4, 1e-6, 0, 0xB0000, 0xC0000, 0xD0000,
0xE0000, 0xF0000, 0x100000, 0x110000, 0x120000, 0x130000, 0x140000,
);
struct Counter(usize);
impl trueno_gpu::graph::KernelDispatch for Counter {
fn dispatch_mul_mat(
&mut self,
_: &trueno_gpu::graph::TensorNode,
_: u64,
_: u64,
_: u32,
_: u32,
_: u32,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
fn dispatch_rms_norm(
&mut self,
_: &trueno_gpu::graph::TensorNode,
_: u64,
_: u64,
_: u32,
_: u32,
_: f32,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
fn dispatch_add(
&mut self,
_: u64,
_: u64,
_: u64,
_: usize,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
fn dispatch_rope(
&mut self,
_: &trueno_gpu::graph::TensorNode,
_: u64,
_: &[u32],
_: u32,
_: u32,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
fn dispatch_attention(
&mut self,
_: &trueno_gpu::graph::TensorNode,
_: u64,
_: u64,
_: u64,
_: u64,
_: u32,
_: usize,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
fn dispatch_copy(
&mut self,
_: u64,
_: u64,
_: usize,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
fn dispatch_mul(
&mut self,
_: u64,
_: u64,
_: u64,
_: usize,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
fn dispatch_silu(
&mut self,
_: u64,
_: u64,
_: usize,
) -> Result<(), trueno_gpu::GpuError> {
self.0 += 1;
Ok(())
}
}
let mut counter = Counter(0);
let n = execute_graph(&graph, &mut counter).unwrap();
// 13 ops per layer: 2 rmsnorm + 7 mul_mat + 1 attention + 2 add + 1 swiglu
assert_eq!(n, 13);
assert_eq!(counter.0, 13);
}
}