use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, ArrayD, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LossScalingStrategy {
Fixed,
#[default]
Dynamic,
Automatic,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PrecisionMode {
FP32,
#[default]
FP16Mixed,
BF16Mixed,
Automatic,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FP32Operation {
LossComputation,
Softmax,
LayerNorm,
BatchNormStats,
GradientAccumulation,
OptimizerUpdate,
Embedding,
Reductions,
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub enabled: bool,
pub precision_mode: PrecisionMode,
pub loss_scaling_strategy: LossScalingStrategy,
pub initial_loss_scale: f64,
pub growth_factor: f64,
pub backoff_factor: f64,
pub growth_interval: usize,
pub min_loss_scale: f64,
pub max_loss_scale: f64,
pub fp32_operations: Vec<FP32Operation>,
pub cast_model_type: bool,
pub memory_efficient_gradients: bool,
pub gradient_checkpointing: bool,
pub max_consecutive_overflows: usize,
pub grad_clip_threshold: Option<f64>,
pub log_statistics: bool,
}
impl Default for MixedPrecisionConfig {
fn default() -> Self {
Self {
enabled: false,
precision_mode: PrecisionMode::FP16Mixed,
loss_scaling_strategy: LossScalingStrategy::Dynamic,
initial_loss_scale: 65536.0, growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
min_loss_scale: 1.0,
max_loss_scale: 2.0_f64.powi(24), fp32_operations: vec![
FP32Operation::LossComputation,
FP32Operation::Softmax,
FP32Operation::LayerNorm,
FP32Operation::BatchNormStats,
FP32Operation::GradientAccumulation,
FP32Operation::OptimizerUpdate,
],
cast_model_type: true,
memory_efficient_gradients: true,
gradient_checkpointing: false,
max_consecutive_overflows: 5,
grad_clip_threshold: None,
log_statistics: false,
}
}
}
impl MixedPrecisionConfig {
pub fn builder() -> MixedPrecisionConfigBuilder {
MixedPrecisionConfigBuilder::new()
}
pub fn should_use_fp32(&self, operation: FP32Operation) -> bool {
self.fp32_operations.contains(&operation)
}
pub fn validate(&self) -> Result<()> {
if self.initial_loss_scale <= 0.0 {
return Err(NeuralError::ConfigError(
"Initial loss scale must be positive".to_string(),
));
}
if self.growth_factor <= 1.0 {
return Err(NeuralError::ConfigError(
"Growth factor must be greater than 1.0".to_string(),
));
}
if self.backoff_factor <= 0.0 || self.backoff_factor >= 1.0 {
return Err(NeuralError::ConfigError(
"Backoff factor must be in (0.0, 1.0)".to_string(),
));
}
if self.min_loss_scale > self.max_loss_scale {
return Err(NeuralError::ConfigError(
"Minimum loss scale cannot exceed maximum loss scale".to_string(),
));
}
if self.growth_interval == 0 {
return Err(NeuralError::ConfigError(
"Growth interval must be at least 1".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfigBuilder {
config: MixedPrecisionConfig,
}
impl MixedPrecisionConfigBuilder {
pub fn new() -> Self {
Self {
config: MixedPrecisionConfig::default(),
}
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.config.enabled = enabled;
self
}
pub fn precision_mode(mut self, mode: PrecisionMode) -> Self {
self.config.precision_mode = mode;
self
}
pub fn loss_scaling_strategy(mut self, strategy: LossScalingStrategy) -> Self {
self.config.loss_scaling_strategy = strategy;
self
}
pub fn initial_loss_scale(mut self, scale: f64) -> Self {
self.config.initial_loss_scale = scale;
self
}
pub fn growth_factor(mut self, factor: f64) -> Self {
self.config.growth_factor = factor;
self
}
pub fn backoff_factor(mut self, factor: f64) -> Self {
self.config.backoff_factor = factor;
self
}
pub fn growth_interval(mut self, interval: usize) -> Self {
self.config.growth_interval = interval;
self
}
pub fn min_loss_scale(mut self, scale: f64) -> Self {
self.config.min_loss_scale = scale;
self
}
pub fn max_loss_scale(mut self, scale: f64) -> Self {
self.config.max_loss_scale = scale;
self
}
pub fn add_fp32_operation(mut self, operation: FP32Operation) -> Self {
if !self.config.fp32_operations.contains(&operation) {
self.config.fp32_operations.push(operation);
}
self
}
pub fn grad_clip_threshold(mut self, threshold: f64) -> Self {
self.config.grad_clip_threshold = Some(threshold);
self
}
pub fn memory_efficient_gradients(mut self, enabled: bool) -> Self {
self.config.memory_efficient_gradients = enabled;
self
}
pub fn gradient_checkpointing(mut self, enabled: bool) -> Self {
self.config.gradient_checkpointing = enabled;
self
}
pub fn log_statistics(mut self, enabled: bool) -> Self {
self.config.log_statistics = enabled;
self
}
pub fn build(self) -> Result<MixedPrecisionConfig> {
self.config.validate()?;
Ok(self.config)
}
}
impl Default for MixedPrecisionConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Half(u16);
impl Half {
pub const ZERO: Half = Half(0);
pub const ONE: Half = Half(0x3c00);
pub const MAX: Half = Half(0x7bff);
pub const MIN_POSITIVE: Half = Half(0x0400);
pub const INFINITY: Half = Half(0x7c00);
pub const NEG_INFINITY: Half = Half(0xfc00);
pub const NAN: Half = Half(0x7e00);
pub const fn from_bits(bits: u16) -> Self {
Half(bits)
}
pub const fn to_bits(self) -> u16 {
self.0
}
pub fn from_f32(value: f32) -> Self {
let bits = value.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xff) as i32;
let frac = bits & 0x7fffff;
if exp == 0xff {
if frac == 0 {
return Half(((sign << 15) | 0x7c00) as u16);
} else {
return Half(((sign << 15) | 0x7e00) as u16);
}
}
let new_exp = exp - 127 + 15;
if new_exp <= 0 {
if new_exp < -10 {
return Half((sign << 15) as u16);
}
let shift = 1 - new_exp;
let frac_with_hidden = frac | 0x800000;
let frac16 = (frac_with_hidden >> (shift + 13)) as u16;
return Half(((sign << 15) | frac16 as u32) as u16);
}
if new_exp >= 31 {
return Half(((sign << 15) | 0x7c00) as u16);
}
let frac16 = (frac >> 13) as u16;
Half(((sign << 15) | ((new_exp as u32) << 10) | frac16 as u32) as u16)
}
pub fn to_f32(self) -> f32 {
let bits = self.0 as u32;
let sign = (bits >> 15) & 1;
let exp = (bits >> 10) & 0x1f;
let frac = bits & 0x3ff;
if exp == 0 {
if frac == 0 {
return f32::from_bits(sign << 31);
}
let mut frac = frac;
let mut e = -14i32;
while frac & 0x400 == 0 {
frac <<= 1;
e -= 1;
}
frac &= 0x3ff;
let exp32 = (e + 127) as u32;
let frac32 = frac << 13;
return f32::from_bits((sign << 31) | (exp32 << 23) | frac32);
}
if exp == 0x1f {
if frac == 0 {
return f32::from_bits((sign << 31) | 0x7f800000);
}
return f32::from_bits((sign << 31) | 0x7fc00000);
}
let exp32 = (exp as i32 - 15 + 127) as u32;
let frac32 = frac << 13;
f32::from_bits((sign << 31) | (exp32 << 23) | frac32)
}
pub fn is_nan(self) -> bool {
(self.0 & 0x7c00) == 0x7c00 && (self.0 & 0x03ff) != 0
}
pub fn is_infinite(self) -> bool {
(self.0 & 0x7fff) == 0x7c00
}
pub fn is_finite(self) -> bool {
(self.0 & 0x7c00) != 0x7c00
}
pub fn is_zero(self) -> bool {
(self.0 & 0x7fff) == 0
}
}
impl From<f32> for Half {
fn from(value: f32) -> Self {
Half::from_f32(value)
}
}
impl From<Half> for f32 {
fn from(value: Half) -> Self {
value.to_f32()
}
}
impl From<f64> for Half {
fn from(value: f64) -> Self {
Half::from_f32(value as f32)
}
}
impl From<Half> for f64 {
fn from(value: Half) -> Self {
value.to_f32() as f64
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionTensor<F: Float + Debug + ScalarOperand> {
fp32_data: Option<ArrayD<F>>,
fp16_bits: Option<ArrayD<u16>>,
shape: Vec<usize>,
precision: PrecisionMode,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive> MixedPrecisionTensor<F> {
pub fn from_fp32(data: ArrayD<F>) -> Self {
let shape = data.shape().to_vec();
Self {
fp32_data: Some(data),
fp16_bits: None,
shape,
precision: PrecisionMode::FP32,
}
}
pub fn to_fp16(&mut self) -> Result<()> {
if self.fp16_bits.is_some() {
return Ok(()); }
let fp32 = self
.fp32_data
.as_ref()
.ok_or_else(|| NeuralError::InvalidState("No FP32 data available".to_string()))?;
let fp16_data: Vec<u16> = fp32
.iter()
.map(|&x| {
let f32_val = x.to_f64().unwrap_or(0.0) as f32;
Half::from_f32(f32_val).to_bits()
})
.collect();
self.fp16_bits = Some(
ArrayD::from_shape_vec(IxDyn(&self.shape), fp16_data).map_err(|e| {
NeuralError::ShapeMismatch(format!("FP16 conversion failed: {}", e))
})?,
);
self.precision = PrecisionMode::FP16Mixed;
Ok(())
}
pub fn to_fp32(&mut self) -> Result<()> {
if self.fp32_data.is_some() && self.precision == PrecisionMode::FP32 {
return Ok(()); }
let fp16 = self
.fp16_bits
.as_ref()
.ok_or_else(|| NeuralError::InvalidState("No FP16 data available".to_string()))?;
let fp32_data: Vec<F> = fp16
.iter()
.map(|&bits| {
let f32_val = Half::from_bits(bits).to_f32();
F::from(f32_val).unwrap_or_else(F::zero)
})
.collect();
self.fp32_data = Some(
ArrayD::from_shape_vec(IxDyn(&self.shape), fp32_data).map_err(|e| {
NeuralError::ShapeMismatch(format!("FP32 conversion failed: {}", e))
})?,
);
self.precision = PrecisionMode::FP32;
Ok(())
}
pub fn get_fp32(&mut self) -> Result<&ArrayD<F>> {
self.to_fp32()?;
self.fp32_data
.as_ref()
.ok_or_else(|| NeuralError::InvalidState("FP32 data not available".to_string()))
}
pub fn precision(&self) -> PrecisionMode {
self.precision
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn free_fp16(&mut self) {
self.fp16_bits = None;
}
pub fn free_fp32(&mut self) {
self.fp32_data = None;
}
pub fn is_valid(&self) -> bool {
self.fp32_data.is_some() || self.fp16_bits.is_some()
}
}
#[derive(Debug, Clone, Default)]
pub struct GradScalerStats {
pub loss_scale: f64,
pub total_steps: u64,
pub overflow_steps: u64,
pub steps_since_increase: u64,
pub num_scale_increases: u64,
pub num_scale_decreases: u64,
pub consecutive_overflows: u64,
}
#[derive(Debug)]
pub struct GradScaler {
config: MixedPrecisionConfig,
scale: Arc<RwLock<f64>>,
growth_tracker: AtomicU64,
found_inf: AtomicBool,
consecutive_overflows: AtomicU64,
stats: Arc<RwLock<GradScalerStats>>,
}
impl GradScaler {
pub fn new(config: MixedPrecisionConfig) -> Result<Self> {
config.validate()?;
let initial_scale = config.initial_loss_scale;
Ok(Self {
config,
scale: Arc::new(RwLock::new(initial_scale)),
growth_tracker: AtomicU64::new(0),
found_inf: AtomicBool::new(false),
consecutive_overflows: AtomicU64::new(0),
stats: Arc::new(RwLock::new(GradScalerStats {
loss_scale: initial_scale,
..Default::default()
})),
})
}
pub fn get_scale(&self) -> f64 {
*self.scale.read().unwrap_or_else(|e| e.into_inner())
}
pub fn scale_loss<F: Float + Debug + FromPrimitive>(&self, loss: F) -> Result<F> {
if !self.config.enabled {
return Ok(loss);
}
let scale = self.get_scale();
let scale_f = F::from(scale).ok_or_else(|| {
NeuralError::ComputationError("Failed to convert scale to loss type".to_string())
})?;
Ok(loss * scale_f)
}
pub fn unscale_gradients<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign>(
&self,
gradients: &mut [ArrayD<F>],
) -> Result<bool> {
if !self.config.enabled {
return Ok(false);
}
let scale = self.get_scale();
let inv_scale = F::from(1.0 / scale).ok_or_else(|| {
NeuralError::ComputationError("Failed to compute inverse scale".to_string())
})?;
let mut found_inf = false;
for grad in gradients.iter_mut() {
for val in grad.iter_mut() {
*val *= inv_scale;
let f64_val = val.to_f64().unwrap_or(f64::NAN);
if !f64_val.is_finite() {
found_inf = true;
}
}
}
self.found_inf.store(found_inf, Ordering::SeqCst);
Ok(found_inf)
}
pub fn check_gradients_for_overflow<F: Float + Debug + ScalarOperand>(
&self,
gradients: &[ArrayD<F>],
) -> bool {
for grad in gradients {
for val in grad.iter() {
let f64_val = val.to_f64().unwrap_or(f64::NAN);
if !f64_val.is_finite() {
return true;
}
}
}
false
}
pub fn clip_gradients<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign>(
&self,
gradients: &mut [ArrayD<F>],
) -> Result<Option<f64>> {
let threshold = match self.config.grad_clip_threshold {
Some(t) => t,
None => return Ok(None),
};
let mut global_norm_sq = 0.0_f64;
for grad in gradients.iter() {
for val in grad.iter() {
let f = val.to_f64().unwrap_or(0.0);
global_norm_sq += f * f;
}
}
let global_norm = global_norm_sq.sqrt();
if global_norm > threshold {
let clip_factor = F::from(threshold / global_norm).ok_or_else(|| {
NeuralError::ComputationError("Failed to compute clip factor".to_string())
})?;
for grad in gradients.iter_mut() {
for val in grad.iter_mut() {
*val *= clip_factor;
}
}
}
Ok(Some(global_norm))
}
pub fn update(&self) -> Result<bool> {
if !self.config.enabled {
return Ok(false);
}
let found_inf = self.found_inf.load(Ordering::SeqCst);
let mut scale = self.scale.write().unwrap_or_else(|e| e.into_inner());
let mut stats = self.stats.write().unwrap_or_else(|e| e.into_inner());
stats.total_steps += 1;
if found_inf {
*scale *= self.config.backoff_factor;
*scale = scale.max(self.config.min_loss_scale);
self.growth_tracker.store(0, Ordering::SeqCst);
let consec = self.consecutive_overflows.fetch_add(1, Ordering::SeqCst) + 1;
stats.overflow_steps += 1;
stats.num_scale_decreases += 1;
stats.consecutive_overflows = consec;
stats.steps_since_increase = 0;
if self.config.log_statistics {
eprintln!(
"[GradScaler] Overflow detected. Scale: {:.2} -> {:.2}, consecutive: {}",
*scale / self.config.backoff_factor,
*scale,
consec
);
}
if consec >= self.config.max_consecutive_overflows as u64 {
*scale *= self.config.backoff_factor;
*scale = scale.max(self.config.min_loss_scale);
if self.config.log_statistics {
eprintln!(
"[GradScaler] Aggressive backoff due to {} consecutive overflows. Scale: {:.2}",
consec, *scale
);
}
}
stats.loss_scale = *scale;
self.found_inf.store(false, Ordering::SeqCst);
return Ok(true); }
self.consecutive_overflows.store(0, Ordering::SeqCst);
let growth_count = self.growth_tracker.fetch_add(1, Ordering::SeqCst) + 1;
stats.steps_since_increase = growth_count;
if growth_count >= self.config.growth_interval as u64 {
let old_scale = *scale;
*scale *= self.config.growth_factor;
*scale = scale.min(self.config.max_loss_scale);
self.growth_tracker.store(0, Ordering::SeqCst);
stats.num_scale_increases += 1;
stats.steps_since_increase = 0;
if self.config.log_statistics && *scale != old_scale {
eprintln!(
"[GradScaler] Scale increased: {:.2} -> {:.2}",
old_scale, *scale
);
}
}
stats.loss_scale = *scale;
stats.consecutive_overflows = 0;
self.found_inf.store(false, Ordering::SeqCst);
Ok(false)
}
pub fn get_stats(&self) -> GradScalerStats {
self.stats.read().unwrap_or_else(|e| e.into_inner()).clone()
}
pub fn reset(&self) {
let mut scale = self.scale.write().unwrap_or_else(|e| e.into_inner());
*scale = self.config.initial_loss_scale;
self.growth_tracker.store(0, Ordering::SeqCst);
self.found_inf.store(false, Ordering::SeqCst);
self.consecutive_overflows.store(0, Ordering::SeqCst);
let mut stats = self.stats.write().unwrap_or_else(|e| e.into_inner());
*stats = GradScalerStats {
loss_scale: self.config.initial_loss_scale,
..Default::default()
};
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
}
#[derive(Debug)]
pub struct MasterWeights<F: Float + Debug + ScalarOperand> {
master_weights: HashMap<String, ArrayD<F>>,
compute_weights: HashMap<String, ArrayD<u16>>,
sync_on_update: bool,
precision_mode: PrecisionMode,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign> MasterWeights<F> {
pub fn new(precision_mode: PrecisionMode) -> Self {
Self {
master_weights: HashMap::new(),
compute_weights: HashMap::new(),
sync_on_update: true,
precision_mode,
}
}
pub fn register(&mut self, name: &str, weights: ArrayD<F>) -> Result<()> {
self.master_weights.insert(name.to_string(), weights);
self.sync_to_compute(name)?;
Ok(())
}
fn sync_to_compute(&mut self, name: &str) -> Result<()> {
let master = self
.master_weights
.get(name)
.ok_or_else(|| NeuralError::InvalidArgument(format!("Weight '{}' not found", name)))?;
let fp16_data: Vec<u16> = master
.iter()
.map(|&x| {
let f32_val = x.to_f64().unwrap_or(0.0) as f32;
Half::from_f32(f32_val).to_bits()
})
.collect();
let compute = ArrayD::from_shape_vec(IxDyn(master.shape()), fp16_data)
.map_err(|e| NeuralError::ShapeMismatch(format!("Sync failed: {}", e)))?;
self.compute_weights.insert(name.to_string(), compute);
Ok(())
}
pub fn get_compute_weights(&self, name: &str) -> Option<&ArrayD<u16>> {
self.compute_weights.get(name)
}
pub fn get_master_weights(&self, name: &str) -> Option<&ArrayD<F>> {
self.master_weights.get(name)
}
pub fn update_master_weights(
&mut self,
name: &str,
gradients: &ArrayD<F>,
learning_rate: F,
) -> Result<()> {
let master = self
.master_weights
.get_mut(name)
.ok_or_else(|| NeuralError::InvalidArgument(format!("Weight '{}' not found", name)))?;
for (w, g) in master.iter_mut().zip(gradients.iter()) {
*w -= learning_rate * *g;
}
if self.sync_on_update {
self.sync_to_compute(name)?;
}
Ok(())
}
pub fn sync_all(&mut self) -> Result<()> {
let names: Vec<String> = self.master_weights.keys().cloned().collect();
for name in names {
self.sync_to_compute(&name)?;
}
Ok(())
}
pub fn weight_names(&self) -> Vec<&String> {
self.master_weights.keys().collect()
}
pub fn contains(&self, name: &str) -> bool {
self.master_weights.contains_key(name)
}
pub fn len(&self) -> usize {
self.master_weights.len()
}
pub fn is_empty(&self) -> bool {
self.master_weights.is_empty()
}
pub fn precision_mode(&self) -> PrecisionMode {
self.precision_mode
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AmpContextState {
Normal,
Forward,
Backward,
Optimization,
}
#[derive(Debug)]
pub struct AutoMixedPrecision<F: Float + Debug + ScalarOperand> {
config: MixedPrecisionConfig,
scaler: Option<GradScaler>,
master_weights: Option<MasterWeights<F>>,
context_state: Arc<RwLock<AmpContextState>>,
active: AtomicBool,
_phantom: PhantomData<F>,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync>
AutoMixedPrecision<F>
{
pub fn new(config: MixedPrecisionConfig) -> Result<Self> {
config.validate()?;
let scaler = if config.enabled {
Some(GradScaler::new(config.clone())?)
} else {
None
};
let master_weights = if config.enabled {
Some(MasterWeights::new(config.precision_mode))
} else {
None
};
Ok(Self {
config,
scaler,
master_weights,
context_state: Arc::new(RwLock::new(AmpContextState::Normal)),
active: AtomicBool::new(false),
_phantom: PhantomData,
})
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn enter_forward(&self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let mut state = self
.context_state
.write()
.unwrap_or_else(|e| e.into_inner());
*state = AmpContextState::Forward;
self.active.store(true, Ordering::SeqCst);
Ok(())
}
pub fn exit_forward(&self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let mut state = self
.context_state
.write()
.unwrap_or_else(|e| e.into_inner());
*state = AmpContextState::Normal;
Ok(())
}
pub fn enter_backward(&self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let mut state = self
.context_state
.write()
.unwrap_or_else(|e| e.into_inner());
*state = AmpContextState::Backward;
Ok(())
}
pub fn exit_backward(&self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let mut state = self
.context_state
.write()
.unwrap_or_else(|e| e.into_inner());
*state = AmpContextState::Normal;
Ok(())
}
pub fn context_state(&self) -> AmpContextState {
*self.context_state.read().unwrap_or_else(|e| e.into_inner())
}
pub fn scale_loss(&self, loss: F) -> Result<F> {
match &self.scaler {
Some(scaler) => scaler.scale_loss(loss),
None => Ok(loss),
}
}
pub fn unscale_gradients(&self, gradients: &mut [ArrayD<F>]) -> Result<bool> {
match &self.scaler {
Some(scaler) => scaler.unscale_gradients(gradients),
None => Ok(false),
}
}
pub fn update_scaler(&self) -> Result<bool> {
match &self.scaler {
Some(scaler) => scaler.update(),
None => Ok(false),
}
}
pub fn get_loss_scale(&self) -> f64 {
self.scaler.as_ref().map_or(1.0, |s| s.get_scale())
}
pub fn register_weights(&mut self, name: &str, weights: ArrayD<F>) -> Result<()> {
if let Some(ref mut mw) = self.master_weights {
mw.register(name, weights)?;
}
Ok(())
}
pub fn get_scaler_stats(&self) -> Option<GradScalerStats> {
self.scaler.as_ref().map(|s| s.get_stats())
}
pub fn reset(&mut self) {
if let Some(ref scaler) = self.scaler {
scaler.reset();
}
if let Some(ref mut mw) = self.master_weights {
*mw = MasterWeights::new(self.config.precision_mode);
}
self.active.store(false, Ordering::SeqCst);
let mut state = self
.context_state
.write()
.unwrap_or_else(|e| e.into_inner());
*state = AmpContextState::Normal;
}
pub fn to_compute_precision(&self, tensor: &ArrayD<F>) -> Result<ArrayD<u16>> {
let fp16_data: Vec<u16> = tensor
.iter()
.map(|&x| {
let f32_val = x.to_f64().unwrap_or(0.0) as f32;
Half::from_f32(f32_val).to_bits()
})
.collect();
ArrayD::from_shape_vec(IxDyn(tensor.shape()), fp16_data)
.map_err(|e| NeuralError::ShapeMismatch(format!("Precision conversion failed: {}", e)))
}
pub fn from_compute_precision(&self, tensor: &ArrayD<u16>) -> Result<ArrayD<F>> {
let fp32_data: Vec<F> = tensor
.iter()
.map(|&bits| {
let f32_val = Half::from_bits(bits).to_f32();
F::from(f32_val).unwrap_or_else(F::zero)
})
.collect();
ArrayD::from_shape_vec(IxDyn(tensor.shape()), fp32_data)
.map_err(|e| NeuralError::ShapeMismatch(format!("Precision conversion failed: {}", e)))
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionStepResult<F: Float + Debug + NumAssign> {
pub scaled_loss: F,
pub unscaled_loss: F,
pub overflow_detected: bool,
pub step_skipped: bool,
pub loss_scale: f64,
pub grad_norm: Option<f64>,
}
#[derive(Debug)]
pub struct MixedPrecisionTrainer<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign> {
config: MixedPrecisionConfig,
amp: AutoMixedPrecision<F>,
total_steps: AtomicU64,
skipped_steps: AtomicU64,
training: AtomicBool,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync>
MixedPrecisionTrainer<F>
{
pub fn new(config: MixedPrecisionConfig) -> Result<Self> {
let amp = AutoMixedPrecision::new(config.clone())?;
Ok(Self {
config,
amp,
total_steps: AtomicU64::new(0),
skipped_steps: AtomicU64::new(0),
training: AtomicBool::new(false),
})
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn train(&self) {
self.training.store(true, Ordering::SeqCst);
}
pub fn eval(&self) {
self.training.store(false, Ordering::SeqCst);
}
pub fn is_training(&self) -> bool {
self.training.load(Ordering::SeqCst)
}
pub fn forward<L: Layer<F>>(&self, model: &L, input: &ArrayD<F>) -> Result<ArrayD<F>> {
self.amp.enter_forward()?;
let result = model.forward(input);
self.amp.exit_forward()?;
result
}
pub fn scale_loss(&self, loss: F) -> Result<F> {
self.amp.scale_loss(loss)
}
pub fn training_step<L: Layer<F>>(
&self,
model: &mut L,
input: &ArrayD<F>,
target: &ArrayD<F>,
loss_fn: impl Fn(&ArrayD<F>, &ArrayD<F>) -> Result<F>,
optimizer_step: impl FnOnce(&mut L, F) -> Result<()>,
) -> Result<MixedPrecisionStepResult<F>> {
self.total_steps.fetch_add(1, Ordering::SeqCst);
self.amp.enter_forward()?;
let output = model.forward(input)?;
self.amp.exit_forward()?;
let unscaled_loss = loss_fn(&output, target)?;
let scaled_loss = self.amp.scale_loss(unscaled_loss)?;
self.amp.enter_backward()?;
let mut gradients = model.gradients();
let overflow_detected = self.amp.unscale_gradients(&mut gradients)?;
model.set_gradients(&gradients)?;
self.amp.exit_backward()?;
let step_skipped = self.amp.update_scaler()?;
if step_skipped {
self.skipped_steps.fetch_add(1, Ordering::SeqCst);
} else {
let lr = F::from(0.001).unwrap_or_else(F::zero); optimizer_step(model, lr)?;
}
Ok(MixedPrecisionStepResult {
scaled_loss,
unscaled_loss,
overflow_detected,
step_skipped,
loss_scale: self.amp.get_loss_scale(),
grad_norm: None,
})
}
pub fn get_stats(&self) -> TrainingStats {
let scaler_stats = self.amp.get_scaler_stats();
TrainingStats {
total_steps: self.total_steps.load(Ordering::SeqCst),
skipped_steps: self.skipped_steps.load(Ordering::SeqCst),
current_loss_scale: self.amp.get_loss_scale(),
scaler_stats,
}
}
pub fn reset_stats(&mut self) {
self.total_steps.store(0, Ordering::SeqCst);
self.skipped_steps.store(0, Ordering::SeqCst);
self.amp.reset();
}
pub fn amp(&self) -> &AutoMixedPrecision<F> {
&self.amp
}
pub fn amp_mut(&mut self) -> &mut AutoMixedPrecision<F> {
&mut self.amp
}
pub fn config(&self) -> &MixedPrecisionConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct TrainingStats {
pub total_steps: u64,
pub skipped_steps: u64,
pub current_loss_scale: f64,
pub scaler_stats: Option<GradScalerStats>,
}
use crate::callbacks::{Callback, CallbackContext, CallbackTiming};
#[derive(Debug)]
pub struct MixedPrecisionCallback<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync,
> {
config: MixedPrecisionConfig,
scaler: GradScaler,
last_loss_scale: f64,
epoch_overflows: usize,
_phantom: PhantomData<F>,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync>
MixedPrecisionCallback<F>
{
pub fn new(config: MixedPrecisionConfig) -> Result<Self> {
let scaler = GradScaler::new(config.clone())?;
let initial_scale = config.initial_loss_scale;
Ok(Self {
config,
scaler,
last_loss_scale: initial_scale,
epoch_overflows: 0,
_phantom: PhantomData,
})
}
pub fn scaler(&self) -> &GradScaler {
&self.scaler
}
pub fn loss_scale(&self) -> f64 {
self.scaler.get_scale()
}
pub fn get_stats(&self) -> GradScalerStats {
self.scaler.get_stats()
}
}
impl<
F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + std::fmt::Display + Send + Sync,
> Callback<F> for MixedPrecisionCallback<F>
{
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
match timing {
CallbackTiming::BeforeEpoch => {
self.epoch_overflows = 0;
self.last_loss_scale = self.scaler.get_scale();
}
CallbackTiming::AfterBatch => {
if let Some(model) = context.model.as_mut() {
let mut gradients = model.gradients();
let overflow = self.scaler.unscale_gradients(&mut gradients)?;
if overflow {
self.epoch_overflows += 1;
}
let _skipped = self.scaler.update()?;
if !overflow {
model.set_gradients(&gradients)?;
}
}
}
CallbackTiming::AfterEpoch if self.config.log_statistics => {
let stats = self.scaler.get_stats();
eprintln!(
"[MixedPrecision] Epoch {} - Scale: {:.2}, Overflows: {}, Total skipped: {}",
context.epoch, stats.loss_scale, self.epoch_overflows, stats.overflow_steps
);
}
_ => {}
}
Ok(())
}
}
pub fn contains_inf_or_nan<F: Float + Debug + NumAssign>(tensor: &ArrayD<F>) -> bool {
tensor.iter().any(|x| {
let val = x.to_f64().unwrap_or(f64::NAN);
!val.is_finite()
})
}
pub fn tensor_norm<F: Float + Debug + NumAssign>(tensor: &ArrayD<F>) -> f64 {
let sum_sq: f64 = tensor
.iter()
.map(|x| {
let val = x.to_f64().unwrap_or(0.0);
val * val
})
.sum();
sum_sq.sqrt()
}
pub fn global_norm<F: Float + Debug + NumAssign>(tensors: &[ArrayD<F>]) -> f64 {
let sum_sq: f64 = tensors
.iter()
.flat_map(|t| t.iter())
.map(|x| {
let val = x.to_f64().unwrap_or(0.0);
val * val
})
.sum();
sum_sq.sqrt()
}
pub fn clip_by_value<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign>(
tensor: &mut ArrayD<F>,
max_value: f64,
) -> Result<()> {
let max_f = F::from(max_value)
.ok_or_else(|| NeuralError::ComputationError("Failed to convert max value".to_string()))?;
let min_f = F::from(-max_value)
.ok_or_else(|| NeuralError::ComputationError("Failed to convert min value".to_string()))?;
for val in tensor.iter_mut() {
if *val > max_f {
*val = max_f;
} else if *val < min_f {
*val = min_f;
}
}
Ok(())
}
pub fn clip_by_global_norm<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign>(
tensors: &mut [ArrayD<F>],
max_norm: f64,
) -> Result<f64> {
let current_norm = global_norm(tensors);
if current_norm > max_norm {
let scale = F::from(max_norm / current_norm).ok_or_else(|| {
NeuralError::ComputationError("Failed to compute clip scale".to_string())
})?;
for tensor in tensors.iter_mut() {
for val in tensor.iter_mut() {
*val *= scale;
}
}
}
Ok(current_norm)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_mixed_precision_config_builder() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.growth_factor(2.0)
.backoff_factor(0.5)
.growth_interval(100)
.build()
.expect("Config build should succeed");
assert!(config.enabled);
assert!((config.initial_loss_scale - 1024.0).abs() < 1e-10);
assert!((config.growth_factor - 2.0).abs() < 1e-10);
assert!((config.backoff_factor - 0.5).abs() < 1e-10);
assert_eq!(config.growth_interval, 100);
}
#[test]
fn test_config_validation() {
let result = MixedPrecisionConfig::builder()
.initial_loss_scale(-1.0)
.build();
assert!(result.is_err());
let result = MixedPrecisionConfig::builder().growth_factor(0.5).build();
assert!(result.is_err());
let result = MixedPrecisionConfig::builder().backoff_factor(1.5).build();
assert!(result.is_err());
}
#[test]
fn test_half_precision_conversion() {
let val = 1.5_f32;
let half = Half::from_f32(val);
let back = half.to_f32();
assert!((val - back).abs() < 0.01);
let val = -std::f32::consts::PI;
let half = Half::from_f32(val);
let back = half.to_f32();
assert!((val - back).abs() < 0.01);
let half = Half::from_f32(0.0);
assert!(half.is_zero());
let half = Half::from_f32(f32::INFINITY);
assert!(half.is_infinite());
let half = Half::from_f32(f32::NAN);
assert!(half.is_nan());
}
#[test]
fn test_grad_scaler_creation() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(65536.0)
.build()
.expect("Config should be valid");
let scaler = GradScaler::new(config).expect("Scaler creation should succeed");
assert!((scaler.get_scale() - 65536.0).abs() < 1e-10);
}
#[test]
fn test_grad_scaler_scale_loss() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.build()
.expect("Config should be valid");
let scaler = GradScaler::new(config).expect("Scaler creation should succeed");
let loss = 0.5_f64;
let scaled = scaler.scale_loss(loss).expect("Scale should succeed");
assert!((scaled - 512.0).abs() < 1e-10);
}
#[test]
fn test_grad_scaler_unscale() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.build()
.expect("Config should be valid");
let scaler = GradScaler::new(config).expect("Scaler creation should succeed");
let mut gradients =
vec![
Array::from_shape_vec(vec![2, 2], vec![1024.0_f64, 2048.0, 512.0, 256.0])
.expect("Array creation should succeed")
.into_dyn(),
];
let overflow = scaler
.unscale_gradients(&mut gradients)
.expect("Unscale should succeed");
assert!(!overflow);
assert!((gradients[0][[0, 0]] - 1.0).abs() < 1e-10);
assert!((gradients[0][[0, 1]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_grad_scaler_overflow_detection() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.build()
.expect("Config should be valid");
let scaler = GradScaler::new(config).expect("Scaler creation should succeed");
let mut gradients = vec![Array::from_shape_vec(vec![2], vec![f64::INFINITY, 1.0])
.expect("Array creation should succeed")
.into_dyn()];
let overflow = scaler
.unscale_gradients(&mut gradients)
.expect("Unscale should succeed");
assert!(overflow);
}
#[test]
fn test_grad_scaler_update_no_overflow() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.growth_interval(2)
.growth_factor(2.0)
.build()
.expect("Config should be valid");
let scaler = GradScaler::new(config).expect("Scaler creation should succeed");
let skipped = scaler.update().expect("Update should succeed");
assert!(!skipped);
let skipped = scaler.update().expect("Update should succeed");
assert!(!skipped);
assert!((scaler.get_scale() - 2048.0).abs() < 1e-10);
}
#[test]
fn test_grad_scaler_update_with_overflow() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.backoff_factor(0.5)
.build()
.expect("Config should be valid");
let scaler = GradScaler::new(config).expect("Scaler creation should succeed");
let mut gradients = vec![Array::from_shape_vec(vec![1], vec![f64::INFINITY])
.expect("Array creation should succeed")
.into_dyn()];
scaler
.unscale_gradients(&mut gradients)
.expect("Unscale should succeed");
let skipped = scaler.update().expect("Update should succeed");
assert!(skipped);
assert!((scaler.get_scale() - 512.0).abs() < 1e-10);
}
#[test]
fn test_mixed_precision_tensor() {
let data: ArrayD<f64> = Array::from_shape_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0])
.expect("Array creation should succeed")
.into_dyn();
let mut tensor = MixedPrecisionTensor::from_fp32(data);
assert_eq!(tensor.precision(), PrecisionMode::FP32);
tensor.to_fp16().expect("FP16 conversion should succeed");
assert_eq!(tensor.precision(), PrecisionMode::FP16Mixed);
tensor.to_fp32().expect("FP32 conversion should succeed");
assert_eq!(tensor.precision(), PrecisionMode::FP32);
let fp32 = tensor.get_fp32().expect("Get FP32 should succeed");
assert!((fp32[[0, 0]] - 1.0).abs() < 0.01);
}
#[test]
fn test_master_weights() {
let mut weights: MasterWeights<f64> = MasterWeights::new(PrecisionMode::FP16Mixed);
let w1: ArrayD<f64> = Array::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Array creation should succeed")
.into_dyn();
weights
.register("layer1", w1)
.expect("Register should succeed");
assert!(weights.contains("layer1"));
assert_eq!(weights.len(), 1);
let master = weights.get_master_weights("layer1");
assert!(master.is_some());
let compute = weights.get_compute_weights("layer1");
assert!(compute.is_some());
}
#[test]
fn test_auto_mixed_precision() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.build()
.expect("Config should be valid");
let amp: AutoMixedPrecision<f64> =
AutoMixedPrecision::new(config).expect("AMP creation should succeed");
assert!(amp.is_enabled());
assert_eq!(amp.context_state(), AmpContextState::Normal);
amp.enter_forward().expect("Enter forward should succeed");
assert_eq!(amp.context_state(), AmpContextState::Forward);
amp.exit_forward().expect("Exit forward should succeed");
assert_eq!(amp.context_state(), AmpContextState::Normal);
}
#[test]
fn test_mixed_precision_trainer_creation() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.build()
.expect("Config should be valid");
let trainer: MixedPrecisionTrainer<f64> =
MixedPrecisionTrainer::new(config).expect("Trainer creation should succeed");
assert!(trainer.is_enabled());
assert!(!trainer.is_training());
trainer.train();
assert!(trainer.is_training());
trainer.eval();
assert!(!trainer.is_training());
}
#[test]
fn test_utility_functions() {
let normal: ArrayD<f64> = Array::from_shape_vec(vec![2], vec![1.0, 2.0])
.expect("Array creation should succeed")
.into_dyn();
assert!(!contains_inf_or_nan(&normal));
let with_nan: ArrayD<f64> = Array::from_shape_vec(vec![2], vec![1.0, f64::NAN])
.expect("Array creation should succeed")
.into_dyn();
assert!(contains_inf_or_nan(&with_nan));
let t: ArrayD<f64> = Array::from_shape_vec(vec![2], vec![3.0, 4.0])
.expect("Array creation should succeed")
.into_dyn();
assert!((tensor_norm(&t) - 5.0).abs() < 1e-10);
let t1: ArrayD<f64> = Array::from_shape_vec(vec![2], vec![1.0, 2.0])
.expect("Array creation should succeed")
.into_dyn();
let t2: ArrayD<f64> = Array::from_shape_vec(vec![2], vec![2.0, 0.0])
.expect("Array creation should succeed")
.into_dyn();
let norm = global_norm(&[t1, t2]);
assert!((norm - 3.0).abs() < 1e-10);
}
#[test]
fn test_clip_by_value() {
let mut tensor: ArrayD<f64> = Array::from_shape_vec(vec![4], vec![-5.0, -1.0, 1.0, 5.0])
.expect("Array creation should succeed")
.into_dyn();
clip_by_value(&mut tensor, 2.0).expect("Clip should succeed");
assert!((tensor[[0]] - (-2.0)).abs() < 1e-10);
assert!((tensor[[1]] - (-1.0)).abs() < 1e-10);
assert!((tensor[[2]] - 1.0).abs() < 1e-10);
assert!((tensor[[3]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_clip_by_global_norm() {
let mut tensors = vec![Array::from_shape_vec(vec![2], vec![3.0, 4.0])
.expect("Array creation should succeed")
.into_dyn()];
let original_norm = clip_by_global_norm(&mut tensors, 2.5).expect("Clip should succeed");
assert!((original_norm - 5.0).abs() < 1e-10);
let new_norm = global_norm(&tensors);
assert!((new_norm - 2.5).abs() < 1e-10);
}
#[test]
fn test_fp32_operation_check() {
let config = MixedPrecisionConfig::default();
assert!(config.should_use_fp32(FP32Operation::LossComputation));
assert!(config.should_use_fp32(FP32Operation::Softmax));
assert!(config.should_use_fp32(FP32Operation::LayerNorm));
assert!(!config.should_use_fp32(FP32Operation::Embedding)); }
#[test]
fn test_grad_scaler_reset() {
let config = MixedPrecisionConfig::builder()
.enabled(true)
.initial_loss_scale(1024.0)
.backoff_factor(0.5)
.build()
.expect("Config should be valid");
let scaler = GradScaler::new(config).expect("Scaler creation should succeed");
let mut gradients = vec![Array::from_shape_vec(vec![1], vec![f64::INFINITY])
.expect("Array creation should succeed")
.into_dyn()];
scaler
.unscale_gradients(&mut gradients)
.expect("Unscale should succeed");
scaler.update().expect("Update should succeed");
assert!((scaler.get_scale() - 512.0).abs() < 1e-10);
scaler.reset();
assert!((scaler.get_scale() - 1024.0).abs() < 1e-10);
}
}