gosh_fire/base.rs
1// [[file:../fire.note::*imports][imports:1]]
2use crate::common::*;
3
4use vecfx::*;
5// imports:1 ends here
6
7// [[file:../fire.note::*input/output][input/output:1]]
8/// Evaluated function value and gradient
9#[derive(Debug, Clone)]
10pub struct Output {
11 pub fx: f64,
12 pub gx: Vec<f64>,
13}
14
15pub type Input<'a> = &'a [f64];
16
17impl Output {
18 fn new(n: usize) -> Self {
19 use std::f64::NAN;
20
21 Self {
22 fx: NAN,
23 gx: vec![NAN; n],
24 }
25 }
26}
27// input/output:1 ends here
28
29// [[file:../fire.note::*trait][trait:1]]
30/// A trait for evaluating value and gradient of objective function
31pub trait EvaluateFunction<U> {
32 fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U>;
33}
34
35// [2020-11-27 Fri] Will trigger conflicting implementation error
36//
37// impl<T> EvaluateFunction<()> for T
38// where
39// T: FnMut(&[f64], &mut [f64]) -> f64,
40// {
41// fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<()> {
42// let fx = (self)(input, &mut output.gx);
43// output.fx = fx;
44// Ok(())
45// }
46// }
47
48impl<T, U> EvaluateFunction<U> for T
49where
50 T: FnMut(Input, &mut Output) -> Result<U>,
51{
52 fn evaluate(&mut self, input: Input, output: &mut Output) -> Result<U> {
53 let user_data = (self)(input, output)?;
54 Ok(user_data)
55 }
56}
57// trait:1 ends here
58
59// [[file:../fire.note::*problem][problem:1]]
60pub struct Problem<'a, U> {
61 // input position
62 x: Vec<f64>,
63 // callback function for evaluation
64 f: Box<dyn EvaluateFunction<U> + 'a>,
65
66 // evaluated function value and gradient
67 out: Option<Output>,
68
69 epsilon: f64,
70 neval: usize,
71
72 // cache previous position and function value
73 x_prev: Option<Vec<f64>>,
74 out_prev: Option<Output>,
75
76 // returned data from user defined closure
77 pub user_data: Option<U>,
78}
79
80impl<'a, U> Problem<'a, U> {
81 /// The number of function calls made
82 pub fn ncalls(&self) -> usize {
83 self.neval
84 }
85
86 /// Return function value at current position.
87 ///
88 /// The function will be evaluated when necessary.
89 pub fn value(&mut self) -> f64 {
90 // found cached value?
91 if self.out.is_none() {
92 self.eval().expect("eval error");
93 }
94 self.out.as_ref().expect("no out").fx
95 }
96
97 /// Return function value at previous point
98 pub fn value_prev(&self) -> f64 {
99 self.out_prev.as_ref().expect("not evaluated yet").fx
100 }
101
102 /// Return a reference to function gradient at previous point
103 pub fn gradient_prev(&self) -> &[f64] {
104 &self.out_prev.as_ref().expect("not evaluated yet").gx
105 }
106
107 /// Return a reference to function gradient at current position.
108 ///
109 /// The function will be evaluated when necessary.
110 pub fn gradient(&mut self) -> &[f64] {
111 // found cached value?
112 if self.out.is_none() {
113 self.eval().expect("eval error");
114 }
115 &self.out.as_ref().expect("no out").gx
116 }
117
118 /// Return a reference to current position vector.
119 pub fn position(&self) -> &[f64] {
120 &self.x
121 }
122
123 /// Revert to previous point
124 pub fn revert(&mut self) {
125 self.x = self.x_prev.clone().expect("not evaluated yet");
126 self.out = self.out_prev.clone();
127 }
128}
129
130/// Core input/output methods
131impl<'a, U> Problem<'a, U> {
132 /// Construct a CachedProblem
133 ///
134 /// # Parameters
135 ///
136 /// * x: initial position
137 /// * f: a closure for function evaluation of value and gradient.
138 pub fn new(x: Vec<f64>, f: impl EvaluateFunction<U> + 'a) -> Self {
139 Self {
140 neval: 0,
141 epsilon: 1e-8,
142 out: None,
143 x_prev: x.clone().into(),
144 out_prev: None,
145 user_data: None,
146
147 f: Box::new(f),
148 x,
149 }
150 }
151
152 /// Update position `x` at a prescribed displacement and step size.
153 ///
154 /// x += step * displ
155 pub fn take_line_step(&mut self, displ: &[f64], step: f64) {
156 // position changed
157 if step * displ.vec2norm() > self.epsilon {
158 // update position vector with displacement
159 self.x.vecadd(displ, step);
160
161 // invalidate function output and update cached previous point
162 // FIXME: review required
163 self.out = None;
164 // self.out_prev = self.out.take();
165 // self.x_prev = self.out_prev.as_ref().map(|_| self.x.clone());
166 }
167 }
168
169 /// evaluate function value and gradient at current position
170 fn eval(&mut self) -> Result<()> {
171 // evaluate function and save returned value from user defined closure
172 let n = self.x.len();
173 let mut out = self.out.take().unwrap_or(Output::new(n));
174 self.user_data = self.f.evaluate(&self.x, &mut out)?.into();
175
176 // FIXME: review required
177 // update cached previous point
178 self.out_prev = out.clone().into();
179 // self.x_prev = self.out_prev.as_ref().map(|_| self.x.clone());
180 self.x_prev = self.x.clone().into();
181
182 // update function value and gradient
183 self.out = out.into();
184 self.neval += 1;
185
186 Ok(())
187 }
188}
189// problem:1 ends here
190
191// [[file:../fire.note::*progress][progress:1]]
192/// Progress data produced in each minimization iterations, useful for progress monitor.
193#[derive(Debug, Clone)]
194pub struct Progress<T> {
195 /// Current gradient vector norm.
196 pub gnorm: f64,
197
198 /// Current value of the objective function.
199 pub fx: f64,
200
201 /// The number of function calls made
202 pub ncalls: usize,
203
204 /// The extra data returned from user defined closure for objective function
205 /// evaluation
206 pub extra: T,
207}
208// progress:1 ends here