gosh_fire/
base.rs

1// [[file:../fire.note::*imports][imports:1]]
2use crate::common::*;
3
4use vecfx::*;
5// imports:1 ends here
6
7// [[file:../fire.note::*input/output][input/output:1]]
8/// Evaluated function value and gradient
9#[derive(Debug, Clone)]
10pub struct Output {
11    pub fx: f64,
12    pub gx: Vec<f64>,
13}
14
15pub type Input<'a> = &'a [f64];
16
17impl Output {
18    fn new(n: usize) -> Self {
19        use std::f64::NAN;
20
21        Self {
22            fx: NAN,
23            gx: vec![NAN; n],
24        }
25    }
26}
27// input/output:1 ends here
28
29// [[file:../fire.note::*trait][trait:1]]
30/// A trait for evaluating value and gradient of objective function
31pub trait EvaluateFunction<U> {
32    fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U>;
33}
34
35// [2020-11-27 Fri] Will trigger conflicting implementation error
36//
37// impl<T> EvaluateFunction<()> for T
38// where
39//     T: FnMut(&[f64], &mut [f64]) -> f64,
40// {
41//     fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<()> {
42//         let fx = (self)(input, &mut output.gx);
43//         output.fx = fx;
44//         Ok(())
45//     }
46// }
47
48impl<T, U> EvaluateFunction<U> for T
49where
50    T: FnMut(Input, &mut Output) -> Result<U>,
51{
52    fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U> {
53        let user_data = (self)(input, output)?;
54        Ok(user_data)
55    }
56}
57// trait:1 ends here
58
59// [[file:../fire.note::*problem][problem:1]]
60pub struct Problem<'a, U> {
61    // input position
62    x: Vec<f64>,
63    // callback function for evaluation
64    f: Box<dyn EvaluateFunction<U> + 'a>,
65
66    // evaluated function value and gradient
67    out: Option<Output>,
68
69    epsilon: f64,
70    neval: usize,
71
72    // cache previous position and function value
73    x_prev: Option<Vec<f64>>,
74    out_prev: Option<Output>,
75
76    // returned data from user defined closure
77    pub user_data: Option<U>,
78}
79
80impl<'a, U> Problem<'a, U> {
81    /// The number of function calls made
82    pub fn ncalls(&self) -> usize {
83        self.neval
84    }
85
86    /// Return function value at current position.
87    ///
88    /// The function will be evaluated when necessary.
89    pub fn value(&mut self) -> f64 {
90        // found cached value?
91        if self.out.is_none() {
92            self.eval().expect("eval error");
93        }
94        self.out.as_ref().expect("no out").fx
95    }
96
97    /// Return function value at previous point
98    pub fn value_prev(&self) -> f64 {
99        self.out_prev.as_ref().expect("not evaluated yet").fx
100    }
101
102    /// Return a reference to function gradient at previous point
103    pub fn gradient_prev(&self) -> &[f64] {
104        &self.out_prev.as_ref().expect("not evaluated yet").gx
105    }
106
107    /// Return a reference to function gradient at current position.
108    ///
109    /// The function will be evaluated when necessary.
110    pub fn gradient(&mut self) -> &[f64] {
111        // found cached value?
112        if self.out.is_none() {
113            self.eval().expect("eval error");
114        }
115        &self.out.as_ref().expect("no out").gx
116    }
117
118    /// Return a reference to current position vector.
119    pub fn position(&self) -> &[f64] {
120        &self.x
121    }
122
123    /// Revert to previous point
124    pub fn revert(&mut self) {
125        self.x = self.x_prev.clone().expect("not evaluated yet");
126        self.out = self.out_prev.clone();
127    }
128}
129
130/// Core input/output methods
131impl<'a, U> Problem<'a, U> {
132    /// Construct a CachedProblem
133    ///
134    /// # Parameters
135    ///
136    /// * x: initial position
137    /// * f: a closure for function evaluation of value and gradient.
138    pub fn new(x: Vec<f64>, f: impl EvaluateFunction<U> + 'a) -> Self {
139        Self {
140            neval: 0,
141            epsilon: 1e-8,
142            out: None,
143            x_prev: x.clone().into(),
144            out_prev: None,
145            user_data: None,
146
147            f: Box::new(f),
148            x,
149        }
150    }
151
152    /// Update position `x` at a prescribed displacement and step size.
153    ///
154    /// x += step * displ
155    pub fn take_line_step(&mut self, displ: &[f64], step: f64) {
156        // position changed
157        if step * displ.vec2norm() > self.epsilon {
158            // update position vector with displacement
159            self.x.vecadd(displ, step);
160
161            // invalidate function output and update cached previous point
162            // FIXME: review required
163            self.out = None;
164            // self.out_prev = self.out.take();
165            // self.x_prev = self.out_prev.as_ref().map(|_| self.x.clone());
166        }
167    }
168
169    /// evaluate function value and gradient at current position
170    fn eval(&mut self) -> Result<()> {
171        // evaluate function and save returned value from user defined closure
172        let n = self.x.len();
173        let mut out = self.out.take().unwrap_or(Output::new(n));
174        self.user_data = self.f.evaluate(&self.x, &mut out)?.into();
175
176        // FIXME: review required
177        // update cached previous point
178        self.out_prev = out.clone().into();
179        // self.x_prev = self.out_prev.as_ref().map(|_| self.x.clone());
180        self.x_prev = self.x.clone().into();
181
182        // update function value and gradient
183        self.out = out.into();
184        self.neval += 1;
185
186        Ok(())
187    }
188}
189// problem:1 ends here
190
191// [[file:../fire.note::*progress][progress:1]]
192/// Progress data produced in each minimization iterations, useful for progress monitor.
193#[derive(Debug, Clone)]
194pub struct Progress<T> {
195    /// Current gradient vector norm.
196    pub gnorm: f64,
197
198    /// Current value of the objective function.
199    pub fx: f64,
200
201    /// The number of function calls made
202    pub ncalls: usize,
203
204    /// The extra data returned from user defined closure for objective function
205    /// evaluation
206    pub extra: T,
207}
208// progress:1 ends here