trueno/brick/exec_graph/node/
mod.rs1use std::fmt;
10
11mod execution;
12mod stats;
13
14pub use execution::{EdgeType, ExecutionEdge, ExecutionNode, ExecutionNodeId, TransferDirection};
15pub use stats::{BrickStats, CategoryStats, PtxRegistry};
16
17#[derive(Debug, Clone, Copy)]
24pub struct BrickSample {
25 pub brick_id: u64,
27 pub elapsed_ns: u64,
29 pub elements: u64,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub enum BrickBottleneck {
36 #[default]
38 Unknown,
39 Memory,
41 Compute,
43}
44
45impl fmt::Display for BrickBottleneck {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 BrickBottleneck::Unknown => write!(f, "unknown"),
49 BrickBottleneck::Memory => write!(f, "memory"),
50 BrickBottleneck::Compute => write!(f, "compute"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73#[repr(u8)]
74pub enum BrickId {
75 RmsNorm = 0,
78 LayerNorm = 1,
80
81 QkvProjection = 2,
84 RopeEmbedding = 3,
86 AttentionScore = 4,
88 AttentionSoftmax = 5,
90 AttentionOutput = 6,
92 OutputProjection = 7,
94
95 GateProjection = 8,
98 UpProjection = 9,
100 Activation = 10,
102 DownProjection = 11,
104
105 Embedding = 12,
108 LmHead = 13,
110 Sampling = 14,
112
113 SpMV = 15,
116 SpMM = 16,
118 FormatConvert = 17,
120
121 FFT1D = 18,
124 FFT2D = 19,
126
127 LUFactorize = 20,
130 QRFactorize = 21,
132 SVDCompute = 22,
134}
135
136impl BrickId {
137 pub const COUNT: usize = 23;
139
140 pub const ALL: [BrickId; Self::COUNT] = [
144 Self::RmsNorm,
145 Self::LayerNorm,
146 Self::QkvProjection,
147 Self::RopeEmbedding,
148 Self::AttentionScore,
149 Self::AttentionSoftmax,
150 Self::AttentionOutput,
151 Self::OutputProjection,
152 Self::GateProjection,
153 Self::UpProjection,
154 Self::Activation,
155 Self::DownProjection,
156 Self::Embedding,
157 Self::LmHead,
158 Self::Sampling,
159 Self::SpMV,
160 Self::SpMM,
161 Self::FormatConvert,
162 Self::FFT1D,
163 Self::FFT2D,
164 Self::LUFactorize,
165 Self::QRFactorize,
166 Self::SVDCompute,
167 ];
168
169 #[inline]
171 pub fn validate_index(index: usize) -> bool {
172 debug_assert!(
173 index < Self::COUNT,
174 "CB-BUDGET: brick index {} out of bounds (max {})",
175 index,
176 Self::COUNT
177 );
178 index < Self::COUNT
179 }
180
181 #[inline]
183 pub fn category(self) -> BrickCategory {
184 match self {
185 Self::RmsNorm | Self::LayerNorm => BrickCategory::Norm,
186 Self::QkvProjection
187 | Self::RopeEmbedding
188 | Self::AttentionScore
189 | Self::AttentionSoftmax
190 | Self::AttentionOutput
191 | Self::OutputProjection => BrickCategory::Attention,
192 Self::GateProjection | Self::UpProjection | Self::Activation | Self::DownProjection => {
193 BrickCategory::Ffn
194 }
195 Self::Embedding | Self::LmHead | Self::Sampling => BrickCategory::Other,
196 Self::SpMV | Self::SpMM | Self::FormatConvert => BrickCategory::Sparse,
197 Self::FFT1D | Self::FFT2D => BrickCategory::Fft,
198 Self::LUFactorize | Self::QRFactorize | Self::SVDCompute => BrickCategory::Solver,
199 }
200 }
201
202 #[inline]
204 pub const fn name(self) -> &'static str {
205 match self {
206 Self::RmsNorm => "RmsNorm",
207 Self::LayerNorm => "LayerNorm",
208 Self::QkvProjection => "QkvProjection",
209 Self::RopeEmbedding => "RopeEmbedding",
210 Self::AttentionScore => "AttentionScore",
211 Self::AttentionSoftmax => "AttentionSoftmax",
212 Self::AttentionOutput => "AttentionOutput",
213 Self::OutputProjection => "OutputProjection",
214 Self::GateProjection => "GateProjection",
215 Self::UpProjection => "UpProjection",
216 Self::Activation => "Activation",
217 Self::DownProjection => "DownProjection",
218 Self::Embedding => "Embedding",
219 Self::LmHead => "LmHead",
220 Self::Sampling => "Sampling",
221 Self::SpMV => "SpMV",
222 Self::SpMM => "SpMM",
223 Self::FormatConvert => "FormatConvert",
224 Self::FFT1D => "FFT1D",
225 Self::FFT2D => "FFT2D",
226 Self::LUFactorize => "LUFactorize",
227 Self::QRFactorize => "QRFactorize",
228 Self::SVDCompute => "SVDCompute",
229 }
230 }
231
232 #[allow(clippy::should_implement_trait)]
234 pub fn from_str(s: &str) -> Option<Self> {
235 match s {
236 "RmsNorm" => Some(Self::RmsNorm),
237 "LayerNorm" => Some(Self::LayerNorm),
238 "QkvProjection" | "Qkv" => Some(Self::QkvProjection),
239 "RopeEmbedding" | "Rope" | "RoPE" => Some(Self::RopeEmbedding),
240 "AttentionScore" => Some(Self::AttentionScore),
241 "AttentionSoftmax" | "Softmax" => Some(Self::AttentionSoftmax),
242 "AttentionOutput" => Some(Self::AttentionOutput),
243 "OutputProjection" | "OutProj" => Some(Self::OutputProjection),
244 "GateProjection" | "Gate" => Some(Self::GateProjection),
245 "UpProjection" | "Up" => Some(Self::UpProjection),
246 "Activation" | "SiLU" | "GELU" | "ReLU" => Some(Self::Activation),
247 "DownProjection" | "Down" => Some(Self::DownProjection),
248 "Embedding" | "Embed" => Some(Self::Embedding),
249 "LmHead" | "Head" => Some(Self::LmHead),
250 "Sampling" | "Sample" => Some(Self::Sampling),
251 "SpMV" | "spmv" => Some(Self::SpMV),
252 "SpMM" | "spmm" => Some(Self::SpMM),
253 "FormatConvert" => Some(Self::FormatConvert),
254 "FFT1D" | "fft1d" | "FFT" => Some(Self::FFT1D),
255 "FFT2D" | "fft2d" => Some(Self::FFT2D),
256 "LUFactorize" | "LU" => Some(Self::LUFactorize),
257 "QRFactorize" | "QR" => Some(Self::QRFactorize),
258 "SVDCompute" | "SVD" => Some(Self::SVDCompute),
259 _ => None,
260 }
261 }
262}
263
264impl fmt::Display for BrickId {
265 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266 write!(f, "{}", self.name())
267 }
268}
269
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
274#[repr(u8)]
275pub enum BrickCategory {
276 Norm = 0,
278 Attention = 1,
280 Ffn = 2,
282 #[default]
284 Other = 3,
285 Sparse = 4,
287 Fft = 5,
289 Solver = 6,
291}
292
293impl BrickCategory {
294 pub const COUNT: usize = 7;
296
297 pub const ALL: [BrickCategory; Self::COUNT] = [
299 Self::Norm,
300 Self::Attention,
301 Self::Ffn,
302 Self::Other,
303 Self::Sparse,
304 Self::Fft,
305 Self::Solver,
306 ];
307
308 #[inline]
310 pub const fn name(self) -> &'static str {
311 match self {
312 Self::Norm => "Norm",
313 Self::Attention => "Attention",
314 Self::Ffn => "FFN",
315 Self::Other => "Other",
316 Self::Sparse => "Sparse",
317 Self::Fft => "FFT",
318 Self::Solver => "Solver",
319 }
320 }
321}
322
323impl fmt::Display for BrickCategory {
324 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325 write!(f, "{}", self.name())
326 }
327}
328
329#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
342pub enum SyncMode {
343 Immediate,
346 PerLayer,
349 #[default]
352 Deferred,
353 None,
355}