1use crate::error::{ModelError, ModelResult};
13use scirs2_core::ndarray::{Array1, Array2};
14
15#[derive(Debug, Clone)]
23pub struct ActivationStats {
24 pub mean: Array1<f32>,
26 pub variance: Array1<f32>,
28 pub max: Array1<f32>,
30 pub min: Array1<f32>,
32 pub sparsity: f32,
34 pub l2_norm: f32,
36 pub num_steps: usize,
38
39 welford_m2: Array1<f32>,
41 near_zero_count: usize,
42 total_elements: usize,
43 l2_sum: f32,
44}
45
46impl ActivationStats {
47 pub fn from_sequence(activations: &[Array1<f32>]) -> ModelResult<Self> {
51 if activations.is_empty() {
52 return Err(ModelError::invalid_config(
53 "ActivationStats::from_sequence: empty activation sequence",
54 ));
55 }
56 let dim = activations[0].len();
57 for (i, a) in activations.iter().enumerate() {
58 if a.len() != dim {
59 return Err(ModelError::dimension_mismatch(
60 format!("activation[{i}]"),
61 dim,
62 a.len(),
63 ));
64 }
65 }
66
67 let mut stats = Self::zero(dim);
68 for a in activations {
69 stats.update(a);
70 }
71 Ok(stats)
72 }
73
74 fn zero(dim: usize) -> Self {
76 Self {
77 mean: Array1::zeros(dim),
78 variance: Array1::zeros(dim),
79 max: Array1::from_elem(dim, f32::NEG_INFINITY),
80 min: Array1::from_elem(dim, f32::INFINITY),
81 sparsity: 0.0,
82 l2_norm: 0.0,
83 num_steps: 0,
84 welford_m2: Array1::zeros(dim),
85 near_zero_count: 0,
86 total_elements: 0,
87 l2_sum: 0.0,
88 }
89 }
90
91 pub fn update(&mut self, activation: &Array1<f32>) {
93 let eps = 1e-6_f32;
94 self.num_steps += 1;
95 let n = self.num_steps as f32;
96
97 let mut sq_sum = 0.0_f32;
98 let mut nz = 0usize;
99
100 for (i, &v) in activation.iter().enumerate() {
101 if i >= self.mean.len() {
102 break;
103 }
104 let delta = v - self.mean[i];
106 self.mean[i] += delta / n;
107 let delta2 = v - self.mean[i];
108 self.welford_m2[i] += delta * delta2;
109 self.variance[i] = if self.num_steps > 1 {
110 self.welford_m2[i] / n
111 } else {
112 0.0
113 };
114
115 if v > self.max[i] {
117 self.max[i] = v;
118 }
119 if v < self.min[i] {
120 self.min[i] = v;
121 }
122
123 if v.abs() < eps {
125 nz += 1;
126 }
127
128 sq_sum += v * v;
129 }
130
131 self.near_zero_count += nz;
132 self.total_elements += activation.len();
133 self.l2_sum += sq_sum.sqrt();
134 self.l2_norm = self.l2_sum / n;
135 self.sparsity = if self.total_elements > 0 {
136 self.near_zero_count as f32 / self.total_elements as f32
137 } else {
138 0.0
139 };
140 }
141
142 pub fn reset(&mut self) {
144 let dim = self.mean.len();
145 self.mean.fill(0.0);
146 self.variance.fill(0.0);
147 self.max.fill(f32::NEG_INFINITY);
148 self.min.fill(f32::INFINITY);
149 self.sparsity = 0.0;
150 self.l2_norm = 0.0;
151 self.num_steps = 0;
152 self.welford_m2 = Array1::zeros(dim);
153 self.near_zero_count = 0;
154 self.total_elements = 0;
155 self.l2_sum = 0.0;
156 }
157}
158
159pub struct LayerProbe {
168 layer_name: String,
169 captured: Vec<Array1<f32>>,
170 max_capture: usize,
171 head: usize, filled: bool,
173 enabled: bool,
174}
175
176impl LayerProbe {
177 pub fn new(layer_name: &str, max_capture: usize) -> Self {
179 let max_capture = max_capture.max(1);
180 Self {
181 layer_name: layer_name.to_owned(),
182 captured: Vec::with_capacity(max_capture),
183 max_capture,
184 head: 0,
185 filled: false,
186 enabled: true,
187 }
188 }
189
190 pub fn capture(&mut self, activation: Array1<f32>) {
192 if !self.enabled {
193 return;
194 }
195 if self.captured.len() < self.max_capture {
196 self.captured.push(activation);
197 } else {
198 self.captured[self.head] = activation;
199 self.filled = true;
200 }
201 self.head = (self.head + 1) % self.max_capture;
202 }
203
204 pub fn stats(&self) -> ModelResult<ActivationStats> {
206 if self.captured.is_empty() {
207 return Err(ModelError::invalid_config(format!(
208 "LayerProbe '{}': no activations captured",
209 self.layer_name
210 )));
211 }
212 ActivationStats::from_sequence(&self.captured)
213 }
214
215 pub fn activations(&self) -> &[Array1<f32>] {
217 &self.captured
218 }
219
220 pub fn is_full(&self) -> bool {
222 self.filled
223 }
224
225 pub fn enable(&mut self) {
227 self.enabled = true;
228 }
229
230 pub fn disable(&mut self) {
232 self.enabled = false;
233 }
234
235 pub fn clear(&mut self) {
237 self.captured.clear();
238 self.head = 0;
239 self.filled = false;
240 }
241
242 pub fn layer_name(&self) -> &str {
244 &self.layer_name
245 }
246}
247
248#[derive(Debug, Clone)]
254pub struct GatingAnalysis {
255 pub gate_values: Vec<Array1<f32>>,
257 pub avg_gate: Array1<f32>,
259 pub dead_gates: Vec<usize>,
261 pub saturated_gates: Vec<usize>,
263 pub gate_entropy: f32,
265}
266
267impl GatingAnalysis {
268 pub fn from_activations(gate_values: Vec<Array1<f32>>, threshold: f32) -> ModelResult<Self> {
275 if gate_values.is_empty() {
276 return Err(ModelError::invalid_config(
277 "GatingAnalysis: no gate values provided",
278 ));
279 }
280 let dim = gate_values[0].len();
281 for (i, g) in gate_values.iter().enumerate() {
282 if g.len() != dim {
283 return Err(ModelError::dimension_mismatch(
284 format!("gate_values[{i}]"),
285 dim,
286 g.len(),
287 ));
288 }
289 }
290
291 let n = gate_values.len() as f32;
293 let mut avg_gate = Array1::zeros(dim);
294 for g in &gate_values {
295 for (i, &v) in g.iter().enumerate() {
296 avg_gate[i] += v;
297 }
298 }
299 avg_gate.mapv_inplace(|v: f32| v / n);
300
301 let threshold_clamped = threshold.clamp(0.0, 0.5);
302 let mut dead_gates = Vec::new();
303 let mut saturated_gates = Vec::new();
304 for (i, &v) in avg_gate.iter().enumerate() {
305 if v < threshold_clamped {
306 dead_gates.push(i);
307 } else if v > 1.0 - threshold_clamped {
308 saturated_gates.push(i);
309 }
310 }
311
312 let eps = 1e-9_f32;
314 let mut entropy = 0.0_f32;
315 for &p in avg_gate.iter() {
316 let p: f32 = p.clamp(eps, 1.0 - eps);
317 entropy -= p * p.log2() + (1.0 - p) * (1.0 - p).log2();
318 }
319 let gate_entropy = entropy / dim as f32;
320
321 Ok(Self {
322 gate_values,
323 avg_gate,
324 dead_gates,
325 saturated_gates,
326 gate_entropy,
327 })
328 }
329
330 pub fn effective_capacity(&self) -> f32 {
332 let total = self.avg_gate.len();
333 if total == 0 {
334 return 0.0;
335 }
336 let inactive = self.dead_gates.len() + self.saturated_gates.len();
337 let active = total.saturating_sub(inactive);
338 active as f32 / total as f32
339 }
340}
341
342pub struct StateTrajectory {
348 states: Vec<Array1<f32>>,
349 dim: usize,
350}
351
352impl StateTrajectory {
353 pub fn new(dim: usize) -> Self {
355 Self {
356 states: Vec::new(),
357 dim,
358 }
359 }
360
361 pub fn push(&mut self, state: Array1<f32>) -> ModelResult<()> {
363 if state.len() != self.dim {
364 return Err(ModelError::dimension_mismatch(
365 "StateTrajectory::push",
366 self.dim,
367 state.len(),
368 ));
369 }
370 self.states.push(state);
371 Ok(())
372 }
373
374 pub fn len(&self) -> usize {
376 self.states.len()
377 }
378
379 pub fn is_empty(&self) -> bool {
381 self.states.is_empty()
382 }
383
384 pub fn velocities(&self) -> ModelResult<Vec<f32>> {
388 if self.states.len() < 2 {
389 return Err(ModelError::invalid_config(
390 "StateTrajectory::velocities: need at least 2 states",
391 ));
392 }
393 let mut vels = Vec::with_capacity(self.states.len() - 1);
394 for w in self.states.windows(2) {
395 let diff = &w[1] - &w[0];
396 let norm = diff.iter().map(|&v| v * v).sum::<f32>().sqrt();
397 vels.push(norm);
398 }
399 Ok(vels)
400 }
401
402 pub fn participation_ratio(&self) -> ModelResult<f32> {
408 if self.states.is_empty() {
409 return Err(ModelError::invalid_config(
410 "StateTrajectory::participation_ratio: no states recorded",
411 ));
412 }
413
414 let n = self.states.len() as f32;
416 let mut mean: Array1<f32> = Array1::zeros(self.dim);
417 for s in &self.states {
418 for (i, &v) in s.iter().enumerate() {
419 mean[i] += v;
420 }
421 }
422 mean.mapv_inplace(|v: f32| v / n);
423
424 let mut var: Array1<f32> = Array1::zeros(self.dim);
425 for s in &self.states {
426 for (i, &v) in s.iter().enumerate() {
427 let d: f32 = v - mean[i];
428 var[i] += d * d;
429 }
430 }
431 var.mapv_inplace(|v: f32| v / n);
432
433 let sum_var: f32 = var.iter().sum();
434 let sum_var_sq: f32 = var.iter().map(|&v| v * v).sum();
435
436 if sum_var_sq < 1e-20 {
437 return Ok(0.0);
439 }
440
441 Ok((sum_var * sum_var) / sum_var_sq)
442 }
443
444 pub fn most_variable_dims(&self, k: usize) -> ModelResult<Vec<usize>> {
446 if self.states.is_empty() {
447 return Err(ModelError::invalid_config(
448 "StateTrajectory::most_variable_dims: no states recorded",
449 ));
450 }
451 let k = k.min(self.dim);
452
453 let n = self.states.len() as f32;
454 let mut mean = vec![0.0_f32; self.dim];
455 for s in &self.states {
456 for (i, &v) in s.iter().enumerate() {
457 mean[i] += v;
458 }
459 }
460 for m in &mut mean {
461 *m /= n;
462 }
463
464 let mut var = vec![0.0_f32; self.dim];
465 for s in &self.states {
466 for (i, &v) in s.iter().enumerate() {
467 let d = v - mean[i];
468 var[i] += d * d;
469 }
470 }
471 for v in &mut var {
472 *v /= n;
473 }
474
475 let mut idx: Vec<usize> = (0..self.dim).collect();
476 idx.sort_unstable_by(|&a, &b| {
477 var[b]
478 .partial_cmp(&var[a])
479 .unwrap_or(std::cmp::Ordering::Equal)
480 });
481 idx.truncate(k);
482 Ok(idx)
483 }
484
485 pub fn autocorrelation(&self, lag: usize) -> ModelResult<f32> {
490 if self.states.len() <= lag {
491 return Err(ModelError::invalid_config(format!(
492 "StateTrajectory::autocorrelation: lag {lag} requires at least {} states, have {}",
493 lag + 1,
494 self.states.len()
495 )));
496 }
497
498 let n_pairs = self.states.len() - lag;
499 let mut corr_sum = 0.0_f32;
500
501 for t in 0..n_pairs {
502 let s0 = &self.states[t];
503 let s1 = &self.states[t + lag];
504
505 let n = self.dim as f32;
507 let mean0: f32 = s0.iter().sum::<f32>() / n;
508 let mean1: f32 = s1.iter().sum::<f32>() / n;
509
510 let mut cov = 0.0_f32;
511 let mut std0 = 0.0_f32;
512 let mut std1 = 0.0_f32;
513 for (&a, &b) in s0.iter().zip(s1.iter()) {
514 let da = a - mean0;
515 let db = b - mean1;
516 cov += da * db;
517 std0 += da * da;
518 std1 += db * db;
519 }
520
521 let denom = (std0 * std1).sqrt();
522 if denom < 1e-10 {
523 corr_sum += 1.0;
525 } else {
526 corr_sum += cov / denom;
527 }
528 }
529
530 Ok(corr_sum / n_pairs as f32)
531 }
532
533 pub fn to_matrix(&self) -> ModelResult<Array2<f32>> {
535 if self.states.is_empty() {
536 return Err(ModelError::invalid_config(
537 "StateTrajectory::to_matrix: no states recorded",
538 ));
539 }
540 let t = self.states.len();
541 let d = self.dim;
542 let mut mat = Array2::zeros((t, d));
543 for (row, state) in self.states.iter().enumerate() {
544 for (col, &v) in state.iter().enumerate() {
545 mat[[row, col]] = v;
546 }
547 }
548 Ok(mat)
549 }
550}
551
552pub struct SensitivityAnalyzer {
558 input_dim: usize,
559}
560
561impl SensitivityAnalyzer {
562 pub fn new(input_dim: usize) -> Self {
564 Self { input_dim }
565 }
566
567 pub fn input_sensitivity<F>(
577 &self,
578 input: &Array1<f32>,
579 forward_fn: F,
580 eps: f32,
581 ) -> ModelResult<Array1<f32>>
582 where
583 F: Fn(&Array1<f32>) -> ModelResult<Array1<f32>>,
584 {
585 if input.len() != self.input_dim {
586 return Err(ModelError::dimension_mismatch(
587 "SensitivityAnalyzer::input_sensitivity",
588 self.input_dim,
589 input.len(),
590 ));
591 }
592
593 let base_out = forward_fn(input)?;
594 let base_norm = base_out.iter().map(|&v| v * v).sum::<f32>().sqrt();
595
596 let mut sensitivities = Array1::zeros(self.input_dim);
597 for i in 0..self.input_dim {
598 let mut perturbed = input.clone();
599 perturbed[i] += eps;
600 let pert_out = forward_fn(&perturbed)?;
601
602 let diff_norm = pert_out
604 .iter()
605 .zip(base_out.iter())
606 .map(|(&a, &b)| (a - b) * (a - b))
607 .sum::<f32>()
608 .sqrt();
609
610 sensitivities[i] = if eps.abs() > 1e-15 {
611 diff_norm / eps.abs()
612 } else {
613 base_norm
614 };
615 }
616
617 Ok(sensitivities)
618 }
619
620 pub fn rank_features<F>(
624 &self,
625 inputs: &[Array1<f32>],
626 forward_fn: F,
627 eps: f32,
628 ) -> ModelResult<Vec<(usize, f32)>>
629 where
630 F: Fn(&Array1<f32>) -> ModelResult<Array1<f32>>,
631 {
632 if inputs.is_empty() {
633 return Err(ModelError::invalid_config(
634 "SensitivityAnalyzer::rank_features: no inputs provided",
635 ));
636 }
637
638 let mut total: Array1<f32> = Array1::zeros(self.input_dim);
639 for input in inputs {
640 let sens = self.input_sensitivity(input, &forward_fn, eps)?;
641 for (i, &v) in sens.iter().enumerate() {
642 total[i] += v;
643 }
644 }
645
646 let n = inputs.len() as f32;
647 let mut ranked: Vec<(usize, f32)> =
648 total.iter().enumerate().map(|(i, &v)| (i, v / n)).collect();
649
650 ranked.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
651
652 Ok(ranked)
653 }
654}
655
656#[derive(Debug, Clone)]
662pub struct CompressionAnalysis {
663 pub weight_sparsity: f32,
665 pub effective_rank: f32,
667 pub quantization_error: f32,
669 pub recommended_rank: usize,
671 pub compression_potential: f32,
673}
674
675impl CompressionAnalysis {
676 pub fn analyze_weight(weight: &Array2<f32>, eps: f32) -> ModelResult<Self> {
680 let (rows, cols) = (weight.shape()[0], weight.shape()[1]);
681 let total = rows * cols;
682
683 if total == 0 {
684 return Err(ModelError::invalid_config(
685 "CompressionAnalysis: weight matrix is empty",
686 ));
687 }
688
689 let near_zero = weight.iter().filter(|&&v| v.abs() < eps).count();
691 let weight_sparsity = near_zero as f32 / total as f32;
692
693 let n = rows as f32;
696 let col_variances: Vec<f32> = (0..cols)
697 .map(|j| {
698 let col = weight.column(j);
699 let mean = col.iter().sum::<f32>() / n;
700 col.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / n
701 })
702 .collect();
703
704 let sum_var: f32 = col_variances.iter().sum();
705 let sum_var_sq: f32 = col_variances.iter().map(|&v| v * v).sum();
706 let effective_rank = if sum_var_sq > 1e-20 {
707 (sum_var * sum_var) / sum_var_sq
708 } else {
709 1.0
710 };
711
712 let max_abs = weight.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
715 let quantization_error = if max_abs > 0.0 { max_abs / 127.0 } else { 0.0 };
716
717 let recommended_rank = ((effective_rank / 4.0).ceil() as usize).max(1);
719
720 let max_dim = rows.max(cols) as f32;
723 let rank_score = 1.0 - (effective_rank / max_dim).clamp(0.0, 1.0);
724 let compression_potential = (0.6 * rank_score + 0.4 * weight_sparsity).clamp(0.0, 1.0);
725
726 Ok(Self {
727 weight_sparsity,
728 effective_rank,
729 quantization_error,
730 recommended_rank,
731 compression_potential,
732 })
733 }
734
735 pub fn analyze_multiple<'a>(
737 weights: &[(&'a str, &Array2<f32>)],
738 ) -> Vec<(&'a str, CompressionAnalysis)> {
739 weights
740 .iter()
741 .filter_map(|&(name, w)| Self::analyze_weight(w, 1e-6).ok().map(|a| (name, a)))
742 .collect()
743 }
744}
745
746pub struct InterpretabilityReport {
752 pub num_steps: usize,
754 pub layer_stats: Vec<(String, ActivationStats)>,
756 pub state_trajectory: StateTrajectory,
758 pub top_sensitive_features: Vec<(usize, f32)>,
760 pub overall_sparsity: f32,
762}
763
764impl Default for InterpretabilityReport {
765 fn default() -> Self {
766 Self::new()
767 }
768}
769
770impl InterpretabilityReport {
771 pub fn new() -> Self {
773 Self {
774 num_steps: 0,
775 layer_stats: Vec::new(),
776 state_trajectory: StateTrajectory::new(0),
777 top_sensitive_features: Vec::new(),
778 overall_sparsity: 0.0,
779 }
780 }
781
782 pub fn summary(&self) -> String {
784 let mut lines = Vec::new();
785 lines.push(format!(
786 "InterpretabilityReport — {} step(s), {} layer(s)",
787 self.num_steps,
788 self.layer_stats.len()
789 ));
790 lines.push(format!(
791 " Overall sparsity : {:.2}%",
792 self.overall_sparsity * 100.0
793 ));
794 lines.push(format!(
795 " State trajectory : {} entries, dim={}",
796 self.state_trajectory.len(),
797 self.state_trajectory.dim
798 ));
799
800 if !self.layer_stats.is_empty() {
801 lines.push(" Layer statistics:".to_owned());
802 for (name, stats) in &self.layer_stats {
803 lines.push(format!(
804 " {name}: sparsity={:.2}% l2={:.4} steps={}",
805 stats.sparsity * 100.0,
806 stats.l2_norm,
807 stats.num_steps
808 ));
809 }
810 }
811
812 if !self.top_sensitive_features.is_empty() {
813 lines.push(" Top sensitive features:".to_owned());
814 for &(idx, sens) in self.top_sensitive_features.iter().take(5) {
815 lines.push(format!(" feature {idx}: {sens:.4}"));
816 }
817 }
818
819 lines.join("\n")
820 }
821}
822
823#[cfg(test)]
828mod tests {
829 use super::*;
830 use scirs2_core::ndarray::array;
831
832 #[test]
836 fn test_activation_stats_basic() {
837 let v = array![1.0_f32, 2.0, 3.0];
839 let activations: Vec<Array1<f32>> = (0..5).map(|_| v.clone()).collect();
840 let stats = ActivationStats::from_sequence(&activations).expect("stats");
841
842 assert_eq!(stats.num_steps, 5);
843 for (&m, &expected) in stats.mean.iter().zip(v.iter()) {
844 assert!((m - expected).abs() < 1e-5, "mean mismatch");
845 }
846 for &var in stats.variance.iter() {
847 assert!(
848 var.abs() < 1e-5,
849 "variance should be ~0 for identical vectors"
850 );
851 }
852 assert_eq!(stats.sparsity, 0.0);
854 }
855
856 #[test]
860 fn test_activation_stats_incremental() {
861 let activations: Vec<Array1<f32>> =
862 (0..10).map(|i| array![i as f32, (i * 2) as f32]).collect();
863
864 let batch = ActivationStats::from_sequence(&activations).expect("batch");
865
866 let mut incr = ActivationStats::zero(2);
867 for a in &activations {
868 incr.update(a);
869 }
870
871 for (&bm, &im) in batch.mean.iter().zip(incr.mean.iter()) {
872 assert!((bm - im).abs() < 1e-4, "mean mismatch: {bm} vs {im}");
873 }
874 for (&bv, &iv) in batch.variance.iter().zip(incr.variance.iter()) {
875 assert!((bv - iv).abs() < 1e-4, "variance mismatch: {bv} vs {iv}");
876 }
877 }
878
879 #[test]
883 fn test_layer_probe_capture() {
884 let mut probe = LayerProbe::new("layer0", 1000);
885 assert!(probe.activations().is_empty());
886
887 for i in 0..7 {
888 probe.capture(array![i as f32, 0.0]);
889 }
890 assert_eq!(probe.activations().len(), 7);
891 assert_eq!(probe.layer_name(), "layer0");
892
893 probe.disable();
895 probe.capture(array![99.0, 0.0]);
896 assert_eq!(probe.activations().len(), 7);
897
898 probe.enable();
900 probe.clear();
901 assert!(probe.activations().is_empty());
902 }
903
904 #[test]
908 fn test_layer_probe_stats() {
909 let mut probe = LayerProbe::new("attn", 100);
910 probe.capture(array![0.0_f32, 0.0]);
911 probe.capture(array![2.0_f32, 4.0]);
912
913 let stats = probe.stats().expect("stats");
914 assert_eq!(stats.num_steps, 2);
915 assert!((stats.mean[0] - 1.0).abs() < 1e-5);
917 assert!((stats.mean[1] - 2.0).abs() < 1e-5);
918 }
919
920 #[test]
924 fn test_state_trajectory_velocities() {
925 let mut traj = StateTrajectory::new(4);
926 for i in 0..5_u32 {
927 traj.push(Array1::from_elem(4, i as f32)).expect("push");
928 }
929 let vels = traj.velocities().expect("velocities");
930 assert_eq!(vels.len(), 4, "should have len-1 velocities");
931 for v in &vels {
932 assert!(v.is_finite(), "velocities must be finite");
933 assert!(*v >= 0.0);
934 }
935 }
936
937 #[test]
941 fn test_state_trajectory_autocorrelation() {
942 let mut traj = StateTrajectory::new(8);
943 for i in 0..10_u32 {
944 let s = Array1::from_shape_fn(8, |j| (i * 8 + j as u32) as f32);
945 traj.push(s).expect("push");
946 }
947
948 let ac0 = traj.autocorrelation(0).expect("lag0");
949 assert!(
950 (ac0 - 1.0).abs() < 1e-5,
951 "lag-0 autocorr should be 1.0, got {ac0}"
952 );
953
954 let ac1 = traj.autocorrelation(1).expect("lag1");
956 assert!(ac1.is_finite());
957 assert!((-1.0_f32..=1.0_f32).contains(&ac1));
958 }
959
960 #[test]
964 fn test_sensitivity_analyzer() {
965 let analyzer = SensitivityAnalyzer::new(3);
966
967 let forward = |x: &Array1<f32>| -> ModelResult<Array1<f32>> { Ok(x.clone()) };
969
970 let input = array![1.0_f32, -0.5, 2.0];
971 let sens = analyzer
972 .input_sensitivity(&input, forward, 1e-3)
973 .expect("sensitivity");
974
975 assert_eq!(sens.len(), 3);
976 for &s in sens.iter() {
977 assert!(s >= 0.0, "sensitivity must be non-negative, got {s}");
978 assert!(s.is_finite());
979 }
980 }
981
982 #[test]
986 fn test_compression_analysis() {
987 let w: Array2<f32> =
988 Array2::from_shape_fn((16, 16), |(i, j)| if i == j { 1.0 } else { 0.0 });
989 let analysis = CompressionAnalysis::analyze_weight(&w, 1e-6).expect("analysis");
990
991 assert!((0.0..=1.0_f32).contains(&analysis.weight_sparsity));
993 assert!(analysis.compression_potential >= 0.0);
994 assert!(analysis.compression_potential <= 1.0);
995 assert!(analysis.effective_rank > 0.0);
996 assert!(analysis.recommended_rank >= 1);
997 }
998}