entrenar/autograd/precision/
config.rs1use super::Precision;
4
5const DEFAULT_SCALE_GROWTH_INTERVAL: usize = 2000;
7const FP16_INITIAL_LOSS_SCALE: f32 = 65536.0;
9
10#[derive(Debug, Clone)]
12pub struct MixedPrecisionConfig {
13 pub compute_precision: Precision,
15 pub weight_precision: Precision,
17 pub initial_scale: f32,
19 pub scale_growth_factor: f32,
21 pub scale_backoff_factor: f32,
23 pub scale_growth_interval: usize,
25 pub dynamic_scaling: bool,
27}
28
29impl MixedPrecisionConfig {
30 pub fn fp32() -> Self {
32 Self {
33 compute_precision: Precision::Fp32,
34 weight_precision: Precision::Fp32,
35 initial_scale: 1.0,
36 scale_growth_factor: 2.0,
37 scale_backoff_factor: 0.5,
38 scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
39 dynamic_scaling: false,
40 }
41 }
42
43 pub fn fp16() -> Self {
45 Self {
46 compute_precision: Precision::Fp16,
47 weight_precision: Precision::Fp32,
48 initial_scale: FP16_INITIAL_LOSS_SCALE,
49 scale_growth_factor: 2.0,
50 scale_backoff_factor: 0.5,
51 scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
52 dynamic_scaling: true,
53 }
54 }
55
56 pub fn bf16() -> Self {
58 Self {
59 compute_precision: Precision::Bf16,
60 weight_precision: Precision::Fp32,
61 initial_scale: 1.0, scale_growth_factor: 2.0,
63 scale_backoff_factor: 0.5,
64 scale_growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
65 dynamic_scaling: false, }
67 }
68
69 pub fn is_mixed(&self) -> bool {
71 self.compute_precision.is_reduced()
72 }
73
74 pub fn with_initial_scale(mut self, scale: f32) -> Self {
76 self.initial_scale = scale;
77 self
78 }
79
80 pub fn with_dynamic_scaling(mut self, enabled: bool) -> Self {
82 self.dynamic_scaling = enabled;
83 self
84 }
85}
86
87impl Default for MixedPrecisionConfig {
88 fn default() -> Self {
89 Self::fp32()
90 }
91}