use std::collections::HashSet;
use std::path::PathBuf;
use std::time::Instant;
#[derive(Debug, Clone, Default)]
pub struct TraceConfig {
pub enabled: bool,
pub steps: HashSet<TraceStep>,
pub verbose: bool,
pub output: Option<PathBuf>,
}
impl TraceConfig {
#[must_use]
pub fn enabled() -> Self {
Self {
enabled: true,
..Default::default()
}
}
#[must_use]
pub fn should_trace(&self, step: TraceStep) -> bool {
self.enabled && (self.steps.is_empty() || self.steps.contains(&step))
}
#[must_use]
pub fn parse_steps(s: &str) -> HashSet<TraceStep> {
s.split(',')
.filter_map(|part| TraceStep::parse(part.trim()))
.collect()
}
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub enum TraceStep {
Tokenize,
Embed,
LayerNorm,
Attention,
FFN,
TransformerBlock,
LmHead,
Sample,
Decode,
KernelLaunch,
BrickProfile,
}
impl TraceStep {
#[must_use]
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"tokenize" | "encode" => Some(Self::Tokenize),
"embed" | "embedding" => Some(Self::Embed),
"layernorm" | "ln" | "norm" => Some(Self::LayerNorm),
"attention" | "attn" => Some(Self::Attention),
"ffn" | "mlp" => Some(Self::FFN),
"transformer" | "transformer_block" | "layer" => Some(Self::TransformerBlock),
"lmhead" | "lm_head" | "head" => Some(Self::LmHead),
"sample" | "sampling" => Some(Self::Sample),
"decode" | "detokenize" => Some(Self::Decode),
"kernel" | "kernel_launch" | "ptx" | "cuda" => Some(Self::KernelLaunch),
"brick" | "brick_profile" | "profiler" | "bricks" => Some(Self::BrickProfile),
_ => None,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::Tokenize => "TOKENIZE",
Self::Embed => "EMBED",
Self::LayerNorm => "LAYER_NORM",
Self::Attention => "ATTENTION",
Self::FFN => "FFN",
Self::TransformerBlock => "TRANSFORMER_BLOCK",
Self::LmHead => "LM_HEAD",
Self::Sample => "SAMPLE",
Self::Decode => "DECODE",
Self::KernelLaunch => "KERNEL_LAUNCH",
Self::BrickProfile => "BRICK_PROFILE",
}
}
#[deprecated(since = "3.0.0", note = "Use name() instead")]
#[must_use]
pub fn legacy_name(&self) -> &'static str {
match self {
Self::Tokenize => "ENCODE",
Self::Embed => "EMBED",
Self::LayerNorm => "LAYER_NORM",
Self::Attention => "ATTENTION",
Self::FFN => "FFN",
Self::TransformerBlock => "TRANSFORMER",
Self::LmHead => "LM_HEAD",
Self::Sample => "SAMPLE",
Self::Decode => "DECODE",
Self::KernelLaunch => "KERNEL_LAUNCH",
Self::BrickProfile => "BRICK_PROFILE",
}
}
#[must_use]
pub fn step_number(&self) -> usize {
match self {
Self::Tokenize => 1,
Self::Embed => 2,
Self::LayerNorm | Self::Attention | Self::FFN | Self::TransformerBlock => 3,
Self::LmHead => 4,
Self::Sample => 5,
Self::Decode => 6,
Self::KernelLaunch => 7,
Self::BrickProfile => 8,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TensorStats {
pub min: f32,
pub max: f32,
pub mean: f32,
pub std: f32,
pub has_nan: bool,
pub has_inf: bool,
}
impl TensorStats {
#[must_use]
pub fn from_slice(data: &[f32]) -> Self {
if data.is_empty() {
return Self::default();
}
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut sum = 0.0f64;
let mut has_nan = false;
let mut has_inf = false;
for &v in data {
if v.is_nan() {
has_nan = true;
} else if v.is_infinite() {
has_inf = true;
} else {
min = min.min(v);
max = max.max(v);
sum += f64::from(v);
}
}
let mean = (sum / data.len() as f64) as f32;
let mut var_sum = 0.0f64;
for &v in data {
if !v.is_nan() && !v.is_infinite() {
let diff = f64::from(v) - f64::from(mean);
var_sum += diff * diff;
}
}
let std = ((var_sum / data.len() as f64).sqrt()) as f32;
Self {
min,
max,
mean,
std,
has_nan,
has_inf,
}
}
#[must_use]
pub fn has_error(&self) -> bool {
self.has_nan || self.has_inf
}
}
#[derive(Debug, Clone)]
pub enum TraceError {
VocabOverflow {
token_id: u32,
vocab_size: usize,
},
NaNDetected {
layer: Option<usize>,
},
InfDetected {
layer: Option<usize>,
},
GarbageOutput {
sample: String,
},
UnknownToken {
token_id: u32,
},
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
ExecutionFailed {
cause: String,
},
}
impl std::fmt::Display for TraceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VocabOverflow {
token_id,
vocab_size,
} => {
write!(f, "Token ID {} exceeds vocab size {}", token_id, vocab_size)
},
Self::NaNDetected { layer } => {
if let Some(l) = layer {
write!(f, "NaN values detected in layer {}", l)
} else {
write!(f, "NaN values detected")
}
},
Self::InfDetected { layer } => {
if let Some(l) = layer {
write!(f, "Inf values detected in layer {}", l)
} else {
write!(f, "Inf values detected")
}
},
Self::GarbageOutput { sample } => {
write!(f, "Garbage output detected: {:?}", sample)
},
Self::UnknownToken { token_id } => {
write!(f, "Unknown token ID: {}", token_id)
},
Self::ShapeMismatch { expected, actual } => {
write!(
f,
"Shape mismatch: expected {:?}, got {:?}",
expected, actual
)
},
Self::ExecutionFailed { cause } => {
write!(f, "Execution failed: {}", cause)
},
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AwsEventType {
TaskStateEntered,
TaskStateExited,
ExecutionFailed,
}
impl AwsEventType {
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::TaskStateEntered => "TaskStateEntered",
Self::TaskStateExited => "TaskStateExited",
Self::ExecutionFailed => "ExecutionFailed",
}
}
}
#[derive(Debug, Clone)]
pub struct TraceEvent {
pub id: u64,
pub timestamp: String,
pub event_type: AwsEventType,
pub previous_event_id: Option<u64>,
pub step: TraceStep,
pub iteration: usize,
pub layer: Option<usize>,
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
pub stats: TensorStats,
pub duration_us: u64,
pub error: Option<TraceError>,
pub cause: Option<String>,
pub details: TraceDetails,
}
#[derive(Debug, Clone, Default)]
pub struct TraceDetails {
pub input_text: Option<String>,
pub output_tokens: Option<Vec<u32>>,
pub vocab_entries: Option<Vec<String>>,
pub top_k_logits: Option<Vec<(u32, f32)>>,
pub top_k_probs: Option<Vec<(u32, f32)>>,
pub sampled_token: Option<u32>,
pub decoded_text: Option<String>,
pub token_string: Option<String>,
pub temperature: Option<f32>,
pub top_k: Option<usize>,
pub kernel_name: Option<String>,
pub grid_dims: Option<[u32; 3]>,
pub block_dims: Option<[u32; 3]>,
pub shared_mem_bytes: Option<u32>,
pub kernel_layer: Option<usize>,
pub dispatch_strategy: Option<String>,
pub brick_categories: Option<Vec<(String, u64)>>,
pub brick_timings: Option<Vec<(String, u64, u64)>>,
}
#[derive(Debug)]
pub struct InferenceTracer {
config: TraceConfig,
events: Vec<TraceEvent>,
model_info: ModelInfo,
step_start: Option<Instant>,
error_count: usize,
warning_count: usize,
next_event_id: u64,
last_entered_id: Option<u64>,
}
include!("model_info.rs");
include!("execution_failure.rs");
include!("mod_get_top_compute.rs");
include!("tracer_contracts.rs");