Skip to main content

trueno/brick/tracing/
activation.rs

1// ============================================================================
2// E.11.2: LayerActivationTrace (MLT-01)
3// ============================================================================
4
5/// Statistics for a tensor without storing the tensor itself.
6///
7/// Computes min, max, mean, std, L2 norm, NaN/Inf counts in a single pass.
8/// Used for anomaly detection (explosion, vanishing gradients, NaN propagation).
9///
10/// # Example
11/// ```rust,ignore
12/// let stats = TensorStats::from_slice(&tensor_data);
13/// if stats.has_anomaly() {
14///     log::warn!("Anomaly detected: {}", stats.anomaly_description());
15/// }
16/// ```
17#[derive(Debug, Clone, Default, PartialEq)]
18pub struct TensorStats {
19    /// Number of elements analyzed
20    pub count: usize,
21    /// Minimum value (ignoring NaN/Inf)
22    pub min: f32,
23    /// Maximum value (ignoring NaN/Inf)
24    pub max: f32,
25    /// Mean value (ignoring NaN/Inf)
26    pub mean: f32,
27    /// Standard deviation (ignoring NaN/Inf)
28    pub std: f32,
29    /// Count of NaN values
30    pub nan_count: usize,
31    /// Count of Inf values
32    pub inf_count: usize,
33    /// L2 norm (sqrt of sum of squares)
34    pub l2_norm: f32,
35}
36
37impl TensorStats {
38    /// Compute statistics from a slice in a single pass.
39    ///
40    /// Uses Welford's algorithm for numerically stable mean/variance.
41    pub fn from_slice(data: &[f32]) -> Self {
42        if data.is_empty() {
43            return Self::default();
44        }
45
46        let mut count = 0usize;
47        let mut nan_count = 0usize;
48        let mut inf_count = 0usize;
49        let mut min = f32::MAX;
50        let mut max = f32::MIN;
51        let mut sum_sq = 0.0f64;
52
53        // Welford's algorithm for online mean/variance
54        let mut mean = 0.0f64;
55        let mut m2 = 0.0f64;
56
57        for &val in data {
58            if val.is_nan() {
59                nan_count += 1;
60                continue;
61            }
62            if val.is_infinite() {
63                inf_count += 1;
64                continue;
65            }
66
67            count += 1;
68            min = min.min(val);
69            max = max.max(val);
70            sum_sq += (val as f64) * (val as f64);
71
72            // Welford's update
73            let delta = val as f64 - mean;
74            mean += delta / count as f64;
75            let delta2 = val as f64 - mean;
76            m2 += delta * delta2;
77        }
78
79        let std = if count > 1 { (m2 / (count - 1) as f64).sqrt() as f32 } else { 0.0 };
80
81        let l2_norm = sum_sq.sqrt() as f32;
82
83        Self {
84            count: data.len(),
85            min: if count > 0 { min } else { 0.0 },
86            max: if count > 0 { max } else { 0.0 },
87            mean: mean as f32,
88            std,
89            nan_count,
90            inf_count,
91            l2_norm,
92        }
93    }
94
95    /// Check if this tensor has any anomalies.
96    ///
97    /// Anomaly detection rules (from E.11.2):
98    /// - NaN detected: `nan_count > 0`
99    /// - Explosion: `max.abs() > 1e6` or `std > 1e4`
100    /// - Vanishing: `std < 1e-6` (should check after first few layers)
101    pub fn has_anomaly(&self) -> bool {
102        self.nan_count > 0
103            || self.inf_count > 0
104            || self.max.abs() > 1e6
105            || self.min.abs() > 1e6
106            || self.std > 1e4
107    }
108
109    /// Check if values are vanishing (for layers past warmup).
110    pub fn is_vanishing(&self) -> bool {
111        self.std < 1e-6 && self.count > 0
112    }
113
114    /// Get a description of any anomaly detected.
115    pub fn anomaly_description(&self) -> Option<String> {
116        if self.nan_count > 0 {
117            return Some(format!("NaN detected: {} values", self.nan_count));
118        }
119        if self.inf_count > 0 {
120            return Some(format!("Inf detected: {} values", self.inf_count));
121        }
122        if self.max.abs() > 1e6 || self.min.abs() > 1e6 {
123            return Some(format!("Explosion: min={:.2e}, max={:.2e}", self.min, self.max));
124        }
125        if self.std > 1e4 {
126            return Some(format!("High variance: std={:.2e}", self.std));
127        }
128        None
129    }
130}
131
132/// Activation trace for a single transformer layer.
133///
134/// Records tensor statistics at each stage of a transformer layer:
135/// input -> norm -> attention -> residual -> ffn -> output
136#[derive(Debug, Clone, Default)]
137pub struct LayerActivationTrace {
138    /// Layer index (0-indexed)
139    pub layer_idx: usize,
140    /// Input hidden state statistics
141    pub input_stats: TensorStats,
142    /// After RMSNorm/LayerNorm statistics
143    pub post_norm_stats: TensorStats,
144    /// After attention statistics
145    pub post_attn_stats: TensorStats,
146    /// After FFN statistics
147    pub post_ffn_stats: TensorStats,
148    /// Output hidden state statistics
149    pub output_stats: TensorStats,
150    /// Residual connection magnitude ratio (output_norm / (output_norm + attn_norm))
151    pub residual_ratio: f32,
152}
153
154impl LayerActivationTrace {
155    /// Create a new layer activation trace.
156    pub fn new(layer_idx: usize) -> Self {
157        Self { layer_idx, ..Default::default() }
158    }
159
160    /// Check if this layer has any anomalies.
161    pub fn has_anomaly(&self) -> bool {
162        self.input_stats.has_anomaly()
163            || self.post_norm_stats.has_anomaly()
164            || self.post_attn_stats.has_anomaly()
165            || self.post_ffn_stats.has_anomaly()
166            || self.output_stats.has_anomaly()
167            || self.residual_ratio > 0.99 // Skip connection bypass
168    }
169
170    /// Get anomaly description for this layer.
171    pub fn anomaly_description(&self) -> Option<String> {
172        if let Some(desc) = self.input_stats.anomaly_description() {
173            return Some(format!("Layer {} input: {}", self.layer_idx, desc));
174        }
175        if let Some(desc) = self.post_norm_stats.anomaly_description() {
176            return Some(format!("Layer {} post_norm: {}", self.layer_idx, desc));
177        }
178        if let Some(desc) = self.post_attn_stats.anomaly_description() {
179            return Some(format!("Layer {} post_attn: {}", self.layer_idx, desc));
180        }
181        if let Some(desc) = self.post_ffn_stats.anomaly_description() {
182            return Some(format!("Layer {} post_ffn: {}", self.layer_idx, desc));
183        }
184        if let Some(desc) = self.output_stats.anomaly_description() {
185            return Some(format!("Layer {} output: {}", self.layer_idx, desc));
186        }
187        if self.residual_ratio > 0.99 {
188            return Some(format!(
189                "Layer {} residual dominance: ratio={:.4}",
190                self.layer_idx, self.residual_ratio
191            ));
192        }
193        None
194    }
195}
196
197/// Full model activation trace for one forward pass.
198#[derive(Debug, Clone, Default)]
199pub struct ModelActivationTrace {
200    /// Per-layer activation traces
201    pub layers: Vec<LayerActivationTrace>,
202    /// Embedding output statistics
203    pub embedding_stats: TensorStats,
204    /// Final logits statistics
205    pub logits_stats: TensorStats,
206    /// Whether any anomaly was detected
207    pub has_anomaly: bool,
208    /// Description of first anomaly found
209    pub anomaly_desc: Option<String>,
210}
211
212impl ModelActivationTrace {
213    /// Create a new model activation trace with expected layer count.
214    pub fn with_capacity(num_layers: usize) -> Self {
215        Self { layers: Vec::with_capacity(num_layers), ..Default::default() }
216    }
217
218    /// Add a layer trace.
219    pub fn add_layer(&mut self, trace: LayerActivationTrace) {
220        if !self.has_anomaly {
221            if let Some(desc) = trace.anomaly_description() {
222                self.has_anomaly = true;
223                self.anomaly_desc = Some(desc);
224            }
225        }
226        self.layers.push(trace);
227    }
228
229    /// Finalize the trace and check embedding/logits.
230    pub fn finalize(&mut self) {
231        if !self.has_anomaly {
232            if let Some(desc) = self.embedding_stats.anomaly_description() {
233                self.has_anomaly = true;
234                self.anomaly_desc = Some(format!("Embedding: {}", desc));
235            }
236        }
237        if !self.has_anomaly {
238            if let Some(desc) = self.logits_stats.anomaly_description() {
239                self.has_anomaly = true;
240                self.anomaly_desc = Some(format!("Logits: {}", desc));
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    // ========================================================================
251    // TensorStats Tests
252    // ========================================================================
253
254    #[test]
255    fn test_tensor_stats_empty() {
256        let stats = TensorStats::from_slice(&[]);
257        assert_eq!(stats.count, 0);
258        assert_eq!(stats.nan_count, 0);
259        assert!(!stats.has_anomaly());
260    }
261
262    #[test]
263    fn test_tensor_stats_basic() {
264        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
265        let stats = TensorStats::from_slice(&data);
266        assert_eq!(stats.count, 5);
267        assert_eq!(stats.min, 1.0);
268        assert_eq!(stats.max, 5.0);
269        assert!((stats.mean - 3.0).abs() < 0.01);
270        assert!(!stats.has_anomaly());
271    }
272
273    #[test]
274    fn test_tensor_stats_nan_detection() {
275        let data = vec![1.0, f32::NAN, 3.0];
276        let stats = TensorStats::from_slice(&data);
277        assert_eq!(stats.nan_count, 1);
278        assert!(stats.has_anomaly());
279        assert!(stats.anomaly_description().unwrap().contains("NaN"));
280    }
281
282    #[test]
283    fn test_tensor_stats_inf_detection() {
284        let data = vec![1.0, f32::INFINITY, 3.0];
285        let stats = TensorStats::from_slice(&data);
286        assert_eq!(stats.inf_count, 1);
287        assert!(stats.has_anomaly());
288    }
289
290    #[test]
291    fn test_tensor_stats_explosion() {
292        let data = vec![1e7, 2e7];
293        let stats = TensorStats::from_slice(&data);
294        assert!(stats.has_anomaly());
295        assert!(stats.anomaly_description().unwrap().contains("Explosion"));
296    }
297
298    #[test]
299    fn test_tensor_stats_vanishing() {
300        let data = vec![1e-8, 1e-8, 1e-8];
301        let stats = TensorStats::from_slice(&data);
302        assert!(stats.is_vanishing());
303    }
304
305    // ========================================================================
306    // LayerActivationTrace Tests
307    // ========================================================================
308
309    #[test]
310    fn test_layer_activation_trace_new() {
311        let trace = LayerActivationTrace::new(5);
312        assert_eq!(trace.layer_idx, 5);
313        assert!(!trace.has_anomaly());
314    }
315
316    #[test]
317    fn test_layer_activation_trace_anomaly() {
318        let mut trace = LayerActivationTrace::new(0);
319        trace.input_stats = TensorStats::from_slice(&[f32::NAN]);
320        assert!(trace.has_anomaly());
321        assert!(trace.anomaly_description().is_some());
322    }
323
324    #[test]
325    fn test_layer_activation_trace_residual_dominance() {
326        let mut trace = LayerActivationTrace::new(0);
327        trace.residual_ratio = 0.999;
328        assert!(trace.has_anomaly());
329        assert!(trace.anomaly_description().unwrap().contains("residual"));
330    }
331
332    // ========================================================================
333    // ModelActivationTrace Tests
334    // ========================================================================
335
336    #[test]
337    fn test_model_activation_trace_add_layer() {
338        let mut trace = ModelActivationTrace::with_capacity(32);
339        trace.add_layer(LayerActivationTrace::new(0));
340        trace.add_layer(LayerActivationTrace::new(1));
341        assert_eq!(trace.layers.len(), 2);
342        assert!(!trace.has_anomaly);
343    }
344
345    #[test]
346    fn test_model_activation_trace_anomaly_propagation() {
347        let mut trace = ModelActivationTrace::default();
348        let mut bad_layer = LayerActivationTrace::new(0);
349        bad_layer.input_stats = TensorStats::from_slice(&[f32::NAN]);
350        trace.add_layer(bad_layer);
351        assert!(trace.has_anomaly);
352    }
353
354    /// FALSIFICATION TEST: TensorStats Welford algorithm numerical stability
355    ///
356    /// Welford's algorithm must produce correct mean/std even for large values.
357    #[test]
358    fn test_falsify_tensor_stats_welford_stability() {
359        // Test with large offset - naive algorithm would lose precision
360        let large_offset = 1e9;
361        let data: Vec<f32> = (0..1000).map(|i| large_offset + i as f32).collect();
362        let stats = TensorStats::from_slice(&data);
363
364        // Mean should be large_offset + 499.5
365        let expected_mean = large_offset + 499.5;
366        assert!(
367            (stats.mean - expected_mean as f32).abs() < 1.0,
368            "FALSIFICATION FAILED: Welford mean {} != expected {} (relative error too high)",
369            stats.mean,
370            expected_mean
371        );
372
373        // Std should be ~288.7 (uniform distribution 0-999)
374        assert!(
375            stats.std > 280.0 && stats.std < 300.0,
376            "FALSIFICATION FAILED: Welford std {} outside expected range [280, 300]",
377            stats.std
378        );
379    }
380}