1use std::collections::HashMap;
27
28use axonml_autograd::Variable;
29use axonml_tensor::Tensor;
30
31use crate::init::{constant, kaiming_uniform, zeros};
32use crate::module::Module;
33use crate::parameter::Parameter;
34
35const TEMPERATURE: f32 = 10.0;
42
43const DEFAULT_THRESHOLD: f32 = 0.01;
46
47pub struct SparseLinear {
79 pub weight: Parameter,
81 pub bias: Option<Parameter>,
83 pub threshold: Parameter,
87 in_features: usize,
89 out_features: usize,
91 structured: bool,
93}
94
95impl SparseLinear {
96 pub fn new(in_features: usize, out_features: usize) -> Self {
102 Self::build(in_features, out_features, true, true)
103 }
104
105 pub fn unstructured(in_features: usize, out_features: usize) -> Self {
111 Self::build(in_features, out_features, false, true)
112 }
113
114 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
121 Self::build(in_features, out_features, true, bias)
122 }
123
124 fn build(in_features: usize, out_features: usize, structured: bool, bias: bool) -> Self {
126 let weight_data = kaiming_uniform(out_features, in_features);
128 let weight = Parameter::named("weight", weight_data, true);
129
130 let bias_param = if bias {
132 let bias_data = zeros(&[out_features]);
133 Some(Parameter::named("bias", bias_data, true))
134 } else {
135 None
136 };
137
138 let threshold_data = if structured {
140 constant(&[out_features], DEFAULT_THRESHOLD)
141 } else {
142 constant(&[out_features, in_features], DEFAULT_THRESHOLD)
143 };
144 let threshold = Parameter::named("threshold", threshold_data, true);
145
146 Self {
147 weight,
148 bias: bias_param,
149 threshold,
150 in_features,
151 out_features,
152 structured,
153 }
154 }
155
156 pub fn in_features(&self) -> usize {
158 self.in_features
159 }
160
161 pub fn out_features(&self) -> usize {
163 self.out_features
164 }
165
166 pub fn is_structured(&self) -> bool {
168 self.structured
169 }
170
171 fn hard_mask(&self) -> Tensor<f32> {
176 let weight_data = self.weight.data();
177 let threshold_data = self.threshold.data();
178 let w_vec = weight_data.to_vec();
179 let t_vec = threshold_data.to_vec();
180
181 let mask_vec: Vec<f32> = if self.structured {
182 w_vec
184 .iter()
185 .enumerate()
186 .map(|(idx, &w)| {
187 let out_idx = idx / self.in_features;
188 let t = t_vec[out_idx];
189 if w.abs() >= t { 1.0 } else { 0.0 }
190 })
191 .collect()
192 } else {
193 w_vec
195 .iter()
196 .zip(t_vec.iter())
197 .map(|(&w, &t)| if w.abs() >= t { 1.0 } else { 0.0 })
198 .collect()
199 };
200
201 Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
202 .expect("tensor creation failed")
203 }
204
205 pub fn density(&self) -> f32 {
209 let mask = self.hard_mask();
210 let mask_vec = mask.to_vec();
211 let total = mask_vec.len() as f32;
212 let active: f32 = mask_vec.iter().sum();
213 active / total
214 }
215
216 pub fn sparsity(&self) -> f32 {
220 1.0 - self.density()
221 }
222
223 pub fn num_active(&self) -> usize {
225 let mask = self.hard_mask();
226 let mask_vec = mask.to_vec();
227 mask_vec.iter().filter(|&&v| v > 0.5).count()
228 }
229
230 pub fn hard_prune(&mut self) {
236 let mask = self.hard_mask();
237 let weight_data = self.weight.data();
238 let w_vec = weight_data.to_vec();
239 let m_vec = mask.to_vec();
240
241 let pruned: Vec<f32> = w_vec
242 .iter()
243 .zip(m_vec.iter())
244 .map(|(&w, &m)| w * m)
245 .collect();
246
247 let new_weight = Tensor::from_vec(pruned, &[self.out_features, self.in_features])
248 .expect("tensor creation failed");
249 self.weight.update_data(new_weight);
250
251 let zero_threshold = if self.structured {
253 zeros(&[self.out_features])
254 } else {
255 zeros(&[self.out_features, self.in_features])
256 };
257 self.threshold.update_data(zero_threshold);
258 }
259
260 pub fn reset_threshold(&mut self, value: f32) {
265 let new_threshold = if self.structured {
266 constant(&[self.out_features], value)
267 } else {
268 constant(&[self.out_features, self.in_features], value)
269 };
270 self.threshold.update_data(new_threshold);
271 }
272
273 pub fn effective_weight(&self) -> Tensor<f32> {
278 let mask = self.hard_mask();
279 let weight_data = self.weight.data();
280 let w_vec = weight_data.to_vec();
281 let m_vec = mask.to_vec();
282
283 let effective: Vec<f32> = w_vec
284 .iter()
285 .zip(m_vec.iter())
286 .map(|(&w, &m)| w * m)
287 .collect();
288
289 Tensor::from_vec(effective, &[self.out_features, self.in_features])
290 .expect("tensor creation failed")
291 }
292
293 fn compute_soft_mask(&self, weight_var: &Variable) -> Variable {
303 let weight_data = weight_var.data();
304 let threshold_data = self.threshold.data();
305 let w_vec = weight_data.to_vec();
306 let t_vec = threshold_data.to_vec();
307
308 let mask_vec: Vec<f32> = if self.structured {
310 w_vec
311 .iter()
312 .enumerate()
313 .map(|(idx, &w)| {
314 let out_idx = idx / self.in_features;
315 let t = t_vec[out_idx];
316 let x = (w.abs() - t) * TEMPERATURE;
317 1.0 / (1.0 + (-x).exp())
318 })
319 .collect()
320 } else {
321 w_vec
322 .iter()
323 .zip(t_vec.iter())
324 .map(|(&w, &t)| {
325 let x = (w.abs() - t) * TEMPERATURE;
326 1.0 / (1.0 + (-x).exp())
327 })
328 .collect()
329 };
330
331 let mask_tensor = Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
332 .expect("tensor creation failed");
333
334 Variable::new(mask_tensor, false)
339 }
340}
341
342impl Module for SparseLinear {
343 fn forward(&self, input: &Variable) -> Variable {
344 let input_shape = input.shape();
345 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
346 let total_batch: usize = batch_dims.iter().product();
347
348 let input_2d = if input_shape.len() > 2 {
350 input.reshape(&[total_batch, self.in_features])
351 } else {
352 input.clone()
353 };
354
355 let weight_var = self.weight.variable();
357 let mask = self.compute_soft_mask(&weight_var);
358
359 let effective_weight = weight_var.mul_var(&mask);
361
362 let weight_t = effective_weight.transpose(0, 1);
364 let mut output = input_2d.matmul(&weight_t);
365
366 if let Some(ref bias) = self.bias {
368 let bias_var = bias.variable();
369 output = output.add_var(&bias_var);
370 }
371
372 if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
374 let mut output_shape: Vec<usize> = batch_dims;
375 output_shape.push(self.out_features);
376 output.reshape(&output_shape)
377 } else {
378 output
379 }
380 }
381
382 fn parameters(&self) -> Vec<Parameter> {
383 let mut params = vec![self.weight.clone(), self.threshold.clone()];
384 if let Some(ref bias) = self.bias {
385 params.push(bias.clone());
386 }
387 params
388 }
389
390 fn named_parameters(&self) -> HashMap<String, Parameter> {
391 let mut params = HashMap::new();
392 params.insert("weight".to_string(), self.weight.clone());
393 params.insert("threshold".to_string(), self.threshold.clone());
394 if let Some(ref bias) = self.bias {
395 params.insert("bias".to_string(), bias.clone());
396 }
397 params
398 }
399
400 fn name(&self) -> &'static str {
401 "SparseLinear"
402 }
403}
404
405impl std::fmt::Debug for SparseLinear {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 f.debug_struct("SparseLinear")
408 .field("in_features", &self.in_features)
409 .field("out_features", &self.out_features)
410 .field("bias", &self.bias.is_some())
411 .field("structured", &self.structured)
412 .field("density", &self.density())
413 .finish()
414 }
415}
416
417pub struct GroupSparsity {
440 lambda: f32,
442 group_size: usize,
444}
445
446impl GroupSparsity {
447 pub fn new(lambda: f32, group_size: usize) -> Self {
453 assert!(group_size > 0, "group_size must be positive");
454 Self { lambda, group_size }
455 }
456
457 pub fn lambda(&self) -> f32 {
459 self.lambda
460 }
461
462 pub fn group_size(&self) -> usize {
464 self.group_size
465 }
466
467 pub fn penalty(&self, weight: &Variable) -> Variable {
477 let weight_data = weight.data();
478 let w_vec = weight_data.to_vec();
479 let total = w_vec.len();
480
481 let num_groups = total.div_ceil(self.group_size);
483
484 let mut group_norm_sum = 0.0f32;
486 for g in 0..num_groups {
487 let start = g * self.group_size;
488 let end = (start + self.group_size).min(total);
489 let group = &w_vec[start..end];
490
491 let l2_norm: f32 = group.iter().map(|&x| x * x).sum::<f32>().sqrt();
492 group_norm_sum += l2_norm;
493 }
494
495 let penalty_val = self.lambda * group_norm_sum;
496 let penalty_tensor =
497 Tensor::from_vec(vec![penalty_val], &[1]).expect("tensor creation failed");
498
499 Variable::new(penalty_tensor, false)
504 }
505}
506
507impl std::fmt::Debug for GroupSparsity {
508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509 f.debug_struct("GroupSparsity")
510 .field("lambda", &self.lambda)
511 .field("group_size", &self.group_size)
512 .finish()
513 }
514}
515
516pub struct LotteryTicket {
549 initial_weights: HashMap<String, Tensor<f32>>,
551}
552
553impl LotteryTicket {
554 pub fn snapshot(params: &[Parameter]) -> Self {
559 let mut initial_weights = HashMap::new();
560 for (i, param) in params.iter().enumerate() {
561 let key = if param.name().is_empty() {
562 format!("param_{}", i)
563 } else {
564 param.name().to_string()
565 };
566 initial_weights.insert(key, param.data());
567 }
568 Self { initial_weights }
569 }
570
571 pub fn num_saved(&self) -> usize {
573 self.initial_weights.len()
574 }
575
576 pub fn rewind(&self, params: &[Parameter]) {
581 for (i, param) in params.iter().enumerate() {
582 let key = if param.name().is_empty() {
583 format!("param_{}", i)
584 } else {
585 param.name().to_string()
586 };
587 if let Some(initial) = self.initial_weights.get(&key) {
588 param.update_data(initial.clone());
589 }
590 }
591 }
592
593 pub fn rewind_with_mask(&self, params: &[Parameter], masks: &[Tensor<f32>]) {
602 assert_eq!(
603 params.len(),
604 masks.len(),
605 "Number of parameters and masks must match"
606 );
607
608 for (i, (param, mask)) in params.iter().zip(masks.iter()).enumerate() {
609 let key = if param.name().is_empty() {
610 format!("param_{}", i)
611 } else {
612 param.name().to_string()
613 };
614
615 if let Some(initial) = self.initial_weights.get(&key) {
616 let init_vec = initial.to_vec();
617 let mask_vec = mask.to_vec();
618
619 let rewound: Vec<f32> = init_vec
620 .iter()
621 .zip(mask_vec.iter())
622 .map(|(&w, &m)| if m > 0.5 { w } else { 0.0 })
623 .collect();
624
625 let shape = param.shape();
626 let new_data = Tensor::from_vec(rewound, &shape).expect("tensor creation failed");
627 param.update_data(new_data);
628 }
629 }
630 }
631}
632
633impl std::fmt::Debug for LotteryTicket {
634 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
635 f.debug_struct("LotteryTicket")
636 .field("num_saved", &self.initial_weights.len())
637 .finish()
638 }
639}
640
641#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
654 fn test_sparse_linear_creation_structured() {
655 let layer = SparseLinear::new(10, 5);
656 assert_eq!(layer.in_features(), 10);
657 assert_eq!(layer.out_features(), 5);
658 assert!(layer.is_structured());
659 assert!(layer.bias.is_some());
660 }
661
662 #[test]
663 fn test_sparse_linear_creation_unstructured() {
664 let layer = SparseLinear::unstructured(10, 5);
665 assert_eq!(layer.in_features(), 10);
666 assert_eq!(layer.out_features(), 5);
667 assert!(!layer.is_structured());
668 assert!(layer.bias.is_some());
669 }
670
671 #[test]
672 fn test_sparse_linear_no_bias() {
673 let layer = SparseLinear::with_bias(10, 5, false);
674 assert!(layer.bias.is_none());
675 }
676
677 #[test]
678 fn test_sparse_linear_forward_shape() {
679 let layer = SparseLinear::new(4, 3);
680 let input = Variable::new(
681 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
682 false,
683 );
684 let output = layer.forward(&input);
685 assert_eq!(output.shape(), vec![1, 3]);
686 }
687
688 #[test]
689 fn test_sparse_linear_forward_batch() {
690 let layer = SparseLinear::new(4, 3);
691 let input = Variable::new(
692 Tensor::from_vec(vec![1.0; 12], &[3, 4]).expect("tensor creation failed"),
693 false,
694 );
695 let output = layer.forward(&input);
696 assert_eq!(output.shape(), vec![3, 3]);
697 }
698
699 #[test]
700 fn test_sparse_linear_forward_no_bias() {
701 let layer = SparseLinear::with_bias(4, 3, false);
702 let input = Variable::new(
703 Tensor::from_vec(vec![1.0; 8], &[2, 4]).expect("tensor creation failed"),
704 false,
705 );
706 let output = layer.forward(&input);
707 assert_eq!(output.shape(), vec![2, 3]);
708 }
709
710 #[test]
711 fn test_sparse_linear_density_initial() {
712 let layer = SparseLinear::new(100, 50);
715 let density = layer.density();
716 assert!(
717 density > 0.9,
718 "Initial density should be high, got {}",
719 density
720 );
721 }
722
723 #[test]
724 fn test_sparse_linear_sparsity_initial() {
725 let layer = SparseLinear::new(100, 50);
726 let sparsity = layer.sparsity();
727 assert!(
728 sparsity < 0.1,
729 "Initial sparsity should be low, got {}",
730 sparsity
731 );
732 assert!((layer.density() + layer.sparsity() - 1.0).abs() < 1e-6);
733 }
734
735 #[test]
736 fn test_sparse_linear_num_active() {
737 let layer = SparseLinear::new(10, 5);
738 let active = layer.num_active();
739 let total = 10 * 5;
740 assert!(active <= total);
741 assert!(active > 0);
742 }
743
744 #[test]
745 fn test_sparse_linear_high_threshold_more_sparsity() {
746 let mut layer = SparseLinear::new(100, 50);
747 let density_low_thresh = layer.density();
748
749 layer.reset_threshold(10.0);
751 let density_high_thresh = layer.density();
752
753 assert!(
754 density_high_thresh < density_low_thresh,
755 "Higher threshold should reduce density: low_thresh={}, high_thresh={}",
756 density_low_thresh,
757 density_high_thresh
758 );
759 }
760
761 #[test]
762 fn test_sparse_linear_low_threshold_dense() {
763 let mut layer = SparseLinear::new(100, 50);
764 layer.reset_threshold(0.0);
766 let density = layer.density();
767 assert!(
768 (density - 1.0).abs() < 1e-6,
769 "Zero threshold should give density=1.0, got {}",
770 density
771 );
772 }
773
774 #[test]
775 fn test_sparse_linear_soft_mask_values_in_range() {
776 let layer = SparseLinear::new(10, 5);
777 let weight_var = layer.weight.variable();
778 let mask = layer.compute_soft_mask(&weight_var);
779 let mask_vec = mask.data().to_vec();
780
781 for &v in &mask_vec {
782 assert!(
783 (0.0..=1.0).contains(&v),
784 "Soft mask value {} not in [0, 1]",
785 v
786 );
787 }
788 }
789
790 #[test]
791 fn test_sparse_linear_hard_prune() {
792 let mut layer = SparseLinear::new(10, 5);
793 layer.reset_threshold(0.5);
795
796 let pre_prune_density = layer.density();
797 layer.hard_prune();
798
799 let weight_data = layer.weight.data();
801 let w_vec = weight_data.to_vec();
802 let zeros_count = w_vec.iter().filter(|&&v| v == 0.0).count();
803
804 let expected_zeros = ((1.0 - pre_prune_density) * (10 * 5) as f32).round() as usize;
806 assert_eq!(
807 zeros_count, expected_zeros,
808 "Hard prune should zero out pruned weights"
809 );
810 }
811
812 #[test]
813 fn test_sparse_linear_hard_prune_threshold_reset() {
814 let mut layer = SparseLinear::new(10, 5);
815 layer.reset_threshold(0.5);
816 layer.hard_prune();
817
818 let t_vec = layer.threshold.data().to_vec();
820 assert!(
821 t_vec.iter().all(|&v| v == 0.0),
822 "Thresholds should be zero after hard_prune"
823 );
824 }
825
826 #[test]
827 fn test_sparse_linear_effective_weight() {
828 let layer = SparseLinear::new(10, 5);
829 let ew = layer.effective_weight();
830 assert_eq!(ew.shape(), &[5, 10]);
831 }
832
833 #[test]
834 fn test_sparse_linear_effective_weight_matches_hard_prune() {
835 let mut layer = SparseLinear::new(10, 5);
836 layer.reset_threshold(0.3);
837
838 let effective = layer.effective_weight();
839 layer.hard_prune();
840 let pruned = layer.weight.data();
841
842 let e_vec = effective.to_vec();
843 let p_vec = pruned.to_vec();
844 for (e, p) in e_vec.iter().zip(p_vec.iter()) {
845 assert!(
846 (e - p).abs() < 1e-6,
847 "effective_weight and hard_prune should match"
848 );
849 }
850 }
851
852 #[test]
853 fn test_sparse_linear_parameters_include_threshold() {
854 let layer = SparseLinear::new(10, 5);
855 let params = layer.parameters();
856 assert_eq!(params.len(), 3);
858
859 let named = layer.named_parameters();
860 assert!(named.contains_key("threshold"));
861 assert!(named.contains_key("weight"));
862 assert!(named.contains_key("bias"));
863 }
864
865 #[test]
866 fn test_sparse_linear_parameters_no_bias() {
867 let layer = SparseLinear::with_bias(10, 5, false);
868 let params = layer.parameters();
869 assert_eq!(params.len(), 2);
871 }
872
873 #[test]
874 fn test_sparse_linear_module_name() {
875 let layer = SparseLinear::new(10, 5);
876 assert_eq!(layer.name(), "SparseLinear");
877 }
878
879 #[test]
880 fn test_sparse_linear_debug() {
881 let layer = SparseLinear::new(10, 5);
882 let debug_str = format!("{:?}", layer);
883 assert!(debug_str.contains("SparseLinear"));
884 assert!(debug_str.contains("in_features: 10"));
885 assert!(debug_str.contains("out_features: 5"));
886 }
887
888 #[test]
889 fn test_sparse_linear_reset_threshold() {
890 let mut layer = SparseLinear::new(10, 5);
891 layer.reset_threshold(0.5);
892 let t_vec = layer.threshold.data().to_vec();
893 assert!(t_vec.iter().all(|&v| (v - 0.5).abs() < 1e-6));
894 }
895
896 #[test]
897 fn test_sparse_linear_unstructured_threshold_shape() {
898 let layer = SparseLinear::unstructured(10, 5);
899 assert_eq!(layer.threshold.shape(), vec![5, 10]);
901 }
902
903 #[test]
904 fn test_sparse_linear_structured_threshold_shape() {
905 let layer = SparseLinear::new(10, 5);
906 assert_eq!(layer.threshold.shape(), vec![5]);
908 }
909
910 #[test]
911 fn test_sparse_linear_unstructured_forward() {
912 let layer = SparseLinear::unstructured(4, 3);
913 let input = Variable::new(
914 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
915 .expect("tensor creation failed"),
916 false,
917 );
918 let output = layer.forward(&input);
919 assert_eq!(output.shape(), vec![2, 3]);
920 }
921
922 #[test]
927 fn test_group_sparsity_creation() {
928 let reg = GroupSparsity::new(0.001, 10);
929 assert!((reg.lambda() - 0.001).abs() < 1e-8);
930 assert_eq!(reg.group_size(), 10);
931 }
932
933 #[test]
934 fn test_group_sparsity_penalty_non_negative() {
935 let reg = GroupSparsity::new(0.01, 4);
936 let weight = Variable::new(
937 Tensor::from_vec(vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0], &[2, 4])
938 .expect("tensor creation failed"),
939 true,
940 );
941 let penalty = reg.penalty(&weight);
942 let penalty_val = penalty.data().to_vec()[0];
943 assert!(
944 penalty_val >= 0.0,
945 "Penalty should be non-negative, got {}",
946 penalty_val
947 );
948 }
949
950 #[test]
951 fn test_group_sparsity_zero_weights_zero_penalty() {
952 let reg = GroupSparsity::new(0.01, 4);
953 let weight = Variable::new(
954 Tensor::from_vec(vec![0.0; 8], &[2, 4]).expect("tensor creation failed"),
955 true,
956 );
957 let penalty = reg.penalty(&weight);
958 let penalty_val = penalty.data().to_vec()[0];
959 assert!(
960 (penalty_val).abs() < 1e-6,
961 "Zero weights should give zero penalty, got {}",
962 penalty_val
963 );
964 }
965
966 #[test]
967 fn test_group_sparsity_scales_with_lambda() {
968 let reg_small = GroupSparsity::new(0.001, 4);
969 let reg_large = GroupSparsity::new(0.01, 4);
970 let weight = Variable::new(
971 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
972 true,
973 );
974
975 let penalty_small = reg_small.penalty(&weight).data().to_vec()[0];
976 let penalty_large = reg_large.penalty(&weight).data().to_vec()[0];
977
978 assert!(
979 penalty_large > penalty_small,
980 "Larger lambda should give larger penalty: small={}, large={}",
981 penalty_small,
982 penalty_large
983 );
984
985 let ratio = penalty_large / penalty_small;
987 assert!(
988 (ratio - 10.0).abs() < 1e-4,
989 "Penalty should scale linearly with lambda, ratio={}",
990 ratio
991 );
992 }
993
994 #[test]
995 fn test_group_sparsity_debug() {
996 let reg = GroupSparsity::new(0.001, 10);
997 let debug_str = format!("{:?}", reg);
998 assert!(debug_str.contains("GroupSparsity"));
999 assert!(debug_str.contains("lambda"));
1000 }
1001
1002 #[test]
1003 #[should_panic(expected = "group_size must be positive")]
1004 fn test_group_sparsity_zero_group_size_panics() {
1005 let _reg = GroupSparsity::new(0.01, 0);
1006 }
1007
1008 #[test]
1013 fn test_lottery_ticket_snapshot() {
1014 let layer = SparseLinear::new(10, 5);
1015 let params = layer.parameters();
1016 let ticket = LotteryTicket::snapshot(¶ms);
1017 assert_eq!(ticket.num_saved(), params.len());
1018 }
1019
1020 #[test]
1021 fn test_lottery_ticket_rewind() {
1022 let layer = SparseLinear::new(10, 5);
1023 let params = layer.parameters();
1024 let initial_weight = params[0].data().to_vec();
1025
1026 let ticket = LotteryTicket::snapshot(¶ms);
1027
1028 let new_data = Tensor::from_vec(vec![99.0; 50], &[5, 10]).expect("tensor creation failed");
1030 params[0].update_data(new_data);
1031
1032 let modified_weight = params[0].data().to_vec();
1034 assert_ne!(modified_weight, initial_weight);
1035
1036 ticket.rewind(¶ms);
1038
1039 let rewound_weight = params[0].data().to_vec();
1041 assert_eq!(rewound_weight, initial_weight);
1042 }
1043
1044 #[test]
1045 fn test_lottery_ticket_rewind_preserves_shapes() {
1046 let layer = SparseLinear::new(10, 5);
1047 let params = layer.parameters();
1048 let initial_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1049
1050 let ticket = LotteryTicket::snapshot(¶ms);
1051
1052 let new_data = Tensor::from_vec(vec![0.0; 50], &[5, 10]).expect("tensor creation failed");
1054 params[0].update_data(new_data);
1055
1056 ticket.rewind(¶ms);
1057
1058 let rewound_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1059 assert_eq!(initial_shapes, rewound_shapes);
1060 }
1061
1062 #[test]
1063 fn test_lottery_ticket_rewind_with_mask() {
1064 let data =
1065 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor creation failed");
1066 let param = Parameter::named("weight", data, true);
1067 let params = vec![param];
1068
1069 let ticket = LotteryTicket::snapshot(¶ms);
1070
1071 let new_data = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2])
1073 .expect("tensor creation failed");
1074 params[0].update_data(new_data);
1075
1076 let mask =
1078 Tensor::from_vec(vec![1.0, 1.0, 0.0, 0.0], &[2, 2]).expect("tensor creation failed");
1079 ticket.rewind_with_mask(¶ms, &[mask]);
1080
1081 let result = params[0].data().to_vec();
1082 assert_eq!(
1083 result,
1084 vec![1.0, 2.0, 0.0, 0.0],
1085 "Masked weights should be zero, unmasked should be initial values"
1086 );
1087 }
1088
1089 #[test]
1090 fn test_lottery_ticket_debug() {
1091 let layer = SparseLinear::new(10, 5);
1092 let ticket = LotteryTicket::snapshot(&layer.parameters());
1093 let debug_str = format!("{:?}", ticket);
1094 assert!(debug_str.contains("LotteryTicket"));
1095 assert!(debug_str.contains("num_saved"));
1096 }
1097
1098 #[test]
1103 fn test_integration_sparse_linear_with_group_sparsity() {
1104 let layer = SparseLinear::new(8, 4);
1106
1107 let input = Variable::new(
1109 Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
1110 false,
1111 );
1112 let output = layer.forward(&input);
1113 assert_eq!(output.shape(), vec![2, 4]);
1114
1115 let reg = GroupSparsity::new(0.001, 8); let weight_var = layer.weight.variable();
1118 let penalty = reg.penalty(&weight_var);
1119 let penalty_val = penalty.data().to_vec()[0];
1120 assert!(
1121 penalty_val > 0.0,
1122 "Penalty should be positive for non-zero weights"
1123 );
1124 }
1125
1126 #[test]
1127 fn test_integration_lottery_ticket_with_pruning() {
1128 let mut layer = SparseLinear::new(8, 4);
1130 let ticket = LotteryTicket::snapshot(&layer.parameters());
1131
1132 let new_weight = Tensor::from_vec(vec![0.5; 32], &[4, 8]).expect("tensor creation failed");
1134 layer.weight.update_data(new_weight);
1135
1136 layer.reset_threshold(0.3);
1138
1139 let mask = layer.hard_mask();
1141
1142 let weight_param = vec![layer.weight.clone()];
1144 ticket.rewind_with_mask(&weight_param, &[mask]);
1145
1146 assert_eq!(layer.weight.shape(), vec![4, 8]);
1148 }
1149
1150 #[test]
1151 fn test_num_parameters_sparse_linear() {
1152 let layer = SparseLinear::new(10, 5);
1153 assert_eq!(layer.num_parameters(), 60);
1155 }
1156}