mod autocast;
mod dtype_utils;
mod grad_scaler;
mod optimizer_wrapper;
pub use autocast::{autocast, maybe_autocast_f32, AutocastContext, AutocastMode};
pub use dtype_utils::{
cast_bf16_to_fp32, cast_tensor, cast_to_bf16, cast_to_fp16, cast_to_fp32, MixedPrecisionTensor,
};
pub use grad_scaler::{GradScaler, ScalerState, ScalerStats, StepResult};
pub use optimizer_wrapper::{AMPOptimizer, ParamGroup, TrainingStats};
use crate::dtype::DType;
use std::sync::{Arc, RwLock};
pub struct AMPConfig {
pub enabled: bool,
pub dtype: DType,
pub init_scale: f32,
pub growth_factor: f32,
pub backoff_factor: f32,
pub growth_interval: usize,
pub dynamic_loss_scaling: bool,
}
impl Default for AMPConfig {
fn default() -> Self {
Self {
enabled: true,
dtype: DType::Float16,
init_scale: 65536.0, growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
dynamic_loss_scaling: true,
}
}
}
impl AMPConfig {
pub fn bf16() -> Self {
Self {
dtype: DType::BFloat16,
init_scale: 1.0,
dynamic_loss_scaling: false,
..Default::default()
}
}
pub fn fp16_static(scale: f32) -> Self {
Self {
dtype: DType::Float16,
init_scale: scale,
dynamic_loss_scaling: false,
..Default::default()
}
}
}
lazy_static::lazy_static! {
static ref AMP_STATE: Arc<RwLock<AMPState>> = Arc::new(RwLock::new(AMPState::default()));
}
struct AMPState {
enabled: bool,
_autocast_mode: AutocastMode,
config: AMPConfig,
}
impl Default for AMPState {
fn default() -> Self {
Self {
enabled: false,
_autocast_mode: AutocastMode::None,
config: AMPConfig::default(),
}
}
}
pub fn enable_amp(config: AMPConfig) {
let mut state = AMP_STATE.write().unwrap();
state.enabled = true;
state.config = config;
}
pub fn disable_amp() {
let mut state = AMP_STATE.write().unwrap();
state.enabled = false;
}
pub fn is_amp_enabled() -> bool {
AMP_STATE.read().unwrap().enabled
}
pub fn get_amp_config() -> AMPConfig {
let state = AMP_STATE.read().unwrap();
AMPConfig {
enabled: state.config.enabled,
dtype: state.config.dtype,
init_scale: state.config.init_scale,
growth_factor: state.config.growth_factor,
backoff_factor: state.config.backoff_factor,
growth_interval: state.config.growth_interval,
dynamic_loss_scaling: state.config.dynamic_loss_scaling,
}
}
pub mod utils {
use super::*;
pub fn should_use_reduced_precision(op_name: &str) -> bool {
const FP32_OPS: &[&str] = &[
"softmax",
"log_softmax",
"cross_entropy",
"nll_loss",
"batch_norm",
"layer_norm",
];
!FP32_OPS.contains(&op_name)
}
pub fn get_optimal_dtype() -> DType {
if has_bf16_support() {
DType::BFloat16
} else if has_fp16_support() {
DType::Float16
} else {
DType::Float32
}
}
pub fn has_fp16_support() -> bool {
true
}
pub fn has_bf16_support() -> bool {
#[cfg(target_arch = "aarch64")]
{
true
}
#[cfg(not(target_arch = "aarch64"))]
{
false }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_amp_config_default() {
let config = AMPConfig::default();
assert_eq!(config.dtype, DType::Float16);
assert_eq!(config.init_scale, 65536.0);
assert!(config.dynamic_loss_scaling);
}
#[test]
fn test_amp_config_bf16() {
let config = AMPConfig::bf16();
assert_eq!(config.dtype, DType::BFloat16);
assert_eq!(config.init_scale, 1.0);
assert!(!config.dynamic_loss_scaling);
}
#[test]
fn test_amp_state() {
disable_amp();
assert!(!is_amp_enabled());
enable_amp(AMPConfig::default());
assert!(is_amp_enabled());
disable_amp();
assert!(!is_amp_enabled());
}
#[test]
fn test_should_use_reduced_precision() {
assert!(utils::should_use_reduced_precision("matmul"));
assert!(utils::should_use_reduced_precision("conv2d"));
assert!(!utils::should_use_reduced_precision("softmax"));
assert!(!utils::should_use_reduced_precision("batch_norm"));
}
}