Skip to main content

trueno_gpu/graph/
executor.rs

1//! PMAT-291: Graph executor -- dispatches tensor operations to CUDA kernels.
2//!
3//! Each TensorOp maps to ONE kernel launch. The executor walks the compute
4//! graph in topological order and dispatches the appropriate kernel for each
5//! node. Combined with CUDA graph capture, this reduces 430 launches to
6//! ~15 tensor-level dispatches, then 1 graph replay.
7//!
8//! # Kernel Mapping
9//!
10//! | TensorOp  | Kernel                        | Source |
11//! |-----------|-------------------------------|--------|
12//! | MulMat    | BatchedHwDp4aQ4KGemvKernel    | trueno-gpu/kernels/quantize/q4k/ |
13//! | RmsNorm   | BatchedVectorizedRmsNormKernel | trueno-gpu/kernels/layernorm/ |
14//! | Add       | BatchedResidualAddKernel       | trueno-gpu/kernels/ |
15//! | Rope      | BatchedRopeKernel              | trueno-gpu/kernels/ |
16//! | Mul       | FusedGateUpSwigluKernel        | trueno-gpu/kernels/quantize/q4k/ |
17//! | SoftMax   | (attention dispatch)           | realizr attention module |
18//! | Copy      | cuMemcpyDtoDAsync              | trueno driver |
19
20use super::{ComputeGraph, TensorNode, TensorOp};
21
22/// Result of executing a compute graph.
23#[derive(Debug)]
24pub struct GraphExecResult {
25    /// Number of kernel launches performed
26    pub n_launches: usize,
27    /// Total execution time in microseconds (if timing enabled)
28    pub elapsed_us: Option<u64>,
29}
30
31/// Trait for dispatching tensor operations to GPU kernels.
32///
33/// Implementors provide the actual kernel launch logic for each TensorOp.
34/// This decouples the graph execution from the specific kernel implementations,
35/// allowing realizr to plug in its own kernel dispatch (DP4A, FP8, cuBLASLt).
36pub trait KernelDispatch {
37    /// Dispatch a MulMat operation (quantized GEMV or GEMM).
38    ///
39    /// # Arguments
40    /// * `node` - The tensor node with weight_ptr in params and input data
41    /// * `input_ptr` - Device pointer to input activation
42    /// * `output_ptr` - Device pointer to output buffer
43    /// * `m` - Batch size
44    /// * `n` - Output dimension
45    /// * `k` - Input dimension
46    fn dispatch_mul_mat(
47        &mut self,
48        node: &TensorNode,
49        input_ptr: u64,
50        output_ptr: u64,
51        m: u32,
52        n: u32,
53        k: u32,
54    ) -> Result<(), crate::GpuError>;
55
56    /// Dispatch a RmsNorm operation.
57    fn dispatch_rms_norm(
58        &mut self,
59        node: &TensorNode,
60        input_ptr: u64,
61        output_ptr: u64,
62        hidden_dim: u32,
63        m: u32,
64        epsilon: f32,
65    ) -> Result<(), crate::GpuError>;
66
67    /// Dispatch an element-wise Add (residual connection).
68    fn dispatch_add(
69        &mut self,
70        a_ptr: u64,
71        b_ptr: u64,
72        output_ptr: u64,
73        n_elements: usize,
74    ) -> Result<(), crate::GpuError>;
75
76    /// Dispatch RoPE position embedding.
77    fn dispatch_rope(
78        &mut self,
79        node: &TensorNode,
80        qk_ptr: u64,
81        positions: &[u32],
82        head_dim: u32,
83        num_heads: u32,
84    ) -> Result<(), crate::GpuError>;
85
86    /// Dispatch attention (incremental or flash).
87    fn dispatch_attention(
88        &mut self,
89        node: &TensorNode,
90        q_ptr: u64,
91        k_ptr: u64,
92        v_ptr: u64,
93        output_ptr: u64,
94        m: u32,
95        layer_idx: usize,
96    ) -> Result<(), crate::GpuError>;
97
98    /// Dispatch KV cache scatter (copy).
99    fn dispatch_copy(
100        &mut self,
101        src_ptr: u64,
102        dst_ptr: u64,
103        size_bytes: usize,
104    ) -> Result<(), crate::GpuError>;
105
106    /// Dispatch element-wise multiply (SwiGLU gate).
107    fn dispatch_mul(
108        &mut self,
109        a_ptr: u64,
110        b_ptr: u64,
111        output_ptr: u64,
112        n_elements: usize,
113    ) -> Result<(), crate::GpuError>;
114
115    /// Dispatch SiLU activation.
116    fn dispatch_silu(
117        &mut self,
118        input_ptr: u64,
119        output_ptr: u64,
120        n_elements: usize,
121    ) -> Result<(), crate::GpuError>;
122}
123
124/// Execute a compute graph using the provided kernel dispatcher.
125///
126/// Walks nodes in topological order. Leaf nodes (TensorOp::None) are
127/// skipped -- they represent input tensors whose data is already on device.
128///
129/// Returns the number of kernel launches performed.
130pub fn execute_graph<D: KernelDispatch>(
131    graph: &ComputeGraph,
132    dispatcher: &mut D,
133) -> Result<usize, crate::GpuError> {
134    let mut n_launches = 0;
135
136    for node in &graph.nodes {
137        match node.op {
138            TensorOp::None => {
139                // Leaf node -- input data already on device, nothing to dispatch
140            }
141            TensorOp::MulMat => {
142                let input_idx = node.inputs.first().copied().unwrap_or(0);
143                let input_ptr = graph.nodes[input_idx].data_ptr;
144                dispatcher.dispatch_mul_mat(
145                    node,
146                    input_ptr,
147                    node.data_ptr,
148                    node.shape[2], // m (batch)
149                    node.shape[0], // n (output dim)
150                    node.shape[1], // k (input dim)
151                )?;
152                n_launches += 1;
153            }
154            TensorOp::RmsNorm => {
155                let input_idx = node.inputs.first().copied().unwrap_or(0);
156                let input_ptr = graph.nodes[input_idx].data_ptr;
157                dispatcher.dispatch_rms_norm(
158                    node,
159                    input_ptr,
160                    node.data_ptr,
161                    node.shape[0],      // hidden_dim
162                    node.shape[2],      // m (batch)
163                    node.params.scalar, // epsilon
164                )?;
165                n_launches += 1;
166            }
167            TensorOp::Add => {
168                let a_idx = node.inputs.first().copied().unwrap_or(0);
169                let b_idx = node.inputs.get(1).copied().unwrap_or(0);
170                let a_ptr = graph.nodes[a_idx].data_ptr;
171                let b_ptr = graph.nodes[b_idx].data_ptr;
172                let n_elements = (node.shape[0] * node.shape[2]) as usize;
173                dispatcher.dispatch_add(a_ptr, b_ptr, node.data_ptr, n_elements)?;
174                n_launches += 1;
175            }
176            TensorOp::Rope => {
177                let input_idx = node.inputs.first().copied().unwrap_or(0);
178                let input_ptr = graph.nodes[input_idx].data_ptr;
179                // positions passed via params.int_param as base position
180                dispatcher.dispatch_rope(
181                    node,
182                    input_ptr,
183                    &[],           // positions filled at runtime
184                    node.shape[0], // head_dim
185                    node.shape[1], // num_heads
186                )?;
187                n_launches += 1;
188            }
189            TensorOp::SoftMax => {
190                // Attention is dispatched as a compound operation
191                // The dispatcher handles Q/K/V/output internally
192                let q_idx = node.inputs.first().copied().unwrap_or(0);
193                let k_idx = node.inputs.get(1).copied().unwrap_or(0);
194                let v_idx = node.inputs.get(2).copied().unwrap_or(0);
195                dispatcher.dispatch_attention(
196                    node,
197                    graph.nodes[q_idx].data_ptr,
198                    graph.nodes[k_idx].data_ptr,
199                    graph.nodes[v_idx].data_ptr,
200                    node.data_ptr,
201                    node.shape[2],                  // m
202                    node.params.int_param as usize, // layer_idx
203                )?;
204                n_launches += 1;
205            }
206            TensorOp::Copy => {
207                let src_idx = node.inputs.first().copied().unwrap_or(0);
208                let src_ptr = graph.nodes[src_idx].data_ptr;
209                let size = (node.shape[0] * node.shape[1] * 4) as usize; // f32
210                dispatcher.dispatch_copy(src_ptr, node.data_ptr, size)?;
211                n_launches += 1;
212            }
213            TensorOp::Mul => {
214                let a_idx = node.inputs.first().copied().unwrap_or(0);
215                let b_idx = node.inputs.get(1).copied().unwrap_or(0);
216                let n_elements = (node.shape[0] * node.shape[2]) as usize;
217                dispatcher.dispatch_mul(
218                    graph.nodes[a_idx].data_ptr,
219                    graph.nodes[b_idx].data_ptr,
220                    node.data_ptr,
221                    n_elements,
222                )?;
223                n_launches += 1;
224            }
225            TensorOp::Silu => {
226                let input_idx = node.inputs.first().copied().unwrap_or(0);
227                let n_elements = (node.shape[0] * node.shape[2]) as usize;
228                dispatcher.dispatch_silu(
229                    graph.nodes[input_idx].data_ptr,
230                    node.data_ptr,
231                    n_elements,
232                )?;
233                n_launches += 1;
234            }
235        }
236    }
237
238    Ok(n_launches)
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    /// Mock dispatcher that counts launches
246    struct CountingDispatcher {
247        launches: usize,
248    }
249
250    impl KernelDispatch for CountingDispatcher {
251        fn dispatch_mul_mat(
252            &mut self,
253            _: &TensorNode,
254            _: u64,
255            _: u64,
256            _: u32,
257            _: u32,
258            _: u32,
259        ) -> Result<(), crate::GpuError> {
260            self.launches += 1;
261            Ok(())
262        }
263        fn dispatch_rms_norm(
264            &mut self,
265            _: &TensorNode,
266            _: u64,
267            _: u64,
268            _: u32,
269            _: u32,
270            _: f32,
271        ) -> Result<(), crate::GpuError> {
272            self.launches += 1;
273            Ok(())
274        }
275        fn dispatch_add(
276            &mut self,
277            _: u64,
278            _: u64,
279            _: u64,
280            _: usize,
281        ) -> Result<(), crate::GpuError> {
282            self.launches += 1;
283            Ok(())
284        }
285        fn dispatch_rope(
286            &mut self,
287            _: &TensorNode,
288            _: u64,
289            _: &[u32],
290            _: u32,
291            _: u32,
292        ) -> Result<(), crate::GpuError> {
293            self.launches += 1;
294            Ok(())
295        }
296        fn dispatch_attention(
297            &mut self,
298            _: &TensorNode,
299            _: u64,
300            _: u64,
301            _: u64,
302            _: u64,
303            _: u32,
304            _: usize,
305        ) -> Result<(), crate::GpuError> {
306            self.launches += 1;
307            Ok(())
308        }
309        fn dispatch_copy(&mut self, _: u64, _: u64, _: usize) -> Result<(), crate::GpuError> {
310            self.launches += 1;
311            Ok(())
312        }
313        fn dispatch_mul(
314            &mut self,
315            _: u64,
316            _: u64,
317            _: u64,
318            _: usize,
319        ) -> Result<(), crate::GpuError> {
320            self.launches += 1;
321            Ok(())
322        }
323        fn dispatch_silu(&mut self, _: u64, _: u64, _: usize) -> Result<(), crate::GpuError> {
324            self.launches += 1;
325            Ok(())
326        }
327    }
328
329    #[test]
330    fn test_execute_empty_graph() {
331        let g = ComputeGraph::new();
332        let mut d = CountingDispatcher { launches: 0 };
333        let n = execute_graph(&g, &mut d).unwrap();
334        assert_eq!(n, 0);
335        assert_eq!(d.launches, 0);
336    }
337
338    #[test]
339    fn test_execute_single_layer_graph() {
340        use super::super::OpParams;
341
342        let mut g = ComputeGraph::new();
343
344        // Build a minimal transformer layer graph:
345        // input -> rmsnorm -> mul_mat(Q) -> attention -> add(residual)
346        let input = g.add_leaf(0x1000, [1536, 1, 4, 0]);
347        let normed = g.add_op(
348            TensorOp::RmsNorm,
349            0x2000,
350            [1536, 1, 4, 0],
351            vec![input],
352            OpParams {
353                gamma_ptr: 0x3000,
354                scalar: 1e-6,
355                ..Default::default()
356            },
357        );
358        let q = g.add_op(
359            TensorOp::MulMat,
360            0x4000,
361            [1536, 1536, 4, 0],
362            vec![normed],
363            OpParams {
364                weight_ptr: 0x5000,
365                ..Default::default()
366            },
367        );
368        let k = g.add_op(
369            TensorOp::MulMat,
370            0x6000,
371            [256, 1536, 4, 0],
372            vec![normed],
373            OpParams {
374                weight_ptr: 0x7000,
375                ..Default::default()
376            },
377        );
378        let v = g.add_op(
379            TensorOp::MulMat,
380            0x8000,
381            [256, 1536, 4, 0],
382            vec![normed],
383            OpParams {
384                weight_ptr: 0x9000,
385                ..Default::default()
386            },
387        );
388        let attn = g.add_op(
389            TensorOp::SoftMax,
390            0xA000,
391            [1536, 1, 4, 0],
392            vec![q, k, v],
393            OpParams {
394                int_param: 0,
395                ..Default::default()
396            },
397        );
398        let _residual = g.add_op(
399            TensorOp::Add,
400            0xB000,
401            [1536, 1, 4, 0],
402            vec![input, attn],
403            OpParams::default(),
404        );
405
406        let mut d = CountingDispatcher { launches: 0 };
407        let n = execute_graph(&g, &mut d).unwrap();
408
409        // 6 ops: rmsnorm + 3 mul_mat + attention + add = 6 launches
410        assert_eq!(n, 6);
411        assert_eq!(d.launches, 6);
412        assert_eq!(g.n_ops(), 6);
413    }
414}