use std::collections::{HashMap, HashSet};
use crate::error::Result;
use crate::optimizer::grad_scaler::{GradScaler, UnscaleResult};
use crate::trainer::config::{MixedPrecisionConfig, TrainingConfig, TrainingMetrics};
use crate::trainer::simple::SimpleTrainer;
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, TypeConversionOps, UnaryOps};
use crate::ops::FusedOptimizerOps;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
pub struct AmpTrainer<R: Runtime<DType = DType>> {
trainer: SimpleTrainer<R>,
master_params: HashMap<TensorId, Tensor<R>>,
compute_dtype: DType,
grad_scaler: Option<GradScaler>,
fp32_overrides: HashSet<TensorId>,
}
impl<R: Runtime<DType = DType>> AmpTrainer<R> {
pub fn new<C>(
config: TrainingConfig,
amp_config: MixedPrecisionConfig,
initial_params: HashMap<TensorId, Tensor<R>>,
client: &C,
) -> Result<Self>
where
C: RuntimeClient<R> + TypeConversionOps<R>,
{
let trainer = SimpleTrainer::new(config)?;
let mut master_params = HashMap::with_capacity(initial_params.len());
for (id, param) in initial_params {
let master = if param.dtype() == DType::F32 {
param
} else {
client.cast(¶m, DType::F32)?
};
master_params.insert(id, master);
}
let grad_scaler = amp_config.loss_scale.to_grad_scaler()?;
Ok(Self {
trainer,
master_params,
compute_dtype: amp_config.compute_dtype,
grad_scaler,
fp32_overrides: HashSet::new(),
})
}
pub fn compute_params<C>(&self, client: &C) -> Result<HashMap<TensorId, Tensor<R>>>
where
C: RuntimeClient<R> + TypeConversionOps<R>,
{
let mut compute = HashMap::with_capacity(self.master_params.len());
for (&id, master) in &self.master_params {
let target_dtype = if self.fp32_overrides.contains(&id) {
DType::F32
} else {
self.compute_dtype
};
let param = if master.dtype() == target_dtype {
master.clone()
} else {
client.cast(master, target_dtype)?
};
compute.insert(id, param);
}
Ok(compute)
}
pub fn step<C>(
&mut self,
client: &C,
grads: GradStore<R>,
loss_value: f64,
) -> Result<Option<TrainingMetrics>>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>
+ TypeConversionOps<R>,
{
let grads = if let Some(ref scaler) = self.grad_scaler {
match scaler.unscale_grads(client, grads)? {
UnscaleResult::Ok(unscaled) => unscaled,
UnscaleResult::Overflow => {
if let Some(ref mut s) = self.grad_scaler {
s.update_scale(true);
}
return Ok(None);
}
}
} else {
grads
};
let mut fp32_grads = GradStore::new();
let ids: Vec<TensorId> = grads.keys().copied().collect();
for id in ids {
let grad = grads
.get(id)
.ok_or_else(|| crate::error::Error::TrainingError {
reason: format!("missing gradient for tensor {id:?}"),
})?;
let fp32_grad = if grad.dtype() == DType::F32 {
grad.clone()
} else {
client.cast(grad, DType::F32)?
};
fp32_grads.insert(id, fp32_grad);
}
let result = self
.trainer
.step(client, &mut self.master_params, fp32_grads, loss_value)?;
if let Some(ref mut scaler) = self.grad_scaler {
scaler.update_scale(false);
}
Ok(result)
}
pub fn set_fp32_overrides(&mut self, ids: HashSet<TensorId>) {
self.fp32_overrides = ids;
}
pub fn add_fp32_override(&mut self, id: TensorId) {
self.fp32_overrides.insert(id);
}
pub fn remove_fp32_override(&mut self, id: &TensorId) {
self.fp32_overrides.remove(id);
}
pub fn fp32_overrides(&self) -> &HashSet<TensorId> {
&self.fp32_overrides
}
pub fn loss_scale(&self) -> f64 {
self.grad_scaler.as_ref().map_or(1.0, |s| s.scale())
}
pub fn scale_loss(&self, loss: f64) -> f64 {
self.grad_scaler
.as_ref()
.map_or(loss, |s| s.scale_loss(loss))
}
pub fn master_params(&self) -> &HashMap<TensorId, Tensor<R>> {
&self.master_params
}
pub fn compute_dtype(&self) -> DType {
self.compute_dtype
}
pub fn global_step(&self) -> u64 {
self.trainer.global_step()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use crate::trainer::config::LossScaleStrategy;
use numr::autograd::{Var, backward, var_mean, var_mul, var_sub};
use numr::ops::TypeConversionOps;
use numr::runtime::cpu::CpuRuntime;
use std::collections::HashSet;
#[test]
fn test_amp_trainer_f64_compute_converges() {
let (client, device) = cpu_setup();
let target = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device);
let w_init = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[2, 2], &device);
let w_id = w_init.id();
let mut initial_params = HashMap::new();
initial_params.insert(w_id, w_init);
let config = TrainingConfig::default()
.with_lr(0.1)
.with_weight_decay(0.0)
.with_max_grad_norm(None);
let amp_config = MixedPrecisionConfig {
compute_dtype: DType::F64,
master_dtype: DType::F32,
loss_scale: LossScaleStrategy::None,
};
let mut trainer = AmpTrainer::new(config, amp_config, initial_params, &client).unwrap();
let mut first_loss = 0.0f64;
let mut last_loss = 0.0f64;
for i in 0..20 {
let compute_params = trainer.compute_params(&client).unwrap();
let w_tensor = compute_params.get(&w_id).unwrap().clone();
assert_eq!(w_tensor.dtype(), DType::F64);
let w = Var::with_id(w_tensor, w_id, true);
let t_f64 = client.cast(&target, DType::F64).unwrap();
let t = Var::new(t_f64, false);
let diff = var_sub(&w, &t, &client).unwrap();
let sq = var_mul(&diff, &diff, &client).unwrap();
let loss = var_mean(&sq, &[], false, &client).unwrap();
let loss_val = loss.tensor().to_vec::<f64>()[0];
if i == 0 {
first_loss = loss_val;
}
last_loss = loss_val;
let grads = backward(&loss, &client).unwrap();
trainer.step(&client, grads, loss_val).unwrap();
}
assert!(
last_loss < first_loss * 0.1,
"loss should decrease: first={first_loss} last={last_loss}"
);
}
#[test]
fn test_amp_trainer_loss_scale() {
let (client, device) = cpu_setup();
let w = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let mut params = HashMap::new();
params.insert(w.id(), w);
let trainer = AmpTrainer::<CpuRuntime>::new(
TrainingConfig::default(),
MixedPrecisionConfig::bf16(),
params.clone(),
&client,
)
.unwrap();
assert_eq!(trainer.loss_scale(), 1.0);
assert_eq!(trainer.scale_loss(2.5), 2.5);
let trainer = AmpTrainer::<CpuRuntime>::new(
TrainingConfig::default(),
MixedPrecisionConfig::fp16(),
params,
&client,
)
.unwrap();
assert_eq!(trainer.loss_scale(), 65536.0);
assert_eq!(trainer.scale_loss(1.0), 65536.0);
}
#[test]
fn test_per_layer_precision_policy() {
let (client, device) = cpu_setup();
let w1 = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let w2 = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device);
let id1 = w1.id();
let id2 = w2.id();
let mut params = HashMap::new();
params.insert(id1, w1);
params.insert(id2, w2);
let amp_config = MixedPrecisionConfig {
compute_dtype: DType::F64,
master_dtype: DType::F32,
loss_scale: LossScaleStrategy::None,
};
let mut trainer =
AmpTrainer::<CpuRuntime>::new(TrainingConfig::default(), amp_config, params, &client)
.unwrap();
trainer.add_fp32_override(id1);
let compute = trainer.compute_params(&client).unwrap();
assert_eq!(compute.get(&id1).unwrap().dtype(), DType::F32);
assert_eq!(compute.get(&id2).unwrap().dtype(), DType::F64);
trainer.remove_fp32_override(&id1);
let compute = trainer.compute_params(&client).unwrap();
assert_eq!(compute.get(&id1).unwrap().dtype(), DType::F64);
assert_eq!(compute.get(&id2).unwrap().dtype(), DType::F64);
}
#[test]
fn test_fp32_override_set_bulk() {
let (client, device) = cpu_setup();
let w1 = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let w2 = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let w3 = Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device);
let id1 = w1.id();
let id2 = w2.id();
let id3 = w3.id();
let mut params = HashMap::new();
params.insert(id1, w1);
params.insert(id2, w2);
params.insert(id3, w3);
let amp_config = MixedPrecisionConfig {
compute_dtype: DType::F64,
master_dtype: DType::F32,
loss_scale: LossScaleStrategy::None,
};
let mut trainer =
AmpTrainer::<CpuRuntime>::new(TrainingConfig::default(), amp_config, params, &client)
.unwrap();
let overrides: HashSet<TensorId> = [id1, id3].into_iter().collect();
trainer.set_fp32_overrides(overrides);
let compute = trainer.compute_params(&client).unwrap();
assert_eq!(compute.get(&id1).unwrap().dtype(), DType::F32);
assert_eq!(compute.get(&id2).unwrap().dtype(), DType::F64);
assert_eq!(compute.get(&id3).unwrap().dtype(), DType::F32);
assert_eq!(trainer.fp32_overrides().len(), 2);
}
#[test]
fn test_amp_trainer_master_params_are_f32() {
let (client, device) = cpu_setup();
let w = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let mut params = HashMap::new();
let id = w.id();
params.insert(id, w);
let trainer = AmpTrainer::<CpuRuntime>::new(
TrainingConfig::default(),
MixedPrecisionConfig::bf16(),
params,
&client,
)
.unwrap();
assert_eq!(
trainer.master_params().get(&id).unwrap().dtype(),
DType::F32
);
}
}