use super::error::PruningError;
use crate::autograd::Tensor;
#[derive(Debug, Clone)]
pub struct DepthPruningResult {
pub removed_layers: Vec<(usize, f32)>,
pub original_depth: usize,
pub final_depth: usize,
}
impl DepthPruningResult {
#[must_use]
pub fn new(removed_layers: Vec<(usize, f32)>, original_depth: usize) -> Self {
let final_depth = original_depth.saturating_sub(removed_layers.len());
Self {
removed_layers,
original_depth,
final_depth,
}
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
if self.final_depth == 0 {
f32::INFINITY
} else {
self.original_depth as f32 / self.final_depth as f32
}
}
#[must_use]
pub fn removal_percentage(&self) -> f32 {
if self.original_depth == 0 {
0.0
} else {
self.removed_layers.len() as f32 / self.original_depth as f32 * 100.0
}
}
}
#[derive(Debug, Clone)]
pub struct BlockImportanceScores {
pub scores: Vec<(usize, f32)>,
pub num_samples: usize,
}
impl BlockImportanceScores {
#[must_use]
pub fn new(scores: Vec<(usize, f32)>, num_samples: usize) -> Self {
Self {
scores,
num_samples,
}
}
#[must_use]
pub fn get(&self, layer_idx: usize) -> Option<f32> {
self.scores
.iter()
.find(|(idx, _)| *idx == layer_idx)
.map(|(_, score)| *score)
}
#[must_use]
pub fn sorted_by_importance(&self) -> Vec<(usize, f32)> {
let mut sorted = self.scores.clone();
sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
sorted
}
#[must_use]
pub fn least_important(&self, n: usize) -> Vec<(usize, f32)> {
self.sorted_by_importance().into_iter().take(n).collect()
}
}
#[derive(Debug, Clone)]
pub struct DepthPruner {
num_layers_to_remove: usize,
iterative: bool,
min_layers: usize,
}
impl DepthPruner {
#[must_use]
pub fn new(num_layers_to_remove: usize) -> Self {
Self {
num_layers_to_remove,
iterative: true,
min_layers: 1,
}
}
#[must_use]
pub fn with_iterative(mut self, iterative: bool) -> Self {
self.iterative = iterative;
self
}
#[must_use]
pub fn with_min_layers(mut self, min_layers: usize) -> Self {
self.min_layers = min_layers;
self
}
#[must_use]
pub fn num_layers_to_remove(&self) -> usize {
self.num_layers_to_remove
}
#[must_use]
pub fn is_iterative(&self) -> bool {
self.iterative
}
pub fn cosine_similarity(a: &Tensor, b: &Tensor) -> Result<f32, PruningError> {
let a_data = a.data();
let b_data = b.data();
if a_data.len() != b_data.len() {
return Err(PruningError::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
});
}
if a_data.is_empty() {
return Ok(1.0); }
let sim = crate::nn::functional::cosine_similarity_slice(a_data, b_data);
let both_zero =
a_data.iter().all(|&x| x.abs() < 1e-10) && b_data.iter().all(|&x| x.abs() < 1e-10);
if both_zero {
return Ok(1.0);
}
Ok(sim)
}
pub fn compute_layer_importance(input: &Tensor, output: &Tensor) -> Result<f32, PruningError> {
let cos_sim = Self::cosine_similarity(input, output)?;
Ok(1.0 - cos_sim)
}
pub fn compute_block_importance(
&self,
layer_inputs: &[Tensor],
layer_outputs: &[Tensor],
) -> Result<BlockImportanceScores, PruningError> {
if layer_inputs.len() != layer_outputs.len() {
return Err(PruningError::ShapeMismatch {
expected: vec![layer_inputs.len()],
got: vec![layer_outputs.len()],
});
}
if layer_inputs.is_empty() {
return Ok(BlockImportanceScores::new(vec![], 0));
}
let mut scores = Vec::with_capacity(layer_inputs.len());
for (idx, (input, output)) in layer_inputs.iter().zip(layer_outputs.iter()).enumerate() {
let bi = Self::compute_layer_importance(input, output)?;
scores.push((idx, bi));
}
Ok(BlockImportanceScores::new(scores, 1))
}
pub fn select_layers_to_remove(
&self,
scores: &BlockImportanceScores,
num_layers: usize,
) -> Result<Vec<usize>, PruningError> {
let max_removable = num_layers.saturating_sub(self.min_layers);
if self.num_layers_to_remove > max_removable {
return Err(PruningError::InvalidSparsity {
value: self.num_layers_to_remove as f32,
constraint: format!(
"Cannot remove {} layers from {} total (min {} required, max removable: {})",
self.num_layers_to_remove, num_layers, self.min_layers, max_removable
),
});
}
let actual_remove = self.num_layers_to_remove;
let to_remove: Vec<usize> = scores
.least_important(actual_remove)
.into_iter()
.map(|(idx, _)| idx)
.collect();
let mut sorted = to_remove;
sorted.sort_by(|a, b| b.cmp(a));
Ok(sorted)
}
pub fn validate(&self, num_layers: usize) -> Result<(), PruningError> {
if num_layers < self.min_layers {
return Err(PruningError::InvalidSparsity {
value: num_layers as f32,
constraint: format!(
"Model has {} layers but minimum is {}",
num_layers, self.min_layers
),
});
}
let max_removable = num_layers.saturating_sub(self.min_layers);
if self.num_layers_to_remove > max_removable {
return Err(PruningError::InvalidSparsity {
value: self.num_layers_to_remove as f32,
constraint: format!(
"Cannot remove {} layers from {} (max removable: {})",
self.num_layers_to_remove, num_layers, max_removable
),
});
}
Ok(())
}
}
impl Default for DepthPruner {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(test)]
#[path = "depth_tests.rs"]
mod tests;