1use super::error::PruningError;
12use crate::autograd::Tensor;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SparsityPattern {
21 Unstructured,
25
26 NM {
32 n: usize,
34 m: usize,
36 },
37
38 Block {
42 height: usize,
44 width: usize,
46 },
47
48 Row,
52
53 Column,
57}
58
59impl SparsityPattern {
60 #[must_use]
65 pub fn is_valid(&self) -> bool {
66 match self {
67 SparsityPattern::NM { n, m } => *n <= *m && *m > 0,
68 SparsityPattern::Block { height, width } => *height > 0 && *width > 0,
69 _ => true,
70 }
71 }
72
73 #[must_use]
78 pub fn theoretical_sparsity(&self) -> Option<f32> {
79 match self {
80 SparsityPattern::NM { n, m } => Some(1.0 - (*n as f32 / *m as f32)),
81 _ => None, }
83 }
84
85 pub fn validate(&self, mask: &Tensor) -> Result<(), PruningError> {
93 match self {
94 SparsityPattern::Unstructured => Ok(()),
95 SparsityPattern::NM { n, m } => validate_nm(mask, *n, *m),
96 SparsityPattern::Block { height, width } => validate_block(mask, *height, *width),
97 SparsityPattern::Row => validate_row(mask),
98 SparsityPattern::Column => validate_column(mask),
99 }
100 }
101}
102
103impl Default for SparsityPattern {
104 fn default() -> Self {
105 SparsityPattern::Unstructured
106 }
107}
108
109#[derive(Debug, Clone)]
120pub struct SparsityMask {
121 mask: Tensor,
123 pattern: SparsityPattern,
125 sparsity: f32,
127}
128
129impl SparsityMask {
130 pub fn new(mask: Tensor, pattern: SparsityPattern) -> Result<Self, PruningError> {
141 for &v in mask.data() {
143 if (v - 0.0).abs() > 1e-6 && (v - 1.0).abs() > 1e-6 {
144 return Err(PruningError::InvalidMask {
145 reason: format!("Mask contains non-binary value: {v}"),
146 });
147 }
148 }
149
150 pattern.validate(&mask)?;
152
153 let data = mask.data();
155 let sparsity = if data.is_empty() {
156 0.0
157 } else {
158 let zeros = data.iter().filter(|&&v| v < 0.5).count();
159 zeros as f32 / data.len() as f32
160 };
161
162 Ok(Self {
163 mask,
164 pattern,
165 sparsity,
166 })
167 }
168
169 #[must_use]
174 pub fn dense(shape: &[usize]) -> Self {
175 let mask = Tensor::ones(shape);
176 Self {
177 mask,
178 pattern: SparsityPattern::Unstructured,
179 sparsity: 0.0,
180 }
181 }
182
183 #[must_use]
185 pub fn sparsity(&self) -> f32 {
186 self.sparsity
187 }
188
189 #[must_use]
191 pub fn pattern(&self) -> SparsityPattern {
192 self.pattern
193 }
194
195 #[must_use]
197 pub fn tensor(&self) -> &Tensor {
198 &self.mask
199 }
200
201 #[must_use]
203 pub fn shape(&self) -> &[usize] {
204 self.mask.shape()
205 }
206
207 pub fn apply(&self, weights: &mut Tensor) -> Result<(), PruningError> {
219 if weights.shape() != self.mask.shape() {
220 return Err(PruningError::ShapeMismatch {
221 expected: self.mask.shape().to_vec(),
222 got: weights.shape().to_vec(),
223 });
224 }
225
226 let mask_data = self.mask.data();
228 let weight_data = weights.data_mut();
229 for (w, &m) in weight_data.iter_mut().zip(mask_data.iter()) {
230 *w *= m;
231 }
232
233 Ok(())
234 }
235
236 #[must_use]
238 pub fn nnz(&self) -> usize {
239 self.mask.data().iter().filter(|&&v| v > 0.5).count()
240 }
241
242 #[must_use]
244 pub fn num_zeros(&self) -> usize {
245 self.mask.data().iter().filter(|&&v| v < 0.5).count()
246 }
247}
248
249pub fn generate_unstructured_mask(
258 scores: &Tensor,
259 target_sparsity: f32,
260) -> Result<SparsityMask, PruningError> {
261 if !(0.0..=1.0).contains(&target_sparsity) {
262 return Err(PruningError::InvalidSparsity {
263 value: target_sparsity,
264 constraint: "must be between 0.0 and 1.0".to_string(),
265 });
266 }
267
268 let data = scores.data();
269 if data.is_empty() {
270 return SparsityMask::new(Tensor::new(&[], &[0]), SparsityPattern::Unstructured);
271 }
272
273 let num_prune = (data.len() as f32 * target_sparsity) as usize;
275
276 let mut sorted: Vec<f32> = data.to_vec();
278 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
279
280 let threshold = if num_prune == 0 {
281 f32::NEG_INFINITY
282 } else if num_prune >= sorted.len() {
283 f32::INFINITY
284 } else {
285 sorted[num_prune - 1]
286 };
287
288 let mask_data: Vec<f32> = data
290 .iter()
291 .map(|&v| if v > threshold { 1.0 } else { 0.0 })
292 .collect();
293
294 SparsityMask::new(
295 Tensor::new(&mask_data, scores.shape()),
296 SparsityPattern::Unstructured,
297 )
298}
299
300fn validate_nm(mask: &Tensor, n: usize, m: usize) -> Result<(), PruningError> {
302 let data = mask.data();
303 if data.len() % m != 0 {
304 return Err(PruningError::InvalidPattern {
305 message: format!("Tensor length {} not divisible by M={}", data.len(), m),
306 });
307 }
308 for (i, chunk) in data.chunks(m).enumerate() {
309 let nnz = chunk.iter().filter(|&&v| v > 0.5).count();
310 if nnz != n {
311 return Err(PruningError::InvalidPattern {
312 message: format!(
313 "Group {} has {} non-zeros, expected {} (N:M = {}:{})",
314 i, nnz, n, n, m
315 ),
316 });
317 }
318 }
319 Ok(())
320}
321
322fn require_2d(mask: &Tensor, pattern_name: &str) -> Result<(usize, usize), PruningError> {
324 let shape = mask.shape();
325 if shape.len() != 2 {
326 return Err(PruningError::InvalidPattern {
327 message: format!(
328 "{pattern_name} pattern requires 2D tensor, got {}D",
329 shape.len()
330 ),
331 });
332 }
333 Ok((shape[0], shape[1]))
334}
335
336fn check_block_uniform(
338 data: &[f32],
339 br: usize,
340 bc: usize,
341 height: usize,
342 width: usize,
343 cols: usize,
344) -> Result<(), PruningError> {
345 let first = data[br * height * cols + bc * width];
346 for r in 0..height {
347 for c in 0..width {
348 let val = data[(br * height + r) * cols + bc * width + c];
349 if (val - first).abs() > 1e-6 {
350 return Err(PruningError::InvalidPattern {
351 message: format!("Block ({br}, {bc}) is not uniform: found {val} and {first}"),
352 });
353 }
354 }
355 }
356 Ok(())
357}
358
359fn validate_block(mask: &Tensor, height: usize, width: usize) -> Result<(), PruningError> {
361 let (rows, cols) = require_2d(mask, "Block")?;
362 if rows % height != 0 || cols % width != 0 {
363 return Err(PruningError::InvalidPattern {
364 message: format!("Shape [{rows}, {cols}] not divisible by block [{height}, {width}]"),
365 });
366 }
367 let data = mask.data();
368 for br in 0..(rows / height) {
369 for bc in 0..(cols / width) {
370 check_block_uniform(data, br, bc, height, width, cols)?;
371 }
372 }
373 Ok(())
374}
375
376fn validate_row(mask: &Tensor) -> Result<(), PruningError> {
378 let (rows, cols) = require_2d(mask, "Row")?;
379 let data = mask.data();
380 for r in 0..rows {
381 let first = data[r * cols];
382 for c in 1..cols {
383 if (data[r * cols + c] - first).abs() > 1e-6 {
384 return Err(PruningError::InvalidPattern {
385 message: format!("Row {r} is not uniform"),
386 });
387 }
388 }
389 }
390 Ok(())
391}
392
393fn validate_column(mask: &Tensor) -> Result<(), PruningError> {
395 let (rows, cols) = require_2d(mask, "Column")?;
396 let data = mask.data();
397 for c in 0..cols {
398 let first = data[c];
399 for r in 1..rows {
400 if (data[r * cols + c] - first).abs() > 1e-6 {
401 return Err(PruningError::InvalidPattern {
402 message: format!("Column {c} is not uniform"),
403 });
404 }
405 }
406 }
407 Ok(())
408}
409
410include!("mask.rs");