pub(crate) use super::*;
#[test]
fn test_sparsity_pattern_nm_valid() {
let pattern = SparsityPattern::NM { n: 2, m: 4 };
assert!(
pattern.is_valid(),
"MSK-01 FALSIFIED: 2:4 pattern should be valid"
);
}
#[test]
fn test_sparsity_pattern_nm_n_equals_m_valid() {
let pattern = SparsityPattern::NM { n: 4, m: 4 };
assert!(
pattern.is_valid(),
"MSK-02 FALSIFIED: N=M pattern should be valid"
);
}
#[test]
fn test_sparsity_pattern_nm_invalid_n_gt_m() {
let pattern = SparsityPattern::NM { n: 5, m: 4 };
assert!(
!pattern.is_valid(),
"MSK-03 FALSIFIED: N>M pattern should be invalid"
);
}
#[test]
fn test_sparsity_pattern_nm_n_zero_valid() {
let pattern = SparsityPattern::NM { n: 0, m: 4 };
assert!(
pattern.is_valid(),
"MSK-04 FALSIFIED: N=0 pattern should be valid"
);
}
#[test]
fn test_sparsity_pattern_block_valid() {
let pattern = SparsityPattern::Block {
height: 2,
width: 2,
};
assert!(
pattern.is_valid(),
"MSK-05 FALSIFIED: block pattern should be valid"
);
}
#[test]
fn test_sparsity_pattern_block_invalid_zero() {
let pattern = SparsityPattern::Block {
height: 0,
width: 2,
};
assert!(
!pattern.is_valid(),
"MSK-06 FALSIFIED: zero height should be invalid"
);
}
#[test]
fn test_sparsity_pattern_theoretical_sparsity() {
let pattern = SparsityPattern::NM { n: 2, m: 4 };
assert!(
(pattern.theoretical_sparsity().unwrap() - 0.5).abs() < 1e-6,
"MSK-07 FALSIFIED: 2:4 has 50% sparsity"
);
let pattern = SparsityPattern::NM { n: 1, m: 4 };
assert!(
(pattern.theoretical_sparsity().unwrap() - 0.75).abs() < 1e-6,
"MSK-07 FALSIFIED: 1:4 has 75% sparsity"
);
}
#[test]
fn test_sparsity_mask_accepts_binary() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured);
assert!(
mask.is_ok(),
"MSK-08 FALSIFIED: binary mask should be accepted"
);
}
#[test]
fn test_sparsity_mask_rejects_non_binary() {
let mask_data = Tensor::new(&[1.0, 0.5, 1.0, 0.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured);
assert!(
mask.is_err(),
"MSK-09 FALSIFIED: non-binary mask should be rejected"
);
let err = mask.unwrap_err();
match err {
PruningError::InvalidMask { .. } => (),
_ => panic!("MSK-09 FALSIFIED: Expected InvalidMask error"),
}
}
#[test]
fn test_sparsity_mask_rejects_negative_values() {
let mask_data = Tensor::new(&[1.0, -1.0, 1.0, 0.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured);
assert!(
mask.is_err(),
"MSK-10 FALSIFIED: negative values should be rejected"
);
}
#[test]
fn test_sparsity_mask_computes_sparsity_correctly() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
assert!(
(mask.sparsity() - 0.5).abs() < 1e-6,
"MSK-11 FALSIFIED: sparsity should be 0.5"
);
}
#[test]
fn test_sparsity_mask_all_ones_zero_sparsity() {
let mask_data = Tensor::new(&[1.0, 1.0, 1.0, 1.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
assert!(
(mask.sparsity() - 0.0).abs() < 1e-6,
"MSK-12 FALSIFIED: all-ones mask has 0% sparsity"
);
}
#[test]
fn test_sparsity_mask_all_zeros_full_sparsity() {
let mask_data = Tensor::new(&[0.0, 0.0, 0.0, 0.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
assert!(
(mask.sparsity() - 1.0).abs() < 1e-6,
"MSK-13 FALSIFIED: all-zeros mask has 100% sparsity"
);
}
#[test]
fn test_mask_apply_correct_shape() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[2, 2]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
let mut weights = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let result = mask.apply(&mut weights);
assert!(
result.is_ok(),
"MSK-14 FALSIFIED: should apply successfully"
);
assert_eq!(
weights.data(),
&[1.0, 0.0, 3.0, 0.0],
"MSK-14 FALSIFIED: weights should be masked"
);
}
#[test]
fn test_mask_apply_wrong_shape_fails() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
let mut weights = Tensor::new(&[1.0, 2.0, 3.0], &[3]); let result = mask.apply(&mut weights);
assert!(result.is_err(), "MSK-15 FALSIFIED: wrong shape should fail");
match result.unwrap_err() {
PruningError::ShapeMismatch { expected, got } => {
assert_eq!(expected, vec![4]);
assert_eq!(got, vec![3]);
}
_ => panic!("MSK-15 FALSIFIED: Expected ShapeMismatch error"),
}
}
#[test]
fn test_mask_apply_idempotent() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[4]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
let mut weights1 = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[4]);
let mut weights2 = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[4]);
mask.apply(&mut weights1).unwrap();
mask.apply(&mut weights1).unwrap(); mask.apply(&mut weights2).unwrap();
assert_eq!(
weights1.data(),
weights2.data(),
"MSK-16 FALSIFIED: mask application should be idempotent"
);
}
#[test]
fn test_nm_mask_validates_structure() {
let mask_data = Tensor::new(
&[
1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ],
&[8],
);
let pattern = SparsityPattern::NM { n: 2, m: 4 };
let mask = SparsityMask::new(mask_data, pattern);
assert!(
mask.is_ok(),
"MSK-17 FALSIFIED: valid 2:4 mask should be accepted"
);
}
#[test]
fn test_nm_mask_rejects_invalid_structure() {
let mask_data = Tensor::new(
&[
1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, ],
&[8],
);
let pattern = SparsityPattern::NM { n: 2, m: 4 };
let mask = SparsityMask::new(mask_data, pattern);
assert!(
mask.is_err(),
"MSK-18 FALSIFIED: invalid N:M structure should be rejected"
);
}
#[test]
fn test_generate_unstructured_mask_basic() {
let scores = Tensor::new(&[0.1, 0.4, 0.2, 0.3], &[4]);
let mask = generate_unstructured_mask(&scores, 0.5).unwrap();
assert!(
(mask.sparsity() - 0.5).abs() < 1e-6,
"MSK-19 FALSIFIED: should achieve ~50% sparsity"
);
}
#[test]
fn test_generate_unstructured_mask_zero_sparsity() {
let scores = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &[4]);
let mask = generate_unstructured_mask(&scores, 0.0).unwrap();
assert!(
(mask.sparsity() - 0.0).abs() < 1e-6,
"MSK-20 FALSIFIED: 0% sparsity should keep all"
);
}
#[test]
fn test_generate_unstructured_mask_full_sparsity() {
let scores = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &[4]);
let mask = generate_unstructured_mask(&scores, 1.0).unwrap();
assert!(
(mask.sparsity() - 1.0).abs() < 1e-6,
"MSK-21 FALSIFIED: 100% sparsity should prune all"
);
}
#[test]
fn test_generate_unstructured_mask_invalid_sparsity() {
let scores = Tensor::new(&[0.1, 0.2], &[2]);
assert!(
generate_unstructured_mask(&scores, -0.1).is_err(),
"MSK-22 FALSIFIED: negative sparsity should error"
);
assert!(
generate_unstructured_mask(&scores, 1.1).is_err(),
"MSK-22 FALSIFIED: sparsity >1 should error"
);
}
#[test]
fn test_generate_nm_mask_2_4() {
let scores = Tensor::new(&[0.1, 0.4, 0.2, 0.3, 0.5, 0.1, 0.3, 0.2], &[8]);
let mask = generate_nm_mask(&scores, 2, 4).unwrap();
let data = mask.tensor().data();
for chunk in data.chunks(4) {
let ones = chunk.iter().filter(|&&v| v > 0.5).count();
assert_eq!(
ones, 2,
"MSK-23 FALSIFIED: each group should have 2 non-zeros"
);
}
}
#[test]
fn test_generate_nm_mask_keeps_top_n() {
let scores = Tensor::new(&[0.1, 0.4, 0.2, 0.3], &[4]);
let mask = generate_nm_mask(&scores, 2, 4).unwrap();
let data = mask.tensor().data();
assert_eq!(data[0], 0.0, "MSK-24 FALSIFIED: 0.1 should be pruned");
assert_eq!(data[1], 1.0, "MSK-24 FALSIFIED: 0.4 should be kept");
assert_eq!(data[2], 0.0, "MSK-24 FALSIFIED: 0.2 should be pruned");
assert_eq!(data[3], 1.0, "MSK-24 FALSIFIED: 0.3 should be kept");
}
#[test]
fn test_generate_nm_mask_invalid_n_gt_m() {
let scores = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &[4]);
let result = generate_nm_mask(&scores, 5, 4);
assert!(result.is_err(), "MSK-25 FALSIFIED: N>M should error");
}
#[test]
fn test_generate_nm_mask_not_divisible() {
let scores = Tensor::new(&[0.1, 0.2, 0.3], &[3]); let result = generate_nm_mask(&scores, 2, 4);
assert!(
result.is_err(),
"MSK-26 FALSIFIED: non-divisible length should error"
);
}
#[test]
fn test_mask_nnz() {
let mask = SparsityMask::new(
Tensor::new(&[1.0, 0.0, 1.0, 0.0, 1.0], &[5]),
SparsityPattern::Unstructured,
)
.unwrap();
assert_eq!(mask.nnz(), 3, "MSK-27 FALSIFIED: nnz should be 3");
assert_eq!(
mask.num_zeros(),
2,
"MSK-27 FALSIFIED: num_zeros should be 2"
);
}
#[test]
fn test_dense_mask() {
let mask = SparsityMask::dense(&[3, 4]);
assert_eq!(
mask.shape(),
&[3, 4],
"MSK-28 FALSIFIED: shape should match"
);
assert_eq!(
mask.sparsity(),
0.0,
"MSK-28 FALSIFIED: dense mask has 0 sparsity"
);
assert_eq!(mask.nnz(), 12, "MSK-28 FALSIFIED: all elements non-zero");
}
#[test]
fn test_generate_block_mask_basic() {
let scores = Tensor::new(
&[
1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0, ],
&[4, 4],
);
let mask = generate_block_mask(&scores, 2, 2, 0.5).unwrap();
assert!(
(mask.sparsity() - 0.5).abs() < 1e-6,
"MSK-29 FALSIFIED: should achieve 50% sparsity"
);
}
#[test]
fn test_generate_block_mask_uniform_blocks() {
let scores = Tensor::new(&[1.0; 16], &[4, 4]);
let mask = generate_block_mask(&scores, 2, 2, 0.5).unwrap();
let data = mask.tensor().data();
for br in 0..2 {
for bc in 0..2 {
let first = data[br * 2 * 4 + bc * 2];
let uniform = (0..2).all(|r| {
(0..2).all(|c| {
let idx = (br * 2 + r) * 4 + (bc * 2 + c);
(data[idx] - first).abs() < 1e-6
})
});
assert!(
uniform,
"MSK-30 FALSIFIED: block ({}, {}) should be uniform",
br, bc
);
}
}
}
#[path = "tests_block_mask.rs"]
mod tests_block_mask;
#[path = "tests_column_validate.rs"]
mod tests_column_validate;