use std::fmt;
use super::activation::{LayerActivationTrace, ModelActivationTrace};
use super::attention::{AttentionTraceConfig, AttentionWeightTrace};
use super::kv_cache::{KvCacheSessionTrace, KvCacheStateTrace};
use super::logit::LogitEvolutionTrace;
use super::quant_error::{ModelQuantizationError, QuantizationErrorTrace};
#[derive(Debug, Clone, Default)]
pub struct ModelTracerConfig {
pub trace_activations: bool,
pub trace_attention: bool,
pub attention_config: AttentionTraceConfig,
pub trace_logits: bool,
pub tracked_tokens: Option<Vec<u32>>,
pub trace_quant_error: bool,
pub trace_kv_cache: bool,
}
impl ModelTracerConfig {
pub fn full() -> Self {
Self {
trace_activations: true,
trace_attention: true,
attention_config: AttentionTraceConfig::default(),
trace_logits: true,
tracked_tokens: None,
trace_quant_error: true,
trace_kv_cache: true,
}
}
pub fn lightweight() -> Self {
Self { trace_activations: true, trace_kv_cache: true, ..Default::default() }
}
pub fn is_enabled(&self) -> bool {
self.trace_activations
|| self.trace_attention
|| self.trace_logits
|| self.trace_quant_error
|| self.trace_kv_cache
}
}
pub struct ModelTracer {
config: ModelTracerConfig,
current_position: usize,
activation_traces: Vec<ModelActivationTrace>,
current_activation_trace: Option<ModelActivationTrace>,
attention_traces: Vec<AttentionWeightTrace>,
logit_traces: Vec<LogitEvolutionTrace>,
current_logit_trace: Option<LogitEvolutionTrace>,
quant_traces: Vec<ModelQuantizationError>,
kv_trace: KvCacheSessionTrace,
}
impl ModelTracer {
pub fn new(config: ModelTracerConfig) -> Self {
Self {
config,
current_position: 0,
activation_traces: Vec::new(),
current_activation_trace: None,
attention_traces: Vec::new(),
logit_traces: Vec::new(),
current_logit_trace: None,
quant_traces: Vec::new(),
kv_trace: KvCacheSessionTrace::default(),
}
}
pub fn config(&self) -> &ModelTracerConfig {
&self.config
}
pub fn current_logit_trace(&self) -> Option<&LogitEvolutionTrace> {
self.current_logit_trace.as_ref()
}
pub fn set_current_logit_trace(&mut self, trace: Option<LogitEvolutionTrace>) {
self.current_logit_trace = trace;
}
pub fn begin_forward(&mut self, position: usize) {
self.current_position = position;
if self.config.trace_activations {
self.current_activation_trace = Some(ModelActivationTrace::default());
}
if self.config.trace_logits {
self.current_logit_trace = Some(LogitEvolutionTrace::new(position, 1.0, 1.0));
}
}
pub fn record_layer_activation(&mut self, trace: LayerActivationTrace) {
if let Some(ref mut activation) = self.current_activation_trace {
activation.add_layer(trace);
}
}
pub fn record_attention(&mut self, trace: AttentionWeightTrace) {
if self.config.trace_attention {
self.attention_traces.push(trace);
}
}
pub fn record_logits(&mut self, layer_idx: usize, logits: &[f32]) {
if let Some(ref mut logit_trace) = self.current_logit_trace {
for token_evo in &mut logit_trace.tracked_tokens {
let logit = logits.get(token_evo.token_id as usize).copied().unwrap_or(0.0);
let rank = LogitEvolutionTrace::compute_rank(logits, token_evo.token_id);
token_evo.record_layer(logit, rank);
}
logit_trace.decisive_layer = layer_idx;
}
}
pub fn record_kv_state(&mut self, trace: KvCacheStateTrace) {
if self.config.trace_kv_cache {
self.kv_trace.add_step(trace);
}
}
pub fn record_quant_error(&mut self, trace: QuantizationErrorTrace) {
if self.config.trace_quant_error {
if self.quant_traces.is_empty() {
self.quant_traces.push(ModelQuantizationError::default());
}
if let Some(model_error) = self.quant_traces.last_mut() {
model_error.add_error(trace);
}
}
}
pub fn end_forward(&mut self) -> Option<String> {
let mut anomaly = None;
if let Some(mut trace) = self.current_activation_trace.take() {
trace.finalize();
if trace.has_anomaly {
anomaly = trace.anomaly_desc.clone();
}
self.activation_traces.push(trace);
}
if let Some(trace) = self.current_logit_trace.take() {
self.logit_traces.push(trace);
}
anomaly
}
pub fn summary(&self) -> ModelTracerSummary {
ModelTracerSummary {
total_forwards: self.activation_traces.len(),
anomalies_detected: self.activation_traces.iter().filter(|t| t.has_anomaly).count(),
attention_traces: self.attention_traces.len(),
logit_traces: self.logit_traces.len(),
kv_steps: self.kv_trace.steps.len(),
total_evictions: self.kv_trace.total_evictions,
avg_hit_rate: self.kv_trace.avg_hit_rate,
quant_warnings: self.quant_traces.iter().map(|t| t.warning_count()).sum(),
quant_criticals: self.quant_traces.iter().map(|t| t.critical_count()).sum(),
}
}
pub fn summary_to_json(&self) -> String {
let summary = self.summary();
format!(
r#"{{"total_forwards":{},"anomalies_detected":{},"attention_traces":{},"logit_traces":{},"kv_steps":{},"total_evictions":{},"avg_hit_rate":{:.4},"quant_warnings":{},"quant_criticals":{}}}"#,
summary.total_forwards,
summary.anomalies_detected,
summary.attention_traces,
summary.logit_traces,
summary.kv_steps,
summary.total_evictions,
summary.avg_hit_rate,
summary.quant_warnings,
summary.quant_criticals
)
}
pub fn clear(&mut self) {
self.activation_traces.clear();
self.attention_traces.clear();
self.logit_traces.clear();
self.quant_traces.clear();
self.kv_trace = KvCacheSessionTrace::default();
}
}
#[derive(Debug, Clone, Default)]
pub struct ModelTracerSummary {
pub total_forwards: usize,
pub anomalies_detected: usize,
pub attention_traces: usize,
pub logit_traces: usize,
pub kv_steps: usize,
pub total_evictions: usize,
pub avg_hit_rate: f32,
pub quant_warnings: usize,
pub quant_criticals: usize,
}
impl fmt::Display for ModelTracerSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "ModelTracer Summary:")?;
writeln!(f, " Forward passes: {}", self.total_forwards)?;
writeln!(f, " Anomalies: {}", self.anomalies_detected)?;
writeln!(f, " Attention traces: {}", self.attention_traces)?;
writeln!(f, " Logit traces: {}", self.logit_traces)?;
writeln!(f, " KV cache steps: {}", self.kv_steps)?;
writeln!(f, " KV evictions: {}", self.total_evictions)?;
writeln!(f, " Avg hit rate: {:.2}%", self.avg_hit_rate * 100.0)?;
writeln!(f, " Quant warnings: {}", self.quant_warnings)?;
write!(f, " Quant criticals: {}", self.quant_criticals)
}
}
#[cfg(test)]
mod tests {
use super::super::activation::TensorStats;
use super::*;
#[test]
fn test_model_tracer_lightweight() {
let config = ModelTracerConfig::lightweight();
assert!(config.trace_activations);
assert!(config.trace_kv_cache);
assert!(!config.trace_attention);
assert!(!config.trace_quant_error);
}
#[test]
fn test_model_tracer_full() {
let config = ModelTracerConfig::full();
assert!(config.trace_activations);
assert!(config.trace_attention);
assert!(config.trace_logits);
assert!(config.trace_quant_error);
assert!(config.trace_kv_cache);
}
#[test]
fn test_model_tracer_forward_pass() {
let config = ModelTracerConfig::lightweight();
let mut tracer = ModelTracer::new(config);
tracer.begin_forward(0);
tracer.record_layer_activation(LayerActivationTrace::new(0));
tracer.record_layer_activation(LayerActivationTrace::new(1));
let anomaly = tracer.end_forward();
assert!(anomaly.is_none());
let summary = tracer.summary();
assert_eq!(summary.total_forwards, 1);
assert_eq!(summary.anomalies_detected, 0);
}
#[test]
fn test_model_tracer_detects_anomaly() {
let config = ModelTracerConfig::lightweight();
let mut tracer = ModelTracer::new(config);
tracer.begin_forward(0);
let mut bad_layer = LayerActivationTrace::new(0);
bad_layer.input_stats = TensorStats::from_slice(&[f32::NAN]);
tracer.record_layer_activation(bad_layer);
let anomaly = tracer.end_forward();
assert!(anomaly.is_some());
assert!(anomaly.unwrap().contains("NaN"));
assert_eq!(tracer.summary().anomalies_detected, 1);
}
#[test]
fn test_model_tracer_json_output() {
let config = ModelTracerConfig::lightweight();
let mut tracer = ModelTracer::new(config);
tracer.begin_forward(0);
tracer.end_forward();
let json = tracer.summary_to_json();
assert!(json.contains("\"total_forwards\":1"));
assert!(json.contains("\"anomalies_detected\":0"));
}
#[test]
fn test_falsify_tracer_layer_count() {
let config = ModelTracerConfig::lightweight();
let mut tracer = ModelTracer::new(config);
tracer.begin_forward(0);
let num_layers = 32;
for i in 0..num_layers {
tracer.record_layer_activation(LayerActivationTrace::new(i));
}
tracer.end_forward();
assert_eq!(
tracer.activation_traces[0].layers.len(),
num_layers,
"FALSIFICATION FAILED: recorded {} layers but expected {}",
tracer.activation_traces[0].layers.len(),
num_layers
);
}
}