use super::error::PruningError;
use crate::autograd::Tensor;
#[derive(Debug, Clone)]
pub struct WidthPruningResult {
pub original_hidden_dim: usize,
pub final_hidden_dim: usize,
pub original_intermediate_dim: usize,
pub final_intermediate_dim: usize,
pub hidden_channels_kept: Vec<usize>,
pub intermediate_channels_kept: Vec<usize>,
}
impl WidthPruningResult {
#[must_use]
pub fn new(
original_hidden_dim: usize,
final_hidden_dim: usize,
original_intermediate_dim: usize,
final_intermediate_dim: usize,
hidden_channels_kept: Vec<usize>,
intermediate_channels_kept: Vec<usize>,
) -> Self {
Self {
original_hidden_dim,
final_hidden_dim,
original_intermediate_dim,
final_intermediate_dim,
hidden_channels_kept,
intermediate_channels_kept,
}
}
#[must_use]
pub fn hidden_compression_ratio(&self) -> f32 {
if self.final_hidden_dim == 0 {
f32::INFINITY
} else {
self.original_hidden_dim as f32 / self.final_hidden_dim as f32
}
}
#[must_use]
pub fn intermediate_compression_ratio(&self) -> f32 {
if self.final_intermediate_dim == 0 {
f32::INFINITY
} else {
self.original_intermediate_dim as f32 / self.final_intermediate_dim as f32
}
}
#[must_use]
pub fn hidden_removal_percentage(&self) -> f32 {
if self.original_hidden_dim == 0 {
0.0
} else {
let removed = self.original_hidden_dim - self.final_hidden_dim;
removed as f32 / self.original_hidden_dim as f32 * 100.0
}
}
#[must_use]
pub fn intermediate_removal_percentage(&self) -> f32 {
if self.original_intermediate_dim == 0 {
0.0
} else {
let removed = self.original_intermediate_dim - self.final_intermediate_dim;
removed as f32 / self.original_intermediate_dim as f32 * 100.0
}
}
}
#[derive(Debug, Clone)]
pub struct ChannelImportance {
pub hidden: Tensor,
pub intermediate: Tensor,
pub num_samples: usize,
}
impl ChannelImportance {
#[must_use]
pub fn new(hidden: Tensor, intermediate: Tensor, num_samples: usize) -> Self {
Self {
hidden,
intermediate,
num_samples,
}
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden.data().len()
}
#[must_use]
pub fn intermediate_dim(&self) -> usize {
self.intermediate.data().len()
}
#[must_use]
pub fn top_hidden_channels(&self, k: usize) -> Vec<usize> {
top_k_indices(self.hidden.data(), k)
}
#[must_use]
pub fn top_intermediate_channels(&self, k: usize) -> Vec<usize> {
top_k_indices(self.intermediate.data(), k)
}
}
fn top_k_indices(data: &[f32], k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f32)> = data.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut result: Vec<usize> = indexed.into_iter().take(k).map(|(idx, _)| idx).collect();
result.sort_unstable(); result
}
#[derive(Debug, Clone)]
pub struct WidthPruner {
target_hidden_dim: usize,
target_intermediate_dim: usize,
num_attention_heads: usize,
}
impl WidthPruner {
#[must_use]
pub fn new(
target_hidden_dim: usize,
target_intermediate_dim: usize,
num_attention_heads: usize,
) -> Self {
Self {
target_hidden_dim,
target_intermediate_dim,
num_attention_heads,
}
}
#[must_use]
pub fn target_hidden_dim(&self) -> usize {
self.target_hidden_dim
}
#[must_use]
pub fn target_intermediate_dim(&self) -> usize {
self.target_intermediate_dim
}
#[must_use]
pub fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
pub fn validate(
&self,
original_hidden_dim: usize,
original_intermediate_dim: usize,
) -> Result<(), PruningError> {
if !self
.target_hidden_dim
.is_multiple_of(self.num_attention_heads)
{
return Err(PruningError::InvalidPattern {
message: format!(
"target_hidden_dim ({}) must be divisible by num_attention_heads ({})",
self.target_hidden_dim, self.num_attention_heads
),
});
}
if self.target_hidden_dim > original_hidden_dim {
return Err(PruningError::InvalidSparsity {
value: self.target_hidden_dim as f32,
constraint: format!(
"target_hidden_dim ({}) exceeds original ({})",
self.target_hidden_dim, original_hidden_dim
),
});
}
if self.target_intermediate_dim > original_intermediate_dim {
return Err(PruningError::InvalidSparsity {
value: self.target_intermediate_dim as f32,
constraint: format!(
"target_intermediate_dim ({}) exceeds original ({})",
self.target_intermediate_dim, original_intermediate_dim
),
});
}
Ok(())
}
pub fn compute_channel_importance(
&self,
hidden_activations: &Tensor,
intermediate_activations: &Tensor,
) -> Result<ChannelImportance, PruningError> {
let h_shape = hidden_activations.shape();
let i_shape = intermediate_activations.shape();
if h_shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0], got: h_shape.to_vec(),
});
}
if i_shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0],
got: i_shape.to_vec(),
});
}
if h_shape[0] != i_shape[0] {
return Err(PruningError::ShapeMismatch {
expected: vec![h_shape[0], i_shape[1]],
got: vec![i_shape[0], i_shape[1]],
});
}
let num_samples = h_shape[0];
let hidden_dim = h_shape[1];
let intermediate_dim = i_shape[1];
let h_data = hidden_activations.data();
let mut hidden_importance = vec![0.0f32; hidden_dim];
for sample in 0..num_samples {
for d in 0..hidden_dim {
let val = h_data[sample * hidden_dim + d];
hidden_importance[d] += val * val;
}
}
if num_samples > 0 {
for imp in &mut hidden_importance {
*imp /= num_samples as f32;
}
}
let i_data = intermediate_activations.data();
let mut intermediate_importance = vec![0.0f32; intermediate_dim];
for sample in 0..num_samples {
for d in 0..intermediate_dim {
let val = i_data[sample * intermediate_dim + d];
intermediate_importance[d] += val * val;
}
}
if num_samples > 0 {
for imp in &mut intermediate_importance {
*imp /= num_samples as f32;
}
}
Ok(ChannelImportance::new(
Tensor::new(&hidden_importance, &[hidden_dim]),
Tensor::new(&intermediate_importance, &[intermediate_dim]),
num_samples,
))
}
pub fn select_channels_to_keep(
&self,
importance: &ChannelImportance,
) -> Result<(Vec<usize>, Vec<usize>), PruningError> {
self.validate(importance.hidden_dim(), importance.intermediate_dim())?;
let hidden_keep = importance.top_hidden_channels(self.target_hidden_dim);
let intermediate_keep = importance.top_intermediate_channels(self.target_intermediate_dim);
Ok((hidden_keep, intermediate_keep))
}
#[must_use]
pub fn generate_hidden_mask(&self, original_dim: usize, channels_to_keep: &[usize]) -> Tensor {
let mut mask = vec![0.0f32; original_dim];
for &idx in channels_to_keep {
if idx < original_dim {
mask[idx] = 1.0;
}
}
Tensor::new(&mask, &[original_dim])
}
#[must_use]
pub fn head_dim_after_pruning(&self) -> usize {
if self.num_attention_heads == 0 {
0
} else {
self.target_hidden_dim / self.num_attention_heads
}
}
}
impl Default for WidthPruner {
fn default() -> Self {
Self::new(0, 0, 1)
}
}
#[cfg(test)]
#[path = "width_tests.rs"]
mod tests;