Expand description
Neural network pruning: importance scoring, sparsity masks, and compression. Neural network pruning module.
This module provides importance scoring and sparsity mask generation
for neural network pruning, following the specification in
docs/specifications/advanced-pruning.md.
§Toyota Way Principles
- Jidoka: All numerical operations validate for NaN/Inf
- Poka-Yoke: Type-safe patterns prevent invalid configurations
- Genchi Genbutsu: Calibration uses real activation data
§Example
ⓘ
use aprender::pruning::{MagnitudeImportance, Importance};
use aprender::nn::Linear;
let layer = Linear::new(512, 256);
let importance = MagnitudeImportance::l2();
let scores = importance.compute(&layer, None).expect("valid layer weights");
println!("Importance stats: min={}, max={}", scores.stats.min, scores.stats.max);§References
- Han, S., et al. (2015). Learning both weights and connections. NeurIPS.
- Sun, M., et al. (2023). A simple and effective pruning approach. arXiv:2306.11695.
- Frantar, E., & Alistarh, D. (2023). SparseGPT. ICML.
- Frankle, J., & Carbin, M. (2018). The Lottery Ticket Hypothesis. arXiv:1803.03635.
Structs§
- Activation
Stats - Per-layer activation statistics.
- Block
Importance Scores - Block Importance (BI) scores for all layers.
- Block
Sparse Tensor - Block sparse tensor representation.
- COOTensor
- Coordinate (COO) sparse tensor representation.
- CSRTensor
- Compressed Sparse Row (CSR) representation.
- Calibration
Context - Calibration context holding activation statistics for all layers.
- Channel
Importance - Channel importance scores.
- Dependency
Graph - Dependency graph for a neural network.
- Depth
Pruner - Minitron Depth Pruner: Layer removal based on Block Importance.
- Depth
Pruning Result - Result of depth (layer) pruning operation.
- Graph
Edge - An edge in the dependency graph.
- Graph
Node - A node in the dependency graph representing a layer or operation.
- Importance
Scores - Importance scores with metadata.
- Importance
Stats - Statistical summary of importance scores.
- Lottery
Ticket Config - Configuration for Lottery Ticket pruning.
- Lottery
Ticket Pruner - Lottery Ticket Hypothesis pruner.
- Lottery
Ticket Pruner Builder - Builder for
LotteryTicketPruner. - Magnitude
Importance - Magnitude-based importance estimator.
- Magnitude
Pruner - Simple magnitude-based pruner.
- Pruning
Plan - Pruning plan that ensures consistency across dependent layers.
- Pruning
Result - Result of a pruning operation with diagnostics.
- SparseGPT
Importance SparseGPTimportance estimator using Hessian-based saliency.- Sparsity
Mask - Sparsity mask with validation.
- Wanda
Importance - Wanda (Weights and Activations) importance estimator.
- Wanda
Pruner - Wanda-based pruner.
- Width
Pruner - Minitron Width Pruner: Channel removal based on activation importance.
- Width
Pruning Result - Result of width (channel) pruning operation.
- Winning
Ticket - A “winning ticket” - the sparse subnetwork found by LTH.
Enums§
- Dependency
Type - Type of dependency between layers.
- Node
Type - Type of layer/operation node.
- Norm
Type - Norm type for magnitude computation.
- Pruning
Error - Pruning operation errors with detailed context.
- Rewind
Strategy - Strategy for rewinding weights after pruning.
- Sparse
Format - Sparse tensor format enumeration.
- Sparse
Tensor - Unified sparse tensor type.
- Sparsity
Pattern - Sparsity pattern constraints.
Traits§
- Importance
- Core trait for importance estimation algorithms.
- Pruner
- High-level pruning interface.
Functions§
- generate_
block_ mask - Generate a block sparsity mask.
- generate_
column_ mask - Generate a column sparsity mask.
- generate_
nm_ mask - Generate an N:M structured sparsity mask.
- generate_
row_ mask - Generate a row sparsity mask.
- generate_
unstructured_ mask - Generate an unstructured sparsity mask based on importance scores.
- propagate_
channel_ pruning - Propagate channel pruning through a dependency graph.
- prune_
module - Convenience function to prune a module with a single call.
- sparsify
- Apply a sparsity mask to a tensor and return sparse representation.