pub(crate) use super::*;
#[test]
fn test_activation_stats_new() {
let stats = ActivationStats::new(4);
assert_eq!(
stats.count, 0,
"CAL-01 FALSIFIED: initial count should be 0"
);
assert_eq!(
stats.input_norms.data().len(),
4,
"CAL-01 FALSIFIED: input_norms should have 4 elements"
);
assert!(
stats.input_norms.data().iter().all(|&v| v == 0.0),
"CAL-01 FALSIFIED: initial norms should be all zeros"
);
}
#[test]
fn test_activation_stats_empty_batch_noop() {
let mut stats = ActivationStats::new(4);
stats.update(&Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[1, 4]));
let count_before = stats.count;
let norms_before = stats.input_norms.data().to_vec();
stats.update(&Tensor::new(&[], &[0, 4]));
assert_eq!(
stats.count, count_before,
"CAL-02 FALSIFIED: empty batch should not change count"
);
assert_eq!(
stats.input_norms.data(),
&norms_before[..],
"CAL-02 FALSIFIED: empty batch should not change norms"
);
}
#[test]
fn test_calibration_context_get_stats() {
let mut ctx = CalibrationContext::new("test_dataset".to_string());
let stats = ActivationStats::new(512);
ctx.add_layer_stats("model.layer0".to_string(), stats);
assert!(
ctx.get_stats("model.layer0").is_some(),
"CAL-03 FALSIFIED: should find added layer"
);
assert!(
ctx.get_stats("nonexistent").is_none(),
"CAL-03 FALSIFIED: should not find non-existent layer"
);
}
#[test]
fn test_calibration_context_missing_stats_error() {
let ctx = CalibrationContext::new("test".to_string());
let result = ctx.require_stats("missing_layer");
assert!(
result.is_err(),
"CAL-04 FALSIFIED: should error on missing stats"
);
match result.unwrap_err() {
PruningError::MissingActivationStats { layer } => {
assert_eq!(
layer, "missing_layer",
"CAL-04 FALSIFIED: error should contain layer name"
);
}
_ => panic!("CAL-04 FALSIFIED: expected MissingActivationStats error"),
}
}
#[test]
fn test_activation_stats_single_sample() {
let mut stats = ActivationStats::new(2);
let batch = Tensor::new(&[3.0, 4.0], &[1, 2]);
stats.update(&batch);
assert_eq!(stats.count, 1, "CAL-05 FALSIFIED: count should be 1");
let norms = stats.input_norms.data();
assert!(
(norms[0] - 3.0).abs() < 1e-5,
"CAL-05 FALSIFIED: norm[0] should be 3.0, got {}",
norms[0]
);
assert!(
(norms[1] - 4.0).abs() < 1e-5,
"CAL-05 FALSIFIED: norm[1] should be 4.0, got {}",
norms[1]
);
}
#[test]
fn test_activation_stats_multi_batch() {
let mut stats = ActivationStats::new(2);
let batch1 = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
stats.update(&batch1);
assert_eq!(
stats.count, 2,
"CAL-06 FALSIFIED: count should be 2 after first batch"
);
let batch2 = Tensor::new(&[5.0, 6.0], &[1, 2]);
stats.update(&batch2);
assert_eq!(
stats.count, 3,
"CAL-06 FALSIFIED: count should be 3 after second batch"
);
}
#[test]
fn test_calibration_context_dataset() {
let ctx = CalibrationContext::new("c4".to_string());
assert_eq!(
ctx.dataset, "c4",
"CAL-07 FALSIFIED: dataset should be 'c4'"
);
assert_eq!(
ctx.num_samples, 0,
"CAL-07 FALSIFIED: initial samples should be 0"
);
}
#[test]
fn test_calibration_context_has_stats() {
let mut ctx = CalibrationContext::new("test".to_string());
assert!(
!ctx.has_stats("layer0"),
"CAL-08 FALSIFIED: should not have stats initially"
);
ctx.add_layer_stats("layer0".to_string(), ActivationStats::new(10));
assert!(
ctx.has_stats("layer0"),
"CAL-08 FALSIFIED: should have stats after adding"
);
}
#[test]
fn test_calibration_context_layer_names() {
let mut ctx = CalibrationContext::new("test".to_string());
ctx.add_layer_stats("layer_a".to_string(), ActivationStats::new(10));
ctx.add_layer_stats("layer_b".to_string(), ActivationStats::new(20));
let names = ctx.layer_names();
assert_eq!(
names.len(),
2,
"CAL-09 FALSIFIED: should have 2 layer names"
);
assert!(
names.contains(&"layer_a"),
"CAL-09 FALSIFIED: should contain layer_a"
);
assert!(
names.contains(&"layer_b"),
"CAL-09 FALSIFIED: should contain layer_b"
);
}
#[test]
fn test_calibration_context_increment_samples() {
let mut ctx = CalibrationContext::new("test".to_string());
ctx.increment_samples(10);
assert_eq!(ctx.num_samples, 10, "CAL-10 FALSIFIED: should be 10");
ctx.increment_samples(5);
assert_eq!(ctx.num_samples, 15, "CAL-10 FALSIFIED: should be 15");
}
#[test]
fn test_activation_stats_input_features() {
let stats = ActivationStats::new(128);
assert_eq!(
stats.input_features(),
128,
"CAL-11 FALSIFIED: input_features should be 128"
);
}
#[test]
fn test_calibration_context_get_stats_mut() {
let mut ctx = CalibrationContext::new("test".to_string());
ctx.add_layer_stats("layer0".to_string(), ActivationStats::new(10));
let stats = ctx.get_stats_mut("layer0").unwrap();
stats.count = 42;
assert_eq!(
ctx.get_stats("layer0").unwrap().count,
42,
"CAL-12 FALSIFIED: Mutable access should allow modification"
);
assert!(
ctx.get_stats_mut("nonexistent").is_none(),
"CAL-12 FALSIFIED: Non-existent layer should return None"
);
}
#[test]
fn test_activation_stats_clone() {
let mut stats = ActivationStats::new(4);
stats.update(&Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[1, 4]));
let cloned = stats.clone();
assert_eq!(
stats.count, cloned.count,
"CAL-13 FALSIFIED: Clone should preserve count"
);
assert_eq!(
stats.input_norms.data(),
cloned.input_norms.data(),
"CAL-13 FALSIFIED: Clone should preserve input_norms"
);
}
#[test]
fn test_calibration_context_clone() {
let mut ctx = CalibrationContext::new("test".to_string());
ctx.add_layer_stats("layer0".to_string(), ActivationStats::new(10));
ctx.increment_samples(5);
let cloned = ctx.clone();
assert_eq!(
ctx.num_samples, cloned.num_samples,
"CAL-14 FALSIFIED: Clone should preserve sample count"
);
assert_eq!(
ctx.dataset, cloned.dataset,
"CAL-14 FALSIFIED: Clone should preserve dataset name"
);
assert!(cloned.has_stats("layer0"));
}
#[test]
fn test_calibration_context_debug() {
let ctx = CalibrationContext::new("debug_test".to_string());
let debug_str = format!("{:?}", ctx);
assert!(
debug_str.contains("CalibrationContext"),
"CAL-15 FALSIFIED: Debug should show type name"
);
assert!(
debug_str.contains("debug_test"),
"CAL-15 FALSIFIED: Debug should show dataset"
);
}
#[test]
fn test_activation_stats_debug() {
let stats = ActivationStats::new(4);
let debug_str = format!("{:?}", stats);
assert!(
debug_str.contains("ActivationStats"),
"CAL-16 FALSIFIED: Debug should show type name"
);
}