1use crate::error::{QuantError, QuantResult};
18use crate::pruning::mask::SparseMask;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum PruneGranularity {
25 Channel {
27 n_out: usize,
29 n_in: usize,
31 },
32 Filter {
34 n_filters: usize,
36 filter_size: usize,
38 },
39 Head {
41 n_heads: usize,
43 head_dim: usize,
45 },
46}
47
48#[derive(Debug, Clone)]
56pub struct StructuredPruner {
57 pub target_sparsity: f32,
59 pub granularity: PruneGranularity,
61}
62
63impl StructuredPruner {
64 #[must_use]
70 pub fn new(target_sparsity: f32, granularity: PruneGranularity) -> Self {
71 assert!(
72 (0.0..1.0).contains(&target_sparsity),
73 "target_sparsity must be in [0, 1), got {target_sparsity}"
74 );
75 Self {
76 target_sparsity,
77 granularity,
78 }
79 }
80
81 pub fn compute_mask(&self, weights: &[f32]) -> QuantResult<SparseMask> {
91 if weights.is_empty() {
92 return Err(QuantError::EmptyInput("StructuredPruner::compute_mask"));
93 }
94 match self.granularity {
95 PruneGranularity::Channel { n_out, n_in } => self.channel_mask(weights, n_out, n_in),
96 PruneGranularity::Filter {
97 n_filters,
98 filter_size,
99 } => self.channel_mask(weights, n_filters, filter_size),
100 PruneGranularity::Head { n_heads, head_dim } => {
101 self.channel_mask(weights, n_heads, head_dim)
102 }
103 }
104 }
105
106 pub fn prune(&self, weights: &mut [f32]) -> QuantResult<SparseMask> {
114 let mask = self.compute_mask(weights)?;
115 mask.apply_in_place(weights);
116 Ok(mask)
117 }
118
119 #[must_use]
123 pub fn unit_l2_norms(weights: &[f32], n_units: usize, unit_size: usize) -> Vec<f32> {
124 (0..n_units)
125 .map(|u| {
126 let base = u * unit_size;
127 weights[base..base + unit_size]
128 .iter()
129 .map(|&w| w * w)
130 .sum::<f32>()
131 .sqrt()
132 })
133 .collect()
134 }
135
136 #[must_use]
140 pub fn pruned_unit_indices(norms: &[f32], n_prune: usize) -> Vec<usize> {
141 let mut idx: Vec<usize> = (0..norms.len()).collect();
142 idx.sort_unstable_by(|&a, &b| {
143 norms[a]
144 .partial_cmp(&norms[b])
145 .unwrap_or(std::cmp::Ordering::Equal)
146 });
147 idx[..n_prune].to_vec()
148 }
149
150 fn channel_mask(
153 &self,
154 weights: &[f32],
155 n_units: usize,
156 unit_size: usize,
157 ) -> QuantResult<SparseMask> {
158 let expected = n_units * unit_size;
159 if weights.len() != expected {
160 return Err(QuantError::DimensionMismatch {
161 expected,
162 got: weights.len(),
163 });
164 }
165
166 let n_prune = ((n_units as f32) * self.target_sparsity).ceil() as usize;
167 if n_prune >= n_units {
168 return Err(QuantError::AllZeroPruning {
169 threshold: self.target_sparsity,
170 n: n_units,
171 });
172 }
173
174 let norms = Self::unit_l2_norms(weights, n_units, unit_size);
175 let pruned = Self::pruned_unit_indices(&norms, n_prune);
176
177 let mut mask = vec![true; weights.len()];
178 for u in pruned {
179 let base = u * unit_size;
180 mask[base..base + unit_size].fill(false);
181 }
182 Ok(SparseMask { mask })
183 }
184}
185
186#[cfg(test)]
189mod tests {
190 use super::*;
191 use approx::assert_abs_diff_eq;
192
193 #[test]
194 fn channel_prune_50_percent() {
195 let n_out = 4;
198 let n_in = 4;
199 let mut w = vec![0.0_f32; n_out * n_in];
200 for j in 0..n_in {
202 w[2 * n_in + j] = 1.0;
203 }
204 for j in 0..n_in {
206 w[3 * n_in + j] = 2.0;
207 }
208
209 let p = StructuredPruner::new(0.5, PruneGranularity::Channel { n_out, n_in });
210 let mask = p.compute_mask(&w).unwrap();
211 assert_abs_diff_eq!(mask.sparsity(), 0.5, epsilon = 0.01);
214 for j in 0..n_in {
216 assert!(!mask.mask[j], "ch 0 elem {j} should be pruned");
217 assert!(!mask.mask[n_in + j], "ch 1 elem {j} should be pruned");
218 assert!(mask.mask[2 * n_in + j], "ch 2 elem {j} should be active");
219 assert!(mask.mask[3 * n_in + j], "ch 3 elem {j} should be active");
220 }
221 }
222
223 #[test]
224 fn prune_in_place_zeroes_low_norm_units() {
225 let n_out = 3;
226 let n_in = 2;
227 let p = StructuredPruner::new(0.33, PruneGranularity::Channel { n_out, n_in });
228 let mut w = vec![
229 0.01_f32, 0.01, 1.0, 1.0, 0.5, 0.5, ];
233 let _mask = p.prune(&mut w).unwrap();
234 assert_abs_diff_eq!(w[0], 0.0, epsilon = 1e-7);
236 assert_abs_diff_eq!(w[1], 0.0, epsilon = 1e-7);
237 }
238
239 #[test]
240 fn unit_l2_norms_correct() {
241 let w = vec![3.0_f32, 4.0, 0.0, 1.0]; let norms = StructuredPruner::unit_l2_norms(&w, 2, 2);
243 assert_abs_diff_eq!(norms[0], 5.0, epsilon = 1e-5); assert_abs_diff_eq!(norms[1], 1.0, epsilon = 1e-5); }
246
247 #[test]
248 fn filter_pruning_behaves_as_channel() {
249 let n_filters = 4;
250 let filter_size = 9;
251 let p = StructuredPruner::new(
252 0.25,
253 PruneGranularity::Filter {
254 n_filters,
255 filter_size,
256 },
257 );
258 let mut w = vec![0.0_f32; n_filters * filter_size];
259 for k in 0..filter_size {
261 w[2 * filter_size + k] = 1.0;
262 }
263 let mask = p.compute_mask(&w).unwrap();
264 assert_eq!(mask.len(), n_filters * filter_size);
265 }
266
267 #[test]
268 fn dimension_mismatch_error() {
269 let p = StructuredPruner::new(0.5, PruneGranularity::Channel { n_out: 4, n_in: 4 });
270 let w = vec![0.5_f32; 12]; assert!(matches!(
272 p.compute_mask(&w),
273 Err(QuantError::DimensionMismatch { .. })
274 ));
275 }
276
277 #[test]
278 fn empty_input_error() {
279 let p = StructuredPruner::new(0.5, PruneGranularity::Channel { n_out: 4, n_in: 4 });
280 assert!(matches!(
281 p.compute_mask(&[]),
282 Err(QuantError::EmptyInput(_))
283 ));
284 }
285
286 #[test]
287 fn all_zero_pruning_error() {
288 let p = StructuredPruner {
289 target_sparsity: 0.99,
290 granularity: PruneGranularity::Channel { n_out: 2, n_in: 4 },
291 };
292 let w = vec![1.0_f32; 8];
293 assert!(matches!(
294 p.compute_mask(&w),
295 Err(QuantError::AllZeroPruning { .. })
296 ));
297 }
298}