entrenar/autograd/precision/
scaler.rs1use super::MixedPrecisionConfig;
4
5const DEFAULT_SCALE_GROWTH_INTERVAL: usize = 2000;
7
8#[derive(Debug)]
12pub struct GradScaler {
13 scale: f32,
15 growth_factor: f32,
17 backoff_factor: f32,
19 pub(crate) growth_interval: usize,
21 steps_since_growth: usize,
23 dynamic: bool,
25 overflow_count: usize,
27 successful_steps: usize,
29}
30
31impl GradScaler {
32 pub fn new(initial_scale: f32) -> Self {
34 Self {
35 scale: initial_scale,
36 growth_factor: 2.0,
37 backoff_factor: 0.5,
38 growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
39 steps_since_growth: 0,
40 dynamic: true,
41 overflow_count: 0,
42 successful_steps: 0,
43 }
44 }
45
46 pub fn from_config(config: &MixedPrecisionConfig) -> Self {
48 Self {
49 scale: config.initial_scale,
50 growth_factor: config.scale_growth_factor,
51 backoff_factor: config.scale_backoff_factor,
52 growth_interval: config.scale_growth_interval,
53 steps_since_growth: 0,
54 dynamic: config.dynamic_scaling,
55 overflow_count: 0,
56 successful_steps: 0,
57 }
58 }
59
60 pub fn scale(&self) -> f32 {
62 self.scale
63 }
64
65 pub fn scale_loss(&self, loss: f32) -> f32 {
67 loss * self.scale
68 }
69
70 pub fn unscale_grad(&self, grad: f32) -> f32 {
72 grad / self.scale
73 }
74
75 pub fn unscale_and_check(&self, grads: &mut [f32]) -> bool {
79 let inv_scale = 1.0 / self.scale;
80 let mut has_overflow = false;
81
82 for grad in grads.iter_mut() {
83 *grad *= inv_scale;
84 if !grad.is_finite() {
85 has_overflow = true;
86 }
87 }
88
89 !has_overflow
90 }
91
92 pub fn update(&mut self, grads_valid: bool) {
96 contract_pre_update!();
97 if !self.dynamic {
98 return;
99 }
100
101 if grads_valid {
102 self.successful_steps += 1;
103 self.steps_since_growth += 1;
104
105 if self.steps_since_growth >= self.growth_interval {
107 self.scale *= self.growth_factor;
108 self.steps_since_growth = 0;
109 }
110 } else {
111 self.overflow_count += 1;
113 self.scale *= self.backoff_factor;
114 self.steps_since_growth = 0;
115
116 self.scale = self.scale.max(1.0);
118 }
119 }
120
121 pub fn overflow_count(&self) -> usize {
123 self.overflow_count
124 }
125
126 pub fn successful_steps(&self) -> usize {
128 self.successful_steps
129 }
130
131 pub fn is_dynamic(&self) -> bool {
133 self.dynamic
134 }
135
136 pub fn set_dynamic(&mut self, enabled: bool) {
138 self.dynamic = enabled;
139 }
140}
141
142impl Default for GradScaler {
143 fn default() -> Self {
144 Self::new(65536.0)
145 }
146}