Skip to main content

oxigdal_ml/optimization/pruning/
unstructured.rs

1//! Unstructured (element-wise) pruning implementation
2//!
3//! This module provides the core types and algorithms for unstructured pruning,
4//! including weight tensors, pruning masks, importance computation methods,
5//! and the `UnstructuredPruner` orchestrator.
6
7use super::{PruningConfig, PruningSchedule, PruningStats};
8use crate::error::{MlError, Result};
9use std::cmp::Ordering;
10use tracing::{debug, info, warn};
11
12/// Weight tensor for pruning operations
13///
14/// Represents a named tensor with shape information, suitable for
15/// layer-wise or global pruning operations.
16#[derive(Debug, Clone)]
17pub struct WeightTensor {
18    /// Weight data in row-major order
19    pub data: Vec<f32>,
20    /// Shape of the tensor (e.g., [out_channels, in_channels, kernel_h, kernel_w])
21    pub shape: Vec<usize>,
22    /// Layer name for identification
23    pub name: String,
24}
25
26impl WeightTensor {
27    /// Creates a new weight tensor
28    #[must_use]
29    pub fn new(data: Vec<f32>, shape: Vec<usize>, name: String) -> Self {
30        Self { data, shape, name }
31    }
32
33    /// Returns the total number of elements
34    #[must_use]
35    pub fn numel(&self) -> usize {
36        self.data.len()
37    }
38
39    /// Returns true if the tensor is empty
40    #[must_use]
41    pub fn is_empty(&self) -> bool {
42        self.data.is_empty()
43    }
44
45    /// Validates that shape matches data length
46    ///
47    /// # Errors
48    /// Returns an error if shape product does not match data length
49    pub fn validate(&self) -> Result<()> {
50        let expected_len: usize = self.shape.iter().product();
51        if expected_len != self.data.len() {
52            return Err(MlError::InvalidConfig(format!(
53                "Shape {:?} expects {} elements but got {}",
54                self.shape,
55                expected_len,
56                self.data.len()
57            )));
58        }
59        Ok(())
60    }
61
62    /// Computes sparsity (fraction of zero weights)
63    #[must_use]
64    pub fn sparsity(&self) -> f32 {
65        if self.data.is_empty() {
66            return 0.0;
67        }
68        let zero_count = self.data.iter().filter(|&&w| w == 0.0).count();
69        zero_count as f32 / self.data.len() as f32
70    }
71
72    /// Returns the L1 norm (sum of absolute values)
73    #[must_use]
74    pub fn l1_norm(&self) -> f32 {
75        self.data.iter().map(|w| w.abs()).sum()
76    }
77
78    /// Returns the L2 norm (Euclidean norm)
79    #[must_use]
80    pub fn l2_norm(&self) -> f32 {
81        self.data.iter().map(|w| w * w).sum::<f32>().sqrt()
82    }
83
84    /// Returns statistics about the weight distribution
85    #[must_use]
86    pub fn statistics(&self) -> WeightStatistics {
87        if self.data.is_empty() {
88            return WeightStatistics::default();
89        }
90
91        let mut sorted = self.data.clone();
92        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
93
94        let min = sorted.first().copied().unwrap_or(0.0);
95        let max = sorted.last().copied().unwrap_or(0.0);
96        let mean = self.data.iter().sum::<f32>() / self.data.len() as f32;
97
98        let variance =
99            self.data.iter().map(|w| (w - mean).powi(2)).sum::<f32>() / self.data.len() as f32;
100        let std = variance.sqrt();
101
102        let median_idx = self.data.len() / 2;
103        let median = if self.data.len() % 2 == 0 {
104            (sorted
105                .get(median_idx.saturating_sub(1))
106                .copied()
107                .unwrap_or(0.0)
108                + sorted.get(median_idx).copied().unwrap_or(0.0))
109                / 2.0
110        } else {
111            sorted.get(median_idx).copied().unwrap_or(0.0)
112        };
113
114        WeightStatistics {
115            min,
116            max,
117            mean,
118            std,
119            median,
120            sparsity: self.sparsity(),
121        }
122    }
123}
124
125/// Statistics about weight distribution
126#[derive(Debug, Clone, Default)]
127pub struct WeightStatistics {
128    /// Minimum weight value
129    pub min: f32,
130    /// Maximum weight value
131    pub max: f32,
132    /// Mean weight value
133    pub mean: f32,
134    /// Standard deviation
135    pub std: f32,
136    /// Median weight value
137    pub median: f32,
138    /// Fraction of zero weights
139    pub sparsity: f32,
140}
141
142/// Pruning mask indicating which weights are kept or pruned
143///
144/// The mask uses `true` to indicate weights that should be KEPT,
145/// and `false` for weights that should be PRUNED (zeroed).
146#[derive(Debug, Clone)]
147pub struct PruningMask {
148    /// Boolean mask (true = keep, false = prune)
149    pub mask: Vec<bool>,
150    /// Shape of the mask (same as weight tensor)
151    pub shape: Vec<usize>,
152    /// Optional layer name
153    pub name: Option<String>,
154}
155
156impl PruningMask {
157    /// Creates a new pruning mask
158    #[must_use]
159    pub fn new(mask: Vec<bool>, shape: Vec<usize>) -> Self {
160        Self {
161            mask,
162            shape,
163            name: None,
164        }
165    }
166
167    /// Creates a new pruning mask with a name
168    #[must_use]
169    pub fn with_name(mask: Vec<bool>, shape: Vec<usize>, name: String) -> Self {
170        Self {
171            mask,
172            shape,
173            name: Some(name),
174        }
175    }
176
177    /// Creates an all-ones mask (keep everything)
178    #[must_use]
179    pub fn ones(shape: &[usize]) -> Self {
180        let size: usize = shape.iter().product();
181        Self {
182            mask: vec![true; size],
183            shape: shape.to_vec(),
184            name: None,
185        }
186    }
187
188    /// Creates an all-zeros mask (prune everything)
189    #[must_use]
190    pub fn zeros(shape: &[usize]) -> Self {
191        let size: usize = shape.iter().product();
192        Self {
193            mask: vec![false; size],
194            shape: shape.to_vec(),
195            name: None,
196        }
197    }
198
199    /// Returns the number of elements in the mask
200    #[must_use]
201    pub fn numel(&self) -> usize {
202        self.mask.len()
203    }
204
205    /// Returns the number of weights kept (not pruned)
206    #[must_use]
207    pub fn num_kept(&self) -> usize {
208        self.mask.iter().filter(|&&m| m).count()
209    }
210
211    /// Returns the number of weights pruned
212    #[must_use]
213    pub fn num_pruned(&self) -> usize {
214        self.mask.iter().filter(|&&m| !m).count()
215    }
216
217    /// Returns the sparsity (fraction of pruned weights)
218    #[must_use]
219    pub fn sparsity(&self) -> f32 {
220        if self.mask.is_empty() {
221            return 0.0;
222        }
223        self.num_pruned() as f32 / self.mask.len() as f32
224    }
225
226    /// Combines two masks with logical AND (both must keep)
227    ///
228    /// # Errors
229    /// Returns an error if mask sizes don't match
230    pub fn and(&self, other: &PruningMask) -> Result<PruningMask> {
231        if self.mask.len() != other.mask.len() {
232            return Err(MlError::InvalidConfig(format!(
233                "Mask sizes don't match: {} vs {}",
234                self.mask.len(),
235                other.mask.len()
236            )));
237        }
238
239        let combined: Vec<bool> = self
240            .mask
241            .iter()
242            .zip(other.mask.iter())
243            .map(|(&a, &b)| a && b)
244            .collect();
245
246        Ok(PruningMask::new(combined, self.shape.clone()))
247    }
248
249    /// Combines two masks with logical OR (either keeps)
250    ///
251    /// # Errors
252    /// Returns an error if mask sizes don't match
253    pub fn or(&self, other: &PruningMask) -> Result<PruningMask> {
254        if self.mask.len() != other.mask.len() {
255            return Err(MlError::InvalidConfig(format!(
256                "Mask sizes don't match: {} vs {}",
257                self.mask.len(),
258                other.mask.len()
259            )));
260        }
261
262        let combined: Vec<bool> = self
263            .mask
264            .iter()
265            .zip(other.mask.iter())
266            .map(|(&a, &b)| a || b)
267            .collect();
268
269        Ok(PruningMask::new(combined, self.shape.clone()))
270    }
271
272    /// Inverts the mask
273    #[must_use]
274    pub fn invert(&self) -> PruningMask {
275        PruningMask::new(self.mask.iter().map(|&m| !m).collect(), self.shape.clone())
276    }
277
278    /// Applies the mask to a weight tensor, zeroing pruned weights
279    ///
280    /// # Errors
281    /// Returns an error if sizes don't match
282    pub fn apply(&self, weights: &WeightTensor) -> Result<WeightTensor> {
283        if self.mask.len() != weights.data.len() {
284            return Err(MlError::InvalidConfig(format!(
285                "Mask size {} doesn't match weight size {}",
286                self.mask.len(),
287                weights.data.len()
288            )));
289        }
290
291        let pruned_data: Vec<f32> = weights
292            .data
293            .iter()
294            .zip(self.mask.iter())
295            .map(|(&w, &keep)| if keep { w } else { 0.0 })
296            .collect();
297
298        Ok(WeightTensor::new(
299            pruned_data,
300            weights.shape.clone(),
301            weights.name.clone(),
302        ))
303    }
304}
305
306/// Importance score computation method for unstructured pruning
307#[derive(Debug, Clone, Copy, PartialEq, Default)]
308pub enum ImportanceMethod {
309    /// L1 norm (absolute value): |w|
310    #[default]
311    L1Norm,
312    /// L2 norm (squared magnitude): w^2
313    L2Norm,
314    /// Gradient-weighted importance: |w * g|
315    GradientWeighted,
316    /// Taylor expansion: |w * g * a| or first-order Taylor approximation
317    TaylorExpansion,
318    /// Random scores (baseline for comparison)
319    Random {
320        /// Seed for reproducibility
321        seed: u64,
322    },
323    /// Movement-based: tracks weight changes during training
324    Movement,
325    /// Fisher information approximation: g^2 (gradient squared)
326    Fisher,
327}
328
329/// Gradient information for gradient-based pruning methods
330#[derive(Debug, Clone)]
331pub struct GradientInfo {
332    /// Gradients for each weight
333    pub gradients: Vec<f32>,
334    /// Optional activations (for Taylor expansion)
335    pub activations: Option<Vec<f32>>,
336}
337
338impl GradientInfo {
339    /// Creates gradient info with gradients only
340    #[must_use]
341    pub fn new(gradients: Vec<f32>) -> Self {
342        Self {
343            gradients,
344            activations: None,
345        }
346    }
347
348    /// Creates gradient info with gradients and activations
349    #[must_use]
350    pub fn with_activations(gradients: Vec<f32>, activations: Vec<f32>) -> Self {
351        Self {
352            gradients,
353            activations: Some(activations),
354        }
355    }
356}
357
358/// State for Lottery Ticket Hypothesis rewinding
359///
360/// The Lottery Ticket Hypothesis (Frankle & Carlin, 2019) suggests that
361/// randomly initialized networks contain sparse subnetworks ("winning tickets")
362/// that can achieve comparable accuracy when trained in isolation.
363#[derive(Debug, Clone)]
364pub struct LotteryTicketState {
365    /// Initial weights before any training (for rewinding)
366    pub initial_weights: Vec<WeightTensor>,
367    /// Current pruning masks learned through training
368    pub masks: Vec<PruningMask>,
369    /// Current pruning iteration
370    pub iteration: usize,
371    /// Sparsity at each iteration
372    pub sparsity_history: Vec<f32>,
373    /// Whether rewinding is enabled
374    pub enabled: bool,
375}
376
377impl LotteryTicketState {
378    /// Creates a new lottery ticket state
379    #[must_use]
380    pub fn new(initial_weights: Vec<WeightTensor>) -> Self {
381        let num_layers = initial_weights.len();
382        Self {
383            initial_weights,
384            masks: Vec::with_capacity(num_layers),
385            iteration: 0,
386            sparsity_history: Vec::new(),
387            enabled: true,
388        }
389    }
390
391    /// Rewinds weights to initial values while applying current masks
392    ///
393    /// Returns the rewound weights with masks applied
394    pub fn rewind(&self) -> Vec<WeightTensor> {
395        if self.masks.is_empty() {
396            return self.initial_weights.clone();
397        }
398
399        self.initial_weights
400            .iter()
401            .zip(self.masks.iter())
402            .map(|(weights, mask)| {
403                // Apply mask, falling back to original if apply fails
404                mask.apply(weights).unwrap_or_else(|_| weights.clone())
405            })
406            .collect()
407    }
408
409    /// Updates masks after a pruning iteration
410    pub fn update_masks(&mut self, new_masks: Vec<PruningMask>, sparsity: f32) {
411        self.masks = new_masks;
412        self.iteration += 1;
413        self.sparsity_history.push(sparsity);
414    }
415
416    /// Returns the current overall sparsity
417    #[must_use]
418    pub fn current_sparsity(&self) -> f32 {
419        if self.masks.is_empty() {
420            return 0.0;
421        }
422
423        let total_pruned: usize = self.masks.iter().map(|m| m.num_pruned()).sum();
424        let total_elements: usize = self.masks.iter().map(|m| m.numel()).sum();
425
426        if total_elements == 0 {
427            0.0
428        } else {
429            total_pruned as f32 / total_elements as f32
430        }
431    }
432}
433
434/// Mask creation mode for pruning
435#[derive(Debug, Clone, Copy, PartialEq)]
436pub enum MaskCreationMode {
437    /// Prune weights below a fixed threshold
438    Threshold(f32),
439    /// Prune a percentage of weights globally
440    GlobalPercentage(f32),
441    /// Prune a percentage of weights per layer
442    LayerWisePercentage(f32),
443    /// Keep only top-k weights globally
444    TopK(usize),
445    /// Keep only top-k weights per layer
446    TopKPerLayer(usize),
447}
448
449impl Default for MaskCreationMode {
450    fn default() -> Self {
451        Self::GlobalPercentage(0.5)
452    }
453}
454
455/// Fine-tuning callback for iterative pruning
456///
457/// Implement this trait to perform fine-tuning between pruning iterations.
458pub trait FineTuneCallback: Send + Sync {
459    /// Called after each pruning iteration for fine-tuning
460    ///
461    /// # Arguments
462    /// * `weights` - Current pruned weights
463    /// * `masks` - Current pruning masks
464    /// * `iteration` - Current pruning iteration (0-indexed)
465    /// * `sparsity` - Current sparsity level
466    ///
467    /// # Returns
468    /// Fine-tuned weights
469    fn fine_tune(
470        &mut self,
471        weights: Vec<WeightTensor>,
472        masks: &[PruningMask],
473        iteration: usize,
474        sparsity: f32,
475    ) -> Result<Vec<WeightTensor>>;
476
477    /// Returns the number of fine-tuning epochs
478    fn epochs(&self) -> usize;
479}
480
481/// No-op fine-tuning callback (skips fine-tuning)
482pub struct NoOpFineTune;
483
484impl FineTuneCallback for NoOpFineTune {
485    fn fine_tune(
486        &mut self,
487        weights: Vec<WeightTensor>,
488        _masks: &[PruningMask],
489        _iteration: usize,
490        _sparsity: f32,
491    ) -> Result<Vec<WeightTensor>> {
492        Ok(weights)
493    }
494
495    fn epochs(&self) -> usize {
496        0
497    }
498}
499
500/// Unstructured pruner for element-wise weight removal
501///
502/// This is the main orchestrator for unstructured pruning operations,
503/// supporting various importance methods, mask creation modes, and
504/// advanced features like lottery ticket rewinding.
505pub struct UnstructuredPruner {
506    /// Pruning configuration
507    config: PruningConfig,
508    /// Importance computation method
509    importance_method: ImportanceMethod,
510    /// Current pruning masks for each layer
511    masks: Vec<PruningMask>,
512    /// Lottery ticket state for rewinding
513    lottery_ticket_state: Option<LotteryTicketState>,
514    /// Mask creation mode
515    mask_mode: MaskCreationMode,
516    /// Current pruning iteration
517    current_iteration: usize,
518    /// RNG state for random pruning (simple LCG)
519    rng_state: u64,
520}
521
522impl UnstructuredPruner {
523    /// Creates a new unstructured pruner
524    #[must_use]
525    pub fn new(config: PruningConfig, importance_method: ImportanceMethod) -> Self {
526        let seed = match importance_method {
527            ImportanceMethod::Random { seed } => seed,
528            _ => 42,
529        };
530        let sparsity_target = config.sparsity_target;
531
532        Self {
533            config,
534            importance_method,
535            masks: Vec::new(),
536            lottery_ticket_state: None,
537            mask_mode: MaskCreationMode::GlobalPercentage(sparsity_target),
538            current_iteration: 0,
539            rng_state: seed,
540        }
541    }
542
543    /// Sets the mask creation mode
544    #[must_use]
545    pub fn with_mask_mode(mut self, mode: MaskCreationMode) -> Self {
546        self.mask_mode = mode;
547        self
548    }
549
550    /// Enables lottery ticket hypothesis support
551    pub fn enable_lottery_ticket(&mut self, initial_weights: Vec<WeightTensor>) {
552        self.lottery_ticket_state = Some(LotteryTicketState::new(initial_weights));
553    }
554
555    /// Disables lottery ticket support
556    pub fn disable_lottery_ticket(&mut self) {
557        self.lottery_ticket_state = None;
558    }
559
560    /// Returns the current masks
561    #[must_use]
562    pub fn masks(&self) -> &[PruningMask] {
563        &self.masks
564    }
565
566    /// Returns the current iteration
567    #[must_use]
568    pub fn current_iteration(&self) -> usize {
569        self.current_iteration
570    }
571
572    /// Returns the lottery ticket state if enabled
573    #[must_use]
574    pub fn lottery_ticket_state(&self) -> Option<&LotteryTicketState> {
575        self.lottery_ticket_state.as_ref()
576    }
577
578    /// Rewinds to initial weights with current masks (lottery ticket)
579    #[must_use]
580    pub fn rewind_to_initial(&self) -> Option<Vec<WeightTensor>> {
581        self.lottery_ticket_state
582            .as_ref()
583            .map(|state| state.rewind())
584    }
585
586    /// Generates a pseudo-random number (simple LCG)
587    fn next_random(&mut self) -> f32 {
588        // Linear Congruential Generator: a = 1103515245, c = 12345, m = 2^31
589        self.rng_state = self.rng_state.wrapping_mul(1103515245).wrapping_add(12345) % (1u64 << 31);
590        (self.rng_state as f32) / ((1u64 << 31) as f32)
591    }
592
593    /// Computes importance scores for a weight tensor
594    ///
595    /// # Arguments
596    /// * `weights` - The weight tensor to compute importance for
597    /// * `gradient_info` - Optional gradient information for gradient-based methods
598    ///
599    /// # Returns
600    /// Vector of importance scores (same length as weights)
601    pub fn compute_importance(
602        &mut self,
603        weights: &WeightTensor,
604        gradient_info: Option<&GradientInfo>,
605    ) -> Vec<f32> {
606        match self.importance_method {
607            ImportanceMethod::L1Norm => weights.data.iter().map(|w| w.abs()).collect(),
608            ImportanceMethod::L2Norm => weights.data.iter().map(|w| w * w).collect(),
609            ImportanceMethod::GradientWeighted => {
610                if let Some(info) = gradient_info {
611                    if info.gradients.len() == weights.data.len() {
612                        weights
613                            .data
614                            .iter()
615                            .zip(info.gradients.iter())
616                            .map(|(w, g)| (w * g).abs())
617                            .collect()
618                    } else {
619                        warn!(
620                            "Gradient size mismatch, falling back to L1 norm. \
621                            Weights: {}, Gradients: {}",
622                            weights.data.len(),
623                            info.gradients.len()
624                        );
625                        weights.data.iter().map(|w| w.abs()).collect()
626                    }
627                } else {
628                    warn!("No gradient info provided, falling back to L1 norm");
629                    weights.data.iter().map(|w| w.abs()).collect()
630                }
631            }
632            ImportanceMethod::TaylorExpansion => {
633                if let Some(info) = gradient_info {
634                    if info.gradients.len() == weights.data.len() {
635                        if let Some(ref activations) = info.activations {
636                            if activations.len() == weights.data.len() {
637                                // Full Taylor: |w * g * a|
638                                weights
639                                    .data
640                                    .iter()
641                                    .zip(info.gradients.iter())
642                                    .zip(activations.iter())
643                                    .map(|((w, g), a)| (w * g * a).abs())
644                                    .collect()
645                            } else {
646                                // First-order Taylor: |w * g|
647                                weights
648                                    .data
649                                    .iter()
650                                    .zip(info.gradients.iter())
651                                    .map(|(w, g)| (w * g).abs())
652                                    .collect()
653                            }
654                        } else {
655                            // First-order Taylor: |w * g|
656                            weights
657                                .data
658                                .iter()
659                                .zip(info.gradients.iter())
660                                .map(|(w, g)| (w * g).abs())
661                                .collect()
662                        }
663                    } else {
664                        warn!("Gradient size mismatch, falling back to L1 norm");
665                        weights.data.iter().map(|w| w.abs()).collect()
666                    }
667                } else {
668                    warn!("No gradient info for Taylor, falling back to L1 norm");
669                    weights.data.iter().map(|w| w.abs()).collect()
670                }
671            }
672            ImportanceMethod::Random { .. } => (0..weights.data.len())
673                .map(|_| self.next_random())
674                .collect(),
675            ImportanceMethod::Movement => {
676                // Movement pruning: importance based on weight magnitude change
677                // Without historical data, fall back to L1 norm
678                weights.data.iter().map(|w| w.abs()).collect()
679            }
680            ImportanceMethod::Fisher => {
681                // Fisher information: gradient squared
682                if let Some(info) = gradient_info {
683                    if info.gradients.len() == weights.data.len() {
684                        info.gradients.iter().map(|g| g * g).collect()
685                    } else {
686                        warn!("Gradient size mismatch for Fisher, falling back to L1");
687                        weights.data.iter().map(|w| w.abs()).collect()
688                    }
689                } else {
690                    warn!("No gradient info for Fisher, falling back to L1 norm");
691                    weights.data.iter().map(|w| w.abs()).collect()
692                }
693            }
694        }
695    }
696
697    /// Creates a pruning mask based on importance scores and mask mode
698    ///
699    /// # Arguments
700    /// * `importance` - Importance scores for each weight
701    /// * `shape` - Shape of the weight tensor
702    pub fn create_mask(&self, importance: &[f32], shape: &[usize]) -> PruningMask {
703        let num_weights = importance.len();
704        if num_weights == 0 {
705            return PruningMask::new(Vec::new(), shape.to_vec());
706        }
707
708        match self.mask_mode {
709            MaskCreationMode::Threshold(threshold) => {
710                let mask: Vec<bool> = importance.iter().map(|&s| s >= threshold).collect();
711                PruningMask::new(mask, shape.to_vec())
712            }
713            MaskCreationMode::GlobalPercentage(sparsity)
714            | MaskCreationMode::LayerWisePercentage(sparsity) => {
715                let num_to_prune =
716                    ((num_weights as f32 * sparsity).round() as usize).min(num_weights);
717
718                // Create indexed importance scores
719                let mut indexed: Vec<(usize, f32)> = importance
720                    .iter()
721                    .enumerate()
722                    .map(|(i, &s)| (i, s))
723                    .collect();
724
725                // Sort by importance (ascending - lowest importance first)
726                indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
727
728                // Create mask: prune the num_to_prune lowest importance weights
729                let mut mask = vec![true; num_weights];
730                for (idx, _) in indexed.iter().take(num_to_prune) {
731                    mask[*idx] = false;
732                }
733
734                PruningMask::new(mask, shape.to_vec())
735            }
736            MaskCreationMode::TopK(k) | MaskCreationMode::TopKPerLayer(k) => {
737                let num_to_keep = k.min(num_weights);
738
739                // Create indexed importance scores
740                let mut indexed: Vec<(usize, f32)> = importance
741                    .iter()
742                    .enumerate()
743                    .map(|(i, &s)| (i, s))
744                    .collect();
745
746                // Sort by importance (descending - highest importance first)
747                indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
748
749                // Create mask: keep only top-k
750                let mut mask = vec![false; num_weights];
751                for (idx, _) in indexed.iter().take(num_to_keep) {
752                    mask[*idx] = true;
753                }
754
755                PruningMask::new(mask, shape.to_vec())
756            }
757        }
758    }
759
760    /// Prunes a single weight tensor
761    ///
762    /// # Arguments
763    /// * `weights` - The weight tensor to prune
764    ///
765    /// # Returns
766    /// Tuple of (pruned weights, mask)
767    ///
768    /// # Errors
769    /// Returns an error if pruning fails
770    pub fn prune_tensor(&mut self, weights: &WeightTensor) -> Result<(WeightTensor, PruningMask)> {
771        self.prune_tensor_with_gradients(weights, None)
772    }
773
774    /// Prunes a single weight tensor with gradient information
775    ///
776    /// # Arguments
777    /// * `weights` - The weight tensor to prune
778    /// * `gradient_info` - Optional gradient information
779    ///
780    /// # Returns
781    /// Tuple of (pruned weights, mask)
782    ///
783    /// # Errors
784    /// Returns an error if pruning fails
785    pub fn prune_tensor_with_gradients(
786        &mut self,
787        weights: &WeightTensor,
788        gradient_info: Option<&GradientInfo>,
789    ) -> Result<(WeightTensor, PruningMask)> {
790        // Validate input
791        weights.validate()?;
792
793        if weights.is_empty() {
794            return Ok((
795                weights.clone(),
796                PruningMask::new(Vec::new(), weights.shape.clone()),
797            ));
798        }
799
800        // Compute importance scores
801        let importance = self.compute_importance(weights, gradient_info);
802
803        // Create mask
804        let mask = self.create_mask(&importance, &weights.shape);
805
806        // Apply mask
807        let pruned = mask.apply(weights)?;
808
809        debug!(
810            "Pruned tensor '{}': {:.1}% sparsity ({} -> {} non-zero)",
811            weights.name,
812            mask.sparsity() * 100.0,
813            weights.numel(),
814            mask.num_kept()
815        );
816
817        Ok((pruned, mask))
818    }
819
820    /// Prunes multiple weight tensors globally
821    ///
822    /// For global pruning, importance scores are computed across all tensors
823    /// and a single threshold is applied.
824    ///
825    /// # Arguments
826    /// * `tensors` - Weight tensors to prune
827    ///
828    /// # Returns
829    /// Tuple of (pruned tensors, masks)
830    ///
831    /// # Errors
832    /// Returns an error if pruning fails
833    pub fn prune_tensors_global(
834        &mut self,
835        tensors: &[WeightTensor],
836    ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
837        self.prune_tensors_global_with_gradients(tensors, &[])
838    }
839
840    /// Prunes multiple weight tensors globally with gradient information
841    ///
842    /// # Arguments
843    /// * `tensors` - Weight tensors to prune
844    /// * `gradient_infos` - Gradient information for each tensor
845    ///
846    /// # Returns
847    /// Tuple of (pruned tensors, masks)
848    ///
849    /// # Errors
850    /// Returns an error if pruning fails
851    pub fn prune_tensors_global_with_gradients(
852        &mut self,
853        tensors: &[WeightTensor],
854        gradient_infos: &[GradientInfo],
855    ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
856        if tensors.is_empty() {
857            return Ok((Vec::new(), Vec::new()));
858        }
859
860        // Validate all tensors
861        for tensor in tensors {
862            tensor.validate()?;
863        }
864
865        // Collect all importance scores with their tensor and element indices
866        let mut all_scores: Vec<(usize, usize, f32)> = Vec::new();
867
868        for (tensor_idx, tensor) in tensors.iter().enumerate() {
869            let gradient_info = gradient_infos.get(tensor_idx);
870            let importance = self.compute_importance(tensor, gradient_info);
871
872            for (elem_idx, &score) in importance.iter().enumerate() {
873                all_scores.push((tensor_idx, elem_idx, score));
874            }
875        }
876
877        // Sort by importance (ascending)
878        all_scores.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
879
880        // Determine how many to prune based on mode
881        let total_weights = all_scores.len();
882        let num_to_prune = match self.mask_mode {
883            MaskCreationMode::GlobalPercentage(sparsity) => {
884                ((total_weights as f32 * sparsity).round() as usize).min(total_weights)
885            }
886            MaskCreationMode::TopK(k) => total_weights.saturating_sub(k),
887            MaskCreationMode::Threshold(threshold) => {
888                all_scores.iter().filter(|(_, _, s)| *s < threshold).count()
889            }
890            MaskCreationMode::LayerWisePercentage(_) | MaskCreationMode::TopKPerLayer(_) => {
891                // For layer-wise modes, fall back to per-tensor pruning
892                return self.prune_tensors_layerwise_with_gradients(tensors, gradient_infos);
893            }
894        };
895
896        // Create masks for each tensor
897        let mut masks: Vec<Vec<bool>> = tensors.iter().map(|t| vec![true; t.data.len()]).collect();
898
899        // Mark weights to prune
900        for (tensor_idx, elem_idx, _) in all_scores.iter().take(num_to_prune) {
901            if let Some(mask) = masks.get_mut(*tensor_idx) {
902                if let Some(elem) = mask.get_mut(*elem_idx) {
903                    *elem = false;
904                }
905            }
906        }
907
908        // Create PruningMask objects and apply to tensors
909        let mut result_tensors = Vec::with_capacity(tensors.len());
910        let mut result_masks = Vec::with_capacity(tensors.len());
911
912        for (tensor, mask_vec) in tensors.iter().zip(masks) {
913            let mask = PruningMask::with_name(mask_vec, tensor.shape.clone(), tensor.name.clone());
914            let pruned = mask.apply(tensor)?;
915            result_tensors.push(pruned);
916            result_masks.push(mask);
917        }
918
919        // Update internal masks
920        self.masks = result_masks.clone();
921
922        // Calculate sparsity before borrowing lottery_ticket_state mutably
923        let overall_sparsity = self.current_sparsity();
924
925        // Update lottery ticket state if enabled
926        if let Some(ref mut lts) = self.lottery_ticket_state {
927            lts.update_masks(result_masks.clone(), overall_sparsity);
928        }
929
930        info!(
931            "Global pruning complete: {:.1}% overall sparsity ({} tensors)",
932            overall_sparsity * 100.0,
933            tensors.len()
934        );
935
936        Ok((result_tensors, result_masks))
937    }
938
939    /// Prunes multiple weight tensors layer-wise
940    ///
941    /// Each tensor is pruned independently with the same sparsity target.
942    ///
943    /// # Arguments
944    /// * `tensors` - Weight tensors to prune
945    ///
946    /// # Returns
947    /// Tuple of (pruned tensors, masks)
948    ///
949    /// # Errors
950    /// Returns an error if pruning fails
951    pub fn prune_tensors_layerwise(
952        &mut self,
953        tensors: &[WeightTensor],
954    ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
955        self.prune_tensors_layerwise_with_gradients(tensors, &[])
956    }
957
958    /// Prunes multiple weight tensors layer-wise with gradient information
959    ///
960    /// # Arguments
961    /// * `tensors` - Weight tensors to prune
962    /// * `gradient_infos` - Gradient information for each tensor
963    ///
964    /// # Returns
965    /// Tuple of (pruned tensors, masks)
966    ///
967    /// # Errors
968    /// Returns an error if pruning fails
969    pub fn prune_tensors_layerwise_with_gradients(
970        &mut self,
971        tensors: &[WeightTensor],
972        gradient_infos: &[GradientInfo],
973    ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
974        let mut result_tensors = Vec::with_capacity(tensors.len());
975        let mut result_masks = Vec::with_capacity(tensors.len());
976
977        for (i, tensor) in tensors.iter().enumerate() {
978            let gradient_info = gradient_infos.get(i);
979            let (pruned, mask) = self.prune_tensor_with_gradients(tensor, gradient_info)?;
980            result_tensors.push(pruned);
981            result_masks.push(mask);
982        }
983
984        // Update internal masks
985        self.masks = result_masks.clone();
986
987        // Calculate sparsity before borrowing lottery_ticket_state mutably
988        let overall_sparsity = self.current_sparsity();
989
990        // Update lottery ticket state if enabled
991        if let Some(ref mut lts) = self.lottery_ticket_state {
992            lts.update_masks(result_masks.clone(), overall_sparsity);
993        }
994
995        info!(
996            "Layer-wise pruning complete: {:.1}% overall sparsity ({} tensors)",
997            overall_sparsity * 100.0,
998            tensors.len()
999        );
1000
1001        Ok((result_tensors, result_masks))
1002    }
1003
1004    /// Returns the current overall sparsity
1005    #[must_use]
1006    pub fn current_sparsity(&self) -> f32 {
1007        if self.masks.is_empty() {
1008            return 0.0;
1009        }
1010
1011        let total_pruned: usize = self.masks.iter().map(|m| m.num_pruned()).sum();
1012        let total_elements: usize = self.masks.iter().map(|m| m.numel()).sum();
1013
1014        if total_elements == 0 {
1015            0.0
1016        } else {
1017            total_pruned as f32 / total_elements as f32
1018        }
1019    }
1020
1021    /// Performs iterative pruning with fine-tuning
1022    ///
1023    /// # Arguments
1024    /// * `initial_weights` - Initial weight tensors
1025    /// * `callback` - Fine-tuning callback
1026    ///
1027    /// # Returns
1028    /// Final pruned weights and masks after all iterations
1029    ///
1030    /// # Errors
1031    /// Returns an error if pruning or fine-tuning fails
1032    pub fn iterative_prune<F: FineTuneCallback>(
1033        &mut self,
1034        initial_weights: Vec<WeightTensor>,
1035        callback: &mut F,
1036    ) -> Result<(Vec<WeightTensor>, Vec<PruningMask>)> {
1037        let iterations = match self.config.schedule {
1038            PruningSchedule::Iterative { iterations } => iterations,
1039            PruningSchedule::Polynomial { steps, .. } => steps,
1040            PruningSchedule::OneShot => 1,
1041        };
1042
1043        let mut current_weights = initial_weights;
1044
1045        for i in 0..iterations {
1046            // Compute current target sparsity
1047            let target_sparsity = match self.config.schedule {
1048                PruningSchedule::Polynomial {
1049                    initial_sparsity,
1050                    final_sparsity,
1051                    steps,
1052                } => {
1053                    let t = i as f32;
1054                    let total = steps as f32;
1055                    let s_i = initial_sparsity as f32 / 100.0;
1056                    let s_f = final_sparsity as f32 / 100.0;
1057                    s_f + (s_i - s_f) * (1.0 - t / total).powi(3)
1058                }
1059                PruningSchedule::Iterative { iterations: n } => {
1060                    self.config.sparsity_target * ((i + 1) as f32 / n as f32)
1061                }
1062                PruningSchedule::OneShot => self.config.sparsity_target,
1063            };
1064
1065            // Update mask mode with current sparsity target
1066            self.mask_mode = MaskCreationMode::GlobalPercentage(target_sparsity);
1067
1068            info!(
1069                "Iteration {}/{}: target sparsity {:.1}%",
1070                i + 1,
1071                iterations,
1072                target_sparsity * 100.0
1073            );
1074
1075            // Prune
1076            let (pruned, masks) = self.prune_tensors_global(&current_weights)?;
1077            self.current_iteration = i + 1;
1078
1079            // Fine-tune if not the last iteration (or if configured)
1080            current_weights = if self.config.fine_tune && i < iterations - 1 {
1081                let actual_sparsity = self.current_sparsity();
1082                callback.fine_tune(pruned, &masks, i, actual_sparsity)?
1083            } else {
1084                pruned
1085            };
1086        }
1087
1088        let final_masks = self.masks.clone();
1089        Ok((current_weights, final_masks))
1090    }
1091
1092    /// Computes pruning statistics
1093    #[must_use]
1094    pub fn compute_stats(&self, original_tensors: &[WeightTensor]) -> PruningStats {
1095        let original_params: usize = original_tensors.iter().map(|t| t.numel()).sum();
1096
1097        let pruned_params = if self.masks.is_empty() {
1098            original_params
1099        } else {
1100            self.masks.iter().map(|m| m.num_kept()).sum()
1101        };
1102
1103        let actual_sparsity = self.current_sparsity();
1104
1105        PruningStats {
1106            original_params,
1107            pruned_params,
1108            actual_sparsity,
1109        }
1110    }
1111}