use crate::core::{ArgminFloat, Error, Gradient, IterState, Problem, Solver, KV};
use argmin_math::ArgminScaledSub;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Landweber<F> {
omega: F,
}
impl<F> Landweber<F> {
pub fn new(omega: F) -> Self {
Landweber { omega }
}
}
impl<O, F, P, G> Solver<O, IterState<P, G, (), (), (), F>> for Landweber<F>
where
O: Gradient<Param = P, Gradient = G>,
P: Clone + ArgminScaledSub<G, F, P>,
F: ArgminFloat,
{
fn name(&self) -> &str {
"Landweber"
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let param = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`Landweber` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let grad = problem.gradient(¶m)?;
let new_param = param.scaled_sub(&self.omega, &grad);
Ok((state.param(new_param), None))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{test_utils::TestProblem, ArgminError, State};
use approx::assert_relative_eq;
test_trait_impl!(landweber, Landweber<f64>);
#[test]
fn test_new() {
let omega_in: f64 = 0.5;
let Landweber { omega } = Landweber::new(omega_in);
assert_eq!(omega.to_ne_bytes(), omega_in.to_ne_bytes());
}
#[test]
fn test_next_iter_param_not_initialized() {
let omega: f64 = 0.5;
let mut landweber = Landweber::new(omega);
let state = IterState::new();
let res = landweber.next_iter(&mut Problem::new(TestProblem::new()), state);
assert_error!(
res,
ArgminError,
concat!(
"Not initialized: \"`Landweber` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method.\""
)
);
}
#[test]
fn test_next_iter() {
let omega: f64 = 0.5;
let mut landweber = Landweber::new(omega);
let state = IterState::new().param(vec![2.0, 4.0]);
let (state, kv) = landweber
.next_iter(&mut Problem::new(TestProblem::new()), state)
.unwrap();
assert!(kv.is_none());
let new_param = state.get_param().unwrap();
assert_relative_eq!(new_param[0], 1.0, epsilon = f64::EPSILON);
assert_relative_eq!(new_param[1], 2.0, epsilon = f64::EPSILON);
}
}