use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn magnitude_prune<T: Float>(
weights: &Tensor<T>,
sparsity: f64,
) -> FerrotorchResult<Tensor<T>> {
if !(0.0..1.0).contains(&sparsity) {
return Err(FerrotorchError::InvalidArgument {
message: format!("sparsity must be in [0, 1), got {sparsity}"),
});
}
let data = weights.data()?;
let numel = data.len();
let n_prune = ((numel as f64) * sparsity).round() as usize;
if n_prune == 0 {
return Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
weights.shape().to_vec(),
weights.requires_grad(),
);
}
let mut magnitudes: Vec<T> = data.iter().map(|&v| v.abs()).collect();
magnitudes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let threshold = magnitudes[n_prune - 1];
let pruned: Vec<T> = data
.iter()
.map(|&v| {
if v.abs() <= threshold {
<T as num_traits::Zero>::zero()
} else {
v
}
})
.collect();
Tensor::from_storage(
TensorStorage::cpu(pruned),
weights.shape().to_vec(),
weights.requires_grad(),
)
}
pub fn apply_2_4_mask<T: Float>(weights: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let data = weights.data()?;
let mut result = data.to_vec();
let groups = result.len() / 4;
for g in 0..groups {
let base = g * 4;
let group = &mut result[base..base + 4];
let mut idx_mag: Vec<(usize, T)> = group
.iter()
.enumerate()
.map(|(i, &v)| (i, v.abs()))
.collect();
idx_mag.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
group[idx_mag[0].0] = <T as num_traits::Zero>::zero();
group[idx_mag[1].0] = <T as num_traits::Zero>::zero();
}
Tensor::from_storage(
TensorStorage::cpu(result),
weights.shape().to_vec(),
weights.requires_grad(),
)
}
pub fn sparsity_ratio<T: Float>(tensor: &Tensor<T>) -> FerrotorchResult<f64> {
let data = tensor.data()?;
let zeros = data
.iter()
.filter(|&&v| v == <T as num_traits::Zero>::zero())
.count();
Ok(zeros as f64 / data.len() as f64)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
fn make_tensor_rg(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, true).unwrap()
}
#[test]
fn test_magnitude_prune_50_percent() {
let t = make_tensor(vec![1.0, -4.0, 2.0, -3.0], vec![4]);
let pruned = magnitude_prune(&t, 0.5).unwrap();
let d = pruned.data().unwrap();
assert_eq!(d[0], 0.0); assert_eq!(d[1], -4.0); assert_eq!(d[2], 0.0); assert_eq!(d[3], -3.0); }
#[test]
fn test_magnitude_prune_zero_sparsity() {
let t = make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
let pruned = magnitude_prune(&t, 0.0).unwrap();
let d = pruned.data().unwrap();
assert_eq!(d, &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_magnitude_prune_invalid_sparsity() {
let t = make_tensor(vec![1.0], vec![1]);
assert!(magnitude_prune(&t, 1.0).is_err());
assert!(magnitude_prune(&t, -0.1).is_err());
}
#[test]
fn test_magnitude_prune_nan_no_panic() {
let t = make_tensor(vec![1.0, f32::NAN, 3.0, f32::NAN, 2.0, 4.0], vec![6]);
let result = magnitude_prune(&t, 0.5);
assert!(result.is_ok());
}
#[test]
fn test_apply_2_4_mask_nan_no_panic() {
let t = make_tensor(
vec![1.0, f32::NAN, 3.0, f32::NAN, 2.0, 4.0, 0.5, 0.1],
vec![8],
);
let result = apply_2_4_mask(&t);
assert!(result.is_ok());
}
#[test]
fn test_apply_2_4_mask_basic() {
let t = make_tensor(vec![1.0, -4.0, 2.0, -3.0, 0.5, 0.1, 0.9, 0.8], vec![8]);
let masked = apply_2_4_mask(&t).unwrap();
let d = masked.data().unwrap();
assert_eq!(d[0], 0.0);
assert_eq!(d[1], -4.0);
assert_eq!(d[2], 0.0);
assert_eq!(d[3], -3.0);
assert_eq!(d[4], 0.0);
assert_eq!(d[5], 0.0);
assert_eq!(d[6], 0.9);
assert_eq!(d[7], 0.8);
}
#[test]
fn test_apply_2_4_mask_preserves_requires_grad() {
let t = make_tensor_rg(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
assert!(t.requires_grad());
let masked = apply_2_4_mask(&t).unwrap();
assert!(
masked.requires_grad(),
"apply_2_4_mask must propagate requires_grad"
);
}
#[test]
fn test_sparsity_ratio() {
let t = make_tensor(vec![0.0, 1.0, 0.0, 2.0], vec![4]);
let ratio = sparsity_ratio(&t).unwrap();
assert!((ratio - 0.5).abs() < 1e-10);
}
}