aprender-core 0.29.2

Next-generation machine learning library in pure Rust
pub(crate) use super::*;

// ==========================================================================
// FALSIFICATION: ActivationStats initialization
// ==========================================================================
#[test]
fn test_activation_stats_new() {
    let stats = ActivationStats::new(4);

    assert_eq!(
        stats.count, 0,
        "CAL-01 FALSIFIED: initial count should be 0"
    );
    assert_eq!(
        stats.input_norms.data().len(),
        4,
        "CAL-01 FALSIFIED: input_norms should have 4 elements"
    );
    assert!(
        stats.input_norms.data().iter().all(|&v| v == 0.0),
        "CAL-01 FALSIFIED: initial norms should be all zeros"
    );
}

// ==========================================================================
// FALSIFICATION: Empty batch doesn't corrupt stats
// ==========================================================================
#[test]
fn test_activation_stats_empty_batch_noop() {
    let mut stats = ActivationStats::new(4);

    // First update with actual data
    stats.update(&Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[1, 4]));
    let count_before = stats.count;
    let norms_before = stats.input_norms.data().to_vec();

    // Empty update (0 batch size)
    stats.update(&Tensor::new(&[], &[0, 4]));

    assert_eq!(
        stats.count, count_before,
        "CAL-02 FALSIFIED: empty batch should not change count"
    );
    assert_eq!(
        stats.input_norms.data(),
        &norms_before[..],
        "CAL-02 FALSIFIED: empty batch should not change norms"
    );
}

// ==========================================================================
// FALSIFICATION: CalibrationContext layer lookup
// ==========================================================================
#[test]
fn test_calibration_context_get_stats() {
    let mut ctx = CalibrationContext::new("test_dataset".to_string());

    let stats = ActivationStats::new(512);
    ctx.add_layer_stats("model.layer0".to_string(), stats);

    assert!(
        ctx.get_stats("model.layer0").is_some(),
        "CAL-03 FALSIFIED: should find added layer"
    );
    assert!(
        ctx.get_stats("nonexistent").is_none(),
        "CAL-03 FALSIFIED: should not find non-existent layer"
    );
}

// ==========================================================================
// FALSIFICATION: Missing stats returns error
// ==========================================================================
#[test]
fn test_calibration_context_missing_stats_error() {
    let ctx = CalibrationContext::new("test".to_string());

    let result = ctx.require_stats("missing_layer");
    assert!(
        result.is_err(),
        "CAL-04 FALSIFIED: should error on missing stats"
    );

    match result.unwrap_err() {
        PruningError::MissingActivationStats { layer } => {
            assert_eq!(
                layer, "missing_layer",
                "CAL-04 FALSIFIED: error should contain layer name"
            );
        }
        _ => panic!("CAL-04 FALSIFIED: expected MissingActivationStats error"),
    }
}

// ==========================================================================
// FALSIFICATION: Input norm computation (single sample)
// ==========================================================================
#[test]
fn test_activation_stats_single_sample() {
    let mut stats = ActivationStats::new(2);

    // Single sample: [3.0, 4.0]
    let batch = Tensor::new(&[3.0, 4.0], &[1, 2]);
    stats.update(&batch);

    assert_eq!(stats.count, 1, "CAL-05 FALSIFIED: count should be 1");

    // input_norms should be sqrt(mean(x^2)) = sqrt(x^2) = |x| for single sample
    // But we're computing running RMS, so for single sample:
    // norm[0] = sqrt(9/1) = 3.0
    // norm[1] = sqrt(16/1) = 4.0
    let norms = stats.input_norms.data();
    assert!(
        (norms[0] - 3.0).abs() < 1e-5,
        "CAL-05 FALSIFIED: norm[0] should be 3.0, got {}",
        norms[0]
    );
    assert!(
        (norms[1] - 4.0).abs() < 1e-5,
        "CAL-05 FALSIFIED: norm[1] should be 4.0, got {}",
        norms[1]
    );
}

// ==========================================================================
// FALSIFICATION: Multi-batch Welford update
// ==========================================================================
#[test]
fn test_activation_stats_multi_batch() {
    let mut stats = ActivationStats::new(2);

    // Batch 1: [[1, 2], [3, 4]] - 2 samples
    let batch1 = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
    stats.update(&batch1);
    assert_eq!(
        stats.count, 2,
        "CAL-06 FALSIFIED: count should be 2 after first batch"
    );

    // Batch 2: [[5, 6]] - 1 sample
    let batch2 = Tensor::new(&[5.0, 6.0], &[1, 2]);
    stats.update(&batch2);
    assert_eq!(
        stats.count, 3,
        "CAL-06 FALSIFIED: count should be 3 after second batch"
    );
}

// ==========================================================================
// FALSIFICATION: CalibrationContext dataset name
// ==========================================================================
#[test]
fn test_calibration_context_dataset() {
    let ctx = CalibrationContext::new("c4".to_string());
    assert_eq!(
        ctx.dataset, "c4",
        "CAL-07 FALSIFIED: dataset should be 'c4'"
    );
    assert_eq!(
        ctx.num_samples, 0,
        "CAL-07 FALSIFIED: initial samples should be 0"
    );
}

