use crate::{CsrTensor, SparseTensor, TorshResult};
use std::collections::HashMap;
use torsh_tensor::Tensor;
pub trait SparseOptimizer {
fn step(
&mut self,
parameters: &mut [&mut CsrTensor],
gradients: &[&CsrTensor],
) -> TorshResult<()>;
fn zero_grad(&mut self) {}
fn lr(&self) -> f32;
fn set_lr(&mut self, lr: f32);
fn state_dict(&self) -> HashMap<String, Tensor> {
HashMap::new()
}
fn load_state_dict(&mut self, _state: HashMap<String, Tensor>) -> TorshResult<()> {
Ok(())
}
fn name(&self) -> &'static str;
fn hyperparameters(&self) -> HashMap<String, f32>;
}
pub trait SparseLayer {
fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>>;
fn parameters(&self) -> Vec<&CsrTensor>;
fn parameters_mut(&mut self) -> Vec<&mut CsrTensor>;
fn layer_type(&self) -> &'static str;
fn dimensions(&self) -> (Vec<usize>, Vec<usize>);
fn sparsity_stats(&self) -> super::types::SparseStats;
fn train(&mut self, training: bool);
fn training(&self) -> bool;
}
pub trait SparseActivation {
fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>>;
fn forward_inplace(&self, input: &mut dyn SparseTensor) -> TorshResult<()> {
let _result = self.forward(input)?;
Ok(())
}
fn name(&self) -> &'static str;
fn preserves_sparsity(&self) -> bool;
fn can_increase_sparsity(&self) -> bool;
}
pub trait SparsePooling {
fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>>;
fn operation_type(&self) -> &'static str;
fn kernel_size(&self) -> (usize, usize);
fn stride(&self) -> (usize, usize);
fn padding(&self) -> (usize, usize);
fn output_dimensions(&self, input_dims: &[usize]) -> Vec<usize>;
}
pub trait SparseNormalization {
fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>>;
fn update_stats(&mut self, input: &dyn SparseTensor) -> TorshResult<()>;
fn norm_type(&self) -> &'static str;
fn learnable_parameters(&self) -> Vec<&Tensor>;
fn learnable_parameters_mut(&mut self) -> Vec<&mut Tensor>;
}
pub trait SparseConverter {
fn to_csr(&self) -> TorshResult<CsrTensor>;
fn to_coo(&self) -> TorshResult<crate::CooTensor>;
fn to_csc(&self) -> TorshResult<crate::CscTensor>;
fn current_format(&self) -> super::types::SparseFormat;
fn needs_conversion(&self, target_format: super::types::SparseFormat) -> bool {
self.current_format() != target_format
}
}
pub trait SparseInitializer {
fn initialize(
&self,
shape: &[usize],
config: &super::types::SparseInitConfig,
) -> TorshResult<CsrTensor>;
fn from_dense(&self, dense: &Tensor, sparsity: f32) -> TorshResult<CsrTensor>;
fn strategy_name(&self) -> &'static str;
}
pub trait SparsePruner {
fn prune(&self, tensor: &CsrTensor, target_sparsity: f32) -> TorshResult<CsrTensor>;
fn prune_with_gradients(
&self,
tensor: &CsrTensor,
gradients: &CsrTensor,
target_sparsity: f32,
) -> TorshResult<CsrTensor>;
fn pruning_strategy(&self) -> &'static str;
fn is_structured(&self) -> bool;
}
pub trait SparseAnalyzer {
fn analyze_model_sparsity(&self, layers: &[&dyn SparseLayer]) -> ModelSparsityAnalysis;
fn recommend_optimizations(&self, analysis: &ModelSparsityAnalysis) -> Vec<String>;
fn estimate_savings(&self, analysis: &ModelSparsityAnalysis) -> SavingsEstimate;
}
#[derive(Debug, Clone)]
pub struct ModelSparsityAnalysis {
pub overall_sparsity: f32,
pub layer_sparsities: Vec<f32>,
pub total_parameters: usize,
pub sparse_parameters: usize,
pub memory_reduction: f32,
pub flops_reduction: f32,
}
#[derive(Debug, Clone)]
pub struct SavingsEstimate {
pub memory_savings: f32,
pub compute_savings: f32,
pub energy_savings: f32,
pub storage_savings: f32,
}
impl SavingsEstimate {
pub fn efficiency_score(&self) -> f32 {
(self.memory_savings + self.compute_savings + self.energy_savings + self.storage_savings)
/ 4.0
}
}
pub mod defaults {
use super::*;
pub struct DefaultSparseInitializer;
impl SparseInitializer for DefaultSparseInitializer {
fn initialize(
&self,
shape: &[usize],
config: &super::super::types::SparseInitConfig,
) -> TorshResult<CsrTensor> {
if shape.len() != 2 {
return Err(crate::TorshError::InvalidArgument(
"Only 2D shapes supported".to_string(),
));
}
super::super::utils::SparseWeightGenerator::from_config(shape[0], shape[1], config)
}
fn from_dense(&self, dense: &Tensor, sparsity: f32) -> TorshResult<CsrTensor> {
super::super::utils::SparseWeightGenerator::prune_by_magnitude(dense, sparsity)
}
fn strategy_name(&self) -> &'static str {
"default_random"
}
}
pub struct MagnitudePruner;
impl SparsePruner for MagnitudePruner {
fn prune(&self, tensor: &CsrTensor, target_sparsity: f32) -> TorshResult<CsrTensor> {
let dense = tensor.to_dense()?;
super::super::utils::SparseWeightGenerator::prune_by_magnitude(&dense, target_sparsity)
}
fn prune_with_gradients(
&self,
tensor: &CsrTensor,
_gradients: &CsrTensor,
target_sparsity: f32,
) -> TorshResult<CsrTensor> {
self.prune(tensor, target_sparsity)
}
fn pruning_strategy(&self) -> &'static str {
"magnitude"
}
fn is_structured(&self) -> bool {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_savings_estimate() {
let estimate = SavingsEstimate {
memory_savings: 0.8,
compute_savings: 0.7,
energy_savings: 0.75,
storage_savings: 0.85,
};
assert_eq!(estimate.efficiency_score(), 0.775);
}
#[test]
fn test_default_initializer() {
let initializer = defaults::DefaultSparseInitializer;
assert_eq!(initializer.strategy_name(), "default_random");
let config = super::super::types::SparseInitConfig::default();
let result = initializer.initialize(&[10, 10], &config);
assert!(result.is_ok());
}
#[test]
fn test_magnitude_pruner() {
let pruner = defaults::MagnitudePruner;
assert_eq!(pruner.pruning_strategy(), "magnitude");
assert!(!pruner.is_structured());
}
}