Skip to main content

kizzasi_model/
interpretability.rs

1//! Model Interpretability Tools
2//!
3//! This module provides tools for understanding what RWKV and other SSM models
4//! have learned: activation statistics, gating pattern analysis, state trajectory
5//! inspection, sensitivity analysis, and compression potential estimation.
6//!
7//! # Design
8//!
9//! All analysis is performed in pure Rust using `scirs2-core` arrays.
10//! No external Python or visualization dependencies are required.
11
12use crate::error::{ModelError, ModelResult};
13use scirs2_core::ndarray::{Array1, Array2};
14
15// ---------------------------------------------------------------------------
16// ActivationStats
17// ---------------------------------------------------------------------------
18
19/// Running statistics over a sequence of activations.
20///
21/// Uses Welford's online algorithm for numerically stable mean/variance.
22#[derive(Debug, Clone)]
23pub struct ActivationStats {
24    /// Per-dimension mean
25    pub mean: Array1<f32>,
26    /// Per-dimension variance (population variance)
27    pub variance: Array1<f32>,
28    /// Per-dimension maximum encountered value
29    pub max: Array1<f32>,
30    /// Per-dimension minimum encountered value
31    pub min: Array1<f32>,
32    /// Fraction of values with absolute value below `eps` (default 1e-6)
33    pub sparsity: f32,
34    /// Global L2 norm (average over all steps)
35    pub l2_norm: f32,
36    /// Number of activation vectors accumulated
37    pub num_steps: usize,
38
39    // Welford running state (not serialized / public)
40    welford_m2: Array1<f32>,
41    near_zero_count: usize,
42    total_elements: usize,
43    l2_sum: f32,
44}
45
46impl ActivationStats {
47    /// Compute statistics from a fixed batch of activations.
48    ///
49    /// Returns an error if the slice is empty or dimensions are inconsistent.
50    pub fn from_sequence(activations: &[Array1<f32>]) -> ModelResult<Self> {
51        if activations.is_empty() {
52            return Err(ModelError::invalid_config(
53                "ActivationStats::from_sequence: empty activation sequence",
54            ));
55        }
56        let dim = activations[0].len();
57        for (i, a) in activations.iter().enumerate() {
58            if a.len() != dim {
59                return Err(ModelError::dimension_mismatch(
60                    format!("activation[{i}]"),
61                    dim,
62                    a.len(),
63                ));
64            }
65        }
66
67        let mut stats = Self::zero(dim);
68        for a in activations {
69            stats.update(a);
70        }
71        Ok(stats)
72    }
73
74    /// Create a zeroed stats object for the given dimension.
75    fn zero(dim: usize) -> Self {
76        Self {
77            mean: Array1::zeros(dim),
78            variance: Array1::zeros(dim),
79            max: Array1::from_elem(dim, f32::NEG_INFINITY),
80            min: Array1::from_elem(dim, f32::INFINITY),
81            sparsity: 0.0,
82            l2_norm: 0.0,
83            num_steps: 0,
84            welford_m2: Array1::zeros(dim),
85            near_zero_count: 0,
86            total_elements: 0,
87            l2_sum: 0.0,
88        }
89    }
90
91    /// Incrementally incorporate a new activation vector (Welford's algorithm).
92    pub fn update(&mut self, activation: &Array1<f32>) {
93        let eps = 1e-6_f32;
94        self.num_steps += 1;
95        let n = self.num_steps as f32;
96
97        let mut sq_sum = 0.0_f32;
98        let mut nz = 0usize;
99
100        for (i, &v) in activation.iter().enumerate() {
101            if i >= self.mean.len() {
102                break;
103            }
104            // Welford update
105            let delta = v - self.mean[i];
106            self.mean[i] += delta / n;
107            let delta2 = v - self.mean[i];
108            self.welford_m2[i] += delta * delta2;
109            self.variance[i] = if self.num_steps > 1 {
110                self.welford_m2[i] / n
111            } else {
112                0.0
113            };
114
115            // Min / max
116            if v > self.max[i] {
117                self.max[i] = v;
118            }
119            if v < self.min[i] {
120                self.min[i] = v;
121            }
122
123            // Near-zero count
124            if v.abs() < eps {
125                nz += 1;
126            }
127
128            sq_sum += v * v;
129        }
130
131        self.near_zero_count += nz;
132        self.total_elements += activation.len();
133        self.l2_sum += sq_sum.sqrt();
134        self.l2_norm = self.l2_sum / n;
135        self.sparsity = if self.total_elements > 0 {
136            self.near_zero_count as f32 / self.total_elements as f32
137        } else {
138            0.0
139        };
140    }
141
142    /// Reset all statistics back to zero.
143    pub fn reset(&mut self) {
144        let dim = self.mean.len();
145        self.mean.fill(0.0);
146        self.variance.fill(0.0);
147        self.max.fill(f32::NEG_INFINITY);
148        self.min.fill(f32::INFINITY);
149        self.sparsity = 0.0;
150        self.l2_norm = 0.0;
151        self.num_steps = 0;
152        self.welford_m2 = Array1::zeros(dim);
153        self.near_zero_count = 0;
154        self.total_elements = 0;
155        self.l2_sum = 0.0;
156    }
157}
158
159// ---------------------------------------------------------------------------
160// LayerProbe
161// ---------------------------------------------------------------------------
162
163/// Ring-buffer probe for capturing intermediate layer outputs.
164///
165/// When `enabled`, each call to [`LayerProbe::capture`] stores the activation.
166/// The buffer is bounded to `max_capture` entries; older entries are overwritten.
167pub struct LayerProbe {
168    layer_name: String,
169    captured: Vec<Array1<f32>>,
170    max_capture: usize,
171    head: usize, // write index into ring buffer (wraps around)
172    filled: bool,
173    enabled: bool,
174}
175
176impl LayerProbe {
177    /// Create a new probe for `layer_name` that holds up to `max_capture` activations.
178    pub fn new(layer_name: &str, max_capture: usize) -> Self {
179        let max_capture = max_capture.max(1);
180        Self {
181            layer_name: layer_name.to_owned(),
182            captured: Vec::with_capacity(max_capture),
183            max_capture,
184            head: 0,
185            filled: false,
186            enabled: true,
187        }
188    }
189
190    /// Capture one activation vector. No-op if disabled.
191    pub fn capture(&mut self, activation: Array1<f32>) {
192        if !self.enabled {
193            return;
194        }
195        if self.captured.len() < self.max_capture {
196            self.captured.push(activation);
197        } else {
198            self.captured[self.head] = activation;
199            self.filled = true;
200        }
201        self.head = (self.head + 1) % self.max_capture;
202    }
203
204    /// Compute statistics over all captured activations.
205    pub fn stats(&self) -> ModelResult<ActivationStats> {
206        if self.captured.is_empty() {
207            return Err(ModelError::invalid_config(format!(
208                "LayerProbe '{}': no activations captured",
209                self.layer_name
210            )));
211        }
212        ActivationStats::from_sequence(&self.captured)
213    }
214
215    /// Read all captured activations (in capture order if not wrapped, otherwise ring order).
216    pub fn activations(&self) -> &[Array1<f32>] {
217        &self.captured
218    }
219
220    /// Whether the ring buffer has wrapped around at least once.
221    pub fn is_full(&self) -> bool {
222        self.filled
223    }
224
225    /// Enable capturing.
226    pub fn enable(&mut self) {
227        self.enabled = true;
228    }
229
230    /// Disable capturing (future captures are silently dropped).
231    pub fn disable(&mut self) {
232        self.enabled = false;
233    }
234
235    /// Clear all captured activations and reset ring-buffer state.
236    pub fn clear(&mut self) {
237        self.captured.clear();
238        self.head = 0;
239        self.filled = false;
240    }
241
242    /// Return the layer name associated with this probe.
243    pub fn layer_name(&self) -> &str {
244        &self.layer_name
245    }
246}
247
248// ---------------------------------------------------------------------------
249// GatingAnalysis
250// ---------------------------------------------------------------------------
251
252/// Analysis of gating patterns in gated models (RWKV, Mamba, etc.).
253#[derive(Debug, Clone)]
254pub struct GatingAnalysis {
255    /// All captured gate activation vectors (one per step)
256    pub gate_values: Vec<Array1<f32>>,
257    /// Average gate value per unit, across all steps
258    pub avg_gate: Array1<f32>,
259    /// Indices of "dead" gates: average value below `threshold`
260    pub dead_gates: Vec<usize>,
261    /// Indices of "saturated" gates: average value above `1 - threshold`
262    pub saturated_gates: Vec<usize>,
263    /// Shannon entropy of the average gate distribution (treating avg_gate as unnormalised probs)
264    pub gate_entropy: f32,
265}
266
267impl GatingAnalysis {
268    /// Analyse the given gate activations.
269    ///
270    /// # Parameters
271    ///
272    /// - `gate_values` — one `Array1<f32>` per time-step; all must have the same length
273    /// - `threshold` — gates with avg < threshold are "dead"; avg > 1 - threshold are "saturated"
274    pub fn from_activations(gate_values: Vec<Array1<f32>>, threshold: f32) -> ModelResult<Self> {
275        if gate_values.is_empty() {
276            return Err(ModelError::invalid_config(
277                "GatingAnalysis: no gate values provided",
278            ));
279        }
280        let dim = gate_values[0].len();
281        for (i, g) in gate_values.iter().enumerate() {
282            if g.len() != dim {
283                return Err(ModelError::dimension_mismatch(
284                    format!("gate_values[{i}]"),
285                    dim,
286                    g.len(),
287                ));
288            }
289        }
290
291        // Compute average gate per dimension
292        let n = gate_values.len() as f32;
293        let mut avg_gate = Array1::zeros(dim);
294        for g in &gate_values {
295            for (i, &v) in g.iter().enumerate() {
296                avg_gate[i] += v;
297            }
298        }
299        avg_gate.mapv_inplace(|v: f32| v / n);
300
301        let threshold_clamped = threshold.clamp(0.0, 0.5);
302        let mut dead_gates = Vec::new();
303        let mut saturated_gates = Vec::new();
304        for (i, &v) in avg_gate.iter().enumerate() {
305            if v < threshold_clamped {
306                dead_gates.push(i);
307            } else if v > 1.0 - threshold_clamped {
308                saturated_gates.push(i);
309            }
310        }
311
312        // Shannon entropy: H = -sum p*log2(p), treating each avg as a probability of being open
313        let eps = 1e-9_f32;
314        let mut entropy = 0.0_f32;
315        for &p in avg_gate.iter() {
316            let p: f32 = p.clamp(eps, 1.0 - eps);
317            entropy -= p * p.log2() + (1.0 - p) * (1.0 - p).log2();
318        }
319        let gate_entropy = entropy / dim as f32;
320
321        Ok(Self {
322            gate_values,
323            avg_gate,
324            dead_gates,
325            saturated_gates,
326            gate_entropy,
327        })
328    }
329
330    /// Fraction of gates that are neither dead nor saturated (effective gates).
331    pub fn effective_capacity(&self) -> f32 {
332        let total = self.avg_gate.len();
333        if total == 0 {
334            return 0.0;
335        }
336        let inactive = self.dead_gates.len() + self.saturated_gates.len();
337        let active = total.saturating_sub(inactive);
338        active as f32 / total as f32
339    }
340}
341
342// ---------------------------------------------------------------------------
343// StateTrajectory
344// ---------------------------------------------------------------------------
345
346/// Records and analyses the trajectory of hidden states over time.
347pub struct StateTrajectory {
348    states: Vec<Array1<f32>>,
349    dim: usize,
350}
351
352impl StateTrajectory {
353    /// Create a new trajectory recorder for the given state dimension.
354    pub fn new(dim: usize) -> Self {
355        Self {
356            states: Vec::new(),
357            dim,
358        }
359    }
360
361    /// Append a state vector; returns an error on dimension mismatch.
362    pub fn push(&mut self, state: Array1<f32>) -> ModelResult<()> {
363        if state.len() != self.dim {
364            return Err(ModelError::dimension_mismatch(
365                "StateTrajectory::push",
366                self.dim,
367                state.len(),
368            ));
369        }
370        self.states.push(state);
371        Ok(())
372    }
373
374    /// Number of states recorded.
375    pub fn len(&self) -> usize {
376        self.states.len()
377    }
378
379    /// Whether no states have been recorded.
380    pub fn is_empty(&self) -> bool {
381        self.states.is_empty()
382    }
383
384    /// Compute per-step velocities: `||s_{t+1} - s_t||_2` for `t` in `0..len-1`.
385    ///
386    /// Returns an error if fewer than 2 states have been recorded.
387    pub fn velocities(&self) -> ModelResult<Vec<f32>> {
388        if self.states.len() < 2 {
389            return Err(ModelError::invalid_config(
390                "StateTrajectory::velocities: need at least 2 states",
391            ));
392        }
393        let mut vels = Vec::with_capacity(self.states.len() - 1);
394        for w in self.states.windows(2) {
395            let diff = &w[1] - &w[0];
396            let norm = diff.iter().map(|&v| v * v).sum::<f32>().sqrt();
397            vels.push(norm);
398        }
399        Ok(vels)
400    }
401
402    /// Effective dimensionality via participation ratio.
403    ///
404    /// PR = (Σ λ_i)² / Σ λ_i²  where λ_i are per-dimension variances.
405    /// A high value means the state uses many dimensions; a low value means
406    /// information is concentrated in a few dimensions.
407    pub fn participation_ratio(&self) -> ModelResult<f32> {
408        if self.states.is_empty() {
409            return Err(ModelError::invalid_config(
410                "StateTrajectory::participation_ratio: no states recorded",
411            ));
412        }
413
414        // Compute per-dimension variance
415        let n = self.states.len() as f32;
416        let mut mean: Array1<f32> = Array1::zeros(self.dim);
417        for s in &self.states {
418            for (i, &v) in s.iter().enumerate() {
419                mean[i] += v;
420            }
421        }
422        mean.mapv_inplace(|v: f32| v / n);
423
424        let mut var: Array1<f32> = Array1::zeros(self.dim);
425        for s in &self.states {
426            for (i, &v) in s.iter().enumerate() {
427                let d: f32 = v - mean[i];
428                var[i] += d * d;
429            }
430        }
431        var.mapv_inplace(|v: f32| v / n);
432
433        let sum_var: f32 = var.iter().sum();
434        let sum_var_sq: f32 = var.iter().map(|&v| v * v).sum();
435
436        if sum_var_sq < 1e-20 {
437            // All states identical → effectively 0-dimensional
438            return Ok(0.0);
439        }
440
441        Ok((sum_var * sum_var) / sum_var_sq)
442    }
443
444    /// Return indices of the `k` dimensions with highest variance, sorted descending.
445    pub fn most_variable_dims(&self, k: usize) -> ModelResult<Vec<usize>> {
446        if self.states.is_empty() {
447            return Err(ModelError::invalid_config(
448                "StateTrajectory::most_variable_dims: no states recorded",
449            ));
450        }
451        let k = k.min(self.dim);
452
453        let n = self.states.len() as f32;
454        let mut mean = vec![0.0_f32; self.dim];
455        for s in &self.states {
456            for (i, &v) in s.iter().enumerate() {
457                mean[i] += v;
458            }
459        }
460        for m in &mut mean {
461            *m /= n;
462        }
463
464        let mut var = vec![0.0_f32; self.dim];
465        for s in &self.states {
466            for (i, &v) in s.iter().enumerate() {
467                let d = v - mean[i];
468                var[i] += d * d;
469            }
470        }
471        for v in &mut var {
472            *v /= n;
473        }
474
475        let mut idx: Vec<usize> = (0..self.dim).collect();
476        idx.sort_unstable_by(|&a, &b| {
477            var[b]
478                .partial_cmp(&var[a])
479                .unwrap_or(std::cmp::Ordering::Equal)
480        });
481        idx.truncate(k);
482        Ok(idx)
483    }
484
485    /// Compute state autocorrelation at lag `lag`.
486    ///
487    /// Returns the mean Pearson correlation of `s_t` and `s_{t+lag}` across all
488    /// valid pairs.  A lag-0 autocorrelation is always 1.0.
489    pub fn autocorrelation(&self, lag: usize) -> ModelResult<f32> {
490        if self.states.len() <= lag {
491            return Err(ModelError::invalid_config(format!(
492                "StateTrajectory::autocorrelation: lag {lag} requires at least {} states, have {}",
493                lag + 1,
494                self.states.len()
495            )));
496        }
497
498        let n_pairs = self.states.len() - lag;
499        let mut corr_sum = 0.0_f32;
500
501        for t in 0..n_pairs {
502            let s0 = &self.states[t];
503            let s1 = &self.states[t + lag];
504
505            // Pearson correlation between s0 and s1 (as two length-dim vectors)
506            let n = self.dim as f32;
507            let mean0: f32 = s0.iter().sum::<f32>() / n;
508            let mean1: f32 = s1.iter().sum::<f32>() / n;
509
510            let mut cov = 0.0_f32;
511            let mut std0 = 0.0_f32;
512            let mut std1 = 0.0_f32;
513            for (&a, &b) in s0.iter().zip(s1.iter()) {
514                let da = a - mean0;
515                let db = b - mean1;
516                cov += da * db;
517                std0 += da * da;
518                std1 += db * db;
519            }
520
521            let denom = (std0 * std1).sqrt();
522            if denom < 1e-10 {
523                // Constant vectors → perfectly correlated by convention
524                corr_sum += 1.0;
525            } else {
526                corr_sum += cov / denom;
527            }
528        }
529
530        Ok(corr_sum / n_pairs as f32)
531    }
532
533    /// Flatten all states into an `(num_steps, dim)` matrix.
534    pub fn to_matrix(&self) -> ModelResult<Array2<f32>> {
535        if self.states.is_empty() {
536            return Err(ModelError::invalid_config(
537                "StateTrajectory::to_matrix: no states recorded",
538            ));
539        }
540        let t = self.states.len();
541        let d = self.dim;
542        let mut mat = Array2::zeros((t, d));
543        for (row, state) in self.states.iter().enumerate() {
544            for (col, &v) in state.iter().enumerate() {
545                mat[[row, col]] = v;
546            }
547        }
548        Ok(mat)
549    }
550}
551
552// ---------------------------------------------------------------------------
553// SensitivityAnalyzer
554// ---------------------------------------------------------------------------
555
556/// Measures feature importance via finite-difference sensitivity analysis.
557pub struct SensitivityAnalyzer {
558    input_dim: usize,
559}
560
561impl SensitivityAnalyzer {
562    /// Create an analyzer for inputs of size `input_dim`.
563    pub fn new(input_dim: usize) -> Self {
564        Self { input_dim }
565    }
566
567    /// Estimate per-feature sensitivity using finite differences.
568    ///
569    /// For each feature `i`, computes `||f(x + eps*e_i) - f(x)|| / eps`.
570    ///
571    /// # Arguments
572    ///
573    /// - `input` — base input vector (length must equal `input_dim`)
574    /// - `forward_fn` — model forward pass (called `input_dim + 1` times)
575    /// - `eps` — perturbation magnitude (default 1e-3 is reasonable)
576    pub fn input_sensitivity<F>(
577        &self,
578        input: &Array1<f32>,
579        forward_fn: F,
580        eps: f32,
581    ) -> ModelResult<Array1<f32>>
582    where
583        F: Fn(&Array1<f32>) -> ModelResult<Array1<f32>>,
584    {
585        if input.len() != self.input_dim {
586            return Err(ModelError::dimension_mismatch(
587                "SensitivityAnalyzer::input_sensitivity",
588                self.input_dim,
589                input.len(),
590            ));
591        }
592
593        let base_out = forward_fn(input)?;
594        let base_norm = base_out.iter().map(|&v| v * v).sum::<f32>().sqrt();
595
596        let mut sensitivities = Array1::zeros(self.input_dim);
597        for i in 0..self.input_dim {
598            let mut perturbed = input.clone();
599            perturbed[i] += eps;
600            let pert_out = forward_fn(&perturbed)?;
601
602            // Measure change in output norm
603            let diff_norm = pert_out
604                .iter()
605                .zip(base_out.iter())
606                .map(|(&a, &b)| (a - b) * (a - b))
607                .sum::<f32>()
608                .sqrt();
609
610            sensitivities[i] = if eps.abs() > 1e-15 {
611                diff_norm / eps.abs()
612            } else {
613                base_norm
614            };
615        }
616
617        Ok(sensitivities)
618    }
619
620    /// Rank features by average sensitivity across multiple input vectors.
621    ///
622    /// Returns a sorted list of `(feature_index, avg_sensitivity)` in descending order.
623    pub fn rank_features<F>(
624        &self,
625        inputs: &[Array1<f32>],
626        forward_fn: F,
627        eps: f32,
628    ) -> ModelResult<Vec<(usize, f32)>>
629    where
630        F: Fn(&Array1<f32>) -> ModelResult<Array1<f32>>,
631    {
632        if inputs.is_empty() {
633            return Err(ModelError::invalid_config(
634                "SensitivityAnalyzer::rank_features: no inputs provided",
635            ));
636        }
637
638        let mut total: Array1<f32> = Array1::zeros(self.input_dim);
639        for input in inputs {
640            let sens = self.input_sensitivity(input, &forward_fn, eps)?;
641            for (i, &v) in sens.iter().enumerate() {
642                total[i] += v;
643            }
644        }
645
646        let n = inputs.len() as f32;
647        let mut ranked: Vec<(usize, f32)> =
648            total.iter().enumerate().map(|(i, &v)| (i, v / n)).collect();
649
650        ranked.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
651
652        Ok(ranked)
653    }
654}
655
656// ---------------------------------------------------------------------------
657// CompressionAnalysis
658// ---------------------------------------------------------------------------
659
660/// Estimates how compressible a weight matrix is.
661#[derive(Debug, Clone)]
662pub struct CompressionAnalysis {
663    /// Fraction of weight values with |w| < eps
664    pub weight_sparsity: f32,
665    /// Effective rank via participation ratio (higher = less compressible)
666    pub effective_rank: f32,
667    /// Estimated INT8 quantization relative error
668    pub quantization_error: f32,
669    /// Suggested LoRA rank for this layer (heuristic)
670    pub recommended_rank: usize,
671    /// Overall compressibility score in [0, 1] (1 = very compressible)
672    pub compression_potential: f32,
673}
674
675impl CompressionAnalysis {
676    /// Analyse a single weight matrix.
677    ///
678    /// `eps` is the threshold below which a weight is considered "zero".
679    pub fn analyze_weight(weight: &Array2<f32>, eps: f32) -> ModelResult<Self> {
680        let (rows, cols) = (weight.shape()[0], weight.shape()[1]);
681        let total = rows * cols;
682
683        if total == 0 {
684            return Err(ModelError::invalid_config(
685                "CompressionAnalysis: weight matrix is empty",
686            ));
687        }
688
689        // --- Sparsity ---
690        let near_zero = weight.iter().filter(|&&v| v.abs() < eps).count();
691        let weight_sparsity = near_zero as f32 / total as f32;
692
693        // --- Per-column variances as surrogate for singular values ---
694        // (Full SVD would be expensive; variance-based PR is a good approximation)
695        let n = rows as f32;
696        let col_variances: Vec<f32> = (0..cols)
697            .map(|j| {
698                let col = weight.column(j);
699                let mean = col.iter().sum::<f32>() / n;
700                col.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / n
701            })
702            .collect();
703
704        let sum_var: f32 = col_variances.iter().sum();
705        let sum_var_sq: f32 = col_variances.iter().map(|&v| v * v).sum();
706        let effective_rank = if sum_var_sq > 1e-20 {
707            (sum_var * sum_var) / sum_var_sq
708        } else {
709            1.0
710        };
711
712        // --- INT8 quantization error estimate ---
713        // max_abs * scale_error, where scale_error ≈ 1/(2^8 - 1)
714        let max_abs = weight.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
715        let quantization_error = if max_abs > 0.0 { max_abs / 127.0 } else { 0.0 };
716
717        // --- Recommended LoRA rank (heuristic: ceil(effective_rank / 4)) ---
718        let recommended_rank = ((effective_rank / 4.0).ceil() as usize).max(1);
719
720        // --- Compression potential ---
721        // Combines sparsity and low effective rank
722        let max_dim = rows.max(cols) as f32;
723        let rank_score = 1.0 - (effective_rank / max_dim).clamp(0.0, 1.0);
724        let compression_potential = (0.6 * rank_score + 0.4 * weight_sparsity).clamp(0.0, 1.0);
725
726        Ok(Self {
727            weight_sparsity,
728            effective_rank,
729            quantization_error,
730            recommended_rank,
731            compression_potential,
732        })
733    }
734
735    /// Analyse multiple weight matrices, returning a list of (name, analysis) pairs.
736    pub fn analyze_multiple<'a>(
737        weights: &[(&'a str, &Array2<f32>)],
738    ) -> Vec<(&'a str, CompressionAnalysis)> {
739        weights
740            .iter()
741            .filter_map(|&(name, w)| Self::analyze_weight(w, 1e-6).ok().map(|a| (name, a)))
742            .collect()
743    }
744}
745
746// ---------------------------------------------------------------------------
747// InterpretabilityReport
748// ---------------------------------------------------------------------------
749
750/// High-level interpretability summary collected from a single model run.
751pub struct InterpretabilityReport {
752    /// Total number of steps included in this report
753    pub num_steps: usize,
754    /// Per-layer activation statistics: (layer_name, stats)
755    pub layer_stats: Vec<(String, ActivationStats)>,
756    /// State trajectory of the primary hidden state
757    pub state_trajectory: StateTrajectory,
758    /// Feature sensitivity ranking (feature_index, avg_sensitivity)
759    pub top_sensitive_features: Vec<(usize, f32)>,
760    /// Overall fraction of near-zero activations across all layers
761    pub overall_sparsity: f32,
762}
763
764impl Default for InterpretabilityReport {
765    fn default() -> Self {
766        Self::new()
767    }
768}
769
770impl InterpretabilityReport {
771    /// Create an empty report.
772    pub fn new() -> Self {
773        Self {
774            num_steps: 0,
775            layer_stats: Vec::new(),
776            state_trajectory: StateTrajectory::new(0),
777            top_sensitive_features: Vec::new(),
778            overall_sparsity: 0.0,
779        }
780    }
781
782    /// Generate a human-readable summary string.
783    pub fn summary(&self) -> String {
784        let mut lines = Vec::new();
785        lines.push(format!(
786            "InterpretabilityReport — {} step(s), {} layer(s)",
787            self.num_steps,
788            self.layer_stats.len()
789        ));
790        lines.push(format!(
791            "  Overall sparsity : {:.2}%",
792            self.overall_sparsity * 100.0
793        ));
794        lines.push(format!(
795            "  State trajectory : {} entries, dim={}",
796            self.state_trajectory.len(),
797            self.state_trajectory.dim
798        ));
799
800        if !self.layer_stats.is_empty() {
801            lines.push("  Layer statistics:".to_owned());
802            for (name, stats) in &self.layer_stats {
803                lines.push(format!(
804                    "    {name}: sparsity={:.2}% l2={:.4} steps={}",
805                    stats.sparsity * 100.0,
806                    stats.l2_norm,
807                    stats.num_steps
808                ));
809            }
810        }
811
812        if !self.top_sensitive_features.is_empty() {
813            lines.push("  Top sensitive features:".to_owned());
814            for &(idx, sens) in self.top_sensitive_features.iter().take(5) {
815                lines.push(format!("    feature {idx}: {sens:.4}"));
816            }
817        }
818
819        lines.join("\n")
820    }
821}
822
823// ---------------------------------------------------------------------------
824// Tests
825// ---------------------------------------------------------------------------
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    use scirs2_core::ndarray::array;
831
832    // -----------------------------------------------------------------------
833    // Test 1: ActivationStats basic
834    // -----------------------------------------------------------------------
835    #[test]
836    fn test_activation_stats_basic() {
837        // Five identical vectors: mean == value, variance == 0
838        let v = array![1.0_f32, 2.0, 3.0];
839        let activations: Vec<Array1<f32>> = (0..5).map(|_| v.clone()).collect();
840        let stats = ActivationStats::from_sequence(&activations).expect("stats");
841
842        assert_eq!(stats.num_steps, 5);
843        for (&m, &expected) in stats.mean.iter().zip(v.iter()) {
844            assert!((m - expected).abs() < 1e-5, "mean mismatch");
845        }
846        for &var in stats.variance.iter() {
847            assert!(
848                var.abs() < 1e-5,
849                "variance should be ~0 for identical vectors"
850            );
851        }
852        // No near-zero elements (all >= 1.0)
853        assert_eq!(stats.sparsity, 0.0);
854    }
855
856    // -----------------------------------------------------------------------
857    // Test 2: Incremental vs batch
858    // -----------------------------------------------------------------------
859    #[test]
860    fn test_activation_stats_incremental() {
861        let activations: Vec<Array1<f32>> =
862            (0..10).map(|i| array![i as f32, (i * 2) as f32]).collect();
863
864        let batch = ActivationStats::from_sequence(&activations).expect("batch");
865
866        let mut incr = ActivationStats::zero(2);
867        for a in &activations {
868            incr.update(a);
869        }
870
871        for (&bm, &im) in batch.mean.iter().zip(incr.mean.iter()) {
872            assert!((bm - im).abs() < 1e-4, "mean mismatch: {bm} vs {im}");
873        }
874        for (&bv, &iv) in batch.variance.iter().zip(incr.variance.iter()) {
875            assert!((bv - iv).abs() < 1e-4, "variance mismatch: {bv} vs {iv}");
876        }
877    }
878
879    // -----------------------------------------------------------------------
880    // Test 3: LayerProbe capture count
881    // -----------------------------------------------------------------------
882    #[test]
883    fn test_layer_probe_capture() {
884        let mut probe = LayerProbe::new("layer0", 1000);
885        assert!(probe.activations().is_empty());
886
887        for i in 0..7 {
888            probe.capture(array![i as f32, 0.0]);
889        }
890        assert_eq!(probe.activations().len(), 7);
891        assert_eq!(probe.layer_name(), "layer0");
892
893        // Disable → capture is ignored
894        probe.disable();
895        probe.capture(array![99.0, 0.0]);
896        assert_eq!(probe.activations().len(), 7);
897
898        // Clear
899        probe.enable();
900        probe.clear();
901        assert!(probe.activations().is_empty());
902    }
903
904    // -----------------------------------------------------------------------
905    // Test 4: LayerProbe stats
906    // -----------------------------------------------------------------------
907    #[test]
908    fn test_layer_probe_stats() {
909        let mut probe = LayerProbe::new("attn", 100);
910        probe.capture(array![0.0_f32, 0.0]);
911        probe.capture(array![2.0_f32, 4.0]);
912
913        let stats = probe.stats().expect("stats");
914        assert_eq!(stats.num_steps, 2);
915        // Mean should be [1.0, 2.0]
916        assert!((stats.mean[0] - 1.0).abs() < 1e-5);
917        assert!((stats.mean[1] - 2.0).abs() < 1e-5);
918    }
919
920    // -----------------------------------------------------------------------
921    // Test 5: StateTrajectory velocities shape
922    // -----------------------------------------------------------------------
923    #[test]
924    fn test_state_trajectory_velocities() {
925        let mut traj = StateTrajectory::new(4);
926        for i in 0..5_u32 {
927            traj.push(Array1::from_elem(4, i as f32)).expect("push");
928        }
929        let vels = traj.velocities().expect("velocities");
930        assert_eq!(vels.len(), 4, "should have len-1 velocities");
931        for v in &vels {
932            assert!(v.is_finite(), "velocities must be finite");
933            assert!(*v >= 0.0);
934        }
935    }
936
937    // -----------------------------------------------------------------------
938    // Test 6: Lag-0 autocorrelation == 1.0
939    // -----------------------------------------------------------------------
940    #[test]
941    fn test_state_trajectory_autocorrelation() {
942        let mut traj = StateTrajectory::new(8);
943        for i in 0..10_u32 {
944            let s = Array1::from_shape_fn(8, |j| (i * 8 + j as u32) as f32);
945            traj.push(s).expect("push");
946        }
947
948        let ac0 = traj.autocorrelation(0).expect("lag0");
949        assert!(
950            (ac0 - 1.0).abs() < 1e-5,
951            "lag-0 autocorr should be 1.0, got {ac0}"
952        );
953
954        // Lag-1 should be finite and in [-1, 1]
955        let ac1 = traj.autocorrelation(1).expect("lag1");
956        assert!(ac1.is_finite());
957        assert!((-1.0_f32..=1.0_f32).contains(&ac1));
958    }
959
960    // -----------------------------------------------------------------------
961    // Test 7: SensitivityAnalyzer gives non-negative sensitivities
962    // -----------------------------------------------------------------------
963    #[test]
964    fn test_sensitivity_analyzer() {
965        let analyzer = SensitivityAnalyzer::new(3);
966
967        // Forward function: identity (output == input)
968        let forward = |x: &Array1<f32>| -> ModelResult<Array1<f32>> { Ok(x.clone()) };
969
970        let input = array![1.0_f32, -0.5, 2.0];
971        let sens = analyzer
972            .input_sensitivity(&input, forward, 1e-3)
973            .expect("sensitivity");
974
975        assert_eq!(sens.len(), 3);
976        for &s in sens.iter() {
977            assert!(s >= 0.0, "sensitivity must be non-negative, got {s}");
978            assert!(s.is_finite());
979        }
980    }
981
982    // -----------------------------------------------------------------------
983    // Test 8: CompressionAnalysis valid ratios
984    // -----------------------------------------------------------------------
985    #[test]
986    fn test_compression_analysis() {
987        let w: Array2<f32> =
988            Array2::from_shape_fn((16, 16), |(i, j)| if i == j { 1.0 } else { 0.0 });
989        let analysis = CompressionAnalysis::analyze_weight(&w, 1e-6).expect("analysis");
990
991        // Identity matrix: nearly all zeros → high sparsity
992        assert!((0.0..=1.0_f32).contains(&analysis.weight_sparsity));
993        assert!(analysis.compression_potential >= 0.0);
994        assert!(analysis.compression_potential <= 1.0);
995        assert!(analysis.effective_rank > 0.0);
996        assert!(analysis.recommended_rank >= 1);
997    }
998}