Skip to main content

trueno/brick/tracing/
tracer.rs

1// ============================================================================
2// E.11.7: Unified ModelTracer
3// ============================================================================
4
5use std::fmt;
6
7use super::activation::{LayerActivationTrace, ModelActivationTrace};
8use super::attention::{AttentionTraceConfig, AttentionWeightTrace};
9use super::kv_cache::{KvCacheSessionTrace, KvCacheStateTrace};
10use super::logit::LogitEvolutionTrace;
11use super::quant_error::{ModelQuantizationError, QuantizationErrorTrace};
12
13/// Configuration for model-level tracing.
14#[derive(Debug, Clone, Default)]
15pub struct ModelTracerConfig {
16    /// Enable layer activation tracing (MLT-01)
17    pub trace_activations: bool,
18    /// Enable attention weight tracing (MLT-02)
19    pub trace_attention: bool,
20    /// Attention trace configuration
21    pub attention_config: AttentionTraceConfig,
22    /// Enable logit evolution tracing (MLT-03)
23    pub trace_logits: bool,
24    /// Specific tokens to track (None = auto-select top-k)
25    pub tracked_tokens: Option<Vec<u32>>,
26    /// Enable quantization error tracing (MLT-04) - expensive!
27    pub trace_quant_error: bool,
28    /// Enable KV cache state tracing (MLT-05)
29    pub trace_kv_cache: bool,
30}
31
32impl ModelTracerConfig {
33    /// Create a config that traces everything (for debugging).
34    pub fn full() -> Self {
35        Self {
36            trace_activations: true,
37            trace_attention: true,
38            attention_config: AttentionTraceConfig::default(),
39            trace_logits: true,
40            tracked_tokens: None,
41            trace_quant_error: true,
42            trace_kv_cache: true,
43        }
44    }
45
46    /// Create a lightweight config (activations + KV cache only).
47    pub fn lightweight() -> Self {
48        Self { trace_activations: true, trace_kv_cache: true, ..Default::default() }
49    }
50
51    /// Check if any tracing is enabled.
52    pub fn is_enabled(&self) -> bool {
53        self.trace_activations
54            || self.trace_attention
55            || self.trace_logits
56            || self.trace_quant_error
57            || self.trace_kv_cache
58    }
59}
60
61/// Unified model tracer that coordinates all trace types.
62///
63/// # Example
64/// ```rust,ignore
65/// let config = ModelTracerConfig::lightweight();
66/// let mut tracer = ModelTracer::new(config);
67///
68/// tracer.begin_forward(position);
69/// // ... forward pass with trace hooks ...
70/// if let Some(anomaly) = tracer.end_forward() {
71///     log::warn!("Anomaly: {}", anomaly);
72/// }
73/// ```
74pub struct ModelTracer {
75    config: ModelTracerConfig,
76    /// Current forward pass position
77    current_position: usize,
78    /// Accumulated activation traces
79    activation_traces: Vec<ModelActivationTrace>,
80    /// Current activation trace (in progress)
81    current_activation_trace: Option<ModelActivationTrace>,
82    /// Accumulated attention traces
83    attention_traces: Vec<AttentionWeightTrace>,
84    /// Accumulated logit evolution traces
85    logit_traces: Vec<LogitEvolutionTrace>,
86    /// Current logit trace (in progress)
87    current_logit_trace: Option<LogitEvolutionTrace>,
88    /// Accumulated quantization error traces
89    quant_traces: Vec<ModelQuantizationError>,
90    /// KV cache session trace
91    kv_trace: KvCacheSessionTrace,
92}
93
94impl ModelTracer {
95    /// Create a new tracer with the given configuration.
96    pub fn new(config: ModelTracerConfig) -> Self {
97        Self {
98            config,
99            current_position: 0,
100            activation_traces: Vec::new(),
101            current_activation_trace: None,
102            attention_traces: Vec::new(),
103            logit_traces: Vec::new(),
104            current_logit_trace: None,
105            quant_traces: Vec::new(),
106            kv_trace: KvCacheSessionTrace::default(),
107        }
108    }
109
110    /// Get the configuration.
111    pub fn config(&self) -> &ModelTracerConfig {
112        &self.config
113    }
114
115    /// Get a reference to the current logit trace (if any).
116    pub fn current_logit_trace(&self) -> Option<&LogitEvolutionTrace> {
117        self.current_logit_trace.as_ref()
118    }
119
120    /// Set the current logit trace (for testing purposes).
121    pub fn set_current_logit_trace(&mut self, trace: Option<LogitEvolutionTrace>) {
122        self.current_logit_trace = trace;
123    }
124
125    /// Begin a forward pass at the given position.
126    pub fn begin_forward(&mut self, position: usize) {
127        self.current_position = position;
128
129        if self.config.trace_activations {
130            self.current_activation_trace = Some(ModelActivationTrace::default());
131        }
132
133        if self.config.trace_logits {
134            self.current_logit_trace = Some(LogitEvolutionTrace::new(position, 1.0, 1.0));
135        }
136    }
137
138    /// Record layer activation (called by executor after each layer).
139    pub fn record_layer_activation(&mut self, trace: LayerActivationTrace) {
140        if let Some(ref mut activation) = self.current_activation_trace {
141            activation.add_layer(trace);
142        }
143    }
144
145    /// Record attention weights (called by attention brick).
146    pub fn record_attention(&mut self, trace: AttentionWeightTrace) {
147        if self.config.trace_attention {
148            self.attention_traces.push(trace);
149        }
150    }
151
152    /// Record logit state at a layer (called by lm_head or probe).
153    pub fn record_logits(&mut self, layer_idx: usize, logits: &[f32]) {
154        if let Some(ref mut logit_trace) = self.current_logit_trace {
155            for token_evo in &mut logit_trace.tracked_tokens {
156                let logit = logits.get(token_evo.token_id as usize).copied().unwrap_or(0.0);
157                let rank = LogitEvolutionTrace::compute_rank(logits, token_evo.token_id);
158                token_evo.record_layer(logit, rank);
159            }
160            // Store decisive layer based on rank changes
161            logit_trace.decisive_layer = layer_idx;
162        }
163    }
164
165    /// Record KV cache state (called after each generation step).
166    pub fn record_kv_state(&mut self, trace: KvCacheStateTrace) {
167        if self.config.trace_kv_cache {
168            self.kv_trace.add_step(trace);
169        }
170    }
171
172    /// Record quantization error for a brick.
173    pub fn record_quant_error(&mut self, trace: QuantizationErrorTrace) {
174        if self.config.trace_quant_error {
175            if self.quant_traces.is_empty() {
176                self.quant_traces.push(ModelQuantizationError::default());
177            }
178            if let Some(model_error) = self.quant_traces.last_mut() {
179                model_error.add_error(trace);
180            }
181        }
182    }
183
184    /// Complete forward pass and check for anomalies.
185    ///
186    /// Returns a description of the first anomaly detected, if any.
187    pub fn end_forward(&mut self) -> Option<String> {
188        let mut anomaly = None;
189
190        // Finalize activation trace
191        if let Some(mut trace) = self.current_activation_trace.take() {
192            trace.finalize();
193            if trace.has_anomaly {
194                anomaly = trace.anomaly_desc.clone();
195            }
196            self.activation_traces.push(trace);
197        }
198
199        // Finalize logit trace
200        if let Some(trace) = self.current_logit_trace.take() {
201            self.logit_traces.push(trace);
202        }
203
204        anomaly
205    }
206
207    /// Get summary statistics.
208    pub fn summary(&self) -> ModelTracerSummary {
209        ModelTracerSummary {
210            total_forwards: self.activation_traces.len(),
211            anomalies_detected: self.activation_traces.iter().filter(|t| t.has_anomaly).count(),
212            attention_traces: self.attention_traces.len(),
213            logit_traces: self.logit_traces.len(),
214            kv_steps: self.kv_trace.steps.len(),
215            total_evictions: self.kv_trace.total_evictions,
216            avg_hit_rate: self.kv_trace.avg_hit_rate,
217            quant_warnings: self.quant_traces.iter().map(|t| t.warning_count()).sum(),
218            quant_criticals: self.quant_traces.iter().map(|t| t.critical_count()).sum(),
219        }
220    }
221
222    /// Export summary as JSON for artifact validation.
223    pub fn summary_to_json(&self) -> String {
224        let summary = self.summary();
225        format!(
226            r#"{{"total_forwards":{},"anomalies_detected":{},"attention_traces":{},"logit_traces":{},"kv_steps":{},"total_evictions":{},"avg_hit_rate":{:.4},"quant_warnings":{},"quant_criticals":{}}}"#,
227            summary.total_forwards,
228            summary.anomalies_detected,
229            summary.attention_traces,
230            summary.logit_traces,
231            summary.kv_steps,
232            summary.total_evictions,
233            summary.avg_hit_rate,
234            summary.quant_warnings,
235            summary.quant_criticals
236        )
237    }
238
239    /// Clear all accumulated traces (free memory).
240    pub fn clear(&mut self) {
241        self.activation_traces.clear();
242        self.attention_traces.clear();
243        self.logit_traces.clear();
244        self.quant_traces.clear();
245        self.kv_trace = KvCacheSessionTrace::default();
246    }
247}
248
249/// Summary of model tracer state.
250#[derive(Debug, Clone, Default)]
251pub struct ModelTracerSummary {
252    /// Total forward passes traced
253    pub total_forwards: usize,
254    /// Number of forward passes with anomalies
255    pub anomalies_detected: usize,
256    /// Total attention traces collected
257    pub attention_traces: usize,
258    /// Total logit evolution traces
259    pub logit_traces: usize,
260    /// Total KV cache steps traced
261    pub kv_steps: usize,
262    /// Total KV cache evictions
263    pub total_evictions: usize,
264    /// Average KV cache hit rate
265    pub avg_hit_rate: f32,
266    /// Quantization warning count
267    pub quant_warnings: usize,
268    /// Quantization critical count
269    pub quant_criticals: usize,
270}
271
272impl fmt::Display for ModelTracerSummary {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        writeln!(f, "ModelTracer Summary:")?;
275        writeln!(f, "  Forward passes: {}", self.total_forwards)?;
276        writeln!(f, "  Anomalies: {}", self.anomalies_detected)?;
277        writeln!(f, "  Attention traces: {}", self.attention_traces)?;
278        writeln!(f, "  Logit traces: {}", self.logit_traces)?;
279        writeln!(f, "  KV cache steps: {}", self.kv_steps)?;
280        writeln!(f, "  KV evictions: {}", self.total_evictions)?;
281        writeln!(f, "  Avg hit rate: {:.2}%", self.avg_hit_rate * 100.0)?;
282        writeln!(f, "  Quant warnings: {}", self.quant_warnings)?;
283        write!(f, "  Quant criticals: {}", self.quant_criticals)
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::super::activation::TensorStats;
290    use super::*;
291
292    #[test]
293    fn test_model_tracer_lightweight() {
294        let config = ModelTracerConfig::lightweight();
295        assert!(config.trace_activations);
296        assert!(config.trace_kv_cache);
297        assert!(!config.trace_attention);
298        assert!(!config.trace_quant_error);
299    }
300
301    #[test]
302    fn test_model_tracer_full() {
303        let config = ModelTracerConfig::full();
304        assert!(config.trace_activations);
305        assert!(config.trace_attention);
306        assert!(config.trace_logits);
307        assert!(config.trace_quant_error);
308        assert!(config.trace_kv_cache);
309    }
310
311    #[test]
312    fn test_model_tracer_forward_pass() {
313        let config = ModelTracerConfig::lightweight();
314        let mut tracer = ModelTracer::new(config);
315
316        tracer.begin_forward(0);
317        tracer.record_layer_activation(LayerActivationTrace::new(0));
318        tracer.record_layer_activation(LayerActivationTrace::new(1));
319        let anomaly = tracer.end_forward();
320
321        assert!(anomaly.is_none());
322        let summary = tracer.summary();
323        assert_eq!(summary.total_forwards, 1);
324        assert_eq!(summary.anomalies_detected, 0);
325    }
326
327    #[test]
328    fn test_model_tracer_detects_anomaly() {
329        let config = ModelTracerConfig::lightweight();
330        let mut tracer = ModelTracer::new(config);
331
332        tracer.begin_forward(0);
333        let mut bad_layer = LayerActivationTrace::new(0);
334        bad_layer.input_stats = TensorStats::from_slice(&[f32::NAN]);
335        tracer.record_layer_activation(bad_layer);
336        let anomaly = tracer.end_forward();
337
338        assert!(anomaly.is_some());
339        assert!(anomaly.unwrap().contains("NaN"));
340        assert_eq!(tracer.summary().anomalies_detected, 1);
341    }
342
343    #[test]
344    fn test_model_tracer_json_output() {
345        let config = ModelTracerConfig::lightweight();
346        let mut tracer = ModelTracer::new(config);
347
348        tracer.begin_forward(0);
349        tracer.end_forward();
350
351        let json = tracer.summary_to_json();
352        assert!(json.contains("\"total_forwards\":1"));
353        assert!(json.contains("\"anomalies_detected\":0"));
354    }
355
356    /// FALSIFICATION TEST: ModelTracer layer count must match recorded layers
357    #[test]
358    fn test_falsify_tracer_layer_count() {
359        let config = ModelTracerConfig::lightweight();
360        let mut tracer = ModelTracer::new(config);
361
362        tracer.begin_forward(0);
363        let num_layers = 32;
364        for i in 0..num_layers {
365            tracer.record_layer_activation(LayerActivationTrace::new(i));
366        }
367        tracer.end_forward();
368
369        // The activation trace should have exactly num_layers entries
370        assert_eq!(
371            tracer.activation_traces[0].layers.len(),
372            num_layers,
373            "FALSIFICATION FAILED: recorded {} layers but expected {}",
374            tracer.activation_traces[0].layers.len(),
375            num_layers
376        );
377    }
378}