from typing import Iterable, List, Union
import numpy as np
from ..autodiff import GradManager
from ..functional import full_like
from ..functional.math import _check_non_finite
from ..tensor import Tensor
class GradScaler:
def __init__(
self,
init_scale: float = 2.0 ** 4,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
):
self.scale_factor = float(init_scale)
self.growth_factor = float(growth_factor)
self.backoff_factor = float(backoff_factor)
self.growth_interval = growth_interval
self._growth_tracker = 0
self._found_non_finite = False
def backward(
self,
gm: GradManager,
y: Union[Tensor, List[Tensor]] = None,
dy: Union[Tensor, List[Tensor]] = None,
*,
unscale_grad: bool = True,
update_scale: bool = "if_unscale_grad"
):
if y is None:
ys = []
elif isinstance(y, (tuple, list)):
ys = y
else:
ys = [y]
if dy is None:
dys = [full_like(y, self.scale_factor) for y in ys]
elif isinstance(dy, (tuple, list)):
dys = [dy_ * self.scale_factor for dy_ in dy]
else:
dys = [dy * self.scale_factor]
gm.backward(y=ys, dy=dys)
if unscale_grad:
self.unscale(gm.attached_tensors())
if update_scale:
self.update()
def unscale(self, grad_tensors: Iterable[Tensor]):
if self.growth_interval == 0:
inv_scale = Tensor(1.0 / self.scale_factor)
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad *= inv_scale
return self
if self._check_gradients(
[x.grad for x in grad_tensors], 1.0 / self.scale_factor
):
self._found_non_finite = True
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad = None
return self
def _check_gradients(self, grad, scale):
return _check_non_finite(grad, scale)
def update(self, new_scale: float = None):
if self.growth_interval == 0:
return
if new_scale is not None:
self.scale_factor = float(new_scale)
else:
if self._found_non_finite:
self.scale_factor *= self.backoff_factor
self._growth_tracker = 0
else:
self._growth_tracker += 1
if self._growth_tracker >= self.growth_interval:
self.scale_factor *= self.growth_factor
self._growth_tracker = 0
self._found_non_finite = False
def state_dict(self):
return {
"scale_factor": self.scale_factor,
"growth_factor": self.growth_factor,
"backoff_factor": self.backoff_factor,
"growth_interval": self.growth_interval,
"_growth_tracker": self._growth_tracker,
}
def load_state_dict(self, state):
self.scale_factor = state["scale_factor"]
self.growth_factor = state["growth_factor"]
self.backoff_factor = state["backoff_factor"]
self.growth_interval = state["growth_interval"]
self._growth_tracker = state["_growth_tracker"]