1use std::fmt;
6
7use super::activation::{LayerActivationTrace, ModelActivationTrace};
8use super::attention::{AttentionTraceConfig, AttentionWeightTrace};
9use super::kv_cache::{KvCacheSessionTrace, KvCacheStateTrace};
10use super::logit::LogitEvolutionTrace;
11use super::quant_error::{ModelQuantizationError, QuantizationErrorTrace};
12
13#[derive(Debug, Clone, Default)]
15pub struct ModelTracerConfig {
16 pub trace_activations: bool,
18 pub trace_attention: bool,
20 pub attention_config: AttentionTraceConfig,
22 pub trace_logits: bool,
24 pub tracked_tokens: Option<Vec<u32>>,
26 pub trace_quant_error: bool,
28 pub trace_kv_cache: bool,
30}
31
32impl ModelTracerConfig {
33 pub fn full() -> Self {
35 Self {
36 trace_activations: true,
37 trace_attention: true,
38 attention_config: AttentionTraceConfig::default(),
39 trace_logits: true,
40 tracked_tokens: None,
41 trace_quant_error: true,
42 trace_kv_cache: true,
43 }
44 }
45
46 pub fn lightweight() -> Self {
48 Self { trace_activations: true, trace_kv_cache: true, ..Default::default() }
49 }
50
51 pub fn is_enabled(&self) -> bool {
53 self.trace_activations
54 || self.trace_attention
55 || self.trace_logits
56 || self.trace_quant_error
57 || self.trace_kv_cache
58 }
59}
60
61pub struct ModelTracer {
75 config: ModelTracerConfig,
76 current_position: usize,
78 activation_traces: Vec<ModelActivationTrace>,
80 current_activation_trace: Option<ModelActivationTrace>,
82 attention_traces: Vec<AttentionWeightTrace>,
84 logit_traces: Vec<LogitEvolutionTrace>,
86 current_logit_trace: Option<LogitEvolutionTrace>,
88 quant_traces: Vec<ModelQuantizationError>,
90 kv_trace: KvCacheSessionTrace,
92}
93
94impl ModelTracer {
95 pub fn new(config: ModelTracerConfig) -> Self {
97 Self {
98 config,
99 current_position: 0,
100 activation_traces: Vec::new(),
101 current_activation_trace: None,
102 attention_traces: Vec::new(),
103 logit_traces: Vec::new(),
104 current_logit_trace: None,
105 quant_traces: Vec::new(),
106 kv_trace: KvCacheSessionTrace::default(),
107 }
108 }
109
110 pub fn config(&self) -> &ModelTracerConfig {
112 &self.config
113 }
114
115 pub fn current_logit_trace(&self) -> Option<&LogitEvolutionTrace> {
117 self.current_logit_trace.as_ref()
118 }
119
120 pub fn set_current_logit_trace(&mut self, trace: Option<LogitEvolutionTrace>) {
122 self.current_logit_trace = trace;
123 }
124
125 pub fn begin_forward(&mut self, position: usize) {
127 self.current_position = position;
128
129 if self.config.trace_activations {
130 self.current_activation_trace = Some(ModelActivationTrace::default());
131 }
132
133 if self.config.trace_logits {
134 self.current_logit_trace = Some(LogitEvolutionTrace::new(position, 1.0, 1.0));
135 }
136 }
137
138 pub fn record_layer_activation(&mut self, trace: LayerActivationTrace) {
140 if let Some(ref mut activation) = self.current_activation_trace {
141 activation.add_layer(trace);
142 }
143 }
144
145 pub fn record_attention(&mut self, trace: AttentionWeightTrace) {
147 if self.config.trace_attention {
148 self.attention_traces.push(trace);
149 }
150 }
151
152 pub fn record_logits(&mut self, layer_idx: usize, logits: &[f32]) {
154 if let Some(ref mut logit_trace) = self.current_logit_trace {
155 for token_evo in &mut logit_trace.tracked_tokens {
156 let logit = logits.get(token_evo.token_id as usize).copied().unwrap_or(0.0);
157 let rank = LogitEvolutionTrace::compute_rank(logits, token_evo.token_id);
158 token_evo.record_layer(logit, rank);
159 }
160 logit_trace.decisive_layer = layer_idx;
162 }
163 }
164
165 pub fn record_kv_state(&mut self, trace: KvCacheStateTrace) {
167 if self.config.trace_kv_cache {
168 self.kv_trace.add_step(trace);
169 }
170 }
171
172 pub fn record_quant_error(&mut self, trace: QuantizationErrorTrace) {
174 if self.config.trace_quant_error {
175 if self.quant_traces.is_empty() {
176 self.quant_traces.push(ModelQuantizationError::default());
177 }
178 if let Some(model_error) = self.quant_traces.last_mut() {
179 model_error.add_error(trace);
180 }
181 }
182 }
183
184 pub fn end_forward(&mut self) -> Option<String> {
188 let mut anomaly = None;
189
190 if let Some(mut trace) = self.current_activation_trace.take() {
192 trace.finalize();
193 if trace.has_anomaly {
194 anomaly = trace.anomaly_desc.clone();
195 }
196 self.activation_traces.push(trace);
197 }
198
199 if let Some(trace) = self.current_logit_trace.take() {
201 self.logit_traces.push(trace);
202 }
203
204 anomaly
205 }
206
207 pub fn summary(&self) -> ModelTracerSummary {
209 ModelTracerSummary {
210 total_forwards: self.activation_traces.len(),
211 anomalies_detected: self.activation_traces.iter().filter(|t| t.has_anomaly).count(),
212 attention_traces: self.attention_traces.len(),
213 logit_traces: self.logit_traces.len(),
214 kv_steps: self.kv_trace.steps.len(),
215 total_evictions: self.kv_trace.total_evictions,
216 avg_hit_rate: self.kv_trace.avg_hit_rate,
217 quant_warnings: self.quant_traces.iter().map(|t| t.warning_count()).sum(),
218 quant_criticals: self.quant_traces.iter().map(|t| t.critical_count()).sum(),
219 }
220 }
221
222 pub fn summary_to_json(&self) -> String {
224 let summary = self.summary();
225 format!(
226 r#"{{"total_forwards":{},"anomalies_detected":{},"attention_traces":{},"logit_traces":{},"kv_steps":{},"total_evictions":{},"avg_hit_rate":{:.4},"quant_warnings":{},"quant_criticals":{}}}"#,
227 summary.total_forwards,
228 summary.anomalies_detected,
229 summary.attention_traces,
230 summary.logit_traces,
231 summary.kv_steps,
232 summary.total_evictions,
233 summary.avg_hit_rate,
234 summary.quant_warnings,
235 summary.quant_criticals
236 )
237 }
238
239 pub fn clear(&mut self) {
241 self.activation_traces.clear();
242 self.attention_traces.clear();
243 self.logit_traces.clear();
244 self.quant_traces.clear();
245 self.kv_trace = KvCacheSessionTrace::default();
246 }
247}
248
249#[derive(Debug, Clone, Default)]
251pub struct ModelTracerSummary {
252 pub total_forwards: usize,
254 pub anomalies_detected: usize,
256 pub attention_traces: usize,
258 pub logit_traces: usize,
260 pub kv_steps: usize,
262 pub total_evictions: usize,
264 pub avg_hit_rate: f32,
266 pub quant_warnings: usize,
268 pub quant_criticals: usize,
270}
271
272impl fmt::Display for ModelTracerSummary {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 writeln!(f, "ModelTracer Summary:")?;
275 writeln!(f, " Forward passes: {}", self.total_forwards)?;
276 writeln!(f, " Anomalies: {}", self.anomalies_detected)?;
277 writeln!(f, " Attention traces: {}", self.attention_traces)?;
278 writeln!(f, " Logit traces: {}", self.logit_traces)?;
279 writeln!(f, " KV cache steps: {}", self.kv_steps)?;
280 writeln!(f, " KV evictions: {}", self.total_evictions)?;
281 writeln!(f, " Avg hit rate: {:.2}%", self.avg_hit_rate * 100.0)?;
282 writeln!(f, " Quant warnings: {}", self.quant_warnings)?;
283 write!(f, " Quant criticals: {}", self.quant_criticals)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::super::activation::TensorStats;
290 use super::*;
291
292 #[test]
293 fn test_model_tracer_lightweight() {
294 let config = ModelTracerConfig::lightweight();
295 assert!(config.trace_activations);
296 assert!(config.trace_kv_cache);
297 assert!(!config.trace_attention);
298 assert!(!config.trace_quant_error);
299 }
300
301 #[test]
302 fn test_model_tracer_full() {
303 let config = ModelTracerConfig::full();
304 assert!(config.trace_activations);
305 assert!(config.trace_attention);
306 assert!(config.trace_logits);
307 assert!(config.trace_quant_error);
308 assert!(config.trace_kv_cache);
309 }
310
311 #[test]
312 fn test_model_tracer_forward_pass() {
313 let config = ModelTracerConfig::lightweight();
314 let mut tracer = ModelTracer::new(config);
315
316 tracer.begin_forward(0);
317 tracer.record_layer_activation(LayerActivationTrace::new(0));
318 tracer.record_layer_activation(LayerActivationTrace::new(1));
319 let anomaly = tracer.end_forward();
320
321 assert!(anomaly.is_none());
322 let summary = tracer.summary();
323 assert_eq!(summary.total_forwards, 1);
324 assert_eq!(summary.anomalies_detected, 0);
325 }
326
327 #[test]
328 fn test_model_tracer_detects_anomaly() {
329 let config = ModelTracerConfig::lightweight();
330 let mut tracer = ModelTracer::new(config);
331
332 tracer.begin_forward(0);
333 let mut bad_layer = LayerActivationTrace::new(0);
334 bad_layer.input_stats = TensorStats::from_slice(&[f32::NAN]);
335 tracer.record_layer_activation(bad_layer);
336 let anomaly = tracer.end_forward();
337
338 assert!(anomaly.is_some());
339 assert!(anomaly.unwrap().contains("NaN"));
340 assert_eq!(tracer.summary().anomalies_detected, 1);
341 }
342
343 #[test]
344 fn test_model_tracer_json_output() {
345 let config = ModelTracerConfig::lightweight();
346 let mut tracer = ModelTracer::new(config);
347
348 tracer.begin_forward(0);
349 tracer.end_forward();
350
351 let json = tracer.summary_to_json();
352 assert!(json.contains("\"total_forwards\":1"));
353 assert!(json.contains("\"anomalies_detected\":0"));
354 }
355
356 #[test]
358 fn test_falsify_tracer_layer_count() {
359 let config = ModelTracerConfig::lightweight();
360 let mut tracer = ModelTracer::new(config);
361
362 tracer.begin_forward(0);
363 let num_layers = 32;
364 for i in 0..num_layers {
365 tracer.record_layer_activation(LayerActivationTrace::new(i));
366 }
367 tracer.end_forward();
368
369 assert_eq!(
371 tracer.activation_traces[0].layers.len(),
372 num_layers,
373 "FALSIFICATION FAILED: recorded {} layers but expected {}",
374 tracer.activation_traces[0].layers.len(),
375 num_layers
376 );
377 }
378}