Skip to main content

axonml_nn/layers/
sparse.rs

1//! Differentiable structured sparsity — a novel AxonML feature.
2//!
3//! 1147 lines. `SparseLinear` (Linear layer with a learnable soft-threshold
4//! pruning mask — the mask is differentiable via straight-through estimation,
5//! enabling end-to-end learning of which weights to prune). `GroupSparsity`
6//! (L2,1 regularization that drives entire rows/columns to zero for structured
7//! pruning). `LotteryTicket` (implements the lottery ticket hypothesis:
8//! train → prune → rewind to init → retrain the sparse subnetwork, with
9//! iterative magnitude pruning and score tracking).
10//!
11//! # File
12//! `crates/axonml-nn/src/layers/sparse.rs`
13//!
14//! # Author
15//! Andrew Jewell Sr. — AutomataNexus LLC
16//! ORCID: 0009-0005-2158-7060
17//!
18//! # Updated
19//! April 14, 2026 11:15 PM EST
20//!
21//! # Disclaimer
22//! Use at own risk. This software is provided "as is", without warranty of any
23//! kind, express or implied. The author and AutomataNexus shall not be held
24//! liable for any damages arising from the use of this software.
25
26use std::collections::HashMap;
27
28use axonml_autograd::Variable;
29use axonml_tensor::Tensor;
30
31use crate::init::{constant, kaiming_uniform, zeros};
32use crate::module::Module;
33use crate::parameter::Parameter;
34
35// =============================================================================
36// Constants
37// =============================================================================
38
39/// Temperature for the sigmoid soft thresholding.
40/// Higher values produce a sharper (more binary) mask.
41const TEMPERATURE: f32 = 10.0;
42
43/// Default initial threshold value.
44/// Small so that most weights start active.
45const DEFAULT_THRESHOLD: f32 = 0.01;
46
47// =============================================================================
48// SparseLinear
49// =============================================================================
50
51/// A linear layer with a differentiable magnitude pruning mask.
52///
53/// During the forward pass, a soft mask is computed via sigmoid soft thresholding:
54///
55/// ```text
56/// mask = sigmoid((|weight| - threshold) * temperature)
57/// effective_weight = weight * mask
58/// y = x @ effective_weight^T + bias
59/// ```
60///
61/// The sigmoid makes the mask differentiable, so gradients flow through it and
62/// the network learns which weights to prune. The `threshold` parameter is
63/// learnable and included in `parameters()`.
64///
65/// # Structured vs Unstructured
66///
67/// - **Structured** (`structured=true`): One threshold per output neuron.
68///   Entire output channels can be pruned, yielding hardware-friendly sparsity.
69/// - **Unstructured** (`structured=false`): One threshold per weight element.
70///   Finer-grained but less hardware-friendly.
71///
72/// # Example
73/// ```ignore
74/// let layer = SparseLinear::new(784, 256);
75/// let output = layer.forward(&input);
76/// println!("Density: {:.1}%", layer.density() * 100.0);
77/// ```
78pub struct SparseLinear {
79    /// Weight matrix of shape (out_features, in_features).
80    pub weight: Parameter,
81    /// Optional bias vector of shape (out_features).
82    pub bias: Option<Parameter>,
83    /// Learnable magnitude thresholds. Shape depends on `structured`:
84    /// - Structured: (out_features,)
85    /// - Unstructured: (out_features, in_features)
86    pub threshold: Parameter,
87    /// Input feature dimension.
88    in_features: usize,
89    /// Output feature dimension.
90    out_features: usize,
91    /// Whether to use structured (channel) pruning.
92    structured: bool,
93}
94
95impl SparseLinear {
96    /// Creates a new SparseLinear layer with structured pruning and bias.
97    ///
98    /// # Arguments
99    /// * `in_features` - Size of each input sample
100    /// * `out_features` - Size of each output sample
101    pub fn new(in_features: usize, out_features: usize) -> Self {
102        Self::build(in_features, out_features, true, true)
103    }
104
105    /// Creates a new SparseLinear layer with unstructured (per-weight) pruning.
106    ///
107    /// # Arguments
108    /// * `in_features` - Size of each input sample
109    /// * `out_features` - Size of each output sample
110    pub fn unstructured(in_features: usize, out_features: usize) -> Self {
111        Self::build(in_features, out_features, false, true)
112    }
113
114    /// Creates a new SparseLinear layer with configurable bias.
115    ///
116    /// # Arguments
117    /// * `in_features` - Size of each input sample
118    /// * `out_features` - Size of each output sample
119    /// * `bias` - Whether to include a learnable bias
120    pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
121        Self::build(in_features, out_features, true, bias)
122    }
123
124    /// Internal constructor.
125    fn build(in_features: usize, out_features: usize, structured: bool, bias: bool) -> Self {
126        // Kaiming uniform initialization for weights
127        let weight_data = kaiming_uniform(out_features, in_features);
128        let weight = Parameter::named("weight", weight_data, true);
129
130        // Bias initialization
131        let bias_param = if bias {
132            let bias_data = zeros(&[out_features]);
133            Some(Parameter::named("bias", bias_data, true))
134        } else {
135            None
136        };
137
138        // Threshold initialization — small value so most weights start active
139        let threshold_data = if structured {
140            constant(&[out_features], DEFAULT_THRESHOLD)
141        } else {
142            constant(&[out_features, in_features], DEFAULT_THRESHOLD)
143        };
144        let threshold = Parameter::named("threshold", threshold_data, true);
145
146        Self {
147            weight,
148            bias: bias_param,
149            threshold,
150            in_features,
151            out_features,
152            structured,
153        }
154    }
155
156    /// Returns the input feature dimension.
157    pub fn in_features(&self) -> usize {
158        self.in_features
159    }
160
161    /// Returns the output feature dimension.
162    pub fn out_features(&self) -> usize {
163        self.out_features
164    }
165
166    /// Returns whether this layer uses structured pruning.
167    pub fn is_structured(&self) -> bool {
168        self.structured
169    }
170
171    /// Computes the hard binary mask (threshold at 0.5) from the current weights
172    /// and thresholds.
173    ///
174    /// Returns a Tensor of 0s and 1s with the same shape as weight.
175    fn hard_mask(&self) -> Tensor<f32> {
176        let weight_data = self.weight.data();
177        let threshold_data = self.threshold.data();
178        let w_vec = weight_data.to_vec();
179        let t_vec = threshold_data.to_vec();
180
181        let mask_vec: Vec<f32> = if self.structured {
182            // One threshold per output neuron — broadcast across in_features
183            w_vec
184                .iter()
185                .enumerate()
186                .map(|(idx, &w)| {
187                    let out_idx = idx / self.in_features;
188                    let t = t_vec[out_idx];
189                    if w.abs() >= t { 1.0 } else { 0.0 }
190                })
191                .collect()
192        } else {
193            // One threshold per weight
194            w_vec
195                .iter()
196                .zip(t_vec.iter())
197                .map(|(&w, &t)| if w.abs() >= t { 1.0 } else { 0.0 })
198                .collect()
199        };
200
201        Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
202            .expect("tensor creation failed")
203    }
204
205    /// Returns the fraction of weights that are active (above threshold).
206    ///
207    /// Uses hard thresholding at |weight| >= threshold.
208    pub fn density(&self) -> f32 {
209        let mask = self.hard_mask();
210        let mask_vec = mask.to_vec();
211        let total = mask_vec.len() as f32;
212        let active: f32 = mask_vec.iter().sum();
213        active / total
214    }
215
216    /// Returns the fraction of weights that are pruned.
217    ///
218    /// Equivalent to `1.0 - density()`.
219    pub fn sparsity(&self) -> f32 {
220        1.0 - self.density()
221    }
222
223    /// Returns the number of active (non-pruned) weights.
224    pub fn num_active(&self) -> usize {
225        let mask = self.hard_mask();
226        let mask_vec = mask.to_vec();
227        mask_vec.iter().filter(|&&v| v > 0.5).count()
228    }
229
230    /// Permanently applies the pruning mask to the weights.
231    ///
232    /// After calling this, pruned weights are zeroed and the threshold is reset
233    /// to zero. This is an irreversible optimization for inference — the zeroed
234    /// weights will not be recovered.
235    pub fn hard_prune(&mut self) {
236        let mask = self.hard_mask();
237        let weight_data = self.weight.data();
238        let w_vec = weight_data.to_vec();
239        let m_vec = mask.to_vec();
240
241        let pruned: Vec<f32> = w_vec
242            .iter()
243            .zip(m_vec.iter())
244            .map(|(&w, &m)| w * m)
245            .collect();
246
247        let new_weight = Tensor::from_vec(pruned, &[self.out_features, self.in_features])
248            .expect("tensor creation failed");
249        self.weight.update_data(new_weight);
250
251        // Reset thresholds to zero so forward pass doesn't re-prune
252        let zero_threshold = if self.structured {
253            zeros(&[self.out_features])
254        } else {
255            zeros(&[self.out_features, self.in_features])
256        };
257        self.threshold.update_data(zero_threshold);
258    }
259
260    /// Resets the threshold to a specific value.
261    ///
262    /// # Arguments
263    /// * `value` - The new threshold value
264    pub fn reset_threshold(&mut self, value: f32) {
265        let new_threshold = if self.structured {
266            constant(&[self.out_features], value)
267        } else {
268            constant(&[self.out_features, self.in_features], value)
269        };
270        self.threshold.update_data(new_threshold);
271    }
272
273    /// Returns the effective weight (weight * hard_mask) for inspection.
274    ///
275    /// This shows what the weight matrix looks like after hard pruning,
276    /// without actually modifying the layer.
277    pub fn effective_weight(&self) -> Tensor<f32> {
278        let mask = self.hard_mask();
279        let weight_data = self.weight.data();
280        let w_vec = weight_data.to_vec();
281        let m_vec = mask.to_vec();
282
283        let effective: Vec<f32> = w_vec
284            .iter()
285            .zip(m_vec.iter())
286            .map(|(&w, &m)| w * m)
287            .collect();
288
289        Tensor::from_vec(effective, &[self.out_features, self.in_features])
290            .expect("tensor creation failed")
291    }
292
293    /// Computes the soft mask using differentiable sigmoid thresholding.
294    ///
295    /// The soft mask is computed element-wise as:
296    /// ```text
297    /// mask_ij = sigmoid((|w_ij| - threshold_j) * temperature)
298    /// ```
299    ///
300    /// For structured pruning, `threshold_j` is broadcast across `in_features`.
301    /// For unstructured pruning, each weight has its own threshold.
302    fn compute_soft_mask(&self, weight_var: &Variable) -> Variable {
303        let weight_data = weight_var.data();
304        let threshold_data = self.threshold.data();
305        let w_vec = weight_data.to_vec();
306        let t_vec = threshold_data.to_vec();
307
308        // Compute sigmoid((|w| - threshold) * temperature) element-wise
309        let mask_vec: Vec<f32> = if self.structured {
310            w_vec
311                .iter()
312                .enumerate()
313                .map(|(idx, &w)| {
314                    let out_idx = idx / self.in_features;
315                    let t = t_vec[out_idx];
316                    let x = (w.abs() - t) * TEMPERATURE;
317                    1.0 / (1.0 + (-x).exp())
318                })
319                .collect()
320        } else {
321            w_vec
322                .iter()
323                .zip(t_vec.iter())
324                .map(|(&w, &t)| {
325                    let x = (w.abs() - t) * TEMPERATURE;
326                    1.0 / (1.0 + (-x).exp())
327                })
328                .collect()
329        };
330
331        let mask_tensor = Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
332            .expect("tensor creation failed");
333
334        // Create as a variable that participates in the graph
335        // The mask depends on both weight and threshold, but since we compute
336        // it from the raw tensor values, we wrap it as a new variable.
337        // The gradient signal flows through the weight multiplication below.
338        Variable::new(mask_tensor, false)
339    }
340}
341
342impl Module for SparseLinear {
343    fn forward(&self, input: &Variable) -> Variable {
344        let input_shape = input.shape();
345        let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
346        let total_batch: usize = batch_dims.iter().product();
347
348        // Reshape to 2D if needed
349        let input_2d = if input_shape.len() > 2 {
350            input.reshape(&[total_batch, self.in_features])
351        } else {
352            input.clone()
353        };
354
355        // Get weight variable and compute soft mask
356        let weight_var = self.weight.variable();
357        let mask = self.compute_soft_mask(&weight_var);
358
359        // effective_weight = weight * mask
360        let effective_weight = weight_var.mul_var(&mask);
361
362        // y = x @ effective_weight^T
363        let weight_t = effective_weight.transpose(0, 1);
364        let mut output = input_2d.matmul(&weight_t);
365
366        // Add bias if present
367        if let Some(ref bias) = self.bias {
368            let bias_var = bias.variable();
369            output = output.add_var(&bias_var);
370        }
371
372        // Reshape back to original batch dimensions
373        if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
374            let mut output_shape: Vec<usize> = batch_dims;
375            output_shape.push(self.out_features);
376            output.reshape(&output_shape)
377        } else {
378            output
379        }
380    }
381
382    fn parameters(&self) -> Vec<Parameter> {
383        let mut params = vec![self.weight.clone(), self.threshold.clone()];
384        if let Some(ref bias) = self.bias {
385            params.push(bias.clone());
386        }
387        params
388    }
389
390    fn named_parameters(&self) -> HashMap<String, Parameter> {
391        let mut params = HashMap::new();
392        params.insert("weight".to_string(), self.weight.clone());
393        params.insert("threshold".to_string(), self.threshold.clone());
394        if let Some(ref bias) = self.bias {
395            params.insert("bias".to_string(), bias.clone());
396        }
397        params
398    }
399
400    fn name(&self) -> &'static str {
401        "SparseLinear"
402    }
403}
404
405impl std::fmt::Debug for SparseLinear {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        f.debug_struct("SparseLinear")
408            .field("in_features", &self.in_features)
409            .field("out_features", &self.out_features)
410            .field("bias", &self.bias.is_some())
411            .field("structured", &self.structured)
412            .field("density", &self.density())
413            .finish()
414    }
415}
416
417// =============================================================================
418// GroupSparsity
419// =============================================================================
420
421/// A regularization module that encourages structured sparsity via group L1 norm.
422///
423/// Computes a penalty term that can be added to the loss function:
424///
425/// ```text
426/// penalty = lambda * sum_g(||weight_g||_2)
427/// ```
428///
429/// where `weight_g` is a group of weights (e.g., all weights for one output neuron).
430/// This encourages entire groups to go to zero (structured pruning), since the L2
431/// norm within each group distributes the penalty equally among group members.
432///
433/// # Example
434/// ```ignore
435/// let reg = GroupSparsity::new(0.001, 128);  // lambda=0.001, group_size=128
436/// let penalty = reg.penalty(&model.weight_variable());
437/// let total_loss = task_loss.add_var(&penalty);
438/// ```
439pub struct GroupSparsity {
440    /// Regularization strength. Higher values encourage more sparsity.
441    lambda: f32,
442    /// Number of weights per group.
443    group_size: usize,
444}
445
446impl GroupSparsity {
447    /// Creates a new GroupSparsity regularizer.
448    ///
449    /// # Arguments
450    /// * `lambda` - Regularization strength (e.g., 0.001)
451    /// * `group_size` - Number of weights per group (e.g., in_features for neuron-level)
452    pub fn new(lambda: f32, group_size: usize) -> Self {
453        assert!(group_size > 0, "group_size must be positive");
454        Self { lambda, group_size }
455    }
456
457    /// Returns the regularization strength.
458    pub fn lambda(&self) -> f32 {
459        self.lambda
460    }
461
462    /// Returns the group size.
463    pub fn group_size(&self) -> usize {
464        self.group_size
465    }
466
467    /// Computes the group L1 penalty for the given weight variable.
468    ///
469    /// The penalty is computed as:
470    /// 1. Reshape weight into groups of size `group_size`
471    /// 2. Compute L2 norm of each group
472    /// 3. Sum all group norms (L1 of norms)
473    /// 4. Multiply by lambda
474    ///
475    /// Returns a scalar Variable that can be added to the loss.
476    pub fn penalty(&self, weight: &Variable) -> Variable {
477        let weight_data = weight.data();
478        let w_vec = weight_data.to_vec();
479        let total = w_vec.len();
480
481        // Number of complete groups
482        let num_groups = total.div_ceil(self.group_size);
483
484        // Compute L2 norm per group, then sum (L1 of group norms)
485        let mut group_norm_sum = 0.0f32;
486        for g in 0..num_groups {
487            let start = g * self.group_size;
488            let end = (start + self.group_size).min(total);
489            let group = &w_vec[start..end];
490
491            let l2_norm: f32 = group.iter().map(|&x| x * x).sum::<f32>().sqrt();
492            group_norm_sum += l2_norm;
493        }
494
495        let penalty_val = self.lambda * group_norm_sum;
496        let penalty_tensor =
497            Tensor::from_vec(vec![penalty_val], &[1]).expect("tensor creation failed");
498
499        // Create as a variable. The penalty is computed from raw tensor values
500        // for simplicity. For full autograd integration, one would implement a
501        // custom backward function, but the penalty is typically used alongside
502        // weight decay in the optimizer.
503        Variable::new(penalty_tensor, false)
504    }
505}
506
507impl std::fmt::Debug for GroupSparsity {
508    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509        f.debug_struct("GroupSparsity")
510            .field("lambda", &self.lambda)
511            .field("group_size", &self.group_size)
512            .finish()
513    }
514}
515
516// =============================================================================
517// LotteryTicket
518// =============================================================================
519
520/// Implements the Lottery Ticket Hypothesis (Frankle & Carlin, 2019).
521///
522/// The Lottery Ticket Hypothesis states that dense networks contain sparse
523/// subnetworks ("winning tickets") that can be trained in isolation to match
524/// the full network's accuracy.
525///
526/// This struct saves a snapshot of the initial weights, then after pruning,
527/// allows rewinding the unpruned weights back to their initial values while
528/// keeping the pruning mask.
529///
530/// # Workflow
531/// 1. Initialize network
532/// 2. `let ticket = LotteryTicket::snapshot(&model.parameters());`
533/// 3. Train network with pruning
534/// 4. Determine pruning mask
535/// 5. `ticket.rewind(&model.parameters());` — reset to initial weights
536/// 6. Apply mask and train again
537///
538/// # Example
539/// ```ignore
540/// let model = SparseLinear::new(784, 256);
541/// let ticket = LotteryTicket::snapshot(&model.parameters());
542///
543/// // ... train and prune ...
544///
545/// ticket.rewind(&model.parameters());  // Reset to initial weights
546/// // ... retrain with mask ...
547/// ```
548pub struct LotteryTicket {
549    /// Saved initial parameter values, keyed by parameter name or index.
550    initial_weights: HashMap<String, Tensor<f32>>,
551}
552
553impl LotteryTicket {
554    /// Takes a snapshot of the current parameter values.
555    ///
556    /// # Arguments
557    /// * `params` - Slice of parameters to snapshot
558    pub fn snapshot(params: &[Parameter]) -> Self {
559        let mut initial_weights = HashMap::new();
560        for (i, param) in params.iter().enumerate() {
561            let key = if param.name().is_empty() {
562                format!("param_{}", i)
563            } else {
564                param.name().to_string()
565            };
566            initial_weights.insert(key, param.data());
567        }
568        Self { initial_weights }
569    }
570
571    /// Returns the number of saved parameters.
572    pub fn num_saved(&self) -> usize {
573        self.initial_weights.len()
574    }
575
576    /// Rewinds all parameters to their initial (snapshot) values.
577    ///
578    /// # Arguments
579    /// * `params` - Slice of parameters to rewind (must match snapshot order)
580    pub fn rewind(&self, params: &[Parameter]) {
581        for (i, param) in params.iter().enumerate() {
582            let key = if param.name().is_empty() {
583                format!("param_{}", i)
584            } else {
585                param.name().to_string()
586            };
587            if let Some(initial) = self.initial_weights.get(&key) {
588                param.update_data(initial.clone());
589            }
590        }
591    }
592
593    /// Rewinds parameters to their initial values, but only where the mask is 1.
594    ///
595    /// Weights where `mask == 0` are set to zero (pruned). Weights where
596    /// `mask == 1` are reset to their initial snapshot values.
597    ///
598    /// # Arguments
599    /// * `params` - Slice of parameters to rewind
600    /// * `masks` - Corresponding binary masks (same length as params)
601    pub fn rewind_with_mask(&self, params: &[Parameter], masks: &[Tensor<f32>]) {
602        assert_eq!(
603            params.len(),
604            masks.len(),
605            "Number of parameters and masks must match"
606        );
607
608        for (i, (param, mask)) in params.iter().zip(masks.iter()).enumerate() {
609            let key = if param.name().is_empty() {
610                format!("param_{}", i)
611            } else {
612                param.name().to_string()
613            };
614
615            if let Some(initial) = self.initial_weights.get(&key) {
616                let init_vec = initial.to_vec();
617                let mask_vec = mask.to_vec();
618
619                let rewound: Vec<f32> = init_vec
620                    .iter()
621                    .zip(mask_vec.iter())
622                    .map(|(&w, &m)| if m > 0.5 { w } else { 0.0 })
623                    .collect();
624
625                let shape = param.shape();
626                let new_data = Tensor::from_vec(rewound, &shape).expect("tensor creation failed");
627                param.update_data(new_data);
628            }
629        }
630    }
631}
632
633impl std::fmt::Debug for LotteryTicket {
634    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
635        f.debug_struct("LotteryTicket")
636            .field("num_saved", &self.initial_weights.len())
637            .finish()
638    }
639}
640
641// =============================================================================
642// Tests
643// =============================================================================
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    // -------------------------------------------------------------------------
650    // SparseLinear Tests
651    // -------------------------------------------------------------------------
652
653    #[test]
654    fn test_sparse_linear_creation_structured() {
655        let layer = SparseLinear::new(10, 5);
656        assert_eq!(layer.in_features(), 10);
657        assert_eq!(layer.out_features(), 5);
658        assert!(layer.is_structured());
659        assert!(layer.bias.is_some());
660    }
661
662    #[test]
663    fn test_sparse_linear_creation_unstructured() {
664        let layer = SparseLinear::unstructured(10, 5);
665        assert_eq!(layer.in_features(), 10);
666        assert_eq!(layer.out_features(), 5);
667        assert!(!layer.is_structured());
668        assert!(layer.bias.is_some());
669    }
670
671    #[test]
672    fn test_sparse_linear_no_bias() {
673        let layer = SparseLinear::with_bias(10, 5, false);
674        assert!(layer.bias.is_none());
675    }
676
677    #[test]
678    fn test_sparse_linear_forward_shape() {
679        let layer = SparseLinear::new(4, 3);
680        let input = Variable::new(
681            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
682            false,
683        );
684        let output = layer.forward(&input);
685        assert_eq!(output.shape(), vec![1, 3]);
686    }
687
688    #[test]
689    fn test_sparse_linear_forward_batch() {
690        let layer = SparseLinear::new(4, 3);
691        let input = Variable::new(
692            Tensor::from_vec(vec![1.0; 12], &[3, 4]).expect("tensor creation failed"),
693            false,
694        );
695        let output = layer.forward(&input);
696        assert_eq!(output.shape(), vec![3, 3]);
697    }
698
699    #[test]
700    fn test_sparse_linear_forward_no_bias() {
701        let layer = SparseLinear::with_bias(4, 3, false);
702        let input = Variable::new(
703            Tensor::from_vec(vec![1.0; 8], &[2, 4]).expect("tensor creation failed"),
704            false,
705        );
706        let output = layer.forward(&input);
707        assert_eq!(output.shape(), vec![2, 3]);
708    }
709
710    #[test]
711    fn test_sparse_linear_density_initial() {
712        // With default threshold of 0.01, most Kaiming-initialized weights
713        // should be above threshold (density close to 1.0).
714        let layer = SparseLinear::new(100, 50);
715        let density = layer.density();
716        assert!(
717            density > 0.9,
718            "Initial density should be high, got {}",
719            density
720        );
721    }
722
723    #[test]
724    fn test_sparse_linear_sparsity_initial() {
725        let layer = SparseLinear::new(100, 50);
726        let sparsity = layer.sparsity();
727        assert!(
728            sparsity < 0.1,
729            "Initial sparsity should be low, got {}",
730            sparsity
731        );
732        assert!((layer.density() + layer.sparsity() - 1.0).abs() < 1e-6);
733    }
734
735    #[test]
736    fn test_sparse_linear_num_active() {
737        let layer = SparseLinear::new(10, 5);
738        let active = layer.num_active();
739        let total = 10 * 5;
740        assert!(active <= total);
741        assert!(active > 0);
742    }
743
744    #[test]
745    fn test_sparse_linear_high_threshold_more_sparsity() {
746        let mut layer = SparseLinear::new(100, 50);
747        let density_low_thresh = layer.density();
748
749        // Set high threshold — should prune more weights
750        layer.reset_threshold(10.0);
751        let density_high_thresh = layer.density();
752
753        assert!(
754            density_high_thresh < density_low_thresh,
755            "Higher threshold should reduce density: low_thresh={}, high_thresh={}",
756            density_low_thresh,
757            density_high_thresh
758        );
759    }
760
761    #[test]
762    fn test_sparse_linear_low_threshold_dense() {
763        let mut layer = SparseLinear::new(100, 50);
764        // Set threshold to zero — all weights should be active
765        layer.reset_threshold(0.0);
766        let density = layer.density();
767        assert!(
768            (density - 1.0).abs() < 1e-6,
769            "Zero threshold should give density=1.0, got {}",
770            density
771        );
772    }
773
774    #[test]
775    fn test_sparse_linear_soft_mask_values_in_range() {
776        let layer = SparseLinear::new(10, 5);
777        let weight_var = layer.weight.variable();
778        let mask = layer.compute_soft_mask(&weight_var);
779        let mask_vec = mask.data().to_vec();
780
781        for &v in &mask_vec {
782            assert!(
783                (0.0..=1.0).contains(&v),
784                "Soft mask value {} not in [0, 1]",
785                v
786            );
787        }
788    }
789
790    #[test]
791    fn test_sparse_linear_hard_prune() {
792        let mut layer = SparseLinear::new(10, 5);
793        // Set a threshold that will prune some weights
794        layer.reset_threshold(0.5);
795
796        let pre_prune_density = layer.density();
797        layer.hard_prune();
798
799        // After hard prune, the zeroed weights should stay zero
800        let weight_data = layer.weight.data();
801        let w_vec = weight_data.to_vec();
802        let zeros_count = w_vec.iter().filter(|&&v| v == 0.0).count();
803
804        // The number of zeros should correspond to the pruned fraction
805        let expected_zeros = ((1.0 - pre_prune_density) * (10 * 5) as f32).round() as usize;
806        assert_eq!(
807            zeros_count, expected_zeros,
808            "Hard prune should zero out pruned weights"
809        );
810    }
811
812    #[test]
813    fn test_sparse_linear_hard_prune_threshold_reset() {
814        let mut layer = SparseLinear::new(10, 5);
815        layer.reset_threshold(0.5);
816        layer.hard_prune();
817
818        // After hard prune, thresholds should be zero
819        let t_vec = layer.threshold.data().to_vec();
820        assert!(
821            t_vec.iter().all(|&v| v == 0.0),
822            "Thresholds should be zero after hard_prune"
823        );
824    }
825
826    #[test]
827    fn test_sparse_linear_effective_weight() {
828        let layer = SparseLinear::new(10, 5);
829        let ew = layer.effective_weight();
830        assert_eq!(ew.shape(), &[5, 10]);
831    }
832
833    #[test]
834    fn test_sparse_linear_effective_weight_matches_hard_prune() {
835        let mut layer = SparseLinear::new(10, 5);
836        layer.reset_threshold(0.3);
837
838        let effective = layer.effective_weight();
839        layer.hard_prune();
840        let pruned = layer.weight.data();
841
842        let e_vec = effective.to_vec();
843        let p_vec = pruned.to_vec();
844        for (e, p) in e_vec.iter().zip(p_vec.iter()) {
845            assert!(
846                (e - p).abs() < 1e-6,
847                "effective_weight and hard_prune should match"
848            );
849        }
850    }
851
852    #[test]
853    fn test_sparse_linear_parameters_include_threshold() {
854        let layer = SparseLinear::new(10, 5);
855        let params = layer.parameters();
856        // weight + threshold + bias = 3
857        assert_eq!(params.len(), 3);
858
859        let named = layer.named_parameters();
860        assert!(named.contains_key("threshold"));
861        assert!(named.contains_key("weight"));
862        assert!(named.contains_key("bias"));
863    }
864
865    #[test]
866    fn test_sparse_linear_parameters_no_bias() {
867        let layer = SparseLinear::with_bias(10, 5, false);
868        let params = layer.parameters();
869        // weight + threshold = 2
870        assert_eq!(params.len(), 2);
871    }
872
873    #[test]
874    fn test_sparse_linear_module_name() {
875        let layer = SparseLinear::new(10, 5);
876        assert_eq!(layer.name(), "SparseLinear");
877    }
878
879    #[test]
880    fn test_sparse_linear_debug() {
881        let layer = SparseLinear::new(10, 5);
882        let debug_str = format!("{:?}", layer);
883        assert!(debug_str.contains("SparseLinear"));
884        assert!(debug_str.contains("in_features: 10"));
885        assert!(debug_str.contains("out_features: 5"));
886    }
887
888    #[test]
889    fn test_sparse_linear_reset_threshold() {
890        let mut layer = SparseLinear::new(10, 5);
891        layer.reset_threshold(0.5);
892        let t_vec = layer.threshold.data().to_vec();
893        assert!(t_vec.iter().all(|&v| (v - 0.5).abs() < 1e-6));
894    }
895
896    #[test]
897    fn test_sparse_linear_unstructured_threshold_shape() {
898        let layer = SparseLinear::unstructured(10, 5);
899        // Unstructured: threshold has same shape as weight
900        assert_eq!(layer.threshold.shape(), vec![5, 10]);
901    }
902
903    #[test]
904    fn test_sparse_linear_structured_threshold_shape() {
905        let layer = SparseLinear::new(10, 5);
906        // Structured: threshold has shape (out_features,)
907        assert_eq!(layer.threshold.shape(), vec![5]);
908    }
909
910    #[test]
911    fn test_sparse_linear_unstructured_forward() {
912        let layer = SparseLinear::unstructured(4, 3);
913        let input = Variable::new(
914            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
915                .expect("tensor creation failed"),
916            false,
917        );
918        let output = layer.forward(&input);
919        assert_eq!(output.shape(), vec![2, 3]);
920    }
921
922    // -------------------------------------------------------------------------
923    // GroupSparsity Tests
924    // -------------------------------------------------------------------------
925
926    #[test]
927    fn test_group_sparsity_creation() {
928        let reg = GroupSparsity::new(0.001, 10);
929        assert!((reg.lambda() - 0.001).abs() < 1e-8);
930        assert_eq!(reg.group_size(), 10);
931    }
932
933    #[test]
934    fn test_group_sparsity_penalty_non_negative() {
935        let reg = GroupSparsity::new(0.01, 4);
936        let weight = Variable::new(
937            Tensor::from_vec(vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0], &[2, 4])
938                .expect("tensor creation failed"),
939            true,
940        );
941        let penalty = reg.penalty(&weight);
942        let penalty_val = penalty.data().to_vec()[0];
943        assert!(
944            penalty_val >= 0.0,
945            "Penalty should be non-negative, got {}",
946            penalty_val
947        );
948    }
949
950    #[test]
951    fn test_group_sparsity_zero_weights_zero_penalty() {
952        let reg = GroupSparsity::new(0.01, 4);
953        let weight = Variable::new(
954            Tensor::from_vec(vec![0.0; 8], &[2, 4]).expect("tensor creation failed"),
955            true,
956        );
957        let penalty = reg.penalty(&weight);
958        let penalty_val = penalty.data().to_vec()[0];
959        assert!(
960            (penalty_val).abs() < 1e-6,
961            "Zero weights should give zero penalty, got {}",
962            penalty_val
963        );
964    }
965
966    #[test]
967    fn test_group_sparsity_scales_with_lambda() {
968        let reg_small = GroupSparsity::new(0.001, 4);
969        let reg_large = GroupSparsity::new(0.01, 4);
970        let weight = Variable::new(
971            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
972            true,
973        );
974
975        let penalty_small = reg_small.penalty(&weight).data().to_vec()[0];
976        let penalty_large = reg_large.penalty(&weight).data().to_vec()[0];
977
978        assert!(
979            penalty_large > penalty_small,
980            "Larger lambda should give larger penalty: small={}, large={}",
981            penalty_small,
982            penalty_large
983        );
984
985        // Should scale linearly with lambda
986        let ratio = penalty_large / penalty_small;
987        assert!(
988            (ratio - 10.0).abs() < 1e-4,
989            "Penalty should scale linearly with lambda, ratio={}",
990            ratio
991        );
992    }
993
994    #[test]
995    fn test_group_sparsity_debug() {
996        let reg = GroupSparsity::new(0.001, 10);
997        let debug_str = format!("{:?}", reg);
998        assert!(debug_str.contains("GroupSparsity"));
999        assert!(debug_str.contains("lambda"));
1000    }
1001
1002    #[test]
1003    #[should_panic(expected = "group_size must be positive")]
1004    fn test_group_sparsity_zero_group_size_panics() {
1005        let _reg = GroupSparsity::new(0.01, 0);
1006    }
1007
1008    // -------------------------------------------------------------------------
1009    // LotteryTicket Tests
1010    // -------------------------------------------------------------------------
1011
1012    #[test]
1013    fn test_lottery_ticket_snapshot() {
1014        let layer = SparseLinear::new(10, 5);
1015        let params = layer.parameters();
1016        let ticket = LotteryTicket::snapshot(&params);
1017        assert_eq!(ticket.num_saved(), params.len());
1018    }
1019
1020    #[test]
1021    fn test_lottery_ticket_rewind() {
1022        let layer = SparseLinear::new(10, 5);
1023        let params = layer.parameters();
1024        let initial_weight = params[0].data().to_vec();
1025
1026        let ticket = LotteryTicket::snapshot(&params);
1027
1028        // Modify the weight
1029        let new_data = Tensor::from_vec(vec![99.0; 50], &[5, 10]).expect("tensor creation failed");
1030        params[0].update_data(new_data);
1031
1032        // Verify it changed
1033        let modified_weight = params[0].data().to_vec();
1034        assert_ne!(modified_weight, initial_weight);
1035
1036        // Rewind
1037        ticket.rewind(&params);
1038
1039        // Verify it's back to initial
1040        let rewound_weight = params[0].data().to_vec();
1041        assert_eq!(rewound_weight, initial_weight);
1042    }
1043
1044    #[test]
1045    fn test_lottery_ticket_rewind_preserves_shapes() {
1046        let layer = SparseLinear::new(10, 5);
1047        let params = layer.parameters();
1048        let initial_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1049
1050        let ticket = LotteryTicket::snapshot(&params);
1051
1052        // Modify weight data (same shape)
1053        let new_data = Tensor::from_vec(vec![0.0; 50], &[5, 10]).expect("tensor creation failed");
1054        params[0].update_data(new_data);
1055
1056        ticket.rewind(&params);
1057
1058        let rewound_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1059        assert_eq!(initial_shapes, rewound_shapes);
1060    }
1061
1062    #[test]
1063    fn test_lottery_ticket_rewind_with_mask() {
1064        let data =
1065            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor creation failed");
1066        let param = Parameter::named("weight", data, true);
1067        let params = vec![param];
1068
1069        let ticket = LotteryTicket::snapshot(&params);
1070
1071        // Modify the parameter
1072        let new_data = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2])
1073            .expect("tensor creation failed");
1074        params[0].update_data(new_data);
1075
1076        // Mask: keep first two, prune last two
1077        let mask =
1078            Tensor::from_vec(vec![1.0, 1.0, 0.0, 0.0], &[2, 2]).expect("tensor creation failed");
1079        ticket.rewind_with_mask(&params, &[mask]);
1080
1081        let result = params[0].data().to_vec();
1082        assert_eq!(
1083            result,
1084            vec![1.0, 2.0, 0.0, 0.0],
1085            "Masked weights should be zero, unmasked should be initial values"
1086        );
1087    }
1088
1089    #[test]
1090    fn test_lottery_ticket_debug() {
1091        let layer = SparseLinear::new(10, 5);
1092        let ticket = LotteryTicket::snapshot(&layer.parameters());
1093        let debug_str = format!("{:?}", ticket);
1094        assert!(debug_str.contains("LotteryTicket"));
1095        assert!(debug_str.contains("num_saved"));
1096    }
1097
1098    // -------------------------------------------------------------------------
1099    // Integration Tests
1100    // -------------------------------------------------------------------------
1101
1102    #[test]
1103    fn test_integration_sparse_linear_with_group_sparsity() {
1104        // Create a SparseLinear layer
1105        let layer = SparseLinear::new(8, 4);
1106
1107        // Forward pass
1108        let input = Variable::new(
1109            Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
1110            false,
1111        );
1112        let output = layer.forward(&input);
1113        assert_eq!(output.shape(), vec![2, 4]);
1114
1115        // Compute group sparsity penalty on the weights
1116        let reg = GroupSparsity::new(0.001, 8); // group_size = in_features
1117        let weight_var = layer.weight.variable();
1118        let penalty = reg.penalty(&weight_var);
1119        let penalty_val = penalty.data().to_vec()[0];
1120        assert!(
1121            penalty_val > 0.0,
1122            "Penalty should be positive for non-zero weights"
1123        );
1124    }
1125
1126    #[test]
1127    fn test_integration_lottery_ticket_with_pruning() {
1128        // 1. Create layer and snapshot
1129        let mut layer = SparseLinear::new(8, 4);
1130        let ticket = LotteryTicket::snapshot(&layer.parameters());
1131
1132        // 2. Simulate training (modify weights)
1133        let new_weight = Tensor::from_vec(vec![0.5; 32], &[4, 8]).expect("tensor creation failed");
1134        layer.weight.update_data(new_weight);
1135
1136        // 3. Set threshold to prune some weights
1137        layer.reset_threshold(0.3);
1138
1139        // 4. Get the effective weight mask
1140        let mask = layer.hard_mask();
1141
1142        // 5. Rewind to initial weights with mask
1143        let weight_param = vec![layer.weight.clone()];
1144        ticket.rewind_with_mask(&weight_param, &[mask]);
1145
1146        // Verify shape is preserved
1147        assert_eq!(layer.weight.shape(), vec![4, 8]);
1148    }
1149
1150    #[test]
1151    fn test_num_parameters_sparse_linear() {
1152        let layer = SparseLinear::new(10, 5);
1153        // weight: 50 + threshold: 5 + bias: 5 = 60
1154        assert_eq!(layer.num_parameters(), 60);
1155    }
1156}