1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::init::{constant, kaiming_uniform, zeros};
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26const TEMPERATURE: f32 = 10.0;
33
34const DEFAULT_THRESHOLD: f32 = 0.01;
37
38pub struct SparseLinear {
70 pub weight: Parameter,
72 pub bias: Option<Parameter>,
74 pub threshold: Parameter,
78 in_features: usize,
80 out_features: usize,
82 structured: bool,
84}
85
86impl SparseLinear {
87 pub fn new(in_features: usize, out_features: usize) -> Self {
93 Self::build(in_features, out_features, true, true)
94 }
95
96 pub fn unstructured(in_features: usize, out_features: usize) -> Self {
102 Self::build(in_features, out_features, false, true)
103 }
104
105 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
112 Self::build(in_features, out_features, true, bias)
113 }
114
115 fn build(in_features: usize, out_features: usize, structured: bool, bias: bool) -> Self {
117 let weight_data = kaiming_uniform(out_features, in_features);
119 let weight = Parameter::named("weight", weight_data, true);
120
121 let bias_param = if bias {
123 let bias_data = zeros(&[out_features]);
124 Some(Parameter::named("bias", bias_data, true))
125 } else {
126 None
127 };
128
129 let threshold_data = if structured {
131 constant(&[out_features], DEFAULT_THRESHOLD)
132 } else {
133 constant(&[out_features, in_features], DEFAULT_THRESHOLD)
134 };
135 let threshold = Parameter::named("threshold", threshold_data, true);
136
137 Self {
138 weight,
139 bias: bias_param,
140 threshold,
141 in_features,
142 out_features,
143 structured,
144 }
145 }
146
147 pub fn in_features(&self) -> usize {
149 self.in_features
150 }
151
152 pub fn out_features(&self) -> usize {
154 self.out_features
155 }
156
157 pub fn is_structured(&self) -> bool {
159 self.structured
160 }
161
162 fn hard_mask(&self) -> Tensor<f32> {
167 let weight_data = self.weight.data();
168 let threshold_data = self.threshold.data();
169 let w_vec = weight_data.to_vec();
170 let t_vec = threshold_data.to_vec();
171
172 let mask_vec: Vec<f32> = if self.structured {
173 w_vec
175 .iter()
176 .enumerate()
177 .map(|(idx, &w)| {
178 let out_idx = idx / self.in_features;
179 let t = t_vec[out_idx];
180 if w.abs() >= t { 1.0 } else { 0.0 }
181 })
182 .collect()
183 } else {
184 w_vec
186 .iter()
187 .zip(t_vec.iter())
188 .map(|(&w, &t)| if w.abs() >= t { 1.0 } else { 0.0 })
189 .collect()
190 };
191
192 Tensor::from_vec(mask_vec, &[self.out_features, self.in_features]).unwrap()
193 }
194
195 pub fn density(&self) -> f32 {
199 let mask = self.hard_mask();
200 let mask_vec = mask.to_vec();
201 let total = mask_vec.len() as f32;
202 let active: f32 = mask_vec.iter().sum();
203 active / total
204 }
205
206 pub fn sparsity(&self) -> f32 {
210 1.0 - self.density()
211 }
212
213 pub fn num_active(&self) -> usize {
215 let mask = self.hard_mask();
216 let mask_vec = mask.to_vec();
217 mask_vec.iter().filter(|&&v| v > 0.5).count()
218 }
219
220 pub fn hard_prune(&mut self) {
226 let mask = self.hard_mask();
227 let weight_data = self.weight.data();
228 let w_vec = weight_data.to_vec();
229 let m_vec = mask.to_vec();
230
231 let pruned: Vec<f32> = w_vec
232 .iter()
233 .zip(m_vec.iter())
234 .map(|(&w, &m)| w * m)
235 .collect();
236
237 let new_weight = Tensor::from_vec(pruned, &[self.out_features, self.in_features]).unwrap();
238 self.weight.update_data(new_weight);
239
240 let zero_threshold = if self.structured {
242 zeros(&[self.out_features])
243 } else {
244 zeros(&[self.out_features, self.in_features])
245 };
246 self.threshold.update_data(zero_threshold);
247 }
248
249 pub fn reset_threshold(&mut self, value: f32) {
254 let new_threshold = if self.structured {
255 constant(&[self.out_features], value)
256 } else {
257 constant(&[self.out_features, self.in_features], value)
258 };
259 self.threshold.update_data(new_threshold);
260 }
261
262 pub fn effective_weight(&self) -> Tensor<f32> {
267 let mask = self.hard_mask();
268 let weight_data = self.weight.data();
269 let w_vec = weight_data.to_vec();
270 let m_vec = mask.to_vec();
271
272 let effective: Vec<f32> = w_vec
273 .iter()
274 .zip(m_vec.iter())
275 .map(|(&w, &m)| w * m)
276 .collect();
277
278 Tensor::from_vec(effective, &[self.out_features, self.in_features]).unwrap()
279 }
280
281 fn compute_soft_mask(&self, weight_var: &Variable) -> Variable {
291 let weight_data = weight_var.data();
292 let threshold_data = self.threshold.data();
293 let w_vec = weight_data.to_vec();
294 let t_vec = threshold_data.to_vec();
295
296 let mask_vec: Vec<f32> = if self.structured {
298 w_vec
299 .iter()
300 .enumerate()
301 .map(|(idx, &w)| {
302 let out_idx = idx / self.in_features;
303 let t = t_vec[out_idx];
304 let x = (w.abs() - t) * TEMPERATURE;
305 1.0 / (1.0 + (-x).exp())
306 })
307 .collect()
308 } else {
309 w_vec
310 .iter()
311 .zip(t_vec.iter())
312 .map(|(&w, &t)| {
313 let x = (w.abs() - t) * TEMPERATURE;
314 1.0 / (1.0 + (-x).exp())
315 })
316 .collect()
317 };
318
319 let mask_tensor =
320 Tensor::from_vec(mask_vec, &[self.out_features, self.in_features]).unwrap();
321
322 Variable::new(mask_tensor, false)
327 }
328}
329
330impl Module for SparseLinear {
331 fn forward(&self, input: &Variable) -> Variable {
332 let input_shape = input.shape();
333 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
334 let total_batch: usize = batch_dims.iter().product();
335
336 let input_2d = if input_shape.len() > 2 {
338 input.reshape(&[total_batch, self.in_features])
339 } else {
340 input.clone()
341 };
342
343 let weight_var = self.weight.variable();
345 let mask = self.compute_soft_mask(&weight_var);
346
347 let effective_weight = weight_var.mul_var(&mask);
349
350 let weight_t = effective_weight.transpose(0, 1);
352 let mut output = input_2d.matmul(&weight_t);
353
354 if let Some(ref bias) = self.bias {
356 let bias_var = bias.variable();
357 output = output.add_var(&bias_var);
358 }
359
360 if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
362 let mut output_shape: Vec<usize> = batch_dims;
363 output_shape.push(self.out_features);
364 output.reshape(&output_shape)
365 } else {
366 output
367 }
368 }
369
370 fn parameters(&self) -> Vec<Parameter> {
371 let mut params = vec![self.weight.clone(), self.threshold.clone()];
372 if let Some(ref bias) = self.bias {
373 params.push(bias.clone());
374 }
375 params
376 }
377
378 fn named_parameters(&self) -> HashMap<String, Parameter> {
379 let mut params = HashMap::new();
380 params.insert("weight".to_string(), self.weight.clone());
381 params.insert("threshold".to_string(), self.threshold.clone());
382 if let Some(ref bias) = self.bias {
383 params.insert("bias".to_string(), bias.clone());
384 }
385 params
386 }
387
388 fn name(&self) -> &'static str {
389 "SparseLinear"
390 }
391}
392
393impl std::fmt::Debug for SparseLinear {
394 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395 f.debug_struct("SparseLinear")
396 .field("in_features", &self.in_features)
397 .field("out_features", &self.out_features)
398 .field("bias", &self.bias.is_some())
399 .field("structured", &self.structured)
400 .field("density", &self.density())
401 .finish()
402 }
403}
404
405pub struct GroupSparsity {
428 lambda: f32,
430 group_size: usize,
432}
433
434impl GroupSparsity {
435 pub fn new(lambda: f32, group_size: usize) -> Self {
441 assert!(group_size > 0, "group_size must be positive");
442 Self { lambda, group_size }
443 }
444
445 pub fn lambda(&self) -> f32 {
447 self.lambda
448 }
449
450 pub fn group_size(&self) -> usize {
452 self.group_size
453 }
454
455 pub fn penalty(&self, weight: &Variable) -> Variable {
465 let weight_data = weight.data();
466 let w_vec = weight_data.to_vec();
467 let total = w_vec.len();
468
469 let num_groups = total.div_ceil(self.group_size);
471
472 let mut group_norm_sum = 0.0f32;
474 for g in 0..num_groups {
475 let start = g * self.group_size;
476 let end = (start + self.group_size).min(total);
477 let group = &w_vec[start..end];
478
479 let l2_norm: f32 = group.iter().map(|&x| x * x).sum::<f32>().sqrt();
480 group_norm_sum += l2_norm;
481 }
482
483 let penalty_val = self.lambda * group_norm_sum;
484 let penalty_tensor = Tensor::from_vec(vec![penalty_val], &[1]).unwrap();
485
486 Variable::new(penalty_tensor, false)
491 }
492}
493
494impl std::fmt::Debug for GroupSparsity {
495 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496 f.debug_struct("GroupSparsity")
497 .field("lambda", &self.lambda)
498 .field("group_size", &self.group_size)
499 .finish()
500 }
501}
502
503pub struct LotteryTicket {
536 initial_weights: HashMap<String, Tensor<f32>>,
538}
539
540impl LotteryTicket {
541 pub fn snapshot(params: &[Parameter]) -> Self {
546 let mut initial_weights = HashMap::new();
547 for (i, param) in params.iter().enumerate() {
548 let key = if param.name().is_empty() {
549 format!("param_{}", i)
550 } else {
551 param.name().to_string()
552 };
553 initial_weights.insert(key, param.data());
554 }
555 Self { initial_weights }
556 }
557
558 pub fn num_saved(&self) -> usize {
560 self.initial_weights.len()
561 }
562
563 pub fn rewind(&self, params: &[Parameter]) {
568 for (i, param) in params.iter().enumerate() {
569 let key = if param.name().is_empty() {
570 format!("param_{}", i)
571 } else {
572 param.name().to_string()
573 };
574 if let Some(initial) = self.initial_weights.get(&key) {
575 param.update_data(initial.clone());
576 }
577 }
578 }
579
580 pub fn rewind_with_mask(&self, params: &[Parameter], masks: &[Tensor<f32>]) {
589 assert_eq!(
590 params.len(),
591 masks.len(),
592 "Number of parameters and masks must match"
593 );
594
595 for (i, (param, mask)) in params.iter().zip(masks.iter()).enumerate() {
596 let key = if param.name().is_empty() {
597 format!("param_{}", i)
598 } else {
599 param.name().to_string()
600 };
601
602 if let Some(initial) = self.initial_weights.get(&key) {
603 let init_vec = initial.to_vec();
604 let mask_vec = mask.to_vec();
605
606 let rewound: Vec<f32> = init_vec
607 .iter()
608 .zip(mask_vec.iter())
609 .map(|(&w, &m)| if m > 0.5 { w } else { 0.0 })
610 .collect();
611
612 let shape = param.shape();
613 let new_data = Tensor::from_vec(rewound, &shape).unwrap();
614 param.update_data(new_data);
615 }
616 }
617 }
618}
619
620impl std::fmt::Debug for LotteryTicket {
621 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
622 f.debug_struct("LotteryTicket")
623 .field("num_saved", &self.initial_weights.len())
624 .finish()
625 }
626}
627
628#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
641 fn test_sparse_linear_creation_structured() {
642 let layer = SparseLinear::new(10, 5);
643 assert_eq!(layer.in_features(), 10);
644 assert_eq!(layer.out_features(), 5);
645 assert!(layer.is_structured());
646 assert!(layer.bias.is_some());
647 }
648
649 #[test]
650 fn test_sparse_linear_creation_unstructured() {
651 let layer = SparseLinear::unstructured(10, 5);
652 assert_eq!(layer.in_features(), 10);
653 assert_eq!(layer.out_features(), 5);
654 assert!(!layer.is_structured());
655 assert!(layer.bias.is_some());
656 }
657
658 #[test]
659 fn test_sparse_linear_no_bias() {
660 let layer = SparseLinear::with_bias(10, 5, false);
661 assert!(layer.bias.is_none());
662 }
663
664 #[test]
665 fn test_sparse_linear_forward_shape() {
666 let layer = SparseLinear::new(4, 3);
667 let input = Variable::new(
668 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
669 false,
670 );
671 let output = layer.forward(&input);
672 assert_eq!(output.shape(), vec![1, 3]);
673 }
674
675 #[test]
676 fn test_sparse_linear_forward_batch() {
677 let layer = SparseLinear::new(4, 3);
678 let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
679 let output = layer.forward(&input);
680 assert_eq!(output.shape(), vec![3, 3]);
681 }
682
683 #[test]
684 fn test_sparse_linear_forward_no_bias() {
685 let layer = SparseLinear::with_bias(4, 3, false);
686 let input = Variable::new(Tensor::from_vec(vec![1.0; 8], &[2, 4]).unwrap(), false);
687 let output = layer.forward(&input);
688 assert_eq!(output.shape(), vec![2, 3]);
689 }
690
691 #[test]
692 fn test_sparse_linear_density_initial() {
693 let layer = SparseLinear::new(100, 50);
696 let density = layer.density();
697 assert!(
698 density > 0.9,
699 "Initial density should be high, got {}",
700 density
701 );
702 }
703
704 #[test]
705 fn test_sparse_linear_sparsity_initial() {
706 let layer = SparseLinear::new(100, 50);
707 let sparsity = layer.sparsity();
708 assert!(
709 sparsity < 0.1,
710 "Initial sparsity should be low, got {}",
711 sparsity
712 );
713 assert!((layer.density() + layer.sparsity() - 1.0).abs() < 1e-6);
714 }
715
716 #[test]
717 fn test_sparse_linear_num_active() {
718 let layer = SparseLinear::new(10, 5);
719 let active = layer.num_active();
720 let total = 10 * 5;
721 assert!(active <= total);
722 assert!(active > 0);
723 }
724
725 #[test]
726 fn test_sparse_linear_high_threshold_more_sparsity() {
727 let mut layer = SparseLinear::new(100, 50);
728 let density_low_thresh = layer.density();
729
730 layer.reset_threshold(10.0);
732 let density_high_thresh = layer.density();
733
734 assert!(
735 density_high_thresh < density_low_thresh,
736 "Higher threshold should reduce density: low_thresh={}, high_thresh={}",
737 density_low_thresh,
738 density_high_thresh
739 );
740 }
741
742 #[test]
743 fn test_sparse_linear_low_threshold_dense() {
744 let mut layer = SparseLinear::new(100, 50);
745 layer.reset_threshold(0.0);
747 let density = layer.density();
748 assert!(
749 (density - 1.0).abs() < 1e-6,
750 "Zero threshold should give density=1.0, got {}",
751 density
752 );
753 }
754
755 #[test]
756 fn test_sparse_linear_soft_mask_values_in_range() {
757 let layer = SparseLinear::new(10, 5);
758 let weight_var = layer.weight.variable();
759 let mask = layer.compute_soft_mask(&weight_var);
760 let mask_vec = mask.data().to_vec();
761
762 for &v in &mask_vec {
763 assert!(v >= 0.0 && v <= 1.0, "Soft mask value {} not in [0, 1]", v);
764 }
765 }
766
767 #[test]
768 fn test_sparse_linear_hard_prune() {
769 let mut layer = SparseLinear::new(10, 5);
770 layer.reset_threshold(0.5);
772
773 let pre_prune_density = layer.density();
774 layer.hard_prune();
775
776 let weight_data = layer.weight.data();
778 let w_vec = weight_data.to_vec();
779 let zeros_count = w_vec.iter().filter(|&&v| v == 0.0).count();
780
781 let expected_zeros = ((1.0 - pre_prune_density) * (10 * 5) as f32).round() as usize;
783 assert_eq!(
784 zeros_count, expected_zeros,
785 "Hard prune should zero out pruned weights"
786 );
787 }
788
789 #[test]
790 fn test_sparse_linear_hard_prune_threshold_reset() {
791 let mut layer = SparseLinear::new(10, 5);
792 layer.reset_threshold(0.5);
793 layer.hard_prune();
794
795 let t_vec = layer.threshold.data().to_vec();
797 assert!(
798 t_vec.iter().all(|&v| v == 0.0),
799 "Thresholds should be zero after hard_prune"
800 );
801 }
802
803 #[test]
804 fn test_sparse_linear_effective_weight() {
805 let layer = SparseLinear::new(10, 5);
806 let ew = layer.effective_weight();
807 assert_eq!(ew.shape(), &[5, 10]);
808 }
809
810 #[test]
811 fn test_sparse_linear_effective_weight_matches_hard_prune() {
812 let mut layer = SparseLinear::new(10, 5);
813 layer.reset_threshold(0.3);
814
815 let effective = layer.effective_weight();
816 layer.hard_prune();
817 let pruned = layer.weight.data();
818
819 let e_vec = effective.to_vec();
820 let p_vec = pruned.to_vec();
821 for (e, p) in e_vec.iter().zip(p_vec.iter()) {
822 assert!(
823 (e - p).abs() < 1e-6,
824 "effective_weight and hard_prune should match"
825 );
826 }
827 }
828
829 #[test]
830 fn test_sparse_linear_parameters_include_threshold() {
831 let layer = SparseLinear::new(10, 5);
832 let params = layer.parameters();
833 assert_eq!(params.len(), 3);
835
836 let named = layer.named_parameters();
837 assert!(named.contains_key("threshold"));
838 assert!(named.contains_key("weight"));
839 assert!(named.contains_key("bias"));
840 }
841
842 #[test]
843 fn test_sparse_linear_parameters_no_bias() {
844 let layer = SparseLinear::with_bias(10, 5, false);
845 let params = layer.parameters();
846 assert_eq!(params.len(), 2);
848 }
849
850 #[test]
851 fn test_sparse_linear_module_name() {
852 let layer = SparseLinear::new(10, 5);
853 assert_eq!(layer.name(), "SparseLinear");
854 }
855
856 #[test]
857 fn test_sparse_linear_debug() {
858 let layer = SparseLinear::new(10, 5);
859 let debug_str = format!("{:?}", layer);
860 assert!(debug_str.contains("SparseLinear"));
861 assert!(debug_str.contains("in_features: 10"));
862 assert!(debug_str.contains("out_features: 5"));
863 }
864
865 #[test]
866 fn test_sparse_linear_reset_threshold() {
867 let mut layer = SparseLinear::new(10, 5);
868 layer.reset_threshold(0.5);
869 let t_vec = layer.threshold.data().to_vec();
870 assert!(t_vec.iter().all(|&v| (v - 0.5).abs() < 1e-6));
871 }
872
873 #[test]
874 fn test_sparse_linear_unstructured_threshold_shape() {
875 let layer = SparseLinear::unstructured(10, 5);
876 assert_eq!(layer.threshold.shape(), vec![5, 10]);
878 }
879
880 #[test]
881 fn test_sparse_linear_structured_threshold_shape() {
882 let layer = SparseLinear::new(10, 5);
883 assert_eq!(layer.threshold.shape(), vec![5]);
885 }
886
887 #[test]
888 fn test_sparse_linear_unstructured_forward() {
889 let layer = SparseLinear::unstructured(4, 3);
890 let input = Variable::new(
891 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
892 false,
893 );
894 let output = layer.forward(&input);
895 assert_eq!(output.shape(), vec![2, 3]);
896 }
897
898 #[test]
903 fn test_group_sparsity_creation() {
904 let reg = GroupSparsity::new(0.001, 10);
905 assert!((reg.lambda() - 0.001).abs() < 1e-8);
906 assert_eq!(reg.group_size(), 10);
907 }
908
909 #[test]
910 fn test_group_sparsity_penalty_non_negative() {
911 let reg = GroupSparsity::new(0.01, 4);
912 let weight = Variable::new(
913 Tensor::from_vec(vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0], &[2, 4]).unwrap(),
914 true,
915 );
916 let penalty = reg.penalty(&weight);
917 let penalty_val = penalty.data().to_vec()[0];
918 assert!(
919 penalty_val >= 0.0,
920 "Penalty should be non-negative, got {}",
921 penalty_val
922 );
923 }
924
925 #[test]
926 fn test_group_sparsity_zero_weights_zero_penalty() {
927 let reg = GroupSparsity::new(0.01, 4);
928 let weight = Variable::new(Tensor::from_vec(vec![0.0; 8], &[2, 4]).unwrap(), true);
929 let penalty = reg.penalty(&weight);
930 let penalty_val = penalty.data().to_vec()[0];
931 assert!(
932 (penalty_val).abs() < 1e-6,
933 "Zero weights should give zero penalty, got {}",
934 penalty_val
935 );
936 }
937
938 #[test]
939 fn test_group_sparsity_scales_with_lambda() {
940 let reg_small = GroupSparsity::new(0.001, 4);
941 let reg_large = GroupSparsity::new(0.01, 4);
942 let weight = Variable::new(
943 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap(),
944 true,
945 );
946
947 let penalty_small = reg_small.penalty(&weight).data().to_vec()[0];
948 let penalty_large = reg_large.penalty(&weight).data().to_vec()[0];
949
950 assert!(
951 penalty_large > penalty_small,
952 "Larger lambda should give larger penalty: small={}, large={}",
953 penalty_small,
954 penalty_large
955 );
956
957 let ratio = penalty_large / penalty_small;
959 assert!(
960 (ratio - 10.0).abs() < 1e-4,
961 "Penalty should scale linearly with lambda, ratio={}",
962 ratio
963 );
964 }
965
966 #[test]
967 fn test_group_sparsity_debug() {
968 let reg = GroupSparsity::new(0.001, 10);
969 let debug_str = format!("{:?}", reg);
970 assert!(debug_str.contains("GroupSparsity"));
971 assert!(debug_str.contains("lambda"));
972 }
973
974 #[test]
975 #[should_panic(expected = "group_size must be positive")]
976 fn test_group_sparsity_zero_group_size_panics() {
977 let _reg = GroupSparsity::new(0.01, 0);
978 }
979
980 #[test]
985 fn test_lottery_ticket_snapshot() {
986 let layer = SparseLinear::new(10, 5);
987 let params = layer.parameters();
988 let ticket = LotteryTicket::snapshot(¶ms);
989 assert_eq!(ticket.num_saved(), params.len());
990 }
991
992 #[test]
993 fn test_lottery_ticket_rewind() {
994 let layer = SparseLinear::new(10, 5);
995 let params = layer.parameters();
996 let initial_weight = params[0].data().to_vec();
997
998 let ticket = LotteryTicket::snapshot(¶ms);
999
1000 let new_data = Tensor::from_vec(vec![99.0; 50], &[5, 10]).unwrap();
1002 params[0].update_data(new_data);
1003
1004 let modified_weight = params[0].data().to_vec();
1006 assert_ne!(modified_weight, initial_weight);
1007
1008 ticket.rewind(¶ms);
1010
1011 let rewound_weight = params[0].data().to_vec();
1013 assert_eq!(rewound_weight, initial_weight);
1014 }
1015
1016 #[test]
1017 fn test_lottery_ticket_rewind_preserves_shapes() {
1018 let layer = SparseLinear::new(10, 5);
1019 let params = layer.parameters();
1020 let initial_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1021
1022 let ticket = LotteryTicket::snapshot(¶ms);
1023
1024 let new_data = Tensor::from_vec(vec![0.0; 50], &[5, 10]).unwrap();
1026 params[0].update_data(new_data);
1027
1028 ticket.rewind(¶ms);
1029
1030 let rewound_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1031 assert_eq!(initial_shapes, rewound_shapes);
1032 }
1033
1034 #[test]
1035 fn test_lottery_ticket_rewind_with_mask() {
1036 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1037 let param = Parameter::named("weight", data, true);
1038 let params = vec![param];
1039
1040 let ticket = LotteryTicket::snapshot(¶ms);
1041
1042 let new_data = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2]).unwrap();
1044 params[0].update_data(new_data);
1045
1046 let mask = Tensor::from_vec(vec![1.0, 1.0, 0.0, 0.0], &[2, 2]).unwrap();
1048 ticket.rewind_with_mask(¶ms, &[mask]);
1049
1050 let result = params[0].data().to_vec();
1051 assert_eq!(
1052 result,
1053 vec![1.0, 2.0, 0.0, 0.0],
1054 "Masked weights should be zero, unmasked should be initial values"
1055 );
1056 }
1057
1058 #[test]
1059 fn test_lottery_ticket_debug() {
1060 let layer = SparseLinear::new(10, 5);
1061 let ticket = LotteryTicket::snapshot(&layer.parameters());
1062 let debug_str = format!("{:?}", ticket);
1063 assert!(debug_str.contains("LotteryTicket"));
1064 assert!(debug_str.contains("num_saved"));
1065 }
1066
1067 #[test]
1072 fn test_integration_sparse_linear_with_group_sparsity() {
1073 let layer = SparseLinear::new(8, 4);
1075
1076 let input = Variable::new(Tensor::from_vec(vec![1.0; 16], &[2, 8]).unwrap(), false);
1078 let output = layer.forward(&input);
1079 assert_eq!(output.shape(), vec![2, 4]);
1080
1081 let reg = GroupSparsity::new(0.001, 8); let weight_var = layer.weight.variable();
1084 let penalty = reg.penalty(&weight_var);
1085 let penalty_val = penalty.data().to_vec()[0];
1086 assert!(
1087 penalty_val > 0.0,
1088 "Penalty should be positive for non-zero weights"
1089 );
1090 }
1091
1092 #[test]
1093 fn test_integration_lottery_ticket_with_pruning() {
1094 let mut layer = SparseLinear::new(8, 4);
1096 let ticket = LotteryTicket::snapshot(&layer.parameters());
1097
1098 let new_weight = Tensor::from_vec(vec![0.5; 32], &[4, 8]).unwrap();
1100 layer.weight.update_data(new_weight);
1101
1102 layer.reset_threshold(0.3);
1104
1105 let mask = layer.hard_mask();
1107
1108 let weight_param = vec![layer.weight.clone()];
1110 ticket.rewind_with_mask(&weight_param, &[mask]);
1111
1112 assert_eq!(layer.weight.shape(), vec![4, 8]);
1114 }
1115
1116 #[test]
1117 fn test_num_parameters_sparse_linear() {
1118 let layer = SparseLinear::new(10, 5);
1119 assert_eq!(layer.num_parameters(), 60);
1121 }
1122}