use super::Precision;
const DEFAULT_SCALE_GROWTH_INTERVAL: usize = 2000;
const FP16_INITIAL_LOSS_SCALE: f32 = 65536.0;
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub compute_precision: Precision,
pub weight_precision: Precision,
pub initial_scale: f32,
pub scale_growth_factor: f32,
pub scale_backoff_factor: f32,
pub scale_growth_interval: usize,
pub dynamic_scaling: bool,
}
impl MixedPrecisionConfig {
pub fn fp32() -> Self {
Self {
compute_precision: Precision::Fp32,
weight_precision: Precision::Fp32,
initial_scale: 1.0,
scale_growth_factor: 2.0,
scale_backoff_factor: 0.5,
scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
dynamic_scaling: false,
}
}
pub fn fp16() -> Self {
Self {
compute_precision: Precision::Fp16,
weight_precision: Precision::Fp32,
initial_scale: FP16_INITIAL_LOSS_SCALE,
scale_growth_factor: 2.0,
scale_backoff_factor: 0.5,
scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
dynamic_scaling: true,
}
}
pub fn bf16() -> Self {
Self {
compute_precision: Precision::Bf16,
weight_precision: Precision::Fp32,
initial_scale: 1.0, scale_growth_factor: 2.0,
scale_backoff_factor: 0.5,
scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
dynamic_scaling: false, }
}
pub fn is_mixed(&self) -> bool {
self.compute_precision.is_reduced()
}
pub fn with_initial_scale(mut self, scale: f32) -> Self {
self.initial_scale = scale;
self
}
pub fn with_dynamic_scaling(mut self, enabled: bool) -> Self {
self.dynamic_scaling = enabled;
self
}
}
impl Default for MixedPrecisionConfig {
fn default() -> Self {
Self::fp32()
}
}