use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SparsityPatternConfig {
#[default]
Unstructured,
#[serde(rename = "nm")]
NM {
n: usize,
m: usize,
},
Block {
height: usize,
width: usize,
},
Row,
Column,
}
impl SparsityPatternConfig {
pub fn nm_2_4() -> Self {
SparsityPatternConfig::NM { n: 2, m: 4 }
}
pub fn nm_4_8() -> Self {
SparsityPatternConfig::NM { n: 4, m: 8 }
}
pub fn theoretical_sparsity(&self) -> f32 {
match self {
SparsityPatternConfig::Unstructured => 0.0, SparsityPatternConfig::NM { n, m } => 1.0 - (*n as f32 / *m as f32),
SparsityPatternConfig::Block { .. } => 0.0, SparsityPatternConfig::Row => 0.0, SparsityPatternConfig::Column => 0.0, }
}
}