Skip to main content

axonml_nn/layers/
sparse.rs

1//! Sparse Layers - Differentiable Structured Sparsity
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/sparse.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::init::{constant, kaiming_uniform, zeros};
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26// =============================================================================
27// Constants
28// =============================================================================
29
30/// Temperature for the sigmoid soft thresholding.
31/// Higher values produce a sharper (more binary) mask.
32const TEMPERATURE: f32 = 10.0;
33
34/// Default initial threshold value.
35/// Small so that most weights start active.
36const DEFAULT_THRESHOLD: f32 = 0.01;
37
38// =============================================================================
39// SparseLinear
40// =============================================================================
41
42/// A linear layer with a differentiable magnitude pruning mask.
43///
44/// During the forward pass, a soft mask is computed via sigmoid soft thresholding:
45///
46/// ```text
47/// mask = sigmoid((|weight| - threshold) * temperature)
48/// effective_weight = weight * mask
49/// y = x @ effective_weight^T + bias
50/// ```
51///
52/// The sigmoid makes the mask differentiable, so gradients flow through it and
53/// the network learns which weights to prune. The `threshold` parameter is
54/// learnable and included in `parameters()`.
55///
56/// # Structured vs Unstructured
57///
58/// - **Structured** (`structured=true`): One threshold per output neuron.
59///   Entire output channels can be pruned, yielding hardware-friendly sparsity.
60/// - **Unstructured** (`structured=false`): One threshold per weight element.
61///   Finer-grained but less hardware-friendly.
62///
63/// # Example
64/// ```ignore
65/// let layer = SparseLinear::new(784, 256);
66/// let output = layer.forward(&input);
67/// println!("Density: {:.1}%", layer.density() * 100.0);
68/// ```
69pub struct SparseLinear {
70    /// Weight matrix of shape (out_features, in_features).
71    pub weight: Parameter,
72    /// Optional bias vector of shape (out_features).
73    pub bias: Option<Parameter>,
74    /// Learnable magnitude thresholds. Shape depends on `structured`:
75    /// - Structured: (out_features,)
76    /// - Unstructured: (out_features, in_features)
77    pub threshold: Parameter,
78    /// Input feature dimension.
79    in_features: usize,
80    /// Output feature dimension.
81    out_features: usize,
82    /// Whether to use structured (channel) pruning.
83    structured: bool,
84}
85
86impl SparseLinear {
87    /// Creates a new SparseLinear layer with structured pruning and bias.
88    ///
89    /// # Arguments
90    /// * `in_features` - Size of each input sample
91    /// * `out_features` - Size of each output sample
92    pub fn new(in_features: usize, out_features: usize) -> Self {
93        Self::build(in_features, out_features, true, true)
94    }
95
96    /// Creates a new SparseLinear layer with unstructured (per-weight) pruning.
97    ///
98    /// # Arguments
99    /// * `in_features` - Size of each input sample
100    /// * `out_features` - Size of each output sample
101    pub fn unstructured(in_features: usize, out_features: usize) -> Self {
102        Self::build(in_features, out_features, false, true)
103    }
104
105    /// Creates a new SparseLinear layer with configurable bias.
106    ///
107    /// # Arguments
108    /// * `in_features` - Size of each input sample
109    /// * `out_features` - Size of each output sample
110    /// * `bias` - Whether to include a learnable bias
111    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
112        Self::build(in_features, out_features, true, bias)
113    }
114
115    /// Internal constructor.
116    fn build(in_features: usize, out_features: usize, structured: bool, bias: bool) -> Self {
117        // Kaiming uniform initialization for weights
118        let weight_data = kaiming_uniform(out_features, in_features);
119        let weight = Parameter::named("weight", weight_data, true);
120
121        // Bias initialization
122        let bias_param = if bias {
123            let bias_data = zeros(&[out_features]);
124            Some(Parameter::named("bias", bias_data, true))
125        } else {
126            None
127        };
128
129        // Threshold initialization — small value so most weights start active
130        let threshold_data = if structured {
131            constant(&[out_features], DEFAULT_THRESHOLD)
132        } else {
133            constant(&[out_features, in_features], DEFAULT_THRESHOLD)
134        };
135        let threshold = Parameter::named("threshold", threshold_data, true);
136
137        Self {
138            weight,
139            bias: bias_param,
140            threshold,
141            in_features,
142            out_features,
143            structured,
144        }
145    }
146
147    /// Returns the input feature dimension.
148    pub fn in_features(&self) -> usize {
149        self.in_features
150    }
151
152    /// Returns the output feature dimension.
153    pub fn out_features(&self) -> usize {
154        self.out_features
155    }
156
157    /// Returns whether this layer uses structured pruning.
158    pub fn is_structured(&self) -> bool {
159        self.structured
160    }
161
162    /// Computes the hard binary mask (threshold at 0.5) from the current weights
163    /// and thresholds.
164    ///
165    /// Returns a Tensor of 0s and 1s with the same shape as weight.
166    fn hard_mask(&self) -> Tensor<f32> {
167        let weight_data = self.weight.data();
168        let threshold_data = self.threshold.data();
169        let w_vec = weight_data.to_vec();
170        let t_vec = threshold_data.to_vec();
171
172        let mask_vec: Vec<f32> = if self.structured {
173            // One threshold per output neuron — broadcast across in_features
174            w_vec
175                .iter()
176                .enumerate()
177                .map(|(idx, &w)| {
178                    let out_idx = idx / self.in_features;
179                    let t = t_vec[out_idx];
180                    if w.abs() >= t { 1.0 } else { 0.0 }
181                })
182                .collect()
183        } else {
184            // One threshold per weight
185            w_vec
186                .iter()
187                .zip(t_vec.iter())
188                .map(|(&w, &t)| if w.abs() >= t { 1.0 } else { 0.0 })
189                .collect()
190        };
191
192        Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
193            .expect("tensor creation failed")
194    }
195
196    /// Returns the fraction of weights that are active (above threshold).
197    ///
198    /// Uses hard thresholding at |weight| >= threshold.
199    pub fn density(&self) -> f32 {
200        let mask = self.hard_mask();
201        let mask_vec = mask.to_vec();
202        let total = mask_vec.len() as f32;
203        let active: f32 = mask_vec.iter().sum();
204        active / total
205    }
206
207    /// Returns the fraction of weights that are pruned.
208    ///
209    /// Equivalent to `1.0 - density()`.
210    pub fn sparsity(&self) -> f32 {
211        1.0 - self.density()
212    }
213
214    /// Returns the number of active (non-pruned) weights.
215    pub fn num_active(&self) -> usize {
216        let mask = self.hard_mask();
217        let mask_vec = mask.to_vec();
218        mask_vec.iter().filter(|&&v| v > 0.5).count()
219    }
220
221    /// Permanently applies the pruning mask to the weights.
222    ///
223    /// After calling this, pruned weights are zeroed and the threshold is reset
224    /// to zero. This is an irreversible optimization for inference — the zeroed
225    /// weights will not be recovered.
226    pub fn hard_prune(&mut self) {
227        let mask = self.hard_mask();
228        let weight_data = self.weight.data();
229        let w_vec = weight_data.to_vec();
230        let m_vec = mask.to_vec();
231
232        let pruned: Vec<f32> = w_vec
233            .iter()
234            .zip(m_vec.iter())
235            .map(|(&w, &m)| w * m)
236            .collect();
237
238        let new_weight = Tensor::from_vec(pruned, &[self.out_features, self.in_features])
239            .expect("tensor creation failed");
240        self.weight.update_data(new_weight);
241
242        // Reset thresholds to zero so forward pass doesn't re-prune
243        let zero_threshold = if self.structured {
244            zeros(&[self.out_features])
245        } else {
246            zeros(&[self.out_features, self.in_features])
247        };
248        self.threshold.update_data(zero_threshold);
249    }
250
251    /// Resets the threshold to a specific value.
252    ///
253    /// # Arguments
254    /// * `value` - The new threshold value
255    pub fn reset_threshold(&mut self, value: f32) {
256        let new_threshold = if self.structured {
257            constant(&[self.out_features], value)
258        } else {
259            constant(&[self.out_features, self.in_features], value)
260        };
261        self.threshold.update_data(new_threshold);
262    }
263
264    /// Returns the effective weight (weight * hard_mask) for inspection.
265    ///
266    /// This shows what the weight matrix looks like after hard pruning,
267    /// without actually modifying the layer.
268    pub fn effective_weight(&self) -> Tensor<f32> {
269        let mask = self.hard_mask();
270        let weight_data = self.weight.data();
271        let w_vec = weight_data.to_vec();
272        let m_vec = mask.to_vec();
273
274        let effective: Vec<f32> = w_vec
275            .iter()
276            .zip(m_vec.iter())
277            .map(|(&w, &m)| w * m)
278            .collect();
279
280        Tensor::from_vec(effective, &[self.out_features, self.in_features])
281            .expect("tensor creation failed")
282    }
283
284    /// Computes the soft mask using differentiable sigmoid thresholding.
285    ///
286    /// The soft mask is computed element-wise as:
287    /// ```text
288    /// mask_ij = sigmoid((|w_ij| - threshold_j) * temperature)
289    /// ```
290    ///
291    /// For structured pruning, `threshold_j` is broadcast across `in_features`.
292    /// For unstructured pruning, each weight has its own threshold.
293    fn compute_soft_mask(&self, weight_var: &Variable) -> Variable {
294        let weight_data = weight_var.data();
295        let threshold_data = self.threshold.data();
296        let w_vec = weight_data.to_vec();
297        let t_vec = threshold_data.to_vec();
298
299        // Compute sigmoid((|w| - threshold) * temperature) element-wise
300        let mask_vec: Vec<f32> = if self.structured {
301            w_vec
302                .iter()
303                .enumerate()
304                .map(|(idx, &w)| {
305                    let out_idx = idx / self.in_features;
306                    let t = t_vec[out_idx];
307                    let x = (w.abs() - t) * TEMPERATURE;
308                    1.0 / (1.0 + (-x).exp())
309                })
310                .collect()
311        } else {
312            w_vec
313                .iter()
314                .zip(t_vec.iter())
315                .map(|(&w, &t)| {
316                    let x = (w.abs() - t) * TEMPERATURE;
317                    1.0 / (1.0 + (-x).exp())
318                })
319                .collect()
320        };
321
322        let mask_tensor = Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
323            .expect("tensor creation failed");
324
325        // Create as a variable that participates in the graph
326        // The mask depends on both weight and threshold, but since we compute
327        // it from the raw tensor values, we wrap it as a new variable.
328        // The gradient signal flows through the weight multiplication below.
329        Variable::new(mask_tensor, false)
330    }
331}
332
333impl Module for SparseLinear {
334    fn forward(&self, input: &Variable) -> Variable {
335        let input_shape = input.shape();
336        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
337        let total_batch: usize = batch_dims.iter().product();
338
339        // Reshape to 2D if needed
340        let input_2d = if input_shape.len() > 2 {
341            input.reshape(&[total_batch, self.in_features])
342        } else {
343            input.clone()
344        };
345
346        // Get weight variable and compute soft mask
347        let weight_var = self.weight.variable();
348        let mask = self.compute_soft_mask(&weight_var);
349
350        // effective_weight = weight * mask
351        let effective_weight = weight_var.mul_var(&mask);
352
353        // y = x @ effective_weight^T
354        let weight_t = effective_weight.transpose(0, 1);
355        let mut output = input_2d.matmul(&weight_t);
356
357        // Add bias if present
358        if let Some(ref bias) = self.bias {
359            let bias_var = bias.variable();
360            output = output.add_var(&bias_var);
361        }
362
363        // Reshape back to original batch dimensions
364        if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
365            let mut output_shape: Vec<usize> = batch_dims;
366            output_shape.push(self.out_features);
367            output.reshape(&output_shape)
368        } else {
369            output
370        }
371    }
372
373    fn parameters(&self) -> Vec<Parameter> {
374        let mut params = vec![self.weight.clone(), self.threshold.clone()];
375        if let Some(ref bias) = self.bias {
376            params.push(bias.clone());
377        }
378        params
379    }
380
381    fn named_parameters(&self) -> HashMap<String, Parameter> {
382        let mut params = HashMap::new();
383        params.insert("weight".to_string(), self.weight.clone());
384        params.insert("threshold".to_string(), self.threshold.clone());
385        if let Some(ref bias) = self.bias {
386            params.insert("bias".to_string(), bias.clone());
387        }
388        params
389    }
390
391    fn name(&self) -> &'static str {
392        "SparseLinear"
393    }
394}
395
396impl std::fmt::Debug for SparseLinear {
397    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398        f.debug_struct("SparseLinear")
399            .field("in_features", &self.in_features)
400            .field("out_features", &self.out_features)
401            .field("bias", &self.bias.is_some())
402            .field("structured", &self.structured)
403            .field("density", &self.density())
404            .finish()
405    }
406}
407
408// =============================================================================
409// GroupSparsity
410// =============================================================================
411
412/// A regularization module that encourages structured sparsity via group L1 norm.
413///
414/// Computes a penalty term that can be added to the loss function:
415///
416/// ```text
417/// penalty = lambda * sum_g(||weight_g||_2)
418/// ```
419///
420/// where `weight_g` is a group of weights (e.g., all weights for one output neuron).
421/// This encourages entire groups to go to zero (structured pruning), since the L2
422/// norm within each group distributes the penalty equally among group members.
423///
424/// # Example
425/// ```ignore
426/// let reg = GroupSparsity::new(0.001, 128);  // lambda=0.001, group_size=128
427/// let penalty = reg.penalty(&model.weight_variable());
428/// let total_loss = task_loss.add_var(&penalty);
429/// ```
430pub struct GroupSparsity {
431    /// Regularization strength. Higher values encourage more sparsity.
432    lambda: f32,
433    /// Number of weights per group.
434    group_size: usize,
435}
436
437impl GroupSparsity {
438    /// Creates a new GroupSparsity regularizer.
439    ///
440    /// # Arguments
441    /// * `lambda` - Regularization strength (e.g., 0.001)
442    /// * `group_size` - Number of weights per group (e.g., in_features for neuron-level)
443    pub fn new(lambda: f32, group_size: usize) -> Self {
444        assert!(group_size > 0, "group_size must be positive");
445        Self { lambda, group_size }
446    }
447
448    /// Returns the regularization strength.
449    pub fn lambda(&self) -> f32 {
450        self.lambda
451    }
452
453    /// Returns the group size.
454    pub fn group_size(&self) -> usize {
455        self.group_size
456    }
457
458    /// Computes the group L1 penalty for the given weight variable.
459    ///
460    /// The penalty is computed as:
461    /// 1. Reshape weight into groups of size `group_size`
462    /// 2. Compute L2 norm of each group
463    /// 3. Sum all group norms (L1 of norms)
464    /// 4. Multiply by lambda
465    ///
466    /// Returns a scalar Variable that can be added to the loss.
467    pub fn penalty(&self, weight: &Variable) -> Variable {
468        let weight_data = weight.data();
469        let w_vec = weight_data.to_vec();
470        let total = w_vec.len();
471
472        // Number of complete groups
473        let num_groups = total.div_ceil(self.group_size);
474
475        // Compute L2 norm per group, then sum (L1 of group norms)
476        let mut group_norm_sum = 0.0f32;
477        for g in 0..num_groups {
478            let start = g * self.group_size;
479            let end = (start + self.group_size).min(total);
480            let group = &w_vec[start..end];
481
482            let l2_norm: f32 = group.iter().map(|&x| x * x).sum::<f32>().sqrt();
483            group_norm_sum += l2_norm;
484        }
485
486        let penalty_val = self.lambda * group_norm_sum;
487        let penalty_tensor =
488            Tensor::from_vec(vec![penalty_val], &[1]).expect("tensor creation failed");
489
490        // Create as a variable. The penalty is computed from raw tensor values
491        // for simplicity. For full autograd integration, one would implement a
492        // custom backward function, but the penalty is typically used alongside
493        // weight decay in the optimizer.
494        Variable::new(penalty_tensor, false)
495    }
496}
497
498impl std::fmt::Debug for GroupSparsity {
499    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500        f.debug_struct("GroupSparsity")
501            .field("lambda", &self.lambda)
502            .field("group_size", &self.group_size)
503            .finish()
504    }
505}
506
507// =============================================================================
508// LotteryTicket
509// =============================================================================
510
511/// Implements the Lottery Ticket Hypothesis (Frankle & Carlin, 2019).
512///
513/// The Lottery Ticket Hypothesis states that dense networks contain sparse
514/// subnetworks ("winning tickets") that can be trained in isolation to match
515/// the full network's accuracy.
516///
517/// This struct saves a snapshot of the initial weights, then after pruning,
518/// allows rewinding the unpruned weights back to their initial values while
519/// keeping the pruning mask.
520///
521/// # Workflow
522/// 1. Initialize network
523/// 2. `let ticket = LotteryTicket::snapshot(&model.parameters());`
524/// 3. Train network with pruning
525/// 4. Determine pruning mask
526/// 5. `ticket.rewind(&model.parameters());` — reset to initial weights
527/// 6. Apply mask and train again
528///
529/// # Example
530/// ```ignore
531/// let model = SparseLinear::new(784, 256);
532/// let ticket = LotteryTicket::snapshot(&model.parameters());
533///
534/// // ... train and prune ...
535///
536/// ticket.rewind(&model.parameters());  // Reset to initial weights
537/// // ... retrain with mask ...
538/// ```
539pub struct LotteryTicket {
540    /// Saved initial parameter values, keyed by parameter name or index.
541    initial_weights: HashMap<String, Tensor<f32>>,
542}
543
544impl LotteryTicket {
545    /// Takes a snapshot of the current parameter values.
546    ///
547    /// # Arguments
548    /// * `params` - Slice of parameters to snapshot
549    pub fn snapshot(params: &[Parameter]) -> Self {
550        let mut initial_weights = HashMap::new();
551        for (i, param) in params.iter().enumerate() {
552            let key = if param.name().is_empty() {
553                format!("param_{}", i)
554            } else {
555                param.name().to_string()
556            };
557            initial_weights.insert(key, param.data());
558        }
559        Self { initial_weights }
560    }
561
562    /// Returns the number of saved parameters.
563    pub fn num_saved(&self) -> usize {
564        self.initial_weights.len()
565    }
566
567    /// Rewinds all parameters to their initial (snapshot) values.
568    ///
569    /// # Arguments
570    /// * `params` - Slice of parameters to rewind (must match snapshot order)
571    pub fn rewind(&self, params: &[Parameter]) {
572        for (i, param) in params.iter().enumerate() {
573            let key = if param.name().is_empty() {
574                format!("param_{}", i)
575            } else {
576                param.name().to_string()
577            };
578            if let Some(initial) = self.initial_weights.get(&key) {
579                param.update_data(initial.clone());
580            }
581        }
582    }
583
584    /// Rewinds parameters to their initial values, but only where the mask is 1.
585    ///
586    /// Weights where `mask == 0` are set to zero (pruned). Weights where
587    /// `mask == 1` are reset to their initial snapshot values.
588    ///
589    /// # Arguments
590    /// * `params` - Slice of parameters to rewind
591    /// * `masks` - Corresponding binary masks (same length as params)
592    pub fn rewind_with_mask(&self, params: &[Parameter], masks: &[Tensor<f32>]) {
593        assert_eq!(
594            params.len(),
595            masks.len(),
596            "Number of parameters and masks must match"
597        );
598
599        for (i, (param, mask)) in params.iter().zip(masks.iter()).enumerate() {
600            let key = if param.name().is_empty() {
601                format!("param_{}", i)
602            } else {
603                param.name().to_string()
604            };
605
606            if let Some(initial) = self.initial_weights.get(&key) {
607                let init_vec = initial.to_vec();
608                let mask_vec = mask.to_vec();
609
610                let rewound: Vec<f32> = init_vec
611                    .iter()
612                    .zip(mask_vec.iter())
613                    .map(|(&w, &m)| if m > 0.5 { w } else { 0.0 })
614                    .collect();
615
616                let shape = param.shape();
617                let new_data = Tensor::from_vec(rewound, &shape).expect("tensor creation failed");
618                param.update_data(new_data);
619            }
620        }
621    }
622}
623
624impl std::fmt::Debug for LotteryTicket {
625    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626        f.debug_struct("LotteryTicket")
627            .field("num_saved", &self.initial_weights.len())
628            .finish()
629    }
630}
631
632// =============================================================================
633// Tests
634// =============================================================================
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639
640    // -------------------------------------------------------------------------
641    // SparseLinear Tests
642    // -------------------------------------------------------------------------
643
644    #[test]
645    fn test_sparse_linear_creation_structured() {
646        let layer = SparseLinear::new(10, 5);
647        assert_eq!(layer.in_features(), 10);
648        assert_eq!(layer.out_features(), 5);
649        assert!(layer.is_structured());
650        assert!(layer.bias.is_some());
651    }
652
653    #[test]
654    fn test_sparse_linear_creation_unstructured() {
655        let layer = SparseLinear::unstructured(10, 5);
656        assert_eq!(layer.in_features(), 10);
657        assert_eq!(layer.out_features(), 5);
658        assert!(!layer.is_structured());
659        assert!(layer.bias.is_some());
660    }
661
662    #[test]
663    fn test_sparse_linear_no_bias() {
664        let layer = SparseLinear::with_bias(10, 5, false);
665        assert!(layer.bias.is_none());
666    }
667
668    #[test]
669    fn test_sparse_linear_forward_shape() {
670        let layer = SparseLinear::new(4, 3);
671        let input = Variable::new(
672            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
673            false,
674        );
675        let output = layer.forward(&input);
676        assert_eq!(output.shape(), vec![1, 3]);
677    }
678
679    #[test]
680    fn test_sparse_linear_forward_batch() {
681        let layer = SparseLinear::new(4, 3);
682        let input = Variable::new(
683            Tensor::from_vec(vec![1.0; 12], &[3, 4]).expect("tensor creation failed"),
684            false,
685        );
686        let output = layer.forward(&input);
687        assert_eq!(output.shape(), vec![3, 3]);
688    }
689
690    #[test]
691    fn test_sparse_linear_forward_no_bias() {
692        let layer = SparseLinear::with_bias(4, 3, false);
693        let input = Variable::new(
694            Tensor::from_vec(vec![1.0; 8], &[2, 4]).expect("tensor creation failed"),
695            false,
696        );
697        let output = layer.forward(&input);
698        assert_eq!(output.shape(), vec![2, 3]);
699    }
700
701    #[test]
702    fn test_sparse_linear_density_initial() {
703        // With default threshold of 0.01, most Kaiming-initialized weights
704        // should be above threshold (density close to 1.0).
705        let layer = SparseLinear::new(100, 50);
706        let density = layer.density();
707        assert!(
708            density > 0.9,
709            "Initial density should be high, got {}",
710            density
711        );
712    }
713
714    #[test]
715    fn test_sparse_linear_sparsity_initial() {
716        let layer = SparseLinear::new(100, 50);
717        let sparsity = layer.sparsity();
718        assert!(
719            sparsity < 0.1,
720            "Initial sparsity should be low, got {}",
721            sparsity
722        );
723        assert!((layer.density() + layer.sparsity() - 1.0).abs() < 1e-6);
724    }
725
726    #[test]
727    fn test_sparse_linear_num_active() {
728        let layer = SparseLinear::new(10, 5);
729        let active = layer.num_active();
730        let total = 10 * 5;
731        assert!(active <= total);
732        assert!(active > 0);
733    }
734
735    #[test]
736    fn test_sparse_linear_high_threshold_more_sparsity() {
737        let mut layer = SparseLinear::new(100, 50);
738        let density_low_thresh = layer.density();
739
740        // Set high threshold — should prune more weights
741        layer.reset_threshold(10.0);
742        let density_high_thresh = layer.density();
743
744        assert!(
745            density_high_thresh < density_low_thresh,
746            "Higher threshold should reduce density: low_thresh={}, high_thresh={}",
747            density_low_thresh,
748            density_high_thresh
749        );
750    }
751
752    #[test]
753    fn test_sparse_linear_low_threshold_dense() {
754        let mut layer = SparseLinear::new(100, 50);
755        // Set threshold to zero — all weights should be active
756        layer.reset_threshold(0.0);
757        let density = layer.density();
758        assert!(
759            (density - 1.0).abs() < 1e-6,
760            "Zero threshold should give density=1.0, got {}",
761            density
762        );
763    }
764
765    #[test]
766    fn test_sparse_linear_soft_mask_values_in_range() {
767        let layer = SparseLinear::new(10, 5);
768        let weight_var = layer.weight.variable();
769        let mask = layer.compute_soft_mask(&weight_var);
770        let mask_vec = mask.data().to_vec();
771
772        for &v in &mask_vec {
773            assert!(v >= 0.0 && v <= 1.0, "Soft mask value {} not in [0, 1]", v);
774        }
775    }
776
777    #[test]
778    fn test_sparse_linear_hard_prune() {
779        let mut layer = SparseLinear::new(10, 5);
780        // Set a threshold that will prune some weights
781        layer.reset_threshold(0.5);
782
783        let pre_prune_density = layer.density();
784        layer.hard_prune();
785
786        // After hard prune, the zeroed weights should stay zero
787        let weight_data = layer.weight.data();
788        let w_vec = weight_data.to_vec();
789        let zeros_count = w_vec.iter().filter(|&&v| v == 0.0).count();
790
791        // The number of zeros should correspond to the pruned fraction
792        let expected_zeros = ((1.0 - pre_prune_density) * (10 * 5) as f32).round() as usize;
793        assert_eq!(
794            zeros_count, expected_zeros,
795            "Hard prune should zero out pruned weights"
796        );
797    }
798
799    #[test]
800    fn test_sparse_linear_hard_prune_threshold_reset() {
801        let mut layer = SparseLinear::new(10, 5);
802        layer.reset_threshold(0.5);
803        layer.hard_prune();
804
805        // After hard prune, thresholds should be zero
806        let t_vec = layer.threshold.data().to_vec();
807        assert!(
808            t_vec.iter().all(|&v| v == 0.0),
809            "Thresholds should be zero after hard_prune"
810        );
811    }
812
813    #[test]
814    fn test_sparse_linear_effective_weight() {
815        let layer = SparseLinear::new(10, 5);
816        let ew = layer.effective_weight();
817        assert_eq!(ew.shape(), &[5, 10]);
818    }
819
820    #[test]
821    fn test_sparse_linear_effective_weight_matches_hard_prune() {
822        let mut layer = SparseLinear::new(10, 5);
823        layer.reset_threshold(0.3);
824
825        let effective = layer.effective_weight();
826        layer.hard_prune();
827        let pruned = layer.weight.data();
828
829        let e_vec = effective.to_vec();
830        let p_vec = pruned.to_vec();
831        for (e, p) in e_vec.iter().zip(p_vec.iter()) {
832            assert!(
833                (e - p).abs() < 1e-6,
834                "effective_weight and hard_prune should match"
835            );
836        }
837    }
838
839    #[test]
840    fn test_sparse_linear_parameters_include_threshold() {
841        let layer = SparseLinear::new(10, 5);
842        let params = layer.parameters();
843        // weight + threshold + bias = 3
844        assert_eq!(params.len(), 3);
845
846        let named = layer.named_parameters();
847        assert!(named.contains_key("threshold"));
848        assert!(named.contains_key("weight"));
849        assert!(named.contains_key("bias"));
850    }
851
852    #[test]
853    fn test_sparse_linear_parameters_no_bias() {
854        let layer = SparseLinear::with_bias(10, 5, false);
855        let params = layer.parameters();
856        // weight + threshold = 2
857        assert_eq!(params.len(), 2);
858    }
859
860    #[test]
861    fn test_sparse_linear_module_name() {
862        let layer = SparseLinear::new(10, 5);
863        assert_eq!(layer.name(), "SparseLinear");
864    }
865
866    #[test]
867    fn test_sparse_linear_debug() {
868        let layer = SparseLinear::new(10, 5);
869        let debug_str = format!("{:?}", layer);
870        assert!(debug_str.contains("SparseLinear"));
871        assert!(debug_str.contains("in_features: 10"));
872        assert!(debug_str.contains("out_features: 5"));
873    }
874
875    #[test]
876    fn test_sparse_linear_reset_threshold() {
877        let mut layer = SparseLinear::new(10, 5);
878        layer.reset_threshold(0.5);
879        let t_vec = layer.threshold.data().to_vec();
880        assert!(t_vec.iter().all(|&v| (v - 0.5).abs() < 1e-6));
881    }
882
883    #[test]
884    fn test_sparse_linear_unstructured_threshold_shape() {
885        let layer = SparseLinear::unstructured(10, 5);
886        // Unstructured: threshold has same shape as weight
887        assert_eq!(layer.threshold.shape(), vec![5, 10]);
888    }
889
890    #[test]
891    fn test_sparse_linear_structured_threshold_shape() {
892        let layer = SparseLinear::new(10, 5);
893        // Structured: threshold has shape (out_features,)
894        assert_eq!(layer.threshold.shape(), vec![5]);
895    }
896
897    #[test]
898    fn test_sparse_linear_unstructured_forward() {
899        let layer = SparseLinear::unstructured(4, 3);
900        let input = Variable::new(
901            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
902                .expect("tensor creation failed"),
903            false,
904        );
905        let output = layer.forward(&input);
906        assert_eq!(output.shape(), vec![2, 3]);
907    }
908
909    // -------------------------------------------------------------------------
910    // GroupSparsity Tests
911    // -------------------------------------------------------------------------
912
913    #[test]
914    fn test_group_sparsity_creation() {
915        let reg = GroupSparsity::new(0.001, 10);
916        assert!((reg.lambda() - 0.001).abs() < 1e-8);
917        assert_eq!(reg.group_size(), 10);
918    }
919
920    #[test]
921    fn test_group_sparsity_penalty_non_negative() {
922        let reg = GroupSparsity::new(0.01, 4);
923        let weight = Variable::new(
924            Tensor::from_vec(vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0], &[2, 4])
925                .expect("tensor creation failed"),
926            true,
927        );
928        let penalty = reg.penalty(&weight);
929        let penalty_val = penalty.data().to_vec()[0];
930        assert!(
931            penalty_val >= 0.0,
932            "Penalty should be non-negative, got {}",
933            penalty_val
934        );
935    }
936
937    #[test]
938    fn test_group_sparsity_zero_weights_zero_penalty() {
939        let reg = GroupSparsity::new(0.01, 4);
940        let weight = Variable::new(
941            Tensor::from_vec(vec![0.0; 8], &[2, 4]).expect("tensor creation failed"),
942            true,
943        );
944        let penalty = reg.penalty(&weight);
945        let penalty_val = penalty.data().to_vec()[0];
946        assert!(
947            (penalty_val).abs() < 1e-6,
948            "Zero weights should give zero penalty, got {}",
949            penalty_val
950        );
951    }
952
953    #[test]
954    fn test_group_sparsity_scales_with_lambda() {
955        let reg_small = GroupSparsity::new(0.001, 4);
956        let reg_large = GroupSparsity::new(0.01, 4);
957        let weight = Variable::new(
958            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
959            true,
960        );
961
962        let penalty_small = reg_small.penalty(&weight).data().to_vec()[0];
963        let penalty_large = reg_large.penalty(&weight).data().to_vec()[0];
964
965        assert!(
966            penalty_large > penalty_small,
967            "Larger lambda should give larger penalty: small={}, large={}",
968            penalty_small,
969            penalty_large
970        );
971
972        // Should scale linearly with lambda
973        let ratio = penalty_large / penalty_small;
974        assert!(
975            (ratio - 10.0).abs() < 1e-4,
976            "Penalty should scale linearly with lambda, ratio={}",
977            ratio
978        );
979    }
980
981    #[test]
982    fn test_group_sparsity_debug() {
983        let reg = GroupSparsity::new(0.001, 10);
984        let debug_str = format!("{:?}", reg);
985        assert!(debug_str.contains("GroupSparsity"));
986        assert!(debug_str.contains("lambda"));
987    }
988
989    #[test]
990    #[should_panic(expected = "group_size must be positive")]
991    fn test_group_sparsity_zero_group_size_panics() {
992        let _reg = GroupSparsity::new(0.01, 0);
993    }
994
995    // -------------------------------------------------------------------------
996    // LotteryTicket Tests
997    // -------------------------------------------------------------------------
998
999    #[test]
1000    fn test_lottery_ticket_snapshot() {
1001        let layer = SparseLinear::new(10, 5);
1002        let params = layer.parameters();
1003        let ticket = LotteryTicket::snapshot(&params);
1004        assert_eq!(ticket.num_saved(), params.len());
1005    }
1006
1007    #[test]
1008    fn test_lottery_ticket_rewind() {
1009        let layer = SparseLinear::new(10, 5);
1010        let params = layer.parameters();
1011        let initial_weight = params[0].data().to_vec();
1012
1013        let ticket = LotteryTicket::snapshot(&params);
1014
1015        // Modify the weight
1016        let new_data = Tensor::from_vec(vec![99.0; 50], &[5, 10]).expect("tensor creation failed");
1017        params[0].update_data(new_data);
1018
1019        // Verify it changed
1020        let modified_weight = params[0].data().to_vec();
1021        assert_ne!(modified_weight, initial_weight);
1022
1023        // Rewind
1024        ticket.rewind(&params);
1025
1026        // Verify it's back to initial
1027        let rewound_weight = params[0].data().to_vec();
1028        assert_eq!(rewound_weight, initial_weight);
1029    }
1030
1031    #[test]
1032    fn test_lottery_ticket_rewind_preserves_shapes() {
1033        let layer = SparseLinear::new(10, 5);
1034        let params = layer.parameters();
1035        let initial_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1036
1037        let ticket = LotteryTicket::snapshot(&params);
1038
1039        // Modify weight data (same shape)
1040        let new_data = Tensor::from_vec(vec![0.0; 50], &[5, 10]).expect("tensor creation failed");
1041        params[0].update_data(new_data);
1042
1043        ticket.rewind(&params);
1044
1045        let rewound_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1046        assert_eq!(initial_shapes, rewound_shapes);
1047    }
1048
1049    #[test]
1050    fn test_lottery_ticket_rewind_with_mask() {
1051        let data =
1052            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor creation failed");
1053        let param = Parameter::named("weight", data, true);
1054        let params = vec![param];
1055
1056        let ticket = LotteryTicket::snapshot(&params);
1057
1058        // Modify the parameter
1059        let new_data = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2])
1060            .expect("tensor creation failed");
1061        params[0].update_data(new_data);
1062
1063        // Mask: keep first two, prune last two
1064        let mask =
1065            Tensor::from_vec(vec![1.0, 1.0, 0.0, 0.0], &[2, 2]).expect("tensor creation failed");
1066        ticket.rewind_with_mask(&params, &[mask]);
1067
1068        let result = params[0].data().to_vec();
1069        assert_eq!(
1070            result,
1071            vec![1.0, 2.0, 0.0, 0.0],
1072            "Masked weights should be zero, unmasked should be initial values"
1073        );
1074    }
1075
1076    #[test]
1077    fn test_lottery_ticket_debug() {
1078        let layer = SparseLinear::new(10, 5);
1079        let ticket = LotteryTicket::snapshot(&layer.parameters());
1080        let debug_str = format!("{:?}", ticket);
1081        assert!(debug_str.contains("LotteryTicket"));
1082        assert!(debug_str.contains("num_saved"));
1083    }
1084
1085    // -------------------------------------------------------------------------
1086    // Integration Tests
1087    // -------------------------------------------------------------------------
1088
1089    #[test]
1090    fn test_integration_sparse_linear_with_group_sparsity() {
1091        // Create a SparseLinear layer
1092        let layer = SparseLinear::new(8, 4);
1093
1094        // Forward pass
1095        let input = Variable::new(
1096            Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
1097            false,
1098        );
1099        let output = layer.forward(&input);
1100        assert_eq!(output.shape(), vec![2, 4]);
1101
1102        // Compute group sparsity penalty on the weights
1103        let reg = GroupSparsity::new(0.001, 8); // group_size = in_features
1104        let weight_var = layer.weight.variable();
1105        let penalty = reg.penalty(&weight_var);
1106        let penalty_val = penalty.data().to_vec()[0];
1107        assert!(
1108            penalty_val > 0.0,
1109            "Penalty should be positive for non-zero weights"
1110        );
1111    }
1112
1113    #[test]
1114    fn test_integration_lottery_ticket_with_pruning() {
1115        // 1. Create layer and snapshot
1116        let mut layer = SparseLinear::new(8, 4);
1117        let ticket = LotteryTicket::snapshot(&layer.parameters());
1118
1119        // 2. Simulate training (modify weights)
1120        let new_weight = Tensor::from_vec(vec![0.5; 32], &[4, 8]).expect("tensor creation failed");
1121        layer.weight.update_data(new_weight);
1122
1123        // 3. Set threshold to prune some weights
1124        layer.reset_threshold(0.3);
1125
1126        // 4. Get the effective weight mask
1127        let mask = layer.hard_mask();
1128
1129        // 5. Rewind to initial weights with mask
1130        let weight_param = vec![layer.weight.clone()];
1131        ticket.rewind_with_mask(&weight_param, &[mask]);
1132
1133        // Verify shape is preserved
1134        assert_eq!(layer.weight.shape(), vec![4, 8]);
1135    }
1136
1137    #[test]
1138    fn test_num_parameters_sparse_linear() {
1139        let layer = SparseLinear::new(10, 5);
1140        // weight: 50 + threshold: 5 + bias: 5 = 60
1141        assert_eq!(layer.num_parameters(), 60);
1142    }
1143}