1use crate::model_merge::WeightTensor;
11use thiserror::Error;
12
13#[derive(Debug, Error)]
19pub enum PruningError {
20 #[error("sparsity {0} must be in [0.0, 1.0)")]
21 InvalidSparsity(f32),
22 #[error("empty tensor: '{0}'")]
23 EmptyTensor(String),
24 #[error("structured pruning requires 2D tensor, got shape {0:?}")]
25 NotTwoDimensional(Vec<usize>),
26 #[error("cannot prune below min_nonzero={0} with {1} total elements")]
27 BelowMinNonzero(usize, usize),
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
36pub enum ImportanceMetric {
37 L1Magnitude,
39 L2Magnitude,
41 TaylorProxy,
44 Random { seed: u64 },
47}
48
49#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum PruningGranularity {
56 Unstructured,
58 StructuredRow,
60 StructuredColumn,
62}
63
64#[derive(Debug, Clone)]
70pub struct PruningConfig {
71 pub sparsity: f32,
73 pub metric: ImportanceMetric,
75 pub granularity: PruningGranularity,
77 pub min_nonzero: usize,
79}
80
81impl PruningConfig {
82 pub fn new(sparsity: f32, metric: ImportanceMetric, granularity: PruningGranularity) -> Self {
84 Self {
85 sparsity,
86 metric,
87 granularity,
88 min_nonzero: 1,
89 }
90 }
91
92 pub fn unstructured_l1(sparsity: f32) -> Self {
94 Self::new(
95 sparsity,
96 ImportanceMetric::L1Magnitude,
97 PruningGranularity::Unstructured,
98 )
99 }
100
101 pub fn structured_row_l2(sparsity: f32) -> Self {
103 Self::new(
104 sparsity,
105 ImportanceMetric::L2Magnitude,
106 PruningGranularity::StructuredRow,
107 )
108 }
109}
110
111#[derive(Debug, Clone)]
117pub struct ScoreStats {
118 pub min: f32,
119 pub max: f32,
120 pub mean: f32,
121 pub median: f32,
122 pub std_dev: f32,
123}
124
125#[derive(Debug, Clone)]
131pub struct ImportanceScores {
132 pub scores: Vec<f32>,
134 pub threshold: f32,
136 pub metric: ImportanceMetric,
138}
139
140impl ImportanceScores {
141 pub fn sparsity(&self) -> f32 {
143 if self.scores.is_empty() {
144 return 0.0;
145 }
146 let below = self.scores.iter().filter(|&&s| s <= self.threshold).count();
147 below as f32 / self.scores.len() as f32
148 }
149
150 pub fn top_k(&self, k: usize) -> Vec<f32> {
152 let mut sorted = self.scores.clone();
153 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
154 sorted.truncate(k);
155 sorted
156 }
157
158 pub fn stats(&self) -> ScoreStats {
160 if self.scores.is_empty() {
161 return ScoreStats {
162 min: 0.0,
163 max: 0.0,
164 mean: 0.0,
165 median: 0.0,
166 std_dev: 0.0,
167 };
168 }
169
170 let n = self.scores.len();
171 let min = self.scores.iter().cloned().fold(f32::INFINITY, f32::min);
172 let max = self
173 .scores
174 .iter()
175 .cloned()
176 .fold(f32::NEG_INFINITY, f32::max);
177 let mean = self.scores.iter().sum::<f32>() / n as f32;
178
179 let variance = self.scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / n as f32;
180 let std_dev = variance.sqrt();
181
182 let mut sorted = self.scores.clone();
183 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
184 let median = if n % 2 == 0 {
185 (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
186 } else {
187 sorted[n / 2]
188 };
189
190 ScoreStats {
191 min,
192 max,
193 mean,
194 median,
195 std_dev,
196 }
197 }
198}
199
200#[derive(Debug, Clone)]
206pub struct SparsityReport {
207 pub name: String,
208 pub total_params: usize,
209 pub nonzero_params: usize,
210 pub sparsity: f32,
211 pub shape: Vec<usize>,
212}
213
214impl SparsityReport {
215 pub fn compute(tensor: &WeightTensor) -> Self {
217 let total_params = tensor.data.len();
218 let nonzero_params = tensor.data.iter().filter(|&&x| x != 0.0).count();
219 let sparsity = if total_params == 0 {
220 0.0
221 } else {
222 1.0 - nonzero_params as f32 / total_params as f32
223 };
224 Self {
225 name: tensor.name.clone(),
226 total_params,
227 nonzero_params,
228 sparsity,
229 shape: tensor.shape.clone(),
230 }
231 }
232
233 pub fn zero_fraction(&self) -> f32 {
235 self.sparsity
236 }
237
238 pub fn density(&self) -> f32 {
240 1.0 - self.sparsity
241 }
242
243 pub fn summary(&self) -> String {
245 format!(
246 "tensor='{}' shape={:?} total={} nonzero={} sparsity={:.4}",
247 self.name, self.shape, self.total_params, self.nonzero_params, self.sparsity,
248 )
249 }
250}
251
252pub struct ModelSparsitySummary {
258 pub layer_reports: Vec<SparsityReport>,
259 pub total_params: usize,
260 pub total_nonzero: usize,
261 pub overall_sparsity: f32,
262}
263
264impl ModelSparsitySummary {
265 pub fn from_model(tensors: &[WeightTensor]) -> Self {
267 let layer_reports: Vec<SparsityReport> =
268 tensors.iter().map(SparsityReport::compute).collect();
269 let total_params: usize = layer_reports.iter().map(|r| r.total_params).sum();
270 let total_nonzero: usize = layer_reports.iter().map(|r| r.nonzero_params).sum();
271 let overall_sparsity = if total_params == 0 {
272 0.0
273 } else {
274 1.0 - total_nonzero as f32 / total_params as f32
275 };
276 Self {
277 layer_reports,
278 total_params,
279 total_nonzero,
280 overall_sparsity,
281 }
282 }
283
284 pub fn summary(&self) -> String {
286 format!(
287 "layers={} total_params={} total_nonzero={} overall_sparsity={:.4}",
288 self.layer_reports.len(),
289 self.total_params,
290 self.total_nonzero,
291 self.overall_sparsity,
292 )
293 }
294}
295
296pub fn compute_importance(tensor: &WeightTensor, metric: ImportanceMetric) -> ImportanceScores {
305 let scores = match metric {
306 ImportanceMetric::L1Magnitude => tensor.data.iter().map(|x| x.abs()).collect(),
307 ImportanceMetric::L2Magnitude => tensor.data.iter().map(|x| x * x).collect(),
308 ImportanceMetric::TaylorProxy => tensor.data.iter().map(|x| x * x).collect(),
309 ImportanceMetric::Random { seed } => {
310 let mut state = seed;
311 tensor.data.iter().map(|_| lcg_next(&mut state)).collect()
312 }
313 };
314 ImportanceScores {
315 scores,
316 threshold: 0.0,
317 metric,
318 }
319}
320
321pub fn prune_tensor(
325 tensor: &WeightTensor,
326 config: &PruningConfig,
327) -> Result<(WeightTensor, Vec<f32>), PruningError> {
328 let mut cloned = tensor.clone();
329 let mask = prune_tensor_inplace(&mut cloned, config)?;
330 Ok((cloned, mask))
331}
332
333pub fn prune_tensor_inplace(
335 tensor: &mut WeightTensor,
336 config: &PruningConfig,
337) -> Result<Vec<f32>, PruningError> {
338 validate_sparsity(config.sparsity)?;
339
340 let n = tensor.data.len();
341 if n == 0 {
342 return Err(PruningError::EmptyTensor(tensor.name.clone()));
343 }
344
345 match config.granularity {
346 PruningGranularity::Unstructured => prune_unstructured(tensor, config),
347 PruningGranularity::StructuredRow => prune_structured(tensor, config, true),
348 PruningGranularity::StructuredColumn => prune_structured(tensor, config, false),
349 }
350}
351
352pub fn prune_model(
354 tensors: &[WeightTensor],
355 config: &PruningConfig,
356) -> Result<Vec<WeightTensor>, PruningError> {
357 tensors
358 .iter()
359 .map(|t| {
360 let (pruned, _mask) = prune_tensor(t, config)?;
361 Ok(pruned)
362 })
363 .collect()
364}
365
366pub fn model_sparsity_report(tensors: &[WeightTensor]) -> Vec<SparsityReport> {
368 tensors.iter().map(SparsityReport::compute).collect()
369}
370
371#[inline]
377fn lcg_next(state: &mut u64) -> f32 {
378 *state = state
379 .wrapping_mul(6_364_136_223_846_793_005)
380 .wrapping_add(1_442_695_040_888_963_407);
381 let bits = (*state >> 32) as u32;
382 (bits as f32) / (u32::MAX as f32 + 1.0)
383}
384
385fn validate_sparsity(sparsity: f32) -> Result<(), PruningError> {
386 if !(0.0..1.0).contains(&sparsity) {
387 return Err(PruningError::InvalidSparsity(sparsity));
388 }
389 Ok(())
390}
391
392fn compute_element_scores(data: &[f32], metric: ImportanceMetric) -> Vec<f32> {
394 match metric {
395 ImportanceMetric::L1Magnitude => data.iter().map(|x| x.abs()).collect(),
396 ImportanceMetric::L2Magnitude => data.iter().map(|x| x * x).collect(),
397 ImportanceMetric::TaylorProxy => data.iter().map(|x| x * x).collect(),
398 ImportanceMetric::Random { seed } => {
399 let mut state = seed;
400 data.iter().map(|_| lcg_next(&mut state)).collect()
401 }
402 }
403}
404
405fn prune_unstructured(
407 tensor: &mut WeightTensor,
408 config: &PruningConfig,
409) -> Result<Vec<f32>, PruningError> {
410 let n = tensor.data.len();
411 let scores = compute_element_scores(&tensor.data, config.metric);
412
413 let num_to_prune = (config.sparsity * n as f32).floor() as usize;
415 let max_to_prune = n.saturating_sub(config.min_nonzero);
417 if config.min_nonzero > n {
418 return Err(PruningError::BelowMinNonzero(config.min_nonzero, n));
419 }
420 let num_to_prune = num_to_prune.min(max_to_prune);
421
422 if num_to_prune == 0 {
423 return Ok(vec![1.0f32; n]);
425 }
426
427 let mut indexed: Vec<(usize, f32)> = scores.iter().cloned().enumerate().collect();
429 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
430
431 let threshold = indexed[num_to_prune - 1].1;
432
433 let mut mask = vec![1.0f32; n];
435 let mut pruned_count = 0usize;
436 for (orig_idx, score) in &indexed {
437 if pruned_count >= num_to_prune {
438 break;
439 }
440 if *score <= threshold {
441 mask[*orig_idx] = 0.0;
442 tensor.data[*orig_idx] = 0.0;
443 pruned_count += 1;
444 }
445 }
446
447 Ok(mask)
448}
449
450fn prune_structured(
452 tensor: &mut WeightTensor,
453 config: &PruningConfig,
454 prune_rows: bool,
455) -> Result<Vec<f32>, PruningError> {
456 if tensor.shape.len() != 2 {
457 return Err(PruningError::NotTwoDimensional(tensor.shape.clone()));
458 }
459
460 let rows = tensor.shape[0];
461 let cols = tensor.shape[1];
462 let (num_units, unit_size) = if prune_rows {
463 (rows, cols)
464 } else {
465 (cols, rows)
466 };
467
468 let unit_scores: Vec<f32> = (0..num_units)
470 .map(|u| {
471 let slice: Vec<f32> = if prune_rows {
472 tensor.data[u * cols..(u + 1) * cols].to_vec()
473 } else {
474 (0..rows).map(|r| tensor.data[r * cols + u]).collect()
476 };
477 match config.metric {
478 ImportanceMetric::L1Magnitude => slice.iter().map(|x| x.abs()).sum::<f32>(),
479 ImportanceMetric::L2Magnitude => slice.iter().map(|x| x * x).sum::<f32>().sqrt(),
480 ImportanceMetric::TaylorProxy => slice.iter().map(|x| x * x).sum::<f32>().sqrt(),
481 ImportanceMetric::Random { seed } => {
482 let mut state = seed.wrapping_add(u as u64);
483 lcg_next(&mut state)
484 }
485 }
486 })
487 .collect();
488
489 let num_to_prune = (config.sparsity * num_units as f32).floor() as usize;
490 let max_to_prune = num_units.saturating_sub(config.min_nonzero.div_ceil(unit_size));
491 if config.min_nonzero > num_units * unit_size {
492 return Err(PruningError::BelowMinNonzero(
493 config.min_nonzero,
494 num_units * unit_size,
495 ));
496 }
497 let num_to_prune = num_to_prune.min(max_to_prune);
498
499 if num_to_prune == 0 {
500 return Ok(vec![1.0f32; tensor.data.len()]);
501 }
502
503 let mut indexed: Vec<(usize, f32)> = unit_scores.iter().cloned().enumerate().collect();
505 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
506
507 let mut units_to_prune = std::collections::HashSet::new();
509 for (unit_idx, _score) in indexed.iter().take(num_to_prune) {
510 units_to_prune.insert(*unit_idx);
511 }
512
513 let total = tensor.data.len();
515 let mut mask = vec![1.0f32; total];
516
517 for (idx, slot) in mask.iter_mut().enumerate().take(total) {
518 let unit = if prune_rows { idx / cols } else { idx % cols };
519 if units_to_prune.contains(&unit) {
520 *slot = 0.0;
521 tensor.data[idx] = 0.0;
522 }
523 }
524
525 Ok(mask)
526}
527
528#[cfg(test)]
533mod tests {
534 use super::*;
535
536 fn make_tensor(name: &str, data: Vec<f32>, shape: Vec<usize>) -> WeightTensor {
537 WeightTensor::new(name, data, shape)
538 }
539
540 #[test]
541 fn lcg_values_in_unit_interval() {
542 let mut state = 12345u64;
543 for _ in 0..1000 {
544 let v = lcg_next(&mut state);
545 assert!((0.0..=1.0).contains(&v));
546 }
547 }
548
549 #[test]
550 fn compute_importance_l1_basic() {
551 let t = make_tensor("w", vec![-2.0, 1.0, -0.5], vec![3]);
552 let scores = compute_importance(&t, ImportanceMetric::L1Magnitude);
553 assert!((scores.scores[0] - 2.0).abs() < 1e-6);
554 assert!((scores.scores[1] - 1.0).abs() < 1e-6);
555 assert!((scores.scores[2] - 0.5).abs() < 1e-6);
556 }
557
558 #[test]
559 fn unstructured_prune_zeroes_smallest() {
560 let data: Vec<f32> = (1..=10).map(|x| x as f32).collect();
561 let t = make_tensor("w", data, vec![10]);
562 let config = PruningConfig::unstructured_l1(0.3);
563 let (pruned, mask) = prune_tensor(&t, &config).expect("prune ok");
564 assert_eq!(pruned.data[0], 0.0);
566 assert_eq!(pruned.data[1], 0.0);
567 assert_eq!(pruned.data[2], 0.0);
568 assert!(pruned.data[9] != 0.0);
569 assert!(mask.iter().all(|&m| m == 0.0 || m == 1.0));
570 }
571}