use super::*;
#[test]
fn test_generate_block_mask_not_divisible() {
let scores = Tensor::new(&[1.0; 15], &[3, 5]);
let result = generate_block_mask(&scores, 2, 2, 0.5);
assert!(
result.is_err(),
"MSK-31 FALSIFIED: non-divisible shape should error"
);
}
#[test]
fn test_generate_block_mask_1d_rejected() {
let scores = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[4]);
let result = generate_block_mask(&scores, 2, 2, 0.5);
assert!(
result.is_err(),
"MSK-32 FALSIFIED: 1D tensor should be rejected"
);
}
#[test]
fn test_generate_block_mask_invalid_sparsity() {
let scores = Tensor::new(&[1.0; 16], &[4, 4]);
assert!(
generate_block_mask(&scores, 2, 2, -0.1).is_err(),
"MSK-33 FALSIFIED: negative sparsity should error"
);
assert!(
generate_block_mask(&scores, 2, 2, 1.1).is_err(),
"MSK-33 FALSIFIED: sparsity >1 should error"
);
}
#[test]
fn test_generate_row_mask_basic() {
let scores = Tensor::new(
&[
1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, ],
&[4, 2],
);
let mask = generate_row_mask(&scores, 0.5).unwrap();
assert!(
(mask.sparsity() - 0.5).abs() < 1e-6,
"MSK-34 FALSIFIED: should achieve 50% sparsity"
);
let data = mask.tensor().data();
for r in 0..4 {
let first = data[r * 2];
assert!(
(data[r * 2 + 1] - first).abs() < 1e-6,
"MSK-34 FALSIFIED: row {} should be uniform",
r
);
}
}
#[test]
fn test_generate_row_mask_keeps_highest() {
let scores = Tensor::new(
&[
1.0, 1.0, 10.0, 10.0, ],
&[2, 2],
);
let mask = generate_row_mask(&scores, 0.5).unwrap();
let data = mask.tensor().data();
assert_eq!(data[0], 0.0, "MSK-35 FALSIFIED: row 0 should be pruned");
assert_eq!(data[1], 0.0, "MSK-35 FALSIFIED: row 0 should be pruned");
assert_eq!(data[2], 1.0, "MSK-35 FALSIFIED: row 1 should be kept");
assert_eq!(data[3], 1.0, "MSK-35 FALSIFIED: row 1 should be kept");
}
#[test]
fn test_generate_row_mask_1d_rejected() {
let scores = Tensor::new(&[1.0, 2.0, 3.0], &[3]);
let result = generate_row_mask(&scores, 0.5);
assert!(
result.is_err(),
"MSK-36 FALSIFIED: 1D tensor should be rejected"
);
}
#[test]
fn test_generate_column_mask_basic() {
let scores = Tensor::new(
&[
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, ],
&[2, 4],
);
let mask = generate_column_mask(&scores, 0.5).unwrap();
assert!(
(mask.sparsity() - 0.5).abs() < 1e-6,
"MSK-37 FALSIFIED: should achieve 50% sparsity"
);
let data = mask.tensor().data();
for c in 0..4 {
let first = data[c];
assert!(
(data[4 + c] - first).abs() < 1e-6,
"MSK-37 FALSIFIED: column {} should be uniform",
c
);
}
}
#[test]
fn test_generate_column_mask_keeps_highest() {
let scores = Tensor::new(
&[
1.0, 10.0, 1.0, 10.0,
],
&[2, 2],
);
let mask = generate_column_mask(&scores, 0.5).unwrap();
let data = mask.tensor().data();
assert_eq!(data[0], 0.0, "MSK-38 FALSIFIED: col 0 should be pruned");
assert_eq!(data[1], 1.0, "MSK-38 FALSIFIED: col 1 should be kept");
assert_eq!(data[2], 0.0, "MSK-38 FALSIFIED: col 0 should be pruned");
assert_eq!(data[3], 1.0, "MSK-38 FALSIFIED: col 1 should be kept");
}
#[test]
fn test_generate_column_mask_1d_rejected() {
let scores = Tensor::new(&[1.0, 2.0, 3.0], &[3]);
let result = generate_column_mask(&scores, 0.5);
assert!(
result.is_err(),
"MSK-39 FALSIFIED: 1D tensor should be rejected"
);
}
#[test]
fn test_generate_column_mask_invalid_sparsity() {
let scores = Tensor::new(&[1.0; 4], &[2, 2]);
assert!(
generate_column_mask(&scores, -0.1).is_err(),
"MSK-40 FALSIFIED: negative sparsity should error"
);
}
#[test]
fn test_sparsity_pattern_theoretical_sparsity_row_col() {
let row = SparsityPattern::Row;
assert!(
row.theoretical_sparsity().is_none(),
"MSK-41 FALSIFIED: Row should return None"
);
let col = SparsityPattern::Column;
assert!(
col.theoretical_sparsity().is_none(),
"MSK-41 FALSIFIED: Column should return None"
);
}
#[test]
fn test_sparsity_mask_tensor() {
let mask_data = Tensor::new(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
assert_eq!(mask.tensor().shape(), &[2, 2]);
assert_eq!(mask.tensor().data(), &[1.0, 0.0, 0.0, 1.0]);
}
#[test]
fn test_sparsity_mask_pattern() {
let mask_data = Tensor::new(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
assert_eq!(mask.pattern(), SparsityPattern::Unstructured);
}
#[test]
fn test_generate_row_mask_invalid_sparsity_positive() {
let scores = Tensor::new(&[1.0; 4], &[2, 2]);
assert!(
generate_row_mask(&scores, 1.5).is_err(),
"MSK-42 FALSIFIED: sparsity > 1.0 should error"
);
}
#[test]
fn test_sparsity_pattern_clone() {
let nm = SparsityPattern::NM { n: 2, m: 4 };
let cloned = nm.clone();
assert_eq!(nm, cloned);
}
#[test]
fn test_sparsity_pattern_debug() {
let block = SparsityPattern::Block {
height: 2,
width: 2,
};
let debug = format!("{:?}", block);
assert!(debug.contains("Block"));
}
#[test]
fn test_sparsity_pattern_copy() {
let unstructured = SparsityPattern::Unstructured;
let copied = unstructured;
assert_eq!(unstructured, copied);
}
#[test]
fn test_sparsity_mask_shape() {
let mask_data = Tensor::new(&[1.0; 6], &[2, 3]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured).unwrap();
assert_eq!(mask.shape(), &[2, 3]);
}
#[test]
fn test_block_validate_1d_mask() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[4]);
let result = SparsityMask::new(
mask_data,
SparsityPattern::Block {
height: 2,
width: 2,
},
);
assert!(result.is_err(), "Block sparsity should require 2D mask");
}
#[test]
fn test_block_validate_non_divisible_shape() {
let mask_data = Tensor::new(&[1.0; 9], &[3, 3]);
let result = SparsityMask::new(
mask_data,
SparsityPattern::Block {
height: 2,
width: 2,
},
);
assert!(result.is_err(), "Block should error on non-divisible shape");
}
#[test]
fn test_block_validate_non_uniform_block() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 1.0], &[2, 2]);
let result = SparsityMask::new(
mask_data,
SparsityPattern::Block {
height: 2,
width: 2,
},
);
assert!(result.is_err(), "Block should error on non-uniform block");
}
#[test]
fn test_row_validate_1d_mask() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[4]);
let result = SparsityMask::new(mask_data, SparsityPattern::Row);
assert!(result.is_err(), "Row sparsity should require 2D mask");
}
#[test]
fn test_row_validate_non_uniform_row() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 1.0], &[2, 2]);
let result = SparsityMask::new(mask_data, SparsityPattern::Row);
assert!(result.is_err(), "Row should error on non-uniform row");
}
#[test]
fn test_column_validate_1d_mask() {
let mask_data = Tensor::new(&[1.0, 0.0, 1.0, 0.0], &[4]);
let result = SparsityMask::new(mask_data, SparsityPattern::Column);
assert!(result.is_err(), "Column sparsity should require 2D mask");
}
#[test]
fn test_column_validate_non_uniform_column() {
let mask_data = Tensor::new(&[1.0, 1.0, 0.0, 1.0], &[2, 2]);
let result = SparsityMask::new(mask_data, SparsityPattern::Column);
assert!(result.is_err(), "Column should error on non-uniform column");
}
#[test]
fn test_sparsity_mask_empty_data() {
let mask_data = Tensor::new(&[], &[0]);
let mask = SparsityMask::new(mask_data, SparsityPattern::Unstructured);
assert!(mask.is_ok());
let mask = mask.unwrap();
assert!((mask.sparsity() - 0.0).abs() < 1e-6);
}
#[test]
fn test_sparsity_pattern_default() {
let pattern = SparsityPattern::default();
assert_eq!(pattern, SparsityPattern::Unstructured);
}
#[test]
fn test_generate_nm_mask_m_zero() {
let scores = Tensor::new(&[1.0; 4], &[4]);
let result = generate_nm_mask(&scores, 1, 0);
assert!(result.is_err(), "M=0 should error");
}
#[test]
fn test_generate_nm_mask_n_greater_than_m() {
let scores = Tensor::new(&[1.0; 4], &[4]);
let result = generate_nm_mask(&scores, 3, 2);
assert!(result.is_err(), "N > M should error");
}
#[test]
fn test_generate_nm_mask_not_divisible_coverage() {
let scores = Tensor::new(&[1.0; 5], &[5]); let result = generate_nm_mask(&scores, 2, 4);
assert!(result.is_err(), "Length not divisible by M should error");
}
#[test]
fn test_generate_unstructured_mask_empty() {
let scores = Tensor::new(&[], &[0]);
let mask = generate_unstructured_mask(&scores, 0.5);
assert!(mask.is_ok());
}
#[test]
fn test_block_validate_valid_uniform() {
let mask_data = Tensor::new(&[1.0; 16], &[4, 4]);
let result = SparsityMask::new(
mask_data,
SparsityPattern::Block {
height: 2,
width: 2,
},
);
assert!(result.is_ok(), "Uniform blocks should be valid");
}
#[test]
fn test_row_validate_valid_uniform() {
let mask_data = Tensor::new(&[1.0, 1.0, 1.0, 0.0, 0.0, 0.0], &[2, 3]);
let result = SparsityMask::new(mask_data, SparsityPattern::Row);
assert!(result.is_ok(), "Uniform rows should be valid");
}