Skip to main content

oxibonsai_model/
pruning.rs

1//! Weight importance analysis and structured/unstructured pruning.
2//!
3//! Pruning reduces model size and inference cost by zeroing or removing
4//! less-important weights. This module implements:
5//! - Magnitude-based unstructured pruning
6//! - Structured pruning (entire rows/columns)
7//! - Importance scoring (L1, L2, gradient sensitivity approximation)
8//! - Sparsity analysis and reporting
9
10use crate::model_merge::WeightTensor;
11use thiserror::Error;
12
13// ──────────────────────────────────────────────────────────────────
14// Errors
15// ──────────────────────────────────────────────────────────────────
16
17/// Errors that can occur during pruning operations.
18#[derive(Debug, Error)]
19pub enum PruningError {
20    #[error("sparsity {0} must be in [0.0, 1.0)")]
21    InvalidSparsity(f32),
22    #[error("empty tensor: '{0}'")]
23    EmptyTensor(String),
24    #[error("structured pruning requires 2D tensor, got shape {0:?}")]
25    NotTwoDimensional(Vec<usize>),
26    #[error("cannot prune below min_nonzero={0} with {1} total elements")]
27    BelowMinNonzero(usize, usize),
28}
29
30// ──────────────────────────────────────────────────────────────────
31// ImportanceMetric
32// ──────────────────────────────────────────────────────────────────
33
34/// How to compute weight importance.
35#[derive(Debug, Clone, Copy, PartialEq)]
36pub enum ImportanceMetric {
37    /// L1 norm of each weight (|w|).
38    L1Magnitude,
39    /// L2 norm of each weight (w^2).
40    L2Magnitude,
41    /// Taylor first-order approximation: |w * gradient|.
42    /// Since we don't run gradients at inference time, uses |w| * |w| as a proxy.
43    TaylorProxy,
44    /// Random importance (for baseline/ablation).
45    /// Uses a seeded LCG for reproducibility.
46    Random { seed: u64 },
47}
48
49// ──────────────────────────────────────────────────────────────────
50// PruningGranularity
51// ──────────────────────────────────────────────────────────────────
52
53/// Pruning granularity.
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum PruningGranularity {
56    /// Zero individual weights.
57    Unstructured,
58    /// Zero entire rows (output neurons).
59    StructuredRow,
60    /// Zero entire columns (input features).
61    StructuredColumn,
62}
63
64// ──────────────────────────────────────────────────────────────────
65// PruningConfig
66// ──────────────────────────────────────────────────────────────────
67
68/// Configuration for a pruning pass.
69#[derive(Debug, Clone)]
70pub struct PruningConfig {
71    /// Target fraction of zeros (0.0 - 1.0).
72    pub sparsity: f32,
73    /// Which metric to use for computing importance scores.
74    pub metric: ImportanceMetric,
75    /// Whether to prune individual weights or entire rows/columns.
76    pub granularity: PruningGranularity,
77    /// Minimum number of non-zero elements to keep (safety floor).
78    pub min_nonzero: usize,
79}
80
81impl PruningConfig {
82    /// Create a new pruning config with the given parameters.
83    pub fn new(sparsity: f32, metric: ImportanceMetric, granularity: PruningGranularity) -> Self {
84        Self {
85            sparsity,
86            metric,
87            granularity,
88            min_nonzero: 1,
89        }
90    }
91
92    /// Convenience: unstructured L1-magnitude pruning at the given sparsity.
93    pub fn unstructured_l1(sparsity: f32) -> Self {
94        Self::new(
95            sparsity,
96            ImportanceMetric::L1Magnitude,
97            PruningGranularity::Unstructured,
98        )
99    }
100
101    /// Convenience: structured row pruning using L2 norm at the given sparsity.
102    pub fn structured_row_l2(sparsity: f32) -> Self {
103        Self::new(
104            sparsity,
105            ImportanceMetric::L2Magnitude,
106            PruningGranularity::StructuredRow,
107        )
108    }
109}
110
111// ──────────────────────────────────────────────────────────────────
112// ScoreStats
113// ──────────────────────────────────────────────────────────────────
114
115/// Summary statistics over importance scores.
116#[derive(Debug, Clone)]
117pub struct ScoreStats {
118    pub min: f32,
119    pub max: f32,
120    pub mean: f32,
121    pub median: f32,
122    pub std_dev: f32,
123}
124
125// ──────────────────────────────────────────────────────────────────
126// ImportanceScores
127// ──────────────────────────────────────────────────────────────────
128
129/// Importance score for each element (or row/column for structured).
130#[derive(Debug, Clone)]
131pub struct ImportanceScores {
132    /// One score per element (unstructured) or row/column (structured).
133    pub scores: Vec<f32>,
134    /// Score below which elements are pruned.
135    pub threshold: f32,
136    /// The metric used to generate these scores.
137    pub metric: ImportanceMetric,
138}
139
140impl ImportanceScores {
141    /// Fraction of scores at or below the threshold.
142    pub fn sparsity(&self) -> f32 {
143        if self.scores.is_empty() {
144            return 0.0;
145        }
146        let below = self.scores.iter().filter(|&&s| s <= self.threshold).count();
147        below as f32 / self.scores.len() as f32
148    }
149
150    /// Return the top-k scores in descending order.
151    pub fn top_k(&self, k: usize) -> Vec<f32> {
152        let mut sorted = self.scores.clone();
153        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
154        sorted.truncate(k);
155        sorted
156    }
157
158    /// Compute summary statistics over all scores.
159    pub fn stats(&self) -> ScoreStats {
160        if self.scores.is_empty() {
161            return ScoreStats {
162                min: 0.0,
163                max: 0.0,
164                mean: 0.0,
165                median: 0.0,
166                std_dev: 0.0,
167            };
168        }
169
170        let n = self.scores.len();
171        let min = self.scores.iter().cloned().fold(f32::INFINITY, f32::min);
172        let max = self
173            .scores
174            .iter()
175            .cloned()
176            .fold(f32::NEG_INFINITY, f32::max);
177        let mean = self.scores.iter().sum::<f32>() / n as f32;
178
179        let variance = self.scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / n as f32;
180        let std_dev = variance.sqrt();
181
182        let mut sorted = self.scores.clone();
183        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
184        let median = if n % 2 == 0 {
185            (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
186        } else {
187            sorted[n / 2]
188        };
189
190        ScoreStats {
191            min,
192            max,
193            mean,
194            median,
195            std_dev,
196        }
197    }
198}
199
200// ──────────────────────────────────────────────────────────────────
201// SparsityReport
202// ──────────────────────────────────────────────────────────────────
203
204/// Analyze sparsity of a tensor.
205#[derive(Debug, Clone)]
206pub struct SparsityReport {
207    pub name: String,
208    pub total_params: usize,
209    pub nonzero_params: usize,
210    pub sparsity: f32,
211    pub shape: Vec<usize>,
212}
213
214impl SparsityReport {
215    /// Compute a sparsity report for the given tensor.
216    pub fn compute(tensor: &WeightTensor) -> Self {
217        let total_params = tensor.data.len();
218        let nonzero_params = tensor.data.iter().filter(|&&x| x != 0.0).count();
219        let sparsity = if total_params == 0 {
220            0.0
221        } else {
222            1.0 - nonzero_params as f32 / total_params as f32
223        };
224        Self {
225            name: tensor.name.clone(),
226            total_params,
227            nonzero_params,
228            sparsity,
229            shape: tensor.shape.clone(),
230        }
231    }
232
233    /// Fraction of zeros — same as `sparsity`.
234    pub fn zero_fraction(&self) -> f32 {
235        self.sparsity
236    }
237
238    /// Fraction of non-zero elements.
239    pub fn density(&self) -> f32 {
240        1.0 - self.sparsity
241    }
242
243    /// Human-readable one-line summary.
244    pub fn summary(&self) -> String {
245        format!(
246            "tensor='{}' shape={:?} total={} nonzero={} sparsity={:.4}",
247            self.name, self.shape, self.total_params, self.nonzero_params, self.sparsity,
248        )
249    }
250}
251
252// ──────────────────────────────────────────────────────────────────
253// ModelSparsitySummary
254// ──────────────────────────────────────────────────────────────────
255
256/// Aggregate sparsity across all layers.
257pub struct ModelSparsitySummary {
258    pub layer_reports: Vec<SparsityReport>,
259    pub total_params: usize,
260    pub total_nonzero: usize,
261    pub overall_sparsity: f32,
262}
263
264impl ModelSparsitySummary {
265    /// Build a summary from a slice of weight tensors.
266    pub fn from_model(tensors: &[WeightTensor]) -> Self {
267        let layer_reports: Vec<SparsityReport> =
268            tensors.iter().map(SparsityReport::compute).collect();
269        let total_params: usize = layer_reports.iter().map(|r| r.total_params).sum();
270        let total_nonzero: usize = layer_reports.iter().map(|r| r.nonzero_params).sum();
271        let overall_sparsity = if total_params == 0 {
272            0.0
273        } else {
274            1.0 - total_nonzero as f32 / total_params as f32
275        };
276        Self {
277            layer_reports,
278            total_params,
279            total_nonzero,
280            overall_sparsity,
281        }
282    }
283
284    /// Human-readable summary of the entire model's sparsity.
285    pub fn summary(&self) -> String {
286        format!(
287            "layers={} total_params={} total_nonzero={} overall_sparsity={:.4}",
288            self.layer_reports.len(),
289            self.total_params,
290            self.total_nonzero,
291            self.overall_sparsity,
292        )
293    }
294}
295
296// ──────────────────────────────────────────────────────────────────
297// Public API functions
298// ──────────────────────────────────────────────────────────────────
299
300/// Compute importance scores for a weight tensor.
301///
302/// For unstructured metrics (L1, L2, TaylorProxy, Random), one score per element.
303/// The threshold field is set to 0.0 (no pruning decision is made here).
304pub fn compute_importance(tensor: &WeightTensor, metric: ImportanceMetric) -> ImportanceScores {
305    let scores = match metric {
306        ImportanceMetric::L1Magnitude => tensor.data.iter().map(|x| x.abs()).collect(),
307        ImportanceMetric::L2Magnitude => tensor.data.iter().map(|x| x * x).collect(),
308        ImportanceMetric::TaylorProxy => tensor.data.iter().map(|x| x * x).collect(),
309        ImportanceMetric::Random { seed } => {
310            let mut state = seed;
311            tensor.data.iter().map(|_| lcg_next(&mut state)).collect()
312        }
313    };
314    ImportanceScores {
315        scores,
316        threshold: 0.0,
317        metric,
318    }
319}
320
321/// Prune a tensor: zero out low-importance weights.
322///
323/// Returns the pruned tensor and a mask (1.0 = kept, 0.0 = pruned).
324pub fn prune_tensor(
325    tensor: &WeightTensor,
326    config: &PruningConfig,
327) -> Result<(WeightTensor, Vec<f32>), PruningError> {
328    let mut cloned = tensor.clone();
329    let mask = prune_tensor_inplace(&mut cloned, config)?;
330    Ok((cloned, mask))
331}
332
333/// Prune a tensor in-place, returning only the mask.
334pub fn prune_tensor_inplace(
335    tensor: &mut WeightTensor,
336    config: &PruningConfig,
337) -> Result<Vec<f32>, PruningError> {
338    validate_sparsity(config.sparsity)?;
339
340    let n = tensor.data.len();
341    if n == 0 {
342        return Err(PruningError::EmptyTensor(tensor.name.clone()));
343    }
344
345    match config.granularity {
346        PruningGranularity::Unstructured => prune_unstructured(tensor, config),
347        PruningGranularity::StructuredRow => prune_structured(tensor, config, true),
348        PruningGranularity::StructuredColumn => prune_structured(tensor, config, false),
349    }
350}
351
352/// Prune a full model (all tensors) with a shared sparsity config.
353pub fn prune_model(
354    tensors: &[WeightTensor],
355    config: &PruningConfig,
356) -> Result<Vec<WeightTensor>, PruningError> {
357    tensors
358        .iter()
359        .map(|t| {
360            let (pruned, _mask) = prune_tensor(t, config)?;
361            Ok(pruned)
362        })
363        .collect()
364}
365
366/// Compute sparsity reports for all tensors in a model.
367pub fn model_sparsity_report(tensors: &[WeightTensor]) -> Vec<SparsityReport> {
368    tensors.iter().map(SparsityReport::compute).collect()
369}
370
371// ──────────────────────────────────────────────────────────────────
372// Internal helpers
373// ──────────────────────────────────────────────────────────────────
374
375/// Deterministic LCG producing values in `[0.0, 1.0)`.
376#[inline]
377fn lcg_next(state: &mut u64) -> f32 {
378    *state = state
379        .wrapping_mul(6_364_136_223_846_793_005)
380        .wrapping_add(1_442_695_040_888_963_407);
381    let bits = (*state >> 32) as u32;
382    (bits as f32) / (u32::MAX as f32 + 1.0)
383}
384
385fn validate_sparsity(sparsity: f32) -> Result<(), PruningError> {
386    if !(0.0..1.0).contains(&sparsity) {
387        return Err(PruningError::InvalidSparsity(sparsity));
388    }
389    Ok(())
390}
391
392/// Compute element-wise importance scores as a flat Vec<f32>.
393fn compute_element_scores(data: &[f32], metric: ImportanceMetric) -> Vec<f32> {
394    match metric {
395        ImportanceMetric::L1Magnitude => data.iter().map(|x| x.abs()).collect(),
396        ImportanceMetric::L2Magnitude => data.iter().map(|x| x * x).collect(),
397        ImportanceMetric::TaylorProxy => data.iter().map(|x| x * x).collect(),
398        ImportanceMetric::Random { seed } => {
399            let mut state = seed;
400            data.iter().map(|_| lcg_next(&mut state)).collect()
401        }
402    }
403}
404
405/// Unstructured pruning: zero individual elements below threshold.
406fn prune_unstructured(
407    tensor: &mut WeightTensor,
408    config: &PruningConfig,
409) -> Result<Vec<f32>, PruningError> {
410    let n = tensor.data.len();
411    let scores = compute_element_scores(&tensor.data, config.metric);
412
413    // Determine how many elements to prune
414    let num_to_prune = (config.sparsity * n as f32).floor() as usize;
415    // Ensure min_nonzero constraint
416    let max_to_prune = n.saturating_sub(config.min_nonzero);
417    if config.min_nonzero > n {
418        return Err(PruningError::BelowMinNonzero(config.min_nonzero, n));
419    }
420    let num_to_prune = num_to_prune.min(max_to_prune);
421
422    if num_to_prune == 0 {
423        // No pruning needed — full mask of ones
424        return Ok(vec![1.0f32; n]);
425    }
426
427    // Find threshold: sort scores to find the num_to_prune-th smallest
428    let mut indexed: Vec<(usize, f32)> = scores.iter().cloned().enumerate().collect();
429    indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
430
431    let threshold = indexed[num_to_prune - 1].1;
432
433    // Build mask: prune the num_to_prune lowest-scoring elements
434    let mut mask = vec![1.0f32; n];
435    let mut pruned_count = 0usize;
436    for (orig_idx, score) in &indexed {
437        if pruned_count >= num_to_prune {
438            break;
439        }
440        if *score <= threshold {
441            mask[*orig_idx] = 0.0;
442            tensor.data[*orig_idx] = 0.0;
443            pruned_count += 1;
444        }
445    }
446
447    Ok(mask)
448}
449
450/// Structured pruning: zero entire rows or columns.
451fn prune_structured(
452    tensor: &mut WeightTensor,
453    config: &PruningConfig,
454    prune_rows: bool,
455) -> Result<Vec<f32>, PruningError> {
456    if tensor.shape.len() != 2 {
457        return Err(PruningError::NotTwoDimensional(tensor.shape.clone()));
458    }
459
460    let rows = tensor.shape[0];
461    let cols = tensor.shape[1];
462    let (num_units, unit_size) = if prune_rows {
463        (rows, cols)
464    } else {
465        (cols, rows)
466    };
467
468    // Compute per-unit (row or column) importance score
469    let unit_scores: Vec<f32> = (0..num_units)
470        .map(|u| {
471            let slice: Vec<f32> = if prune_rows {
472                tensor.data[u * cols..(u + 1) * cols].to_vec()
473            } else {
474                // column: gather every cols-th element
475                (0..rows).map(|r| tensor.data[r * cols + u]).collect()
476            };
477            match config.metric {
478                ImportanceMetric::L1Magnitude => slice.iter().map(|x| x.abs()).sum::<f32>(),
479                ImportanceMetric::L2Magnitude => slice.iter().map(|x| x * x).sum::<f32>().sqrt(),
480                ImportanceMetric::TaylorProxy => slice.iter().map(|x| x * x).sum::<f32>().sqrt(),
481                ImportanceMetric::Random { seed } => {
482                    let mut state = seed.wrapping_add(u as u64);
483                    lcg_next(&mut state)
484                }
485            }
486        })
487        .collect();
488
489    let num_to_prune = (config.sparsity * num_units as f32).floor() as usize;
490    let max_to_prune = num_units.saturating_sub(config.min_nonzero.div_ceil(unit_size));
491    if config.min_nonzero > num_units * unit_size {
492        return Err(PruningError::BelowMinNonzero(
493            config.min_nonzero,
494            num_units * unit_size,
495        ));
496    }
497    let num_to_prune = num_to_prune.min(max_to_prune);
498
499    if num_to_prune == 0 {
500        return Ok(vec![1.0f32; tensor.data.len()]);
501    }
502
503    // Sort units by score ascending
504    let mut indexed: Vec<(usize, f32)> = unit_scores.iter().cloned().enumerate().collect();
505    indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
506
507    // Mark units to prune
508    let mut units_to_prune = std::collections::HashSet::new();
509    for (unit_idx, _score) in indexed.iter().take(num_to_prune) {
510        units_to_prune.insert(*unit_idx);
511    }
512
513    // Build mask and zero tensor
514    let total = tensor.data.len();
515    let mut mask = vec![1.0f32; total];
516
517    for (idx, slot) in mask.iter_mut().enumerate().take(total) {
518        let unit = if prune_rows { idx / cols } else { idx % cols };
519        if units_to_prune.contains(&unit) {
520            *slot = 0.0;
521            tensor.data[idx] = 0.0;
522        }
523    }
524
525    Ok(mask)
526}
527
528// ──────────────────────────────────────────────────────────────────
529// In-module smoke tests
530// ──────────────────────────────────────────────────────────────────
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    fn make_tensor(name: &str, data: Vec<f32>, shape: Vec<usize>) -> WeightTensor {
537        WeightTensor::new(name, data, shape)
538    }
539
540    #[test]
541    fn lcg_values_in_unit_interval() {
542        let mut state = 12345u64;
543        for _ in 0..1000 {
544            let v = lcg_next(&mut state);
545            assert!((0.0..=1.0).contains(&v));
546        }
547    }
548
549    #[test]
550    fn compute_importance_l1_basic() {
551        let t = make_tensor("w", vec![-2.0, 1.0, -0.5], vec![3]);
552        let scores = compute_importance(&t, ImportanceMetric::L1Magnitude);
553        assert!((scores.scores[0] - 2.0).abs() < 1e-6);
554        assert!((scores.scores[1] - 1.0).abs() < 1e-6);
555        assert!((scores.scores[2] - 0.5).abs() < 1e-6);
556    }
557
558    #[test]
559    fn unstructured_prune_zeroes_smallest() {
560        let data: Vec<f32> = (1..=10).map(|x| x as f32).collect();
561        let t = make_tensor("w", data, vec![10]);
562        let config = PruningConfig::unstructured_l1(0.3);
563        let (pruned, mask) = prune_tensor(&t, &config).expect("prune ok");
564        // 3 lowest elements (1,2,3) should be zero
565        assert_eq!(pruned.data[0], 0.0);
566        assert_eq!(pruned.data[1], 0.0);
567        assert_eq!(pruned.data[2], 0.0);
568        assert!(pruned.data[9] != 0.0);
569        assert!(mask.iter().all(|&m| m == 0.0 || m == 1.0));
570    }
571}