Skip to main content

axonml_nn/layers/
sparse.rs

1//! Sparse Layers - Differentiable Structured Sparsity
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/sparse.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::init::{constant, kaiming_uniform, zeros};
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26// =============================================================================
27// Constants
28// =============================================================================
29
30/// Temperature for the sigmoid soft thresholding.
31/// Higher values produce a sharper (more binary) mask.
32const TEMPERATURE: f32 = 10.0;
33
34/// Default initial threshold value.
35/// Small so that most weights start active.
36const DEFAULT_THRESHOLD: f32 = 0.01;
37
38// =============================================================================
39// SparseLinear
40// =============================================================================
41
42/// A linear layer with a differentiable magnitude pruning mask.
43///
44/// During the forward pass, a soft mask is computed via sigmoid soft thresholding:
45///
46/// ```text
47/// mask = sigmoid((|weight| - threshold) * temperature)
48/// effective_weight = weight * mask
49/// y = x @ effective_weight^T + bias
50/// ```
51///
52/// The sigmoid makes the mask differentiable, so gradients flow through it and
53/// the network learns which weights to prune. The `threshold` parameter is
54/// learnable and included in `parameters()`.
55///
56/// # Structured vs Unstructured
57///
58/// - **Structured** (`structured=true`): One threshold per output neuron.
59///   Entire output channels can be pruned, yielding hardware-friendly sparsity.
60/// - **Unstructured** (`structured=false`): One threshold per weight element.
61///   Finer-grained but less hardware-friendly.
62///
63/// # Example
64/// ```ignore
65/// let layer = SparseLinear::new(784, 256);
66/// let output = layer.forward(&input);
67/// println!("Density: {:.1}%", layer.density() * 100.0);
68/// ```
69pub struct SparseLinear {
70    /// Weight matrix of shape (out_features, in_features).
71    pub weight: Parameter,
72    /// Optional bias vector of shape (out_features).
73    pub bias: Option<Parameter>,
74    /// Learnable magnitude thresholds. Shape depends on `structured`:
75    /// - Structured: (out_features,)
76    /// - Unstructured: (out_features, in_features)
77    pub threshold: Parameter,
78    /// Input feature dimension.
79    in_features: usize,
80    /// Output feature dimension.
81    out_features: usize,
82    /// Whether to use structured (channel) pruning.
83    structured: bool,
84}
85
86impl SparseLinear {
87    /// Creates a new SparseLinear layer with structured pruning and bias.
88    ///
89    /// # Arguments
90    /// * `in_features` - Size of each input sample
91    /// * `out_features` - Size of each output sample
92    pub fn new(in_features: usize, out_features: usize) -> Self {
93        Self::build(in_features, out_features, true, true)
94    }
95
96    /// Creates a new SparseLinear layer with unstructured (per-weight) pruning.
97    ///
98    /// # Arguments
99    /// * `in_features` - Size of each input sample
100    /// * `out_features` - Size of each output sample
101    pub fn unstructured(in_features: usize, out_features: usize) -> Self {
102        Self::build(in_features, out_features, false, true)
103    }
104
105    /// Creates a new SparseLinear layer with configurable bias.
106    ///
107    /// # Arguments
108    /// * `in_features` - Size of each input sample
109    /// * `out_features` - Size of each output sample
110    /// * `bias` - Whether to include a learnable bias
111    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
112        Self::build(in_features, out_features, true, bias)
113    }
114
115    /// Internal constructor.
116    fn build(in_features: usize, out_features: usize, structured: bool, bias: bool) -> Self {
117        // Kaiming uniform initialization for weights
118        let weight_data = kaiming_uniform(out_features, in_features);
119        let weight = Parameter::named("weight", weight_data, true);
120
121        // Bias initialization
122        let bias_param = if bias {
123            let bias_data = zeros(&[out_features]);
124            Some(Parameter::named("bias", bias_data, true))
125        } else {
126            None
127        };
128
129        // Threshold initialization — small value so most weights start active
130        let threshold_data = if structured {
131            constant(&[out_features], DEFAULT_THRESHOLD)
132        } else {
133            constant(&[out_features, in_features], DEFAULT_THRESHOLD)
134        };
135        let threshold = Parameter::named("threshold", threshold_data, true);
136
137        Self {
138            weight,
139            bias: bias_param,
140            threshold,
141            in_features,
142            out_features,
143            structured,
144        }
145    }
146
147    /// Returns the input feature dimension.
148    pub fn in_features(&self) -> usize {
149        self.in_features
150    }
151
152    /// Returns the output feature dimension.
153    pub fn out_features(&self) -> usize {
154        self.out_features
155    }
156
157    /// Returns whether this layer uses structured pruning.
158    pub fn is_structured(&self) -> bool {
159        self.structured
160    }
161
162    /// Computes the hard binary mask (threshold at 0.5) from the current weights
163    /// and thresholds.
164    ///
165    /// Returns a Tensor of 0s and 1s with the same shape as weight.
166    fn hard_mask(&self) -> Tensor<f32> {
167        let weight_data = self.weight.data();
168        let threshold_data = self.threshold.data();
169        let w_vec = weight_data.to_vec();
170        let t_vec = threshold_data.to_vec();
171
172        let mask_vec: Vec<f32> = if self.structured {
173            // One threshold per output neuron — broadcast across in_features
174            w_vec
175                .iter()
176                .enumerate()
177                .map(|(idx, &w)| {
178                    let out_idx = idx / self.in_features;
179                    let t = t_vec[out_idx];
180                    if w.abs() >= t { 1.0 } else { 0.0 }
181                })
182                .collect()
183        } else {
184            // One threshold per weight
185            w_vec
186                .iter()
187                .zip(t_vec.iter())
188                .map(|(&w, &t)| if w.abs() >= t { 1.0 } else { 0.0 })
189                .collect()
190        };
191
192        Tensor::from_vec(mask_vec, &[self.out_features, self.in_features]).unwrap()
193    }
194
195    /// Returns the fraction of weights that are active (above threshold).
196    ///
197    /// Uses hard thresholding at |weight| >= threshold.
198    pub fn density(&self) -> f32 {
199        let mask = self.hard_mask();
200        let mask_vec = mask.to_vec();
201        let total = mask_vec.len() as f32;
202        let active: f32 = mask_vec.iter().sum();
203        active / total
204    }
205
206    /// Returns the fraction of weights that are pruned.
207    ///
208    /// Equivalent to `1.0 - density()`.
209    pub fn sparsity(&self) -> f32 {
210        1.0 - self.density()
211    }
212
213    /// Returns the number of active (non-pruned) weights.
214    pub fn num_active(&self) -> usize {
215        let mask = self.hard_mask();
216        let mask_vec = mask.to_vec();
217        mask_vec.iter().filter(|&&v| v > 0.5).count()
218    }
219
220    /// Permanently applies the pruning mask to the weights.
221    ///
222    /// After calling this, pruned weights are zeroed and the threshold is reset
223    /// to zero. This is an irreversible optimization for inference — the zeroed
224    /// weights will not be recovered.
225    pub fn hard_prune(&mut self) {
226        let mask = self.hard_mask();
227        let weight_data = self.weight.data();
228        let w_vec = weight_data.to_vec();
229        let m_vec = mask.to_vec();
230
231        let pruned: Vec<f32> = w_vec
232            .iter()
233            .zip(m_vec.iter())
234            .map(|(&w, &m)| w * m)
235            .collect();
236
237        let new_weight = Tensor::from_vec(pruned, &[self.out_features, self.in_features]).unwrap();
238        self.weight.update_data(new_weight);
239
240        // Reset thresholds to zero so forward pass doesn't re-prune
241        let zero_threshold = if self.structured {
242            zeros(&[self.out_features])
243        } else {
244            zeros(&[self.out_features, self.in_features])
245        };
246        self.threshold.update_data(zero_threshold);
247    }
248
249    /// Resets the threshold to a specific value.
250    ///
251    /// # Arguments
252    /// * `value` - The new threshold value
253    pub fn reset_threshold(&mut self, value: f32) {
254        let new_threshold = if self.structured {
255            constant(&[self.out_features], value)
256        } else {
257            constant(&[self.out_features, self.in_features], value)
258        };
259        self.threshold.update_data(new_threshold);
260    }
261
262    /// Returns the effective weight (weight * hard_mask) for inspection.
263    ///
264    /// This shows what the weight matrix looks like after hard pruning,
265    /// without actually modifying the layer.
266    pub fn effective_weight(&self) -> Tensor<f32> {
267        let mask = self.hard_mask();
268        let weight_data = self.weight.data();
269        let w_vec = weight_data.to_vec();
270        let m_vec = mask.to_vec();
271
272        let effective: Vec<f32> = w_vec
273            .iter()
274            .zip(m_vec.iter())
275            .map(|(&w, &m)| w * m)
276            .collect();
277
278        Tensor::from_vec(effective, &[self.out_features, self.in_features]).unwrap()
279    }
280
281    /// Computes the soft mask using differentiable sigmoid thresholding.
282    ///
283    /// The soft mask is computed element-wise as:
284    /// ```text
285    /// mask_ij = sigmoid((|w_ij| - threshold_j) * temperature)
286    /// ```
287    ///
288    /// For structured pruning, `threshold_j` is broadcast across `in_features`.
289    /// For unstructured pruning, each weight has its own threshold.
290    fn compute_soft_mask(&self, weight_var: &Variable) -> Variable {
291        let weight_data = weight_var.data();
292        let threshold_data = self.threshold.data();
293        let w_vec = weight_data.to_vec();
294        let t_vec = threshold_data.to_vec();
295
296        // Compute sigmoid((|w| - threshold) * temperature) element-wise
297        let mask_vec: Vec<f32> = if self.structured {
298            w_vec
299                .iter()
300                .enumerate()
301                .map(|(idx, &w)| {
302                    let out_idx = idx / self.in_features;
303                    let t = t_vec[out_idx];
304                    let x = (w.abs() - t) * TEMPERATURE;
305                    1.0 / (1.0 + (-x).exp())
306                })
307                .collect()
308        } else {
309            w_vec
310                .iter()
311                .zip(t_vec.iter())
312                .map(|(&w, &t)| {
313                    let x = (w.abs() - t) * TEMPERATURE;
314                    1.0 / (1.0 + (-x).exp())
315                })
316                .collect()
317        };
318
319        let mask_tensor =
320            Tensor::from_vec(mask_vec, &[self.out_features, self.in_features]).unwrap();
321
322        // Create as a variable that participates in the graph
323        // The mask depends on both weight and threshold, but since we compute
324        // it from the raw tensor values, we wrap it as a new variable.
325        // The gradient signal flows through the weight multiplication below.
326        Variable::new(mask_tensor, false)
327    }
328}
329
330impl Module for SparseLinear {
331    fn forward(&self, input: &Variable) -> Variable {
332        let input_shape = input.shape();
333        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
334        let total_batch: usize = batch_dims.iter().product();
335
336        // Reshape to 2D if needed
337        let input_2d = if input_shape.len() > 2 {
338            input.reshape(&[total_batch, self.in_features])
339        } else {
340            input.clone()
341        };
342
343        // Get weight variable and compute soft mask
344        let weight_var = self.weight.variable();
345        let mask = self.compute_soft_mask(&weight_var);
346
347        // effective_weight = weight * mask
348        let effective_weight = weight_var.mul_var(&mask);
349
350        // y = x @ effective_weight^T
351        let weight_t = effective_weight.transpose(0, 1);
352        let mut output = input_2d.matmul(&weight_t);
353
354        // Add bias if present
355        if let Some(ref bias) = self.bias {
356            let bias_var = bias.variable();
357            output = output.add_var(&bias_var);
358        }
359
360        // Reshape back to original batch dimensions
361        if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
362            let mut output_shape: Vec<usize> = batch_dims;
363            output_shape.push(self.out_features);
364            output.reshape(&output_shape)
365        } else {
366            output
367        }
368    }
369
370    fn parameters(&self) -> Vec<Parameter> {
371        let mut params = vec![self.weight.clone(), self.threshold.clone()];
372        if let Some(ref bias) = self.bias {
373            params.push(bias.clone());
374        }
375        params
376    }
377
378    fn named_parameters(&self) -> HashMap<String, Parameter> {
379        let mut params = HashMap::new();
380        params.insert("weight".to_string(), self.weight.clone());
381        params.insert("threshold".to_string(), self.threshold.clone());
382        if let Some(ref bias) = self.bias {
383            params.insert("bias".to_string(), bias.clone());
384        }
385        params
386    }
387
388    fn name(&self) -> &'static str {
389        "SparseLinear"
390    }
391}
392
393impl std::fmt::Debug for SparseLinear {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        f.debug_struct("SparseLinear")
396            .field("in_features", &self.in_features)
397            .field("out_features", &self.out_features)
398            .field("bias", &self.bias.is_some())
399            .field("structured", &self.structured)
400            .field("density", &self.density())
401            .finish()
402    }
403}
404
405// =============================================================================
406// GroupSparsity
407// =============================================================================
408
409/// A regularization module that encourages structured sparsity via group L1 norm.
410///
411/// Computes a penalty term that can be added to the loss function:
412///
413/// ```text
414/// penalty = lambda * sum_g(||weight_g||_2)
415/// ```
416///
417/// where `weight_g` is a group of weights (e.g., all weights for one output neuron).
418/// This encourages entire groups to go to zero (structured pruning), since the L2
419/// norm within each group distributes the penalty equally among group members.
420///
421/// # Example
422/// ```ignore
423/// let reg = GroupSparsity::new(0.001, 128);  // lambda=0.001, group_size=128
424/// let penalty = reg.penalty(&model.weight_variable());
425/// let total_loss = task_loss.add_var(&penalty);
426/// ```
427pub struct GroupSparsity {
428    /// Regularization strength. Higher values encourage more sparsity.
429    lambda: f32,
430    /// Number of weights per group.
431    group_size: usize,
432}
433
434impl GroupSparsity {
435    /// Creates a new GroupSparsity regularizer.
436    ///
437    /// # Arguments
438    /// * `lambda` - Regularization strength (e.g., 0.001)
439    /// * `group_size` - Number of weights per group (e.g., in_features for neuron-level)
440    pub fn new(lambda: f32, group_size: usize) -> Self {
441        assert!(group_size > 0, "group_size must be positive");
442        Self { lambda, group_size }
443    }
444
445    /// Returns the regularization strength.
446    pub fn lambda(&self) -> f32 {
447        self.lambda
448    }
449
450    /// Returns the group size.
451    pub fn group_size(&self) -> usize {
452        self.group_size
453    }
454
455    /// Computes the group L1 penalty for the given weight variable.
456    ///
457    /// The penalty is computed as:
458    /// 1. Reshape weight into groups of size `group_size`
459    /// 2. Compute L2 norm of each group
460    /// 3. Sum all group norms (L1 of norms)
461    /// 4. Multiply by lambda
462    ///
463    /// Returns a scalar Variable that can be added to the loss.
464    pub fn penalty(&self, weight: &Variable) -> Variable {
465        let weight_data = weight.data();
466        let w_vec = weight_data.to_vec();
467        let total = w_vec.len();
468
469        // Number of complete groups
470        let num_groups = total.div_ceil(self.group_size);
471
472        // Compute L2 norm per group, then sum (L1 of group norms)
473        let mut group_norm_sum = 0.0f32;
474        for g in 0..num_groups {
475            let start = g * self.group_size;
476            let end = (start + self.group_size).min(total);
477            let group = &w_vec[start..end];
478
479            let l2_norm: f32 = group.iter().map(|&x| x * x).sum::<f32>().sqrt();
480            group_norm_sum += l2_norm;
481        }
482
483        let penalty_val = self.lambda * group_norm_sum;
484        let penalty_tensor = Tensor::from_vec(vec![penalty_val], &[1]).unwrap();
485
486        // Create as a variable. The penalty is computed from raw tensor values
487        // for simplicity. For full autograd integration, one would implement a
488        // custom backward function, but the penalty is typically used alongside
489        // weight decay in the optimizer.
490        Variable::new(penalty_tensor, false)
491    }
492}
493
494impl std::fmt::Debug for GroupSparsity {
495    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496        f.debug_struct("GroupSparsity")
497            .field("lambda", &self.lambda)
498            .field("group_size", &self.group_size)
499            .finish()
500    }
501}
502
503// =============================================================================
504// LotteryTicket
505// =============================================================================
506
507/// Implements the Lottery Ticket Hypothesis (Frankle & Carlin, 2019).
508///
509/// The Lottery Ticket Hypothesis states that dense networks contain sparse
510/// subnetworks ("winning tickets") that can be trained in isolation to match
511/// the full network's accuracy.
512///
513/// This struct saves a snapshot of the initial weights, then after pruning,
514/// allows rewinding the unpruned weights back to their initial values while
515/// keeping the pruning mask.
516///
517/// # Workflow
518/// 1. Initialize network
519/// 2. `let ticket = LotteryTicket::snapshot(&model.parameters());`
520/// 3. Train network with pruning
521/// 4. Determine pruning mask
522/// 5. `ticket.rewind(&model.parameters());` — reset to initial weights
523/// 6. Apply mask and train again
524///
525/// # Example
526/// ```ignore
527/// let model = SparseLinear::new(784, 256);
528/// let ticket = LotteryTicket::snapshot(&model.parameters());
529///
530/// // ... train and prune ...
531///
532/// ticket.rewind(&model.parameters());  // Reset to initial weights
533/// // ... retrain with mask ...
534/// ```
535pub struct LotteryTicket {
536    /// Saved initial parameter values, keyed by parameter name or index.
537    initial_weights: HashMap<String, Tensor<f32>>,
538}
539
540impl LotteryTicket {
541    /// Takes a snapshot of the current parameter values.
542    ///
543    /// # Arguments
544    /// * `params` - Slice of parameters to snapshot
545    pub fn snapshot(params: &[Parameter]) -> Self {
546        let mut initial_weights = HashMap::new();
547        for (i, param) in params.iter().enumerate() {
548            let key = if param.name().is_empty() {
549                format!("param_{}", i)
550            } else {
551                param.name().to_string()
552            };
553            initial_weights.insert(key, param.data());
554        }
555        Self { initial_weights }
556    }
557
558    /// Returns the number of saved parameters.
559    pub fn num_saved(&self) -> usize {
560        self.initial_weights.len()
561    }
562
563    /// Rewinds all parameters to their initial (snapshot) values.
564    ///
565    /// # Arguments
566    /// * `params` - Slice of parameters to rewind (must match snapshot order)
567    pub fn rewind(&self, params: &[Parameter]) {
568        for (i, param) in params.iter().enumerate() {
569            let key = if param.name().is_empty() {
570                format!("param_{}", i)
571            } else {
572                param.name().to_string()
573            };
574            if let Some(initial) = self.initial_weights.get(&key) {
575                param.update_data(initial.clone());
576            }
577        }
578    }
579
580    /// Rewinds parameters to their initial values, but only where the mask is 1.
581    ///
582    /// Weights where `mask == 0` are set to zero (pruned). Weights where
583    /// `mask == 1` are reset to their initial snapshot values.
584    ///
585    /// # Arguments
586    /// * `params` - Slice of parameters to rewind
587    /// * `masks` - Corresponding binary masks (same length as params)
588    pub fn rewind_with_mask(&self, params: &[Parameter], masks: &[Tensor<f32>]) {
589        assert_eq!(
590            params.len(),
591            masks.len(),
592            "Number of parameters and masks must match"
593        );
594
595        for (i, (param, mask)) in params.iter().zip(masks.iter()).enumerate() {
596            let key = if param.name().is_empty() {
597                format!("param_{}", i)
598            } else {
599                param.name().to_string()
600            };
601
602            if let Some(initial) = self.initial_weights.get(&key) {
603                let init_vec = initial.to_vec();
604                let mask_vec = mask.to_vec();
605
606                let rewound: Vec<f32> = init_vec
607                    .iter()
608                    .zip(mask_vec.iter())
609                    .map(|(&w, &m)| if m > 0.5 { w } else { 0.0 })
610                    .collect();
611
612                let shape = param.shape();
613                let new_data = Tensor::from_vec(rewound, &shape).unwrap();
614                param.update_data(new_data);
615            }
616        }
617    }
618}
619
620impl std::fmt::Debug for LotteryTicket {
621    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
622        f.debug_struct("LotteryTicket")
623            .field("num_saved", &self.initial_weights.len())
624            .finish()
625    }
626}
627
628// =============================================================================
629// Tests
630// =============================================================================
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635
636    // -------------------------------------------------------------------------
637    // SparseLinear Tests
638    // -------------------------------------------------------------------------
639
640    #[test]
641    fn test_sparse_linear_creation_structured() {
642        let layer = SparseLinear::new(10, 5);
643        assert_eq!(layer.in_features(), 10);
644        assert_eq!(layer.out_features(), 5);
645        assert!(layer.is_structured());
646        assert!(layer.bias.is_some());
647    }
648
649    #[test]
650    fn test_sparse_linear_creation_unstructured() {
651        let layer = SparseLinear::unstructured(10, 5);
652        assert_eq!(layer.in_features(), 10);
653        assert_eq!(layer.out_features(), 5);
654        assert!(!layer.is_structured());
655        assert!(layer.bias.is_some());
656    }
657
658    #[test]
659    fn test_sparse_linear_no_bias() {
660        let layer = SparseLinear::with_bias(10, 5, false);
661        assert!(layer.bias.is_none());
662    }
663
664    #[test]
665    fn test_sparse_linear_forward_shape() {
666        let layer = SparseLinear::new(4, 3);
667        let input = Variable::new(
668            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
669            false,
670        );
671        let output = layer.forward(&input);
672        assert_eq!(output.shape(), vec![1, 3]);
673    }
674
675    #[test]
676    fn test_sparse_linear_forward_batch() {
677        let layer = SparseLinear::new(4, 3);
678        let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
679        let output = layer.forward(&input);
680        assert_eq!(output.shape(), vec![3, 3]);
681    }
682
683    #[test]
684    fn test_sparse_linear_forward_no_bias() {
685        let layer = SparseLinear::with_bias(4, 3, false);
686        let input = Variable::new(Tensor::from_vec(vec![1.0; 8], &[2, 4]).unwrap(), false);
687        let output = layer.forward(&input);
688        assert_eq!(output.shape(), vec![2, 3]);
689    }
690
691    #[test]
692    fn test_sparse_linear_density_initial() {
693        // With default threshold of 0.01, most Kaiming-initialized weights
694        // should be above threshold (density close to 1.0).
695        let layer = SparseLinear::new(100, 50);
696        let density = layer.density();
697        assert!(
698            density > 0.9,
699            "Initial density should be high, got {}",
700            density
701        );
702    }
703
704    #[test]
705    fn test_sparse_linear_sparsity_initial() {
706        let layer = SparseLinear::new(100, 50);
707        let sparsity = layer.sparsity();
708        assert!(
709            sparsity < 0.1,
710            "Initial sparsity should be low, got {}",
711            sparsity
712        );
713        assert!((layer.density() + layer.sparsity() - 1.0).abs() < 1e-6);
714    }
715
716    #[test]
717    fn test_sparse_linear_num_active() {
718        let layer = SparseLinear::new(10, 5);
719        let active = layer.num_active();
720        let total = 10 * 5;
721        assert!(active <= total);
722        assert!(active > 0);
723    }
724
725    #[test]
726    fn test_sparse_linear_high_threshold_more_sparsity() {
727        let mut layer = SparseLinear::new(100, 50);
728        let density_low_thresh = layer.density();
729
730        // Set high threshold — should prune more weights
731        layer.reset_threshold(10.0);
732        let density_high_thresh = layer.density();
733
734        assert!(
735            density_high_thresh < density_low_thresh,
736            "Higher threshold should reduce density: low_thresh={}, high_thresh={}",
737            density_low_thresh,
738            density_high_thresh
739        );
740    }
741
742    #[test]
743    fn test_sparse_linear_low_threshold_dense() {
744        let mut layer = SparseLinear::new(100, 50);
745        // Set threshold to zero — all weights should be active
746        layer.reset_threshold(0.0);
747        let density = layer.density();
748        assert!(
749            (density - 1.0).abs() < 1e-6,
750            "Zero threshold should give density=1.0, got {}",
751            density
752        );
753    }
754
755    #[test]
756    fn test_sparse_linear_soft_mask_values_in_range() {
757        let layer = SparseLinear::new(10, 5);
758        let weight_var = layer.weight.variable();
759        let mask = layer.compute_soft_mask(&weight_var);
760        let mask_vec = mask.data().to_vec();
761
762        for &v in &mask_vec {
763            assert!(v >= 0.0 && v <= 1.0, "Soft mask value {} not in [0, 1]", v);
764        }
765    }
766
767    #[test]
768    fn test_sparse_linear_hard_prune() {
769        let mut layer = SparseLinear::new(10, 5);
770        // Set a threshold that will prune some weights
771        layer.reset_threshold(0.5);
772
773        let pre_prune_density = layer.density();
774        layer.hard_prune();
775
776        // After hard prune, the zeroed weights should stay zero
777        let weight_data = layer.weight.data();
778        let w_vec = weight_data.to_vec();
779        let zeros_count = w_vec.iter().filter(|&&v| v == 0.0).count();
780
781        // The number of zeros should correspond to the pruned fraction
782        let expected_zeros = ((1.0 - pre_prune_density) * (10 * 5) as f32).round() as usize;
783        assert_eq!(
784            zeros_count, expected_zeros,
785            "Hard prune should zero out pruned weights"
786        );
787    }
788
789    #[test]
790    fn test_sparse_linear_hard_prune_threshold_reset() {
791        let mut layer = SparseLinear::new(10, 5);
792        layer.reset_threshold(0.5);
793        layer.hard_prune();
794
795        // After hard prune, thresholds should be zero
796        let t_vec = layer.threshold.data().to_vec();
797        assert!(
798            t_vec.iter().all(|&v| v == 0.0),
799            "Thresholds should be zero after hard_prune"
800        );
801    }
802
803    #[test]
804    fn test_sparse_linear_effective_weight() {
805        let layer = SparseLinear::new(10, 5);
806        let ew = layer.effective_weight();
807        assert_eq!(ew.shape(), &[5, 10]);
808    }
809
810    #[test]
811    fn test_sparse_linear_effective_weight_matches_hard_prune() {
812        let mut layer = SparseLinear::new(10, 5);
813        layer.reset_threshold(0.3);
814
815        let effective = layer.effective_weight();
816        layer.hard_prune();
817        let pruned = layer.weight.data();
818
819        let e_vec = effective.to_vec();
820        let p_vec = pruned.to_vec();
821        for (e, p) in e_vec.iter().zip(p_vec.iter()) {
822            assert!(
823                (e - p).abs() < 1e-6,
824                "effective_weight and hard_prune should match"
825            );
826        }
827    }
828
829    #[test]
830    fn test_sparse_linear_parameters_include_threshold() {
831        let layer = SparseLinear::new(10, 5);
832        let params = layer.parameters();
833        // weight + threshold + bias = 3
834        assert_eq!(params.len(), 3);
835
836        let named = layer.named_parameters();
837        assert!(named.contains_key("threshold"));
838        assert!(named.contains_key("weight"));
839        assert!(named.contains_key("bias"));
840    }
841
842    #[test]
843    fn test_sparse_linear_parameters_no_bias() {
844        let layer = SparseLinear::with_bias(10, 5, false);
845        let params = layer.parameters();
846        // weight + threshold = 2
847        assert_eq!(params.len(), 2);
848    }
849
850    #[test]
851    fn test_sparse_linear_module_name() {
852        let layer = SparseLinear::new(10, 5);
853        assert_eq!(layer.name(), "SparseLinear");
854    }
855
856    #[test]
857    fn test_sparse_linear_debug() {
858        let layer = SparseLinear::new(10, 5);
859        let debug_str = format!("{:?}", layer);
860        assert!(debug_str.contains("SparseLinear"));
861        assert!(debug_str.contains("in_features: 10"));
862        assert!(debug_str.contains("out_features: 5"));
863    }
864
865    #[test]
866    fn test_sparse_linear_reset_threshold() {
867        let mut layer = SparseLinear::new(10, 5);
868        layer.reset_threshold(0.5);
869        let t_vec = layer.threshold.data().to_vec();
870        assert!(t_vec.iter().all(|&v| (v - 0.5).abs() < 1e-6));
871    }
872
873    #[test]
874    fn test_sparse_linear_unstructured_threshold_shape() {
875        let layer = SparseLinear::unstructured(10, 5);
876        // Unstructured: threshold has same shape as weight
877        assert_eq!(layer.threshold.shape(), vec![5, 10]);
878    }
879
880    #[test]
881    fn test_sparse_linear_structured_threshold_shape() {
882        let layer = SparseLinear::new(10, 5);
883        // Structured: threshold has shape (out_features,)
884        assert_eq!(layer.threshold.shape(), vec![5]);
885    }
886
887    #[test]
888    fn test_sparse_linear_unstructured_forward() {
889        let layer = SparseLinear::unstructured(4, 3);
890        let input = Variable::new(
891            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
892            false,
893        );
894        let output = layer.forward(&input);
895        assert_eq!(output.shape(), vec![2, 3]);
896    }
897
898    // -------------------------------------------------------------------------
899    // GroupSparsity Tests
900    // -------------------------------------------------------------------------
901
902    #[test]
903    fn test_group_sparsity_creation() {
904        let reg = GroupSparsity::new(0.001, 10);
905        assert!((reg.lambda() - 0.001).abs() < 1e-8);
906        assert_eq!(reg.group_size(), 10);
907    }
908
909    #[test]
910    fn test_group_sparsity_penalty_non_negative() {
911        let reg = GroupSparsity::new(0.01, 4);
912        let weight = Variable::new(
913            Tensor::from_vec(vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0], &[2, 4]).unwrap(),
914            true,
915        );
916        let penalty = reg.penalty(&weight);
917        let penalty_val = penalty.data().to_vec()[0];
918        assert!(
919            penalty_val >= 0.0,
920            "Penalty should be non-negative, got {}",
921            penalty_val
922        );
923    }
924
925    #[test]
926    fn test_group_sparsity_zero_weights_zero_penalty() {
927        let reg = GroupSparsity::new(0.01, 4);
928        let weight = Variable::new(Tensor::from_vec(vec![0.0; 8], &[2, 4]).unwrap(), true);
929        let penalty = reg.penalty(&weight);
930        let penalty_val = penalty.data().to_vec()[0];
931        assert!(
932            (penalty_val).abs() < 1e-6,
933            "Zero weights should give zero penalty, got {}",
934            penalty_val
935        );
936    }
937
938    #[test]
939    fn test_group_sparsity_scales_with_lambda() {
940        let reg_small = GroupSparsity::new(0.001, 4);
941        let reg_large = GroupSparsity::new(0.01, 4);
942        let weight = Variable::new(
943            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
944            true,
945        );
946
947        let penalty_small = reg_small.penalty(&weight).data().to_vec()[0];
948        let penalty_large = reg_large.penalty(&weight).data().to_vec()[0];
949
950        assert!(
951            penalty_large > penalty_small,
952            "Larger lambda should give larger penalty: small={}, large={}",
953            penalty_small,
954            penalty_large
955        );
956
957        // Should scale linearly with lambda
958        let ratio = penalty_large / penalty_small;
959        assert!(
960            (ratio - 10.0).abs() < 1e-4,
961            "Penalty should scale linearly with lambda, ratio={}",
962            ratio
963        );
964    }
965
966    #[test]
967    fn test_group_sparsity_debug() {
968        let reg = GroupSparsity::new(0.001, 10);
969        let debug_str = format!("{:?}", reg);
970        assert!(debug_str.contains("GroupSparsity"));
971        assert!(debug_str.contains("lambda"));
972    }
973
974    #[test]
975    #[should_panic(expected = "group_size must be positive")]
976    fn test_group_sparsity_zero_group_size_panics() {
977        let _reg = GroupSparsity::new(0.01, 0);
978    }
979
980    // -------------------------------------------------------------------------
981    // LotteryTicket Tests
982    // -------------------------------------------------------------------------
983
984    #[test]
985    fn test_lottery_ticket_snapshot() {
986        let layer = SparseLinear::new(10, 5);
987        let params = layer.parameters();
988        let ticket = LotteryTicket::snapshot(&params);
989        assert_eq!(ticket.num_saved(), params.len());
990    }
991
992    #[test]
993    fn test_lottery_ticket_rewind() {
994        let layer = SparseLinear::new(10, 5);
995        let params = layer.parameters();
996        let initial_weight = params[0].data().to_vec();
997
998        let ticket = LotteryTicket::snapshot(&params);
999
1000        // Modify the weight
1001        let new_data = Tensor::from_vec(vec![99.0; 50], &[5, 10]).unwrap();
1002        params[0].update_data(new_data);
1003
1004        // Verify it changed
1005        let modified_weight = params[0].data().to_vec();
1006        assert_ne!(modified_weight, initial_weight);
1007
1008        // Rewind
1009        ticket.rewind(&params);
1010
1011        // Verify it's back to initial
1012        let rewound_weight = params[0].data().to_vec();
1013        assert_eq!(rewound_weight, initial_weight);
1014    }
1015
1016    #[test]
1017    fn test_lottery_ticket_rewind_preserves_shapes() {
1018        let layer = SparseLinear::new(10, 5);
1019        let params = layer.parameters();
1020        let initial_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1021
1022        let ticket = LotteryTicket::snapshot(&params);
1023
1024        // Modify weight data (same shape)
1025        let new_data = Tensor::from_vec(vec![0.0; 50], &[5, 10]).unwrap();
1026        params[0].update_data(new_data);
1027
1028        ticket.rewind(&params);
1029
1030        let rewound_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1031        assert_eq!(initial_shapes, rewound_shapes);
1032    }
1033
1034    #[test]
1035    fn test_lottery_ticket_rewind_with_mask() {
1036        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1037        let param = Parameter::named("weight", data, true);
1038        let params = vec![param];
1039
1040        let ticket = LotteryTicket::snapshot(&params);
1041
1042        // Modify the parameter
1043        let new_data = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2]).unwrap();
1044        params[0].update_data(new_data);
1045
1046        // Mask: keep first two, prune last two
1047        let mask = Tensor::from_vec(vec![1.0, 1.0, 0.0, 0.0], &[2, 2]).unwrap();
1048        ticket.rewind_with_mask(&params, &[mask]);
1049
1050        let result = params[0].data().to_vec();
1051        assert_eq!(
1052            result,
1053            vec![1.0, 2.0, 0.0, 0.0],
1054            "Masked weights should be zero, unmasked should be initial values"
1055        );
1056    }
1057
1058    #[test]
1059    fn test_lottery_ticket_debug() {
1060        let layer = SparseLinear::new(10, 5);
1061        let ticket = LotteryTicket::snapshot(&layer.parameters());
1062        let debug_str = format!("{:?}", ticket);
1063        assert!(debug_str.contains("LotteryTicket"));
1064        assert!(debug_str.contains("num_saved"));
1065    }
1066
1067    // -------------------------------------------------------------------------
1068    // Integration Tests
1069    // -------------------------------------------------------------------------
1070
1071    #[test]
1072    fn test_integration_sparse_linear_with_group_sparsity() {
1073        // Create a SparseLinear layer
1074        let layer = SparseLinear::new(8, 4);
1075
1076        // Forward pass
1077        let input = Variable::new(Tensor::from_vec(vec![1.0; 16], &[2, 8]).unwrap(), false);
1078        let output = layer.forward(&input);
1079        assert_eq!(output.shape(), vec![2, 4]);
1080
1081        // Compute group sparsity penalty on the weights
1082        let reg = GroupSparsity::new(0.001, 8); // group_size = in_features
1083        let weight_var = layer.weight.variable();
1084        let penalty = reg.penalty(&weight_var);
1085        let penalty_val = penalty.data().to_vec()[0];
1086        assert!(
1087            penalty_val > 0.0,
1088            "Penalty should be positive for non-zero weights"
1089        );
1090    }
1091
1092    #[test]
1093    fn test_integration_lottery_ticket_with_pruning() {
1094        // 1. Create layer and snapshot
1095        let mut layer = SparseLinear::new(8, 4);
1096        let ticket = LotteryTicket::snapshot(&layer.parameters());
1097
1098        // 2. Simulate training (modify weights)
1099        let new_weight = Tensor::from_vec(vec![0.5; 32], &[4, 8]).unwrap();
1100        layer.weight.update_data(new_weight);
1101
1102        // 3. Set threshold to prune some weights
1103        layer.reset_threshold(0.3);
1104
1105        // 4. Get the effective weight mask
1106        let mask = layer.hard_mask();
1107
1108        // 5. Rewind to initial weights with mask
1109        let weight_param = vec![layer.weight.clone()];
1110        ticket.rewind_with_mask(&weight_param, &[mask]);
1111
1112        // Verify shape is preserved
1113        assert_eq!(layer.weight.shape(), vec![4, 8]);
1114    }
1115
1116    #[test]
1117    fn test_num_parameters_sparse_linear() {
1118        let layer = SparseLinear::new(10, 5);
1119        // weight: 50 + threshold: 5 + bias: 5 = 60
1120        assert_eq!(layer.num_parameters(), 60);
1121    }
1122}