1use super::{PruningConfig, PruningSchedule, PruningStats};
8use crate::error::{MlError, Result};
9use std::cmp::Ordering;
10use tracing::{debug, info, warn};
11
12#[derive(Debug, Clone)]
17pub struct WeightTensor {
18 pub data: Vec<f32>,
20 pub shape: Vec<usize>,
22 pub name: String,
24}
25
26impl WeightTensor {
27 #[must_use]
29 pub fn new(data: Vec<f32>, shape: Vec<usize>, name: String) -> Self {
30 Self { data, shape, name }
31 }
32
33 #[must_use]
35 pub fn numel(&self) -> usize {
36 self.data.len()
37 }
38
39 #[must_use]
41 pub fn is_empty(&self) -> bool {
42 self.data.is_empty()
43 }
44
45 pub fn validate(&self) -> Result<()> {
50 let expected_len: usize = self.shape.iter().product();
51 if expected_len != self.data.len() {
52 return Err(MlError::InvalidConfig(format!(
53 "Shape {:?} expects {} elements but got {}",
54 self.shape,
55 expected_len,
56 self.data.len()
57 )));
58 }
59 Ok(())
60 }
61
62 #[must_use]
64 pub fn sparsity(&self) -> f32 {
65 if self.data.is_empty() {
66 return 0.0;
67 }
68 let zero_count = self.data.iter().filter(|&&w| w == 0.0).count();
69 zero_count as f32 / self.data.len() as f32
70 }
71
72 #[must_use]
74 pub fn l1_norm(&self) -> f32 {
75 self.data.iter().map(|w| w.abs()).sum()
76 }
77
78 #[must_use]
80 pub fn l2_norm(&self) -> f32 {
81 self.data.iter().map(|w| w * w).sum::<f32>().sqrt()
82 }
83
84 #[must_use]
86 pub fn statistics(&self) -> WeightStatistics {
87 if self.data.is_empty() {
88 return WeightStatistics::default();
89 }
90
91 let mut sorted = self.data.clone();
92 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
93
94 let min = sorted.first().copied().unwrap_or(0.0);
95 let max = sorted.last().copied().unwrap_or(0.0);
96 let mean = self.data.iter().sum::<f32>() / self.data.len() as f32;
97
98 let variance =
99 self.data.iter().map(|w| (w - mean).powi(2)).sum::<f32>() / self.data.len() as f32;
100 let std = variance.sqrt();
101
102 let median_idx = self.data.len() / 2;
103 let median = if self.data.len() % 2 == 0 {
104 (sorted
105 .get(median_idx.saturating_sub(1))
106 .copied()
107 .unwrap_or(0.0)
108 + sorted.get(median_idx).copied().unwrap_or(0.0))
109 / 2.0
110 } else {
111 sorted.get(median_idx).copied().unwrap_or(0.0)
112 };
113
114 WeightStatistics {
115 min,
116 max,
117 mean,
118 std,
119 median,
120 sparsity: self.sparsity(),
121 }
122 }
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct WeightStatistics {
128 pub min: f32,
130 pub max: f32,
132 pub mean: f32,
134 pub std: f32,
136 pub median: f32,
138 pub sparsity: f32,
140}
141
142#[derive(Debug, Clone)]
147pub struct PruningMask {
148 pub mask: Vec<bool>,
150 pub shape: Vec<usize>,
152 pub name: Option<String>,
154}
155
156impl PruningMask {
157 #[must_use]
159 pub fn new(mask: Vec<bool>, shape: Vec<usize>) -> Self {
160 Self {
161 mask,
162 shape,
163 name: None,
164 }
165 }
166
167 #[must_use]
169 pub fn with_name(mask: Vec<bool>, shape: Vec<usize>, name: String) -> Self {
170 Self {
171 mask,
172 shape,
173 name: Some(name),
174 }
175 }
176
177 #[must_use]
179 pub fn ones(shape: &[usize]) -> Self {
180 let size: usize = shape.iter().product();
181 Self {
182 mask: vec![true; size],
183 shape: shape.to_vec(),
184 name: None,
185 }
186 }
187
188 #[must_use]
190 pub fn zeros(shape: &[usize]) -> Self {
191 let size: usize = shape.iter().product();
192 Self {
193 mask: vec![false; size],
194 shape: shape.to_vec(),
195 name: None,
196 }
197 }
198
199 #[must_use]
201 pub fn numel(&self) -> usize {
202 self.mask.len()
203 }
204
205 #[must_use]
207 pub fn num_kept(&self) -> usize {
208 self.mask.iter().filter(|&&m| m).count()
209 }
210
211 #[must_use]
213 pub fn num_pruned(&self) -> usize {
214 self.mask.iter().filter(|&&m| !m).count()
215 }
216
217 #[must_use]
219 pub fn sparsity(&self) -> f32 {
220 if self.mask.is_empty() {
221 return 0.0;
222 }
223 self.num_pruned() as f32 / self.mask.len() as f32
224 }
225
226 pub fn and(&self, other: &PruningMask) -> Result<PruningMask> {
231 if self.mask.len() != other.mask.len() {
232 return Err(MlError::InvalidConfig(format!(
233 "Mask sizes don't match: {} vs {}",
234 self.mask.len(),
235 other.mask.len()
236 )));
237 }
238
239 let combined: Vec<bool> = self
240 .mask
241 .iter()
242 .zip(other.mask.iter())
243 .map(|(&a, &b)| a && b)
244 .collect();
245
246 Ok(PruningMask::new(combined, self.shape.clone()))
247 }
248
249 pub fn or(&self, other: &PruningMask) -> Result<PruningMask> {
254 if self.mask.len() != other.mask.len() {
255 return Err(MlError::InvalidConfig(format!(
256 "Mask sizes don't match: {} vs {}",
257 self.mask.len(),
258 other.mask.len()
259 )));
260 }
261
262 let combined: Vec<bool> = self
263 .mask
264 .iter()
265 .zip(other.mask.iter())
266 .map(|(&a, &b)| a || b)
267 .collect();
268
269 Ok(PruningMask::new(combined, self.shape.clone()))
270 }
271
272 #[must_use]
274 pub fn invert(&self) -> PruningMask {
275 PruningMask::new(self.mask.iter().map(|&m| !m).collect(), self.shape.clone())
276 }
277
278 pub fn apply(&self, weights: &WeightTensor) -> Result<WeightTensor> {
283 if self.mask.len() != weights.data.len() {
284 return Err(MlError::InvalidConfig(format!(
285 "Mask size {} doesn't match weight size {}",
286 self.mask.len(),
287 weights.data.len()
288 )));
289 }
290
291 let pruned_data: Vec<f32> = weights
292 .data
293 .iter()
294 .zip(self.mask.iter())
295 .map(|(&w, &keep)| if keep { w } else { 0.0 })
296 .collect();
297
298 Ok(WeightTensor::new(
299 pruned_data,
300 weights.shape.clone(),
301 weights.name.clone(),
302 ))
303 }
304}
305
306#[derive(Debug, Clone, Copy, PartialEq, Default)]
308pub enum ImportanceMethod {
309 #[default]
311 L1Norm,
312 L2Norm,
314 GradientWeighted,
316 TaylorExpansion,
318 Random {
320 seed: u64,
322 },
323 Movement,
325 Fisher,
327}
328
329#[derive(Debug, Clone)]
331pub struct GradientInfo {
332 pub gradients: Vec<f32>,
334 pub activations: Option<Vec<f32>>,
336}
337
338impl GradientInfo {
339 #[must_use]
341 pub fn new(gradients: Vec<f32>) -> Self {
342 Self {
343 gradients,
344 activations: None,
345 }
346 }
347
348 #[must_use]
350 pub fn with_activations(gradients: Vec<f32>, activations: Vec<f32>) -> Self {
351 Self {
352 gradients,
353 activations: Some(activations),
354 }
355 }
356}
357
358#[derive(Debug, Clone)]
364pub struct LotteryTicketState {
365 pub initial_weights: Vec<WeightTensor>,
367 pub masks: Vec<PruningMask>,
369 pub iteration: usize,
371 pub sparsity_history: Vec<f32>,
373 pub enabled: bool,
375}
376
377impl LotteryTicketState {
378 #[must_use]
380 pub fn new(initial_weights: Vec<WeightTensor>) -> Self {
381 let num_layers = initial_weights.len();
382 Self {
383 initial_weights,
384 masks: Vec::with_capacity(num_layers),
385 iteration: 0,
386 sparsity_history: Vec::new(),
387 enabled: true,
388 }
389 }
390
391 pub fn rewind(&self) -> Vec<WeightTensor> {
395 if self.masks.is_empty() {
396 return self.initial_weights.clone();
397 }
398
399 self.initial_weights
400 .iter()
401 .zip(self.masks.iter())
402 .map(|(weights, mask)| {
403 mask.apply(weights).unwrap_or_else(|_| weights.clone())
405 })
406 .collect()
407 }
408
409 pub fn update_masks(&mut self, new_masks: Vec<PruningMask>, sparsity: f32) {
411 self.masks = new_masks;
412 self.iteration += 1;
413 self.sparsity_history.push(sparsity);
414 }
415
416 #[must_use]
418 pub fn current_sparsity(&self) -> f32 {
419 if self.masks.is_empty() {
420 return 0.0;
421 }
422
423 let total_pruned: usize = self.masks.iter().map(|m| m.num_pruned()).sum();
424 let total_elements: usize = self.masks.iter().map(|m| m.numel()).sum();
425
426 if total_elements == 0 {
427 0.0
428 } else {
429 total_pruned as f32 / total_elements as f32
430 }
431 }
432}
433
434#[derive(Debug, Clone, Copy, PartialEq)]
436pub enum MaskCreationMode {
437 Threshold(f32),
439 GlobalPercentage(f32),
441 LayerWisePercentage(f32),
443 TopK(usize),
445 TopKPerLayer(usize),
447}
448
449impl Default for MaskCreationMode {
450 fn default() -> Self {
451 Self::GlobalPercentage(0.5)
452 }
453}
454
455pub trait FineTuneCallback: Send + Sync {
459 fn fine_tune(
470 &mut self,
471 weights: Vec<WeightTensor>,
472 masks: &[PruningMask],
473 iteration: usize,
474 sparsity: f32,
475 ) -> Result<Vec<WeightTensor>>;
476
477 fn epochs(&self) -> usize;
479}
480
481pub struct NoOpFineTune;
483
484impl FineTuneCallback for NoOpFineTune {
485 fn fine_tune(
486 &mut self,
487 weights: Vec<WeightTensor>,
488 _masks: &[PruningMask],
489 _iteration: usize,
490 _sparsity: f32,
491 ) -> Result<Vec<WeightTensor>> {
492 Ok(weights)
493 }
494
495 fn epochs(&self) -> usize {
496 0
497 }
498}
499
500pub struct UnstructuredPruner {
506 config: PruningConfig,
508 importance_method: ImportanceMethod,
510 masks: Vec<PruningMask>,
512 lottery_ticket_state: Option<LotteryTicketState>,
514 mask_mode: MaskCreationMode,
516 current_iteration: usize,
518 rng_state: u64,
520}
521
522impl UnstructuredPruner {
523 #[must_use]
525 pub fn new(config: PruningConfig, importance_method: ImportanceMethod) -> Self {
526 let seed = match importance_method {
527 ImportanceMethod::Random { seed } => seed,
528 _ => 42,
529 };
530 let sparsity_target = config.sparsity_target;
531
532 Self {
533 config,
534 importance_method,
535 masks: Vec::new(),
536 lottery_ticket_state: None,
537 mask_mode: MaskCreationMode::GlobalPercentage(sparsity_target),
538 current_iteration: 0,
539 rng_state: seed,
540 }
541 }
542
543 #[must_use]
545 pub fn with_mask_mode(mut self, mode: MaskCreationMode) -> Self {
546 self.mask_mode = mode;
547 self
548 }
549
550 pub fn enable_lottery_ticket(&mut self, initial_weights: Vec<WeightTensor>) {
552 self.lottery_ticket_state = Some(LotteryTicketState::new(initial_weights));
553 }
554
555 pub fn disable_lottery_ticket(&mut self) {
557 self.lottery_ticket_state = None;
558 }
559
560 #[must_use]
562 pub fn masks(&self) -> &[PruningMask] {
563 &self.masks
564 }
565
566 #[must_use]
568 pub fn current_iteration(&self) -> usize {
569 self.current_iteration
570 }
571
572 #[must_use]
574 pub fn lottery_ticket_state(&self) -> Option<&LotteryTicketState> {
575 self.lottery_ticket_state.as_ref()
576 }
577
578 #[must_use]
580 pub fn rewind_to_initial(&self) -> Option<Vec<WeightTensor>> {
581 self.lottery_ticket_state
582 .as_ref()
583 .map(|state| state.rewind())
584 }
585
586 fn next_random(&mut self) -> f32 {
588 self.rng_state = self.rng_state.wrapping_mul(1103515245).wrapping_add(12345) % (1u64 << 31);
590 (self.rng_state as f32) / ((1u64 << 31) as f32)
591 }
592
593 pub fn compute_importance(
602 &mut self,
603 weights: &WeightTensor,
604 gradient_info: Option<&GradientInfo>,
605 ) -> Vec<f32> {
606 match self.importance_method {
607 ImportanceMethod::L1Norm => weights.data.iter().map(|w| w.abs()).collect(),
608 ImportanceMethod::L2Norm => weights.data.iter().map(|w| w * w).collect(),
609 ImportanceMethod::GradientWeighted => {
610 if let Some(info) = gradient_info {
611 if info.gradients.len() == weights.data.len() {
612 weights
613 .data
614 .iter()
615 .zip(info.gradients.iter())
616 .map(|(w, g)| (w * g).abs())
617 .collect()
618 } else {
619 warn!(
620 "Gradient size mismatch, falling back to L1 norm. \
621 Weights: {}, Gradients: {}",
622 weights.data.len(),
623 info.gradients.len()
624 );
625 weights.data.iter().map(|w| w.abs()).collect()
626 }
627 } else {
628 warn!("No gradient info provided, falling back to L1 norm");
629 weights.data.iter().map(|w| w.abs()).collect()
630 }
631 }
632 ImportanceMethod::TaylorExpansion => {
633 if let Some(info) = gradient_info {
634 if info.gradients.len() == weights.data.len() {
635 if let Some(ref activations) = info.activations {
636 if activations.len() == weights.data.len() {
637 weights
639 .data
640 .iter()
641 .zip(info.gradients.iter())
642 .zip(activations.iter())
643 .map(|((w, g), a)| (w * g * a).abs())
644 .collect()
645 } else {
646 weights
648 .data
649 .iter()
650 .zip(info.gradients.iter())
651 .map(|(w, g)| (w * g).abs())
652 .collect()
653 }
654 } else {
655 weights
657 .data
658 .iter()
659 .zip(info.gradients.iter())
660 .map(|(w, g)| (w * g).abs())
661 .collect()
662 }
663 } else {
664 warn!("Gradient size mismatch, falling back to L1 norm");
665 weights.data.iter().map(|w| w.abs()).collect()
666 }
667 } else {
668 warn!("No gradient info for Taylor, falling back to L1 norm");
669 weights.data.iter().map(|w| w.abs()).collect()
670 }
671 }
672 ImportanceMethod::Random { .. } => (0..weights.data.len())
673 .map(|_| self.next_random())
674 .collect(),
675 ImportanceMethod::Movement => {
676 weights.data.iter().map(|w| w.abs()).collect()
679 }
680 ImportanceMethod::Fisher => {
681 if let Some(info) = gradient_info {
683 if info.gradients.len() == weights.data.len() {
684 info.gradients.iter().map(|g| g * g).collect()
685 } else {
686 warn!("Gradient size mismatch for Fisher, falling back to L1");
687 weights.data.iter().map(|w| w.abs()).collect()
688 }
689 } else {
690 warn!("No gradient info for Fisher, falling back to L1 norm");
691 weights.data.iter().map(|w| w.abs()).collect()
692 }
693 }
694 }
695 }
696
697 pub fn create_mask(&self, importance: &[f32], shape: &[usize]) -> PruningMask {
703 let num_weights = importance.len();
704 if num_weights == 0 {
705 return PruningMask::new(Vec::new(), shape.to_vec());
706 }
707
708 match self.mask_mode {
709 MaskCreationMode::Threshold(threshold) => {
710 let mask: Vec<bool> = importance.iter().map(|&s| s >= threshold).collect();
711 PruningMask::new(mask, shape.to_vec())
712 }
713 MaskCreationMode::GlobalPercentage(sparsity)
714 | MaskCreationMode::LayerWisePercentage(sparsity) => {
715 let num_to_prune =
716 ((num_weights as f32 * sparsity).round() as usize).min(num_weights);
717
718 let mut indexed: Vec<(usize, f32)> = importance
720 .iter()
721 .enumerate()
722 .map(|(i, &s)| (i, s))
723 .collect();
724
725 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
727
728 let mut mask = vec![true; num_weights];
730 for (idx, _) in indexed.iter().take(num_to_prune) {
731 mask[*idx] = false;
732 }
733
734 PruningMask::new(mask, shape.to_vec())
735 }
736 MaskCreationMode::TopK(k) | MaskCreationMode::TopKPerLayer(k) => {
737 let num_to_keep = k.min(num_weights);
738
739 let mut indexed: Vec<(usize, f32)> = importance
741 .iter()
742 .enumerate()
743 .map(|(i, &s)| (i, s))
744 .collect();
745
746 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
748
749 let mut mask = vec![false; num_weights];
751 for (idx, _) in indexed.iter().take(num_to_keep) {
752 mask[*idx] = true;
753 }
754
755 PruningMask::new(mask, shape.to_vec())
756 }
757 }
758 }
759
760 pub fn prune_tensor(&mut self, weights: &WeightTensor) -> Result<(WeightTensor, PruningMask)> {
771 self.prune_tensor_with_gradients(weights, None)
772 }
773
774 pub fn prune_tensor_with_gradients(
786 &mut self,
787 weights: &WeightTensor,
788 gradient_info: Option<&GradientInfo>,
789 ) -> Result<(WeightTensor, PruningMask)> {
790 weights.validate()?;
792
793 if weights.is_empty() {
794 return Ok((
795 weights.clone(),
796 PruningMask::new(Vec::new(), weights.shape.clone()),
797 ));
798 }
799
800 let importance = self.compute_importance(weights, gradient_info);
802
803 let mask = self.create_mask(&importance, &weights.shape);
805
806 let pruned = mask.apply(weights)?;
808
809 debug!(
810 "Pruned tensor '{}': {:.1}% sparsity ({} -> {} non-zero)",
811 weights.name,
812 mask.sparsity() * 100.0,
813 weights.numel(),
814 mask.num_kept()
815 );
816
817 Ok((pruned, mask))
818 }
819
820 pub fn prune_tensors_global(
834 &mut self,
835 tensors: &[WeightTensor],
836 ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
837 self.prune_tensors_global_with_gradients(tensors, &[])
838 }
839
840 pub fn prune_tensors_global_with_gradients(
852 &mut self,
853 tensors: &[WeightTensor],
854 gradient_infos: &[GradientInfo],
855 ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
856 if tensors.is_empty() {
857 return Ok((Vec::new(), Vec::new()));
858 }
859
860 for tensor in tensors {
862 tensor.validate()?;
863 }
864
865 let mut all_scores: Vec<(usize, usize, f32)> = Vec::new();
867
868 for (tensor_idx, tensor) in tensors.iter().enumerate() {
869 let gradient_info = gradient_infos.get(tensor_idx);
870 let importance = self.compute_importance(tensor, gradient_info);
871
872 for (elem_idx, &score) in importance.iter().enumerate() {
873 all_scores.push((tensor_idx, elem_idx, score));
874 }
875 }
876
877 all_scores.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
879
880 let total_weights = all_scores.len();
882 let num_to_prune = match self.mask_mode {
883 MaskCreationMode::GlobalPercentage(sparsity) => {
884 ((total_weights as f32 * sparsity).round() as usize).min(total_weights)
885 }
886 MaskCreationMode::TopK(k) => total_weights.saturating_sub(k),
887 MaskCreationMode::Threshold(threshold) => {
888 all_scores.iter().filter(|(_, _, s)| *s < threshold).count()
889 }
890 MaskCreationMode::LayerWisePercentage(_) | MaskCreationMode::TopKPerLayer(_) => {
891 return self.prune_tensors_layerwise_with_gradients(tensors, gradient_infos);
893 }
894 };
895
896 let mut masks: Vec<Vec<bool>> = tensors.iter().map(|t| vec![true; t.data.len()]).collect();
898
899 for (tensor_idx, elem_idx, _) in all_scores.iter().take(num_to_prune) {
901 if let Some(mask) = masks.get_mut(*tensor_idx) {
902 if let Some(elem) = mask.get_mut(*elem_idx) {
903 *elem = false;
904 }
905 }
906 }
907
908 let mut result_tensors = Vec::with_capacity(tensors.len());
910 let mut result_masks = Vec::with_capacity(tensors.len());
911
912 for (tensor, mask_vec) in tensors.iter().zip(masks) {
913 let mask = PruningMask::with_name(mask_vec, tensor.shape.clone(), tensor.name.clone());
914 let pruned = mask.apply(tensor)?;
915 result_tensors.push(pruned);
916 result_masks.push(mask);
917 }
918
919 self.masks = result_masks.clone();
921
922 let overall_sparsity = self.current_sparsity();
924
925 if let Some(ref mut lts) = self.lottery_ticket_state {
927 lts.update_masks(result_masks.clone(), overall_sparsity);
928 }
929
930 info!(
931 "Global pruning complete: {:.1}% overall sparsity ({} tensors)",
932 overall_sparsity * 100.0,
933 tensors.len()
934 );
935
936 Ok((result_tensors, result_masks))
937 }
938
939 pub fn prune_tensors_layerwise(
952 &mut self,
953 tensors: &[WeightTensor],
954 ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
955 self.prune_tensors_layerwise_with_gradients(tensors, &[])
956 }
957
958 pub fn prune_tensors_layerwise_with_gradients(
970 &mut self,
971 tensors: &[WeightTensor],
972 gradient_infos: &[GradientInfo],
973 ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
974 let mut result_tensors = Vec::with_capacity(tensors.len());
975 let mut result_masks = Vec::with_capacity(tensors.len());
976
977 for (i, tensor) in tensors.iter().enumerate() {
978 let gradient_info = gradient_infos.get(i);
979 let (pruned, mask) = self.prune_tensor_with_gradients(tensor, gradient_info)?;
980 result_tensors.push(pruned);
981 result_masks.push(mask);
982 }
983
984 self.masks = result_masks.clone();
986
987 let overall_sparsity = self.current_sparsity();
989
990 if let Some(ref mut lts) = self.lottery_ticket_state {
992 lts.update_masks(result_masks.clone(), overall_sparsity);
993 }
994
995 info!(
996 "Layer-wise pruning complete: {:.1}% overall sparsity ({} tensors)",
997 overall_sparsity * 100.0,
998 tensors.len()
999 );
1000
1001 Ok((result_tensors, result_masks))
1002 }
1003
1004 #[must_use]
1006 pub fn current_sparsity(&self) -> f32 {
1007 if self.masks.is_empty() {
1008 return 0.0;
1009 }
1010
1011 let total_pruned: usize = self.masks.iter().map(|m| m.num_pruned()).sum();
1012 let total_elements: usize = self.masks.iter().map(|m| m.numel()).sum();
1013
1014 if total_elements == 0 {
1015 0.0
1016 } else {
1017 total_pruned as f32 / total_elements as f32
1018 }
1019 }
1020
1021 pub fn iterative_prune<F: FineTuneCallback>(
1033 &mut self,
1034 initial_weights: Vec<WeightTensor>,
1035 callback: &mut F,
1036 ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
1037 let iterations = match self.config.schedule {
1038 PruningSchedule::Iterative { iterations } => iterations,
1039 PruningSchedule::Polynomial { steps, .. } => steps,
1040 PruningSchedule::OneShot => 1,
1041 };
1042
1043 let mut current_weights = initial_weights;
1044
1045 for i in 0..iterations {
1046 let target_sparsity = match self.config.schedule {
1048 PruningSchedule::Polynomial {
1049 initial_sparsity,
1050 final_sparsity,
1051 steps,
1052 } => {
1053 let t = i as f32;
1054 let total = steps as f32;
1055 let s_i = initial_sparsity as f32 / 100.0;
1056 let s_f = final_sparsity as f32 / 100.0;
1057 s_f + (s_i - s_f) * (1.0 - t / total).powi(3)
1058 }
1059 PruningSchedule::Iterative { iterations: n } => {
1060 self.config.sparsity_target * ((i + 1) as f32 / n as f32)
1061 }
1062 PruningSchedule::OneShot => self.config.sparsity_target,
1063 };
1064
1065 self.mask_mode = MaskCreationMode::GlobalPercentage(target_sparsity);
1067
1068 info!(
1069 "Iteration {}/{}: target sparsity {:.1}%",
1070 i + 1,
1071 iterations,
1072 target_sparsity * 100.0
1073 );
1074
1075 let (pruned, masks) = self.prune_tensors_global(¤t_weights)?;
1077 self.current_iteration = i + 1;
1078
1079 current_weights = if self.config.fine_tune && i < iterations - 1 {
1081 let actual_sparsity = self.current_sparsity();
1082 callback.fine_tune(pruned, &masks, i, actual_sparsity)?
1083 } else {
1084 pruned
1085 };
1086 }
1087
1088 let final_masks = self.masks.clone();
1089 Ok((current_weights, final_masks))
1090 }
1091
1092 #[must_use]
1094 pub fn compute_stats(&self, original_tensors: &[WeightTensor]) -> PruningStats {
1095 let original_params: usize = original_tensors.iter().map(|t| t.numel()).sum();
1096
1097 let pruned_params = if self.masks.is_empty() {
1098 original_params
1099 } else {
1100 self.masks.iter().map(|m| m.num_kept()).sum()
1101 };
1102
1103 let actual_sparsity = self.current_sparsity();
1104
1105 PruningStats {
1106 original_params,
1107 pruned_params,
1108 actual_sparsity,
1109 }
1110 }
1111}