Skip to main content

aprender/pruning/mask/
mod.rs

1//! Sparsity mask representation and pattern validation.
2//!
3//! # Toyota Way: Poka-Yoke
4//! Masks validate shape compatibility and binary values at construction,
5//! preventing invalid operations downstream.
6//!
7//! # References
8//! - Zhou, A., et al. (2021). Learning N:M fine-grained structured sparse networks. ICLR.
9//! - Mishra, A., et al. (2021). Accelerating sparse deep neural networks.
10
11use super::error::PruningError;
12use crate::autograd::Tensor;
13
14/// Sparsity pattern constraints.
15///
16/// Defines the structural constraints on which weights can be pruned.
17/// Different patterns offer different trade-offs between flexibility
18/// and hardware acceleration.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SparsityPattern {
21    /// No structural constraint - any element can be pruned.
22    ///
23    /// Maximum flexibility but requires sparse hardware for speedup.
24    Unstructured,
25
26    /// N:M sparsity - in every M consecutive elements, exactly N are non-zero.
27    ///
28    /// # Hardware Support
29    /// - 2:4 sparsity: NVIDIA Ampere (A100, RTX 30xx) - 2x speedup
30    /// - 4:8 sparsity: Future hardware
31    NM {
32        /// Number of non-zero elements per group
33        n: usize,
34        /// Group size
35        m: usize,
36    },
37
38    /// Block sparsity - entire blocks of size (height, width) are pruned together.
39    ///
40    /// Useful for structured pruning where entire neurons or filters are removed.
41    Block {
42        /// Block height
43        height: usize,
44        /// Block width
45        width: usize,
46    },
47
48    /// Row sparsity - entire rows (output channels) pruned.
49    ///
50    /// Equivalent to pruning entire output neurons.
51    Row,
52
53    /// Column sparsity - entire columns (input channels) pruned.
54    ///
55    /// Equivalent to removing input features.
56    Column,
57}
58
59impl SparsityPattern {
60    /// Check if this pattern configuration is valid.
61    ///
62    /// # Returns
63    /// `true` if the pattern parameters are valid.
64    #[must_use]
65    pub fn is_valid(&self) -> bool {
66        match self {
67            SparsityPattern::NM { n, m } => *n <= *m && *m > 0,
68            SparsityPattern::Block { height, width } => *height > 0 && *width > 0,
69            _ => true,
70        }
71    }
72
73    /// Get the theoretical sparsity for this pattern.
74    ///
75    /// # Returns
76    /// Sparsity ratio (0.0 = dense, 1.0 = fully sparse)
77    #[must_use]
78    pub fn theoretical_sparsity(&self) -> Option<f32> {
79        match self {
80            SparsityPattern::NM { n, m } => Some(1.0 - (*n as f32 / *m as f32)),
81            _ => None, // Variable sparsity for other patterns
82        }
83    }
84
85    /// Validate a mask tensor against this pattern's constraints.
86    ///
87    /// # Arguments
88    /// * `mask` - Binary mask tensor to validate
89    ///
90    /// # Returns
91    /// `Ok(())` if valid, `Err(PruningError::InvalidPattern)` if not.
92    pub fn validate(&self, mask: &Tensor) -> Result<(), PruningError> {
93        match self {
94            SparsityPattern::Unstructured => Ok(()),
95            SparsityPattern::NM { n, m } => validate_nm(mask, *n, *m),
96            SparsityPattern::Block { height, width } => validate_block(mask, *height, *width),
97            SparsityPattern::Row => validate_row(mask),
98            SparsityPattern::Column => validate_column(mask),
99        }
100    }
101}
102
103impl Default for SparsityPattern {
104    fn default() -> Self {
105        SparsityPattern::Unstructured
106    }
107}
108
109/// Sparsity mask with validation.
110///
111/// # Toyota Way: Poka-Yoke
112/// The mask validates binary values and pattern constraints at construction,
113/// preventing invalid masks from being created.
114///
115/// # Invariants
116/// - All values are exactly 0.0 or 1.0
117/// - Pattern constraints are satisfied
118/// - Sparsity is precomputed and cached
119#[derive(Debug, Clone)]
120pub struct SparsityMask {
121    /// Binary mask tensor (1 = keep, 0 = prune)
122    mask: Tensor,
123    /// Pattern used to generate this mask
124    pattern: SparsityPattern,
125    /// Cached sparsity ratio
126    sparsity: f32,
127}
128
129impl SparsityMask {
130    /// Create a new mask with validation.
131    ///
132    /// # Arguments
133    /// * `mask` - Binary tensor with values in {0.0, 1.0}
134    /// * `pattern` - Sparsity pattern constraint
135    ///
136    /// # Returns
137    /// * `Ok(SparsityMask)` - Valid mask
138    /// * `Err(PruningError::InvalidMask)` - If values are not binary
139    /// * `Err(PruningError::InvalidPattern)` - If pattern constraints violated
140    pub fn new(mask: Tensor, pattern: SparsityPattern) -> Result<Self, PruningError> {
141        // Validate binary values
142        for &v in mask.data() {
143            if (v - 0.0).abs() > 1e-6 && (v - 1.0).abs() > 1e-6 {
144                return Err(PruningError::InvalidMask {
145                    reason: format!("Mask contains non-binary value: {v}"),
146                });
147            }
148        }
149
150        // Validate pattern constraints
151        pattern.validate(&mask)?;
152
153        // Compute sparsity (fraction of zeros)
154        let data = mask.data();
155        let sparsity = if data.is_empty() {
156            0.0
157        } else {
158            let zeros = data.iter().filter(|&&v| v < 0.5).count();
159            zeros as f32 / data.len() as f32
160        };
161
162        Ok(Self {
163            mask,
164            pattern,
165            sparsity,
166        })
167    }
168
169    /// Create an all-ones (dense) mask.
170    ///
171    /// # Arguments
172    /// * `shape` - Shape of the mask
173    #[must_use]
174    pub fn dense(shape: &[usize]) -> Self {
175        let mask = Tensor::ones(shape);
176        Self {
177            mask,
178            pattern: SparsityPattern::Unstructured,
179            sparsity: 0.0,
180        }
181    }
182
183    /// Get the sparsity ratio (0.0 = dense, 1.0 = all zeros).
184    #[must_use]
185    pub fn sparsity(&self) -> f32 {
186        self.sparsity
187    }
188
189    /// Get the pattern used for this mask.
190    #[must_use]
191    pub fn pattern(&self) -> SparsityPattern {
192        self.pattern
193    }
194
195    /// Get the underlying mask tensor.
196    #[must_use]
197    pub fn tensor(&self) -> &Tensor {
198        &self.mask
199    }
200
201    /// Get the shape of the mask.
202    #[must_use]
203    pub fn shape(&self) -> &[usize] {
204        self.mask.shape()
205    }
206
207    /// Apply mask to weights in-place.
208    ///
209    /// # Arguments
210    /// * `weights` - Tensor to apply mask to (modified in-place)
211    ///
212    /// # Returns
213    /// * `Ok(())` - Mask applied successfully
214    /// * `Err(PruningError::ShapeMismatch)` - If shapes don't match
215    ///
216    /// # Toyota Way: Poka-Yoke
217    /// Shape validation prevents applying mask to wrong tensor.
218    pub fn apply(&self, weights: &mut Tensor) -> Result<(), PruningError> {
219        if weights.shape() != self.mask.shape() {
220            return Err(PruningError::ShapeMismatch {
221                expected: self.mask.shape().to_vec(),
222                got: weights.shape().to_vec(),
223            });
224        }
225
226        // Element-wise multiplication
227        let mask_data = self.mask.data();
228        let weight_data = weights.data_mut();
229        for (w, &m) in weight_data.iter_mut().zip(mask_data.iter()) {
230            *w *= m;
231        }
232
233        Ok(())
234    }
235
236    /// Count the number of non-zero elements.
237    #[must_use]
238    pub fn nnz(&self) -> usize {
239        self.mask.data().iter().filter(|&&v| v > 0.5).count()
240    }
241
242    /// Count the number of zero elements.
243    #[must_use]
244    pub fn num_zeros(&self) -> usize {
245        self.mask.data().iter().filter(|&&v| v < 0.5).count()
246    }
247}
248
249/// Generate an unstructured sparsity mask based on importance scores.
250///
251/// # Arguments
252/// * `scores` - Importance scores tensor
253/// * `target_sparsity` - Fraction of weights to prune (0.0 to 1.0)
254///
255/// # Returns
256/// Mask where lowest-importance weights are set to 0.
257pub fn generate_unstructured_mask(
258    scores: &Tensor,
259    target_sparsity: f32,
260) -> Result<SparsityMask, PruningError> {
261    if !(0.0..=1.0).contains(&target_sparsity) {
262        return Err(PruningError::InvalidSparsity {
263            value: target_sparsity,
264            constraint: "must be between 0.0 and 1.0".to_string(),
265        });
266    }
267
268    let data = scores.data();
269    if data.is_empty() {
270        return SparsityMask::new(Tensor::new(&[], &[0]), SparsityPattern::Unstructured);
271    }
272
273    // Find threshold for target sparsity
274    let num_prune = (data.len() as f32 * target_sparsity) as usize;
275
276    // Sort scores to find threshold
277    let mut sorted: Vec<f32> = data.to_vec();
278    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
279
280    let threshold = if num_prune == 0 {
281        f32::NEG_INFINITY
282    } else if num_prune >= sorted.len() {
283        f32::INFINITY
284    } else {
285        sorted[num_prune - 1]
286    };
287
288    // Generate mask (1 = keep, 0 = prune)
289    let mask_data: Vec<f32> = data
290        .iter()
291        .map(|&v| if v > threshold { 1.0 } else { 0.0 })
292        .collect();
293
294    SparsityMask::new(
295        Tensor::new(&mask_data, scores.shape()),
296        SparsityPattern::Unstructured,
297    )
298}
299
300/// Validate N:M sparsity pattern: every M consecutive elements must have exactly N non-zeros.
301fn validate_nm(mask: &Tensor, n: usize, m: usize) -> Result<(), PruningError> {
302    let data = mask.data();
303    if data.len() % m != 0 {
304        return Err(PruningError::InvalidPattern {
305            message: format!("Tensor length {} not divisible by M={}", data.len(), m),
306        });
307    }
308    for (i, chunk) in data.chunks(m).enumerate() {
309        let nnz = chunk.iter().filter(|&&v| v > 0.5).count();
310        if nnz != n {
311            return Err(PruningError::InvalidPattern {
312                message: format!(
313                    "Group {} has {} non-zeros, expected {} (N:M = {}:{})",
314                    i, nnz, n, n, m
315                ),
316            });
317        }
318    }
319    Ok(())
320}
321
322/// Require that the mask is 2D and return (rows, cols).
323fn require_2d(mask: &Tensor, pattern_name: &str) -> Result<(usize, usize), PruningError> {
324    let shape = mask.shape();
325    if shape.len() != 2 {
326        return Err(PruningError::InvalidPattern {
327            message: format!(
328                "{pattern_name} pattern requires 2D tensor, got {}D",
329                shape.len()
330            ),
331        });
332    }
333    Ok((shape[0], shape[1]))
334}
335
336/// Check that a single block is uniform (all values equal to `first`).
337fn check_block_uniform(
338    data: &[f32],
339    br: usize,
340    bc: usize,
341    height: usize,
342    width: usize,
343    cols: usize,
344) -> Result<(), PruningError> {
345    let first = data[br * height * cols + bc * width];
346    for r in 0..height {
347        for c in 0..width {
348            let val = data[(br * height + r) * cols + bc * width + c];
349            if (val - first).abs() > 1e-6 {
350                return Err(PruningError::InvalidPattern {
351                    message: format!("Block ({br}, {bc}) is not uniform: found {val} and {first}"),
352                });
353            }
354        }
355    }
356    Ok(())
357}
358
359/// Validate block sparsity: each block must be uniform (all 0s or all 1s).
360fn validate_block(mask: &Tensor, height: usize, width: usize) -> Result<(), PruningError> {
361    let (rows, cols) = require_2d(mask, "Block")?;
362    if rows % height != 0 || cols % width != 0 {
363        return Err(PruningError::InvalidPattern {
364            message: format!("Shape [{rows}, {cols}] not divisible by block [{height}, {width}]"),
365        });
366    }
367    let data = mask.data();
368    for br in 0..(rows / height) {
369        for bc in 0..(cols / width) {
370            check_block_uniform(data, br, bc, height, width, cols)?;
371        }
372    }
373    Ok(())
374}
375
376/// Validate row sparsity: each row must be uniform (all 0s or all 1s).
377fn validate_row(mask: &Tensor) -> Result<(), PruningError> {
378    let (rows, cols) = require_2d(mask, "Row")?;
379    let data = mask.data();
380    for r in 0..rows {
381        let first = data[r * cols];
382        for c in 1..cols {
383            if (data[r * cols + c] - first).abs() > 1e-6 {
384                return Err(PruningError::InvalidPattern {
385                    message: format!("Row {r} is not uniform"),
386                });
387            }
388        }
389    }
390    Ok(())
391}
392
393/// Validate column sparsity: each column must be uniform (all 0s or all 1s).
394fn validate_column(mask: &Tensor) -> Result<(), PruningError> {
395    let (rows, cols) = require_2d(mask, "Column")?;
396    let data = mask.data();
397    for c in 0..cols {
398        let first = data[c];
399        for r in 1..rows {
400            if (data[r * cols + c] - first).abs() > 1e-6 {
401                return Err(PruningError::InvalidPattern {
402                    message: format!("Column {c} is not uniform"),
403                });
404            }
405        }
406    }
407    Ok(())
408}
409
410include!("mask.rs");