1use crate::{CoreError, CoreResult};
21use scirs2_core::ndarray::Array2;
22#[allow(unused_imports)]
23use scirs2_core::ndarray::Axis; use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum PruningStrategy {
30 Magnitude,
32 L1Norm,
34 L2Norm,
36 Gradient,
38 Random,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum PruningGranularity {
45 Unstructured,
47 Channel,
49 Filter,
51 Head,
53 Block,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct PruningConfig {
60 pub strategy: PruningStrategy,
62 pub granularity: PruningGranularity,
64 pub target_sparsity: f32,
66 pub global_threshold: bool,
68 pub num_iterations: usize,
70 pub keep_pruned_weights: bool,
72}
73
74impl Default for PruningConfig {
75 fn default() -> Self {
76 Self {
77 strategy: PruningStrategy::Magnitude,
78 granularity: PruningGranularity::Unstructured,
79 target_sparsity: 0.5,
80 global_threshold: false,
81 num_iterations: 1,
82 keep_pruned_weights: false,
83 }
84 }
85}
86
87impl PruningConfig {
88 pub fn new(strategy: PruningStrategy, target_sparsity: f32) -> Self {
90 Self {
91 strategy,
92 target_sparsity,
93 ..Default::default()
94 }
95 }
96
97 pub fn with_granularity(mut self, granularity: PruningGranularity) -> Self {
99 self.granularity = granularity;
100 self
101 }
102
103 pub fn with_global_threshold(mut self) -> Self {
105 self.global_threshold = true;
106 self
107 }
108
109 pub fn with_iterations(mut self, num_iterations: usize) -> Self {
111 self.num_iterations = num_iterations;
112 self
113 }
114
115 pub fn with_keep_weights(mut self) -> Self {
117 self.keep_pruned_weights = true;
118 self
119 }
120
121 pub fn validate(&self) -> CoreResult<()> {
123 if self.target_sparsity < 0.0 || self.target_sparsity >= 1.0 {
124 return Err(CoreError::InvalidConfig(
125 "target_sparsity must be in [0, 1)".into(),
126 ));
127 }
128 if self.num_iterations == 0 {
129 return Err(CoreError::InvalidConfig(
130 "num_iterations must be > 0".into(),
131 ));
132 }
133 Ok(())
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct PruningMask {
140 pub mask: Array2<f32>,
142 pub pruned_weights: Option<Array2<f32>>,
144 pub sparsity: f32,
146}
147
148impl PruningMask {
149 pub fn new(mask: Array2<f32>) -> Self {
151 let total = mask.len();
152 let zeros = mask.iter().filter(|&&x| x == 0.0).count();
153 let sparsity = zeros as f32 / total as f32;
154
155 Self {
156 mask,
157 pruned_weights: None,
158 sparsity,
159 }
160 }
161
162 pub fn apply(&self, weights: &Array2<f32>) -> Array2<f32> {
164 weights * &self.mask
165 }
166
167 pub fn num_parameters(&self) -> usize {
169 self.mask.iter().filter(|&&x| x != 0.0).count()
170 }
171
172 pub fn compression_ratio(&self) -> f32 {
174 1.0 / (1.0 - self.sparsity).max(1e-6)
175 }
176}
177
178pub struct StructuredPruner {
180 config: PruningConfig,
181 masks: HashMap<String, PruningMask>,
182}
183
184impl StructuredPruner {
185 pub fn new(config: PruningConfig) -> CoreResult<Self> {
187 config.validate()?;
188 Ok(Self {
189 config,
190 masks: HashMap::new(),
191 })
192 }
193
194 pub fn prune(&mut self, name: &str, weights: &Array2<f32>) -> CoreResult<PruningMask> {
196 let mask = match self.config.granularity {
197 PruningGranularity::Unstructured => self.prune_unstructured(weights)?,
198 PruningGranularity::Channel => self.prune_channels(weights)?,
199 PruningGranularity::Filter => self.prune_filters(weights)?,
200 _ => {
201 return Err(CoreError::InvalidConfig(format!(
202 "Granularity {:?} not yet implemented for 2D tensors",
203 self.config.granularity
204 )))
205 }
206 };
207
208 self.masks.insert(name.to_string(), mask.clone());
209 Ok(mask)
210 }
211
212 fn prune_unstructured(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
214 let importance = self.compute_importance(weights)?;
215 let threshold = self.compute_threshold(&importance)?;
216
217 let mask = importance.mapv(|v| if v.abs() >= threshold { 1.0 } else { 0.0 });
218 Ok(PruningMask::new(mask))
219 }
220
221 fn prune_channels(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
223 let (out_channels, _in_features) = weights.dim();
224
225 let mut channel_importance = Vec::with_capacity(out_channels);
227 for channel_idx in 0..out_channels {
228 let channel = weights.row(channel_idx);
229 let importance = match self.config.strategy {
230 PruningStrategy::L1Norm => channel.iter().map(|x| x.abs()).sum::<f32>(),
231 PruningStrategy::L2Norm => channel.iter().map(|x| x.powi(2)).sum::<f32>().sqrt(),
232 PruningStrategy::Magnitude => {
233 channel.iter().map(|x| x.abs()).sum::<f32>() / channel.len() as f32
234 }
235 _ => channel.iter().map(|x| x.abs()).sum::<f32>(),
236 };
237 channel_importance.push((channel_idx, importance));
238 }
239
240 channel_importance.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
242
243 let num_to_prune = (out_channels as f32 * self.config.target_sparsity) as usize;
245
246 let mut mask = Array2::ones(weights.dim());
248 for &(channel_idx, _) in channel_importance.iter().take(num_to_prune) {
249 mask.row_mut(channel_idx).fill(0.0);
250 }
251
252 Ok(PruningMask::new(mask))
253 }
254
255 fn prune_filters(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
257 self.prune_channels(weights)
259 }
260
261 fn compute_importance(&self, weights: &Array2<f32>) -> CoreResult<Array2<f32>> {
263 let importance = match self.config.strategy {
264 PruningStrategy::Magnitude => weights.mapv(|x| x.abs()),
265 PruningStrategy::L1Norm => weights.mapv(|x| x.abs()),
266 PruningStrategy::L2Norm => weights.mapv(|x| x.powi(2)),
267 PruningStrategy::Random => {
268 use scirs2_core::random::thread_rng;
270 let mut rng = thread_rng();
271 Array2::from_shape_fn(weights.dim(), |_| rng.random::<f32>())
272 }
273 PruningStrategy::Gradient => {
274 weights.mapv(|x| x.abs())
276 }
277 };
278
279 Ok(importance)
280 }
281
282 fn compute_threshold(&self, importance: &Array2<f32>) -> CoreResult<f32> {
284 let mut values: Vec<f32> = importance.iter().copied().collect();
286 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
287
288 let threshold_idx = (values.len() as f32 * self.config.target_sparsity) as usize;
290 let threshold = values.get(threshold_idx).copied().unwrap_or(0.0);
291
292 Ok(threshold)
293 }
294
295 pub fn prune_progressive(
297 &mut self,
298 name: &str,
299 weights: &Array2<f32>,
300 ) -> CoreResult<Vec<PruningMask>> {
301 let mut masks = Vec::with_capacity(self.config.num_iterations);
302 let sparsity_per_iter = self.config.target_sparsity / self.config.num_iterations as f32;
303
304 let mut current_weights = weights.clone();
305 for iter in 0..self.config.num_iterations {
306 let iter_config = PruningConfig {
308 target_sparsity: sparsity_per_iter,
309 ..self.config.clone()
310 };
311
312 let mut iter_pruner = StructuredPruner::new(iter_config)?;
313 let mask = iter_pruner.prune(&format!("{}_{}", name, iter), ¤t_weights)?;
314
315 current_weights = mask.apply(¤t_weights);
317 masks.push(mask);
318 }
319
320 if let Some(final_mask) = masks.last() {
322 self.masks.insert(name.to_string(), final_mask.clone());
323 }
324
325 Ok(masks)
326 }
327
328 pub fn get_mask(&self, name: &str) -> Option<&PruningMask> {
330 self.masks.get(name)
331 }
332
333 pub fn masks(&self) -> &HashMap<String, PruningMask> {
335 &self.masks
336 }
337
338 pub fn global_sparsity(&self) -> f32 {
340 if self.masks.is_empty() {
341 return 0.0;
342 }
343
344 let total_params: usize = self.masks.values().map(|m| m.mask.len()).sum();
345 let pruned_params: usize = self
346 .masks
347 .values()
348 .map(|m| m.mask.iter().filter(|&&x| x == 0.0).count())
349 .sum();
350
351 pruned_params as f32 / total_params as f32
352 }
353
354 pub fn global_compression_ratio(&self) -> f32 {
356 let sparsity = self.global_sparsity();
357 1.0 / (1.0 - sparsity).max(1e-6)
358 }
359}
360
361pub struct GradientPruner {
363 pruner: StructuredPruner,
364 gradient_accumulator: HashMap<String, Array2<f32>>,
366}
367
368impl GradientPruner {
369 pub fn new(config: PruningConfig) -> CoreResult<Self> {
371 Ok(Self {
372 pruner: StructuredPruner::new(config)?,
373 gradient_accumulator: HashMap::new(),
374 })
375 }
376
377 pub fn accumulate_gradient(&mut self, name: &str, gradient: &Array2<f32>) {
379 let acc = self
380 .gradient_accumulator
381 .entry(name.to_string())
382 .or_insert_with(|| Array2::zeros(gradient.dim()));
383 *acc = &*acc + gradient;
384 }
385
386 pub fn prune_with_gradients(
388 &mut self,
389 name: &str,
390 weights: &Array2<f32>,
391 ) -> CoreResult<PruningMask> {
392 let gradients = self
394 .gradient_accumulator
395 .get(name)
396 .ok_or_else(|| CoreError::InvalidConfig("No gradients accumulated".into()))?;
397
398 let importance = weights * gradients;
400 let importance = importance.mapv(|x| x.abs());
401
402 let threshold = self.compute_gradient_threshold(&importance)?;
404 let mask = importance.mapv(|v| if v >= threshold { 1.0 } else { 0.0 });
405
406 let pruning_mask = PruningMask::new(mask);
407 self.pruner
408 .masks
409 .insert(name.to_string(), pruning_mask.clone());
410
411 Ok(pruning_mask)
412 }
413
414 fn compute_gradient_threshold(&self, importance: &Array2<f32>) -> CoreResult<f32> {
416 let mut values: Vec<f32> = importance.iter().copied().collect();
417 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
418
419 let threshold_idx = (values.len() as f32 * self.pruner.config.target_sparsity) as usize;
420 Ok(values.get(threshold_idx).copied().unwrap_or(0.0))
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_pruning_config() {
430 let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
431 assert_eq!(config.strategy, PruningStrategy::Magnitude);
432 assert_eq!(config.target_sparsity, 0.5);
433 assert!(config.validate().is_ok());
434 }
435
436 #[test]
437 fn test_pruning_config_validation() {
438 let mut config = PruningConfig::new(PruningStrategy::Magnitude, 1.5);
439 assert!(config.validate().is_err());
440
441 config.target_sparsity = -0.1;
442 assert!(config.validate().is_err());
443
444 config.target_sparsity = 0.5;
445 config.num_iterations = 0;
446 assert!(config.validate().is_err());
447 }
448
449 #[test]
450 fn test_unstructured_pruning() {
451 let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
452 let mut pruner = StructuredPruner::new(config).unwrap();
453
454 let weights = Array2::from_shape_fn((10, 10), |(i, j)| ((i * 10 + j) as f32) * 0.01);
456
457 let mask = pruner.prune("layer1", &weights).unwrap();
458 assert!(
460 mask.sparsity >= 0.45 && mask.sparsity <= 0.55,
461 "Expected sparsity ~0.5, got {}",
462 mask.sparsity
463 );
464 }
465
466 #[test]
467 fn test_channel_pruning() {
468 let config = PruningConfig::new(PruningStrategy::L2Norm, 0.5)
469 .with_granularity(PruningGranularity::Channel);
470 let mut pruner = StructuredPruner::new(config).unwrap();
471
472 let weights = Array2::from_shape_fn((8, 16), |(i, _j)| {
473 if i < 4 {
474 1.0
475 } else {
476 0.1
477 } });
479
480 let mask = pruner.prune("layer1", &weights).unwrap();
481
482 for row in mask.mask.axis_iter(Axis(0)) {
484 let sum: f32 = row.sum();
485 assert!(sum == 0.0 || sum == row.len() as f32);
486 }
487 }
488
489 #[test]
490 fn test_pruning_mask_apply() {
491 let mask_data = Array2::from_shape_fn((4, 4), |(i, j)| if i == j { 1.0 } else { 0.0 });
492 let mask = PruningMask::new(mask_data);
493
494 let weights = Array2::ones((4, 4));
495 let pruned = mask.apply(&weights);
496
497 for i in 0..4 {
499 for j in 0..4 {
500 if i == j {
501 assert_eq!(pruned[[i, j]], 1.0);
502 } else {
503 assert_eq!(pruned[[i, j]], 0.0);
504 }
505 }
506 }
507 }
508
509 #[test]
510 fn test_progressive_pruning() {
511 let config = PruningConfig::new(PruningStrategy::Magnitude, 0.6).with_iterations(3);
512 let mut pruner = StructuredPruner::new(config).unwrap();
513
514 let weights = Array2::from_shape_fn((8, 8), |(i, j)| (i as f32 + j as f32) * 0.1);
515
516 let masks = pruner.prune_progressive("layer1", &weights).unwrap();
517 assert_eq!(masks.len(), 3);
518
519 for i in 1..masks.len() {
521 assert!(masks[i].sparsity >= masks[i - 1].sparsity);
522 }
523 }
524
525 #[test]
526 fn test_compression_ratio() {
527 let mask = PruningMask::new(Array2::from_shape_fn((10, 10), |(i, j)| {
528 if i + j < 5 {
529 1.0
530 } else {
531 0.0
532 }
533 }));
534
535 let ratio = mask.compression_ratio();
536 assert!(ratio > 1.0); assert!(ratio < 10.0); }
539
540 #[test]
541 fn test_global_sparsity() {
542 let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
543 let mut pruner = StructuredPruner::new(config).unwrap();
544
545 let weights1 = Array2::from_shape_fn((4, 4), |(i, j)| (i + j) as f32);
546 let weights2 = Array2::from_shape_fn((4, 4), |(i, j)| (i * j) as f32);
547
548 pruner.prune("layer1", &weights1).unwrap();
549 pruner.prune("layer2", &weights2).unwrap();
550
551 let global_sparsity = pruner.global_sparsity();
552 assert!((0.4..=0.6).contains(&global_sparsity));
553 }
554
555 #[test]
556 fn test_gradient_pruner_accumulation() {
557 let config = PruningConfig::new(PruningStrategy::Gradient, 0.5);
558 let mut pruner = GradientPruner::new(config).unwrap();
559
560 let gradient1 = Array2::ones((4, 4));
561 let gradient2 = Array2::ones((4, 4)) * 2.0;
562
563 pruner.accumulate_gradient("layer1", &gradient1);
564 pruner.accumulate_gradient("layer1", &gradient2);
565
566 let accumulated = &pruner.gradient_accumulator["layer1"];
567 assert_eq!(accumulated[[0, 0]], 3.0);
568 }
569
570 #[test]
571 fn test_random_pruning() {
572 let config = PruningConfig::new(PruningStrategy::Random, 0.5);
573 let mut pruner = StructuredPruner::new(config).unwrap();
574
575 let weights = Array2::ones((10, 10));
576 let mask = pruner.prune("layer1", &weights).unwrap();
577
578 assert!(mask.sparsity >= 0.4 && mask.sparsity <= 0.6);
580 }
581}