use crate::dtype::DType;
use crate::shape::Shape;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PruningStrategy {
Magnitude { threshold_percentile: u8 },
BlockWise { block_size: (usize, usize) },
ChannelWise { channels_to_prune: usize },
AttentionHead { heads_to_prune: usize },
Movement { sensitivity: f32 },
GradualMagnitude {
initial_sparsity: f32,
final_sparsity: f32,
},
}
impl PruningStrategy {
pub fn expected_sparsity(&self) -> f32 {
match self {
Self::Magnitude {
threshold_percentile,
} => *threshold_percentile as f32 / 100.0,
Self::GradualMagnitude { final_sparsity, .. } => *final_sparsity,
_ => 0.5, }
}
pub fn is_structured(&self) -> bool {
matches!(
self,
Self::BlockWise { .. } | Self::ChannelWise { .. } | Self::AttentionHead { .. }
)
}
}
#[derive(Debug, Clone)]
pub struct PruningMetadata {
strategy: PruningStrategy,
pruned_indices: Option<Vec<usize>>,
pruned_blocks: Option<Vec<(usize, usize)>>,
pruned_channels: Option<Vec<usize>>,
achieved_sparsity: f32,
original_shape: Shape,
threshold_value: Option<f32>,
compression_ratio: f32,
}
impl PruningMetadata {
pub fn new(strategy: PruningStrategy, original_shape: Shape, achieved_sparsity: f32) -> Self {
let compression_ratio = 1.0 / (1.0 - achieved_sparsity);
Self {
strategy,
pruned_indices: None,
pruned_blocks: None,
pruned_channels: None,
achieved_sparsity,
original_shape,
threshold_value: None,
compression_ratio,
}
}
pub fn with_indices(mut self, indices: Vec<usize>) -> Self {
self.pruned_indices = Some(indices);
self
}
pub fn with_blocks(mut self, blocks: Vec<(usize, usize)>) -> Self {
self.pruned_blocks = Some(blocks);
self
}
pub fn with_channels(mut self, channels: Vec<usize>) -> Self {
self.pruned_channels = Some(channels);
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold_value = Some(threshold);
self
}
pub fn strategy(&self) -> PruningStrategy {
self.strategy
}
pub fn sparsity(&self) -> f32 {
self.achieved_sparsity
}
pub fn compression_ratio(&self) -> f32 {
self.compression_ratio
}
pub fn num_pruned_elements(&self) -> usize {
if let Some(ref indices) = self.pruned_indices {
indices.len()
} else if let Some(ref blocks) = self.pruned_blocks {
blocks.len()
} else if let Some(ref channels) = self.pruned_channels {
channels.len()
} else {
0
}
}
pub fn is_element_pruned(&self, index: usize) -> bool {
if let Some(ref indices) = self.pruned_indices {
indices.binary_search(&index).is_ok()
} else {
false
}
}
pub fn memory_savings(&self, dtype: DType) -> usize {
let total_elements = self.original_shape.numel();
let pruned_elements = (total_elements as f32 * self.achieved_sparsity) as usize;
pruned_elements * dtype.size()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CompressionEncoding {
Raw,
RunLength,
Delta,
Huffman,
Bitmap,
Hybrid,
}
impl CompressionEncoding {
pub fn expected_compression_ratio(&self) -> f32 {
match self {
Self::Raw => 1.0,
Self::RunLength => 2.0,
Self::Delta => 1.5,
Self::Huffman => 2.5,
Self::Bitmap => 3.0,
Self::Hybrid => 3.5,
}
}
pub fn requires_sorted_indices(&self) -> bool {
matches!(self, Self::RunLength | Self::Delta | Self::Hybrid)
}
}
#[derive(Debug, Clone)]
pub struct RunLengthEncoded {
start_indices: Vec<usize>,
run_lengths: Vec<usize>,
total_elements: usize,
}
impl RunLengthEncoded {
pub fn encode(indices: &[usize]) -> Self {
if indices.is_empty() {
return Self {
start_indices: vec![],
run_lengths: vec![],
total_elements: 0,
};
}
let mut start_indices = Vec::new();
let mut run_lengths = Vec::new();
let mut current_start = indices[0];
let mut current_length = 1;
for i in 1..indices.len() {
if indices[i] == indices[i - 1] + 1 {
current_length += 1;
} else {
start_indices.push(current_start);
run_lengths.push(current_length);
current_start = indices[i];
current_length = 1;
}
}
start_indices.push(current_start);
run_lengths.push(current_length);
Self {
start_indices,
run_lengths,
total_elements: indices.len(),
}
}
pub fn decode(&self) -> Vec<usize> {
let mut indices = Vec::with_capacity(self.total_elements);
for (start, length) in self.start_indices.iter().zip(self.run_lengths.iter()) {
for offset in 0..*length {
indices.push(start + offset);
}
}
indices
}
pub fn compression_ratio(&self) -> f32 {
if self.start_indices.is_empty() {
return 1.0;
}
let original_size = self.total_elements * std::mem::size_of::<usize>();
let compressed_size =
(self.start_indices.len() + self.run_lengths.len()) * std::mem::size_of::<usize>();
original_size as f32 / compressed_size as f32
}
pub fn num_runs(&self) -> usize {
self.start_indices.len()
}
}
#[derive(Debug, Clone)]
pub struct DeltaEncoded {
base_index: usize,
deltas: Vec<i32>,
total_elements: usize,
}
impl DeltaEncoded {
pub fn encode(indices: &[usize]) -> Self {
if indices.is_empty() {
return Self {
base_index: 0,
deltas: vec![],
total_elements: 0,
};
}
let base_index = indices[0];
let mut deltas = Vec::with_capacity(indices.len() - 1);
for i in 1..indices.len() {
let delta = (indices[i] as i64 - indices[i - 1] as i64) as i32;
deltas.push(delta);
}
Self {
base_index,
deltas,
total_elements: indices.len(),
}
}
pub fn decode(&self) -> Vec<usize> {
if self.total_elements == 0 {
return vec![];
}
let mut indices = Vec::with_capacity(self.total_elements);
indices.push(self.base_index);
let mut current = self.base_index as i64;
for &delta in &self.deltas {
current += delta as i64;
indices.push(current as usize);
}
indices
}
pub fn compression_ratio(&self) -> f32 {
if self.total_elements == 0 {
return 1.0;
}
let original_size = self.total_elements * std::mem::size_of::<usize>();
let compressed_size =
std::mem::size_of::<usize>() + self.deltas.len() * std::mem::size_of::<i32>();
original_size as f32 / compressed_size as f32
}
}
#[derive(Debug, Clone)]
pub struct BitmapEncoded {
start_index: usize,
bitmap: Vec<u64>,
num_elements: usize,
num_set_bits: usize,
}
impl BitmapEncoded {
pub fn encode(indices: &[usize], start: usize, end: usize) -> Self {
let num_elements = end - start;
let num_words = (num_elements + 63) / 64;
let mut bitmap = vec![0u64; num_words];
let mut num_set_bits = 0;
for &idx in indices {
if idx >= start && idx < end {
let bit_pos = idx - start;
let word_idx = bit_pos / 64;
let bit_idx = bit_pos % 64;
bitmap[word_idx] |= 1u64 << bit_idx;
num_set_bits += 1;
}
}
Self {
start_index: start,
bitmap,
num_elements,
num_set_bits,
}
}
pub fn decode(&self) -> Vec<usize> {
let mut indices = Vec::with_capacity(self.num_set_bits);
for (word_idx, &word) in self.bitmap.iter().enumerate() {
if word == 0 {
continue;
}
for bit_idx in 0..64 {
if (word & (1u64 << bit_idx)) != 0 {
let idx = self.start_index + word_idx * 64 + bit_idx;
if idx < self.start_index + self.num_elements {
indices.push(idx);
}
}
}
}
indices
}
pub fn compression_ratio(&self) -> f32 {
if self.num_set_bits == 0 {
return 1.0;
}
let original_size = self.num_set_bits * std::mem::size_of::<usize>();
let compressed_size =
std::mem::size_of::<usize>() + self.bitmap.len() * std::mem::size_of::<u64>();
original_size as f32 / compressed_size as f32
}
pub fn density(&self) -> f32 {
self.num_set_bits as f32 / self.num_elements as f32
}
}
#[derive(Debug, Clone)]
pub struct CompressionAnalysis {
pub original_size: usize,
pub compressed_size: usize,
pub compression_ratio: f32,
pub space_savings: usize,
pub encoding: CompressionEncoding,
pub sparsity: f32,
pub efficiency_score: u8,
}
impl CompressionAnalysis {
pub fn new(
original_size: usize,
compressed_size: usize,
encoding: CompressionEncoding,
sparsity: f32,
) -> Self {
let compression_ratio = if compressed_size > 0 {
original_size as f32 / compressed_size as f32
} else {
1.0
};
let space_savings = original_size.saturating_sub(compressed_size);
let theoretical_max = encoding.expected_compression_ratio();
let efficiency_score = ((compression_ratio / theoretical_max) * 100.0).min(100.0) as u8;
Self {
original_size,
compressed_size,
compression_ratio,
space_savings,
encoding,
sparsity,
efficiency_score,
}
}
pub fn is_beneficial(&self) -> bool {
self.compression_ratio > 1.1 }
pub fn savings_percentage(&self) -> f32 {
(self.space_savings as f32 / self.original_size as f32) * 100.0
}
}
#[derive(Debug, Clone)]
pub struct CompressionSelector {
sparsity_threshold: f32,
preferred_encodings: Vec<CompressionEncoding>,
}
impl CompressionSelector {
pub fn new() -> Self {
Self {
sparsity_threshold: 0.3, preferred_encodings: vec![
CompressionEncoding::Hybrid,
CompressionEncoding::Huffman,
CompressionEncoding::Bitmap,
CompressionEncoding::RunLength,
CompressionEncoding::Delta,
],
}
}
pub fn with_sparsity_threshold(mut self, threshold: f32) -> Self {
self.sparsity_threshold = threshold;
self
}
pub fn preferred_encodings(&self) -> &[CompressionEncoding] {
&self.preferred_encodings
}
pub fn select_encoding(&self, indices: &[usize], total_size: usize) -> CompressionEncoding {
if indices.is_empty() {
return CompressionEncoding::Raw;
}
let sparsity = 1.0 - (indices.len() as f32 / total_size as f32);
if sparsity < self.sparsity_threshold {
return CompressionEncoding::Raw;
}
let consecutive_ratio = self.calculate_consecutive_ratio(indices);
if consecutive_ratio > 0.7 {
return CompressionEncoding::RunLength;
}
let avg_delta = self.calculate_average_delta(indices);
if avg_delta < 10.0 {
return CompressionEncoding::Delta;
}
if self.has_dense_regions(indices) {
return CompressionEncoding::Bitmap;
}
CompressionEncoding::Hybrid
}
fn calculate_consecutive_ratio(&self, indices: &[usize]) -> f32 {
if indices.len() < 2 {
return 0.0;
}
let mut consecutive_count = 0;
for i in 1..indices.len() {
if indices[i] == indices[i - 1] + 1 {
consecutive_count += 1;
}
}
consecutive_count as f32 / (indices.len() - 1) as f32
}
fn calculate_average_delta(&self, indices: &[usize]) -> f32 {
if indices.len() < 2 {
return 0.0;
}
let mut total_delta = 0i64;
for i in 1..indices.len() {
total_delta += (indices[i] as i64 - indices[i - 1] as i64).abs();
}
total_delta as f32 / (indices.len() - 1) as f32
}
fn has_dense_regions(&self, indices: &[usize]) -> bool {
if indices.len() < 10 {
return false;
}
let min_idx = *indices.iter().min().expect("reduction should succeed");
let max_idx = *indices.iter().max().expect("reduction should succeed");
let range = max_idx - min_idx + 1;
if range == 0 {
return false;
}
let density = indices.len() as f32 / range as f32;
density > 0.8
}
}
impl Default for CompressionSelector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MagnitudeThresholdCalculator;
impl MagnitudeThresholdCalculator {
pub fn from_percentile(values: &[f32], percentile: u8) -> f32 {
if values.is_empty() {
return 0.0;
}
let mut sorted_values: Vec<f32> = values.iter().map(|v| v.abs()).collect();
sorted_values.sort_by(|a, b| {
a.partial_cmp(b)
.expect("absolute values should be comparable (no NaN)")
});
let index = ((percentile as f32 / 100.0) * sorted_values.len() as f32) as usize;
let index = index.min(sorted_values.len() - 1);
sorted_values[index]
}
pub fn from_top_k(values: &[f32], k: usize) -> f32 {
if values.is_empty() || k == 0 {
return 0.0;
}
let mut sorted_values: Vec<f32> = values.iter().map(|v| v.abs()).collect();
sorted_values.sort_by(|a, b| {
b.partial_cmp(a)
.expect("absolute values should be comparable (no NaN)")
});
let k = k.min(sorted_values.len());
sorted_values[k - 1]
}
pub fn from_std_dev(values: &[f32], num_std_dev: f32) -> f32 {
if values.is_empty() {
return 0.0;
}
let mean = values.iter().sum::<f32>() / values.len() as f32;
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
let std_dev = variance.sqrt();
mean.abs() - num_std_dev * std_dev
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pruning_strategy() {
let strategy = PruningStrategy::Magnitude {
threshold_percentile: 50,
};
assert_eq!(strategy.expected_sparsity(), 0.5);
assert!(!strategy.is_structured());
let structured = PruningStrategy::BlockWise { block_size: (4, 4) };
assert!(structured.is_structured());
}
#[test]
fn test_run_length_encoding() {
let indices = vec![0, 1, 2, 3, 10, 11, 12, 20];
let encoded = RunLengthEncoded::encode(&indices);
assert_eq!(encoded.num_runs(), 3);
assert_eq!(encoded.decode(), indices);
assert!(encoded.compression_ratio() > 1.0);
}
#[test]
fn test_delta_encoding() {
let indices = vec![5, 10, 15, 20, 25];
let encoded = DeltaEncoded::encode(&indices);
assert_eq!(encoded.decode(), indices);
assert!(encoded.compression_ratio() > 1.0);
}
#[test]
fn test_bitmap_encoding() {
let indices = vec![0, 1, 3, 5, 7];
let encoded = BitmapEncoded::encode(&indices, 0, 10);
assert_eq!(encoded.num_set_bits, 5);
assert_eq!(encoded.decode(), indices);
assert_eq!(encoded.density(), 0.5);
}
#[test]
fn test_compression_selector() {
let selector = CompressionSelector::new();
let consecutive = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let encoding = selector.select_encoding(&consecutive, 100);
assert_eq!(encoding, CompressionEncoding::RunLength);
let small_deltas = vec![0, 1, 3, 4, 6, 7, 9, 10];
let encoding = selector.select_encoding(&small_deltas, 100);
assert!(matches!(
encoding,
CompressionEncoding::Delta | CompressionEncoding::RunLength
));
}
#[test]
fn test_magnitude_threshold() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let threshold = MagnitudeThresholdCalculator::from_percentile(&values, 50);
assert!((threshold - 6.0).abs() < 0.1);
let threshold = MagnitudeThresholdCalculator::from_top_k(&values, 3);
assert!((threshold - 8.0).abs() < 0.1);
}
#[test]
fn test_pruning_metadata() {
let shape = Shape::new(vec![10, 10]);
let metadata = PruningMetadata::new(
PruningStrategy::Magnitude {
threshold_percentile: 50,
},
shape,
0.5,
)
.with_indices(vec![0, 1, 2, 3, 4])
.with_threshold(0.1);
assert_eq!(metadata.sparsity(), 0.5);
assert_eq!(metadata.compression_ratio(), 2.0);
assert_eq!(metadata.num_pruned_elements(), 5);
assert!(metadata.is_element_pruned(2));
assert!(!metadata.is_element_pruned(10));
}
#[test]
fn test_compression_analysis() {
let analysis = CompressionAnalysis::new(1000, 250, CompressionEncoding::Huffman, 0.75);
assert_eq!(analysis.compression_ratio, 4.0);
assert_eq!(analysis.space_savings, 750);
assert!(analysis.is_beneficial());
assert_eq!(analysis.savings_percentage(), 75.0);
}
}