entrenar/prune/config/pattern.rs
1//! Sparsity pattern configuration.
2
3use serde::{Deserialize, Serialize};
4
5/// Sparsity pattern selection.
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
7#[serde(tag = "type", rename_all = "snake_case")]
8pub enum SparsityPatternConfig {
9 /// Unstructured sparsity - any weight can be pruned.
10 #[default]
11 Unstructured,
12
13 /// N:M structured sparsity (e.g., 2:4 for NVIDIA Ampere).
14 #[serde(rename = "nm")]
15 NM {
16 /// Number of non-zero elements per group.
17 n: usize,
18 /// Group size.
19 m: usize,
20 },
21
22 /// Block sparsity - entire blocks pruned together.
23 Block {
24 /// Block height.
25 height: usize,
26 /// Block width.
27 width: usize,
28 },
29
30 /// Row sparsity - entire output channels pruned.
31 Row,
32
33 /// Column sparsity - entire input channels pruned.
34 Column,
35}
36
37impl SparsityPatternConfig {
38 /// Create 2:4 sparsity pattern for NVIDIA Ampere.
39 pub fn nm_2_4() -> Self {
40 SparsityPatternConfig::NM { n: 2, m: 4 }
41 }
42
43 /// Create 4:8 sparsity pattern.
44 pub fn nm_4_8() -> Self {
45 SparsityPatternConfig::NM { n: 4, m: 8 }
46 }
47
48 /// Get the theoretical sparsity for this pattern.
49 pub fn theoretical_sparsity(&self) -> f32 {
50 match self {
51 SparsityPatternConfig::Unstructured => 0.0, // Variable
52 SparsityPatternConfig::NM { n, m } => 1.0 - (*n as f32 / *m as f32),
53 SparsityPatternConfig::Block { .. } => 0.0, // Variable
54 SparsityPatternConfig::Row => 0.0, // Variable
55 SparsityPatternConfig::Column => 0.0, // Variable
56 }
57 }
58}