use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
use crate::init::{constant, kaiming_uniform, zeros};
use crate::module::Module;
use crate::parameter::Parameter;
const TEMPERATURE: f32 = 10.0;
const DEFAULT_THRESHOLD: f32 = 0.01;
pub struct SparseLinear {
pub weight: Parameter,
pub bias: Option<Parameter>,
pub threshold: Parameter,
in_features: usize,
out_features: usize,
structured: bool,
}
impl SparseLinear {
pub fn new(in_features: usize, out_features: usize) -> Self {
Self::build(in_features, out_features, true, true)
}
pub fn unstructured(in_features: usize, out_features: usize) -> Self {
Self::build(in_features, out_features, false, true)
}
pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
Self::build(in_features, out_features, true, bias)
}
fn build(in_features: usize, out_features: usize, structured: bool, bias: bool) -> Self {
let weight_data = kaiming_uniform(out_features, in_features);
let weight = Parameter::named("weight", weight_data, true);
let bias_param = if bias {
let bias_data = zeros(&[out_features]);
Some(Parameter::named("bias", bias_data, true))
} else {
None
};
let threshold_data = if structured {
constant(&[out_features], DEFAULT_THRESHOLD)
} else {
constant(&[out_features, in_features], DEFAULT_THRESHOLD)
};
let threshold = Parameter::named("threshold", threshold_data, true);
Self {
weight,
bias: bias_param,
threshold,
in_features,
out_features,
structured,
}
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn is_structured(&self) -> bool {
self.structured
}
fn hard_mask(&self) -> Tensor<f32> {
let weight_data = self.weight.data();
let threshold_data = self.threshold.data();
let w_vec = weight_data.to_vec();
let t_vec = threshold_data.to_vec();
let mask_vec: Vec<f32> = if self.structured {
w_vec
.iter()
.enumerate()
.map(|(idx, &w)| {
let out_idx = idx / self.in_features;
let t = t_vec[out_idx];
if w.abs() >= t { 1.0 } else { 0.0 }
})
.collect()
} else {
w_vec
.iter()
.zip(t_vec.iter())
.map(|(&w, &t)| if w.abs() >= t { 1.0 } else { 0.0 })
.collect()
};
Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
.expect("tensor creation failed")
}
pub fn density(&self) -> f32 {
let mask = self.hard_mask();
let mask_vec = mask.to_vec();
let total = mask_vec.len() as f32;
let active: f32 = mask_vec.iter().sum();
active / total
}
pub fn sparsity(&self) -> f32 {
1.0 - self.density()
}
pub fn num_active(&self) -> usize {
let mask = self.hard_mask();
let mask_vec = mask.to_vec();
mask_vec.iter().filter(|&&v| v > 0.5).count()
}
pub fn hard_prune(&mut self) {
let mask = self.hard_mask();
let weight_data = self.weight.data();
let w_vec = weight_data.to_vec();
let m_vec = mask.to_vec();
let pruned: Vec<f32> = w_vec
.iter()
.zip(m_vec.iter())
.map(|(&w, &m)| w * m)
.collect();
let new_weight = Tensor::from_vec(pruned, &[self.out_features, self.in_features])
.expect("tensor creation failed");
self.weight.update_data(new_weight);
let zero_threshold = if self.structured {
zeros(&[self.out_features])
} else {
zeros(&[self.out_features, self.in_features])
};
self.threshold.update_data(zero_threshold);
}
pub fn reset_threshold(&mut self, value: f32) {
let new_threshold = if self.structured {
constant(&[self.out_features], value)
} else {
constant(&[self.out_features, self.in_features], value)
};
self.threshold.update_data(new_threshold);
}
pub fn effective_weight(&self) -> Tensor<f32> {
let mask = self.hard_mask();
let weight_data = self.weight.data();
let w_vec = weight_data.to_vec();
let m_vec = mask.to_vec();
let effective: Vec<f32> = w_vec
.iter()
.zip(m_vec.iter())
.map(|(&w, &m)| w * m)
.collect();
Tensor::from_vec(effective, &[self.out_features, self.in_features])
.expect("tensor creation failed")
}
fn compute_soft_mask(&self, weight_var: &Variable) -> Variable {
let weight_data = weight_var.data();
let threshold_data = self.threshold.data();
let w_vec = weight_data.to_vec();
let t_vec = threshold_data.to_vec();
let mask_vec: Vec<f32> = if self.structured {
w_vec
.iter()
.enumerate()
.map(|(idx, &w)| {
let out_idx = idx / self.in_features;
let t = t_vec[out_idx];
let x = (w.abs() - t) * TEMPERATURE;
1.0 / (1.0 + (-x).exp())
})
.collect()
} else {
w_vec
.iter()
.zip(t_vec.iter())
.map(|(&w, &t)| {
let x = (w.abs() - t) * TEMPERATURE;
1.0 / (1.0 + (-x).exp())
})
.collect()
};
let mask_tensor = Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
.expect("tensor creation failed");
Variable::new(mask_tensor, false)
}
}
impl Module for SparseLinear {
fn forward(&self, input: &Variable) -> Variable {
let input_shape = input.shape();
let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
let total_batch: usize = batch_dims.iter().product();
let input_2d = if input_shape.len() > 2 {
input.reshape(&[total_batch, self.in_features])
} else {
input.clone()
};
let weight_var = self.weight.variable();
let mask = self.compute_soft_mask(&weight_var);
let effective_weight = weight_var.mul_var(&mask);
let weight_t = effective_weight.transpose(0, 1);
let mut output = input_2d.matmul(&weight_t);
if let Some(ref bias) = self.bias {
let bias_var = bias.variable();
output = output.add_var(&bias_var);
}
if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
let mut output_shape: Vec<usize> = batch_dims;
output_shape.push(self.out_features);
output.reshape(&output_shape)
} else {
output
}
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = vec![self.weight.clone(), self.threshold.clone()];
if let Some(ref bias) = self.bias {
params.push(bias.clone());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.insert("weight".to_string(), self.weight.clone());
params.insert("threshold".to_string(), self.threshold.clone());
if let Some(ref bias) = self.bias {
params.insert("bias".to_string(), bias.clone());
}
params
}
fn name(&self) -> &'static str {
"SparseLinear"
}
}
impl std::fmt::Debug for SparseLinear {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SparseLinear")
.field("in_features", &self.in_features)
.field("out_features", &self.out_features)
.field("bias", &self.bias.is_some())
.field("structured", &self.structured)
.field("density", &self.density())
.finish()
}
}
pub struct GroupSparsity {
lambda: f32,
group_size: usize,
}
impl GroupSparsity {
pub fn new(lambda: f32, group_size: usize) -> Self {
assert!(group_size > 0, "group_size must be positive");
Self { lambda, group_size }
}
pub fn lambda(&self) -> f32 {
self.lambda
}
pub fn group_size(&self) -> usize {
self.group_size
}
pub fn penalty(&self, weight: &Variable) -> Variable {
let weight_data = weight.data();
let w_vec = weight_data.to_vec();
let total = w_vec.len();
let num_groups = total.div_ceil(self.group_size);
let mut group_norm_sum = 0.0f32;
for g in 0..num_groups {
let start = g * self.group_size;
let end = (start + self.group_size).min(total);
let group = &w_vec[start..end];
let l2_norm: f32 = group.iter().map(|&x| x * x).sum::<f32>().sqrt();
group_norm_sum += l2_norm;
}
let penalty_val = self.lambda * group_norm_sum;
let penalty_tensor =
Tensor::from_vec(vec![penalty_val], &[1]).expect("tensor creation failed");
Variable::new(penalty_tensor, false)
}
}
impl std::fmt::Debug for GroupSparsity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GroupSparsity")
.field("lambda", &self.lambda)
.field("group_size", &self.group_size)
.finish()
}
}
pub struct LotteryTicket {
initial_weights: HashMap<String, Tensor<f32>>,
}
impl LotteryTicket {
pub fn snapshot(params: &[Parameter]) -> Self {
let mut initial_weights = HashMap::new();
for (i, param) in params.iter().enumerate() {
let key = if param.name().is_empty() {
format!("param_{}", i)
} else {
param.name().to_string()
};
initial_weights.insert(key, param.data());
}
Self { initial_weights }
}
pub fn num_saved(&self) -> usize {
self.initial_weights.len()
}
pub fn rewind(&self, params: &[Parameter]) {
for (i, param) in params.iter().enumerate() {
let key = if param.name().is_empty() {
format!("param_{}", i)
} else {
param.name().to_string()
};
if let Some(initial) = self.initial_weights.get(&key) {
param.update_data(initial.clone());
}
}
}
pub fn rewind_with_mask(&self, params: &[Parameter], masks: &[Tensor<f32>]) {
assert_eq!(
params.len(),
masks.len(),
"Number of parameters and masks must match"
);
for (i, (param, mask)) in params.iter().zip(masks.iter()).enumerate() {
let key = if param.name().is_empty() {
format!("param_{}", i)
} else {
param.name().to_string()
};
if let Some(initial) = self.initial_weights.get(&key) {
let init_vec = initial.to_vec();
let mask_vec = mask.to_vec();
let rewound: Vec<f32> = init_vec
.iter()
.zip(mask_vec.iter())
.map(|(&w, &m)| if m > 0.5 { w } else { 0.0 })
.collect();
let shape = param.shape();
let new_data = Tensor::from_vec(rewound, &shape).expect("tensor creation failed");
param.update_data(new_data);
}
}
}
}
impl std::fmt::Debug for LotteryTicket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LotteryTicket")
.field("num_saved", &self.initial_weights.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_linear_creation_structured() {
let layer = SparseLinear::new(10, 5);
assert_eq!(layer.in_features(), 10);
assert_eq!(layer.out_features(), 5);
assert!(layer.is_structured());
assert!(layer.bias.is_some());
}
#[test]
fn test_sparse_linear_creation_unstructured() {
let layer = SparseLinear::unstructured(10, 5);
assert_eq!(layer.in_features(), 10);
assert_eq!(layer.out_features(), 5);
assert!(!layer.is_structured());
assert!(layer.bias.is_some());
}
#[test]
fn test_sparse_linear_no_bias() {
let layer = SparseLinear::with_bias(10, 5, false);
assert!(layer.bias.is_none());
}
#[test]
fn test_sparse_linear_forward_shape() {
let layer = SparseLinear::new(4, 3);
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
false,
);
let output = layer.forward(&input);
assert_eq!(output.shape(), vec![1, 3]);
}
#[test]
fn test_sparse_linear_forward_batch() {
let layer = SparseLinear::new(4, 3);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 12], &[3, 4]).expect("tensor creation failed"),
false,
);
let output = layer.forward(&input);
assert_eq!(output.shape(), vec![3, 3]);
}
#[test]
fn test_sparse_linear_forward_no_bias() {
let layer = SparseLinear::with_bias(4, 3, false);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 8], &[2, 4]).expect("tensor creation failed"),
false,
);
let output = layer.forward(&input);
assert_eq!(output.shape(), vec![2, 3]);
}
#[test]
fn test_sparse_linear_density_initial() {
let layer = SparseLinear::new(100, 50);
let density = layer.density();
assert!(
density > 0.9,
"Initial density should be high, got {}",
density
);
}
#[test]
fn test_sparse_linear_sparsity_initial() {
let layer = SparseLinear::new(100, 50);
let sparsity = layer.sparsity();
assert!(
sparsity < 0.1,
"Initial sparsity should be low, got {}",
sparsity
);
assert!((layer.density() + layer.sparsity() - 1.0).abs() < 1e-6);
}
#[test]
fn test_sparse_linear_num_active() {
let layer = SparseLinear::new(10, 5);
let active = layer.num_active();
let total = 10 * 5;
assert!(active <= total);
assert!(active > 0);
}
#[test]
fn test_sparse_linear_high_threshold_more_sparsity() {
let mut layer = SparseLinear::new(100, 50);
let density_low_thresh = layer.density();
layer.reset_threshold(10.0);
let density_high_thresh = layer.density();
assert!(
density_high_thresh < density_low_thresh,
"Higher threshold should reduce density: low_thresh={}, high_thresh={}",
density_low_thresh,
density_high_thresh
);
}
#[test]
fn test_sparse_linear_low_threshold_dense() {
let mut layer = SparseLinear::new(100, 50);
layer.reset_threshold(0.0);
let density = layer.density();
assert!(
(density - 1.0).abs() < 1e-6,
"Zero threshold should give density=1.0, got {}",
density
);
}
#[test]
fn test_sparse_linear_soft_mask_values_in_range() {
let layer = SparseLinear::new(10, 5);
let weight_var = layer.weight.variable();
let mask = layer.compute_soft_mask(&weight_var);
let mask_vec = mask.data().to_vec();
for &v in &mask_vec {
assert!(
(0.0..=1.0).contains(&v),
"Soft mask value {} not in [0, 1]",
v
);
}
}
#[test]
fn test_sparse_linear_hard_prune() {
let mut layer = SparseLinear::new(10, 5);
layer.reset_threshold(0.5);
let pre_prune_density = layer.density();
layer.hard_prune();
let weight_data = layer.weight.data();
let w_vec = weight_data.to_vec();
let zeros_count = w_vec.iter().filter(|&&v| v == 0.0).count();
let expected_zeros = ((1.0 - pre_prune_density) * (10 * 5) as f32).round() as usize;
assert_eq!(
zeros_count, expected_zeros,
"Hard prune should zero out pruned weights"
);
}
#[test]
fn test_sparse_linear_hard_prune_threshold_reset() {
let mut layer = SparseLinear::new(10, 5);
layer.reset_threshold(0.5);
layer.hard_prune();
let t_vec = layer.threshold.data().to_vec();
assert!(
t_vec.iter().all(|&v| v == 0.0),
"Thresholds should be zero after hard_prune"
);
}
#[test]
fn test_sparse_linear_effective_weight() {
let layer = SparseLinear::new(10, 5);
let ew = layer.effective_weight();
assert_eq!(ew.shape(), &[5, 10]);
}
#[test]
fn test_sparse_linear_effective_weight_matches_hard_prune() {
let mut layer = SparseLinear::new(10, 5);
layer.reset_threshold(0.3);
let effective = layer.effective_weight();
layer.hard_prune();
let pruned = layer.weight.data();
let e_vec = effective.to_vec();
let p_vec = pruned.to_vec();
for (e, p) in e_vec.iter().zip(p_vec.iter()) {
assert!(
(e - p).abs() < 1e-6,
"effective_weight and hard_prune should match"
);
}
}
#[test]
fn test_sparse_linear_parameters_include_threshold() {
let layer = SparseLinear::new(10, 5);
let params = layer.parameters();
assert_eq!(params.len(), 3);
let named = layer.named_parameters();
assert!(named.contains_key("threshold"));
assert!(named.contains_key("weight"));
assert!(named.contains_key("bias"));
}
#[test]
fn test_sparse_linear_parameters_no_bias() {
let layer = SparseLinear::with_bias(10, 5, false);
let params = layer.parameters();
assert_eq!(params.len(), 2);
}
#[test]
fn test_sparse_linear_module_name() {
let layer = SparseLinear::new(10, 5);
assert_eq!(layer.name(), "SparseLinear");
}
#[test]
fn test_sparse_linear_debug() {
let layer = SparseLinear::new(10, 5);
let debug_str = format!("{:?}", layer);
assert!(debug_str.contains("SparseLinear"));
assert!(debug_str.contains("in_features: 10"));
assert!(debug_str.contains("out_features: 5"));
}
#[test]
fn test_sparse_linear_reset_threshold() {
let mut layer = SparseLinear::new(10, 5);
layer.reset_threshold(0.5);
let t_vec = layer.threshold.data().to_vec();
assert!(t_vec.iter().all(|&v| (v - 0.5).abs() < 1e-6));
}
#[test]
fn test_sparse_linear_unstructured_threshold_shape() {
let layer = SparseLinear::unstructured(10, 5);
assert_eq!(layer.threshold.shape(), vec![5, 10]);
}
#[test]
fn test_sparse_linear_structured_threshold_shape() {
let layer = SparseLinear::new(10, 5);
assert_eq!(layer.threshold.shape(), vec![5]);
}
#[test]
fn test_sparse_linear_unstructured_forward() {
let layer = SparseLinear::unstructured(4, 3);
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
.expect("tensor creation failed"),
false,
);
let output = layer.forward(&input);
assert_eq!(output.shape(), vec![2, 3]);
}
#[test]
fn test_group_sparsity_creation() {
let reg = GroupSparsity::new(0.001, 10);
assert!((reg.lambda() - 0.001).abs() < 1e-8);
assert_eq!(reg.group_size(), 10);
}
#[test]
fn test_group_sparsity_penalty_non_negative() {
let reg = GroupSparsity::new(0.01, 4);
let weight = Variable::new(
Tensor::from_vec(vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0], &[2, 4])
.expect("tensor creation failed"),
true,
);
let penalty = reg.penalty(&weight);
let penalty_val = penalty.data().to_vec()[0];
assert!(
penalty_val >= 0.0,
"Penalty should be non-negative, got {}",
penalty_val
);
}
#[test]
fn test_group_sparsity_zero_weights_zero_penalty() {
let reg = GroupSparsity::new(0.01, 4);
let weight = Variable::new(
Tensor::from_vec(vec![0.0; 8], &[2, 4]).expect("tensor creation failed"),
true,
);
let penalty = reg.penalty(&weight);
let penalty_val = penalty.data().to_vec()[0];
assert!(
(penalty_val).abs() < 1e-6,
"Zero weights should give zero penalty, got {}",
penalty_val
);
}
#[test]
fn test_group_sparsity_scales_with_lambda() {
let reg_small = GroupSparsity::new(0.001, 4);
let reg_large = GroupSparsity::new(0.01, 4);
let weight = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
true,
);
let penalty_small = reg_small.penalty(&weight).data().to_vec()[0];
let penalty_large = reg_large.penalty(&weight).data().to_vec()[0];
assert!(
penalty_large > penalty_small,
"Larger lambda should give larger penalty: small={}, large={}",
penalty_small,
penalty_large
);
let ratio = penalty_large / penalty_small;
assert!(
(ratio - 10.0).abs() < 1e-4,
"Penalty should scale linearly with lambda, ratio={}",
ratio
);
}
#[test]
fn test_group_sparsity_debug() {
let reg = GroupSparsity::new(0.001, 10);
let debug_str = format!("{:?}", reg);
assert!(debug_str.contains("GroupSparsity"));
assert!(debug_str.contains("lambda"));
}
#[test]
#[should_panic(expected = "group_size must be positive")]
fn test_group_sparsity_zero_group_size_panics() {
let _reg = GroupSparsity::new(0.01, 0);
}
#[test]
fn test_lottery_ticket_snapshot() {
let layer = SparseLinear::new(10, 5);
let params = layer.parameters();
let ticket = LotteryTicket::snapshot(¶ms);
assert_eq!(ticket.num_saved(), params.len());
}
#[test]
fn test_lottery_ticket_rewind() {
let layer = SparseLinear::new(10, 5);
let params = layer.parameters();
let initial_weight = params[0].data().to_vec();
let ticket = LotteryTicket::snapshot(¶ms);
let new_data = Tensor::from_vec(vec![99.0; 50], &[5, 10]).expect("tensor creation failed");
params[0].update_data(new_data);
let modified_weight = params[0].data().to_vec();
assert_ne!(modified_weight, initial_weight);
ticket.rewind(¶ms);
let rewound_weight = params[0].data().to_vec();
assert_eq!(rewound_weight, initial_weight);
}
#[test]
fn test_lottery_ticket_rewind_preserves_shapes() {
let layer = SparseLinear::new(10, 5);
let params = layer.parameters();
let initial_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
let ticket = LotteryTicket::snapshot(¶ms);
let new_data = Tensor::from_vec(vec![0.0; 50], &[5, 10]).expect("tensor creation failed");
params[0].update_data(new_data);
ticket.rewind(¶ms);
let rewound_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
assert_eq!(initial_shapes, rewound_shapes);
}
#[test]
fn test_lottery_ticket_rewind_with_mask() {
let data =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor creation failed");
let param = Parameter::named("weight", data, true);
let params = vec![param];
let ticket = LotteryTicket::snapshot(¶ms);
let new_data = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2])
.expect("tensor creation failed");
params[0].update_data(new_data);
let mask =
Tensor::from_vec(vec![1.0, 1.0, 0.0, 0.0], &[2, 2]).expect("tensor creation failed");
ticket.rewind_with_mask(¶ms, &[mask]);
let result = params[0].data().to_vec();
assert_eq!(
result,
vec![1.0, 2.0, 0.0, 0.0],
"Masked weights should be zero, unmasked should be initial values"
);
}
#[test]
fn test_lottery_ticket_debug() {
let layer = SparseLinear::new(10, 5);
let ticket = LotteryTicket::snapshot(&layer.parameters());
let debug_str = format!("{:?}", ticket);
assert!(debug_str.contains("LotteryTicket"));
assert!(debug_str.contains("num_saved"));
}
#[test]
fn test_integration_sparse_linear_with_group_sparsity() {
let layer = SparseLinear::new(8, 4);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
false,
);
let output = layer.forward(&input);
assert_eq!(output.shape(), vec![2, 4]);
let reg = GroupSparsity::new(0.001, 8); let weight_var = layer.weight.variable();
let penalty = reg.penalty(&weight_var);
let penalty_val = penalty.data().to_vec()[0];
assert!(
penalty_val > 0.0,
"Penalty should be positive for non-zero weights"
);
}
#[test]
fn test_integration_lottery_ticket_with_pruning() {
let mut layer = SparseLinear::new(8, 4);
let ticket = LotteryTicket::snapshot(&layer.parameters());
let new_weight = Tensor::from_vec(vec![0.5; 32], &[4, 8]).expect("tensor creation failed");
layer.weight.update_data(new_weight);
layer.reset_threshold(0.3);
let mask = layer.hard_mask();
let weight_param = vec![layer.weight.clone()];
ticket.rewind_with_mask(&weight_param, &[mask]);
assert_eq!(layer.weight.shape(), vec![4, 8]);
}
#[test]
fn test_num_parameters_sparse_linear() {
let layer = SparseLinear::new(10, 5);
assert_eq!(layer.num_parameters(), 60);
}
}