use candle_core::{Result, Tensor, Var};
#[derive(Debug, Clone)]
pub struct ParamsScheduleFree {
pub lr: f64,
pub warmup_steps: usize,
pub weight_decay: f64,
pub beta: f64, }
impl Default for ParamsScheduleFree {
fn default() -> Self {
Self {
lr: 0.0025, warmup_steps: 0,
weight_decay: 0.0,
beta: 0.9,
}
}
}
pub struct ScheduleFreeOptimizer {
vars: Vec<Var>,
pub z: Vec<Tensor>, step: usize,
params: ParamsScheduleFree,
}
impl ScheduleFreeOptimizer {
pub fn new(vars: Vec<Var>, params: ParamsScheduleFree) -> Result<Self> {
let mut z = Vec::with_capacity(vars.len());
for var in &vars {
z.push(var.as_tensor().copy()?);
}
Ok(Self {
vars,
z,
step: 0,
params,
})
}
pub fn train(&self) -> Result<()> {
let b = self.params.beta;
let one_minus_b = 1.0 - b;
for (i, var) in self.vars.iter().enumerate() {
let z_i = &self.z[i];
let c_i = var.as_tensor();
let y_i = ((z_i * one_minus_b)? + (c_i * b)?)?;
var.set(&y_i)?;
}
Ok(())
}
pub fn eval(&self) -> Result<()> {
Ok(())
}
pub fn pre_step(&self) -> Result<()> {
let b = self.params.beta;
let one_minus_b = 1.0 - b;
for (i, var) in self.vars.iter().enumerate() {
let z_i = &self.z[i];
let c_i = var.as_tensor();
let y_i = ((z_i * one_minus_b)? + (c_i * b)?)?.detach();
var.set(&y_i)?;
}
Ok(())
}
pub fn step(&mut self, grads: &[Tensor]) -> Result<()> {
self.step += 1;
let k = self.step as f64;
let lr = self.params.lr;
let b = self.params.beta;
for (i, var) in self.vars.iter().enumerate() {
if let Some(grad) = grads.get(i) {
let z_i = &self.z[i];
let y_i = var.as_tensor();
let one_minus_b = 1.0 - b;
let term = (z_i * one_minus_b)?;
let diff = (y_i - term)?;
let x_old = (diff / b)?;
let z_new = (z_i - (grad * lr)?)?.detach();
let k_inv = 1.0 / (k + 1.0);
let one_minus_k = 1.0 - k_inv;
let x_part = (x_old * one_minus_k)?;
let z_part = (&z_new * k_inv)?;
let x_new = (x_part + z_part)?.detach();
self.z[i] = z_new;
var.set(&x_new)?; }
}
Ok(())
}
pub fn learning_rate(&self) -> f64 {
self.params.lr
}
pub fn set_learning_rate(&mut self, lr: f64) {
self.params.lr = lr;
}
}