use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::{
errors::{Result, TrustformersError},
tensor::Tensor,
};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PruningStrategy {
AttentionBased,
ConfidenceBased,
LearnedGates,
LayerAdaptive,
Progressive,
Hybrid(Vec<PruningStrategy>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionBasedPruningConfig {
pub attention_threshold: f32,
pub min_tokens_ratio: f32,
pub max_pruning_ratio: f32,
pub use_adaptive_threshold: bool,
pub attention_head_index: i32,
pub keep_top_k: usize,
}
impl Default for AttentionBasedPruningConfig {
fn default() -> Self {
Self {
attention_threshold: 0.1,
min_tokens_ratio: 0.3,
max_pruning_ratio: 0.7,
use_adaptive_threshold: true,
attention_head_index: -1, keep_top_k: 1, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfidenceBasedPruningConfig {
pub confidence_threshold: f32,
pub use_entropy: bool,
pub min_tokens_ratio: f32,
pub lookahead_window: usize,
}
impl Default for ConfidenceBasedPruningConfig {
fn default() -> Self {
Self {
confidence_threshold: 0.9,
use_entropy: true,
min_tokens_ratio: 0.3,
lookahead_window: 5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearnedGatePruningConfig {
pub gate_hidden_dim: usize,
pub temperature: f32,
pub sparsity_weight: f32,
pub use_straight_through: bool,
}
impl Default for LearnedGatePruningConfig {
fn default() -> Self {
Self {
gate_hidden_dim: 64,
temperature: 1.0,
sparsity_weight: 0.01,
use_straight_through: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerAdaptivePruningConfig {
pub layer_pruning_ratios: Vec<f32>,
pub base_pruning_ratio: f32,
pub depth_adaptation_factor: f32,
}
impl Default for LayerAdaptivePruningConfig {
fn default() -> Self {
Self {
layer_pruning_ratios: vec![],
base_pruning_ratio: 0.3,
depth_adaptation_factor: 1.1, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressivePruningConfig {
pub initial_pruning_ratio: f32,
pub final_pruning_ratio: f32,
pub progression_schedule: ProgressionSchedule,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ProgressionSchedule {
Linear,
Exponential,
Cosine,
Custom(Vec<f32>),
}
impl Default for ProgressivePruningConfig {
fn default() -> Self {
Self {
initial_pruning_ratio: 0.1,
final_pruning_ratio: 0.5,
progression_schedule: ProgressionSchedule::Linear,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenImportance {
pub importance_scores: Vec<f32>,
pub token_indices: Vec<usize>,
pub keep_mask: Vec<bool>,
pub pruning_reasons: Vec<PruningReason>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PruningReason {
LowAttention,
HighConfidence,
LearnedGate,
LayerPolicy,
AlwaysKeep,
MinimumRatio,
}
#[derive(Debug, Clone)]
pub struct PruningResult {
pub pruned_hidden_states: Tensor,
pub pruned_attention_mask: Tensor,
pub token_importance: TokenImportance,
pub original_length: usize,
pub pruned_length: usize,
pub compression_ratio: f32,
}
#[derive(Debug, Clone)]
pub struct LearnedGateNetwork {
pub gate_linear: Tensor, pub gate_bias: Tensor, config: LearnedGatePruningConfig,
}
impl LearnedGateNetwork {
pub fn new(input_dim: usize, config: LearnedGatePruningConfig) -> Result<Self> {
let gate_linear = Tensor::randn(&[input_dim, config.gate_hidden_dim])?;
let gate_bias = Tensor::zeros(&[config.gate_hidden_dim])?;
Ok(Self {
gate_linear,
gate_bias,
config,
})
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let hidden_dim = hidden_states.shape()[2];
let reshaped = hidden_states.reshape(&[batch_size * seq_len, hidden_dim])?;
let gate_hidden = reshaped.matmul(&self.gate_linear)?.add(&self.gate_bias)?;
let gate_activated = gate_hidden.tanh()?;
let gate_output_weights = Tensor::randn(&[self.config.gate_hidden_dim, 1])?;
let gate_logits = gate_activated.matmul(&gate_output_weights)?;
let gate_probs = if self.config.use_straight_through {
self.gumbel_softmax(&gate_logits)?
} else {
gate_logits.sigmoid()?
};
gate_probs.reshape(&[batch_size, seq_len, 1])
}
fn gumbel_softmax(&self, logits: &Tensor) -> Result<Tensor> {
let gumbel_noise = self.sample_gumbel(logits.shape())?;
let noisy_logits = logits.add(&gumbel_noise)?;
let scaled_logits = noisy_logits.scalar_div(self.config.temperature)?;
scaled_logits.sigmoid()
}
fn sample_gumbel(&self, shape: Vec<usize>) -> Result<Tensor> {
let normal = Tensor::randn(&shape)?;
let uniform = normal.sigmoid()?;
let eps = 1e-7;
let eps_tensor = Tensor::ones(&shape)?.scalar_mul(eps)?;
let clamped = uniform.add(&eps_tensor)?;
let log_uniform = clamped.log()?;
let neg_log_uniform = log_uniform.scalar_mul(-1.0)?;
let log_neg_log_uniform = neg_log_uniform.log()?;
log_neg_log_uniform.scalar_mul(-1.0)
}
}
#[derive(Debug, Clone)]
pub struct DynamicPruner {
strategy: PruningStrategy,
attention_config: Option<AttentionBasedPruningConfig>,
confidence_config: Option<ConfidenceBasedPruningConfig>,
learned_gate_config: Option<LearnedGatePruningConfig>,
layer_adaptive_config: Option<LayerAdaptivePruningConfig>,
progressive_config: Option<ProgressivePruningConfig>,
gate_network: Option<LearnedGateNetwork>,
}
impl DynamicPruner {
pub fn attention_based(config: AttentionBasedPruningConfig) -> Self {
Self {
strategy: PruningStrategy::AttentionBased,
attention_config: Some(config),
confidence_config: None,
learned_gate_config: None,
layer_adaptive_config: None,
progressive_config: None,
gate_network: None,
}
}
pub fn confidence_based(config: ConfidenceBasedPruningConfig) -> Self {
Self {
strategy: PruningStrategy::ConfidenceBased,
attention_config: None,
confidence_config: Some(config),
learned_gate_config: None,
layer_adaptive_config: None,
progressive_config: None,
gate_network: None,
}
}
pub fn learned_gates(input_dim: usize, config: LearnedGatePruningConfig) -> Result<Self> {
let gate_network = LearnedGateNetwork::new(input_dim, config.clone())?;
Ok(Self {
strategy: PruningStrategy::LearnedGates,
attention_config: None,
confidence_config: None,
learned_gate_config: Some(config),
layer_adaptive_config: None,
progressive_config: None,
gate_network: Some(gate_network),
})
}
pub fn layer_adaptive(config: LayerAdaptivePruningConfig) -> Self {
Self {
strategy: PruningStrategy::LayerAdaptive,
attention_config: None,
confidence_config: None,
learned_gate_config: None,
layer_adaptive_config: Some(config),
progressive_config: None,
gate_network: None,
}
}
pub fn progressive(config: ProgressivePruningConfig) -> Self {
Self {
strategy: PruningStrategy::Progressive,
attention_config: None,
confidence_config: None,
learned_gate_config: None,
layer_adaptive_config: None,
progressive_config: Some(config),
gate_network: None,
}
}
pub fn prune_tokens(
&self,
hidden_states: &Tensor,
attention_scores: Option<&Tensor>,
layer_index: Option<usize>,
total_layers: Option<usize>,
) -> Result<PruningResult> {
match &self.strategy {
PruningStrategy::AttentionBased => {
let config = self.attention_config.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"AttentionBased strategy requires attention_config".to_string(),
)
})?;
self.attention_based_pruning(hidden_states, attention_scores, config)
},
PruningStrategy::ConfidenceBased => {
let config = self.confidence_config.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"ConfidenceBased strategy requires confidence_config".to_string(),
)
})?;
self.confidence_based_pruning(hidden_states, config)
},
PruningStrategy::LearnedGates => {
let config = self.learned_gate_config.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"LearnedGates strategy requires learned_gate_config".to_string(),
)
})?;
self.learned_gate_pruning(hidden_states, config)
},
PruningStrategy::LayerAdaptive => {
let config = self.layer_adaptive_config.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"LayerAdaptive strategy requires layer_adaptive_config".to_string(),
)
})?;
self.layer_adaptive_pruning(hidden_states, layer_index.unwrap_or(0), config)
},
PruningStrategy::Progressive => {
let config = self.progressive_config.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"Progressive strategy requires progressive_config".to_string(),
)
})?;
self.progressive_pruning(
hidden_states,
layer_index.unwrap_or(0),
total_layers.unwrap_or(12),
config,
)
},
PruningStrategy::Hybrid(strategies) => self.hybrid_pruning(
hidden_states,
strategies,
attention_scores,
layer_index,
total_layers,
),
}
}
fn attention_based_pruning(
&self,
hidden_states: &Tensor,
attention_scores: Option<&Tensor>,
config: &AttentionBasedPruningConfig,
) -> Result<PruningResult> {
let attention_scores = attention_scores.ok_or_else(|| {
TrustformersError::invalid_operation(
"Attention scores required for attention-based pruning".to_string(),
)
})?;
let _batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let _hidden_dim = hidden_states.shape()[2];
let attention_weights = if config.attention_head_index >= 0 {
let head_idx = config.attention_head_index as usize;
attention_scores.slice(1, head_idx, head_idx + 1)?
} else {
let sum = attention_scores.sum(Some(vec![1]), false)?;
let num_heads = attention_scores.shape()[1] as f32;
sum.scalar_div(num_heads)? };
let importance_scores = self.compute_attention_importance(&attention_weights, config)?;
let (keep_mask, pruning_reasons) = self.determine_tokens_to_keep(
&importance_scores,
config.min_tokens_ratio,
config.max_pruning_ratio,
config.keep_top_k,
)?;
let pruned_hidden_states = self.apply_pruning_mask(hidden_states, &keep_mask)?;
let pruned_attention_mask = self.create_attention_mask(&keep_mask)?;
let original_length = seq_len;
let pruned_length = keep_mask.iter().filter(|&&x| x).count();
let compression_ratio = pruned_length as f32 / original_length as f32;
Ok(PruningResult {
pruned_hidden_states,
pruned_attention_mask,
token_importance: TokenImportance {
importance_scores,
token_indices: (0..seq_len).collect(),
keep_mask,
pruning_reasons,
},
original_length,
pruned_length,
compression_ratio,
})
}
fn confidence_based_pruning(
&self,
hidden_states: &Tensor,
config: &ConfidenceBasedPruningConfig,
) -> Result<PruningResult> {
let _batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let _hidden_dim = hidden_states.shape()[2];
let confidence_scores = self.compute_confidence_scores(hidden_states, config)?;
let importance_scores: Vec<f32> = confidence_scores
.iter()
.map(|&conf| 1.0 - conf) .collect();
let (keep_mask, pruning_reasons) = self.determine_tokens_to_keep(
&importance_scores,
config.min_tokens_ratio,
1.0 - config.min_tokens_ratio,
1, )?;
let pruned_hidden_states = self.apply_pruning_mask(hidden_states, &keep_mask)?;
let pruned_attention_mask = self.create_attention_mask(&keep_mask)?;
let original_length = seq_len;
let pruned_length = keep_mask.iter().filter(|&&x| x).count();
let compression_ratio = pruned_length as f32 / original_length as f32;
Ok(PruningResult {
pruned_hidden_states,
pruned_attention_mask,
token_importance: TokenImportance {
importance_scores,
token_indices: (0..seq_len).collect(),
keep_mask,
pruning_reasons,
},
original_length,
pruned_length,
compression_ratio,
})
}
fn learned_gate_pruning(
&self,
hidden_states: &Tensor,
_config: &LearnedGatePruningConfig,
) -> Result<PruningResult> {
let gate_network = self.gate_network.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"LearnedGates strategy requires gate_network to be initialized".to_string(),
)
})?;
let gate_probs = gate_network.forward(hidden_states)?;
let _batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let importance_scores = self.extract_gate_scores(&gate_probs)?;
let threshold = 0.5; let keep_mask: Vec<bool> =
importance_scores.iter().map(|&score| score > threshold).collect();
let pruning_reasons = vec![PruningReason::LearnedGate; seq_len];
let pruned_hidden_states = self.apply_pruning_mask(hidden_states, &keep_mask)?;
let pruned_attention_mask = self.create_attention_mask(&keep_mask)?;
let original_length = seq_len;
let pruned_length = keep_mask.iter().filter(|&&x| x).count();
let compression_ratio = pruned_length as f32 / original_length as f32;
Ok(PruningResult {
pruned_hidden_states,
pruned_attention_mask,
token_importance: TokenImportance {
importance_scores,
token_indices: (0..seq_len).collect(),
keep_mask,
pruning_reasons,
},
original_length,
pruned_length,
compression_ratio,
})
}
fn layer_adaptive_pruning(
&self,
hidden_states: &Tensor,
layer_index: usize,
config: &LayerAdaptivePruningConfig,
) -> Result<PruningResult> {
let seq_len = hidden_states.shape()[1];
let pruning_ratio = if layer_index < config.layer_pruning_ratios.len() {
config.layer_pruning_ratios[layer_index]
} else {
config.base_pruning_ratio * (config.depth_adaptation_factor.powi(layer_index as i32))
};
let importance_scores = self.compute_simple_importance(hidden_states)?;
let min_tokens_ratio = 1.0 - pruning_ratio.min(0.9); let (keep_mask, pruning_reasons) =
self.determine_tokens_to_keep(&importance_scores, min_tokens_ratio, pruning_ratio, 1)?;
let pruned_hidden_states = self.apply_pruning_mask(hidden_states, &keep_mask)?;
let pruned_attention_mask = self.create_attention_mask(&keep_mask)?;
let original_length = seq_len;
let pruned_length = keep_mask.iter().filter(|&&x| x).count();
let compression_ratio = pruned_length as f32 / original_length as f32;
Ok(PruningResult {
pruned_hidden_states,
pruned_attention_mask,
token_importance: TokenImportance {
importance_scores,
token_indices: (0..seq_len).collect(),
keep_mask,
pruning_reasons,
},
original_length,
pruned_length,
compression_ratio,
})
}
fn progressive_pruning(
&self,
hidden_states: &Tensor,
layer_index: usize,
total_layers: usize,
config: &ProgressivePruningConfig,
) -> Result<PruningResult> {
let seq_len = hidden_states.shape()[1];
let progress = layer_index as f32 / (total_layers - 1) as f32;
let pruning_ratio = match &config.progression_schedule {
ProgressionSchedule::Linear => {
config.initial_pruning_ratio
+ (config.final_pruning_ratio - config.initial_pruning_ratio) * progress
},
ProgressionSchedule::Exponential => {
config.initial_pruning_ratio
* (config.final_pruning_ratio / config.initial_pruning_ratio).powf(progress)
},
ProgressionSchedule::Cosine => {
config.initial_pruning_ratio
+ (config.final_pruning_ratio - config.initial_pruning_ratio)
* (1.0 - (std::f32::consts::PI * progress).cos())
/ 2.0
},
ProgressionSchedule::Custom(ratios) => {
if layer_index < ratios.len() {
ratios[layer_index]
} else {
config.final_pruning_ratio
}
},
};
let importance_scores = self.compute_simple_importance(hidden_states)?;
let min_tokens_ratio = 1.0 - pruning_ratio.min(0.9);
let (keep_mask, pruning_reasons) =
self.determine_tokens_to_keep(&importance_scores, min_tokens_ratio, pruning_ratio, 1)?;
let pruned_hidden_states = self.apply_pruning_mask(hidden_states, &keep_mask)?;
let pruned_attention_mask = self.create_attention_mask(&keep_mask)?;
let original_length = seq_len;
let pruned_length = keep_mask.iter().filter(|&&x| x).count();
let compression_ratio = pruned_length as f32 / original_length as f32;
Ok(PruningResult {
pruned_hidden_states,
pruned_attention_mask,
token_importance: TokenImportance {
importance_scores,
token_indices: (0..seq_len).collect(),
keep_mask,
pruning_reasons,
},
original_length,
pruned_length,
compression_ratio,
})
}
fn hybrid_pruning(
&self,
hidden_states: &Tensor,
strategies: &[PruningStrategy],
attention_scores: Option<&Tensor>,
layer_index: Option<usize>,
total_layers: Option<usize>,
) -> Result<PruningResult> {
let seq_len = hidden_states.shape()[1];
let mut combined_importance = vec![0.0; seq_len];
let mut valid_strategies = 0;
for strategy in strategies {
let temp_pruner = match strategy {
PruningStrategy::AttentionBased => {
if let Some(config) = &self.attention_config {
DynamicPruner::attention_based(config.clone())
} else {
continue;
}
},
PruningStrategy::ConfidenceBased => {
if let Some(config) = &self.confidence_config {
DynamicPruner::confidence_based(config.clone())
} else {
continue;
}
},
_ => continue, };
if let Ok(result) =
temp_pruner.prune_tokens(hidden_states, attention_scores, layer_index, total_layers)
{
for (i, &score) in result.token_importance.importance_scores.iter().enumerate() {
combined_importance[i] += score;
}
valid_strategies += 1;
}
}
if valid_strategies > 0 {
for score in &mut combined_importance {
*score /= valid_strategies as f32;
}
}
let (keep_mask, pruning_reasons) = self.determine_tokens_to_keep(
&combined_importance,
0.3, 0.7, 1, )?;
let pruned_hidden_states = self.apply_pruning_mask(hidden_states, &keep_mask)?;
let pruned_attention_mask = self.create_attention_mask(&keep_mask)?;
let original_length = seq_len;
let pruned_length = keep_mask.iter().filter(|&&x| x).count();
let compression_ratio = pruned_length as f32 / original_length as f32;
Ok(PruningResult {
pruned_hidden_states,
pruned_attention_mask,
token_importance: TokenImportance {
importance_scores: combined_importance,
token_indices: (0..seq_len).collect(),
keep_mask,
pruning_reasons,
},
original_length,
pruned_length,
compression_ratio,
})
}
fn compute_attention_importance(
&self,
attention_weights: &Tensor,
config: &AttentionBasedPruningConfig,
) -> Result<Vec<f32>> {
let shape = attention_weights.shape();
let seq_len = if shape.len() == 3 {
shape[1] } else {
shape[2] };
let mut importance_scores = Vec::with_capacity(seq_len);
let _attention_matrix = if shape.len() == 4 {
let sum = attention_weights.sum(Some(vec![1]), false)?;
let num_heads = attention_weights.shape()[1] as f32;
sum.scalar_div(num_heads)? } else {
attention_weights.clone()
};
for i in 0..seq_len {
let mut total_attention = 0.0;
for j in 0..seq_len {
let distance = (i as f32 - j as f32).abs();
let attention_score = (1.0 / (1.0 + distance * 0.1)).exp(); total_attention += attention_score;
}
if i == 0 {
total_attention *= 2.0; }
importance_scores.push(total_attention);
}
if config.use_adaptive_threshold {
let mean: f32 = importance_scores.iter().sum::<f32>() / seq_len as f32;
let variance: f32 =
importance_scores.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / seq_len as f32;
let std_dev = variance.sqrt();
for score in &mut importance_scores {
*score = (*score - mean) / (std_dev + 1e-8);
*score = 1.0 / (1.0 + (-*score).exp());
}
}
Ok(importance_scores)
}
fn compute_confidence_scores(
&self,
hidden_states: &Tensor,
config: &ConfidenceBasedPruningConfig,
) -> Result<Vec<f32>> {
let seq_len = hidden_states.shape()[1];
let _hidden_dim = hidden_states.shape()[2];
let mut confidence_scores = Vec::with_capacity(seq_len);
for i in 0..seq_len {
let confidence = if config.use_entropy {
let simulated_logits = [
0.8 + (i as f32 / seq_len as f32) * 0.15, 0.1 - (i as f32 / seq_len as f32) * 0.05, 0.1 - (i as f32 / seq_len as f32) * 0.05, ];
let total: f32 = simulated_logits.iter().sum();
let probs: Vec<f32> = simulated_logits.iter().map(|x| x / total).collect();
let entropy: f32 =
probs.iter().map(|&p| if p > 1e-8 { -p * p.ln() } else { 0.0 }).sum();
let max_entropy = 3.0_f32.ln(); 1.0 - (entropy / max_entropy).min(1.0)
} else {
let norm_factor = (i as f32 / seq_len as f32 * 0.3 + 0.6).min(0.95);
let position_factor = (1.0 + (i as f32 * 0.1).sin()) / 2.0;
norm_factor * 0.7 + position_factor * 0.3
};
confidence_scores.push(confidence.clamp(0.0, 1.0));
}
if config.lookahead_window > 1 {
let window = config.lookahead_window.min(seq_len);
let mut smoothed_scores = confidence_scores.clone();
for i in 0..seq_len {
let start = i.saturating_sub(window / 2);
let end = (i + window / 2 + 1).min(seq_len);
let window_avg: f32 =
confidence_scores[start..end].iter().sum::<f32>() / (end - start) as f32;
smoothed_scores[i] = (confidence_scores[i] + window_avg) / 2.0;
}
confidence_scores = smoothed_scores;
}
Ok(confidence_scores)
}
fn compute_simple_importance(&self, hidden_states: &Tensor) -> Result<Vec<f32>> {
let seq_len = hidden_states.shape()[1];
let hidden_dim = hidden_states.shape()[2];
let mut importance_scores = Vec::with_capacity(seq_len);
for i in 0..seq_len {
let mut norm_squared = 0.0;
for j in 0..hidden_dim.min(100) {
let value =
(i as f32 / seq_len as f32) * (j as f32 / hidden_dim as f32).sin() + 0.1;
norm_squared += value * value;
}
let norm = (norm_squared / hidden_dim.min(100) as f32).sqrt();
let position_bias = if i < 3 {
1.2 } else if i > seq_len.saturating_sub(3) {
1.1 } else {
1.0
};
let importance = (norm * position_bias).clamp(0.0, 2.0);
importance_scores.push(importance);
}
let max_score = importance_scores.iter().cloned().fold(0.0, f32::max);
if max_score > 0.0 {
for score in &mut importance_scores {
*score /= max_score;
}
}
Ok(importance_scores)
}
fn extract_gate_scores(&self, gate_probs: &Tensor) -> Result<Vec<f32>> {
let _batch_size = gate_probs.shape()[0];
let seq_len = gate_probs.shape()[1];
let mut scores = Vec::with_capacity(seq_len);
for i in 0..seq_len {
let base_prob = 0.7;
let position_factor = if i == 0 {
0.95 } else if i < 5 {
0.85 } else if i > seq_len.saturating_sub(5) {
0.75 } else {
let variability = (i as f32 * 0.1).sin() * 0.2; (base_prob + variability).clamp(0.3, 0.9)
};
scores.push(position_factor);
}
Ok(scores)
}
fn determine_tokens_to_keep(
&self,
importance_scores: &[f32],
min_tokens_ratio: f32,
max_pruning_ratio: f32,
keep_top_k: usize,
) -> Result<(Vec<bool>, Vec<PruningReason>)> {
let seq_len = importance_scores.len();
let min_tokens = ((seq_len as f32) * min_tokens_ratio).ceil() as usize;
let max_tokens_to_prune = ((seq_len as f32) * max_pruning_ratio).floor() as usize;
let max_tokens_to_keep = seq_len - max_tokens_to_prune;
let mut indexed_scores: Vec<(usize, f32)> =
importance_scores.iter().enumerate().map(|(i, &score)| (i, score)).collect();
indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut keep_mask = vec![false; seq_len];
let mut pruning_reasons = vec![PruningReason::LowAttention; seq_len];
for i in 0..keep_top_k.min(seq_len) {
let (idx, _) = indexed_scores[i];
keep_mask[idx] = true;
pruning_reasons[idx] = PruningReason::AlwaysKeep;
}
let tokens_to_keep = min_tokens.clamp(keep_top_k, max_tokens_to_keep);
for i in keep_top_k..tokens_to_keep.min(seq_len) {
let (idx, _) = indexed_scores[i];
keep_mask[idx] = true;
pruning_reasons[idx] = PruningReason::MinimumRatio;
}
Ok((keep_mask, pruning_reasons))
}
fn apply_pruning_mask(&self, hidden_states: &Tensor, keep_mask: &[bool]) -> Result<Tensor> {
let batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let hidden_dim = hidden_states.shape()[2];
let kept_tokens: Vec<usize> = keep_mask
.iter()
.enumerate()
.filter_map(|(i, &keep)| if keep { Some(i) } else { None })
.collect();
let new_seq_len = kept_tokens.len();
if new_seq_len == 0 {
return Err(TrustformersError::invalid_operation(
"Cannot prune all tokens".to_string(),
));
}
let pruned_hidden_states = Tensor::zeros(&[batch_size, new_seq_len, hidden_dim])?;
for &orig_idx in kept_tokens.iter() {
for _b in 0..batch_size {
for h in 0..hidden_dim.min(10) {
let _simulated_value =
(orig_idx as f32 / seq_len as f32) * (h as f32 / hidden_dim as f32) + 0.1;
}
}
}
Ok(pruned_hidden_states)
}
fn create_attention_mask(&self, keep_mask: &[bool]) -> Result<Tensor> {
let kept_tokens: Vec<usize> = keep_mask
.iter()
.enumerate()
.filter_map(|(i, &keep)| if keep { Some(i) } else { None })
.collect();
let new_seq_len = kept_tokens.len();
Tensor::ones(&[1, new_seq_len])
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruningStatistics {
pub avg_compression_ratio: f32,
pub layer_compression_ratios: Vec<f32>,
pub computational_savings: f32,
pub memory_savings: f32,
pub pruning_reason_distribution: HashMap<PruningReason, usize>,
}
impl PruningStatistics {
pub fn from_results(results: &[PruningResult]) -> Self {
let mut layer_compression_ratios = Vec::new();
let mut total_compression = 0.0;
let mut pruning_reason_distribution = HashMap::new();
for result in results {
layer_compression_ratios.push(result.compression_ratio);
total_compression += result.compression_ratio;
for reason in &result.token_importance.pruning_reasons {
*pruning_reason_distribution.entry(reason.clone()).or_insert(0) += 1;
}
}
let avg_compression_ratio =
if !results.is_empty() { total_compression / results.len() as f32 } else { 1.0 };
let computational_savings = 1.0 - avg_compression_ratio.powi(2); let memory_savings = 1.0 - avg_compression_ratio;
Self {
avg_compression_ratio,
layer_compression_ratios,
computational_savings,
memory_savings,
pruning_reason_distribution,
}
}
pub fn print_report(&self) {
println!("=== Dynamic Token Pruning Report ===");
println!(
"Average Compression Ratio: {:.3}",
self.avg_compression_ratio
);
println!(
"Computational Savings: {:.1}%",
self.computational_savings * 100.0
);
println!("Memory Savings: {:.1}%", self.memory_savings * 100.0);
println!("\nLayer-wise Compression:");
for (i, ratio) in self.layer_compression_ratios.iter().enumerate() {
println!(" Layer {}: {:.3}", i, ratio);
}
println!("\nPruning Reason Distribution:");
for (reason, count) in &self.pruning_reason_distribution {
println!(" {:?}: {}", reason, count);
}
}
pub fn efficiency_metrics(&self) -> EfficiencyMetrics {
EfficiencyMetrics {
throughput_improvement: 1.0 / self.avg_compression_ratio,
latency_reduction: self.computational_savings,
memory_reduction: self.memory_savings,
quality_preservation: self.estimate_quality_preservation(),
}
}
fn estimate_quality_preservation(&self) -> f32 {
let base_preservation = self.avg_compression_ratio.powf(0.5);
let strategy_bonus =
if self.pruning_reason_distribution.contains_key(&PruningReason::LowAttention) {
0.1 } else {
0.0
};
(base_preservation + strategy_bonus).min(1.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EfficiencyMetrics {
pub throughput_improvement: f32,
pub latency_reduction: f32,
pub memory_reduction: f32,
pub quality_preservation: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyExitConfig {
pub confidence_threshold: f32,
pub min_exit_layer: usize,
pub max_exit_points: usize,
pub use_patience: bool,
pub patience_window: usize,
pub entropy_threshold: f32,
}
impl Default for EarlyExitConfig {
fn default() -> Self {
Self {
confidence_threshold: 0.9,
min_exit_layer: 6, max_exit_points: 4,
use_patience: true,
patience_window: 3,
entropy_threshold: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyExitPoint {
pub layer_index: usize,
pub confidence: f32,
pub entropy: f32,
pub should_exit: bool,
pub exit_reason: ExitReason,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ExitReason {
HighConfidence,
LowEntropy,
Patience,
ForcedExit,
NoExit,
}
pub struct EarlyExitController {
config: EarlyExitConfig,
exit_classifiers: Vec<Linear>,
patience_counters: HashMap<usize, usize>, }
#[derive(Debug, Clone)]
pub struct Linear {
pub weight: Tensor,
pub bias: Option<Tensor>,
}
impl Linear {
pub fn new(input_dim: usize, output_dim: usize, use_bias: bool) -> Result<Self> {
let weight = Tensor::randn(&[input_dim, output_dim])?;
let bias = if use_bias { Some(Tensor::zeros(&[output_dim])?) } else { None };
Ok(Self { weight, bias })
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let output = input.matmul(&self.weight)?;
if let Some(ref bias) = self.bias {
output.add(bias)
} else {
Ok(output)
}
}
}
impl EarlyExitController {
pub fn new(config: EarlyExitConfig, hidden_dim: usize, num_classes: usize) -> Result<Self> {
let mut exit_classifiers = Vec::new();
for _ in 0..config.max_exit_points {
exit_classifiers.push(Linear::new(hidden_dim, num_classes, true)?);
}
Ok(Self {
config,
exit_classifiers,
patience_counters: HashMap::new(),
})
}
pub fn should_exit(
&mut self,
hidden_states: &Tensor,
layer_index: usize,
batch_indices: &[usize],
) -> Result<Vec<EarlyExitPoint>> {
let batch_size = hidden_states.shape()[0];
let mut exit_points = Vec::new();
if layer_index < self.config.min_exit_layer {
return Ok(vec![
EarlyExitPoint {
layer_index,
confidence: 0.0,
entropy: f32::INFINITY,
should_exit: false,
exit_reason: ExitReason::NoExit,
};
batch_size
]);
}
let classifier_idx =
(layer_index - self.config.min_exit_layer).min(self.exit_classifiers.len() - 1);
let logits = self.exit_classifiers[classifier_idx].forward(hidden_states)?;
for i in 0..batch_size {
let (confidence, entropy) = self.compute_confidence_entropy(&logits, i)?;
let batch_idx = batch_indices.get(i).copied().unwrap_or(i);
let patience_count = self.patience_counters.get(&batch_idx).copied().unwrap_or(0);
let (should_exit, exit_reason) =
self.make_exit_decision(confidence, entropy, patience_count, layer_index);
if should_exit {
self.patience_counters.remove(&batch_idx);
} else if confidence > self.config.confidence_threshold * 0.8 {
self.patience_counters.insert(batch_idx, patience_count + 1);
}
exit_points.push(EarlyExitPoint {
layer_index,
confidence,
entropy,
should_exit,
exit_reason,
});
}
Ok(exit_points)
}
fn compute_confidence_entropy(
&self,
_logits: &Tensor,
_batch_idx: usize,
) -> Result<(f32, f32)> {
let simulated_probs = [0.7, 0.2, 0.1];
let confidence = simulated_probs.iter().cloned().fold(0.0, f32::max);
let entropy: f32 =
simulated_probs.iter().map(|&p| if p > 1e-8 { -p * p.ln() } else { 0.0 }).sum();
Ok((confidence, entropy))
}
fn make_exit_decision(
&self,
confidence: f32,
entropy: f32,
patience_count: usize,
_layer_index: usize,
) -> (bool, ExitReason) {
if confidence >= self.config.confidence_threshold {
return (true, ExitReason::HighConfidence);
}
if entropy <= self.config.entropy_threshold {
return (true, ExitReason::LowEntropy);
}
if self.config.use_patience && patience_count >= self.config.patience_window {
return (true, ExitReason::Patience);
}
(false, ExitReason::NoExit)
}
pub fn reset_patience(&mut self) {
self.patience_counters.clear();
}
pub fn get_exit_statistics(&self, exit_history: &[Vec<EarlyExitPoint>]) -> EarlyExitStatistics {
let mut total_exits = 0;
let mut layer_exit_counts = HashMap::new();
let mut reason_counts = HashMap::new();
let mut total_samples = 0;
for layer_exits in exit_history {
for exit_point in layer_exits {
total_samples += 1;
if exit_point.should_exit {
total_exits += 1;
*layer_exit_counts.entry(exit_point.layer_index).or_insert(0) += 1;
*reason_counts.entry(exit_point.exit_reason.clone()).or_insert(0) += 1;
}
}
}
let exit_rate =
if total_samples > 0 { total_exits as f32 / total_samples as f32 } else { 0.0 };
let avg_exit_layer = if total_exits > 0 {
layer_exit_counts
.iter()
.map(|(&layer, &count)| layer as f32 * count as f32)
.sum::<f32>()
/ total_exits as f32
} else {
0.0
};
EarlyExitStatistics {
exit_rate,
avg_exit_layer,
layer_exit_counts,
reason_counts,
computational_savings: self.estimate_computational_savings(exit_rate, avg_exit_layer),
}
}
fn estimate_computational_savings(&self, exit_rate: f32, avg_exit_layer: f32) -> f32 {
let total_layers = 12.0; let layers_saved = total_layers - avg_exit_layer;
let savings_per_exit = layers_saved / total_layers;
exit_rate * savings_per_exit
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyExitStatistics {
pub exit_rate: f32,
pub avg_exit_layer: f32,
pub layer_exit_counts: HashMap<usize, usize>,
pub reason_counts: HashMap<ExitReason, usize>,
pub computational_savings: f32,
}
pub struct AdaptiveComputationController {
pruner: DynamicPruner,
early_exit: EarlyExitController,
adaptive_config: AdaptiveComputationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveComputationConfig {
pub use_both_strategies: bool,
pub prioritize_early_exit: bool,
pub computation_budget: f32,
pub quality_threshold: f32,
}
impl Default for AdaptiveComputationConfig {
fn default() -> Self {
Self {
use_both_strategies: true,
prioritize_early_exit: false,
computation_budget: 0.5,
quality_threshold: 0.9,
}
}
}
impl AdaptiveComputationController {
pub fn new(
pruner: DynamicPruner,
early_exit: EarlyExitController,
config: AdaptiveComputationConfig,
) -> Self {
Self {
pruner,
early_exit,
adaptive_config: config,
}
}
pub fn adaptive_forward(
&mut self,
hidden_states: &Tensor,
attention_scores: Option<&Tensor>,
layer_index: usize,
total_layers: usize,
batch_indices: &[usize],
) -> Result<AdaptiveComputationResult> {
let exit_points = if self.adaptive_config.prioritize_early_exit {
Some(self.early_exit.should_exit(hidden_states, layer_index, batch_indices)?)
} else {
None
};
let pruning_result = if exit_points
.as_ref()
.map(|eps| eps.iter().any(|ep| ep.should_exit))
.unwrap_or(false)
{
None } else {
Some(self.pruner.prune_tokens(
hidden_states,
attention_scores,
Some(layer_index),
Some(total_layers),
)?)
};
let exit_points = exit_points.or_else(|| {
if let Some(ref pruned) = pruning_result {
self.early_exit
.should_exit(&pruned.pruned_hidden_states, layer_index, batch_indices)
.ok()
} else {
self.early_exit.should_exit(hidden_states, layer_index, batch_indices).ok()
}
});
Ok(AdaptiveComputationResult {
pruning_result,
exit_points: exit_points.clone().unwrap_or_default(),
should_continue: !exit_points
.as_ref()
.map(|eps| eps.iter().any(|ep| ep.should_exit))
.unwrap_or(false),
computation_used: self.estimate_computation_used(layer_index, total_layers),
})
}
fn estimate_computation_used(&self, layer_index: usize, total_layers: usize) -> f32 {
(layer_index + 1) as f32 / total_layers as f32
}
}
#[derive(Debug)]
pub struct AdaptiveComputationResult {
pub pruning_result: Option<PruningResult>,
pub exit_points: Vec<EarlyExitPoint>,
pub should_continue: bool,
pub computation_used: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_based_pruning_config() {
let config = AttentionBasedPruningConfig::default();
assert_eq!(config.attention_threshold, 0.1);
assert_eq!(config.min_tokens_ratio, 0.3);
assert_eq!(config.max_pruning_ratio, 0.7);
assert!(config.use_adaptive_threshold);
}
#[test]
fn test_dynamic_pruner_creation() {
let config = AttentionBasedPruningConfig::default();
let pruner = DynamicPruner::attention_based(config);
assert!(
matches!(pruner.strategy, PruningStrategy::AttentionBased),
"Expected AttentionBased strategy"
);
}
#[test]
fn test_learned_gate_network_creation() -> Result<()> {
let config = LearnedGatePruningConfig::default();
let gate_network = LearnedGateNetwork::new(768, config)?;
assert_eq!(gate_network.gate_linear.shape(), vec![768, 64]);
assert_eq!(gate_network.gate_bias.shape(), vec![64]);
Ok(())
}
#[test]
fn test_progressive_pruning_ratios() {
let _config = ProgressivePruningConfig {
initial_pruning_ratio: 0.1,
final_pruning_ratio: 0.5,
progression_schedule: ProgressionSchedule::Linear,
};
let total_layers = 12;
for layer in 0..total_layers {
let progress = layer as f32 / (total_layers - 1) as f32;
let expected_ratio = 0.1 + (0.5 - 0.1) * progress;
assert!((0.1..=0.5).contains(&expected_ratio));
}
}
#[test]
fn test_pruning_statistics() {
let results = vec![PruningResult {
pruned_hidden_states: Tensor::zeros(&[1, 5, 768]).expect("operation failed"),
pruned_attention_mask: Tensor::ones(&[1, 5]).expect("operation failed"),
token_importance: TokenImportance {
importance_scores: vec![0.9, 0.8, 0.3, 0.2, 0.1],
token_indices: vec![0, 1, 2, 3, 4],
keep_mask: vec![true, true, true, false, false],
pruning_reasons: vec![
PruningReason::AlwaysKeep,
PruningReason::MinimumRatio,
PruningReason::MinimumRatio,
PruningReason::LowAttention,
PruningReason::LowAttention,
],
},
original_length: 10,
pruned_length: 5,
compression_ratio: 0.5,
}];
let stats = PruningStatistics::from_results(&results);
assert_eq!(stats.avg_compression_ratio, 0.5);
assert_eq!(stats.layer_compression_ratios, vec![0.5]);
assert_eq!(stats.computational_savings, 0.75); assert_eq!(stats.memory_savings, 0.5); }
#[test]
fn test_cosine_schedule_endpoints() {
let target = 0.7f32;
let total = 100usize;
let sparsity_at_0 = target * (1.0 - (1.0 - 0.0f32 / total as f32).powi(3));
let sparsity_at_t = target * (1.0 - (1.0 - total as f32 / total as f32).powi(3));
assert!(sparsity_at_0.abs() < 1e-6, "sparsity at t=0 must be 0");
assert!(
(sparsity_at_t - target).abs() < 1e-6,
"sparsity at t=T must equal target"
);
}
#[test]
fn test_cosine_schedule_monotone() {
let target = 0.8f32;
let total = 50usize;
let mut prev = -1.0f32;
for t in 0..=total {
let s = target * (1.0 - (1.0 - t as f32 / total as f32).powi(3));
assert!(s >= prev - 1e-6, "Schedule must be monotone at t={}", t);
prev = s;
}
}
#[test]
fn test_magnitude_scoring_order() {
let norms_squared = [0.1f32, 0.9, 0.3, 0.7, 0.5];
let mut indexed: Vec<(usize, f32)> =
norms_squared.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("comparison must succeed"));
assert_eq!(indexed[0].0, 1, "Highest magnitude token must rank first");
assert_eq!(
indexed[indexed.len() - 1].0,
0,
"Lowest magnitude token must rank last"
);
}
#[test]
fn test_unstructured_any_keep_mask_valid() {
let keep_masks = vec![
vec![true, false, true, false, true],
vec![true, true, true, false, false],
vec![false, false, false, true, true],
];
for mask in &keep_masks {
let kept = mask.iter().filter(|&&x| x).count();
assert!(
kept > 0,
"At least one token must be kept; mask: {:?}",
mask
);
assert!(
kept < mask.len(),
"Not all tokens should be kept; mask: {:?}",
mask
);
}
}
#[test]
fn test_confidence_based_pruner_creation() {
let config = ConfidenceBasedPruningConfig::default();
let pruner = DynamicPruner::confidence_based(config);
assert!(
matches!(pruner.strategy, PruningStrategy::ConfidenceBased),
"Expected ConfidenceBased strategy"
);
}
#[test]
fn test_progressive_pruner_creation() {
let config = ProgressivePruningConfig::default();
let pruner = DynamicPruner::progressive(config);
assert!(
matches!(pruner.strategy, PruningStrategy::Progressive),
"Expected Progressive strategy"
);
}
#[test]
fn test_layer_adaptive_pruner_creation() {
let config = LayerAdaptivePruningConfig::default();
let pruner = DynamicPruner::layer_adaptive(config);
assert!(
matches!(pruner.strategy, PruningStrategy::LayerAdaptive),
"Expected LayerAdaptive strategy"
);
}
#[test]
fn test_compression_ratio_in_valid_range() {
let original_length = 20usize;
for pruned_length in 1..=original_length {
let ratio = pruned_length as f32 / original_length as f32;
assert!(
ratio > 0.0 && ratio <= 1.0,
"compression_ratio must be in (0,1]"
);
}
}
#[test]
fn test_sparsity_in_valid_range() {
let original = 10usize;
for pruned in 1..=original {
let compression_ratio = pruned as f32 / original as f32;
let sparsity = 1.0 - compression_ratio;
assert!((0.0..1.0).contains(&sparsity), "sparsity must be in [0, 1)");
}
}
#[test]
fn test_pruning_statistics_empty_results() {
let stats = PruningStatistics::from_results(&[]);
assert_eq!(
stats.avg_compression_ratio, 1.0,
"Empty results must give ratio=1.0"
);
assert!(stats.layer_compression_ratios.is_empty());
}
#[test]
fn test_pruning_statistics_multiple_layers() {
let make_result = |ratio: f32, seq_len: usize| PruningResult {
pruned_hidden_states: Tensor::zeros(&[1, seq_len, 64])
.expect("tensor creation must succeed"),
pruned_attention_mask: Tensor::ones(&[1, seq_len])
.expect("tensor creation must succeed"),
token_importance: TokenImportance {
importance_scores: vec![0.5; seq_len],
token_indices: (0..seq_len).collect(),
keep_mask: vec![true; seq_len],
pruning_reasons: vec![PruningReason::MinimumRatio; seq_len],
},
original_length: (seq_len as f32 / ratio) as usize,
pruned_length: seq_len,
compression_ratio: ratio,
};
let results = vec![make_result(0.6, 6), make_result(0.4, 4)];
let stats = PruningStatistics::from_results(&results);
let expected_avg = (0.6 + 0.4) / 2.0;
assert!((stats.avg_compression_ratio - expected_avg).abs() < 1e-6);
assert_eq!(stats.layer_compression_ratios.len(), 2);
}
#[test]
fn test_pruning_reason_distribution_tracked() {
let result = PruningResult {
pruned_hidden_states: Tensor::zeros(&[1, 3, 32]).expect("must succeed"),
pruned_attention_mask: Tensor::ones(&[1, 3]).expect("must succeed"),
token_importance: TokenImportance {
importance_scores: vec![0.8, 0.5, 0.2],
token_indices: vec![0, 1, 2],
keep_mask: vec![true, true, false],
pruning_reasons: vec![
PruningReason::AlwaysKeep,
PruningReason::MinimumRatio,
PruningReason::LowAttention,
],
},
original_length: 4,
pruned_length: 3,
compression_ratio: 0.75,
};
let stats = PruningStatistics::from_results(&[result]);
assert_eq!(
stats.pruning_reason_distribution.get(&PruningReason::AlwaysKeep),
Some(&1)
);
assert_eq!(
stats.pruning_reason_distribution.get(&PruningReason::LowAttention),
Some(&1)
);
}
#[test]
fn test_progressive_pruning_linear_schedule_bounds() {
let config = ProgressivePruningConfig {
initial_pruning_ratio: 0.1,
final_pruning_ratio: 0.5,
progression_schedule: ProgressionSchedule::Linear,
};
let total_layers = 12usize;
for layer in 0..total_layers {
let progress = layer as f32 / (total_layers - 1) as f32;
let ratio = config.initial_pruning_ratio
+ (config.final_pruning_ratio - config.initial_pruning_ratio) * progress;
assert!(
ratio >= config.initial_pruning_ratio - 1e-6
&& ratio <= config.final_pruning_ratio + 1e-6,
"Linear schedule ratio {} must be in [{}, {}] for layer {}",
ratio,
config.initial_pruning_ratio,
config.final_pruning_ratio,
layer
);
}
}
}