1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::init::{constant, kaiming_uniform, zeros};
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26const TEMPERATURE: f32 = 10.0;
33
34const DEFAULT_THRESHOLD: f32 = 0.01;
37
38pub struct SparseLinear {
70 pub weight: Parameter,
72 pub bias: Option<Parameter>,
74 pub threshold: Parameter,
78 in_features: usize,
80 out_features: usize,
82 structured: bool,
84}
85
86impl SparseLinear {
87 pub fn new(in_features: usize, out_features: usize) -> Self {
93 Self::build(in_features, out_features, true, true)
94 }
95
96 pub fn unstructured(in_features: usize, out_features: usize) -> Self {
102 Self::build(in_features, out_features, false, true)
103 }
104
105 pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
112 Self::build(in_features, out_features, true, bias)
113 }
114
115 fn build(in_features: usize, out_features: usize, structured: bool, bias: bool) -> Self {
117 let weight_data = kaiming_uniform(out_features, in_features);
119 let weight = Parameter::named("weight", weight_data, true);
120
121 let bias_param = if bias {
123 let bias_data = zeros(&[out_features]);
124 Some(Parameter::named("bias", bias_data, true))
125 } else {
126 None
127 };
128
129 let threshold_data = if structured {
131 constant(&[out_features], DEFAULT_THRESHOLD)
132 } else {
133 constant(&[out_features, in_features], DEFAULT_THRESHOLD)
134 };
135 let threshold = Parameter::named("threshold", threshold_data, true);
136
137 Self {
138 weight,
139 bias: bias_param,
140 threshold,
141 in_features,
142 out_features,
143 structured,
144 }
145 }
146
147 pub fn in_features(&self) -> usize {
149 self.in_features
150 }
151
152 pub fn out_features(&self) -> usize {
154 self.out_features
155 }
156
157 pub fn is_structured(&self) -> bool {
159 self.structured
160 }
161
162 fn hard_mask(&self) -> Tensor<f32> {
167 let weight_data = self.weight.data();
168 let threshold_data = self.threshold.data();
169 let w_vec = weight_data.to_vec();
170 let t_vec = threshold_data.to_vec();
171
172 let mask_vec: Vec<f32> = if self.structured {
173 w_vec
175 .iter()
176 .enumerate()
177 .map(|(idx, &w)| {
178 let out_idx = idx / self.in_features;
179 let t = t_vec[out_idx];
180 if w.abs() >= t { 1.0 } else { 0.0 }
181 })
182 .collect()
183 } else {
184 w_vec
186 .iter()
187 .zip(t_vec.iter())
188 .map(|(&w, &t)| if w.abs() >= t { 1.0 } else { 0.0 })
189 .collect()
190 };
191
192 Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
193 .expect("tensor creation failed")
194 }
195
196 pub fn density(&self) -> f32 {
200 let mask = self.hard_mask();
201 let mask_vec = mask.to_vec();
202 let total = mask_vec.len() as f32;
203 let active: f32 = mask_vec.iter().sum();
204 active / total
205 }
206
207 pub fn sparsity(&self) -> f32 {
211 1.0 - self.density()
212 }
213
214 pub fn num_active(&self) -> usize {
216 let mask = self.hard_mask();
217 let mask_vec = mask.to_vec();
218 mask_vec.iter().filter(|&&v| v > 0.5).count()
219 }
220
221 pub fn hard_prune(&mut self) {
227 let mask = self.hard_mask();
228 let weight_data = self.weight.data();
229 let w_vec = weight_data.to_vec();
230 let m_vec = mask.to_vec();
231
232 let pruned: Vec<f32> = w_vec
233 .iter()
234 .zip(m_vec.iter())
235 .map(|(&w, &m)| w * m)
236 .collect();
237
238 let new_weight = Tensor::from_vec(pruned, &[self.out_features, self.in_features])
239 .expect("tensor creation failed");
240 self.weight.update_data(new_weight);
241
242 let zero_threshold = if self.structured {
244 zeros(&[self.out_features])
245 } else {
246 zeros(&[self.out_features, self.in_features])
247 };
248 self.threshold.update_data(zero_threshold);
249 }
250
251 pub fn reset_threshold(&mut self, value: f32) {
256 let new_threshold = if self.structured {
257 constant(&[self.out_features], value)
258 } else {
259 constant(&[self.out_features, self.in_features], value)
260 };
261 self.threshold.update_data(new_threshold);
262 }
263
264 pub fn effective_weight(&self) -> Tensor<f32> {
269 let mask = self.hard_mask();
270 let weight_data = self.weight.data();
271 let w_vec = weight_data.to_vec();
272 let m_vec = mask.to_vec();
273
274 let effective: Vec<f32> = w_vec
275 .iter()
276 .zip(m_vec.iter())
277 .map(|(&w, &m)| w * m)
278 .collect();
279
280 Tensor::from_vec(effective, &[self.out_features, self.in_features])
281 .expect("tensor creation failed")
282 }
283
284 fn compute_soft_mask(&self, weight_var: &Variable) -> Variable {
294 let weight_data = weight_var.data();
295 let threshold_data = self.threshold.data();
296 let w_vec = weight_data.to_vec();
297 let t_vec = threshold_data.to_vec();
298
299 let mask_vec: Vec<f32> = if self.structured {
301 w_vec
302 .iter()
303 .enumerate()
304 .map(|(idx, &w)| {
305 let out_idx = idx / self.in_features;
306 let t = t_vec[out_idx];
307 let x = (w.abs() - t) * TEMPERATURE;
308 1.0 / (1.0 + (-x).exp())
309 })
310 .collect()
311 } else {
312 w_vec
313 .iter()
314 .zip(t_vec.iter())
315 .map(|(&w, &t)| {
316 let x = (w.abs() - t) * TEMPERATURE;
317 1.0 / (1.0 + (-x).exp())
318 })
319 .collect()
320 };
321
322 let mask_tensor = Tensor::from_vec(mask_vec, &[self.out_features, self.in_features])
323 .expect("tensor creation failed");
324
325 Variable::new(mask_tensor, false)
330 }
331}
332
333impl Module for SparseLinear {
334 fn forward(&self, input: &Variable) -> Variable {
335 let input_shape = input.shape();
336 let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
337 let total_batch: usize = batch_dims.iter().product();
338
339 let input_2d = if input_shape.len() > 2 {
341 input.reshape(&[total_batch, self.in_features])
342 } else {
343 input.clone()
344 };
345
346 let weight_var = self.weight.variable();
348 let mask = self.compute_soft_mask(&weight_var);
349
350 let effective_weight = weight_var.mul_var(&mask);
352
353 let weight_t = effective_weight.transpose(0, 1);
355 let mut output = input_2d.matmul(&weight_t);
356
357 if let Some(ref bias) = self.bias {
359 let bias_var = bias.variable();
360 output = output.add_var(&bias_var);
361 }
362
363 if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
365 let mut output_shape: Vec<usize> = batch_dims;
366 output_shape.push(self.out_features);
367 output.reshape(&output_shape)
368 } else {
369 output
370 }
371 }
372
373 fn parameters(&self) -> Vec<Parameter> {
374 let mut params = vec![self.weight.clone(), self.threshold.clone()];
375 if let Some(ref bias) = self.bias {
376 params.push(bias.clone());
377 }
378 params
379 }
380
381 fn named_parameters(&self) -> HashMap<String, Parameter> {
382 let mut params = HashMap::new();
383 params.insert("weight".to_string(), self.weight.clone());
384 params.insert("threshold".to_string(), self.threshold.clone());
385 if let Some(ref bias) = self.bias {
386 params.insert("bias".to_string(), bias.clone());
387 }
388 params
389 }
390
391 fn name(&self) -> &'static str {
392 "SparseLinear"
393 }
394}
395
396impl std::fmt::Debug for SparseLinear {
397 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398 f.debug_struct("SparseLinear")
399 .field("in_features", &self.in_features)
400 .field("out_features", &self.out_features)
401 .field("bias", &self.bias.is_some())
402 .field("structured", &self.structured)
403 .field("density", &self.density())
404 .finish()
405 }
406}
407
408pub struct GroupSparsity {
431 lambda: f32,
433 group_size: usize,
435}
436
437impl GroupSparsity {
438 pub fn new(lambda: f32, group_size: usize) -> Self {
444 assert!(group_size > 0, "group_size must be positive");
445 Self { lambda, group_size }
446 }
447
448 pub fn lambda(&self) -> f32 {
450 self.lambda
451 }
452
453 pub fn group_size(&self) -> usize {
455 self.group_size
456 }
457
458 pub fn penalty(&self, weight: &Variable) -> Variable {
468 let weight_data = weight.data();
469 let w_vec = weight_data.to_vec();
470 let total = w_vec.len();
471
472 let num_groups = total.div_ceil(self.group_size);
474
475 let mut group_norm_sum = 0.0f32;
477 for g in 0..num_groups {
478 let start = g * self.group_size;
479 let end = (start + self.group_size).min(total);
480 let group = &w_vec[start..end];
481
482 let l2_norm: f32 = group.iter().map(|&x| x * x).sum::<f32>().sqrt();
483 group_norm_sum += l2_norm;
484 }
485
486 let penalty_val = self.lambda * group_norm_sum;
487 let penalty_tensor =
488 Tensor::from_vec(vec![penalty_val], &[1]).expect("tensor creation failed");
489
490 Variable::new(penalty_tensor, false)
495 }
496}
497
498impl std::fmt::Debug for GroupSparsity {
499 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500 f.debug_struct("GroupSparsity")
501 .field("lambda", &self.lambda)
502 .field("group_size", &self.group_size)
503 .finish()
504 }
505}
506
507pub struct LotteryTicket {
540 initial_weights: HashMap<String, Tensor<f32>>,
542}
543
544impl LotteryTicket {
545 pub fn snapshot(params: &[Parameter]) -> Self {
550 let mut initial_weights = HashMap::new();
551 for (i, param) in params.iter().enumerate() {
552 let key = if param.name().is_empty() {
553 format!("param_{}", i)
554 } else {
555 param.name().to_string()
556 };
557 initial_weights.insert(key, param.data());
558 }
559 Self { initial_weights }
560 }
561
562 pub fn num_saved(&self) -> usize {
564 self.initial_weights.len()
565 }
566
567 pub fn rewind(&self, params: &[Parameter]) {
572 for (i, param) in params.iter().enumerate() {
573 let key = if param.name().is_empty() {
574 format!("param_{}", i)
575 } else {
576 param.name().to_string()
577 };
578 if let Some(initial) = self.initial_weights.get(&key) {
579 param.update_data(initial.clone());
580 }
581 }
582 }
583
584 pub fn rewind_with_mask(&self, params: &[Parameter], masks: &[Tensor<f32>]) {
593 assert_eq!(
594 params.len(),
595 masks.len(),
596 "Number of parameters and masks must match"
597 );
598
599 for (i, (param, mask)) in params.iter().zip(masks.iter()).enumerate() {
600 let key = if param.name().is_empty() {
601 format!("param_{}", i)
602 } else {
603 param.name().to_string()
604 };
605
606 if let Some(initial) = self.initial_weights.get(&key) {
607 let init_vec = initial.to_vec();
608 let mask_vec = mask.to_vec();
609
610 let rewound: Vec<f32> = init_vec
611 .iter()
612 .zip(mask_vec.iter())
613 .map(|(&w, &m)| if m > 0.5 { w } else { 0.0 })
614 .collect();
615
616 let shape = param.shape();
617 let new_data = Tensor::from_vec(rewound, &shape).expect("tensor creation failed");
618 param.update_data(new_data);
619 }
620 }
621 }
622}
623
624impl std::fmt::Debug for LotteryTicket {
625 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626 f.debug_struct("LotteryTicket")
627 .field("num_saved", &self.initial_weights.len())
628 .finish()
629 }
630}
631
632#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
645 fn test_sparse_linear_creation_structured() {
646 let layer = SparseLinear::new(10, 5);
647 assert_eq!(layer.in_features(), 10);
648 assert_eq!(layer.out_features(), 5);
649 assert!(layer.is_structured());
650 assert!(layer.bias.is_some());
651 }
652
653 #[test]
654 fn test_sparse_linear_creation_unstructured() {
655 let layer = SparseLinear::unstructured(10, 5);
656 assert_eq!(layer.in_features(), 10);
657 assert_eq!(layer.out_features(), 5);
658 assert!(!layer.is_structured());
659 assert!(layer.bias.is_some());
660 }
661
662 #[test]
663 fn test_sparse_linear_no_bias() {
664 let layer = SparseLinear::with_bias(10, 5, false);
665 assert!(layer.bias.is_none());
666 }
667
668 #[test]
669 fn test_sparse_linear_forward_shape() {
670 let layer = SparseLinear::new(4, 3);
671 let input = Variable::new(
672 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
673 false,
674 );
675 let output = layer.forward(&input);
676 assert_eq!(output.shape(), vec![1, 3]);
677 }
678
679 #[test]
680 fn test_sparse_linear_forward_batch() {
681 let layer = SparseLinear::new(4, 3);
682 let input = Variable::new(
683 Tensor::from_vec(vec![1.0; 12], &[3, 4]).expect("tensor creation failed"),
684 false,
685 );
686 let output = layer.forward(&input);
687 assert_eq!(output.shape(), vec![3, 3]);
688 }
689
690 #[test]
691 fn test_sparse_linear_forward_no_bias() {
692 let layer = SparseLinear::with_bias(4, 3, false);
693 let input = Variable::new(
694 Tensor::from_vec(vec![1.0; 8], &[2, 4]).expect("tensor creation failed"),
695 false,
696 );
697 let output = layer.forward(&input);
698 assert_eq!(output.shape(), vec![2, 3]);
699 }
700
701 #[test]
702 fn test_sparse_linear_density_initial() {
703 let layer = SparseLinear::new(100, 50);
706 let density = layer.density();
707 assert!(
708 density > 0.9,
709 "Initial density should be high, got {}",
710 density
711 );
712 }
713
714 #[test]
715 fn test_sparse_linear_sparsity_initial() {
716 let layer = SparseLinear::new(100, 50);
717 let sparsity = layer.sparsity();
718 assert!(
719 sparsity < 0.1,
720 "Initial sparsity should be low, got {}",
721 sparsity
722 );
723 assert!((layer.density() + layer.sparsity() - 1.0).abs() < 1e-6);
724 }
725
726 #[test]
727 fn test_sparse_linear_num_active() {
728 let layer = SparseLinear::new(10, 5);
729 let active = layer.num_active();
730 let total = 10 * 5;
731 assert!(active <= total);
732 assert!(active > 0);
733 }
734
735 #[test]
736 fn test_sparse_linear_high_threshold_more_sparsity() {
737 let mut layer = SparseLinear::new(100, 50);
738 let density_low_thresh = layer.density();
739
740 layer.reset_threshold(10.0);
742 let density_high_thresh = layer.density();
743
744 assert!(
745 density_high_thresh < density_low_thresh,
746 "Higher threshold should reduce density: low_thresh={}, high_thresh={}",
747 density_low_thresh,
748 density_high_thresh
749 );
750 }
751
752 #[test]
753 fn test_sparse_linear_low_threshold_dense() {
754 let mut layer = SparseLinear::new(100, 50);
755 layer.reset_threshold(0.0);
757 let density = layer.density();
758 assert!(
759 (density - 1.0).abs() < 1e-6,
760 "Zero threshold should give density=1.0, got {}",
761 density
762 );
763 }
764
765 #[test]
766 fn test_sparse_linear_soft_mask_values_in_range() {
767 let layer = SparseLinear::new(10, 5);
768 let weight_var = layer.weight.variable();
769 let mask = layer.compute_soft_mask(&weight_var);
770 let mask_vec = mask.data().to_vec();
771
772 for &v in &mask_vec {
773 assert!(v >= 0.0 && v <= 1.0, "Soft mask value {} not in [0, 1]", v);
774 }
775 }
776
777 #[test]
778 fn test_sparse_linear_hard_prune() {
779 let mut layer = SparseLinear::new(10, 5);
780 layer.reset_threshold(0.5);
782
783 let pre_prune_density = layer.density();
784 layer.hard_prune();
785
786 let weight_data = layer.weight.data();
788 let w_vec = weight_data.to_vec();
789 let zeros_count = w_vec.iter().filter(|&&v| v == 0.0).count();
790
791 let expected_zeros = ((1.0 - pre_prune_density) * (10 * 5) as f32).round() as usize;
793 assert_eq!(
794 zeros_count, expected_zeros,
795 "Hard prune should zero out pruned weights"
796 );
797 }
798
799 #[test]
800 fn test_sparse_linear_hard_prune_threshold_reset() {
801 let mut layer = SparseLinear::new(10, 5);
802 layer.reset_threshold(0.5);
803 layer.hard_prune();
804
805 let t_vec = layer.threshold.data().to_vec();
807 assert!(
808 t_vec.iter().all(|&v| v == 0.0),
809 "Thresholds should be zero after hard_prune"
810 );
811 }
812
813 #[test]
814 fn test_sparse_linear_effective_weight() {
815 let layer = SparseLinear::new(10, 5);
816 let ew = layer.effective_weight();
817 assert_eq!(ew.shape(), &[5, 10]);
818 }
819
820 #[test]
821 fn test_sparse_linear_effective_weight_matches_hard_prune() {
822 let mut layer = SparseLinear::new(10, 5);
823 layer.reset_threshold(0.3);
824
825 let effective = layer.effective_weight();
826 layer.hard_prune();
827 let pruned = layer.weight.data();
828
829 let e_vec = effective.to_vec();
830 let p_vec = pruned.to_vec();
831 for (e, p) in e_vec.iter().zip(p_vec.iter()) {
832 assert!(
833 (e - p).abs() < 1e-6,
834 "effective_weight and hard_prune should match"
835 );
836 }
837 }
838
839 #[test]
840 fn test_sparse_linear_parameters_include_threshold() {
841 let layer = SparseLinear::new(10, 5);
842 let params = layer.parameters();
843 assert_eq!(params.len(), 3);
845
846 let named = layer.named_parameters();
847 assert!(named.contains_key("threshold"));
848 assert!(named.contains_key("weight"));
849 assert!(named.contains_key("bias"));
850 }
851
852 #[test]
853 fn test_sparse_linear_parameters_no_bias() {
854 let layer = SparseLinear::with_bias(10, 5, false);
855 let params = layer.parameters();
856 assert_eq!(params.len(), 2);
858 }
859
860 #[test]
861 fn test_sparse_linear_module_name() {
862 let layer = SparseLinear::new(10, 5);
863 assert_eq!(layer.name(), "SparseLinear");
864 }
865
866 #[test]
867 fn test_sparse_linear_debug() {
868 let layer = SparseLinear::new(10, 5);
869 let debug_str = format!("{:?}", layer);
870 assert!(debug_str.contains("SparseLinear"));
871 assert!(debug_str.contains("in_features: 10"));
872 assert!(debug_str.contains("out_features: 5"));
873 }
874
875 #[test]
876 fn test_sparse_linear_reset_threshold() {
877 let mut layer = SparseLinear::new(10, 5);
878 layer.reset_threshold(0.5);
879 let t_vec = layer.threshold.data().to_vec();
880 assert!(t_vec.iter().all(|&v| (v - 0.5).abs() < 1e-6));
881 }
882
883 #[test]
884 fn test_sparse_linear_unstructured_threshold_shape() {
885 let layer = SparseLinear::unstructured(10, 5);
886 assert_eq!(layer.threshold.shape(), vec![5, 10]);
888 }
889
890 #[test]
891 fn test_sparse_linear_structured_threshold_shape() {
892 let layer = SparseLinear::new(10, 5);
893 assert_eq!(layer.threshold.shape(), vec![5]);
895 }
896
897 #[test]
898 fn test_sparse_linear_unstructured_forward() {
899 let layer = SparseLinear::unstructured(4, 3);
900 let input = Variable::new(
901 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
902 .expect("tensor creation failed"),
903 false,
904 );
905 let output = layer.forward(&input);
906 assert_eq!(output.shape(), vec![2, 3]);
907 }
908
909 #[test]
914 fn test_group_sparsity_creation() {
915 let reg = GroupSparsity::new(0.001, 10);
916 assert!((reg.lambda() - 0.001).abs() < 1e-8);
917 assert_eq!(reg.group_size(), 10);
918 }
919
920 #[test]
921 fn test_group_sparsity_penalty_non_negative() {
922 let reg = GroupSparsity::new(0.01, 4);
923 let weight = Variable::new(
924 Tensor::from_vec(vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0], &[2, 4])
925 .expect("tensor creation failed"),
926 true,
927 );
928 let penalty = reg.penalty(&weight);
929 let penalty_val = penalty.data().to_vec()[0];
930 assert!(
931 penalty_val >= 0.0,
932 "Penalty should be non-negative, got {}",
933 penalty_val
934 );
935 }
936
937 #[test]
938 fn test_group_sparsity_zero_weights_zero_penalty() {
939 let reg = GroupSparsity::new(0.01, 4);
940 let weight = Variable::new(
941 Tensor::from_vec(vec![0.0; 8], &[2, 4]).expect("tensor creation failed"),
942 true,
943 );
944 let penalty = reg.penalty(&weight);
945 let penalty_val = penalty.data().to_vec()[0];
946 assert!(
947 (penalty_val).abs() < 1e-6,
948 "Zero weights should give zero penalty, got {}",
949 penalty_val
950 );
951 }
952
953 #[test]
954 fn test_group_sparsity_scales_with_lambda() {
955 let reg_small = GroupSparsity::new(0.001, 4);
956 let reg_large = GroupSparsity::new(0.01, 4);
957 let weight = Variable::new(
958 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
959 true,
960 );
961
962 let penalty_small = reg_small.penalty(&weight).data().to_vec()[0];
963 let penalty_large = reg_large.penalty(&weight).data().to_vec()[0];
964
965 assert!(
966 penalty_large > penalty_small,
967 "Larger lambda should give larger penalty: small={}, large={}",
968 penalty_small,
969 penalty_large
970 );
971
972 let ratio = penalty_large / penalty_small;
974 assert!(
975 (ratio - 10.0).abs() < 1e-4,
976 "Penalty should scale linearly with lambda, ratio={}",
977 ratio
978 );
979 }
980
981 #[test]
982 fn test_group_sparsity_debug() {
983 let reg = GroupSparsity::new(0.001, 10);
984 let debug_str = format!("{:?}", reg);
985 assert!(debug_str.contains("GroupSparsity"));
986 assert!(debug_str.contains("lambda"));
987 }
988
989 #[test]
990 #[should_panic(expected = "group_size must be positive")]
991 fn test_group_sparsity_zero_group_size_panics() {
992 let _reg = GroupSparsity::new(0.01, 0);
993 }
994
995 #[test]
1000 fn test_lottery_ticket_snapshot() {
1001 let layer = SparseLinear::new(10, 5);
1002 let params = layer.parameters();
1003 let ticket = LotteryTicket::snapshot(¶ms);
1004 assert_eq!(ticket.num_saved(), params.len());
1005 }
1006
1007 #[test]
1008 fn test_lottery_ticket_rewind() {
1009 let layer = SparseLinear::new(10, 5);
1010 let params = layer.parameters();
1011 let initial_weight = params[0].data().to_vec();
1012
1013 let ticket = LotteryTicket::snapshot(¶ms);
1014
1015 let new_data = Tensor::from_vec(vec![99.0; 50], &[5, 10]).expect("tensor creation failed");
1017 params[0].update_data(new_data);
1018
1019 let modified_weight = params[0].data().to_vec();
1021 assert_ne!(modified_weight, initial_weight);
1022
1023 ticket.rewind(¶ms);
1025
1026 let rewound_weight = params[0].data().to_vec();
1028 assert_eq!(rewound_weight, initial_weight);
1029 }
1030
1031 #[test]
1032 fn test_lottery_ticket_rewind_preserves_shapes() {
1033 let layer = SparseLinear::new(10, 5);
1034 let params = layer.parameters();
1035 let initial_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1036
1037 let ticket = LotteryTicket::snapshot(¶ms);
1038
1039 let new_data = Tensor::from_vec(vec![0.0; 50], &[5, 10]).expect("tensor creation failed");
1041 params[0].update_data(new_data);
1042
1043 ticket.rewind(¶ms);
1044
1045 let rewound_shapes: Vec<Vec<usize>> = params.iter().map(|p| p.shape()).collect();
1046 assert_eq!(initial_shapes, rewound_shapes);
1047 }
1048
1049 #[test]
1050 fn test_lottery_ticket_rewind_with_mask() {
1051 let data =
1052 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor creation failed");
1053 let param = Parameter::named("weight", data, true);
1054 let params = vec![param];
1055
1056 let ticket = LotteryTicket::snapshot(¶ms);
1057
1058 let new_data = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2])
1060 .expect("tensor creation failed");
1061 params[0].update_data(new_data);
1062
1063 let mask =
1065 Tensor::from_vec(vec![1.0, 1.0, 0.0, 0.0], &[2, 2]).expect("tensor creation failed");
1066 ticket.rewind_with_mask(¶ms, &[mask]);
1067
1068 let result = params[0].data().to_vec();
1069 assert_eq!(
1070 result,
1071 vec![1.0, 2.0, 0.0, 0.0],
1072 "Masked weights should be zero, unmasked should be initial values"
1073 );
1074 }
1075
1076 #[test]
1077 fn test_lottery_ticket_debug() {
1078 let layer = SparseLinear::new(10, 5);
1079 let ticket = LotteryTicket::snapshot(&layer.parameters());
1080 let debug_str = format!("{:?}", ticket);
1081 assert!(debug_str.contains("LotteryTicket"));
1082 assert!(debug_str.contains("num_saved"));
1083 }
1084
1085 #[test]
1090 fn test_integration_sparse_linear_with_group_sparsity() {
1091 let layer = SparseLinear::new(8, 4);
1093
1094 let input = Variable::new(
1096 Tensor::from_vec(vec![1.0; 16], &[2, 8]).expect("tensor creation failed"),
1097 false,
1098 );
1099 let output = layer.forward(&input);
1100 assert_eq!(output.shape(), vec![2, 4]);
1101
1102 let reg = GroupSparsity::new(0.001, 8); let weight_var = layer.weight.variable();
1105 let penalty = reg.penalty(&weight_var);
1106 let penalty_val = penalty.data().to_vec()[0];
1107 assert!(
1108 penalty_val > 0.0,
1109 "Penalty should be positive for non-zero weights"
1110 );
1111 }
1112
1113 #[test]
1114 fn test_integration_lottery_ticket_with_pruning() {
1115 let mut layer = SparseLinear::new(8, 4);
1117 let ticket = LotteryTicket::snapshot(&layer.parameters());
1118
1119 let new_weight = Tensor::from_vec(vec![0.5; 32], &[4, 8]).expect("tensor creation failed");
1121 layer.weight.update_data(new_weight);
1122
1123 layer.reset_threshold(0.3);
1125
1126 let mask = layer.hard_mask();
1128
1129 let weight_param = vec![layer.weight.clone()];
1131 ticket.rewind_with_mask(&weight_param, &[mask]);
1132
1133 assert_eq!(layer.weight.shape(), vec![4, 8]);
1135 }
1136
1137 #[test]
1138 fn test_num_parameters_sparse_linear() {
1139 let layer = SparseLinear::new(10, 5);
1140 assert_eq!(layer.num_parameters(), 60);
1142 }
1143}