#[cfg(feature = "gpu")]
use crate::tensor::TensorStorage;
use crate::{DType, Result, Tensor, TensorError};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OptimizationLevel {
O0,
#[default]
O1,
O2,
O3,
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub enabled: bool,
pub opt_level: OptimizationLevel,
pub loss_scale: f32,
pub dynamic_loss_scaling: bool,
pub min_loss_scale: f32,
pub max_loss_scale: f32,
pub scale_growth_interval: usize,
pub scale_growth_factor: f32,
pub scale_backoff_factor: f32,
pub fp32_operations: Vec<String>,
pub fp16_blacklist: Vec<String>,
pub keep_master_weights: bool,
pub enable_gradient_clipping: bool,
pub gradient_clip_norm: f32,
}
impl Default for MixedPrecisionConfig {
fn default() -> Self {
Self {
enabled: false,
opt_level: OptimizationLevel::O1,
loss_scale: 65536.0, dynamic_loss_scaling: true,
min_loss_scale: 1.0,
max_loss_scale: 65536.0 * 65536.0, scale_growth_interval: 2000,
scale_growth_factor: 2.0,
scale_backoff_factor: 0.5,
fp32_operations: vec![
"softmax".to_string(),
"log_softmax".to_string(),
"cross_entropy".to_string(),
"batch_norm".to_string(),
"layer_norm".to_string(),
],
fp16_blacklist: vec![
"exp".to_string(),
"log".to_string(),
"sqrt".to_string(),
"pow".to_string(),
],
keep_master_weights: true,
enable_gradient_clipping: false,
gradient_clip_norm: 1.0,
}
}
}
impl MixedPrecisionConfig {
pub fn o0() -> Self {
Self {
enabled: false,
opt_level: OptimizationLevel::O0,
..Default::default()
}
}
pub fn o1() -> Self {
Self {
enabled: true,
opt_level: OptimizationLevel::O1,
loss_scale: 65536.0,
keep_master_weights: true,
..Default::default()
}
}
pub fn o2() -> Self {
Self {
enabled: true,
opt_level: OptimizationLevel::O2,
loss_scale: 32768.0,
keep_master_weights: true,
fp32_operations: vec!["batch_norm".to_string(), "layer_norm".to_string()],
..Default::default()
}
}
pub fn o3() -> Self {
Self {
enabled: true,
opt_level: OptimizationLevel::O3,
loss_scale: 16384.0,
keep_master_weights: false,
fp32_operations: vec![],
enable_gradient_clipping: true,
gradient_clip_norm: 1.0,
..Default::default()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct MixedPrecisionStatistics {
pub total_steps: usize,
pub overflow_steps: usize,
pub scale_increases: usize,
pub scale_decreases: usize,
pub cumulative_scale: f64,
pub min_scale_reached: f32,
pub max_scale_reached: f32,
}
impl MixedPrecisionStatistics {
pub fn overflow_rate(&self) -> f64 {
if self.total_steps == 0 {
0.0
} else {
self.overflow_steps as f64 / self.total_steps as f64
}
}
pub fn average_scale(&self) -> f64 {
if self.total_steps == 0 {
0.0
} else {
self.cumulative_scale / self.total_steps as f64
}
}
pub fn scale_stability(&self) -> f64 {
if self.total_steps == 0 {
0.0
} else {
(self.scale_increases + self.scale_decreases) as f64 / self.total_steps as f64
}
}
}
#[derive(Debug)]
pub struct MixedPrecisionState {
pub config: MixedPrecisionConfig,
pub current_loss_scale: f32,
pub steps_since_overflow: usize,
pub overflow_detected: bool,
pub autocast_stack: Vec<DType>,
pub statistics: MixedPrecisionStatistics,
}
impl MixedPrecisionState {
pub fn new(config: MixedPrecisionConfig) -> Self {
let loss_scale = config.loss_scale;
let stats = MixedPrecisionStatistics {
min_scale_reached: loss_scale,
max_scale_reached: loss_scale,
..Default::default()
};
Self {
config,
current_loss_scale: loss_scale,
steps_since_overflow: 0,
overflow_detected: false,
autocast_stack: Vec::new(),
statistics: stats,
}
}
pub fn update_loss_scale(&mut self, has_overflow: bool) {
self.statistics.total_steps += 1;
self.statistics.cumulative_scale += self.current_loss_scale as f64;
if !self.config.dynamic_loss_scaling {
return;
}
if has_overflow {
let new_scale = (self.current_loss_scale * self.config.scale_backoff_factor)
.max(self.config.min_loss_scale);
self.current_loss_scale = new_scale;
self.steps_since_overflow = 0;
self.overflow_detected = true;
self.statistics.overflow_steps += 1;
self.statistics.scale_decreases += 1;
self.statistics.min_scale_reached = self.statistics.min_scale_reached.min(new_scale);
} else {
self.steps_since_overflow += 1;
self.overflow_detected = false;
if self.steps_since_overflow >= self.config.scale_growth_interval {
let new_scale = (self.current_loss_scale * self.config.scale_growth_factor)
.min(self.config.max_loss_scale);
self.current_loss_scale = new_scale;
self.steps_since_overflow = 0;
self.statistics.scale_increases += 1;
self.statistics.max_scale_reached =
self.statistics.max_scale_reached.max(new_scale);
}
}
}
pub fn get_statistics(&self) -> MixedPrecisionStatistics {
self.statistics.clone()
}
pub fn reset_statistics(&mut self) {
self.statistics = MixedPrecisionStatistics::default();
self.statistics.min_scale_reached = self.current_loss_scale;
self.statistics.max_scale_reached = self.current_loss_scale;
}
pub fn check_gradient_overflow<T>(&self, gradients: &[Tensor<T>]) -> bool
where
T: scirs2_core::num_traits::Float
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
for grad in gradients {
if has_inf_or_nan(grad) {
return true;
}
}
false
}
pub fn scale_loss<T>(&self, loss: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ scirs2_core::num_traits::Float
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
if !self.config.enabled {
return Ok(loss.clone());
}
let scale_value =
T::from(self.current_loss_scale).ok_or_else(|| TensorError::ComputeError {
operation: "scale_loss".to_string(),
details: format!(
"Failed to convert loss scale {} to target type",
self.current_loss_scale
),
retry_possible: false,
context: None,
})?;
let scale_tensor = Tensor::from_scalar(scale_value);
loss.mul(&scale_tensor)
}
pub fn unscale_gradients<T>(&self, gradients: &mut [Tensor<T>]) -> Result<()>
where
T: Clone
+ Default
+ scirs2_core::num_traits::Float
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
if !self.config.enabled {
return Ok(());
}
let scale_factor =
T::from(1.0 / self.current_loss_scale).ok_or_else(|| TensorError::ComputeError {
operation: "unscale_gradients".to_string(),
details: format!(
"Failed to convert unscale factor {} to target type",
1.0 / self.current_loss_scale
),
retry_possible: false,
context: None,
})?;
let scale_tensor = Tensor::from_scalar(scale_factor);
for grad in gradients.iter_mut() {
*grad = grad.mul(&scale_tensor)?;
}
Ok(())
}
}
pub struct AutocastContext {
enabled: bool,
target_dtype: DType,
operation_overrides: HashMap<String, DType>,
}
impl AutocastContext {
pub fn new(enabled: bool, target_dtype: DType) -> Self {
let mut operation_overrides = HashMap::new();
let fp32_ops = vec![
"softmax",
"log_softmax",
"cross_entropy",
"batch_norm",
"layer_norm",
"group_norm",
"exp",
"log",
"sqrt",
];
for op in fp32_ops {
operation_overrides.insert(op.to_string(), DType::Float32);
}
Self {
enabled,
target_dtype,
operation_overrides,
}
}
pub fn get_operation_dtype(&self, operation_name: &str, default_dtype: DType) -> DType {
if !self.enabled {
return default_dtype;
}
if let Some(&override_dtype) = self.operation_overrides.get(operation_name) {
return override_dtype;
}
match default_dtype {
DType::Float32 => self.target_dtype,
DType::Float64 => DType::Float32, _ => default_dtype, }
}
pub fn should_autocast(&self, operation_name: &str, input_dtype: DType) -> bool {
self.enabled
&& matches!(input_dtype, DType::Float32 | DType::Float64)
&& !self.operation_overrides.contains_key(operation_name)
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
pub fn to_half_f32(input: &Tensor<f32>) -> Result<Tensor<crate::half_precision::f16>> {
use crate::half_precision::f16;
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let f16_data: Vec<f16> = data.iter().map(|&v| f16::from_f32(v)).collect();
Tensor::from_vec(f16_data, input.shape().dims())
}
pub fn to_half_f64(input: &Tensor<f64>) -> Result<Tensor<crate::half_precision::f16>> {
use crate::half_precision::f16;
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let f16_data: Vec<f16> = data.iter().map(|&v| f16::from_f32(v as f32)).collect();
Tensor::from_vec(f16_data, input.shape().dims())
}
pub fn to_bfloat16_f32(input: &Tensor<f32>) -> Result<Tensor<crate::half_precision::bf16>> {
use crate::half_precision::bf16;
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let bf16_data: Vec<bf16> = data.iter().map(|&v| bf16::from_f32(v)).collect();
Tensor::from_vec(bf16_data, input.shape().dims())
}
pub fn to_bfloat16_f64(input: &Tensor<f64>) -> Result<Tensor<crate::half_precision::bf16>> {
use crate::half_precision::bf16;
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let bf16_data: Vec<bf16> = data.iter().map(|&v| bf16::from_f32(v as f32)).collect();
Tensor::from_vec(bf16_data, input.shape().dims())
}
pub fn to_half<T>(_input: &Tensor<T>) -> Result<Tensor<crate::half_precision::f16>>
where
T: Clone + Send + Sync + 'static,
{
Err(TensorError::unsupported_operation_simple(
"Generic half precision conversion not implemented - use to_half_f32 or to_half_f64"
.to_string(),
))
}
pub fn from_half_f32(input: &Tensor<crate::half_precision::f16>) -> Result<Tensor<f32>> {
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let f32_data: Vec<f32> = data.iter().map(|&v| v.to_f32()).collect();
Tensor::from_vec(f32_data, input.shape().dims())
}
pub fn from_half_f64(input: &Tensor<crate::half_precision::f16>) -> Result<Tensor<f64>> {
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let f64_data: Vec<f64> = data.iter().map(|&v| v.to_f32() as f64).collect();
Tensor::from_vec(f64_data, input.shape().dims())
}
pub fn from_bfloat16_f32(input: &Tensor<crate::half_precision::bf16>) -> Result<Tensor<f32>> {
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let f32_data: Vec<f32> = data.iter().map(|&v| v.to_f32()).collect();
Tensor::from_vec(f32_data, input.shape().dims())
}
pub fn from_bfloat16_f64(input: &Tensor<crate::half_precision::bf16>) -> Result<Tensor<f64>> {
let data = input.as_slice().ok_or_else(|| {
TensorError::unsupported_operation_simple(
"Cannot access tensor data for conversion".to_string(),
)
})?;
let f64_data: Vec<f64> = data.iter().map(|&v| v.to_f32() as f64).collect();
Tensor::from_vec(f64_data, input.shape().dims())
}
pub fn from_half<T>(_input: &Tensor<crate::half_precision::f16>) -> Result<Tensor<T>>
where
T: Clone + Send + Sync + 'static,
{
Err(TensorError::unsupported_operation_simple(
"Generic from half precision conversion not implemented - use from_half_f32 or from_half_f64".to_string()
))
}
#[derive(Debug)]
pub struct GradientScaler {
state: MixedPrecisionState,
}
impl GradientScaler {
pub fn new(config: MixedPrecisionConfig) -> Self {
Self {
state: MixedPrecisionState::new(config),
}
}
pub fn scale<T>(&self, loss: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ scirs2_core::num_traits::Float
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
self.state.scale_loss(loss)
}
pub fn unscale_gradients_and_check<T>(&mut self, gradients: &mut [Tensor<T>]) -> Result<bool>
where
T: Clone
+ Default
+ scirs2_core::num_traits::Float
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let has_overflow = self.state.check_gradient_overflow(gradients);
if has_overflow {
self.state.update_loss_scale(true);
return Ok(false);
}
self.state.unscale_gradients(gradients)?;
self.state.update_loss_scale(false);
Ok(true)
}
pub fn get_scale(&self) -> f32 {
self.state.current_loss_scale
}
pub fn update_config(&mut self, config: MixedPrecisionConfig) {
if config.loss_scale != self.state.config.loss_scale {
self.state.current_loss_scale = config.loss_scale;
}
self.state.config = config;
}
pub fn is_enabled(&self) -> bool {
self.state.config.enabled
}
pub fn get_config(&self) -> MixedPrecisionConfig {
self.state.config.clone()
}
pub fn get_statistics(&self) -> MixedPrecisionStatistics {
self.state.get_statistics()
}
pub fn reset(&mut self) {
self.state.current_loss_scale = self.state.config.loss_scale;
self.state.steps_since_overflow = 0;
self.state.overflow_detected = false;
self.state.reset_statistics();
}
pub fn clip_gradients(&self, gradients: &mut [f32]) -> f32 {
if !self.state.config.enable_gradient_clipping {
return 0.0;
}
let global_norm = gradients.iter().map(|&g| g * g).sum::<f32>().sqrt();
if global_norm > self.state.config.gradient_clip_norm {
let clip_coef = self.state.config.gradient_clip_norm / global_norm;
for grad in gradients.iter_mut() {
*grad *= clip_coef;
}
}
global_norm
}
}
#[derive(Debug, Clone)]
pub struct MasterWeightsManager {
master_weights: HashMap<String, Vec<f32>>,
enabled: bool,
}
impl MasterWeightsManager {
pub fn new(enabled: bool) -> Self {
Self {
master_weights: HashMap::new(),
enabled,
}
}
pub fn store(&mut self, name: String, weights: Vec<f32>) {
if !self.enabled {
return;
}
self.master_weights.insert(name, weights);
}
pub fn retrieve(&self, name: &str) -> Option<&Vec<f32>> {
self.master_weights.get(name)
}
pub fn update_from_gradients(&mut self, name: &str, gradients: &[f32], learning_rate: f32) {
if !self.enabled {
return;
}
if let Some(master) = self.master_weights.get_mut(name) {
for (weight, &grad) in master.iter_mut().zip(gradients.iter()) {
*weight -= learning_rate * grad;
}
}
}
pub fn copy_to_working(&self, name: &str) -> Option<Vec<f32>> {
self.master_weights.get(name).cloned()
}
pub fn clear(&mut self) {
self.master_weights.clear();
}
pub fn count(&self) -> usize {
self.master_weights.len()
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
fn has_inf_or_nan<T>(tensor: &Tensor<T>) -> bool
where
T: scirs2_core::num_traits::Float + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
{
if let Some(data) = tensor.as_slice() {
return data
.iter()
.any(|&value| value.is_infinite() || value.is_nan());
}
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
use crate::gpu::{buffer::GpuBuffer, ReductionOp};
if let TensorStorage::Gpu(ref gpu_buffer) = tensor.storage {
match crate::gpu::ops::execute_reduction_op(
gpu_buffer,
ReductionOp::InfNanDetection,
None, ) {
Ok(result_buffer) => {
match result_buffer.to_cpu() {
Ok(result_data) => {
if !result_data.is_empty() {
return !result_data[0].is_zero();
}
}
Err(_) => {
return false;
}
}
}
Err(_) => {
return false;
}
}
}
}
false
}
#[cfg(not(feature = "gpu"))]
{
false
}
}
pub fn enable_autocast() -> AutocastContext {
AutocastContext::new(true, DType::Float16)
}
pub fn enable_autocast_bfloat16() -> AutocastContext {
AutocastContext::new(true, DType::BFloat16)
}
pub fn disable_autocast() -> AutocastContext {
AutocastContext::new(false, DType::Float32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mixed_precision_config() {
let config = MixedPrecisionConfig::default();
assert!(!config.enabled);
assert_eq!(config.loss_scale, 65536.0);
assert!(config.dynamic_loss_scaling);
}
#[test]
fn test_autocast_context() {
let ctx = enable_autocast();
assert!(ctx.enabled);
assert_eq!(
ctx.get_operation_dtype("softmax", DType::Float32),
DType::Float32
);
assert_eq!(
ctx.get_operation_dtype("conv2d", DType::Float32),
DType::Float16
);
let ctx_bf16 = enable_autocast_bfloat16();
assert_eq!(
ctx_bf16.get_operation_dtype("conv2d", DType::Float32),
DType::BFloat16
);
}
#[test]
fn test_gradient_scaler() {
let config = MixedPrecisionConfig {
enabled: true,
..Default::default()
};
let scaler = GradientScaler::new(config);
assert_eq!(scaler.get_scale(), 65536.0);
}
#[test]
fn test_half_precision_conversions() {
use crate::Tensor;
let f32_tensor = Tensor::<f32>::from_vec(vec![1.0, 2.5, -std::f32::consts::PI], &[3])
.expect("test: from_vec should succeed");
let f16_tensor = to_half_f32(&f32_tensor).expect("test: to_half_f32 should succeed");
let f32_back = from_half_f32(&f16_tensor).expect("test: from_half_f32 should succeed");
let original_data = f32_tensor.as_slice().expect("tensor should be contiguous");
let converted_data = f32_back.as_slice().expect("tensor should be contiguous");
for (orig, conv) in original_data.iter().zip(converted_data.iter()) {
assert!(
(orig - conv).abs() < 0.01,
"f16 conversion error too large: {} vs {}",
orig,
conv
);
}
let bf16_tensor =
to_bfloat16_f32(&f32_tensor).expect("test: to_bfloat16_f32 should succeed");
let f32_back_bf16 =
from_bfloat16_f32(&bf16_tensor).expect("test: from_bfloat16_f32 should succeed");
let bf16_data = f32_back_bf16
.as_slice()
.expect("tensor should be contiguous");
for (orig, conv) in original_data.iter().zip(bf16_data.iter()) {
assert!(
(orig - conv).abs() < 0.001,
"bf16 conversion error too large: {} vs {}",
orig,
conv
);
}
}
#[test]
fn test_mixed_precision_dtype_mapping() {
use crate::DType;
assert_eq!(DType::Float16.size(), 2);
assert_eq!(DType::BFloat16.size(), 2);
assert_eq!(DType::Float32.size(), 4);
assert_eq!(DType::Float16.name(), "float16");
assert_eq!(DType::BFloat16.name(), "bfloat16");
}
#[test]
fn test_optimization_levels() {
let o0 = MixedPrecisionConfig::o0();
assert_eq!(o0.opt_level, OptimizationLevel::O0);
assert!(!o0.enabled);
let o1 = MixedPrecisionConfig::o1();
assert_eq!(o1.opt_level, OptimizationLevel::O1);
assert!(o1.enabled);
assert!(o1.keep_master_weights);
let o2 = MixedPrecisionConfig::o2();
assert_eq!(o2.opt_level, OptimizationLevel::O2);
assert!(o2.enabled);
assert_eq!(o2.fp32_operations.len(), 2);
let o3 = MixedPrecisionConfig::o3();
assert_eq!(o3.opt_level, OptimizationLevel::O3);
assert!(o3.enabled);
assert!(!o3.keep_master_weights);
assert!(o3.enable_gradient_clipping);
assert_eq!(o3.fp32_operations.len(), 0); }
#[test]
fn test_mixed_precision_statistics() {
let stats = MixedPrecisionStatistics {
total_steps: 100,
overflow_steps: 5,
scale_increases: 10,
scale_decreases: 5,
cumulative_scale: 6553600.0, ..Default::default()
};
assert_eq!(stats.overflow_rate(), 0.05);
assert_eq!(stats.average_scale(), 65536.0);
assert_eq!(stats.scale_stability(), 0.15);
}
#[test]
fn test_loss_scale_update_with_statistics() {
let config = MixedPrecisionConfig {
enabled: true,
dynamic_loss_scaling: true,
loss_scale: 1024.0,
scale_growth_factor: 2.0,
scale_backoff_factor: 0.5,
scale_growth_interval: 2,
..Default::default()
};
let mut state = MixedPrecisionState::new(config);
let initial_scale = state.current_loss_scale;
state.update_loss_scale(true);
assert_eq!(state.current_loss_scale, initial_scale * 0.5);
assert_eq!(state.statistics.overflow_steps, 1);
assert_eq!(state.statistics.scale_decreases, 1);
assert_eq!(state.statistics.min_scale_reached, initial_scale * 0.5);
state.update_loss_scale(false);
state.update_loss_scale(false);
assert_eq!(state.current_loss_scale, initial_scale); assert_eq!(state.statistics.scale_increases, 1);
}
#[test]
fn test_gradient_clipping() {
let config = MixedPrecisionConfig {
enabled: true,
enable_gradient_clipping: true,
gradient_clip_norm: 1.0,
..Default::default()
};
let scaler = GradientScaler::new(config);
let mut gradients = vec![0.6, 0.8];
let norm = scaler.clip_gradients(&mut gradients);
assert!((norm - 1.0).abs() < 1e-5);
assert!((gradients[0] - 0.6).abs() < 1e-5);
let mut large_gradients = vec![3.0, 4.0];
let large_norm = scaler.clip_gradients(&mut large_gradients);
assert!((large_norm - 5.0).abs() < 1e-5);
let clipped_norm: f32 = large_gradients.iter().map(|&g| g * g).sum::<f32>().sqrt();
assert!((clipped_norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_master_weights_manager() {
let mut manager = MasterWeightsManager::new(true);
manager.store("layer1".to_string(), vec![1.0, 2.0, 3.0]);
manager.store("layer2".to_string(), vec![4.0, 5.0]);
assert_eq!(manager.count(), 2);
assert!(manager.is_enabled());
let weights = manager.retrieve("layer1");
assert_eq!(weights, Some(&vec![1.0, 2.0, 3.0]));
manager.update_from_gradients("layer1", &[0.1, 0.2, 0.3], 0.1);
let updated = manager
.retrieve("layer1")
.expect("test: retrieve should succeed");
assert!((updated[0] - 0.99).abs() < 1e-5); assert!((updated[1] - 1.98).abs() < 1e-5); assert!((updated[2] - 2.97).abs() < 1e-5);
let working_copy = manager.copy_to_working("layer1");
assert!(working_copy.is_some());
manager.clear();
assert_eq!(manager.count(), 0);
}
#[test]
fn test_scaler_reset() {
let config = MixedPrecisionConfig {
enabled: true,
loss_scale: 1024.0,
..Default::default()
};
let mut scaler = GradientScaler::new(config);
scaler.state.current_loss_scale = 512.0;
scaler.state.steps_since_overflow = 10;
scaler.state.statistics.total_steps = 100;
scaler.reset();
assert_eq!(scaler.get_scale(), 1024.0);
assert_eq!(scaler.state.steps_since_overflow, 0);
assert_eq!(scaler.state.statistics.total_steps, 0);
}
#[test]
fn test_get_scaler_statistics() {
let config = MixedPrecisionConfig {
enabled: true,
..Default::default()
};
let mut scaler = GradientScaler::new(config);
scaler.state.update_loss_scale(false);
scaler.state.update_loss_scale(true);
let stats = scaler.get_statistics();
assert_eq!(stats.total_steps, 2);
assert_eq!(stats.overflow_steps, 1);
}
#[test]
fn test_disabled_master_weights() {
let mut manager = MasterWeightsManager::new(false);
manager.store("layer1".to_string(), vec![1.0, 2.0]);
assert_eq!(manager.count(), 0);
assert!(!manager.is_enabled());
}
}