Skip to main content

entrenar/prune/config/
pattern.rs

1//! Sparsity pattern configuration.
2
3use serde::{Deserialize, Serialize};
4
5/// Sparsity pattern selection.
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
7#[serde(tag = "type", rename_all = "snake_case")]
8pub enum SparsityPatternConfig {
9    /// Unstructured sparsity - any weight can be pruned.
10    #[default]
11    Unstructured,
12
13    /// N:M structured sparsity (e.g., 2:4 for NVIDIA Ampere).
14    #[serde(rename = "nm")]
15    NM {
16        /// Number of non-zero elements per group.
17        n: usize,
18        /// Group size.
19        m: usize,
20    },
21
22    /// Block sparsity - entire blocks pruned together.
23    Block {
24        /// Block height.
25        height: usize,
26        /// Block width.
27        width: usize,
28    },
29
30    /// Row sparsity - entire output channels pruned.
31    Row,
32
33    /// Column sparsity - entire input channels pruned.
34    Column,
35}
36
37impl SparsityPatternConfig {
38    /// Create 2:4 sparsity pattern for NVIDIA Ampere.
39    pub fn nm_2_4() -> Self {
40        SparsityPatternConfig::NM { n: 2, m: 4 }
41    }
42
43    /// Create 4:8 sparsity pattern.
44    pub fn nm_4_8() -> Self {
45        SparsityPatternConfig::NM { n: 4, m: 8 }
46    }
47
48    /// Get the theoretical sparsity for this pattern.
49    pub fn theoretical_sparsity(&self) -> f32 {
50        match self {
51            SparsityPatternConfig::Unstructured => 0.0, // Variable
52            SparsityPatternConfig::NM { n, m } => 1.0 - (*n as f32 / *m as f32),
53            SparsityPatternConfig::Block { .. } => 0.0, // Variable
54            SparsityPatternConfig::Row => 0.0,          // Variable
55            SparsityPatternConfig::Column => 0.0,       // Variable
56        }
57    }
58}