#![allow(dead_code)]
use super::super::common::{
traits::SparseLayer,
types::{SparseLayerConfig, SparseStats},
utils::SparseWeightGenerator,
};
use crate::{CsrTensor, SparseTensor, TorshResult};
use torsh_core::TorshError;
use torsh_tensor::{creation::randn, Tensor};
pub struct SparseLinear {
weight: CsrTensor,
bias: Option<Tensor>,
in_features: usize,
out_features: usize,
sparsity: f32,
training: bool,
config: SparseLayerConfig,
}
impl SparseLinear {
pub fn new(
in_features: usize,
out_features: usize,
sparsity: f32,
use_bias: bool,
) -> TorshResult<Self> {
let config = SparseLayerConfig::linear(sparsity);
Self::with_config(in_features, out_features, config, use_bias)
}
pub fn with_config(
in_features: usize,
out_features: usize,
config: SparseLayerConfig,
use_bias: bool,
) -> TorshResult<Self> {
config
.validate()
.map_err(|e| TorshError::InvalidArgument(e))?;
if !(0.0..=1.0).contains(&config.input_sparsity) {
return Err(TorshError::InvalidArgument(
"Sparsity must be between 0.0 and 1.0".to_string(),
));
}
let weight = SparseWeightGenerator::generate_sparse_weights(
out_features,
in_features,
config.input_sparsity,
)?;
let bias = if use_bias && config.use_bias {
Some(randn::<f32>(&[out_features])?)
} else {
None
};
Ok(Self {
weight,
bias,
in_features,
out_features,
sparsity: config.input_sparsity,
training: true,
config,
})
}
pub fn from_weight(weight: CsrTensor, bias: Option<Tensor>) -> TorshResult<Self> {
let shape = weight.shape();
if shape.ndim() != 2 {
return Err(TorshError::InvalidArgument(
"Weight matrix must be 2D".to_string(),
));
}
let out_features = shape.dims()[0];
let in_features = shape.dims()[1];
if let Some(ref bias_tensor) = bias {
if bias_tensor.shape().dims() != [out_features] {
return Err(TorshError::InvalidArgument(
"Bias dimension must match output features".to_string(),
));
}
}
let total_elements = out_features * in_features;
let nnz = weight.nnz();
let sparsity = 1.0 - (nnz as f32 / total_elements as f32);
let config = SparseLayerConfig::linear(sparsity);
Ok(Self {
weight,
bias,
in_features,
out_features,
sparsity,
training: true,
config,
})
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
let _batch_size = if input_shape.ndim() == 1 {
1
} else if input_shape.ndim() == 2 {
input_shape.dims()[0]
} else {
return Err(TorshError::InvalidArgument(
"Input must be 1D or 2D tensor".to_string(),
));
};
let input_features = if input_shape.ndim() == 1 {
input_shape.dims()[0]
} else {
input_shape.dims()[1]
};
if input_features != self.in_features {
return Err(TorshError::InvalidArgument(format!(
"Input features {} don't match layer input features {}",
input_features, self.in_features
)));
}
let output = if input_shape.ndim() == 1 {
self.weight.matvec(input)?
} else {
self.weight.matmul(input)?
};
if let Some(ref bias) = self.bias {
output.add_bias(bias)
} else {
Ok(output)
}
}
pub fn weight(&self) -> &CsrTensor {
&self.weight
}
pub fn weight_mut(&mut self) -> &mut CsrTensor {
&mut self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
self.bias.as_mut()
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn sparsity(&self) -> f32 {
self.sparsity
}
pub fn prune_to_sparsity(&mut self, target_sparsity: f32) -> TorshResult<()> {
if !(0.0..=1.0).contains(&target_sparsity) {
return Err(TorshError::InvalidArgument(
"Target sparsity must be between 0.0 and 1.0".to_string(),
));
}
if target_sparsity <= self.sparsity {
return Ok(()); }
let dense_weight = self.weight.to_dense()?;
self.weight = SparseWeightGenerator::prune_by_magnitude(&dense_weight, target_sparsity)?;
self.sparsity = target_sparsity;
Ok(())
}
pub fn config(&self) -> &SparseLayerConfig {
&self.config
}
pub fn output_dimensions(&self, input_dims: &[usize]) -> Vec<usize> {
match input_dims.len() {
1 => vec![self.out_features],
2 => vec![input_dims[0], self.out_features],
_ => vec![], }
}
pub fn memory_stats(&self) -> SparseMemoryStats {
let dense_params = self.in_features * self.out_features;
let sparse_params = self.weight.nnz();
let bias_params = self.bias.as_ref().map_or(0, |b| b.numel());
let dense_memory = dense_params * std::mem::size_of::<f32>();
let sparse_memory = sparse_params
* (std::mem::size_of::<f32>() + std::mem::size_of::<usize>())
+ bias_params * std::mem::size_of::<f32>();
SparseMemoryStats {
dense_parameters: dense_params,
sparse_parameters: sparse_params,
bias_parameters: bias_params,
dense_memory_bytes: dense_memory,
sparse_memory_bytes: sparse_memory,
memory_reduction: 1.0 - (sparse_memory as f32 / dense_memory as f32),
}
}
}
impl SparseLayer for SparseLinear {
fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let dense_input = input.to_dense()?;
let output = self.forward(&dense_input)?;
let sparse_output = SparseWeightGenerator::dense_to_sparse(&output)?;
Ok(Box::new(sparse_output))
}
fn parameters(&self) -> Vec<&CsrTensor> {
vec![&self.weight]
}
fn parameters_mut(&mut self) -> Vec<&mut CsrTensor> {
vec![&mut self.weight]
}
fn layer_type(&self) -> &'static str {
"SparseLinear"
}
fn dimensions(&self) -> (Vec<usize>, Vec<usize>) {
let input_dims = vec![self.in_features];
let output_dims = vec![self.out_features];
(input_dims, output_dims)
}
fn sparsity_stats(&self) -> SparseStats {
let mut stats = SparseStats::new();
stats.update(&self.weight, true);
stats
}
fn train(&mut self, training: bool) {
self.training = training;
}
fn training(&self) -> bool {
self.training
}
}
#[derive(Debug, Clone)]
pub struct SparseMemoryStats {
pub dense_parameters: usize,
pub sparse_parameters: usize,
pub bias_parameters: usize,
pub dense_memory_bytes: usize,
pub sparse_memory_bytes: usize,
pub memory_reduction: f32,
}
impl SparseMemoryStats {
pub fn compression_ratio(&self) -> f32 {
self.dense_parameters as f32 / self.sparse_parameters as f32
}
pub fn efficiency_score(&self) -> f32 {
self.memory_reduction
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_linear_creation() {
let layer = SparseLinear::new(100, 50, 0.9, true);
assert!(layer.is_ok());
let layer = layer.expect("operation should succeed");
assert_eq!(layer.in_features(), 100);
assert_eq!(layer.out_features(), 50);
assert_eq!(layer.sparsity(), 0.9);
assert!(layer.bias().is_some());
}
#[test]
fn test_sparse_linear_dimensions() {
let layer = SparseLinear::new(784, 128, 0.8, false).expect("Sparse Linear should succeed");
let output_dims = layer.output_dimensions(&[784]);
assert_eq!(output_dims, vec![128]);
let output_dims = layer.output_dimensions(&[32, 784]);
assert_eq!(output_dims, vec![32, 128]);
}
#[test]
fn test_memory_stats() {
let layer = SparseLinear::new(100, 50, 0.9, true).expect("Sparse Linear should succeed");
let stats = layer.memory_stats();
assert_eq!(stats.dense_parameters, 5000); assert!(stats.sparse_parameters < stats.dense_parameters);
assert_eq!(stats.bias_parameters, 50);
assert!(stats.memory_reduction > 0.0);
}
#[test]
fn test_sparsity_validation() {
let result = SparseLinear::new(10, 10, 1.5, true);
assert!(result.is_err());
let result = SparseLinear::new(10, 10, -0.1, true);
assert!(result.is_err());
}
#[test]
fn test_pruning() {
let mut layer =
SparseLinear::new(10, 10, 0.5, false).expect("Sparse Linear should succeed");
let initial_sparsity = layer.sparsity();
let result = layer.prune_to_sparsity(0.8);
assert!(result.is_ok());
assert!(layer.sparsity() > initial_sparsity);
}
}