use std;
use ndarray::{Array1, Array2};
use errors::*;
use prelude::*;
use operator::ArgminOperator;
use result::ArgminResult;
use termination::TerminationReason;
pub struct Landweber<'a> {
omega: f64,
max_iters: u64,
state: Option<LandweberState<'a>>,
}
struct LandweberState<'a> {
operator: &'a ArgminOperator<'a>,
param: Array1<f64>,
iter: u64,
norm: f64,
}
impl<'a> LandweberState<'a> {
pub fn new(operator: &'a ArgminOperator<'a>, param: Array1<f64>) -> Self {
LandweberState {
operator: operator,
param: param,
iter: 0_u64,
norm: std::f64::NAN,
}
}
}
impl<'a> Landweber<'a> {
pub fn new(omega: f64) -> Self {
Landweber {
omega: omega,
max_iters: std::u64::MAX,
state: None,
}
}
pub fn max_iters(&mut self, max_iters: u64) -> &mut Self {
self.max_iters = max_iters;
self
}
}
impl<'a> ArgminSolver<'a> for Landweber<'a> {
type Parameter = Array1<f64>;
type CostValue = f64;
type Hessian = Array2<f64>;
type StartingPoints = Self::Parameter;
type ProblemDefinition = &'a ArgminOperator<'a>;
fn init(
&mut self,
operator: Self::ProblemDefinition,
init_param: &Self::StartingPoints,
) -> Result<()> {
self.state = Some(LandweberState::new(operator, init_param.clone()));
Ok(())
}
fn next_iter(&mut self) -> Result<ArgminResult<Self::Parameter, Self::CostValue>> {
let mut state = self.state.take().unwrap();
let prev_param = state.param.clone();
let diff = state.operator.apply(&prev_param) - state.operator.y;
state.param = state.param - self.omega * state.operator.apply_transpose(&diff);
state.iter += 1;
state.norm = diff.iter().map(|a| a.powf(2.0)).sum::<f64>().sqrt();
let mut out = ArgminResult::new(state.param.clone(), state.norm, state.iter);
self.state = Some(state);
out.set_termination_reason(self.terminate());
Ok(out)
}
make_terminate!(self,
self.state.as_ref().unwrap().iter >= self.max_iters, TerminationReason::MaxItersReached;
self.state.as_ref().unwrap().norm <= self.state.as_ref().unwrap().operator.target_cost, TerminationReason::TargetCostReached;
);
make_run!(
Self::ProblemDefinition,
Self::StartingPoints,
Self::Parameter,
Self::CostValue
);
}
impl<'a> Default for Landweber<'a> {
fn default() -> Self {
Self::new(1.0)
}
}