trueno_gpu/graph/mod.rs
1//! PMAT-291: Tensor Compute Graph for GPU Inference
2//!
3//! Inspired by ggml's compute graph pattern (transpiled via decy for reference).
4//! Pure Rust, no FFI. Reduces ~430 individual cuLaunchKernel dispatches to ~15
5//! tensor-level operations per decode step.
6//!
7//! # Design (from cross-project analysis)
8//!
9//! - ggml: C tensor graph with ~15 nodes, CUDA graph replay = 1 launch
10//! - vLLM: PyTorch/inductor IR fusion + CUDA graphs (~80 nodes)
11//! - realizr current: 430 individual kernel dispatches
12//! - realizr target: ~15 tensor ops via this module + CUDA graph replay
13//!
14//! # Academic References
15//!
16//! - [Kwon et al., SOSP 2023] PagedAttention (arxiv:2309.06180)
17//! - [Yu et al., OSDI 2022] Orca iteration-level scheduling
18//! - [Dao, NeurIPS 2022] FlashAttention (arxiv:2205.14135)
19
20pub mod executor;
21
22pub use executor::{execute_graph, KernelDispatch};
23
24/// Tensor operation types for decoder inference.
25///
26/// Each variant maps to ONE kernel dispatch. The goal is to express
27/// an entire transformer layer as ~5 operations:
28/// 1. RmsNorm (pre-attention)
29/// 2. QKV+Attention (fused projection + attention + output projection)
30/// 3. Residual add
31/// 4. RmsNorm (pre-FFN)
32/// 5. FFN (gate+up+swiglu+down fused)
33/// 6. Residual add
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TensorOp {
36 /// Matrix-vector multiply (Q4K dequant+GEMV or cuBLASLt GEMM)
37 MulMat,
38 /// Element-wise add (residual connections)
39 Add,
40 /// RMS normalization
41 RmsNorm,
42 /// Rotary position embedding
43 Rope,
44 /// Softmax (attention scores)
45 SoftMax,
46 /// Element-wise multiply (SwiGLU gate)
47 Mul,
48 /// SiLU activation
49 Silu,
50 /// Memory copy (KV cache scatter)
51 Copy,
52 /// No-op (input tensor, leaf node)
53 None,
54}
55
56/// A node in the compute graph.
57///
58/// Each node represents a tensor operation with device-memory pointers
59/// to its data and references to input nodes (by index).
60#[derive(Debug, Clone)]
61pub struct TensorNode {
62 /// Operation to perform
63 pub op: TensorOp,
64 /// Device pointer to output data
65 pub data_ptr: u64,
66 /// Output dimensions [rows, cols, batch, unused]
67 pub shape: [u32; 4],
68 /// Indices of input nodes in the graph (max 3: src0, src1, src2)
69 pub inputs: Vec<usize>,
70 /// Operation-specific parameters (e.g., epsilon for RmsNorm, position for RoPE)
71 pub params: OpParams,
72}
73
74/// Operation-specific parameters.
75#[derive(Debug, Clone, Default)]
76pub struct OpParams {
77 /// Weight pointer (for MulMat: quantized weights on device)
78 pub weight_ptr: u64,
79 /// Normalization gamma pointer (for RmsNorm)
80 pub gamma_ptr: u64,
81 /// Scalar parameter (epsilon for RmsNorm, etc.)
82 pub scalar: f32,
83 /// Integer parameter (position for RoPE, etc.)
84 pub int_param: u32,
85}
86
87/// Compute graph: topologically sorted list of tensor operations.
88///
89/// Built once per model architecture, reused every decode step.
90/// Only the data pointers and parameters change between steps.
91#[derive(Debug, Clone)]
92pub struct ComputeGraph {
93 /// Nodes in topological order (leafs first, output last)
94 pub nodes: Vec<TensorNode>,
95 /// Number of leaf nodes (inputs, no operation)
96 pub n_leafs: usize,
97}
98
99impl ComputeGraph {
100 /// Create an empty compute graph.
101 pub fn new() -> Self {
102 Self {
103 nodes: Vec::new(),
104 n_leafs: 0,
105 }
106 }
107
108 /// Add a leaf node (input tensor, no operation).
109 pub fn add_leaf(&mut self, data_ptr: u64, shape: [u32; 4]) -> usize {
110 let idx = self.nodes.len();
111 self.nodes.push(TensorNode {
112 op: TensorOp::None,
113 data_ptr,
114 shape,
115 inputs: Vec::new(),
116 params: OpParams::default(),
117 });
118 self.n_leafs += 1;
119 idx
120 }
121
122 /// Add an operation node with inputs.
123 pub fn add_op(
124 &mut self,
125 op: TensorOp,
126 data_ptr: u64,
127 shape: [u32; 4],
128 inputs: Vec<usize>,
129 params: OpParams,
130 ) -> usize {
131 let idx = self.nodes.len();
132 self.nodes.push(TensorNode {
133 op,
134 data_ptr,
135 shape,
136 inputs,
137 params,
138 });
139 idx
140 }
141
142 /// Number of operation nodes (excludes leafs).
143 pub fn n_ops(&self) -> usize {
144 self.nodes.len() - self.n_leafs
145 }
146}
147
148impl Default for ComputeGraph {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn test_empty_graph() {
160 let g = ComputeGraph::new();
161 assert_eq!(g.nodes.len(), 0);
162 assert_eq!(g.n_leafs, 0);
163 assert_eq!(g.n_ops(), 0);
164 }
165
166 #[test]
167 fn test_simple_graph() {
168 let mut g = ComputeGraph::new();
169
170 // Input tensor (leaf)
171 let input = g.add_leaf(0x1000, [1536, 1, 1, 0]);
172
173 // RmsNorm
174 let normed = g.add_op(
175 TensorOp::RmsNorm,
176 0x2000,
177 [1536, 1, 1, 0],
178 vec![input],
179 OpParams {
180 gamma_ptr: 0x3000,
181 scalar: 1e-6,
182 ..Default::default()
183 },
184 );
185
186 // MulMat (Q projection)
187 let q = g.add_op(
188 TensorOp::MulMat,
189 0x4000,
190 [1536, 1, 1, 0],
191 vec![normed],
192 OpParams {
193 weight_ptr: 0x5000,
194 ..Default::default()
195 },
196 );
197
198 assert_eq!(g.nodes.len(), 3);
199 assert_eq!(g.n_leafs, 1);
200 assert_eq!(g.n_ops(), 2);
201 assert_eq!(g.nodes[q].op, TensorOp::MulMat);
202 assert_eq!(g.nodes[q].inputs, vec![normed]);
203 }
204}