use crate::error::{RusTorchError, RusTorchResult};
use crate::optim::common::{AdamConfig, AdamState, AdamUtils, AdamVariant, GenericAdamOptimizer};
use crate::optim::Optimizer;
use crate::tensor::Tensor;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RAdamVariant {
rho_inf_cache: Option<f32>,
rectification_threshold: f32,
}
impl RAdamVariant {
pub fn new() -> Self {
Self {
rho_inf_cache: None,
rectification_threshold: 4.0,
}
}
pub fn with_threshold(threshold: f32) -> Self {
Self {
rho_inf_cache: None,
rectification_threshold: threshold,
}
}
fn get_rho_inf(&mut self, beta2: f32) -> f32 {
if let Some(cached) = self.rho_inf_cache {
cached
} else {
let rho_inf = 2.0 / (1.0 - beta2) - 1.0;
self.rho_inf_cache = Some(rho_inf);
rho_inf
}
}
fn compute_rho_t(&mut self, beta2: f32, step: usize) -> f32 {
let rho_inf = self.get_rho_inf(beta2);
let beta2_t = beta2.powi(step as i32);
rho_inf - 2.0 * (step as f32) * beta2_t / (1.0 - beta2_t)
}
fn should_rectify(&mut self, beta2: f32, step: usize) -> bool {
self.compute_rho_t(beta2, step) > self.rectification_threshold
}
fn compute_rectification_term(&mut self, beta2: f32, step: usize) -> f32 {
let rho_inf = self.get_rho_inf(beta2);
let rho_t = self.compute_rho_t(beta2, step);
let rho_inf_minus_4 = rho_inf - 4.0;
let rho_inf_minus_2 = rho_inf - 2.0;
let rho_t_minus_4 = rho_t - 4.0;
let rho_t_minus_2 = rho_t - 2.0;
let numerator = rho_inf_minus_4 * rho_inf_minus_2 * rho_t;
let denominator = rho_inf * rho_t_minus_4 * rho_t_minus_2;
(numerator / denominator).sqrt()
}
}
impl Default for RAdamVariant {
fn default() -> Self {
Self::new()
}
}
impl AdamVariant for RAdamVariant {
fn compute_specialized_update(
&self,
state: &mut AdamState,
grad: &Tensor<f32>,
config: &AdamConfig,
step: usize,
) -> Tensor<f32> {
let mut variant_copy = self.clone();
AdamUtils::update_momentum(&mut state.momentum, grad, config.beta1);
AdamUtils::update_velocity(&mut state.velocity, grad, config.beta2);
let bias_correction1 = AdamUtils::bias_correction1(config.beta1, step);
let bias_corrected_momentum =
AdamUtils::apply_bias_correction(&state.momentum, bias_correction1);
if variant_copy.should_rectify(config.beta2, step) {
let bias_correction2 = AdamUtils::bias_correction2(config.beta2, step);
let bias_corrected_velocity =
AdamUtils::apply_bias_correction(&state.velocity, bias_correction2);
let rectification_term = variant_copy.compute_rectification_term(config.beta2, step);
let adam_update = AdamUtils::compute_adam_update(
&bias_corrected_momentum,
&bias_corrected_velocity,
config.eps,
);
&adam_update * rectification_term
} else {
bias_corrected_momentum
}
}
fn optimizer_name(&self) -> &'static str {
"RAdam"
}
fn validate_specific_config(&self, _config: &AdamConfig) -> RusTorchResult<()> {
if self.rectification_threshold <= 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "RAdam optimizer".to_string(),
message: "Rectification threshold must be positive".to_string(),
});
}
Ok(())
}
fn additional_config_fields(&self) -> HashMap<String, f32> {
let mut fields = HashMap::new();
fields.insert(
"rectification_threshold".to_string(),
self.rectification_threshold,
);
if let Some(rho_inf) = self.rho_inf_cache {
fields.insert("rho_inf_cache".to_string(), rho_inf);
}
fields
}
fn load_additional_config(&mut self, config: &HashMap<String, f32>) {
if let Some(&threshold) = config.get("rectification_threshold") {
self.rectification_threshold = threshold;
}
if let Some(&rho_inf) = config.get("rho_inf_cache") {
self.rho_inf_cache = Some(rho_inf);
}
}
}
pub type RAdam = GenericAdamOptimizer<RAdamVariant>;
impl RAdam {
pub fn new(
learning_rate: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
) -> RusTorchResult<Self> {
let config = AdamConfig {
learning_rate,
beta1,
beta2,
eps,
weight_decay,
};
let variant = RAdamVariant::new();
GenericAdamOptimizer::from_config_variant(config, variant)
}
pub fn default_params(learning_rate: f32) -> RusTorchResult<Self> {
let config = AdamConfig::radam(learning_rate);
let variant = RAdamVariant::new();
GenericAdamOptimizer::from_config_variant(config, variant)
}
pub fn with_weight_decay(learning_rate: f32, weight_decay: f32) -> RusTorchResult<Self> {
let config = AdamConfig::radam(learning_rate).with_weight_decay(weight_decay);
let variant = RAdamVariant::new();
GenericAdamOptimizer::from_config_variant(config, variant)
}
pub fn with_threshold(learning_rate: f32, threshold: f32) -> RusTorchResult<Self> {
let config = AdamConfig::radam(learning_rate);
let variant = RAdamVariant::with_threshold(threshold);
GenericAdamOptimizer::from_config_variant(config, variant)
}
pub fn get_param_state(&self, param_id: usize) -> Option<(&Tensor<f32>, &Tensor<f32>)> {
self.get_state(param_id).map(|s| (&s.momentum, &s.velocity))
}
pub fn radam_config(&self) -> (f32, Option<f32>) {
let variant = self.variant();
(variant.rectification_threshold, variant.rho_inf_cache)
}
pub fn is_rectification_enabled(&self, step: usize) -> bool {
let mut variant_copy = self.variant().clone();
let config = self.config();
variant_copy.should_rectify(config.beta2, step)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_radam_creation() {
let optimizer = RAdam::default_params(0.001).unwrap();
assert_eq!(optimizer.learning_rate(), 0.001);
assert_eq!(optimizer.config().beta1, 0.9);
assert_eq!(optimizer.config().beta2, 0.999);
}
#[test]
fn test_radam_with_weight_decay() {
let optimizer = RAdam::with_weight_decay(0.001, 0.1).unwrap();
assert_eq!(optimizer.config().weight_decay, 0.1);
}
#[test]
fn test_radam_step() {
let mut optimizer = RAdam::default_params(0.1).unwrap();
let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let grad = Tensor::from_vec(vec![0.1, 0.2, 0.3], vec![3]);
optimizer.step(¶m, &grad);
let param_data = param.as_slice().unwrap();
assert!(param_data[0] < 1.0); assert!(param_data[1] < 2.0);
assert!(param_data[2] < 3.0);
}
#[test]
fn test_radam_variant_caching() {
let mut variant = RAdamVariant::new();
let rho_inf1 = variant.get_rho_inf(0.999);
assert!((rho_inf1 - 1999.0).abs() < 1.0);
let rho_inf2 = variant.get_rho_inf(0.999);
assert_eq!(rho_inf1, rho_inf2);
assert!(variant.rho_inf_cache.is_some());
}
#[test]
fn test_variance_rectification() {
let optimizer = RAdam::default_params(0.001).unwrap();
assert!(!optimizer.is_rectification_enabled(1));
assert!(!optimizer.is_rectification_enabled(2));
assert!(optimizer.is_rectification_enabled(1000));
assert!(optimizer.is_rectification_enabled(10000));
}
#[test]
fn test_radam_with_custom_threshold() {
let optimizer = RAdam::with_threshold(0.001, 2.0).unwrap();
let (threshold, _) = optimizer.radam_config();
assert_eq!(threshold, 2.0);
assert!(optimizer.is_rectification_enabled(100));
}
#[test]
fn test_radam_variant_rectification_term() {
let mut variant = RAdamVariant::new();
let rect_term = variant.compute_rectification_term(0.999, 10000);
assert!(rect_term > 0.0);
assert!(rect_term < 5.0);
}
#[test]
fn test_radam_fallback_to_momentum() {
let mut optimizer = RAdam::default_params(0.001).unwrap();
let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let grad = Tensor::from_vec(vec![0.1, 0.2, 0.3], vec![3]);
let original_param = param.clone();
optimizer.step(¶m, &grad);
let param_data = param.as_slice().unwrap();
let orig_data = original_param.as_slice().unwrap();
for (new_val, orig_val) in param_data.iter().zip(orig_data.iter()) {
assert!(new_val != orig_val); }
}
#[test]
fn test_radam_state_dict() {
let optimizer = RAdam::default_params(0.001).unwrap();
let state_dict = optimizer.state_dict();
assert_eq!(state_dict.get("learning_rate"), Some(&0.001));
assert_eq!(state_dict.get("beta1"), Some(&0.9));
assert_eq!(state_dict.get("rectification_threshold"), Some(&4.0));
}
#[test]
fn test_radam_variant_validation() {
let variant = RAdamVariant::with_threshold(-1.0);
let config = AdamConfig::radam(0.001);
assert!(variant.validate_specific_config(&config).is_err());
let valid_variant = RAdamVariant::with_threshold(4.0);
assert!(valid_variant.validate_specific_config(&config).is_ok());
}
}