use std::fmt::Debug;
use std::ops::Add;
use types::{Function, Function1};
pub trait LineSearch: Debug {
fn search<F>(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec<f64>
where F: Function1;
}
#[derive(Debug, Copy, Clone)]
pub struct FixedStepWidth {
fixed_step_width: f64
}
impl FixedStepWidth {
pub fn new(fixed_step_width: f64) -> FixedStepWidth {
assert!(fixed_step_width > 0.0 && fixed_step_width.is_finite(),
"fixed_step_width must be greater than 0 and finite");
FixedStepWidth {
fixed_step_width
}
}
}
impl LineSearch for FixedStepWidth {
fn search<F>(&self, _function: &F, initial_position: &[f64], direction: &[f64]) -> Vec<f64>
where F: Function
{
initial_position.iter().cloned().zip(direction).map(|(x, d)| {
x + self.fixed_step_width * d
}).collect()
}
}
#[derive(Debug, Copy, Clone)]
pub struct ExactLineSearch {
start_step_width: f64,
stop_step_width: f64,
increase_factor: f64
}
impl ExactLineSearch {
pub fn new(start_step_width: f64, stop_step_width: f64, increase_factor: f64) ->
ExactLineSearch
{
assert!(start_step_width > 0.0 && start_step_width.is_finite(),
"start_step_width must be greater than 0 and finite");
assert!(stop_step_width > start_step_width && stop_step_width.is_finite(),
"stop_step_width must be greater than start_step_width");
assert!(increase_factor > 1.0 && increase_factor.is_finite(),
"increase_factor must be greater than 1 and finite");
ExactLineSearch {
start_step_width,
stop_step_width,
increase_factor
}
}
}
impl LineSearch for ExactLineSearch {
fn search<F>(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec<f64>
where F: Function1
{
let mut min_position = initial_position.to_vec();
let mut min_value = function.value(initial_position);
let mut step_width = self.start_step_width;
loop {
let position: Vec<_> = initial_position.iter().cloned().zip(direction).map(|(x, d)| {
x + step_width * d
}).collect();
let value = function.value(&position);
if value < min_value {
min_position = position;
min_value = value;
}
step_width *= self.increase_factor;
if step_width >= self.stop_step_width {
break;
}
}
min_position
}
}
#[derive(Debug, Copy, Clone)]
pub struct ArmijoLineSearch {
control_parameter: f64,
initial_step_width: f64,
decay_factor: f64
}
impl ArmijoLineSearch {
pub fn new(control_parameter: f64, initial_step_width: f64, decay_factor: f64) ->
ArmijoLineSearch
{
assert!(control_parameter > 0.0 && control_parameter < 1.0,
"control_parameter must be in range (0, 1)");
assert!(initial_step_width > 0.0 && initial_step_width.is_finite(),
"initial_step_width must be > 0 and finite");
assert!(decay_factor > 0.0 && decay_factor < 1.0, "decay_factor must be in range (0, 1)");
ArmijoLineSearch {
control_parameter,
initial_step_width,
decay_factor
}
}
}
impl LineSearch for ArmijoLineSearch {
fn search<F>(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec<f64>
where F: Function1
{
let initial_value = function.value(initial_position);
let gradient = function.gradient(initial_position);
let m = gradient.iter().zip(direction).map(|(g, d)| g * d).fold(0.0, Add::add);
let t = -self.control_parameter * m;
assert!(t > 0.0);
let mut step_width = self.initial_step_width;
loop {
let position: Vec<_> = initial_position.iter().cloned().zip(direction).map(|(x, d)| {
x + step_width * d
}).collect();
let value = function.value(&position);
if value <= initial_value - step_width * t {
return position;
}
step_width *= self.decay_factor;
}
}
}