use crate::{Result, TensorError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AutoCastDtype {
Float32,
Float16,
BFloat16,
}
#[derive(Debug, Clone)]
pub struct AutoCast {
pub dtype: AutoCastDtype,
pub enabled: bool,
}
impl AutoCast {
pub fn new(dtype: AutoCastDtype) -> Self {
Self {
dtype,
enabled: true,
}
}
pub fn enabled() -> Self {
Self {
dtype: AutoCastDtype::Float16,
enabled: true,
}
}
pub fn disabled() -> Self {
Self {
dtype: AutoCastDtype::Float32,
enabled: false,
}
}
pub fn cast_input(&self, data: &[f32]) -> Vec<f32> {
if !self.enabled {
return data.to_vec();
}
match self.dtype {
AutoCastDtype::Float32 => data.to_vec(),
AutoCastDtype::Float16 => data.iter().map(|&x| simulate_f16(x)).collect(),
AutoCastDtype::BFloat16 => data.iter().map(|&x| simulate_bf16(x)).collect(),
}
}
pub fn cast_scalar(&self, x: f32) -> f32 {
if !self.enabled {
return x;
}
match self.dtype {
AutoCastDtype::Float32 => x,
AutoCastDtype::Float16 => simulate_f16(x),
AutoCastDtype::BFloat16 => simulate_bf16(x),
}
}
#[inline]
pub fn would_overflow_fp16(x: f32) -> bool {
!x.is_finite() || x.abs() > 65504.0_f32
}
}
pub fn simulate_f16(x: f32) -> f32 {
use half::f16;
f16::from_f32(x).to_f32()
}
pub fn simulate_bf16(x: f32) -> f32 {
use half::bf16;
bf16::from_f32(x).to_f32()
}
#[derive(Debug, Clone, PartialEq)]
pub struct ScalerState {
pub scale: f32,
pub growth_factor: f32,
pub backoff_factor: f32,
pub growth_interval: u32,
pub steps_since_overflow: u32,
}
#[derive(Debug, Clone)]
pub struct GradScaler {
pub scale: f32,
pub growth_factor: f32,
pub backoff_factor: f32,
pub growth_interval: u32,
pub enabled: bool,
steps_since_overflow: u32,
pub overflow_count: u64,
pub step_count: u64,
}
impl GradScaler {
pub fn new(init_scale: f32) -> Self {
Self {
scale: init_scale,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
enabled: true,
steps_since_overflow: 0,
overflow_count: 0,
step_count: 0,
}
}
pub fn with_config(
init_scale: f32,
growth_factor: f32,
backoff_factor: f32,
interval: u32,
) -> Result<Self> {
if !init_scale.is_finite() || init_scale <= 0.0 {
return Err(TensorError::InvalidArgument {
operation: "GradScaler::with_config".to_string(),
reason: format!(
"init_scale must be a positive finite value, got {}",
init_scale
),
context: None,
});
}
if growth_factor <= 1.0 {
return Err(TensorError::InvalidArgument {
operation: "GradScaler::with_config".to_string(),
reason: format!(
"growth_factor must be > 1.0, got {}",
growth_factor
),
context: None,
});
}
if backoff_factor <= 0.0 || backoff_factor >= 1.0 {
return Err(TensorError::InvalidArgument {
operation: "GradScaler::with_config".to_string(),
reason: format!(
"backoff_factor must be in (0.0, 1.0), got {}",
backoff_factor
),
context: None,
});
}
if interval == 0 {
return Err(TensorError::InvalidArgument {
operation: "GradScaler::with_config".to_string(),
reason: "growth_interval must be >= 1".to_string(),
context: None,
});
}
Ok(Self {
scale: init_scale,
growth_factor,
backoff_factor,
growth_interval: interval,
enabled: true,
steps_since_overflow: 0,
overflow_count: 0,
step_count: 0,
})
}
pub fn disabled() -> Self {
Self {
scale: 1.0,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
enabled: false,
steps_since_overflow: 0,
overflow_count: 0,
step_count: 0,
}
}
#[inline]
pub fn scale_loss(&self, loss: f32) -> f32 {
if self.enabled {
loss * self.scale
} else {
loss
}
}
pub fn scale_gradients(&self, grads: &mut [f32]) {
if !self.enabled {
return;
}
let s = self.scale;
for g in grads.iter_mut() {
*g *= s;
}
}
pub fn unscale_gradients(&self, grads: &mut [f32]) {
if !self.enabled {
return;
}
if self.scale == 0.0 || !self.scale.is_finite() {
for g in grads.iter_mut() {
*g = f32::NAN;
}
return;
}
let inv = 1.0 / self.scale;
for g in grads.iter_mut() {
*g *= inv;
}
}
#[inline]
pub fn check_overflow(&self, grads: &[f32]) -> bool {
grads.iter().any(|&g| !g.is_finite())
}
pub fn step_update(&mut self, grads: &[f32]) -> bool {
self.step_count += 1;
if !self.enabled {
return true;
}
if self.check_overflow(grads) {
self.overflow_count += 1;
self.steps_since_overflow = 0;
self.scale = (self.scale * self.backoff_factor).max(1.0_f32);
return false;
}
self.steps_since_overflow += 1;
if self.steps_since_overflow >= self.growth_interval {
self.scale = (self.scale * self.growth_factor).min(f32::MAX / 2.0);
self.steps_since_overflow = 0;
}
true
}
#[inline]
pub fn get_scale(&self) -> f32 {
self.scale
}
pub fn state_dict(&self) -> ScalerState {
ScalerState {
scale: self.scale,
growth_factor: self.growth_factor,
backoff_factor: self.backoff_factor,
growth_interval: self.growth_interval,
steps_since_overflow: self.steps_since_overflow,
}
}
pub fn load_state_dict(&mut self, state: ScalerState) {
self.scale = state.scale;
self.growth_factor = state.growth_factor;
self.backoff_factor = state.backoff_factor;
self.growth_interval = state.growth_interval;
self.steps_since_overflow = state.steps_since_overflow;
}
}
pub fn f32_to_f16_roundtrip(data: &[f32]) -> Vec<f32> {
data.iter().map(|&x| simulate_f16(x)).collect()
}
pub fn f16_precision_error(data: &[f32]) -> f32 {
if data.is_empty() {
return 0.0;
}
let eps = f32::EPSILON;
let sum: f32 = data
.iter()
.map(|&x| {
let q = simulate_f16(x);
(x - q).abs() / (x.abs() + eps)
})
.sum();
sum / data.len() as f32
}
#[inline]
pub fn grads_are_finite(grads: &[f32]) -> bool {
grads.iter().all(|&g| g.is_finite())
}
pub fn clip_grad_norm(grads: &mut [f32], max_norm: f32) -> f32 {
if grads.is_empty() || max_norm <= 0.0 {
return 0.0;
}
let norm_sq: f32 = grads.iter().map(|&g| g * g).sum();
let norm = norm_sq.sqrt();
if norm > max_norm {
let scale = max_norm / norm;
for g in grads.iter_mut() {
*g *= scale;
}
}
norm
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simulate_f16_exact_representable_values() {
assert_eq!(simulate_f16(1.0_f32), 1.0_f32);
assert_eq!(simulate_f16(0.0_f32), 0.0_f32);
assert_eq!(simulate_f16(-1.0_f32), -1.0_f32);
assert_eq!(simulate_f16(2.0_f32), 2.0_f32);
}
#[test]
fn test_simulate_f16_small_denormals_flush() {
let tiny = 1.0e-8_f32;
let result = simulate_f16(tiny);
assert!(
result.abs() < 1.0e-6_f32,
"Expected tiny value to flush toward 0 in FP16, got {}",
result
);
}
#[test]
fn test_simulate_f16_large_value_clamps_to_inf() {
let huge = 1.0e10_f32;
let result = simulate_f16(huge);
assert!(
result.is_infinite(),
"Expected Inf for huge value, got {}",
result
);
}
#[test]
fn test_simulate_f16_nan_stays_nan() {
assert!(simulate_f16(f32::NAN).is_nan());
}
#[test]
fn test_simulate_f16_precision_loss() {
let pi = std::f32::consts::PI;
let approx = simulate_f16(pi);
assert!((pi - approx).abs() < 0.01_f32);
assert!((pi - approx).abs() > 0.0_f32);
}
#[test]
fn test_simulate_bf16_larger_range_than_f16() {
let v = 65504.0_f32;
let f16_result = simulate_f16(v);
let bf16_result = simulate_bf16(v);
assert!((f16_result - v).abs() < 1.0_f32);
assert!(bf16_result.is_finite());
}
#[test]
fn test_simulate_bf16_less_mantissa_precision_than_f16() {
let x = 1.1_f32;
let f16_err = (x - simulate_f16(x)).abs();
let bf16_err = (x - simulate_bf16(x)).abs();
assert!(
bf16_err >= f16_err,
"Expected bf16_err ({}) >= f16_err ({}) for x={}",
bf16_err,
f16_err,
x
);
}
#[test]
fn test_simulate_bf16_does_not_overflow_large_f32() {
let big = 1.0e30_f32;
let result = simulate_bf16(big);
assert!(result.is_finite(), "Expected finite result for {}; got {}", big, result);
}
#[test]
fn test_would_overflow_fp16_below_max() {
assert!(!AutoCast::would_overflow_fp16(65504.0_f32));
assert!(!AutoCast::would_overflow_fp16(-65504.0_f32));
assert!(!AutoCast::would_overflow_fp16(1.0_f32));
assert!(!AutoCast::would_overflow_fp16(0.0_f32));
}
#[test]
fn test_would_overflow_fp16_above_max() {
assert!(AutoCast::would_overflow_fp16(65505.0_f32));
assert!(AutoCast::would_overflow_fp16(-65505.0_f32));
assert!(AutoCast::would_overflow_fp16(1.0e10_f32));
}
#[test]
fn test_would_overflow_fp16_special() {
assert!(AutoCast::would_overflow_fp16(f32::INFINITY));
assert!(AutoCast::would_overflow_fp16(f32::NEG_INFINITY));
assert!(AutoCast::would_overflow_fp16(f32::NAN));
}
#[test]
fn test_grad_scaler_scale_loss() {
let scaler = GradScaler::new(1024.0);
assert_eq!(scaler.scale_loss(2.0), 2048.0_f32);
assert_eq!(scaler.scale_loss(0.0), 0.0_f32);
}
#[test]
fn test_grad_scaler_disabled_scale_loss_is_identity() {
let scaler = GradScaler::disabled();
assert_eq!(scaler.scale_loss(3.14), 3.14_f32);
}
#[test]
fn test_grad_scaler_grows_after_growth_interval() {
let mut scaler = GradScaler::with_config(1.0, 2.0, 0.5, 3).expect("valid config");
let clean_grads = vec![0.1_f32, 0.2_f32];
scaler.step_update(&clean_grads);
scaler.step_update(&clean_grads);
assert_eq!(scaler.get_scale(), 1.0_f32);
scaler.step_update(&clean_grads);
assert_eq!(scaler.get_scale(), 2.0_f32);
assert_eq!(scaler.steps_since_overflow, 0);
}
#[test]
fn test_grad_scaler_backs_off_on_overflow() {
let mut scaler = GradScaler::new(1024.0);
let bad_grads = vec![f32::NAN, 0.5_f32];
let proceed = scaler.step_update(&bad_grads);
assert!(!proceed, "Step should be skipped on overflow");
assert_eq!(scaler.get_scale(), 512.0_f32); assert_eq!(scaler.overflow_count, 1);
assert_eq!(scaler.steps_since_overflow, 0);
}
#[test]
fn test_grad_scaler_backs_off_on_inf_gradient() {
let mut scaler = GradScaler::new(4096.0);
let bad_grads = vec![f32::INFINITY];
let proceed = scaler.step_update(&bad_grads);
assert!(!proceed);
assert_eq!(scaler.get_scale(), 2048.0_f32);
}
#[test]
fn test_check_overflow_detects_nan() {
let scaler = GradScaler::new(1.0);
assert!(scaler.check_overflow(&[f32::NAN]));
assert!(scaler.check_overflow(&[1.0, f32::NAN, 2.0]));
assert!(!scaler.check_overflow(&[1.0, 2.0, 3.0]));
}
#[test]
fn test_check_overflow_detects_inf() {
let scaler = GradScaler::new(1.0);
assert!(scaler.check_overflow(&[f32::INFINITY]));
assert!(scaler.check_overflow(&[f32::NEG_INFINITY]));
assert!(!scaler.check_overflow(&[1.0, -1.0]));
}
#[test]
fn test_clip_grad_norm_no_clip_needed() {
let mut grads = vec![0.6_f32, 0.8_f32]; let norm = clip_grad_norm(&mut grads, 1.0);
assert!((norm - 1.0_f32).abs() < 1.0e-5);
assert!((grads[0] - 0.6_f32).abs() < 1.0e-5);
assert!((grads[1] - 0.8_f32).abs() < 1.0e-5);
}
#[test]
fn test_clip_grad_norm_clips_large_gradients() {
let mut grads = vec![3.0_f32, 4.0_f32]; let pre_clip_norm = clip_grad_norm(&mut grads, 1.0);
assert!((pre_clip_norm - 5.0_f32).abs() < 1.0e-5);
let post_norm: f32 = grads.iter().map(|&g| g * g).sum::<f32>().sqrt();
assert!((post_norm - 1.0_f32).abs() < 1.0e-5);
}
#[test]
fn test_clip_grad_norm_zero_max_is_noop() {
let mut grads = vec![3.0_f32, 4.0_f32];
let norm = clip_grad_norm(&mut grads, 0.0);
assert_eq!(norm, 0.0_f32);
assert_eq!(grads[0], 3.0_f32);
assert_eq!(grads[1], 4.0_f32);
}
#[test]
fn test_grads_are_finite_with_nan() {
assert!(!grads_are_finite(&[1.0, f32::NAN]));
}
#[test]
fn test_grads_are_finite_with_inf() {
assert!(!grads_are_finite(&[f32::INFINITY, 1.0]));
assert!(!grads_are_finite(&[f32::NEG_INFINITY]));
}
#[test]
fn test_grads_are_finite_all_finite() {
assert!(grads_are_finite(&[1.0, -2.0, 0.001]));
assert!(grads_are_finite(&[]));
}
#[test]
fn test_f16_precision_error_small_for_moderate_values() {
let data: Vec<f32> = (1..=100).map(|i| i as f32).collect();
let err = f16_precision_error(&data);
assert!(err < 1.0e-3_f32, "Relative error {} unexpectedly large", err);
}
#[test]
fn test_f16_precision_error_empty_slice() {
assert_eq!(f16_precision_error(&[]), 0.0_f32);
}
#[test]
fn test_grad_scaler_state_dict_roundtrip() {
let mut scaler = GradScaler::with_config(512.0, 1.5, 0.25, 100).expect("valid config");
scaler.step_update(&[0.1_f32]);
scaler.step_update(&[f32::NAN]);
let state = scaler.state_dict();
let mut scaler2 = GradScaler::new(1.0);
scaler2.load_state_dict(state.clone());
assert_eq!(scaler2.scale, state.scale);
assert_eq!(scaler2.growth_factor, state.growth_factor);
assert_eq!(scaler2.backoff_factor, state.backoff_factor);
assert_eq!(scaler2.growth_interval, state.growth_interval);
assert_eq!(scaler2.steps_since_overflow, state.steps_since_overflow);
}
#[test]
fn test_with_config_rejects_bad_arguments() {
assert!(GradScaler::with_config(0.0, 2.0, 0.5, 100).is_err());
assert!(GradScaler::with_config(-1.0, 2.0, 0.5, 100).is_err());
assert!(GradScaler::with_config(f32::NAN, 2.0, 0.5, 100).is_err());
assert!(GradScaler::with_config(1.0, 1.0, 0.5, 100).is_err()); assert!(GradScaler::with_config(1.0, 0.5, 0.5, 100).is_err()); assert!(GradScaler::with_config(1.0, 2.0, 0.0, 100).is_err()); assert!(GradScaler::with_config(1.0, 2.0, 1.0, 100).is_err()); assert!(GradScaler::with_config(1.0, 2.0, 0.5, 0).is_err()); }
#[test]
fn test_autocast_disabled_is_identity() {
let ctx = AutoCast::disabled();
let data = vec![1.0_f32, 2.5_f32, -3.14_f32];
let out = ctx.cast_input(&data);
assert_eq!(out, data);
assert_eq!(ctx.cast_scalar(1.1_f32), 1.1_f32);
}
#[test]
fn test_autocast_f16_quantises_values() {
let ctx = AutoCast::new(AutoCastDtype::Float16);
let x = 1.1_f32;
let cast = ctx.cast_scalar(x);
assert!((cast - x).abs() > 0.0);
assert!((cast - x).abs() < 0.01);
}
#[test]
fn test_autocast_bf16_quantises_values() {
let ctx = AutoCast::new(AutoCastDtype::BFloat16);
let x = 1.1_f32;
let cast = ctx.cast_scalar(x);
assert!((cast - x).abs() > 0.0);
assert!((cast - x).abs() < 0.01);
}
#[test]
fn test_f32_to_f16_roundtrip_preserves_length() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0];
let out = f32_to_f16_roundtrip(&data);
assert_eq!(out.len(), data.len());
}
#[test]
fn test_f32_to_f16_roundtrip_integer_values_exact() {
let data = vec![1.0_f32, 2.0, 4.0, 8.0, 16.0];
let out = f32_to_f16_roundtrip(&data);
for (orig, rounded) in data.iter().zip(out.iter()) {
assert_eq!(orig, rounded, "Integer {} should round-trip exactly through FP16", orig);
}
}
}