use std::sync::{Arc, Mutex};
use tch::nn::{VarStore, Variables};
use tch::{no_grad, Device, Kind, Tensor};
struct Buffer {
pub first_moment: Tensor,
pub second_moment: Tensor,
idx: usize,
}
impl Buffer {
pub fn new(size: &[i64]) -> Buffer {
Buffer {
first_moment: Tensor::zeros(size, (Kind::Float, Device::Cpu)),
second_moment: Tensor::zeros(size, (Kind::Float, Device::Cpu)),
idx: 1,
}
}
pub fn inc(&mut self) -> usize {
let old_val = self.idx;
self.idx += 1;
old_val
}
}
pub struct SparseAdam {
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
force_sparse: bool,
vars: Arc<Mutex<Variables>>,
buffers: Vec<Buffer>,
}
impl SparseAdam {
pub fn new(
vs: &VarStore,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
force_sparse: bool,
) -> SparseAdam {
let vars = vs.variables_.clone();
let buffers = vars
.lock()
.unwrap()
.trainable_variables
.iter()
.map(|x| Buffer::new(&x.tensor.size()))
.collect();
SparseAdam { lr, beta1, beta2, eps, force_sparse, vars, buffers }
}
pub fn step(&mut self) {
no_grad(|| self._step());
}
pub fn _step(&mut self) {
let mut vars = self.vars.lock().unwrap();
for (var, buffer) in vars.trainable_variables.iter_mut().zip(&mut self.buffers) {
let mut grad = var.tensor.grad();
let buffer_idx = buffer.inc();
let bias_correction1 = 1.0 - self.beta1.powf(buffer_idx as f64);
let bias_correction2 = 1.0 - self.beta2.powf(buffer_idx as f64);
if grad.is_sparse() || self.force_sparse {
if !grad.is_sparse() {
grad = grad.to_sparse_sparse_dim(1);
}
let grad = grad.coalesce();
let indices = grad.indices().squeeze();
let values = grad.values();
let update_first_moment =
(1.0 - self.beta1) * (&values - buffer.first_moment.index_select(0, &indices));
let update_second_moment = (1.0 - self.beta2)
* (&values * &values - buffer.second_moment.index_select(0, &indices));
let _ = buffer.first_moment.index_add_(0, &indices, &update_first_moment);
let _ = buffer.second_moment.index_add_(0, &indices, &update_second_moment);
let part1 =
buffer.first_moment.index_select(0, &indices) * (-self.lr / bias_correction1);
let part2 = (buffer.second_moment.index_select(0, &indices) / bias_correction2)
.sqrt()
+ self.eps;
let _ = var.tensor.index_add_(0, &indices, &(part1 / part2));
} else {
buffer.first_moment *= self.beta1;
buffer.first_moment += (1.0 - self.beta1) * &grad;
buffer.second_moment *= self.beta2;
let scaled_grad = grad * (1.0 - self.beta2).sqrt();
let _ = buffer.second_moment.addcmul_(&scaled_grad, &scaled_grad);
let part1 = &buffer.first_moment * (-self.lr / bias_correction1);
let part2 = (&buffer.second_moment / bias_correction2).sqrt() + self.eps;
let _ = var.tensor.addcdiv_(&part1, &part2);
}
}
}
pub fn zero_grad(&mut self) {
let mut vars = self.vars.lock().unwrap();
for var in &mut vars.trainable_variables {
var.tensor.zero_grad();
}
}
}