use super::utilities::{dot_product, tensor_add, tensor_scalar_mul};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub enum LineSearchMethod {
Backtracking,
Exact,
Wolfe,
StrongWolfe,
}
#[derive(Debug, Clone)]
pub struct BacktrackingParams {
pub alpha0: f32,
pub c1: f32,
pub rho: f32,
pub max_iter: usize,
}
impl Default for BacktrackingParams {
fn default() -> Self {
Self {
alpha0: 1.0,
c1: 1e-4,
rho: 0.5,
max_iter: 50,
}
}
}
pub fn backtracking_line_search<F, G>(
objective: F,
gradient: G,
x: &Tensor,
p: &Tensor,
params: Option<BacktrackingParams>,
) -> TorshResult<f32>
where
F: Fn(&Tensor) -> TorshResult<f32>,
G: Fn(&Tensor) -> TorshResult<Tensor>,
{
let params = params.unwrap_or_default();
let f0 = objective(x)?;
let grad0 = gradient(x)?;
let directional_deriv = dot_product(&grad0, p)?;
if directional_deriv >= 0.0 {
return Err(TorshError::InvalidArgument(
"Search direction is not a descent direction".to_string(),
));
}
let mut alpha = params.alpha0;
for _ in 0..params.max_iter {
let x_new = tensor_add(x, &tensor_scalar_mul(p, alpha)?)?;
let f_new = objective(&x_new)?;
if f_new <= f0 + params.c1 * alpha * directional_deriv {
return Ok(alpha);
}
alpha *= params.rho;
}
Ok(alpha)
}
#[derive(Debug, Clone)]
pub struct WolfeParams {
pub c1: f32,
pub c2: f32,
pub alpha0: f32,
pub alpha_max: f32,
pub max_iter: usize,
}
impl Default for WolfeParams {
fn default() -> Self {
Self {
c1: 1e-4,
c2: 0.9,
alpha0: 1.0,
alpha_max: 100.0,
max_iter: 20,
}
}
}
pub fn wolfe_line_search<F, G>(
objective: F,
gradient: G,
x: &Tensor,
p: &Tensor,
params: Option<WolfeParams>,
) -> TorshResult<f32>
where
F: Fn(&Tensor) -> TorshResult<f32>,
G: Fn(&Tensor) -> TorshResult<Tensor>,
{
let params = params.unwrap_or_default();
let f0 = objective(x)?;
let grad0 = gradient(x)?;
let directional_deriv0 = dot_product(&grad0, p)?;
if directional_deriv0 >= 0.0 {
return Err(TorshError::InvalidArgument(
"Search direction is not a descent direction".to_string(),
));
}
let mut alpha_lo = 0.0;
let mut alpha_hi = params.alpha_max;
let mut alpha = params.alpha0;
for _ in 0..params.max_iter {
let x_new = tensor_add(x, &tensor_scalar_mul(p, alpha)?)?;
let f_new = objective(&x_new)?;
if f_new > f0 + params.c1 * alpha * directional_deriv0 {
alpha_hi = alpha;
alpha = (alpha_lo + alpha_hi) / 2.0;
continue;
}
let grad_new = gradient(&x_new)?;
let directional_deriv_new = dot_product(&grad_new, p)?;
if directional_deriv_new.abs() <= -params.c2 * directional_deriv0 {
return Ok(alpha);
}
if directional_deriv_new >= 0.0 {
alpha_hi = alpha_lo;
}
alpha_lo = alpha;
alpha = if alpha_hi == params.alpha_max {
2.0 * alpha
} else {
(alpha_lo + alpha_hi) / 2.0
};
}
Ok(alpha)
}