Skip to main content

entrenar/autograd/precision/
config.rs

1//! Configuration for mixed-precision training.
2
3use super::Precision;
4
5/// Default number of successful steps before the loss scale is increased
6const DEFAULT_SCALE_GROWTH_INTERVAL: usize = 2000;
7/// Initial loss scale for fp16 mixed-precision (2^16)
8const FP16_INITIAL_LOSS_SCALE: f32 = 65536.0;
9
10/// Configuration for mixed-precision training
11#[derive(Debug, Clone)]
12pub struct MixedPrecisionConfig {
13    /// Precision for activations and gradients
14    pub compute_precision: Precision,
15    /// Precision for master weights (always fp32 recommended)
16    pub weight_precision: Precision,
17    /// Initial loss scale factor
18    pub initial_scale: f32,
19    /// Factor to increase scale by on successful step
20    pub scale_growth_factor: f32,
21    /// Factor to decrease scale by on overflow
22    pub scale_backoff_factor: f32,
23    /// Number of successful steps before increasing scale
24    pub scale_growth_interval: usize,
25    /// Whether to use dynamic loss scaling
26    pub dynamic_scaling: bool,
27}
28
29impl MixedPrecisionConfig {
30    /// Create fp32 config (no mixed precision)
31    pub fn fp32() -> Self {
32        Self {
33            compute_precision: Precision::Fp32,
34            weight_precision: Precision::Fp32,
35            initial_scale: 1.0,
36            scale_growth_factor: 2.0,
37            scale_backoff_factor: 0.5,
38            scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
39            dynamic_scaling: false,
40        }
41    }
42
43    /// Create fp16 mixed-precision config
44    pub fn fp16() -> Self {
45        Self {
46            compute_precision: Precision::Fp16,
47            weight_precision: Precision::Fp32,
48            initial_scale: FP16_INITIAL_LOSS_SCALE,
49            scale_growth_factor: 2.0,
50            scale_backoff_factor: 0.5,
51            scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
52            dynamic_scaling: true,
53        }
54    }
55
56    /// Create bf16 mixed-precision config
57    pub fn bf16() -> Self {
58        Self {
59            compute_precision: Precision::Bf16,
60            weight_precision: Precision::Fp32,
61            initial_scale: 1.0, // bf16 has larger dynamic range, less scaling needed
62            scale_growth_factor: 2.0,
63            scale_backoff_factor: 0.5,
64            scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
65            dynamic_scaling: false, // Often not needed for bf16
66        }
67    }
68
69    /// Check if mixed precision is enabled
70    pub fn is_mixed(&self) -> bool {
71        self.compute_precision.is_reduced()
72    }
73
74    /// Set initial loss scale
75    pub fn with_initial_scale(mut self, scale: f32) -> Self {
76        self.initial_scale = scale;
77        self
78    }
79
80    /// Enable/disable dynamic scaling
81    pub fn with_dynamic_scaling(mut self, enabled: bool) -> Self {
82        self.dynamic_scaling = enabled;
83        self
84    }
85}
86
87impl Default for MixedPrecisionConfig {
88    fn default() -> Self {
89        Self::fp32()
90    }
91}