kizzasi_core/
pruning.rs

1//! # Structured Pruning
2//!
3//! Advanced pruning algorithms for model compression with structure preservation.
4//!
5//! ## Features
6//!
7//! - **Magnitude Pruning**: Remove weights below threshold
8//! - **Structured Pruning**: Remove entire channels, filters, or heads
9//! - **L1/L2 Norm Pruning**: Importance scoring based on norms
10//! - **Gradient-based Pruning**: Use gradient information for importance
11//! - **Progressive Pruning**: Gradual pruning with retraining
12//! - **Group Lasso**: Structured sparsity via regularization
13//!
14//! ## References
15//!
16//! - "Learning both Weights and Connections for Efficient Neural Networks" (Han et al., 2015)
17//! - "Pruning Filters for Efficient ConvNets" (Li et al., 2017)
18//! - "The Lottery Ticket Hypothesis" (Frankle & Carbin, 2019)
19
20use crate::{CoreError, CoreResult};
21use scirs2_core::ndarray::Array2;
22#[allow(unused_imports)]
23use scirs2_core::ndarray::Axis; // Used in tests
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27/// Pruning strategy
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum PruningStrategy {
30    /// Unstructured magnitude-based pruning
31    Magnitude,
32    /// L1 norm-based structured pruning
33    L1Norm,
34    /// L2 norm-based structured pruning
35    L2Norm,
36    /// Gradient magnitude-based pruning
37    Gradient,
38    /// Random pruning (baseline)
39    Random,
40}
41
42/// Pruning granularity for structured pruning
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum PruningGranularity {
45    /// Unstructured (individual weights)
46    Unstructured,
47    /// Channel-wise (entire input/output channels)
48    Channel,
49    /// Filter-wise (entire filters in convolutions)
50    Filter,
51    /// Head-wise (entire attention heads)
52    Head,
53    /// Block-wise (entire transformer/SSM blocks)
54    Block,
55}
56
57/// Pruning configuration
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct PruningConfig {
60    /// Pruning strategy
61    pub strategy: PruningStrategy,
62    /// Pruning granularity
63    pub granularity: PruningGranularity,
64    /// Target sparsity ratio (0.0 to 1.0)
65    pub target_sparsity: f32,
66    /// Whether to use global threshold (vs per-layer)
67    pub global_threshold: bool,
68    /// Number of pruning iterations (for progressive pruning)
69    pub num_iterations: usize,
70    /// Whether to keep pruned weights for recovery
71    pub keep_pruned_weights: bool,
72}
73
74impl Default for PruningConfig {
75    fn default() -> Self {
76        Self {
77            strategy: PruningStrategy::Magnitude,
78            granularity: PruningGranularity::Unstructured,
79            target_sparsity: 0.5,
80            global_threshold: false,
81            num_iterations: 1,
82            keep_pruned_weights: false,
83        }
84    }
85}
86
87impl PruningConfig {
88    /// Create new pruning configuration
89    pub fn new(strategy: PruningStrategy, target_sparsity: f32) -> Self {
90        Self {
91            strategy,
92            target_sparsity,
93            ..Default::default()
94        }
95    }
96
97    /// Set granularity
98    pub fn with_granularity(mut self, granularity: PruningGranularity) -> Self {
99        self.granularity = granularity;
100        self
101    }
102
103    /// Enable global thresholding
104    pub fn with_global_threshold(mut self) -> Self {
105        self.global_threshold = true;
106        self
107    }
108
109    /// Set number of iterations for progressive pruning
110    pub fn with_iterations(mut self, num_iterations: usize) -> Self {
111        self.num_iterations = num_iterations;
112        self
113    }
114
115    /// Keep pruned weights
116    pub fn with_keep_weights(mut self) -> Self {
117        self.keep_pruned_weights = true;
118        self
119    }
120
121    /// Validate configuration
122    pub fn validate(&self) -> CoreResult<()> {
123        if self.target_sparsity < 0.0 || self.target_sparsity >= 1.0 {
124            return Err(CoreError::InvalidConfig(
125                "target_sparsity must be in [0, 1)".into(),
126            ));
127        }
128        if self.num_iterations == 0 {
129            return Err(CoreError::InvalidConfig(
130                "num_iterations must be > 0".into(),
131            ));
132        }
133        Ok(())
134    }
135}
136
137/// Pruning mask for a weight tensor
138#[derive(Debug, Clone)]
139pub struct PruningMask {
140    /// Binary mask (1 = keep, 0 = prune)
141    pub mask: Array2<f32>,
142    /// Pruned weights (if keep_pruned_weights is enabled)
143    pub pruned_weights: Option<Array2<f32>>,
144    /// Sparsity ratio achieved
145    pub sparsity: f32,
146}
147
148impl PruningMask {
149    /// Create a new pruning mask
150    pub fn new(mask: Array2<f32>) -> Self {
151        let total = mask.len();
152        let zeros = mask.iter().filter(|&&x| x == 0.0).count();
153        let sparsity = zeros as f32 / total as f32;
154
155        Self {
156            mask,
157            pruned_weights: None,
158            sparsity,
159        }
160    }
161
162    /// Apply mask to weights
163    pub fn apply(&self, weights: &Array2<f32>) -> Array2<f32> {
164        weights * &self.mask
165    }
166
167    /// Count remaining (non-zero) parameters
168    pub fn num_parameters(&self) -> usize {
169        self.mask.iter().filter(|&&x| x != 0.0).count()
170    }
171
172    /// Get compression ratio
173    pub fn compression_ratio(&self) -> f32 {
174        1.0 / (1.0 - self.sparsity).max(1e-6)
175    }
176}
177
178/// Structured pruner for neural network layers
179pub struct StructuredPruner {
180    config: PruningConfig,
181    masks: HashMap<String, PruningMask>,
182}
183
184impl StructuredPruner {
185    /// Create a new structured pruner
186    pub fn new(config: PruningConfig) -> CoreResult<Self> {
187        config.validate()?;
188        Ok(Self {
189            config,
190            masks: HashMap::new(),
191        })
192    }
193
194    /// Prune a 2D weight matrix
195    pub fn prune(&mut self, name: &str, weights: &Array2<f32>) -> CoreResult<PruningMask> {
196        let mask = match self.config.granularity {
197            PruningGranularity::Unstructured => self.prune_unstructured(weights)?,
198            PruningGranularity::Channel => self.prune_channels(weights)?,
199            PruningGranularity::Filter => self.prune_filters(weights)?,
200            _ => {
201                return Err(CoreError::InvalidConfig(format!(
202                    "Granularity {:?} not yet implemented for 2D tensors",
203                    self.config.granularity
204                )))
205            }
206        };
207
208        self.masks.insert(name.to_string(), mask.clone());
209        Ok(mask)
210    }
211
212    /// Unstructured pruning (individual weights)
213    fn prune_unstructured(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
214        let importance = self.compute_importance(weights)?;
215        let threshold = self.compute_threshold(&importance)?;
216
217        let mask = importance.mapv(|v| if v.abs() >= threshold { 1.0 } else { 0.0 });
218        Ok(PruningMask::new(mask))
219    }
220
221    /// Channel-wise pruning (prune entire output channels)
222    fn prune_channels(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
223        let (out_channels, _in_features) = weights.dim();
224
225        // Compute importance score for each output channel
226        let mut channel_importance = Vec::with_capacity(out_channels);
227        for channel_idx in 0..out_channels {
228            let channel = weights.row(channel_idx);
229            let importance = match self.config.strategy {
230                PruningStrategy::L1Norm => channel.iter().map(|x| x.abs()).sum::<f32>(),
231                PruningStrategy::L2Norm => channel.iter().map(|x| x.powi(2)).sum::<f32>().sqrt(),
232                PruningStrategy::Magnitude => {
233                    channel.iter().map(|x| x.abs()).sum::<f32>() / channel.len() as f32
234                }
235                _ => channel.iter().map(|x| x.abs()).sum::<f32>(),
236            };
237            channel_importance.push((channel_idx, importance));
238        }
239
240        // Sort by importance (ascending)
241        channel_importance.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
242
243        // Determine how many channels to prune
244        let num_to_prune = (out_channels as f32 * self.config.target_sparsity) as usize;
245
246        // Create mask
247        let mut mask = Array2::ones(weights.dim());
248        for &(channel_idx, _) in channel_importance.iter().take(num_to_prune) {
249            mask.row_mut(channel_idx).fill(0.0);
250        }
251
252        Ok(PruningMask::new(mask))
253    }
254
255    /// Filter-wise pruning (similar to channel pruning for conv layers)
256    fn prune_filters(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
257        // For 2D weights, treat as channels
258        self.prune_channels(weights)
259    }
260
261    /// Compute importance scores for weights
262    fn compute_importance(&self, weights: &Array2<f32>) -> CoreResult<Array2<f32>> {
263        let importance = match self.config.strategy {
264            PruningStrategy::Magnitude => weights.mapv(|x| x.abs()),
265            PruningStrategy::L1Norm => weights.mapv(|x| x.abs()),
266            PruningStrategy::L2Norm => weights.mapv(|x| x.powi(2)),
267            PruningStrategy::Random => {
268                // Random importance for baseline
269                use scirs2_core::random::thread_rng;
270                let mut rng = thread_rng();
271                Array2::from_shape_fn(weights.dim(), |_| rng.random::<f32>())
272            }
273            PruningStrategy::Gradient => {
274                // Would need gradient information; use magnitude as fallback
275                weights.mapv(|x| x.abs())
276            }
277        };
278
279        Ok(importance)
280    }
281
282    /// Compute threshold for pruning
283    fn compute_threshold(&self, importance: &Array2<f32>) -> CoreResult<f32> {
284        // Flatten and sort importance values
285        let mut values: Vec<f32> = importance.iter().copied().collect();
286        values.sort_by(|a, b| a.partial_cmp(b).unwrap());
287
288        // Find threshold at target sparsity percentile
289        let threshold_idx = (values.len() as f32 * self.config.target_sparsity) as usize;
290        let threshold = values.get(threshold_idx).copied().unwrap_or(0.0);
291
292        Ok(threshold)
293    }
294
295    /// Progressive pruning over multiple iterations
296    pub fn prune_progressive(
297        &mut self,
298        name: &str,
299        weights: &Array2<f32>,
300    ) -> CoreResult<Vec<PruningMask>> {
301        let mut masks = Vec::with_capacity(self.config.num_iterations);
302        let sparsity_per_iter = self.config.target_sparsity / self.config.num_iterations as f32;
303
304        let mut current_weights = weights.clone();
305        for iter in 0..self.config.num_iterations {
306            // Adjust target sparsity for this iteration
307            let iter_config = PruningConfig {
308                target_sparsity: sparsity_per_iter,
309                ..self.config.clone()
310            };
311
312            let mut iter_pruner = StructuredPruner::new(iter_config)?;
313            let mask = iter_pruner.prune(&format!("{}_{}", name, iter), &current_weights)?;
314
315            // Apply mask for next iteration
316            current_weights = mask.apply(&current_weights);
317            masks.push(mask);
318        }
319
320        // Store final mask
321        if let Some(final_mask) = masks.last() {
322            self.masks.insert(name.to_string(), final_mask.clone());
323        }
324
325        Ok(masks)
326    }
327
328    /// Get mask for a layer
329    pub fn get_mask(&self, name: &str) -> Option<&PruningMask> {
330        self.masks.get(name)
331    }
332
333    /// Get all masks
334    pub fn masks(&self) -> &HashMap<String, PruningMask> {
335        &self.masks
336    }
337
338    /// Compute global sparsity across all pruned layers
339    pub fn global_sparsity(&self) -> f32 {
340        if self.masks.is_empty() {
341            return 0.0;
342        }
343
344        let total_params: usize = self.masks.values().map(|m| m.mask.len()).sum();
345        let pruned_params: usize = self
346            .masks
347            .values()
348            .map(|m| m.mask.iter().filter(|&&x| x == 0.0).count())
349            .sum();
350
351        pruned_params as f32 / total_params as f32
352    }
353
354    /// Get compression ratio across all layers
355    pub fn global_compression_ratio(&self) -> f32 {
356        let sparsity = self.global_sparsity();
357        1.0 / (1.0 - sparsity).max(1e-6)
358    }
359}
360
361/// Gradient-based pruning (requires gradient information)
362pub struct GradientPruner {
363    pruner: StructuredPruner,
364    /// Gradient accumulation for importance scoring
365    gradient_accumulator: HashMap<String, Array2<f32>>,
366}
367
368impl GradientPruner {
369    /// Create a new gradient-based pruner
370    pub fn new(config: PruningConfig) -> CoreResult<Self> {
371        Ok(Self {
372            pruner: StructuredPruner::new(config)?,
373            gradient_accumulator: HashMap::new(),
374        })
375    }
376
377    /// Accumulate gradient for a layer
378    pub fn accumulate_gradient(&mut self, name: &str, gradient: &Array2<f32>) {
379        let acc = self
380            .gradient_accumulator
381            .entry(name.to_string())
382            .or_insert_with(|| Array2::zeros(gradient.dim()));
383        *acc = &*acc + gradient;
384    }
385
386    /// Prune using accumulated gradients
387    pub fn prune_with_gradients(
388        &mut self,
389        name: &str,
390        weights: &Array2<f32>,
391    ) -> CoreResult<PruningMask> {
392        // Get accumulated gradients
393        let gradients = self
394            .gradient_accumulator
395            .get(name)
396            .ok_or_else(|| CoreError::InvalidConfig("No gradients accumulated".into()))?;
397
398        // Compute importance as |weight * gradient|
399        let importance = weights * gradients;
400        let importance = importance.mapv(|x| x.abs());
401
402        // Use importance to create mask
403        let threshold = self.compute_gradient_threshold(&importance)?;
404        let mask = importance.mapv(|v| if v >= threshold { 1.0 } else { 0.0 });
405
406        let pruning_mask = PruningMask::new(mask);
407        self.pruner
408            .masks
409            .insert(name.to_string(), pruning_mask.clone());
410
411        Ok(pruning_mask)
412    }
413
414    /// Compute threshold from gradient-based importance
415    fn compute_gradient_threshold(&self, importance: &Array2<f32>) -> CoreResult<f32> {
416        let mut values: Vec<f32> = importance.iter().copied().collect();
417        values.sort_by(|a, b| a.partial_cmp(b).unwrap());
418
419        let threshold_idx = (values.len() as f32 * self.pruner.config.target_sparsity) as usize;
420        Ok(values.get(threshold_idx).copied().unwrap_or(0.0))
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_pruning_config() {
430        let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
431        assert_eq!(config.strategy, PruningStrategy::Magnitude);
432        assert_eq!(config.target_sparsity, 0.5);
433        assert!(config.validate().is_ok());
434    }
435
436    #[test]
437    fn test_pruning_config_validation() {
438        let mut config = PruningConfig::new(PruningStrategy::Magnitude, 1.5);
439        assert!(config.validate().is_err());
440
441        config.target_sparsity = -0.1;
442        assert!(config.validate().is_err());
443
444        config.target_sparsity = 0.5;
445        config.num_iterations = 0;
446        assert!(config.validate().is_err());
447    }
448
449    #[test]
450    fn test_unstructured_pruning() {
451        let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
452        let mut pruner = StructuredPruner::new(config).unwrap();
453
454        // Create weights with uniform distribution for predictable pruning
455        let weights = Array2::from_shape_fn((10, 10), |(i, j)| ((i * 10 + j) as f32) * 0.01);
456
457        let mask = pruner.prune("layer1", &weights).unwrap();
458        // Should be close to 50% sparsity, allow some tolerance
459        assert!(
460            mask.sparsity >= 0.45 && mask.sparsity <= 0.55,
461            "Expected sparsity ~0.5, got {}",
462            mask.sparsity
463        );
464    }
465
466    #[test]
467    fn test_channel_pruning() {
468        let config = PruningConfig::new(PruningStrategy::L2Norm, 0.5)
469            .with_granularity(PruningGranularity::Channel);
470        let mut pruner = StructuredPruner::new(config).unwrap();
471
472        let weights = Array2::from_shape_fn((8, 16), |(i, _j)| {
473            if i < 4 {
474                1.0
475            } else {
476                0.1
477            } // First 4 channels more important
478        });
479
480        let mask = pruner.prune("layer1", &weights).unwrap();
481
482        // Check that entire channels are pruned
483        for row in mask.mask.axis_iter(Axis(0)) {
484            let sum: f32 = row.sum();
485            assert!(sum == 0.0 || sum == row.len() as f32);
486        }
487    }
488
489    #[test]
490    fn test_pruning_mask_apply() {
491        let mask_data = Array2::from_shape_fn((4, 4), |(i, j)| if i == j { 1.0 } else { 0.0 });
492        let mask = PruningMask::new(mask_data);
493
494        let weights = Array2::ones((4, 4));
495        let pruned = mask.apply(&weights);
496
497        // Should only keep diagonal elements
498        for i in 0..4 {
499            for j in 0..4 {
500                if i == j {
501                    assert_eq!(pruned[[i, j]], 1.0);
502                } else {
503                    assert_eq!(pruned[[i, j]], 0.0);
504                }
505            }
506        }
507    }
508
509    #[test]
510    fn test_progressive_pruning() {
511        let config = PruningConfig::new(PruningStrategy::Magnitude, 0.6).with_iterations(3);
512        let mut pruner = StructuredPruner::new(config).unwrap();
513
514        let weights = Array2::from_shape_fn((8, 8), |(i, j)| (i as f32 + j as f32) * 0.1);
515
516        let masks = pruner.prune_progressive("layer1", &weights).unwrap();
517        assert_eq!(masks.len(), 3);
518
519        // Sparsity should increase with iterations
520        for i in 1..masks.len() {
521            assert!(masks[i].sparsity >= masks[i - 1].sparsity);
522        }
523    }
524
525    #[test]
526    fn test_compression_ratio() {
527        let mask = PruningMask::new(Array2::from_shape_fn((10, 10), |(i, j)| {
528            if i + j < 5 {
529                1.0
530            } else {
531                0.0
532            }
533        }));
534
535        let ratio = mask.compression_ratio();
536        assert!(ratio > 1.0); // Should have compression
537        assert!(ratio < 10.0); // But not too extreme
538    }
539
540    #[test]
541    fn test_global_sparsity() {
542        let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
543        let mut pruner = StructuredPruner::new(config).unwrap();
544
545        let weights1 = Array2::from_shape_fn((4, 4), |(i, j)| (i + j) as f32);
546        let weights2 = Array2::from_shape_fn((4, 4), |(i, j)| (i * j) as f32);
547
548        pruner.prune("layer1", &weights1).unwrap();
549        pruner.prune("layer2", &weights2).unwrap();
550
551        let global_sparsity = pruner.global_sparsity();
552        assert!((0.4..=0.6).contains(&global_sparsity));
553    }
554
555    #[test]
556    fn test_gradient_pruner_accumulation() {
557        let config = PruningConfig::new(PruningStrategy::Gradient, 0.5);
558        let mut pruner = GradientPruner::new(config).unwrap();
559
560        let gradient1 = Array2::ones((4, 4));
561        let gradient2 = Array2::ones((4, 4)) * 2.0;
562
563        pruner.accumulate_gradient("layer1", &gradient1);
564        pruner.accumulate_gradient("layer1", &gradient2);
565
566        let accumulated = &pruner.gradient_accumulator["layer1"];
567        assert_eq!(accumulated[[0, 0]], 3.0);
568    }
569
570    #[test]
571    fn test_random_pruning() {
572        let config = PruningConfig::new(PruningStrategy::Random, 0.5);
573        let mut pruner = StructuredPruner::new(config).unwrap();
574
575        let weights = Array2::ones((10, 10));
576        let mask = pruner.prune("layer1", &weights).unwrap();
577
578        // Should achieve approximately 50% sparsity
579        assert!(mask.sparsity >= 0.4 && mask.sparsity <= 0.6);
580    }
581}