use std;
use errors::*;
use prelude::*;
use problem::ArgminProblem;
use result::ArgminResult;
use termination::TerminationReason;
pub struct NelderMead<'a> {
max_iters: u64,
alpha: f64,
gamma: f64,
rho: f64,
sigma: f64,
state: Option<NelderMeadState<'a>>,
}
#[derive(Clone, Debug)]
struct NelderMeadParam {
param: Vec<f64>,
cost: f64,
}
struct NelderMeadState<'a> {
problem: &'a ArgminProblem<'a, Vec<f64>, f64, Vec<f64>>,
param_vecs: Vec<NelderMeadParam>,
iter: u64,
}
impl<'a> NelderMeadState<'a> {
pub fn new(
problem: &'a ArgminProblem<'a, Vec<f64>, f64, Vec<f64>>,
param_vecs: Vec<NelderMeadParam>,
) -> Self {
NelderMeadState {
problem: problem,
param_vecs: param_vecs,
iter: 0_u64,
}
}
}
impl<'a> NelderMead<'a> {
pub fn new() -> Self {
NelderMead {
max_iters: std::u64::MAX,
alpha: 1.0,
gamma: 2.0,
rho: 0.5,
sigma: 0.5,
state: None,
}
}
pub fn max_iters(&mut self, max_iters: u64) -> &mut Self {
self.max_iters = max_iters;
self
}
pub fn alpha(&mut self, alpha: f64) -> &mut Self {
self.alpha = alpha;
self
}
pub fn gamma(&mut self, gamma: f64) -> &mut Self {
self.gamma = gamma;
self
}
pub fn rho(&mut self, rho: f64) -> &mut Self {
self.rho = rho;
self
}
pub fn sigma(&mut self, sigma: f64) -> &mut Self {
self.sigma = sigma;
self
}
fn sort_param_vecs(&mut self, state: &mut NelderMeadState) {
state.param_vecs.sort_by(|a, b| {
a.cost
.partial_cmp(&b.cost)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn calculate_centroid(&self, state: &NelderMeadState) -> Vec<f64> {
let num_param = state.param_vecs.len() - 1;
let mut x0: Vec<f64> = state.param_vecs[0].clone().param;
for idx in 1..num_param {
x0 = x0.iter()
.zip(state.param_vecs[idx].param.iter())
.map(|(a, b)| a + b)
.collect();
}
x0.iter().map(|a| a / (num_param as f64)).collect()
}
fn reflect(&self, x0: &[f64], x: &[f64]) -> Vec<f64> {
x0.iter()
.zip(x.iter())
.map(|(a, b)| a + self.alpha * (a - b))
.collect()
}
fn expand(&self, x0: &[f64], x: &[f64]) -> Vec<f64> {
x0.iter()
.zip(x.iter())
.map(|(a, b)| a + self.gamma * (b - a))
.collect()
}
fn contract(&self, x0: &[f64], x: &[f64]) -> Vec<f64> {
x0.iter()
.zip(x.iter())
.map(|(a, b)| a + self.rho * (b - a))
.collect()
}
fn shrink(&mut self, state: &mut NelderMeadState) {
for idx in 1..state.param_vecs.len() {
state.param_vecs[idx].param = state
.param_vecs
.first()
.unwrap()
.param
.iter()
.zip(state.param_vecs[idx].param.iter())
.map(|(a, b)| a + self.sigma * (b - a))
.collect();
state.param_vecs[idx].cost =
(state.problem.cost_function)(&state.param_vecs[idx].param);
}
}
}
impl<'a> ArgminSolver<'a> for NelderMead<'a> {
type Parameter = Vec<f64>;
type CostValue = f64;
type Hessian = Vec<f64>;
type StartingPoints = Vec<Self::Parameter>;
type ProblemDefinition = &'a ArgminProblem<'a, Self::Parameter, Self::CostValue, Self::Hessian>;
fn init(
&mut self,
problem: Self::ProblemDefinition,
param_vecs: &Self::StartingPoints,
) -> Result<()> {
let mut params: Vec<NelderMeadParam> = vec![];
for param in param_vecs.iter() {
params.push(NelderMeadParam {
param: param.to_vec(),
cost: (problem.cost_function)(param),
});
}
let mut state = NelderMeadState::new(problem, params);
self.sort_param_vecs(&mut state);
self.state = Some(state);
Ok(())
}
fn next_iter(&mut self) -> Result<ArgminResult<Self::Parameter, Self::CostValue>> {
let mut state = self.state.take().unwrap();
self.sort_param_vecs(&mut state);
let num_param = state.param_vecs[0].param.len();
let x0 = self.calculate_centroid(&state);
let xr = self.reflect(&x0, &state.param_vecs.last().unwrap().param);
let xr_cost = (state.problem.cost_function)(&xr);
if xr_cost < state.param_vecs[num_param - 2].cost && xr_cost >= state.param_vecs[0].cost {
state.param_vecs.last_mut().unwrap().param = xr;
state.param_vecs.last_mut().unwrap().cost = xr_cost;
} else if xr_cost < state.param_vecs[0].cost {
let xe = self.expand(&x0, &xr);
let xe_cost = (state.problem.cost_function)(&xe);
if xe_cost < xr_cost {
state.param_vecs.last_mut().unwrap().param = xe;
state.param_vecs.last_mut().unwrap().cost = xe_cost;
} else {
state.param_vecs.last_mut().unwrap().param = xr;
state.param_vecs.last_mut().unwrap().cost = xr_cost;
}
} else if xr_cost >= state.param_vecs[num_param - 2].cost {
let xc = self.contract(&x0, &state.param_vecs.last().unwrap().param);
let xc_cost = (state.problem.cost_function)(&xc);
if xc_cost < state.param_vecs.last().unwrap().cost {
state.param_vecs.last_mut().unwrap().param = xc;
state.param_vecs.last_mut().unwrap().cost = xc_cost;
}
} else {
self.shrink(&mut state)
}
state.iter += 1;
self.sort_param_vecs(&mut state);
let param = state.param_vecs[0].clone();
let mut out = ArgminResult::new(param.param, param.cost, state.iter);
self.state = Some(state);
out.set_termination_reason(self.terminate());
Ok(out)
}
make_terminate!(self,
self.state.as_ref().unwrap().iter >= self.max_iters, TerminationReason::MaxItersReached;
self.state.as_ref().unwrap().param_vecs[0].cost <= self.state.as_ref().unwrap().problem.target_cost, TerminationReason::TargetCostReached;
);
make_run!(
Self::ProblemDefinition,
Self::StartingPoints,
Self::Parameter,
Self::CostValue
);
}
impl<'a> Default for NelderMead<'a> {
fn default() -> Self {
Self::new()
}
}