use crate::common::*;
use vecfx::*;
#[derive(Debug, Clone)]
pub struct Output {
pub fx: f64,
pub gx: Vec<f64>,
}
pub type Input<'a> = &'a [f64];
impl Output {
fn new(n: usize) -> Self {
use std::f64::NAN;
Self {
fx: NAN,
gx: vec![NAN; n],
}
}
}
pub trait EvaluateFunction<U> {
fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U>;
}
impl<T, U> EvaluateFunction<U> for T
where
T: FnMut(Input, &mut Output) -> Result<U>,
{
fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U> {
let user_data = (self)(input, output)?;
Ok(user_data)
}
}
pub struct Problem<'a, U> {
x: Vec<f64>,
f: Box<dyn EvaluateFunction<U> + 'a>,
out: Option<Output>,
epsilon: f64,
neval: usize,
x_prev: Option<Vec<f64>>,
out_prev: Option<Output>,
pub user_data: Option<U>,
}
impl<'a, U> Problem<'a, U> {
pub fn ncalls(&self) -> usize {
self.neval
}
pub fn value(&mut self) -> f64 {
if self.out.is_none() {
self.eval().expect("eval error");
}
self.out.as_ref().expect("no out").fx
}
pub fn value_prev(&self) -> f64 {
self.out_prev.as_ref().expect("not evaluated yet").fx
}
pub fn gradient_prev(&self) -> &[f64] {
&self.out_prev.as_ref().expect("not evaluated yet").gx
}
pub fn gradient(&mut self) -> &[f64] {
if self.out.is_none() {
self.eval().expect("eval error");
}
&self.out.as_ref().expect("no out").gx
}
pub fn position(&self) -> &[f64] {
&self.x
}
pub fn revert(&mut self) {
self.x = self.x_prev.clone().expect("not evaluated yet");
self.out = self.out_prev.clone();
}
}
impl<'a, U> Problem<'a, U> {
pub fn new(x: Vec<f64>, f: impl EvaluateFunction<U> + 'a) -> Self {
Self {
neval: 0,
epsilon: 1e-8,
out: None,
x_prev: x.clone().into(),
out_prev: None,
user_data: None,
f: Box::new(f),
x,
}
}
pub fn take_line_step(&mut self, displ: &[f64], step: f64) {
if step * displ.vec2norm() > self.epsilon {
self.x.vecadd(displ, step);
self.out = None;
}
}
fn eval(&mut self) -> Result<()> {
let n = self.x.len();
let mut out = self.out.take().unwrap_or(Output::new(n));
self.user_data = self.f.evaluate(&self.x, &mut out)?.into();
self.out_prev = out.clone().into();
self.x_prev = self.x.clone().into();
self.out = out.into();
self.neval += 1;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Progress<T> {
pub gnorm: f64,
pub fx: f64,
pub ncalls: usize,
pub extra: T,
}