Skip to main content

trueno_gpu/graph/
mod.rs

1//! PMAT-291: Tensor Compute Graph for GPU Inference
2//!
3//! Inspired by ggml's compute graph pattern (transpiled via decy for reference).
4//! Pure Rust, no FFI. Reduces ~430 individual cuLaunchKernel dispatches to ~15
5//! tensor-level operations per decode step.
6//!
7//! # Design (from cross-project analysis)
8//!
9//! - ggml: C tensor graph with ~15 nodes, CUDA graph replay = 1 launch
10//! - vLLM: PyTorch/inductor IR fusion + CUDA graphs (~80 nodes)
11//! - realizr current: 430 individual kernel dispatches
12//! - realizr target: ~15 tensor ops via this module + CUDA graph replay
13//!
14//! # Academic References
15//!
16//! - [Kwon et al., SOSP 2023] PagedAttention (arxiv:2309.06180)
17//! - [Yu et al., OSDI 2022] Orca iteration-level scheduling
18//! - [Dao, NeurIPS 2022] FlashAttention (arxiv:2205.14135)
19
20pub mod executor;
21
22pub use executor::{execute_graph, KernelDispatch};
23
24/// Tensor operation types for decoder inference.
25///
26/// Each variant maps to ONE kernel dispatch. The goal is to express
27/// an entire transformer layer as ~5 operations:
28/// 1. RmsNorm (pre-attention)
29/// 2. QKV+Attention (fused projection + attention + output projection)
30/// 3. Residual add
31/// 4. RmsNorm (pre-FFN)
32/// 5. FFN (gate+up+swiglu+down fused)
33/// 6. Residual add
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TensorOp {
36    /// Matrix-vector multiply (Q4K dequant+GEMV or cuBLASLt GEMM)
37    MulMat,
38    /// Element-wise add (residual connections)
39    Add,
40    /// RMS normalization
41    RmsNorm,
42    /// Rotary position embedding
43    Rope,
44    /// Softmax (attention scores)
45    SoftMax,
46    /// Element-wise multiply (SwiGLU gate)
47    Mul,
48    /// SiLU activation
49    Silu,
50    /// Memory copy (KV cache scatter)
51    Copy,
52    /// No-op (input tensor, leaf node)
53    None,
54}
55
56/// A node in the compute graph.
57///
58/// Each node represents a tensor operation with device-memory pointers
59/// to its data and references to input nodes (by index).
60#[derive(Debug, Clone)]
61pub struct TensorNode {
62    /// Operation to perform
63    pub op: TensorOp,
64    /// Device pointer to output data
65    pub data_ptr: u64,
66    /// Output dimensions [rows, cols, batch, unused]
67    pub shape: [u32; 4],
68    /// Indices of input nodes in the graph (max 3: src0, src1, src2)
69    pub inputs: Vec<usize>,
70    /// Operation-specific parameters (e.g., epsilon for RmsNorm, position for RoPE)
71    pub params: OpParams,
72}
73
74/// Operation-specific parameters.
75#[derive(Debug, Clone, Default)]
76pub struct OpParams {
77    /// Weight pointer (for MulMat: quantized weights on device)
78    pub weight_ptr: u64,
79    /// Normalization gamma pointer (for RmsNorm)
80    pub gamma_ptr: u64,
81    /// Scalar parameter (epsilon for RmsNorm, etc.)
82    pub scalar: f32,
83    /// Integer parameter (position for RoPE, etc.)
84    pub int_param: u32,
85}
86
87/// Compute graph: topologically sorted list of tensor operations.
88///
89/// Built once per model architecture, reused every decode step.
90/// Only the data pointers and parameters change between steps.
91#[derive(Debug, Clone)]
92pub struct ComputeGraph {
93    /// Nodes in topological order (leafs first, output last)
94    pub nodes: Vec<TensorNode>,
95    /// Number of leaf nodes (inputs, no operation)
96    pub n_leafs: usize,
97}
98
99impl ComputeGraph {
100    /// Create an empty compute graph.
101    pub fn new() -> Self {
102        Self {
103            nodes: Vec::new(),
104            n_leafs: 0,
105        }
106    }
107
108    /// Add a leaf node (input tensor, no operation).
109    pub fn add_leaf(&mut self, data_ptr: u64, shape: [u32; 4]) -> usize {
110        let idx = self.nodes.len();
111        self.nodes.push(TensorNode {
112            op: TensorOp::None,
113            data_ptr,
114            shape,
115            inputs: Vec::new(),
116            params: OpParams::default(),
117        });
118        self.n_leafs += 1;
119        idx
120    }
121
122    /// Add an operation node with inputs.
123    pub fn add_op(
124        &mut self,
125        op: TensorOp,
126        data_ptr: u64,
127        shape: [u32; 4],
128        inputs: Vec<usize>,
129        params: OpParams,
130    ) -> usize {
131        let idx = self.nodes.len();
132        self.nodes.push(TensorNode {
133            op,
134            data_ptr,
135            shape,
136            inputs,
137            params,
138        });
139        idx
140    }
141
142    /// Number of operation nodes (excludes leafs).
143    pub fn n_ops(&self) -> usize {
144        self.nodes.len() - self.n_leafs
145    }
146}
147
148impl Default for ComputeGraph {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn test_empty_graph() {
160        let g = ComputeGraph::new();
161        assert_eq!(g.nodes.len(), 0);
162        assert_eq!(g.n_leafs, 0);
163        assert_eq!(g.n_ops(), 0);
164    }
165
166    #[test]
167    fn test_simple_graph() {
168        let mut g = ComputeGraph::new();
169
170        // Input tensor (leaf)
171        let input = g.add_leaf(0x1000, [1536, 1, 1, 0]);
172
173        // RmsNorm
174        let normed = g.add_op(
175            TensorOp::RmsNorm,
176            0x2000,
177            [1536, 1, 1, 0],
178            vec![input],
179            OpParams {
180                gamma_ptr: 0x3000,
181                scalar: 1e-6,
182                ..Default::default()
183            },
184        );
185
186        // MulMat (Q projection)
187        let q = g.add_op(
188            TensorOp::MulMat,
189            0x4000,
190            [1536, 1, 1, 0],
191            vec![normed],
192            OpParams {
193                weight_ptr: 0x5000,
194                ..Default::default()
195            },
196        );
197
198        assert_eq!(g.nodes.len(), 3);
199        assert_eq!(g.n_leafs, 1);
200        assert_eq!(g.n_ops(), 2);
201        assert_eq!(g.nodes[q].op, TensorOp::MulMat);
202        assert_eq!(g.nodes[q].inputs, vec![normed]);
203    }
204}