use crate::error::OptimizeError;
use crate::stochastic::{
clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
LearningRateSchedule, StochasticGradientFunction,
};
use crate::unconstrained::result::OptimizeResult;
use scirs2_core::ndarray::Array1;
use statrs::statistics::Statistics;
#[derive(Debug, Clone)]
pub struct AdamWOptions {
pub learning_rate: f64,
pub beta1: f64,
pub beta2: f64,
pub epsilon: f64,
pub weight_decay: f64,
pub max_iter: usize,
pub tol: f64,
pub lr_schedule: LearningRateSchedule,
pub gradient_clip: Option<f64>,
pub batch_size: Option<usize>,
pub decouple_weight_decay: bool,
}
impl Default for AdamWOptions {
fn default() -> Self {
Self {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.01,
max_iter: 1000,
tol: 1e-6,
lr_schedule: LearningRateSchedule::Constant,
gradient_clip: None,
batch_size: None,
decouple_weight_decay: true,
}
}
}
#[allow(dead_code)]
pub fn minimize_adamw<F>(
mut grad_func: F,
mut x: Array1<f64>,
data_provider: Box<dyn DataProvider>,
options: AdamWOptions,
) -> Result<OptimizeResult<f64>, OptimizeError>
where
F: StochasticGradientFunction,
{
let mut func_evals = 0;
let mut _grad_evals = 0;
let num_samples = data_provider.num_samples();
let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
let actual_batch_size = batch_size.min(num_samples);
let mut m: Array1<f64> = Array1::zeros(x.len()); let mut v: Array1<f64> = Array1::zeros(x.len());
let mut best_x = x.clone();
let mut best_f = f64::INFINITY;
let mut prev_loss = f64::INFINITY;
let mut stagnant_iterations = 0;
println!("Starting AdamW optimization:");
println!(" Parameters: {}", x.len());
println!(" Dataset size: {}", num_samples);
println!(" Batch size: {}", actual_batch_size);
println!(" Initial learning rate: {}", options.learning_rate);
println!(" Beta1: {}, Beta2: {}", options.beta1, options.beta2);
println!(" Weight decay: {}", options.weight_decay);
println!(" Decoupled: {}", options.decouple_weight_decay);
#[allow(clippy::explicit_counter_loop)]
for iteration in 0..options.max_iter {
let current_lr = update_learning_rate(
options.learning_rate,
iteration,
options.max_iter,
&options.lr_schedule,
);
let batch_indices = if actual_batch_size < num_samples {
generate_batch_indices(num_samples, actual_batch_size, true)
} else {
(0..num_samples).collect()
};
let batch_data = data_provider.get_batch(&batch_indices);
let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
_grad_evals += 1;
if let Some(clip_threshold) = options.gradient_clip {
clip_gradients(&mut gradient, clip_threshold);
}
if options.decouple_weight_decay && options.weight_decay > 0.0 {
x = &x * (1.0 - current_lr * options.weight_decay);
} else if options.weight_decay > 0.0 {
gradient = &gradient + &x * options.weight_decay;
}
m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
let gradient_sq = gradient.mapv(|g| g * g);
v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
let bias_correction1 = 1.0 - options.beta1.powi((iteration + 1) as i32);
let m_hat = &m / bias_correction1;
let bias_correction2 = 1.0 - options.beta2.powi((iteration + 1) as i32);
let v_hat = &v / bias_correction2;
let denominator = v_hat.mapv(|v: f64| v.sqrt() + options.epsilon);
let gradient_update = &m_hat / &denominator * current_lr;
x = &x - &gradient_update;
if iteration % 10 == 0 || iteration == options.max_iter - 1 {
let full_data = data_provider.get_full_data();
let current_loss = grad_func.compute_value(&x.view(), &full_data);
func_evals += 1;
if current_loss < best_f {
best_f = current_loss;
best_x = x.clone();
stagnant_iterations = 0;
} else {
stagnant_iterations += 1;
}
if iteration % 100 == 0 {
let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
let param_norm = x.mapv(|p| p * p).sum().sqrt();
let m_norm = m_hat.mapv(|g: f64| g * g).sum().sqrt();
let v_mean = v_hat.mean();
println!(" Iteration {}: loss = {:.6e}, |grad| = {:.3e}, |param| = {:.3e}, |m| = {:.3e}, <v> = {:.3e}, lr = {:.3e}",
iteration, current_loss, grad_norm, param_norm, m_norm, v_mean, current_lr);
}
let loss_change = (prev_loss - current_loss).abs();
if loss_change < options.tol {
return Ok(OptimizeResult {
x: best_x,
fun: best_f,
nit: iteration,
func_evals,
nfev: func_evals,
success: true,
message: format!(
"AdamW converged: loss change {:.2e} < {:.2e}",
loss_change, options.tol
),
jacobian: Some(gradient),
hessian: None,
});
}
prev_loss = current_loss;
if stagnant_iterations > 50 {
return Ok(OptimizeResult {
x: best_x,
fun: best_f,
nit: iteration,
func_evals,
nfev: func_evals,
success: false,
message: "AdamW stopped due to stagnation".to_string(),
jacobian: Some(gradient),
hessian: None,
});
}
}
}
let full_data = data_provider.get_full_data();
let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
func_evals += 1;
Ok(OptimizeResult {
x: best_x,
fun: final_loss.min(best_f),
nit: options.max_iter,
func_evals,
nfev: func_evals,
success: false,
message: "AdamW reached maximum iterations".to_string(),
jacobian: None,
hessian: None,
})
}
#[allow(dead_code)]
pub fn minimize_adamw_cosine_restarts<F>(
mut grad_func: F,
mut x: Array1<f64>,
data_provider: Box<dyn DataProvider>,
options: AdamWOptions,
t_initial: usize,
t_mult: f64,
eta_min: f64,
) -> Result<OptimizeResult<f64>, OptimizeError>
where
F: StochasticGradientFunction,
{
let mut current_cycle_length = t_initial;
let mut cycle_start = 0;
let mut restart_count = 0;
let initial_lr = options.learning_rate;
let total_max_iter = options.max_iter;
let mut global_best_x = x.clone();
let mut global_best_f = f64::INFINITY;
while cycle_start < total_max_iter {
let cycle_end = (cycle_start + current_cycle_length).min(total_max_iter);
println!(
"Starting restart {} (cycle {}-{}, length {})",
restart_count, cycle_start, cycle_end, current_cycle_length
);
let mut cycle_options = options.clone();
cycle_options.lr_schedule = LearningRateSchedule::CosineAnnealing;
cycle_options.max_iter = cycle_end - cycle_start;
cycle_options.learning_rate = initial_lr;
let cycle_result = minimize_adamw_cycle(
&mut grad_func,
x.clone(),
data_provider.as_ref(),
&cycle_options,
initial_lr,
eta_min,
cycle_start,
)?;
if cycle_result.fun < global_best_f {
global_best_f = cycle_result.fun;
global_best_x = cycle_result.x.clone();
}
x = cycle_result.x; cycle_start = cycle_end;
current_cycle_length = (current_cycle_length as f64 * t_mult) as usize;
restart_count += 1;
if global_best_f < options.tol {
break;
}
}
Ok(OptimizeResult {
x: global_best_x,
fun: global_best_f,
nit: cycle_start,
func_evals: 0, nfev: 0,
success: global_best_f < options.tol,
message: format!(
"AdamW with cosine restarts completed ({} restarts)",
restart_count
),
jacobian: None,
hessian: None,
})
}
#[allow(dead_code)]
fn minimize_adamw_cycle<F>(
grad_func: &mut F,
mut x: Array1<f64>,
data_provider: &dyn DataProvider,
options: &AdamWOptions,
lr_max: f64,
lr_min: f64,
cycle_offset: usize,
) -> Result<OptimizeResult<f64>, OptimizeError>
where
F: StochasticGradientFunction,
{
let mut m: Array1<f64> = Array1::zeros(x.len());
let mut v: Array1<f64> = Array1::zeros(x.len());
let mut best_x = x.clone();
let mut best_f = f64::INFINITY;
let num_samples = data_provider.num_samples();
let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
let actual_batch_size = batch_size.min(num_samples);
#[allow(clippy::explicit_counter_loop)]
for iteration in 0..options.max_iter {
let progress = iteration as f64 / options.max_iter as f64;
let current_lr =
lr_min + 0.5 * (lr_max - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos());
let batch_indices = if actual_batch_size < num_samples {
generate_batch_indices(num_samples, actual_batch_size, true)
} else {
(0..num_samples).collect()
};
let batch_data = data_provider.get_batch(&batch_indices);
let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
if let Some(clip_threshold) = options.gradient_clip {
clip_gradients(&mut gradient, clip_threshold);
}
if options.decouple_weight_decay && options.weight_decay > 0.0 {
x = &x * (1.0 - current_lr * options.weight_decay);
}
m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
let gradient_sq = gradient.mapv(|g| g * g);
v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
let global_step = cycle_offset + iteration + 1;
let bias_correction1 = 1.0 - options.beta1.powi(global_step as i32);
let bias_correction2 = 1.0 - options.beta2.powi(global_step as i32);
let m_hat = &m / bias_correction1;
let v_hat = &v / bias_correction2;
let denominator = v_hat.mapv(|v: f64| v.sqrt() + options.epsilon);
let update = &m_hat / &denominator * current_lr;
x = &x - &update;
if iteration % 10 == 0 {
let full_data = data_provider.get_full_data();
let current_loss = grad_func.compute_value(&x.view(), &full_data);
if current_loss < best_f {
best_f = current_loss;
best_x = x.clone();
}
}
}
Ok(OptimizeResult {
x: best_x,
fun: best_f,
nit: options.max_iter,
func_evals: 0,
nfev: 0,
success: false,
message: "Cycle completed".to_string(),
jacobian: None,
hessian: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stochastic::InMemoryDataProvider;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::ArrayView1;
struct QuadraticFunction;
impl StochasticGradientFunction for QuadraticFunction {
fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
x.mapv(|xi| 2.0 * xi)
}
fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
x.mapv(|xi| xi * xi).sum()
}
}
#[test]
fn test_adamw_quadratic() {
let grad_func = QuadraticFunction;
let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
let options = AdamWOptions {
learning_rate: 0.1,
max_iter: 200,
tol: 1e-6,
..Default::default()
};
let result =
minimize_adamw(grad_func, x0, data_provider, options).expect("Operation failed");
assert!(result.success || result.fun < 1e-4);
for &xi in result.x.iter() {
assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
}
}
#[test]
fn test_adamw_weight_decay() {
let grad_func = QuadraticFunction;
let x0 = Array1::from_vec(vec![1.0, -1.0]);
let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
let options = AdamWOptions {
learning_rate: 0.1,
weight_decay: 0.01,
max_iter: 100,
batch_size: Some(10),
tol: 1e-6,
..Default::default()
};
let result =
minimize_adamw(grad_func, x0, data_provider, options).expect("Operation failed");
assert!(result.success || result.fun < 1e-4);
}
#[test]
fn test_adamw_decoupled_vs_coupled() {
let x0 = Array1::from_vec(vec![2.0, -2.0]);
let data_provider1 = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
let data_provider2 = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
let options_decoupled = AdamWOptions {
learning_rate: 0.01,
weight_decay: 0.1,
decouple_weight_decay: true,
max_iter: 500,
tol: 1e-4,
..Default::default()
};
let grad_func1 = QuadraticFunction;
let result_decoupled =
minimize_adamw(grad_func1, x0.clone(), data_provider1, options_decoupled)
.expect("Operation failed");
let options_coupled = AdamWOptions {
learning_rate: 0.01,
weight_decay: 0.1,
decouple_weight_decay: false,
max_iter: 500, tol: 1e-4,
..Default::default()
};
let grad_func2 = QuadraticFunction;
let result_coupled = minimize_adamw(grad_func2, x0, data_provider2, options_coupled)
.expect("Operation failed");
assert!(result_decoupled.fun < 1.0);
assert!(result_coupled.fun < 1.0);
}
#[test]
fn test_adamw_cosine_restarts() {
let grad_func = QuadraticFunction;
let x0 = Array1::from_vec(vec![3.0, -3.0]);
let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
let options = AdamWOptions {
learning_rate: 0.1,
max_iter: 500,
tol: 1e-4,
..Default::default()
};
let result = minimize_adamw_cosine_restarts(
grad_func,
x0,
data_provider,
options,
50, 1.5, 1e-6, )
.expect("Operation failed");
assert!(result.fun < 10.0); }
#[test]
fn test_adamw_gradient_clipping() {
let grad_func = QuadraticFunction;
let x0 = Array1::from_vec(vec![10.0, -10.0]); let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
let options = AdamWOptions {
learning_rate: 0.1, max_iter: 1000, gradient_clip: Some(1.0), tol: 1e-4,
..Default::default()
};
let result =
minimize_adamw(grad_func, x0, data_provider, options).expect("Operation failed");
assert!(result.success || result.fun < 1e-1);
}
}