use std::ffi::c_void;
use std::io::{Read, Write};
use std::ptr;
use flodl_sys as ffi;
use crate::autograd::Variable;
use crate::tensor::{DType, Result, Tensor};
use super::checkpoint::{write_f64_le, read_f64_le, write_i64_le, read_i64_le};
use super::optim::Stateful;
use super::parameter::Parameter;
pub struct AutocastGuard {
guard: *mut c_void,
}
impl AutocastGuard {
pub fn new(dtype: DType) -> Self {
let guard = unsafe {
ffi::flodl_autocast_guard_new(ffi::FLODL_CUDA, dtype as i32)
};
AutocastGuard { guard }
}
pub fn for_device(device_type: i32, dtype: DType) -> Self {
let guard = unsafe {
ffi::flodl_autocast_guard_new(device_type, dtype as i32)
};
AutocastGuard { guard }
}
}
impl Drop for AutocastGuard {
fn drop(&mut self) {
if !self.guard.is_null() {
unsafe { ffi::flodl_autocast_guard_delete(self.guard) };
self.guard = ptr::null_mut();
}
}
}
pub fn autocast<F, R>(dtype: DType, f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = AutocastGuard::new(dtype);
f()
}
pub fn is_autocast_enabled() -> bool {
unsafe { ffi::flodl_is_autocast_enabled(ffi::FLODL_CUDA) != 0 }
}
pub fn cast_parameters(params: &[Parameter], dtype: DType) {
for p in params {
if p.variable.data().dtype() != dtype
&& let Ok(t) = p.variable.data().to_dtype(dtype)
{
p.variable.set_data(t);
}
}
}
pub struct GradScaler {
scale: f64,
growth: f64,
backoff: f64,
interval: i64,
steps_since_growth: i64,
found_inf: bool,
}
impl Default for GradScaler {
fn default() -> Self {
GradScaler {
scale: 65536.0,
growth: 2.0,
backoff: 0.5,
interval: 2000,
steps_since_growth: 0,
found_inf: false,
}
}
}
impl GradScaler {
pub fn new() -> Self {
Self::default()
}
pub fn scale(&self, loss: &Variable) -> Result<Variable> {
loss.mul_scalar(self.scale)
}
pub fn scale_factor(&self) -> f64 {
self.scale
}
pub fn step(&mut self, params: &[Parameter], step_fn: &mut dyn FnMut() -> Result<()>) -> Result<bool> {
let inv_scale = 1.0 / self.scale;
let mut unscaled_grads: Vec<Option<Tensor>> = Vec::with_capacity(params.len());
for p in params {
if let Some(grad) = p.variable.grad() {
let unscaled = grad.mul_scalar(inv_scale)?;
if !unscaled.all_finite()? {
self.found_inf = true;
return Ok(false);
}
unscaled_grads.push(Some(unscaled));
} else {
unscaled_grads.push(None);
}
}
for (p, ug) in params.iter().zip(unscaled_grads) {
if let Some(g) = ug {
p.variable.set_grad(g);
}
}
step_fn()?;
Ok(true)
}
pub fn update(&mut self) {
if self.found_inf {
self.scale *= self.backoff;
self.steps_since_growth = 0;
} else {
self.steps_since_growth += 1;
if self.steps_since_growth >= self.interval {
self.scale *= self.growth;
self.steps_since_growth = 0;
}
}
self.found_inf = false;
}
}
impl Stateful for GradScaler {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()> {
write_f64_le(w, self.scale)?;
write_i64_le(w, self.steps_since_growth)?;
Ok(())
}
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()> {
self.scale = read_f64_le(r)?;
self.steps_since_growth = read_i64_le(r)?;
Ok(())
}
}