use scirs2_core::ndarray::Array2;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::TrainResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PrecisionMode {
FP32,
FP16,
BF16,
}
impl PrecisionMode {
pub fn bytes_per_element(&self) -> usize {
match self {
PrecisionMode::FP32 => 4,
PrecisionMode::FP16 => 2,
PrecisionMode::BF16 => 2,
}
}
pub fn memory_reduction(&self) -> f32 {
match self {
PrecisionMode::FP32 => 1.0,
PrecisionMode::FP16 => 2.0,
PrecisionMode::BF16 => 2.0,
}
}
pub fn numerical_range(&self) -> (f32, f32) {
match self {
PrecisionMode::FP32 => (-3.4e38, 3.4e38),
PrecisionMode::FP16 => (-6.55e4, 6.55e4),
PrecisionMode::BF16 => (-3.39e38, 3.39e38), }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LossScaler {
None,
Static { scale: f32 },
Dynamic {
scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: usize,
steps_since_overflow: usize,
},
}
impl LossScaler {
pub fn static_scale(scale: f32) -> Self {
Self::Static { scale }
}
pub fn dynamic(initial_scale: f32, growth_factor: f32, growth_interval: usize) -> Self {
Self::Dynamic {
scale: initial_scale,
growth_factor,
backoff_factor: 0.5,
growth_interval,
steps_since_overflow: 0,
}
}
pub fn get_scale(&self) -> f32 {
match self {
Self::None => 1.0,
Self::Static { scale } => *scale,
Self::Dynamic { scale, .. } => *scale,
}
}
pub fn scale_loss(&self, loss: f32) -> f32 {
loss * self.get_scale()
}
pub fn unscale_gradients(&self, gradients: &mut Array2<f32>) {
let scale = self.get_scale();
if scale != 1.0 {
*gradients /= scale;
}
}
pub fn update(&mut self, overflow_detected: bool) -> bool {
if let Self::Dynamic {
scale,
growth_factor,
backoff_factor,
growth_interval,
steps_since_overflow,
} = self
{
if overflow_detected {
*scale *= *backoff_factor;
*steps_since_overflow = 0;
false } else {
*steps_since_overflow += 1;
if *steps_since_overflow >= *growth_interval {
*scale *= *growth_factor;
*steps_since_overflow = 0;
}
true }
} else {
!overflow_detected
}
}
}
pub struct MixedPrecisionTrainer {
mode: PrecisionMode,
scaler: LossScaler,
master_weights: HashMap<String, Array2<f32>>,
stats: MixedPrecisionStats,
}
impl MixedPrecisionTrainer {
pub fn new(mode: PrecisionMode, scaler: LossScaler) -> Self {
Self {
mode,
scaler,
master_weights: HashMap::new(),
stats: MixedPrecisionStats::default(),
}
}
pub fn register_weights(&mut self, name: String, weights: Array2<f32>) {
self.master_weights.insert(name, weights);
}
pub fn cast_to_working_precision(&self, weights: &Array2<f32>) -> Array2<f32> {
match self.mode {
PrecisionMode::FP32 => weights.clone(),
PrecisionMode::FP16 => self.simulate_fp16(weights),
PrecisionMode::BF16 => self.simulate_bf16(weights),
}
}
fn simulate_fp16(&self, weights: &Array2<f32>) -> Array2<f32> {
weights.mapv(|x| {
let clamped = x.clamp(-65504.0, 65504.0);
let scale = 2.0_f32.powi(10);
(clamped * scale).round() / scale
})
}
fn simulate_bf16(&self, weights: &Array2<f32>) -> Array2<f32> {
weights.mapv(|x| {
let scale = 2.0_f32.powi(7);
(x * scale).round() / scale
})
}
pub fn scale_loss(&mut self, loss: f32) -> f32 {
self.stats.total_steps += 1;
self.scaler.scale_loss(loss)
}
pub fn unscale_and_check_gradients(
&mut self,
gradients: &mut HashMap<String, Array2<f32>>,
) -> TrainResult<(bool, bool)> {
let mut overflow = false;
for (_name, grad) in gradients.iter() {
if grad.iter().any(|&x| !x.is_finite()) {
overflow = true;
break;
}
}
if overflow {
self.stats.overflow_steps += 1;
}
for (_name, grad) in gradients.iter_mut() {
self.scaler.unscale_gradients(grad);
}
let should_step = self.scaler.update(overflow);
Ok((should_step, overflow))
}
pub fn update_master_weights(&mut self, updates: &HashMap<String, Array2<f32>>) {
for (name, update) in updates {
if let Some(master) = self.master_weights.get_mut(name) {
*master = master.clone() + update;
}
}
}
pub fn mode(&self) -> PrecisionMode {
self.mode
}
pub fn current_scale(&self) -> f32 {
self.scaler.get_scale()
}
pub fn stats(&self) -> &MixedPrecisionStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = MixedPrecisionStats::default();
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MixedPrecisionStats {
pub total_steps: usize,
pub overflow_steps: usize,
pub successful_steps: usize,
}
impl MixedPrecisionStats {
pub fn overflow_rate(&self) -> f32 {
if self.total_steps == 0 {
0.0
} else {
self.overflow_steps as f32 / self.total_steps as f32
}
}
pub fn success_rate(&self) -> f32 {
if self.total_steps == 0 {
0.0
} else {
self.successful_steps as f32 / self.total_steps as f32
}
}
}
pub struct GradientScaler {
scaler: LossScaler,
enabled: bool,
}
impl GradientScaler {
pub fn new(enabled: bool) -> Self {
let scaler = if enabled {
LossScaler::dynamic(2.0_f32.powi(15), 2.0, 2000)
} else {
LossScaler::None
};
Self { scaler, enabled }
}
pub fn with_scaler(scaler: LossScaler, enabled: bool) -> Self {
Self { scaler, enabled }
}
pub fn scale(&self, loss: f32) -> f32 {
if self.enabled {
self.scaler.scale_loss(loss)
} else {
loss
}
}
pub fn unscale(&self, gradients: &mut Array2<f32>) {
if self.enabled {
self.scaler.unscale_gradients(gradients);
}
}
pub fn step(&mut self, overflow_detected: bool) -> bool {
if self.enabled {
self.scaler.update(overflow_detected)
} else {
!overflow_detected
}
}
pub fn get_scale(&self) -> f32 {
self.scaler.get_scale()
}
}
pub struct AutocastContext {
enabled: bool,
mode: PrecisionMode,
}
impl AutocastContext {
pub fn new(enabled: bool, mode: PrecisionMode) -> Self {
Self { enabled, mode }
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn mode(&self) -> PrecisionMode {
self.mode
}
pub fn cast(&self, tensor: &Array2<f32>) -> Array2<f32> {
if !self.enabled || self.mode == PrecisionMode::FP32 {
return tensor.clone();
}
match self.mode {
PrecisionMode::FP16 => self.simulate_fp16(tensor),
PrecisionMode::BF16 => self.simulate_bf16(tensor),
PrecisionMode::FP32 => tensor.clone(),
}
}
fn simulate_fp16(&self, tensor: &Array2<f32>) -> Array2<f32> {
tensor.mapv(|x| {
let clamped = x.clamp(-65504.0, 65504.0);
let scale = 2.0_f32.powi(10);
(clamped * scale).round() / scale
})
}
fn simulate_bf16(&self, tensor: &Array2<f32>) -> Array2<f32> {
tensor.mapv(|x| {
let scale = 2.0_f32.powi(7);
(x * scale).round() / scale
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_precision_mode_properties() {
assert_eq!(PrecisionMode::FP32.bytes_per_element(), 4);
assert_eq!(PrecisionMode::FP16.bytes_per_element(), 2);
assert_eq!(PrecisionMode::BF16.bytes_per_element(), 2);
assert_eq!(PrecisionMode::FP16.memory_reduction(), 2.0);
assert_eq!(PrecisionMode::BF16.memory_reduction(), 2.0);
}
#[test]
fn test_static_loss_scaler() {
let scaler = LossScaler::static_scale(1024.0);
assert_eq!(scaler.get_scale(), 1024.0);
let loss = 0.5;
let scaled = scaler.scale_loss(loss);
assert_eq!(scaled, 512.0);
}
#[test]
fn test_dynamic_loss_scaler() {
let mut scaler = LossScaler::dynamic(1000.0, 2.0, 3);
assert_eq!(scaler.get_scale(), 1000.0);
assert!(scaler.update(false));
assert!(scaler.update(false));
assert!(scaler.update(false));
assert_eq!(scaler.get_scale(), 2000.0);
assert!(!scaler.update(true));
assert_eq!(scaler.get_scale(), 1000.0); }
#[test]
fn test_gradient_unscaling() {
let mut gradients =
Array2::from_shape_vec((2, 2), vec![100.0, 200.0, 300.0, 400.0]).expect("unwrap");
let scaler = LossScaler::static_scale(10.0);
scaler.unscale_gradients(&mut gradients);
assert_eq!(gradients[[0, 0]], 10.0);
assert_eq!(gradients[[0, 1]], 20.0);
assert_eq!(gradients[[1, 0]], 30.0);
assert_eq!(gradients[[1, 1]], 40.0);
}
#[test]
fn test_mixed_precision_trainer() {
let mut trainer =
MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::static_scale(100.0));
let loss = 0.5;
let scaled_loss = trainer.scale_loss(loss);
assert_eq!(scaled_loss, 50.0);
assert_eq!(trainer.stats().total_steps, 1);
}
#[test]
fn test_fp16_simulation() {
let trainer = MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::None);
let weights = Array2::from_shape_vec((2, 2), vec![1.234_567, 100000.0, -100000.0, 0.0001])
.expect("unwrap");
let fp16_weights = trainer.cast_to_working_precision(&weights);
assert_ne!(fp16_weights[[0, 0]], 1.234_567); assert!(fp16_weights[[0, 0]] > 1.0 && fp16_weights[[0, 0]] < 2.0);
assert!(fp16_weights[[0, 1]] <= 65504.0);
assert!(fp16_weights[[1, 0]] >= -65504.0);
}
#[test]
fn test_bf16_simulation() {
let trainer = MixedPrecisionTrainer::new(PrecisionMode::BF16, LossScaler::None);
let weights =
Array2::from_shape_vec((2, 2), vec![1.234_567, 100.5, -50.25, 0.125]).expect("unwrap");
let bf16_weights = trainer.cast_to_working_precision(&weights);
assert_ne!(bf16_weights[[0, 0]], 1.234_567);
}
#[test]
fn test_overflow_detection() {
let mut trainer =
MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::dynamic(1000.0, 2.0, 100));
let mut gradients = HashMap::new();
gradients.insert(
"layer1".to_string(),
Array2::from_shape_vec((2, 2), vec![f32::INFINITY, 1.0, 2.0, 3.0]).expect("unwrap"),
);
let (should_step, overflow) = trainer
.unscale_and_check_gradients(&mut gradients)
.expect("unwrap");
assert!(!should_step);
assert!(overflow);
assert_eq!(trainer.stats().overflow_steps, 1);
}
#[test]
fn test_gradient_scaler() {
let scaler = GradientScaler::new(true);
let loss = 1.0;
let scaled = scaler.scale(loss);
assert!(scaled > loss);
let mut grads = Array2::from_shape_vec((2, 2), vec![1000.0; 4]).expect("unwrap");
scaler.unscale(&mut grads);
assert!(grads[[0, 0]] < 1000.0); }
#[test]
fn test_autocast_context() {
let ctx = AutocastContext::new(true, PrecisionMode::FP16);
assert!(ctx.is_enabled());
assert_eq!(ctx.mode(), PrecisionMode::FP16);
let tensor = Array2::from_shape_vec((2, 2), vec![1.234_567; 4]).expect("unwrap");
let casted = ctx.cast(&tensor);
assert_ne!(casted[[0, 0]], 1.234_567);
}
#[test]
fn test_autocast_disabled() {
let ctx = AutocastContext::new(false, PrecisionMode::FP16);
assert!(!ctx.is_enabled());
let tensor = Array2::from_shape_vec((2, 2), vec![1.234_567; 4]).expect("unwrap");
let casted = ctx.cast(&tensor);
assert_eq!(casted, tensor);
}
#[test]
fn test_master_weights_update() {
let mut trainer = MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::None);
let weights = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("unwrap");
trainer.register_weights("layer1".to_string(), weights.clone());
let mut updates = HashMap::new();
updates.insert(
"layer1".to_string(),
Array2::from_shape_vec((2, 2), vec![0.1, 0.1, 0.1, 0.1]).expect("unwrap"),
);
trainer.update_master_weights(&updates);
let master = &trainer.master_weights["layer1"];
assert_relative_eq!(master[[0, 0]], 1.1, epsilon = 1e-6);
}
#[test]
fn test_mixed_precision_stats() {
let stats = MixedPrecisionStats {
total_steps: 100,
overflow_steps: 5,
successful_steps: 95,
};
assert_eq!(stats.overflow_rate(), 0.05);
assert_eq!(stats.success_rate(), 0.95);
}
#[test]
fn test_loss_scaler_growth() {
let mut scaler = LossScaler::dynamic(1000.0, 2.0, 2);
assert!(scaler.update(false));
assert_eq!(scaler.get_scale(), 1000.0);
assert!(scaler.update(false));
assert_eq!(scaler.get_scale(), 2000.0);
}
}