Skip to main content

oxicuda_quant/pruning/
structured.rs

1//! # Structured Pruning
2//!
3//! Removes entire structural units — channels, filters, or attention heads —
4//! rather than individual weights.  Structured pruning produces weight matrices
5//! with rows or columns of zeros that can be physically removed, yielding
6//! real hardware speedups (unlike unstructured sparsity which requires
7//! special sparse kernels).
8//!
9//! ## Granularities
10//!
11//! | Granularity | Unit removed        | Layout assumption                   |
12//! |-------------|---------------------|-------------------------------------|
13//! | `Channel`   | Output channel      | `[n_out, n_in]` row-major           |
14//! | `Filter`    | Convolutional filter| `[n_filters, filter_size]` flat     |
15//! | `Head`      | Attention head      | `[n_heads × head_dim, ...]`         |
16
17use crate::error::{QuantError, QuantResult};
18use crate::pruning::mask::SparseMask;
19
20// ─── Granularity ─────────────────────────────────────────────────────────────
21
22/// Structural unit to remove during pruning.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum PruneGranularity {
25    /// Prune entire output channels (rows) of a weight matrix `[n_out, n_in]`.
26    Channel {
27        /// Number of output channels (rows of the weight matrix).
28        n_out: usize,
29        /// Number of input channels (columns of the weight matrix).
30        n_in: usize,
31    },
32    /// Prune entire convolutional filters, each of length `filter_size`.
33    Filter {
34        /// Number of filters (= output channels for a conv layer).
35        n_filters: usize,
36        /// Flattened size of each filter (in_channels × kH × kW).
37        filter_size: usize,
38    },
39    /// Prune entire attention heads in a projection matrix.
40    Head {
41        /// Number of attention heads.
42        n_heads: usize,
43        /// Dimension per head.
44        head_dim: usize,
45    },
46}
47
48// ─── StructuredPruner ────────────────────────────────────────────────────────
49
50/// Removes structural units based on L2 norm importance.
51///
52/// Each unit (channel / filter / head) receives a scalar importance score
53/// equal to the L2 norm of its weights.  The bottom `target_sparsity` fraction
54/// of units are pruned.
55#[derive(Debug, Clone)]
56pub struct StructuredPruner {
57    /// Fraction of structural units to prune ∈ [0, 1).
58    pub target_sparsity: f32,
59    /// Structural granularity.
60    pub granularity: PruneGranularity,
61}
62
63impl StructuredPruner {
64    /// Create a new structured pruner.
65    ///
66    /// # Panics
67    ///
68    /// Panics if `target_sparsity` is not in `[0, 1)`.
69    #[must_use]
70    pub fn new(target_sparsity: f32, granularity: PruneGranularity) -> Self {
71        assert!(
72            (0.0..1.0).contains(&target_sparsity),
73            "target_sparsity must be in [0, 1), got {target_sparsity}"
74        );
75        Self {
76            target_sparsity,
77            granularity,
78        }
79    }
80
81    /// Compute the pruning mask for `weights`.
82    ///
83    /// Returns a flat mask the same length as `weights`.
84    ///
85    /// # Errors
86    ///
87    /// * [`QuantError::EmptyInput`]        — `weights` is empty.
88    /// * [`QuantError::DimensionMismatch`] — `weights.len()` is inconsistent with the granularity.
89    /// * [`QuantError::AllZeroPruning`]    — all units would be pruned.
90    pub fn compute_mask(&self, weights: &[f32]) -> QuantResult<SparseMask> {
91        if weights.is_empty() {
92            return Err(QuantError::EmptyInput("StructuredPruner::compute_mask"));
93        }
94        match self.granularity {
95            PruneGranularity::Channel { n_out, n_in } => self.channel_mask(weights, n_out, n_in),
96            PruneGranularity::Filter {
97                n_filters,
98                filter_size,
99            } => self.channel_mask(weights, n_filters, filter_size),
100            PruneGranularity::Head { n_heads, head_dim } => {
101                self.channel_mask(weights, n_heads, head_dim)
102            }
103        }
104    }
105
106    /// Apply structured pruning in-place.
107    ///
108    /// Returns the mask applied.
109    ///
110    /// # Errors
111    ///
112    /// Propagates errors from [`compute_mask`](Self::compute_mask).
113    pub fn prune(&self, weights: &mut [f32]) -> QuantResult<SparseMask> {
114        let mask = self.compute_mask(weights)?;
115        mask.apply_in_place(weights);
116        Ok(mask)
117    }
118
119    /// Compute per-unit L2 norms for a matrix `[n_units, unit_size]`.
120    ///
121    /// Returns a `Vec<f32>` of length `n_units`.
122    #[must_use]
123    pub fn unit_l2_norms(weights: &[f32], n_units: usize, unit_size: usize) -> Vec<f32> {
124        (0..n_units)
125            .map(|u| {
126                let base = u * unit_size;
127                weights[base..base + unit_size]
128                    .iter()
129                    .map(|&w| w * w)
130                    .sum::<f32>()
131                    .sqrt()
132            })
133            .collect()
134    }
135
136    /// Return the indices of units to prune, sorted ascending by L2 norm.
137    ///
138    /// Prunes the `n_prune` units with the smallest L2 norm.
139    #[must_use]
140    pub fn pruned_unit_indices(norms: &[f32], n_prune: usize) -> Vec<usize> {
141        let mut idx: Vec<usize> = (0..norms.len()).collect();
142        idx.sort_unstable_by(|&a, &b| {
143            norms[a]
144                .partial_cmp(&norms[b])
145                .unwrap_or(std::cmp::Ordering::Equal)
146        });
147        idx[..n_prune].to_vec()
148    }
149
150    // ── Private ───────────────────────────────────────────────────────────────
151
152    fn channel_mask(
153        &self,
154        weights: &[f32],
155        n_units: usize,
156        unit_size: usize,
157    ) -> QuantResult<SparseMask> {
158        let expected = n_units * unit_size;
159        if weights.len() != expected {
160            return Err(QuantError::DimensionMismatch {
161                expected,
162                got: weights.len(),
163            });
164        }
165
166        let n_prune = ((n_units as f32) * self.target_sparsity).ceil() as usize;
167        if n_prune >= n_units {
168            return Err(QuantError::AllZeroPruning {
169                threshold: self.target_sparsity,
170                n: n_units,
171            });
172        }
173
174        let norms = Self::unit_l2_norms(weights, n_units, unit_size);
175        let pruned = Self::pruned_unit_indices(&norms, n_prune);
176
177        let mut mask = vec![true; weights.len()];
178        for u in pruned {
179            let base = u * unit_size;
180            mask[base..base + unit_size].fill(false);
181        }
182        Ok(SparseMask { mask })
183    }
184}
185
186// ─── Tests ───────────────────────────────────────────────────────────────────
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use approx::assert_abs_diff_eq;
192
193    #[test]
194    fn channel_prune_50_percent() {
195        // 4 output channels × 4 input channels = 16 weights.
196        // Channels 0, 1 have small norm; channels 2, 3 have large norm.
197        let n_out = 4;
198        let n_in = 4;
199        let mut w = vec![0.0_f32; n_out * n_in];
200        // channel 2: all 1.0 → norm = 2.0
201        for j in 0..n_in {
202            w[2 * n_in + j] = 1.0;
203        }
204        // channel 3: all 2.0 → norm = 4.0
205        for j in 0..n_in {
206            w[3 * n_in + j] = 2.0;
207        }
208
209        let p = StructuredPruner::new(0.5, PruneGranularity::Channel { n_out, n_in });
210        let mask = p.compute_mask(&w).unwrap();
211        // Sparsity = 2 channels / 4 channels = 0.5 in channel count
212        // = 8 weights / 16 weights = 0.5 sparsity
213        assert_abs_diff_eq!(mask.sparsity(), 0.5, epsilon = 0.01);
214        // Channels 0, 1 (near-zero norm) should be pruned.
215        for j in 0..n_in {
216            assert!(!mask.mask[j], "ch 0 elem {j} should be pruned");
217            assert!(!mask.mask[n_in + j], "ch 1 elem {j} should be pruned");
218            assert!(mask.mask[2 * n_in + j], "ch 2 elem {j} should be active");
219            assert!(mask.mask[3 * n_in + j], "ch 3 elem {j} should be active");
220        }
221    }
222
223    #[test]
224    fn prune_in_place_zeroes_low_norm_units() {
225        let n_out = 3;
226        let n_in = 2;
227        let p = StructuredPruner::new(0.33, PruneGranularity::Channel { n_out, n_in });
228        let mut w = vec![
229            0.01_f32, 0.01, // channel 0: near zero
230            1.0, 1.0, // channel 1: large norm
231            0.5, 0.5, // channel 2: medium norm
232        ];
233        let _mask = p.prune(&mut w).unwrap();
234        // Channel 0 should be zeroed.
235        assert_abs_diff_eq!(w[0], 0.0, epsilon = 1e-7);
236        assert_abs_diff_eq!(w[1], 0.0, epsilon = 1e-7);
237    }
238
239    #[test]
240    fn unit_l2_norms_correct() {
241        let w = vec![3.0_f32, 4.0, 0.0, 1.0]; // [2 units, 2 each]
242        let norms = StructuredPruner::unit_l2_norms(&w, 2, 2);
243        assert_abs_diff_eq!(norms[0], 5.0, epsilon = 1e-5); // sqrt(9+16)=5
244        assert_abs_diff_eq!(norms[1], 1.0, epsilon = 1e-5); // sqrt(0+1)=1
245    }
246
247    #[test]
248    fn filter_pruning_behaves_as_channel() {
249        let n_filters = 4;
250        let filter_size = 9;
251        let p = StructuredPruner::new(
252            0.25,
253            PruneGranularity::Filter {
254                n_filters,
255                filter_size,
256            },
257        );
258        let mut w = vec![0.0_f32; n_filters * filter_size];
259        // Give filter 2 large values.
260        for k in 0..filter_size {
261            w[2 * filter_size + k] = 1.0;
262        }
263        let mask = p.compute_mask(&w).unwrap();
264        assert_eq!(mask.len(), n_filters * filter_size);
265    }
266
267    #[test]
268    fn dimension_mismatch_error() {
269        let p = StructuredPruner::new(0.5, PruneGranularity::Channel { n_out: 4, n_in: 4 });
270        let w = vec![0.5_f32; 12]; // 3×4, not 4×4
271        assert!(matches!(
272            p.compute_mask(&w),
273            Err(QuantError::DimensionMismatch { .. })
274        ));
275    }
276
277    #[test]
278    fn empty_input_error() {
279        let p = StructuredPruner::new(0.5, PruneGranularity::Channel { n_out: 4, n_in: 4 });
280        assert!(matches!(
281            p.compute_mask(&[]),
282            Err(QuantError::EmptyInput(_))
283        ));
284    }
285
286    #[test]
287    fn all_zero_pruning_error() {
288        let p = StructuredPruner {
289            target_sparsity: 0.99,
290            granularity: PruneGranularity::Channel { n_out: 2, n_in: 4 },
291        };
292        let w = vec![1.0_f32; 8];
293        assert!(matches!(
294            p.compute_mask(&w),
295            Err(QuantError::AllZeroPruning { .. })
296        ));
297    }
298}