#![allow(clippy::disallowed_methods)]
use aprender::nn::Linear;
use aprender::pruning::{
generate_nm_mask, generate_unstructured_mask, Importance, MagnitudeImportance,
};
#[allow(clippy::cognitive_complexity)]
fn main() {
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ Magnitude Pruning with Aprender ║");
println!("║ Prune neural networks by weight magnitude ║");
println!("╚══════════════════════════════════════════════════════════════╝\n");
println!("📊 Creating Linear Layer (16 → 8)");
let layer = Linear::new(16, 8);
let weights = layer.weight();
let total_params = weights.data().len();
println!(" Weight shape: {:?}", weights.shape());
println!(" Total parameters: {}\n", total_params);
println!("🔬 Computing L1 Magnitude Importance");
let l1_importance = MagnitudeImportance::l1();
let l1_scores = l1_importance.compute(&layer, None).expect("L1 importance");
println!(" Method: {}", l1_scores.method);
println!(" Stats:");
println!(" - Min: {:.6}", l1_scores.stats.min);
println!(" - Max: {:.6}", l1_scores.stats.max);
println!(" - Mean: {:.6}", l1_scores.stats.mean);
println!(" - Std: {:.6}\n", l1_scores.stats.std);
println!("🔬 Computing L2 Magnitude Importance");
let l2_importance = MagnitudeImportance::l2();
let l2_scores = l2_importance.compute(&layer, None).expect("L2 importance");
println!(" Method: {}", l2_scores.method);
println!(" Stats:");
println!(" - Min: {:.6}", l2_scores.stats.min);
println!(" - Max: {:.6}", l2_scores.stats.max);
println!(" - Mean: {:.6}", l2_scores.stats.mean);
println!(" - Std: {:.6}\n", l2_scores.stats.std);
println!("✂️ Generating Unstructured Mask (50% sparsity)");
let mask = generate_unstructured_mask(&l1_scores.values, 0.5).expect("Unstructured mask");
let sparsity = mask.sparsity();
let nonzeros = mask.nnz();
let zeros = mask.num_zeros();
println!(" Achieved sparsity: {:.1}%", sparsity * 100.0);
println!(" Non-zero weights: {}", nonzeros);
println!(" Pruned weights: {}\n", zeros);
println!("✂️ Generating 2:4 N:M Mask (50% structured sparsity)");
let nm_layer = Linear::new(8, 8); let nm_importance = MagnitudeImportance::l1();
match nm_importance.compute(&nm_layer, None) {
Ok(nm_scores) => match generate_nm_mask(&nm_scores.values, 2, 4) {
Ok(nm_mask) => {
println!(" Pattern: 2:4 (2 non-zeros per 4 elements)");
println!(" Achieved sparsity: {:.1}%", nm_mask.sparsity() * 100.0);
let mask_data = nm_mask.tensor().data();
let mut valid_groups = 0;
let mut total_groups = 0;
for chunk in mask_data.chunks(4) {
if chunk.len() == 4 {
let chunk_nonzeros: usize =
chunk.iter().map(|&v| if v > 0.5 { 1 } else { 0 }).sum();
if chunk_nonzeros == 2 {
valid_groups += 1;
}
total_groups += 1;
}
}
println!(" Valid 2:4 groups: {}/{}", valid_groups, total_groups);
}
Err(e) => println!(" N:M mask error: {}", e),
},
Err(e) => println!(" Importance error: {}", e),
}
println!("\n📉 Applying Mask to Weights");
let mut pruned_weights = weights.clone();
mask.apply(&mut pruned_weights).expect("Apply mask");
let zeros_after: usize = pruned_weights
.data()
.iter()
.filter(|&&v| v.abs() < 1e-10)
.count();
let pruned_len = pruned_weights.data().len();
println!(
" Zeros after pruning: {} ({:.1}%)",
zeros_after,
zeros_after as f32 / pruned_len as f32 * 100.0
);
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ Pruning Summary ║");
println!("╠══════════════════════════════════════════════════════════════╣");
println!(
"║ Original parameters: {:>6} ║",
total_params
);
println!(
"║ Pruned parameters: {:>6} ({:.0}% reduction) ║",
zeros_after,
sparsity * 100.0
);
println!(
"║ Remaining parameters: {:>6} ║",
total_params - zeros_after
);
println!("╚══════════════════════════════════════════════════════════════╝");
}