Skip to main content

trueno/brick/exec_graph/node/
mod.rs

1//! Execution Graph Node Types and Profiling Primitives
2//!
3//! This module contains all type definitions for execution path tracking:
4//!
5//! - **PAR-073**: BrickSample, BrickBottleneck - foundational profiling primitives
6//! - **PAR-200**: BrickId, BrickCategory, SyncMode - O(1) hot path brick identification
7//! - **PAR-201**: ExecutionNode, EdgeType, etc. - execution hierarchy types
8
9use std::fmt;
10
11mod execution;
12mod stats;
13
14pub use execution::{EdgeType, ExecutionEdge, ExecutionNode, ExecutionNodeId, TransferDirection};
15pub use stats::{BrickStats, CategoryStats, PtxRegistry};
16
17// ============================================================================
18// BrickProfiler: FOUNDATIONAL Real-Time Per-Brick Timing (PAR-073)
19// ============================================================================
20
21/// Individual brick timing sample.
22/// Pure Rust timing using `std::time::Instant`.
23#[derive(Debug, Clone, Copy)]
24pub struct BrickSample {
25    /// Brick name hash (for fast lookup)
26    pub brick_id: u64,
27    /// Elapsed time in nanoseconds
28    pub elapsed_ns: u64,
29    /// Number of elements processed
30    pub elements: u64,
31}
32
33/// Bottleneck classification for roofline analysis (PMAT-451)
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub enum BrickBottleneck {
36    /// Not classified
37    #[default]
38    Unknown,
39    /// Limited by memory bandwidth
40    Memory,
41    /// Limited by compute throughput
42    Compute,
43}
44
45impl fmt::Display for BrickBottleneck {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            BrickBottleneck::Unknown => write!(f, "unknown"),
49            BrickBottleneck::Memory => write!(f, "memory"),
50            BrickBottleneck::Compute => write!(f, "compute"),
51        }
52    }
53}
54
55// ============================================================================
56// PAR-200: BrickProfiler v2 - O(1) Hot Path with BrickId Enum
57// ============================================================================
58
59/// Well-known brick types for O(1) lookup on hot path.
60///
61/// PAR-200: Eliminates string allocation and HashMap hashing during profiling.
62/// Use `BrickId::Custom` with string fallback for unknown brick types.
63///
64/// # Example
65/// ```rust
66/// use trueno::brick::BrickId;
67///
68/// let brick = BrickId::RmsNorm;
69/// assert_eq!(brick.category(), trueno::brick::BrickCategory::Norm);
70/// assert_eq!(brick.name(), "RmsNorm");
71/// ```
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73#[repr(u8)]
74pub enum BrickId {
75    // Normalization (0-1)
76    /// RMS normalization layer
77    RmsNorm = 0,
78    /// Layer normalization
79    LayerNorm = 1,
80
81    // Attention (2-7)
82    /// Q/K/V projection (combined or separate)
83    QkvProjection = 2,
84    /// Rotary position embedding
85    RopeEmbedding = 3,
86    /// Attention score computation (Q @ K^T)
87    AttentionScore = 4,
88    /// Attention softmax
89    AttentionSoftmax = 5,
90    /// Attention output (scores @ V)
91    AttentionOutput = 6,
92    /// Output projection after attention
93    OutputProjection = 7,
94
95    // FFN (8-11)
96    /// Gate projection (for gated FFN)
97    GateProjection = 8,
98    /// Up projection
99    UpProjection = 9,
100    /// SiLU/GELU/ReLU activation
101    Activation = 10,
102    /// Down projection
103    DownProjection = 11,
104
105    // Other (12-14)
106    /// Token embedding lookup
107    Embedding = 12,
108    /// Language model head (logits)
109    LmHead = 13,
110    /// Token sampling
111    Sampling = 14,
112
113    // Sparse (15-17) — CUDA-parity-spec Phase 1
114    /// Sparse matrix-vector multiply (cuSPARSE parity)
115    SpMV = 15,
116    /// Sparse matrix-dense matrix multiply
117    SpMM = 16,
118    /// Sparse format conversion
119    FormatConvert = 17,
120
121    // FFT (18-19) — CUDA-parity-spec Phase 2
122    /// 1D Fast Fourier Transform
123    FFT1D = 18,
124    /// 2D Fast Fourier Transform
125    FFT2D = 19,
126
127    // Solvers (20-22) — CUDA-parity-spec Phase 2
128    /// LU factorization
129    LUFactorize = 20,
130    /// QR factorization
131    QRFactorize = 21,
132    /// Singular Value Decomposition
133    SVDCompute = 22,
134}
135
136impl BrickId {
137    /// Number of well-known brick types.
138    pub const COUNT: usize = 23;
139
140    /// All BrickId variants in order, for safe index-based iteration.
141    ///
142    /// Eliminates need for `transmute::<u8, BrickId>` in array initialization.
143    pub const ALL: [BrickId; Self::COUNT] = [
144        Self::RmsNorm,
145        Self::LayerNorm,
146        Self::QkvProjection,
147        Self::RopeEmbedding,
148        Self::AttentionScore,
149        Self::AttentionSoftmax,
150        Self::AttentionOutput,
151        Self::OutputProjection,
152        Self::GateProjection,
153        Self::UpProjection,
154        Self::Activation,
155        Self::DownProjection,
156        Self::Embedding,
157        Self::LmHead,
158        Self::Sampling,
159        Self::SpMV,
160        Self::SpMM,
161        Self::FormatConvert,
162        Self::FFT1D,
163        Self::FFT2D,
164        Self::LUFactorize,
165        Self::QRFactorize,
166        Self::SVDCompute,
167    ];
168
169    /// Validate that a raw u8 is within the BrickId range.
170    #[inline]
171    pub fn validate_index(index: usize) -> bool {
172        debug_assert!(
173            index < Self::COUNT,
174            "CB-BUDGET: brick index {} out of bounds (max {})",
175            index,
176            Self::COUNT
177        );
178        index < Self::COUNT
179    }
180
181    /// Get the category for hierarchical aggregation.
182    #[inline]
183    pub fn category(self) -> BrickCategory {
184        match self {
185            Self::RmsNorm | Self::LayerNorm => BrickCategory::Norm,
186            Self::QkvProjection
187            | Self::RopeEmbedding
188            | Self::AttentionScore
189            | Self::AttentionSoftmax
190            | Self::AttentionOutput
191            | Self::OutputProjection => BrickCategory::Attention,
192            Self::GateProjection | Self::UpProjection | Self::Activation | Self::DownProjection => {
193                BrickCategory::Ffn
194            }
195            Self::Embedding | Self::LmHead | Self::Sampling => BrickCategory::Other,
196            Self::SpMV | Self::SpMM | Self::FormatConvert => BrickCategory::Sparse,
197            Self::FFT1D | Self::FFT2D => BrickCategory::Fft,
198            Self::LUFactorize | Self::QRFactorize | Self::SVDCompute => BrickCategory::Solver,
199        }
200    }
201
202    /// Get the string name of this brick.
203    #[inline]
204    pub const fn name(self) -> &'static str {
205        match self {
206            Self::RmsNorm => "RmsNorm",
207            Self::LayerNorm => "LayerNorm",
208            Self::QkvProjection => "QkvProjection",
209            Self::RopeEmbedding => "RopeEmbedding",
210            Self::AttentionScore => "AttentionScore",
211            Self::AttentionSoftmax => "AttentionSoftmax",
212            Self::AttentionOutput => "AttentionOutput",
213            Self::OutputProjection => "OutputProjection",
214            Self::GateProjection => "GateProjection",
215            Self::UpProjection => "UpProjection",
216            Self::Activation => "Activation",
217            Self::DownProjection => "DownProjection",
218            Self::Embedding => "Embedding",
219            Self::LmHead => "LmHead",
220            Self::Sampling => "Sampling",
221            Self::SpMV => "SpMV",
222            Self::SpMM => "SpMM",
223            Self::FormatConvert => "FormatConvert",
224            Self::FFT1D => "FFT1D",
225            Self::FFT2D => "FFT2D",
226            Self::LUFactorize => "LUFactorize",
227            Self::QRFactorize => "QRFactorize",
228            Self::SVDCompute => "SVDCompute",
229        }
230    }
231
232    /// Try to parse a string into a BrickId.
233    #[allow(clippy::should_implement_trait)]
234    pub fn from_str(s: &str) -> Option<Self> {
235        match s {
236            "RmsNorm" => Some(Self::RmsNorm),
237            "LayerNorm" => Some(Self::LayerNorm),
238            "QkvProjection" | "Qkv" => Some(Self::QkvProjection),
239            "RopeEmbedding" | "Rope" | "RoPE" => Some(Self::RopeEmbedding),
240            "AttentionScore" => Some(Self::AttentionScore),
241            "AttentionSoftmax" | "Softmax" => Some(Self::AttentionSoftmax),
242            "AttentionOutput" => Some(Self::AttentionOutput),
243            "OutputProjection" | "OutProj" => Some(Self::OutputProjection),
244            "GateProjection" | "Gate" => Some(Self::GateProjection),
245            "UpProjection" | "Up" => Some(Self::UpProjection),
246            "Activation" | "SiLU" | "GELU" | "ReLU" => Some(Self::Activation),
247            "DownProjection" | "Down" => Some(Self::DownProjection),
248            "Embedding" | "Embed" => Some(Self::Embedding),
249            "LmHead" | "Head" => Some(Self::LmHead),
250            "Sampling" | "Sample" => Some(Self::Sampling),
251            "SpMV" | "spmv" => Some(Self::SpMV),
252            "SpMM" | "spmm" => Some(Self::SpMM),
253            "FormatConvert" => Some(Self::FormatConvert),
254            "FFT1D" | "fft1d" | "FFT" => Some(Self::FFT1D),
255            "FFT2D" | "fft2d" => Some(Self::FFT2D),
256            "LUFactorize" | "LU" => Some(Self::LUFactorize),
257            "QRFactorize" | "QR" => Some(Self::QRFactorize),
258            "SVDCompute" | "SVD" => Some(Self::SVDCompute),
259            _ => None,
260        }
261    }
262}
263
264impl fmt::Display for BrickId {
265    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266        write!(f, "{}", self.name())
267    }
268}
269
270/// Category for hierarchical aggregation of brick statistics.
271///
272/// PAR-200: Groups related bricks for high-level performance analysis.
273#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
274#[repr(u8)]
275pub enum BrickCategory {
276    /// Normalization layers (RmsNorm, LayerNorm)
277    Norm = 0,
278    /// Attention mechanism (QKV, RoPE, scores, softmax, output)
279    Attention = 1,
280    /// Feed-forward network (gate, up, activation, down)
281    Ffn = 2,
282    /// Other operations (embedding, lm_head, sampling)
283    #[default]
284    Other = 3,
285    /// Sparse linear algebra (SpMV, SpMM, format conversion)
286    Sparse = 4,
287    /// FFT operations (1D, 2D, 3D)
288    Fft = 5,
289    /// Dense solvers (LU, QR, SVD, Cholesky)
290    Solver = 6,
291}
292
293impl BrickCategory {
294    /// Number of categories.
295    pub const COUNT: usize = 7;
296
297    /// All BrickCategory variants in order, for safe index-based iteration.
298    pub const ALL: [BrickCategory; Self::COUNT] = [
299        Self::Norm,
300        Self::Attention,
301        Self::Ffn,
302        Self::Other,
303        Self::Sparse,
304        Self::Fft,
305        Self::Solver,
306    ];
307
308    /// Get the string name of this category.
309    #[inline]
310    pub const fn name(self) -> &'static str {
311        match self {
312            Self::Norm => "Norm",
313            Self::Attention => "Attention",
314            Self::Ffn => "FFN",
315            Self::Other => "Other",
316            Self::Sparse => "Sparse",
317            Self::Fft => "FFT",
318            Self::Solver => "Solver",
319        }
320    }
321}
322
323impl fmt::Display for BrickCategory {
324    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325        write!(f, "{}", self.name())
326    }
327}
328
329/// Synchronization mode for GPU profiling.
330///
331/// PAR-200: Controls the trade-off between accuracy and overhead.
332///
333/// # Performance Characteristics
334///
335/// | Mode | Overhead | Accuracy | Use Case |
336/// |------|----------|----------|----------|
337/// | `Immediate` | ~200% | Exact per-kernel | Debugging |
338/// | `PerLayer` | ~20% | Per-layer exact | Development |
339/// | `Deferred` | ~5% | Approximate | Production |
340/// | `None` | 0% | N/A | Disabled |
341#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
342pub enum SyncMode {
343    /// Sync after each kernel (accurate but slow).
344    /// Best for debugging and detailed optimization.
345    Immediate,
346    /// Sync once per transformer layer.
347    /// Good balance for development.
348    PerLayer,
349    /// Sync once per forward pass (fast, approximate).
350    /// Best for production profiling.
351    #[default]
352    Deferred,
353    /// No synchronization (profiling disabled or CPU-only).
354    None,
355}