// ==========================================================================
// FALSIFICATION: has_stats helper
// ==========================================================================
#[test]
fn test_calibration_context_has_stats() {
    let mut ctx = CalibrationContext::new("test".to_string());

    assert!(
        !ctx.has_stats("layer0"),
        "CAL-08 FALSIFIED: should not have stats initially"
    );

    ctx.add_layer_stats("layer0".to_string(), ActivationStats::new(10));
    assert!(
        ctx.has_stats("layer0"),
        "CAL-08 FALSIFIED: should have stats after adding"
    );
}

// ==========================================================================
// FALSIFICATION: layer_names helper
// ==========================================================================
#[test]
fn test_calibration_context_layer_names() {
    let mut ctx = CalibrationContext::new("test".to_string());

    ctx.add_layer_stats("layer_a".to_string(), ActivationStats::new(10));
    ctx.add_layer_stats("layer_b".to_string(), ActivationStats::new(20));

    let names = ctx.layer_names();
    assert_eq!(
        names.len(),
        2,
        "CAL-09 FALSIFIED: should have 2 layer names"
    );
    assert!(
        names.contains(&"layer_a"),
        "CAL-09 FALSIFIED: should contain layer_a"
    );
    assert!(
        names.contains(&"layer_b"),
        "CAL-09 FALSIFIED: should contain layer_b"
    );
}

// ==========================================================================
// FALSIFICATION: increment_samples
// ==========================================================================
#[test]
fn test_calibration_context_increment_samples() {
    let mut ctx = CalibrationContext::new("test".to_string());

    ctx.increment_samples(10);
    assert_eq!(ctx.num_samples, 10, "CAL-10 FALSIFIED: should be 10");

    ctx.increment_samples(5);
    assert_eq!(ctx.num_samples, 15, "CAL-10 FALSIFIED: should be 15");
}

// ==========================================================================
// FALSIFICATION: input_features helper
// ==========================================================================
#[test]
fn test_activation_stats_input_features() {
    let stats = ActivationStats::new(128);
    assert_eq!(
        stats.input_features(),
        128,
        "CAL-11 FALSIFIED: input_features should be 128"
    );
}

// ==========================================================================
// FALSIFICATION: get_stats_mut helper
// ==========================================================================
#[test]
fn test_calibration_context_get_stats_mut() {
    let mut ctx = CalibrationContext::new("test".to_string());
    ctx.add_layer_stats("layer0".to_string(), ActivationStats::new(10));

    // Get mutable stats and modify
    let stats = ctx.get_stats_mut("layer0").unwrap();
    stats.count = 42;

    // Verify modification
    assert_eq!(
        ctx.get_stats("layer0").unwrap().count,
        42,
        "CAL-12 FALSIFIED: Mutable access should allow modification"
    );

    // Non-existent layer returns None
    assert!(
        ctx.get_stats_mut("nonexistent").is_none(),
        "CAL-12 FALSIFIED: Non-existent layer should return None"
    );
}

// ==========================================================================
// FALSIFICATION: Clone and Debug traits
// ==========================================================================
#[test]
fn test_activation_stats_clone() {
    let mut stats = ActivationStats::new(4);
    stats.update(&Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[1, 4]));

    let cloned = stats.clone();
    assert_eq!(
        stats.count, cloned.count,
        "CAL-13 FALSIFIED: Clone should preserve count"
    );
    assert_eq!(
        stats.input_norms.data(),
        cloned.input_norms.data(),
        "CAL-13 FALSIFIED: Clone should preserve input_norms"
    );
}

#[test]
fn test_calibration_context_clone() {
    let mut ctx = CalibrationContext::new("test".to_string());
    ctx.add_layer_stats("layer0".to_string(), ActivationStats::new(10));
    ctx.increment_samples(5);

    let cloned = ctx.clone();
    assert_eq!(
        ctx.num_samples, cloned.num_samples,
        "CAL-14 FALSIFIED: Clone should preserve sample count"
    );
    assert_eq!(
        ctx.dataset, cloned.dataset,
        "CAL-14 FALSIFIED: Clone should preserve dataset name"
    );
    assert!(cloned.has_stats("layer0"));
}

#[test]
fn test_calibration_context_debug() {
    let ctx = CalibrationContext::new("debug_test".to_string());
    let debug_str = format!("{:?}", ctx);
    assert!(
        debug_str.contains("CalibrationContext"),
        "CAL-15 FALSIFIED: Debug should show type name"
    );
    assert!(
        debug_str.contains("debug_test"),
        "CAL-15 FALSIFIED: Debug should show dataset"
    );
}

#[test]
fn test_activation_stats_debug() {
    let stats = ActivationStats::new(4);
    let debug_str = format!("{:?}", stats);
    assert!(
        debug_str.contains("ActivationStats"),
        "CAL-16 FALSIFIED: Debug should show type name"
    );
}