use burn::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RAdamConfig {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub epsilon: f64,
pub weight_decay: f64,
}
impl Default for RAdamConfig {
fn default() -> Self {
Self {
lr: 1e-3,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
}
}
}
impl RAdamConfig {
pub fn new(lr: f64) -> Self {
Self {
lr,
..Default::default()
}
}
#[must_use]
pub fn with_beta1(mut self, beta1: f64) -> Self {
self.beta1 = beta1;
self
}
#[must_use]
pub fn with_beta2(mut self, beta2: f64) -> Self {
self.beta2 = beta2;
self
}
#[must_use]
pub fn with_epsilon(mut self, epsilon: f64) -> Self {
self.epsilon = epsilon;
self
}
#[must_use]
pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
self.weight_decay = weight_decay;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RAdamParamState {
pub m: Vec<f32>,
pub v: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RAdam {
config: RAdamConfig,
step: usize,
states: HashMap<String, RAdamParamState>,
}
impl RAdam {
pub fn new(config: RAdamConfig) -> Self {
Self {
config,
step: 0,
states: HashMap::new(),
}
}
pub fn lr(&self) -> f64 {
self.config.lr
}
pub fn set_lr(&mut self, lr: f64) {
self.config.lr = lr;
}
pub fn step(&self) -> usize {
self.step
}
fn rho_inf(&self) -> f64 {
2.0 / (1.0 - self.config.beta2) - 1.0
}
pub fn step_param<B: Backend>(
&mut self,
param_id: &str,
param: Tensor<B, 1>,
grad: Tensor<B, 1>,
) -> Tensor<B, 1> {
let device = param.device();
self.step += 1;
let t = self.step as f64;
let beta1 = self.config.beta1;
let beta2 = self.config.beta2;
let lr = self.config.lr;
let eps = self.config.epsilon;
let wd = self.config.weight_decay;
let rho_inf = self.rho_inf();
let param_data: Vec<f32> = param.clone().into_data().to_vec().unwrap();
let grad_data: Vec<f32> = grad.into_data().to_vec().unwrap();
let n = param_data.len();
let state = self.states.entry(param_id.to_string()).or_insert_with(|| {
RAdamParamState {
m: vec![0.0; n],
v: vec![0.0; n],
}
});
for i in 0..n {
state.m[i] = (beta1 as f32) * state.m[i] + (1.0 - beta1 as f32) * grad_data[i];
state.v[i] =
(beta2 as f32) * state.v[i] + (1.0 - beta2 as f32) * grad_data[i] * grad_data[i];
}
let beta1_t = beta1.powi(t as i32);
let beta2_t = beta2.powi(t as i32);
let m_hat: Vec<f32> = state.m.iter().map(|&m| m / (1.0 - beta1_t as f32)).collect();
let rho_t = rho_inf - 2.0 * t * beta2_t / (1.0 - beta2_t);
let updated: Vec<f32> = if rho_t > 5.0 {
let v_hat: Vec<f32> = state.v.iter().map(|&v| v / (1.0 - beta2_t as f32)).collect();
let r = ((rho_t - 4.0) * (rho_t - 2.0) * rho_inf
/ ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
.sqrt();
param_data
.iter()
.zip(m_hat.iter())
.zip(v_hat.iter())
.map(|((&p, &m), &v)| {
let update = (r as f32) * m / (v.sqrt() + eps as f32);
p - (lr as f32) * update - (wd as f32) * p
})
.collect()
} else {
param_data
.iter()
.zip(m_hat.iter())
.map(|(&p, &m)| p - (lr as f32) * m - (wd as f32) * p)
.collect()
};
Tensor::from_floats(updated.as_slice(), &device)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangerConfig {
pub radam: RAdamConfig,
pub alpha: f64,
pub k: usize,
}
impl Default for RangerConfig {
fn default() -> Self {
Self {
radam: RAdamConfig::default(),
alpha: 0.5,
k: 6,
}
}
}
impl RangerConfig {
pub fn new(lr: f64) -> Self {
Self {
radam: RAdamConfig::new(lr),
..Default::default()
}
}
#[must_use]
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
#[must_use]
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
#[must_use]
pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
self.radam.weight_decay = weight_decay;
self
}
#[must_use]
pub fn with_beta1(mut self, beta1: f64) -> Self {
self.radam.beta1 = beta1;
self
}
#[must_use]
pub fn with_beta2(mut self, beta2: f64) -> Self {
self.radam.beta2 = beta2;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangerParamState {
pub radam_state: RAdamParamState,
pub slow_weights: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ranger {
config: RangerConfig,
radam: RAdam,
slow_weights: HashMap<String, Vec<f32>>,
sync_counter: usize,
}
impl Ranger {
pub fn new(config: RangerConfig) -> Self {
let radam = RAdam::new(config.radam.clone());
Self {
config,
radam,
slow_weights: HashMap::new(),
sync_counter: 0,
}
}
pub fn lr(&self) -> f64 {
self.radam.lr()
}
pub fn set_lr(&mut self, lr: f64) {
self.radam.set_lr(lr);
}
pub fn step(&self) -> usize {
self.radam.step()
}
pub fn step_param<B: Backend>(
&mut self,
param_id: &str,
param: Tensor<B, 1>,
grad: Tensor<B, 1>,
) -> Tensor<B, 1> {
let device = param.device();
let fast_weights = self.radam.step_param(param_id, param, grad);
let fast_data: Vec<f32> = fast_weights.clone().into_data().to_vec().unwrap();
if !self.slow_weights.contains_key(param_id) {
self.slow_weights
.insert(param_id.to_string(), fast_data.clone());
}
self.sync_counter += 1;
if self.sync_counter >= self.config.k {
self.sync_counter = 0;
let slow = self.slow_weights.get_mut(param_id).unwrap();
let alpha = self.config.alpha as f32;
for i in 0..slow.len() {
slow[i] += alpha * (fast_data[i] - slow[i]);
}
Tensor::from_floats(slow.as_slice(), &device)
} else {
fast_weights
}
}
pub fn sync_lookahead(&mut self) {
self.sync_counter = self.config.k;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerState {
pub lr: f64,
pub step: usize,
}
impl Default for OptimizerState {
fn default() -> Self {
Self { lr: 1e-3, step: 0 }
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_ndarray::NdArray;
type TestBackend = NdArray;
#[test]
fn test_radam_config() {
let config = RAdamConfig::new(1e-3)
.with_beta1(0.9)
.with_beta2(0.999)
.with_weight_decay(0.01);
assert!((config.lr - 1e-3).abs() < 1e-10);
assert!((config.beta1 - 0.9).abs() < 1e-10);
assert!((config.beta2 - 0.999).abs() < 1e-10);
assert!((config.weight_decay - 0.01).abs() < 1e-10);
}
#[test]
fn test_radam_step() {
let config = RAdamConfig::new(0.1);
let mut optimizer = RAdam::new(config);
let device = Default::default();
let param = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0], &device);
let grad = Tensor::<TestBackend, 1>::from_floats([0.1, 0.1, 0.1], &device);
let updated = optimizer.step_param("test", param.clone(), grad);
let updated_data: Vec<f32> = updated.into_data().to_vec().unwrap();
assert!(updated_data[0] < 1.0);
assert!(updated_data[1] < 2.0);
assert!(updated_data[2] < 3.0);
}
#[test]
fn test_ranger_config() {
let config = RangerConfig::new(1e-3).with_alpha(0.5).with_k(6);
assert!((config.radam.lr - 1e-3).abs() < 1e-10);
assert!((config.alpha - 0.5).abs() < 1e-10);
assert_eq!(config.k, 6);
}
#[test]
fn test_ranger_step() {
let config = RangerConfig::new(0.1);
let mut optimizer = Ranger::new(config);
let device = Default::default();
for _ in 0..10 {
let param = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0], &device);
let grad = Tensor::<TestBackend, 1>::from_floats([0.1, 0.1, 0.1], &device);
let _ = optimizer.step_param("test", param, grad);
}
assert_eq!(optimizer.step(), 10);
}
}