use std::ffi::c_void;
use std::io::{Read, Write};
use std::ptr;
use flodl_sys as ffi;
use crate::autograd::Variable;
use crate::tensor::{DType, Result, Tensor};
use super::checkpoint::{write_f64_le, read_f64_le, write_i64_le, read_i64_le};
use super::optim::Stateful;
use super::parameter::Parameter;
pub struct AutocastGuard {
guard: *mut c_void,
}
impl AutocastGuard {
pub fn new(dtype: DType) -> Self {
let guard = unsafe {
ffi::flodl_autocast_guard_new(ffi::FLODL_CUDA, dtype as i32)
};
AutocastGuard { guard }
}
pub fn for_device(device_type: i32, dtype: DType) -> Self {
let guard = unsafe {
ffi::flodl_autocast_guard_new(device_type, dtype as i32)
};
AutocastGuard { guard }
}
}
impl Drop for AutocastGuard {
fn drop(&mut self) {
if !self.guard.is_null() {
unsafe { ffi::flodl_autocast_guard_delete(self.guard) };
self.guard = ptr::null_mut();
}
}
}
pub fn autocast<F, R>(dtype: DType, f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = AutocastGuard::new(dtype);
f()
}
pub fn is_autocast_enabled() -> bool {
unsafe { ffi::flodl_is_autocast_enabled(ffi::FLODL_CUDA) != 0 }
}
pub fn cast_parameters(params: &[Parameter], dtype: DType) {
for p in params {
if p.variable.data().dtype() != dtype
&& let Ok(t) = p.variable.data().to_dtype(dtype)
{
p.variable.set_data(t);
}
}
}
pub struct GradScaler {
scale: f64,
growth: f64,
backoff: f64,
interval: i64,
steps_since_growth: i64,
found_inf: bool,
}
impl Default for GradScaler {
fn default() -> Self {
GradScaler {
scale: 65536.0,
growth: 2.0,
backoff: 0.5,
interval: 2000,
steps_since_growth: 0,
found_inf: false,
}
}
}
impl GradScaler {
pub fn new() -> Self {
Self::default()
}
pub fn scale(&self, loss: &Variable) -> Result<Variable> {
loss.mul_scalar(self.scale)
}
pub fn scale_factor(&self) -> f64 {
self.scale
}
pub fn step(&mut self, params: &[Parameter], step_fn: &mut dyn FnMut() -> Result<()>) -> Result<bool> {
let inv_scale = 1.0 / self.scale;
let mut unscaled_grads: Vec<Option<Tensor>> = Vec::with_capacity(params.len());
for p in params {
if let Some(grad) = p.variable.grad() {
let unscaled = grad.mul_scalar(inv_scale)?;
if !unscaled.all_finite()? {
self.found_inf = true;
return Ok(false);
}
unscaled_grads.push(Some(unscaled));
} else {
unscaled_grads.push(None);
}
}
for (p, ug) in params.iter().zip(unscaled_grads) {
if let Some(g) = ug {
p.variable.set_grad(g);
}
}
step_fn()?;
Ok(true)
}
pub fn update(&mut self) {
if self.found_inf {
self.scale *= self.backoff;
self.steps_since_growth = 0;
} else {
self.steps_since_growth += 1;
if self.steps_since_growth >= self.interval {
self.scale *= self.growth;
self.steps_since_growth = 0;
}
}
self.found_inf = false;
}
}
impl Stateful for GradScaler {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()> {
write_f64_le(w, self.scale)?;
write_i64_le(w, self.steps_since_growth)?;
Ok(())
}
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()> {
self.scale = read_f64_le(r)?;
self.steps_since_growth = read_i64_le(r)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Tensor, test_device, test_opts};
#[test]
fn test_autocast_guard_lifecycle() {
assert!(!is_autocast_enabled());
{
let _guard = AutocastGuard::new(DType::Float16);
assert!(is_autocast_enabled());
}
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_closure() {
assert!(!is_autocast_enabled());
let was_enabled = autocast(DType::Float16, || {
is_autocast_enabled()
});
assert!(was_enabled);
assert!(!is_autocast_enabled());
}
#[test]
fn test_cast_parameters() {
let t = Tensor::ones(&[4], test_opts()).unwrap();
let p = Parameter {
variable: Variable::new(t, true),
name: "w".into(),
};
assert_eq!(p.variable.data().dtype(), DType::Float32);
cast_parameters(&[p.clone()], DType::Float64);
assert_eq!(p.variable.data().dtype(), DType::Float64);
}
#[test]
fn test_cast_parameters_noop_same_dtype() {
let t = Tensor::ones(&[4], test_opts()).unwrap();
let p = Parameter {
variable: Variable::new(t, true),
name: "w".into(),
};
cast_parameters(&[p.clone()], DType::Float32);
assert_eq!(p.variable.data().dtype(), DType::Float32);
}
#[test]
fn test_grad_scaler_defaults() {
let scaler = GradScaler::new();
assert_eq!(scaler.scale_factor(), 65536.0);
}
#[test]
fn test_grad_scaler_scale() {
let scaler = GradScaler::new();
let loss = Variable::new(
Tensor::from_f32(&[1.0], &[1], test_device()).unwrap(), true,
);
let scaled = scaler.scale(&loss).unwrap();
assert!((scaled.item().unwrap() - 65536.0).abs() < 1.0);
}
#[test]
fn test_grad_scaler_step_finite() {
let mut scaler = GradScaler::new();
let t = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let p = Parameter {
variable: Variable::new(t, true),
name: "w".into(),
};
let grad = Tensor::from_f32(&[0.1, 0.2], &[2], test_device()).unwrap();
p.variable.set_grad(grad);
let mut stepped = false;
let ok = scaler.step(&[p.clone()], &mut || { stepped = true; Ok(()) }).unwrap();
assert!(ok);
assert!(stepped);
}
#[test]
fn test_grad_scaler_step_inf() {
let mut scaler = GradScaler::new();
let t = Tensor::from_f32(&[1.0], &[1], test_device()).unwrap();
let p = Parameter {
variable: Variable::new(t, true),
name: "w".into(),
};
let grad = Tensor::from_f32(&[f32::INFINITY], &[1], test_device()).unwrap();
p.variable.set_grad(grad);
let mut stepped = false;
let ok = scaler.step(&[p], &mut || { stepped = true; Ok(()) }).unwrap();
assert!(!ok, "step should be skipped on inf");
assert!(!stepped, "optimizer should not have stepped");
}
#[test]
fn test_grad_scaler_update_growth() {
let mut scaler = GradScaler {
scale: 100.0,
growth: 2.0,
backoff: 0.5,
interval: 3,
steps_since_growth: 2,
found_inf: false,
};
scaler.update(); assert_eq!(scaler.scale_factor(), 200.0);
}
#[test]
fn test_grad_scaler_update_backoff() {
let mut scaler = GradScaler::new();
let initial = scaler.scale_factor();
scaler.found_inf = true;
scaler.update();
assert_eq!(scaler.scale_factor(), initial * 0.5);
}
#[test]
fn test_grad_scaler_state_roundtrip() {
let mut scaler = GradScaler::new();
scaler.found_inf = true;
scaler.update(); scaler.update();
let mut buf = Vec::new();
scaler.save_state(&mut buf).unwrap();
let mut loaded = GradScaler::new();
loaded.load_state(&mut &buf[..]).unwrap();
assert_eq!(loaded.scale_factor(), scaler.scale_factor());
}
}