#[derive(Debug, Clone, Copy)]
pub struct TernaryConfig {
pub sparsity_threshold: f32,
pub tile_size: u32,
pub block_size: u32,
pub enable_plane_skipping: bool,
pub enable_dim_metadata: bool,
pub metadata_chunk_size: u32,
pub quantization_threshold: Option<f32>,
pub calibration_method: CalibrationMethodConfig,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum CalibrationMethodConfig {
#[default]
AbsMax,
Percentile(f32),
MeanStd(f32),
Manual(f32),
}
impl Default for TernaryConfig {
fn default() -> Self {
Self {
sparsity_threshold: 0.8,
tile_size: 64,
block_size: 256,
enable_plane_skipping: true,
enable_dim_metadata: true,
metadata_chunk_size: 2048,
quantization_threshold: None,
calibration_method: CalibrationMethodConfig::default(),
}
}
}
impl TernaryConfig {
#[must_use]
pub fn for_sparse_model() -> Self {
Self {
sparsity_threshold: 0.90,
tile_size: 128,
block_size: 256,
enable_plane_skipping: true,
enable_dim_metadata: true,
metadata_chunk_size: 4096,
quantization_threshold: None,
calibration_method: CalibrationMethodConfig::Percentile(99.5),
}
}
#[must_use]
pub fn for_dense_model() -> Self {
Self {
sparsity_threshold: 0.5,
tile_size: 64,
block_size: 256,
enable_plane_skipping: false,
enable_dim_metadata: false,
metadata_chunk_size: 2048,
quantization_threshold: None,
calibration_method: CalibrationMethodConfig::AbsMax,
}
}
#[must_use]
pub fn for_rtx_5080() -> Self {
Self {
tile_size: 128,
block_size: 256,
..Self::default()
}
}
#[must_use]
pub fn for_datacenter() -> Self {
Self {
tile_size: 256,
block_size: 256,
..Self::default()
}
}
pub fn validate(&self) -> Result<(), ConfigError> {
if !self.tile_size.is_power_of_two() {
return Err(ConfigError::InvalidTileSize(self.tile_size));
}
if !self.block_size.is_multiple_of(32) || self.block_size > 1024 {
return Err(ConfigError::InvalidBlockSize(self.block_size));
}
if !self.metadata_chunk_size.is_power_of_two() {
return Err(ConfigError::InvalidChunkSize(self.metadata_chunk_size));
}
if !(0.0..=1.0).contains(&self.sparsity_threshold) {
return Err(ConfigError::InvalidSparsityThreshold(
self.sparsity_threshold,
));
}
Ok(())
}
#[must_use]
pub const fn shared_memory_bytes(&self) -> u32 {
2 * self.tile_size * 4
}
#[must_use]
pub const fn k_words(k_dim: u32) -> u32 {
k_dim.div_ceil(32)
}
}
#[derive(Debug, Clone)]
pub enum ConfigError {
InvalidTileSize(u32),
InvalidBlockSize(u32),
InvalidChunkSize(u32),
InvalidSparsityThreshold(f32),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidTileSize(v) => write!(f, "tile_size {v} must be power of 2"),
Self::InvalidBlockSize(v) => {
write!(f, "block_size {v} must be multiple of 32 and ≤1024")
}
Self::InvalidChunkSize(v) => write!(f, "metadata_chunk_size {v} must be power of 2"),
Self::InvalidSparsityThreshold(v) => {
write!(f, "sparsity_threshold {v} must be in [0.0, 1.0]")
}
}
}
}
impl std::error::Error for ConfigError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config_valid() {
let config = TernaryConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_sparse_config_valid() {
let config = TernaryConfig::for_sparse_model();
assert!(config.validate().is_ok());
}
#[test]
fn test_invalid_tile_size() {
let config = TernaryConfig {
tile_size: 65, ..Default::default()
};
assert!(matches!(
config.validate(),
Err(ConfigError::InvalidTileSize(65))
));
}
#[test]
fn test_invalid_block_size() {
let config = TernaryConfig {
block_size: 100, ..Default::default()
};
assert!(matches!(
config.validate(),
Err(ConfigError::InvalidBlockSize(100))
));
}
#[test]
fn test_k_words_calculation() {
assert_eq!(TernaryConfig::k_words(32), 1);
assert_eq!(TernaryConfig::k_words(33), 2);
assert_eq!(TernaryConfig::k_words(64), 2);
assert_eq!(TernaryConfig::k_words(128), 4);
}
}