#![allow(clippy::excessive_nesting)] #![allow(unused_variables)]
use crate::tensor::Tensor;
use anyhow::{anyhow, Result};
use scirs2_core::random::*; use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct PruningConfig {
pub target_sparsity: f32,
pub iterative: bool,
pub iterations: usize,
pub fine_tune: bool,
pub exclude_layers: HashSet<String>,
pub magnitude_threshold: Option<f32>,
pub seed: Option<u64>,
}
impl Default for PruningConfig {
fn default() -> Self {
Self {
target_sparsity: 0.5,
iterative: false,
iterations: 1,
fine_tune: true,
exclude_layers: HashSet::new(),
magnitude_threshold: None,
seed: None,
}
}
}
pub trait PruningStrategy: Send + Sync {
fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor>;
fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct PruningResult<M>
where
M: crate::traits::Model,
{
pub model: M,
pub sparsity: f32,
pub pruned_params: usize,
pub total_params: usize,
pub layer_sparsity: HashMap<String, f32>,
}
pub trait Pruner: Send + Sync {
fn prune<M>(&self, model: M, config: &PruningConfig) -> Result<PruningResult<M>>
where
M: crate::traits::Model + Clone;
fn estimate_pruning_potential<M>(
&self,
model: &M,
config: &PruningConfig,
) -> Result<PruningStats>
where
M: crate::traits::Model;
}
#[derive(Debug, Clone)]
pub struct PruningStats {
pub total_params: usize,
pub zero_params: usize,
pub sparsity: f32,
pub layer_stats: HashMap<String, LayerPruningStats>,
}
#[derive(Debug, Clone)]
pub struct LayerPruningStats {
pub total_params: usize,
pub zero_params: usize,
pub sparsity: f32,
}
pub struct MagnitudePruner {
#[allow(dead_code)]
threshold: f32,
}
impl MagnitudePruner {
pub fn new(sparsity: f32) -> Self {
Self {
threshold: sparsity,
}
}
}
impl PruningStrategy for MagnitudePruner {
fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let mask = self.get_mask(weights, config)?;
let pruned = weights
.data()?
.iter()
.zip(mask.data()?.iter())
.map(|(w, m)| if *m > 0.5 { *w } else { 0.0 })
.collect::<Vec<_>>();
Ok(Tensor::from_vec(pruned, &weights.shape())?)
}
fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let data = weights.data()?;
let mut abs_weights: Vec<(f32, usize)> =
data.iter().enumerate().map(|(i, &w)| (w.abs(), i)).collect();
abs_weights.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let num_prune = (data.len() as f32 * config.target_sparsity) as usize;
let mut mask = vec![1.0; data.len()];
for i in 0..num_prune.min(abs_weights.len()) {
mask[abs_weights[i].1] = 0.0;
}
Ok(Tensor::from_vec(mask, &weights.shape())?)
}
fn name(&self) -> &str {
"MagnitudePruner"
}
}
pub struct StructuredPruner {
pruning_dim: usize,
}
impl StructuredPruner {
pub fn new(pruning_dim: usize) -> Self {
Self { pruning_dim }
}
}
impl PruningStrategy for StructuredPruner {
fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let shape = weights.shape();
if shape.len() < 2 {
return Err(anyhow!("Structured pruning requires at least 2D tensors"));
}
let importance_scores = self.calculate_importance(weights)?;
let num_structures = shape[self.pruning_dim];
let num_prune = (num_structures as f32 * config.target_sparsity) as usize;
let mut indices: Vec<usize> = (0..num_structures).collect();
indices.sort_by(|&a, &b| {
importance_scores[a]
.partial_cmp(&importance_scores[b])
.expect("Partial comparison failed")
});
let pruned_indices: HashSet<_> = indices.iter().take(num_prune).cloned().collect();
let data = weights.data()?;
let mut pruned_data = Vec::with_capacity(data.len());
for (i, &val) in data.iter().enumerate() {
let structure_idx = (i / shape.iter().skip(self.pruning_dim + 1).product::<usize>())
% shape[self.pruning_dim];
if pruned_indices.contains(&structure_idx) {
pruned_data.push(0.0);
} else {
pruned_data.push(val);
}
}
Ok(Tensor::from_vec(pruned_data, &shape)?)
}
fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
Ok(Tensor::ones(&weights.shape())?)
}
fn name(&self) -> &str {
"StructuredPruner"
}
}
impl StructuredPruner {
fn calculate_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
let shape = weights.shape();
let num_structures = shape[self.pruning_dim];
let mut importance = vec![0.0; num_structures];
let data = weights.data()?;
let structure_size = shape.iter().skip(self.pruning_dim + 1).product::<usize>();
let structures_per_batch = shape.iter().take(self.pruning_dim).product::<usize>();
for (i, importance_ref) in importance.iter_mut().enumerate() {
let mut sum_sq = 0.0;
for j in 0..structures_per_batch {
for k in 0..structure_size {
let idx = j * num_structures * structure_size + i * structure_size + k;
if idx < data.len() {
sum_sq += data[idx] * data[idx];
}
}
}
*importance_ref = sum_sq.sqrt();
}
Ok(importance)
}
}
pub struct UnstructuredPruner {
random: bool,
}
impl UnstructuredPruner {
pub fn new(random: bool) -> Self {
Self { random }
}
}
impl PruningStrategy for UnstructuredPruner {
fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let data = weights.data()?;
let num_prune = (data.len() as f32 * config.target_sparsity) as usize;
let mut pruned = data.to_vec();
if self.random {
let mut rng = thread_rng();
let mut indices: Vec<usize> = (0..data.len()).collect();
for i in (1..indices.len()).rev() {
let j = rng.random_range(0..=i);
indices.swap(i, j);
}
for i in 0..num_prune.min(indices.len()) {
pruned[indices[i]] = 0.0;
}
} else {
let magnitude_pruner = MagnitudePruner::new(config.target_sparsity);
return magnitude_pruner.prune_weights(weights, config);
}
Ok(Tensor::from_vec(pruned, &weights.shape())?)
}
fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let data = weights.data()?;
let num_prune = (data.len() as f32 * config.target_sparsity) as usize;
let mut mask = vec![1.0; data.len()];
if self.random {
let mut rng = thread_rng();
let mut indices: Vec<usize> = (0..data.len()).collect();
for i in (1..indices.len()).rev() {
let j = rng.random_range(0..=i);
indices.swap(i, j);
}
for i in 0..num_prune.min(indices.len()) {
mask[indices[i]] = 0.0;
}
}
Ok(Tensor::from_vec(mask, &weights.shape())?)
}
fn name(&self) -> &str {
"UnstructuredPruner"
}
}
pub struct GradualPruner {
initial_sparsity: f32,
final_sparsity: f32,
begin_step: usize,
end_step: usize,
#[allow(dead_code)]
frequency: usize,
}
impl GradualPruner {
pub fn new(
initial_sparsity: f32,
final_sparsity: f32,
begin_step: usize,
end_step: usize,
frequency: usize,
) -> Self {
Self {
initial_sparsity,
final_sparsity,
begin_step,
end_step,
frequency,
}
}
pub fn get_sparsity_at_step(&self, step: usize) -> f32 {
if step < self.begin_step {
return 0.0;
}
if step >= self.end_step {
return self.final_sparsity;
}
let progress = (step - self.begin_step) as f32 / (self.end_step - self.begin_step) as f32;
self.initial_sparsity + (self.final_sparsity - self.initial_sparsity) * progress
}
}
#[derive(Debug, Clone)]
pub enum PruningSchedule {
OneShot { step: usize },
Gradual {
begin_step: usize,
end_step: usize,
frequency: usize,
},
Iterative {
steps: Vec<usize>,
sparsities: Vec<f32>,
},
}
pub struct ChannelPruner {
importance_metric: ChannelImportanceMetric,
}
#[derive(Debug, Clone)]
pub enum ChannelImportanceMetric {
L1Norm,
L2Norm,
MeanActivation,
GeometricMedian,
}
impl ChannelPruner {
pub fn new(metric: ChannelImportanceMetric) -> Self {
Self {
importance_metric: metric,
}
}
}
impl PruningStrategy for ChannelPruner {
fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let shape = weights.shape();
if shape.len() != 4 {
return Err(anyhow!("Channel pruning requires 4D tensors (NCHW format)"));
}
let num_channels = shape[1]; let channel_importance = self.calculate_channel_importance(weights)?;
let num_prune = (num_channels as f32 * config.target_sparsity) as usize;
let mut sorted_channels: Vec<(f32, usize)> =
channel_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
sorted_channels.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let pruned_channels: HashSet<usize> =
sorted_channels.iter().take(num_prune).map(|(_, idx)| *idx).collect();
let data = weights.data()?;
let mut pruned_data = data.to_vec();
let channel_size = shape[2] * shape[3]; let batch_channel_size = num_channels * channel_size;
for batch in 0..shape[0] {
for channel in &pruned_channels {
let start_idx = batch * batch_channel_size + channel * channel_size;
let end_idx = start_idx + channel_size;
for i in start_idx..end_idx.min(pruned_data.len()) {
pruned_data[i] = 0.0;
}
}
}
Ok(Tensor::from_vec(pruned_data, &shape)?)
}
fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let shape = weights.shape();
let num_channels = shape[1];
let channel_importance = self.calculate_channel_importance(weights)?;
let num_prune = (num_channels as f32 * config.target_sparsity) as usize;
let mut sorted_channels: Vec<(f32, usize)> =
channel_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
sorted_channels.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let pruned_channels: HashSet<usize> =
sorted_channels.iter().take(num_prune).map(|(_, idx)| *idx).collect();
let data = weights.data()?;
let mut mask = vec![1.0; data.len()];
let channel_size = shape[2] * shape[3];
let batch_channel_size = num_channels * channel_size;
for batch in 0..shape[0] {
for channel in &pruned_channels {
let start_idx = batch * batch_channel_size + channel * channel_size;
let end_idx = start_idx + channel_size;
for i in start_idx..end_idx.min(mask.len()) {
mask[i] = 0.0;
}
}
}
Ok(Tensor::from_vec(mask, &shape)?)
}
fn name(&self) -> &str {
"ChannelPruner"
}
}
impl ChannelPruner {
fn calculate_channel_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
let shape = weights.shape();
let num_channels = shape[1];
let channel_size = shape[2] * shape[3];
let data = weights.data()?;
let mut importance = vec![0.0; num_channels];
for (channel, importance_ref) in importance.iter_mut().enumerate() {
let mut channel_score = 0.0;
let mut count = 0;
for batch in 0..shape[0] {
let start_idx = batch * num_channels * channel_size + channel * channel_size;
let end_idx = start_idx + channel_size;
for data_ref in data.iter().take(end_idx.min(data.len())).skip(start_idx) {
match self.importance_metric {
ChannelImportanceMetric::L1Norm => channel_score += data_ref.abs(),
ChannelImportanceMetric::L2Norm => channel_score += data_ref * data_ref,
ChannelImportanceMetric::MeanActivation => channel_score += data_ref.abs(),
ChannelImportanceMetric::GeometricMedian => channel_score += data_ref.abs(),
}
count += 1;
}
}
*importance_ref = match self.importance_metric {
ChannelImportanceMetric::L2Norm => (channel_score / count as f32).sqrt(),
_ => channel_score / count as f32,
};
}
Ok(importance)
}
}
pub struct FilterPruner {
importance_metric: FilterImportanceMetric,
}
#[derive(Debug, Clone)]
pub enum FilterImportanceMetric {
L1Norm,
L2Norm,
APoZ,
}
impl FilterPruner {
pub fn new(metric: FilterImportanceMetric) -> Self {
Self {
importance_metric: metric,
}
}
}
impl PruningStrategy for FilterPruner {
fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let shape = weights.shape();
if shape.len() != 4 {
return Err(anyhow!("Filter pruning requires 4D tensors (NCHW format)"));
}
let num_filters = shape[0]; let filter_importance = self.calculate_filter_importance(weights)?;
let num_prune = (num_filters as f32 * config.target_sparsity) as usize;
let mut sorted_filters: Vec<(f32, usize)> =
filter_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
sorted_filters.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let pruned_filters: HashSet<usize> =
sorted_filters.iter().take(num_prune).map(|(_, idx)| *idx).collect();
let data = weights.data()?;
let mut pruned_data = data.to_vec();
let filter_size = shape[1] * shape[2] * shape[3];
for filter_idx in &pruned_filters {
let start_idx = filter_idx * filter_size;
let end_idx = start_idx + filter_size;
for i in start_idx..end_idx.min(pruned_data.len()) {
pruned_data[i] = 0.0;
}
}
Ok(Tensor::from_vec(pruned_data, &shape)?)
}
fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let shape = weights.shape();
let num_filters = shape[0];
let filter_importance = self.calculate_filter_importance(weights)?;
let num_prune = (num_filters as f32 * config.target_sparsity) as usize;
let mut sorted_filters: Vec<(f32, usize)> =
filter_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
sorted_filters.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let pruned_filters: HashSet<usize> =
sorted_filters.iter().take(num_prune).map(|(_, idx)| *idx).collect();
let data = weights.data()?;
let mut mask = vec![1.0; data.len()];
let filter_size = shape[1] * shape[2] * shape[3];
for filter_idx in &pruned_filters {
let start_idx = filter_idx * filter_size;
let end_idx = start_idx + filter_size;
for i in start_idx..end_idx.min(mask.len()) {
mask[i] = 0.0;
}
}
Ok(Tensor::from_vec(mask, &shape)?)
}
fn name(&self) -> &str {
"FilterPruner"
}
}
impl FilterPruner {
fn calculate_filter_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
let shape = weights.shape();
let num_filters = shape[0];
let filter_size = shape[1] * shape[2] * shape[3];
let data = weights.data()?;
let mut importance = vec![0.0; num_filters];
for (filter, importance_ref) in importance.iter_mut().enumerate() {
let start_idx = filter * filter_size;
let end_idx = start_idx + filter_size;
let mut filter_score = 0.0;
for data_ref in data.iter().take(end_idx.min(data.len())).skip(start_idx) {
match self.importance_metric {
FilterImportanceMetric::L1Norm => filter_score += data_ref.abs(),
FilterImportanceMetric::L2Norm => filter_score += data_ref * data_ref,
FilterImportanceMetric::APoZ => {
filter_score += if *data_ref == 0.0 { 1.0 } else { 0.0 }
},
}
}
*importance_ref = match self.importance_metric {
FilterImportanceMetric::L2Norm => filter_score.sqrt(),
FilterImportanceMetric::APoZ => filter_score / filter_size as f32,
_ => filter_score,
};
}
Ok(importance)
}
}
pub struct HeadPruner {
num_heads: usize,
head_dim: usize,
}
impl HeadPruner {
pub fn new(num_heads: usize, head_dim: usize) -> Self {
Self {
num_heads,
head_dim,
}
}
}
impl PruningStrategy for HeadPruner {
fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let shape = weights.shape();
if shape.len() != 2 {
return Err(anyhow!(
"Head pruning requires 2D tensors (attention weight matrices)"
));
}
let num_prune = (self.num_heads as f32 * config.target_sparsity) as usize;
let head_importance = self.calculate_head_importance(weights)?;
let mut sorted_heads: Vec<(f32, usize)> =
head_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
sorted_heads.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let pruned_heads: HashSet<usize> =
sorted_heads.iter().take(num_prune).map(|(_, idx)| *idx).collect();
let data = weights.data()?;
let mut pruned_data = data.to_vec();
for head_idx in &pruned_heads {
let start_col = head_idx * self.head_dim;
let end_col = start_col + self.head_dim;
for row in 0..shape[0] {
for col in start_col..end_col.min(shape[1]) {
let idx = row * shape[1] + col;
if idx < pruned_data.len() {
pruned_data[idx] = 0.0;
}
}
}
}
Ok(Tensor::from_vec(pruned_data, &shape)?)
}
fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
let shape = weights.shape();
let num_prune = (self.num_heads as f32 * config.target_sparsity) as usize;
let head_importance = self.calculate_head_importance(weights)?;
let mut sorted_heads: Vec<(f32, usize)> =
head_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
sorted_heads.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let pruned_heads: HashSet<usize> =
sorted_heads.iter().take(num_prune).map(|(_, idx)| *idx).collect();
let data = weights.data()?;
let mut mask = vec![1.0; data.len()];
for head_idx in &pruned_heads {
let start_col = head_idx * self.head_dim;
let end_col = start_col + self.head_dim;
for row in 0..shape[0] {
for col in start_col..end_col.min(shape[1]) {
let idx = row * shape[1] + col;
if idx < mask.len() {
mask[idx] = 0.0;
}
}
}
}
Ok(Tensor::from_vec(mask, &shape)?)
}
fn name(&self) -> &str {
"HeadPruner"
}
}
impl HeadPruner {
fn calculate_head_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
let shape = weights.shape();
let data = weights.data()?;
let mut importance = vec![0.0; self.num_heads];
for (head, importance_ref) in importance.iter_mut().enumerate() {
let start_col = head * self.head_dim;
let end_col = start_col + self.head_dim;
let mut head_score = 0.0;
let mut count = 0;
for row in 0..shape[0] {
for col in start_col..end_col.min(shape[1]) {
let idx = row * shape[1] + col;
if idx < data.len() {
head_score += data[idx] * data[idx]; count += 1;
}
}
}
*importance_ref = if count > 0 { (head_score / count as f32).sqrt() } else { 0.0 };
}
Ok(importance)
}
}
pub struct LayerPruner {
layer_importance: HashMap<String, f32>,
}
impl Default for LayerPruner {
fn default() -> Self {
Self::new()
}
}
impl LayerPruner {
pub fn new() -> Self {
Self {
layer_importance: HashMap::new(),
}
}
pub fn with_importance_scores(scores: HashMap<String, f32>) -> Self {
Self {
layer_importance: scores,
}
}
pub fn analyze_model<M>(&mut self, model: &M) -> Result<()>
where
M: crate::traits::Model,
{
let total_params = model.num_parameters();
let typical_layers = vec![
("embedding".to_string(), 0.8),
("attention_0".to_string(), 0.6),
("feedforward_0".to_string(), 0.4),
("attention_1".to_string(), 0.5),
("feedforward_1".to_string(), 0.3),
("output".to_string(), 0.9),
];
for (name, importance) in typical_layers {
self.layer_importance.insert(name, importance * total_params as f32);
}
Ok(())
}
pub fn get_pruning_candidates(&self, config: &PruningConfig) -> Result<Vec<String>> {
let total_layers = self.layer_importance.len();
let num_prune = (total_layers as f32 * config.target_sparsity) as usize;
let mut sorted_layers: Vec<(f32, String)> = self
.layer_importance
.iter()
.map(|(name, &score)| (score, name.clone()))
.collect();
sorted_layers.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
let pruned_layers: Vec<String> = sorted_layers
.iter()
.take(num_prune)
.map(|(_, name)| name.clone())
.filter(|name| !config.exclude_layers.contains(name))
.collect();
Ok(pruned_layers)
}
}
pub struct AutomaticPruner {
strategies: HashMap<String, Box<dyn PruningStrategy>>,
default_strategy: Box<dyn PruningStrategy>,
}
impl AutomaticPruner {
pub fn new() -> Self {
let mut strategies = HashMap::new();
strategies.insert(
"conv".to_string(),
Box::new(FilterPruner::new(FilterImportanceMetric::L2Norm)) as Box<dyn PruningStrategy>,
);
strategies.insert(
"attention".to_string(),
Box::new(HeadPruner::new(12, 64)) as Box<dyn PruningStrategy>,
);
strategies.insert(
"linear".to_string(),
Box::new(MagnitudePruner::new(0.5)) as Box<dyn PruningStrategy>,
);
let default_strategy = Box::new(MagnitudePruner::new(0.5));
Self {
strategies,
default_strategy,
}
}
pub fn with_strategy(mut self, layer_type: String, strategy: Box<dyn PruningStrategy>) -> Self {
self.strategies.insert(layer_type, strategy);
self
}
pub fn with_default_strategy(mut self, strategy: Box<dyn PruningStrategy>) -> Self {
self.default_strategy = strategy;
self
}
#[allow(dead_code)]
fn detect_layer_type(&self, layer_name: &str) -> String {
let name_lower = layer_name.to_lowercase();
if name_lower.contains("conv") {
"conv".to_string()
} else if name_lower.contains("attention") || name_lower.contains("attn") {
"attention".to_string()
} else if name_lower.contains("linear")
|| name_lower.contains("dense")
|| name_lower.contains("fc")
{
"linear".to_string()
} else if name_lower.contains("embed") {
"embedding".to_string()
} else {
"unknown".to_string()
}
}
}
impl Pruner for AutomaticPruner {
fn prune<M>(&self, model: M, config: &PruningConfig) -> Result<PruningResult<M>>
where
M: crate::traits::Model + Clone,
{
let total_params = model.num_parameters();
let estimated_pruned_params = (total_params as f32 * config.target_sparsity) as usize;
let mut layer_sparsity = HashMap::new();
let simulated_layers = vec![
("embedding", 0.2), ("attention", 0.4), ("feedforward", 0.6), ("output", 0.1), ];
for (layer_type, base_sparsity) in simulated_layers {
let actual_sparsity = (base_sparsity * config.target_sparsity).min(0.9);
layer_sparsity.insert(layer_type.to_string(), actual_sparsity);
}
let overall_sparsity = config.target_sparsity;
let pruned_model = model;
Ok(PruningResult {
model: pruned_model,
sparsity: overall_sparsity,
pruned_params: estimated_pruned_params,
total_params,
layer_sparsity,
})
}
fn estimate_pruning_potential<M>(
&self,
model: &M,
config: &PruningConfig,
) -> Result<PruningStats>
where
M: crate::traits::Model,
{
let total_params = model.num_parameters();
let estimated_zero_params = (total_params as f32 * config.target_sparsity) as usize;
let mut layer_stats = HashMap::new();
let simulated_layers = vec![
("embedding", 0.15),
("attention", 0.30),
("feedforward", 0.45),
("output", 0.05),
];
for (layer_name, param_fraction) in simulated_layers {
let layer_total = (total_params as f32 * param_fraction) as usize;
let layer_zeros = (layer_total as f32 * config.target_sparsity) as usize;
let layer_sparsity =
if layer_total > 0 { layer_zeros as f32 / layer_total as f32 } else { 0.0 };
layer_stats.insert(
layer_name.to_string(),
LayerPruningStats {
total_params: layer_total,
zero_params: layer_zeros,
sparsity: layer_sparsity,
},
);
}
let overall_sparsity = if total_params > 0 {
estimated_zero_params as f32 / total_params as f32
} else {
0.0
};
Ok(PruningStats {
total_params,
zero_params: estimated_zero_params,
sparsity: overall_sparsity,
layer_stats,
})
}
}
impl Default for AutomaticPruner {
fn default() -> Self {
Self::new()
}
}
pub struct PruningUtils;
impl PruningUtils {
pub fn calculate_layer_sensitivities<M>(
model: &M,
_validation_data: &[Tensor],
) -> Result<HashMap<String, f32>>
where
M: crate::traits::Model,
{
let mut sensitivities = HashMap::new();
let _total_params = model.num_parameters();
let typical_sensitivities = vec![
("embedding".to_string(), 0.95), ("attention".to_string(), 0.75), ("feedforward".to_string(), 0.50), ("output".to_string(), 0.90), ("classifier".to_string(), 0.90), ];
for (layer_name, sensitivity) in typical_sensitivities {
sensitivities.insert(layer_name, sensitivity);
}
Ok(sensitivities)
}
pub fn generate_pruning_schedule(
initial_sparsity: f32,
final_sparsity: f32,
num_steps: usize,
) -> Vec<f32> {
let mut schedule = Vec::new();
for i in 0..num_steps {
let progress = i as f32 / (num_steps - 1) as f32;
let cubic_progress = progress * progress * progress;
let sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * cubic_progress;
schedule.push(sparsity);
}
schedule
}
pub fn estimate_compression_ratio(target_sparsity: f32, quantization_bits: Option<u8>) -> f32 {
let sparsity_compression = 1.0 / (1.0 - target_sparsity);
match quantization_bits {
Some(bits) => sparsity_compression * (32.0 / bits as f32), None => sparsity_compression,
}
}
pub fn validate_config(config: &PruningConfig) -> Result<()> {
if config.target_sparsity < 0.0 || config.target_sparsity > 1.0 {
return Err(anyhow!("Target sparsity must be between 0.0 and 1.0"));
}
if config.iterations == 0 {
return Err(anyhow!("Number of iterations must be greater than 0"));
}
if let Some(threshold) = config.magnitude_threshold {
if threshold < 0.0 {
return Err(anyhow!("Magnitude threshold must be non-negative"));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pruning_config_default() {
let config = PruningConfig::default();
assert_eq!(config.target_sparsity, 0.5);
assert!(!config.iterative);
assert_eq!(config.iterations, 1);
assert!(config.fine_tune);
}
#[test]
fn test_magnitude_pruner() -> Result<()> {
let pruner = MagnitudePruner::new(0.5);
let weights = Tensor::from_vec(vec![0.1, -0.8, 0.3, -0.2, 0.9, -0.1], &[2, 3])?;
let config = PruningConfig {
target_sparsity: 0.5,
..Default::default()
};
let mask = pruner.get_mask(&weights, &config)?;
let mask_data = mask.data()?;
let zero_count = mask_data.iter().filter(|&&x| x == 0.0).count();
assert_eq!(zero_count, 3);
Ok(())
}
#[test]
fn test_pruning_utils_validation() {
let valid_config = PruningConfig::default();
assert!(PruningUtils::validate_config(&valid_config).is_ok());
let invalid_config = PruningConfig {
target_sparsity: 1.5, ..Default::default()
};
assert!(PruningUtils::validate_config(&invalid_config).is_err());
}
#[test]
fn test_compression_ratio_estimation() {
let ratio = PruningUtils::estimate_compression_ratio(0.5, None);
assert_eq!(ratio, 2.0);
let ratio_with_quant = PruningUtils::estimate_compression_ratio(0.5, Some(8));
assert_eq!(ratio_with_quant, 8.0); }
#[test]
fn test_pruning_schedule() {
let schedule = PruningUtils::generate_pruning_schedule(0.0, 0.8, 5);
assert_eq!(schedule.len(), 5);
assert_eq!(schedule[0], 0.0);
assert_eq!(schedule[4], 0.8);
for i in 1..schedule.len() {
assert!(schedule[i] >= schedule[i - 1]);
}
}
}