use scivex_core::Float;
use crate::error::Result;
use crate::layer::Layer;
use crate::variable::Variable;
pub fn cast_variable<S: Float, D: Float>(v: &Variable<S>) -> Variable<D> {
let data = v.data();
let casted = data.cast::<D>();
Variable::new(casted, v.requires_grad())
}
pub fn cast_params<S: Float, D: Float>(params: &[Variable<S>]) -> Vec<Variable<D>> {
params.iter().map(cast_variable).collect()
}
pub struct AmpConfig<Master: Float, Compute: Float> {
_master: std::marker::PhantomData<Master>,
_compute: std::marker::PhantomData<Compute>,
}
impl<Master: Float, Compute: Float> AmpConfig<Master, Compute> {
pub fn new() -> Self {
Self {
_master: std::marker::PhantomData,
_compute: std::marker::PhantomData,
}
}
pub fn to_compute(params: &[Variable<Master>]) -> Vec<Variable<Compute>> {
cast_params(params)
}
pub fn loss_to_master(loss: &Variable<Compute>) -> Variable<Master> {
cast_variable(loss)
}
pub fn sync_grads(master_params: &[Variable<Master>], compute_params: &[Variable<Compute>]) {
for (mp, cp) in master_params.iter().zip(compute_params.iter()) {
if let Some(grad_compute) = cp.grad() {
let grad_master = grad_compute.cast::<Master>();
mp.set_grad(grad_master);
}
}
}
}
impl<Master: Float, Compute: Float> Default for AmpConfig<Master, Compute> {
fn default() -> Self {
Self::new()
}
}
pub fn amp_forward<T: Float>(layer: &dyn Layer<T>, input: &Variable<T>) -> Result<Variable<T>> {
layer.forward(input)
}
#[cfg(test)]
mod tests {
use super::*;
use scivex_core::Tensor;
#[test]
fn test_cast_variable_f64_to_f32() {
let v = Variable::new(
Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap(),
true,
);
let casted: Variable<f32> = cast_variable(&v);
let data = casted.data();
let s = data.as_slice();
assert!((s[0] - 1.0).abs() < 1e-6);
assert!((s[1] - 2.0).abs() < 1e-6);
assert!((s[2] - 3.0).abs() < 1e-6);
assert!(casted.requires_grad());
}
#[test]
fn test_cast_params() {
let params: Vec<Variable<f64>> = vec![
Variable::new(Tensor::ones(vec![2, 3]), true),
Variable::new(Tensor::zeros(vec![3]), false),
];
let casted: Vec<Variable<f32>> = cast_params(¶ms);
assert_eq!(casted.len(), 2);
assert_eq!(casted[0].shape(), vec![2, 3]);
assert_eq!(casted[1].shape(), vec![3]);
assert!(casted[0].requires_grad());
assert!(!casted[1].requires_grad());
}
#[test]
fn test_amp_config_sync_grads() {
let master = vec![Variable::new(
Tensor::from_vec(vec![1.0_f64, 2.0], vec![2]).unwrap(),
true,
)];
let compute = vec![Variable::new(
Tensor::from_vec(vec![1.0_f32, 2.0], vec![2]).unwrap(),
true,
)];
compute[0].set_grad(Tensor::from_vec(vec![0.1_f32, 0.2], vec![2]).unwrap());
AmpConfig::<f64, f32>::sync_grads(&master, &compute);
let grad = master[0]
.grad()
.expect("master should have grad after sync");
assert!((grad.as_slice()[0] - 0.1).abs() < 1e-5);
assert!((grad.as_slice()[1] - 0.2).abs() < 1e-5);
}
}