use std::fmt;
mod execution;
mod stats;
pub use execution::{EdgeType, ExecutionEdge, ExecutionNode, ExecutionNodeId, TransferDirection};
pub use stats::{BrickStats, CategoryStats, PtxRegistry};
#[derive(Debug, Clone, Copy)]
pub struct BrickSample {
pub brick_id: u64,
pub elapsed_ns: u64,
pub elements: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BrickBottleneck {
#[default]
Unknown,
Memory,
Compute,
}
impl fmt::Display for BrickBottleneck {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BrickBottleneck::Unknown => write!(f, "unknown"),
BrickBottleneck::Memory => write!(f, "memory"),
BrickBottleneck::Compute => write!(f, "compute"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum BrickId {
RmsNorm = 0,
LayerNorm = 1,
QkvProjection = 2,
RopeEmbedding = 3,
AttentionScore = 4,
AttentionSoftmax = 5,
AttentionOutput = 6,
OutputProjection = 7,
GateProjection = 8,
UpProjection = 9,
Activation = 10,
DownProjection = 11,
Embedding = 12,
LmHead = 13,
Sampling = 14,
SpMV = 15,
SpMM = 16,
FormatConvert = 17,
FFT1D = 18,
FFT2D = 19,
LUFactorize = 20,
QRFactorize = 21,
SVDCompute = 22,
}
impl BrickId {
pub const COUNT: usize = 23;
pub const ALL: [BrickId; Self::COUNT] = [
Self::RmsNorm,
Self::LayerNorm,
Self::QkvProjection,
Self::RopeEmbedding,
Self::AttentionScore,
Self::AttentionSoftmax,
Self::AttentionOutput,
Self::OutputProjection,
Self::GateProjection,
Self::UpProjection,
Self::Activation,
Self::DownProjection,
Self::Embedding,
Self::LmHead,
Self::Sampling,
Self::SpMV,
Self::SpMM,
Self::FormatConvert,
Self::FFT1D,
Self::FFT2D,
Self::LUFactorize,
Self::QRFactorize,
Self::SVDCompute,
];
#[inline]
pub fn validate_index(index: usize) -> bool {
debug_assert!(
index < Self::COUNT,
"CB-BUDGET: brick index {} out of bounds (max {})",
index,
Self::COUNT
);
index < Self::COUNT
}
#[inline]
pub fn category(self) -> BrickCategory {
match self {
Self::RmsNorm | Self::LayerNorm => BrickCategory::Norm,
Self::QkvProjection
| Self::RopeEmbedding
| Self::AttentionScore
| Self::AttentionSoftmax
| Self::AttentionOutput
| Self::OutputProjection => BrickCategory::Attention,
Self::GateProjection | Self::UpProjection | Self::Activation | Self::DownProjection => {
BrickCategory::Ffn
}
Self::Embedding | Self::LmHead | Self::Sampling => BrickCategory::Other,
Self::SpMV | Self::SpMM | Self::FormatConvert => BrickCategory::Sparse,
Self::FFT1D | Self::FFT2D => BrickCategory::Fft,
Self::LUFactorize | Self::QRFactorize | Self::SVDCompute => BrickCategory::Solver,
}
}
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::RmsNorm => "RmsNorm",
Self::LayerNorm => "LayerNorm",
Self::QkvProjection => "QkvProjection",
Self::RopeEmbedding => "RopeEmbedding",
Self::AttentionScore => "AttentionScore",
Self::AttentionSoftmax => "AttentionSoftmax",
Self::AttentionOutput => "AttentionOutput",
Self::OutputProjection => "OutputProjection",
Self::GateProjection => "GateProjection",
Self::UpProjection => "UpProjection",
Self::Activation => "Activation",
Self::DownProjection => "DownProjection",
Self::Embedding => "Embedding",
Self::LmHead => "LmHead",
Self::Sampling => "Sampling",
Self::SpMV => "SpMV",
Self::SpMM => "SpMM",
Self::FormatConvert => "FormatConvert",
Self::FFT1D => "FFT1D",
Self::FFT2D => "FFT2D",
Self::LUFactorize => "LUFactorize",
Self::QRFactorize => "QRFactorize",
Self::SVDCompute => "SVDCompute",
}
}
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Option<Self> {
match s {
"RmsNorm" => Some(Self::RmsNorm),
"LayerNorm" => Some(Self::LayerNorm),
"QkvProjection" | "Qkv" => Some(Self::QkvProjection),
"RopeEmbedding" | "Rope" | "RoPE" => Some(Self::RopeEmbedding),
"AttentionScore" => Some(Self::AttentionScore),
"AttentionSoftmax" | "Softmax" => Some(Self::AttentionSoftmax),
"AttentionOutput" => Some(Self::AttentionOutput),
"OutputProjection" | "OutProj" => Some(Self::OutputProjection),
"GateProjection" | "Gate" => Some(Self::GateProjection),
"UpProjection" | "Up" => Some(Self::UpProjection),
"Activation" | "SiLU" | "GELU" | "ReLU" => Some(Self::Activation),
"DownProjection" | "Down" => Some(Self::DownProjection),
"Embedding" | "Embed" => Some(Self::Embedding),
"LmHead" | "Head" => Some(Self::LmHead),
"Sampling" | "Sample" => Some(Self::Sampling),
"SpMV" | "spmv" => Some(Self::SpMV),
"SpMM" | "spmm" => Some(Self::SpMM),
"FormatConvert" => Some(Self::FormatConvert),
"FFT1D" | "fft1d" | "FFT" => Some(Self::FFT1D),
"FFT2D" | "fft2d" => Some(Self::FFT2D),
"LUFactorize" | "LU" => Some(Self::LUFactorize),
"QRFactorize" | "QR" => Some(Self::QRFactorize),
"SVDCompute" | "SVD" => Some(Self::SVDCompute),
_ => None,
}
}
}
impl fmt::Display for BrickId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[repr(u8)]
pub enum BrickCategory {
Norm = 0,
Attention = 1,
Ffn = 2,
#[default]
Other = 3,
Sparse = 4,
Fft = 5,
Solver = 6,
}
impl BrickCategory {
pub const COUNT: usize = 7;
pub const ALL: [BrickCategory; Self::COUNT] = [
Self::Norm,
Self::Attention,
Self::Ffn,
Self::Other,
Self::Sparse,
Self::Fft,
Self::Solver,
];
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Norm => "Norm",
Self::Attention => "Attention",
Self::Ffn => "FFN",
Self::Other => "Other",
Self::Sparse => "Sparse",
Self::Fft => "FFT",
Self::Solver => "Solver",
}
}
}
impl fmt::Display for BrickCategory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SyncMode {
Immediate,
PerLayer,
#[default]
Deferred,
None,
}