use super::line_search::{backtracking_line_search, wolfe_line_search, LineSearchMethod};
use super::utilities::{
tensor_add, tensor_elementwise_div, tensor_elementwise_mul, tensor_full_like, tensor_norm,
tensor_scalar_mul, tensor_sqrt, tensor_sub, tensor_zeros_like,
};
use torsh_core::Result as TorshResult;
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct GradientDescentParams {
pub learning_rate: f32,
pub max_iter: usize,
pub tolerance: f32,
pub line_search: Option<LineSearchMethod>,
}
impl Default for GradientDescentParams {
fn default() -> Self {
Self {
learning_rate: 0.01,
max_iter: 1000,
tolerance: 1e-6,
line_search: Some(LineSearchMethod::Backtracking),
}
}
}
pub fn gradient_descent<F, G>(
objective: F,
gradient: G,
x0: &Tensor,
params: Option<GradientDescentParams>,
) -> TorshResult<(Tensor, Vec<f32>)>
where
F: Fn(&Tensor) -> TorshResult<f32>,
G: Fn(&Tensor) -> TorshResult<Tensor>,
{
let params = params.unwrap_or_default();
let mut x = x0.clone();
let mut objective_values = Vec::new();
for iter in 0..params.max_iter {
let f_val = objective(&x)?;
objective_values.push(f_val);
let grad = gradient(&x)?;
let grad_norm = tensor_norm(&grad)?;
if grad_norm < params.tolerance {
break;
}
let p = tensor_scalar_mul(&grad, -1.0)?;
let alpha = match params.line_search {
Some(LineSearchMethod::Backtracking) => {
backtracking_line_search(&objective, &gradient, &x, &p, None)?
}
Some(LineSearchMethod::Wolfe) => {
wolfe_line_search(&objective, &gradient, &x, &p, None)?
}
_ => params.learning_rate,
};
x = tensor_add(&x, &tensor_scalar_mul(&p, alpha)?)?;
if iter % 100 == 0 {
println!(
"Iteration {}: f = {:.6e}, |∇f| = {:.6e}, α = {:.6e}",
iter, f_val, grad_norm, alpha
);
}
}
Ok((x, objective_values))
}
#[derive(Debug, Clone)]
pub struct MomentumParams {
pub learning_rate: f32,
pub momentum: f32,
pub max_iter: usize,
pub tolerance: f32,
}
impl Default for MomentumParams {
fn default() -> Self {
Self {
learning_rate: 0.01,
momentum: 0.9,
max_iter: 1000,
tolerance: 1e-6,
}
}
}
pub fn momentum_gradient_descent<F, G>(
objective: F,
gradient: G,
x0: &Tensor,
params: Option<MomentumParams>,
) -> TorshResult<(Tensor, Vec<f32>)>
where
F: Fn(&Tensor) -> TorshResult<f32>,
G: Fn(&Tensor) -> TorshResult<Tensor>,
{
let params = params.unwrap_or_default();
let mut x = x0.clone();
let mut v = tensor_zeros_like(&x)?; let mut objective_values = Vec::new();
for iter in 0..params.max_iter {
let f_val = objective(&x)?;
objective_values.push(f_val);
let grad = gradient(&x)?;
let grad_norm = tensor_norm(&grad)?;
if grad_norm < params.tolerance {
break;
}
v = tensor_add(&tensor_scalar_mul(&v, params.momentum)?, &grad)?;
x = tensor_sub(&x, &tensor_scalar_mul(&v, params.learning_rate)?)?;
if iter % 100 == 0 {
println!(
"Iteration {}: f = {:.6e}, |∇f| = {:.6e}",
iter, f_val, grad_norm
);
}
}
Ok((x, objective_values))
}
#[derive(Debug, Clone)]
pub struct AdamParams {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub max_iter: usize,
pub tolerance: f32,
}
impl Default for AdamParams {
fn default() -> Self {
Self {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
max_iter: 1000,
tolerance: 1e-6,
}
}
}
pub fn adam_optimizer<F, G>(
objective: F,
gradient: G,
x0: &Tensor,
params: Option<AdamParams>,
) -> TorshResult<(Tensor, Vec<f32>)>
where
F: Fn(&Tensor) -> TorshResult<f32>,
G: Fn(&Tensor) -> TorshResult<Tensor>,
{
let params = params.unwrap_or_default();
let mut x = x0.clone();
let mut m = tensor_zeros_like(&x)?; let mut v = tensor_zeros_like(&x)?; let mut objective_values = Vec::new();
for iter in 0..params.max_iter {
let t = (iter + 1) as f32;
let f_val = objective(&x)?;
objective_values.push(f_val);
let grad = gradient(&x)?;
let grad_norm = tensor_norm(&grad)?;
if grad_norm < params.tolerance {
break;
}
m = tensor_add(
&tensor_scalar_mul(&m, params.beta1)?,
&tensor_scalar_mul(&grad, 1.0 - params.beta1)?,
)?;
let grad_squared = tensor_elementwise_mul(&grad, &grad)?;
v = tensor_add(
&tensor_scalar_mul(&v, params.beta2)?,
&tensor_scalar_mul(&grad_squared, 1.0 - params.beta2)?,
)?;
let m_hat = tensor_scalar_mul(&m, 1.0 / (1.0 - params.beta1.powf(t)))?;
let v_hat = tensor_scalar_mul(&v, 1.0 / (1.0 - params.beta2.powf(t)))?;
let denominator = tensor_add(
&tensor_sqrt(&v_hat)?,
&tensor_full_like(&x, params.epsilon)?,
)?;
let update = tensor_elementwise_div(&m_hat, &denominator)?;
x = tensor_sub(&x, &tensor_scalar_mul(&update, params.learning_rate)?)?;
if iter % 100 == 0 {
println!(
"Iteration {}: f = {:.6e}, |∇f| = {:.6e}",
iter, f_val, grad_norm
);
}
}
Ok((x, objective_values))
}