use serde::{Deserialize, Serialize};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
)]
pub enum Precision {
INT4,
INT8,
BF16,
#[default]
FP16,
FP32,
}
impl Precision {
pub fn bits(&self) -> u32 {
match self {
Precision::INT4 => 4,
Precision::INT8 => 8,
Precision::BF16 => 16,
Precision::FP16 => 16,
Precision::FP32 => 32,
}
}
pub fn bytes(&self) -> f32 {
self.bits() as f32 / 8.0
}
pub fn vram_ratio(&self) -> f32 {
self.bits() as f32 / 32.0
}
pub fn speedup_factor(&self) -> f32 {
match self {
Precision::INT4 => 4.0,
Precision::INT8 => 2.5,
Precision::BF16 => 1.8,
Precision::FP16 => 2.0,
Precision::FP32 => 1.0,
}
}
pub fn quality_factor(&self) -> f32 {
match self {
Precision::INT4 => 0.92,
Precision::INT8 => 0.97,
Precision::BF16 => 0.995,
Precision::FP16 => 0.998,
Precision::FP32 => 1.0,
}
}
pub fn is_lossless(&self) -> bool {
matches!(self, Precision::FP32 | Precision::FP16 | Precision::BF16)
}
pub fn parse(s: &str) -> Option<Self> {
s.parse().ok()
}
}
impl std::str::FromStr for Precision {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"INT4" | "I4" | "4BIT" => Ok(Precision::INT4),
"INT8" | "I8" | "8BIT" => Ok(Precision::INT8),
"BF16" | "BFLOAT16" => Ok(Precision::BF16),
"FP16" | "FLOAT16" | "F16" | "HALF" => Ok(Precision::FP16),
"FP32" | "FLOAT32" | "F32" | "FULL" => Ok(Precision::FP32),
_ => Err(()),
}
}
}
impl std::fmt::Display for Precision {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Precision::INT4 => write!(f, "INT4"),
Precision::INT8 => write!(f, "INT8"),
Precision::BF16 => write!(f, "BF16"),
Precision::FP16 => write!(f, "FP16"),
Precision::FP32 => write!(f, "FP32"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrecisionCapabilities {
pub supported: Vec<Precision>,
pub native: Precision,
pub int4_tensor_cores: bool,
pub int8_tensor_cores: bool,
pub vram_mb: u64,
pub compute_capability: Option<(u32, u32)>,
}
impl PrecisionCapabilities {
pub fn modern_gpu(vram_mb: u64) -> Self {
Self {
supported: vec![
Precision::INT4,
Precision::INT8,
Precision::BF16,
Precision::FP16,
Precision::FP32,
],
native: Precision::FP16,
int4_tensor_cores: true,
int8_tensor_cores: true,
vram_mb,
compute_capability: Some((8, 6)), }
}
pub fn legacy_gpu(vram_mb: u64) -> Self {
Self {
supported: vec![Precision::FP16, Precision::FP32],
native: Precision::FP32,
int4_tensor_cores: false,
int8_tensor_cores: false,
vram_mb,
compute_capability: Some((6, 1)), }
}
pub fn cpu(memory_mb: u64) -> Self {
Self {
supported: vec![Precision::INT8, Precision::FP32],
native: Precision::FP32,
int4_tensor_cores: false,
int8_tensor_cores: false,
vram_mb: memory_mb,
compute_capability: None,
}
}
pub fn supports(&self, precision: Precision) -> bool {
self.supported.contains(&precision)
}
pub fn best_supported(&self, max_precision: Precision) -> Precision {
self.supported
.iter()
.filter(|&&p| p <= max_precision)
.max()
.copied()
.unwrap_or(self.native)
}
pub fn estimate_vram(&self, model_params: u64, precision: Precision) -> u64 {
let base_bytes = model_params * 4; (base_bytes as f32 * precision.vram_ratio()) as u64
}
pub fn fits_model(&self, model_params: u64, precision: Precision) -> bool {
let required_mb = self.estimate_vram(model_params, precision) / (1024 * 1024);
required_mb <= self.vram_mb
}
}
impl Default for PrecisionCapabilities {
fn default() -> Self {
Self::modern_gpu(8192) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_precision_ordering() {
assert!(Precision::INT4 < Precision::INT8);
assert!(Precision::INT8 < Precision::FP16);
assert!(Precision::FP16 < Precision::FP32);
}
#[test]
fn test_precision_bits() {
assert_eq!(Precision::INT4.bits(), 4);
assert_eq!(Precision::INT8.bits(), 8);
assert_eq!(Precision::FP16.bits(), 16);
assert_eq!(Precision::FP32.bits(), 32);
}
#[test]
fn test_vram_ratio() {
assert!((Precision::INT4.vram_ratio() - 0.125).abs() < 0.001);
assert!((Precision::INT8.vram_ratio() - 0.25).abs() < 0.001);
assert!((Precision::FP16.vram_ratio() - 0.5).abs() < 0.001);
assert!((Precision::FP32.vram_ratio() - 1.0).abs() < 0.001);
}
#[test]
fn test_capabilities() {
let caps = PrecisionCapabilities::modern_gpu(12288);
assert!(caps.supports(Precision::INT4));
assert!(caps.supports(Precision::FP16));
assert!(caps.int4_tensor_cores);
}
#[test]
fn test_best_supported() {
let caps = PrecisionCapabilities::legacy_gpu(4096);
assert_eq!(caps.best_supported(Precision::INT4), Precision::FP32);
assert_eq!(caps.best_supported(Precision::FP16), Precision::FP16);
}
}