bacon_sci_1/optimize/
mod.rs

1use crate::polynomial::Polynomial;
2use nalgebra::{
3    allocator::Allocator, ComplexField, DMatrix, DVector, DefaultAllocator, DimName, RealField,
4    VectorN,
5};
6use num_traits::{FromPrimitive, One, Zero};
7
8/// Linear least-squares regression
9pub fn linear_fit<N: ComplexField>(xs: &[N], ys: &[N]) -> Result<Polynomial<N>, String> {
10    if xs.len() != ys.len() {
11        return Err("linear_fit: xs length does not match ys length".to_owned());
12    }
13
14    let mut sum_x = N::zero();
15    let mut sum_y = N::zero();
16    let mut sum_x_sq = N::zero();
17    let mut sum_y_sq = N::zero();
18    let mut sum_xy = N::zero();
19
20    for (ind, x) in xs.iter().enumerate() {
21        sum_x += *x;
22        sum_y += ys[ind];
23        sum_x_sq += x.powi(2);
24        sum_y_sq += ys[ind].powi(2);
25        sum_xy += ys[ind] * *x;
26    }
27
28    let m = N::from_usize(xs.len()).unwrap();
29    let denom = m * sum_x_sq - sum_x.powi(2);
30    let a = (m * sum_xy - sum_x * sum_y) / denom;
31    let b = (sum_x_sq * sum_y - sum_xy * sum_x) / denom;
32
33    Ok(polynomial![a, b])
34}
35
36// Compute the J matrix for LM using finite differences, 3 point formula
37fn jac_finite_differences<N: ComplexField, V: DimName, F: FnMut(N, &VectorN<N, V>) -> N>(
38    mut f: F,
39    xs: &[N],
40    params: &mut VectorN<N, V>,
41    mat: &mut DMatrix<N>,
42    h: N::RealField,
43) where
44    DefaultAllocator: Allocator<N, V>,
45{
46    let h = N::from_real(h);
47    let denom = N::one() / (N::from_i32(2).unwrap() * h);
48    for row in 0..mat.column(0).len() {
49        for col in 0..mat.row(0).len() {
50            params[col] += h;
51            let above = f(xs[row], &params);
52            params[col] -= h;
53            params[col] -= h;
54            let below = f(xs[row], &params);
55            mat[(row, col)] = denom * (above + below);
56            params[col] += h;
57        }
58    }
59}
60
61// Compute the J matrix for LM using analytic formula
62fn jac_analytic<N: ComplexField, V: DimName, F: FnMut(N, &VectorN<N, V>) -> VectorN<N, V>>(
63    mut jac: F,
64    xs: &[N],
65    params: &mut VectorN<N, V>,
66    mat: &mut DMatrix<N>,
67) where
68    DefaultAllocator: Allocator<N, V>,
69{
70    for row in 0..mat.column(0).len() {
71        let deriv = jac(xs[row], &params);
72        for col in 0..mat.row(0).len() {
73            mat[(row, col)] = deriv[col];
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
80pub struct CurveFitParams<N: ComplexField> {
81    pub damping: N::RealField,
82    pub tolerance: N::RealField,
83    pub h: N::RealField,
84    pub damping_mult: N::RealField,
85}
86
87impl<N: ComplexField> Default for CurveFitParams<N> {
88    fn default() -> Self {
89        CurveFitParams {
90            damping: N::RealField::from_f64(2.0).unwrap(),
91            tolerance: N::RealField::from_f64(1e-5).unwrap(),
92            h: N::RealField::from_f64(0.1).unwrap(),
93            damping_mult: N::RealField::from_f64(1.5).unwrap(),
94        }
95    }
96}
97
98/// Fit a curve using the Levenberg-Marquardt algorithm.
99///
100/// Uses finite differences of h to calculate the jacobian. If jacobian
101/// can be found analytically, then use curve_fit_jac. Keeps iterating until
102/// the differences between the sum of the square residuals of two iterations
103/// is under tol.
104pub fn curve_fit<N: ComplexField, V: DimName, F: FnMut(N, &VectorN<N, V>) -> N>(
105    mut f: F,
106    xs: &[N],
107    ys: &[N],
108    initial: &[N],
109    params: &CurveFitParams<N>,
110) -> Result<VectorN<N, V>, String>
111where
112    DefaultAllocator: Allocator<N, V>,
113{
114    let tol = params.tolerance;
115    let mut damping = params.damping;
116    let h = params.h;
117    let damping_mult = params.damping_mult;
118
119    if !tol.is_sign_positive() {
120        return Err("curve_fit: tol must be positive".to_owned());
121    }
122
123    if !h.is_sign_positive() {
124        return Err("curve_fit: h must be positive".to_owned());
125    }
126
127    if !damping.is_sign_positive() {
128        return Err("curve_fit: damping must be positive".to_owned());
129    }
130
131    if xs.len() != ys.len() {
132        return Err("curve_fit: xs length must match ys length".to_owned());
133    }
134
135    let mut params = VectorN::<N, V>::from_column_slice(initial);
136    let ys = DVector::<N>::from_column_slice(ys);
137    let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
138    jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
139    let mut jac_transpose = jac.transpose();
140
141    // Get the initial sum of square residuals
142    let mut resid = Vec::with_capacity(xs.len());
143    for (ind, &x) in xs.iter().enumerate() {
144        resid.push(ys[ind] - f(x, &params));
145    }
146    let sum_sq_initial: N::RealField = resid
147        .iter()
148        .map(|&r| r.modulus_squared())
149        .fold(N::RealField::zero(), |acc, r| acc + r);
150
151    // Get initial factor
152    let mut sum_sq = sum_sq_initial + N::RealField::one();
153    let mut damping_tmp = damping / damping_mult;
154    let mut j = 0;
155    let mut evaluation: DVector<N> =
156        DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
157    while sum_sq > sum_sq_initial && j < 1000 {
158        damping_tmp *= damping_mult;
159        let diff = &ys - &evaluation;
160        let mut b = &jac_transpose * &diff;
161        // Always square
162        let mut multiplied = &jac_transpose * &jac;
163        for i in 0..multiplied.row(0).len() {
164            multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
165        }
166        let lu = multiplied.clone().lu();
167        let solved = lu.solve_mut(&mut b);
168        if !solved {
169            return Err("curve_fit: unable to solve linear equation".to_owned());
170        }
171        params += &b;
172        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
173        let diff = &ys - &evaluation;
174        sum_sq = diff
175            .iter()
176            .map(|&r| r.modulus_squared())
177            .fold(N::RealField::zero(), |acc, r| acc + r);
178        j += 1;
179        jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
180        jac_transpose = jac.transpose();
181    }
182    if j != 1000 {
183        damping = damping_tmp;
184    }
185
186    let mut last_sum_sq = sum_sq;
187    sum_sq += N::RealField::from_i32(2).unwrap() * tol;
188    while (last_sum_sq - sum_sq).abs() > tol {
189        last_sum_sq = sum_sq;
190        // Get right side of iteration equation
191        let diff = &ys - &evaluation;
192        let mut b = &jac_transpose * &diff;
193        let mut b_div = b.clone();
194        // Get left side of equation
195        let mut multiplied = &jac_transpose * &jac;
196        let mut multiplied_div = multiplied.clone();
197        for i in 0..multiplied.row(0).len() {
198            multiplied[(i, i)] *= N::one() + N::from_real(damping);
199        }
200        // Solve equation with LU w/ partial pivoting first
201        let lu = multiplied.clone().lu();
202        let solved = lu.solve_mut(&mut b);
203        if !solved {
204            return Err("curve_fit: unable to solve linear equation".to_owned());
205        }
206        let new_params = &params + &b;
207
208        // Now solve for damping / damping_mult
209        for i in 0..multiplied_div.row(0).len() {
210            multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
211        }
212        let lu = multiplied_div.clone().lu();
213        let solved = lu.solve_mut(&mut b_div);
214        if !solved {
215            let lu = multiplied_div.clone().full_piv_lu();
216            let solved = lu.solve_mut(&mut b_div);
217            if !solved {
218                let qr = multiplied_div.qr();
219                let solved = qr.solve_mut(&mut b_div);
220                if !solved {
221                    return Err("curve_fit: unable to solve linear equation".to_owned());
222                }
223            }
224        }
225        let new_params_div = &params + &b_div;
226
227        // get residuals for each of the new solutions
228        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
229        let evaluation_div =
230            DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
231        let diff = &ys - &evaluation;
232        let diff_div = &ys - &evaluation_div;
233
234        let resid: N::RealField = diff
235            .iter()
236            .map(|&r| r.modulus_squared())
237            .fold(N::RealField::zero(), |acc, r| acc + r);
238        let resid_div: N::RealField = diff_div
239            .iter()
240            .map(|&r| r.modulus_squared())
241            .fold(N::RealField::zero(), |acc, r| acc + r);
242
243        if resid_div < resid {
244            damping /= damping_mult;
245            evaluation = evaluation_div;
246            params = new_params_div;
247            sum_sq = resid_div;
248        } else {
249            params = new_params;
250            sum_sq = resid;
251        }
252
253        jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
254        jac_transpose = jac.transpose();
255    }
256
257    Ok(params)
258}
259
260/// Fit a curve using the Levenberg-Marquardt algorithm.
261///
262/// Uses an analytic jacobian.Keeps iterating until
263/// the differences between the sum of the square residuals of two iterations
264/// is under tol. Jacobian should be a function that returns a column vector
265/// where jacobian[i] is the partial derivative of f with respect to param[i].
266pub fn curve_fit_jac<
267    N: ComplexField,
268    V: DimName,
269    F: FnMut(N, &VectorN<N, V>) -> N,
270    G: FnMut(N, &VectorN<N, V>) -> VectorN<N, V>,
271>(
272    mut f: F,
273    xs: &[N],
274    ys: &[N],
275    initial: &[N],
276    mut jacobian: G,
277    params: &CurveFitParams<N>,
278) -> Result<VectorN<N, V>, String>
279where
280    DefaultAllocator: Allocator<N, V>,
281{
282    let tol = params.tolerance;
283    let mut damping = params.damping;
284    let damping_mult = params.damping_mult;
285
286    if !tol.is_sign_positive() {
287        return Err("curve_fit_jac: tol must be positive".to_owned());
288    }
289
290    if !damping.is_sign_positive() {
291        return Err("curve_fit_jac: damping must be positive".to_owned());
292    }
293
294    if xs.len() != ys.len() {
295        return Err("curve_fit_jac: xs length must match ys length".to_owned());
296    }
297
298    let mut params = VectorN::<N, V>::from_column_slice(initial);
299    let ys = DVector::<N>::from_column_slice(ys);
300    let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
301    jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
302    let mut jac_transpose = jac.transpose();
303
304    // Get the initial sum of square residuals
305    let mut resid = Vec::with_capacity(xs.len());
306    for (ind, &x) in xs.iter().enumerate() {
307        resid.push(ys[ind] - f(x, &params));
308    }
309    let sum_sq_initial: N::RealField = resid
310        .iter()
311        .map(|&r| r.modulus_squared())
312        .fold(N::RealField::zero(), |acc, r| acc + r);
313
314    // Get initial factor
315    let mut sum_sq = sum_sq_initial + N::RealField::one();
316    let mut damping_tmp = damping / damping_mult;
317    let mut j = 0;
318    let mut evaluation: DVector<N> =
319        DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
320    while sum_sq > sum_sq_initial && j < 1000 {
321        damping_tmp *= damping_mult;
322        let diff = &ys - &evaluation;
323        let mut b = &jac_transpose * &diff;
324        // Always square
325        let mut multiplied = &jac_transpose * &jac;
326        for i in 0..multiplied.row(0).len() {
327            multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
328        }
329        let lu = multiplied.clone().lu();
330        let solved = lu.solve_mut(&mut b);
331        if !solved {
332            let lu = multiplied.clone().full_piv_lu();
333            let solved = lu.solve_mut(&mut b);
334            if !solved {
335                let qr = multiplied.qr();
336                let solved = qr.solve_mut(&mut b);
337                if !solved {
338                    return Err("curve_fit_jac: unable to solve linear equation".to_owned());
339                }
340            }
341        }
342        params += &b;
343        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
344        let diff = &ys - &evaluation;
345        sum_sq = diff
346            .iter()
347            .map(|&r| r.modulus_squared())
348            .fold(N::RealField::zero(), |acc, r| acc + r);
349        j += 1;
350        jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
351        jac_transpose = jac.transpose();
352    }
353    if j != 1000 {
354        damping = damping_tmp;
355    }
356
357    let mut last_sum_sq = sum_sq;
358    sum_sq += N::RealField::from_i32(2).unwrap() * tol;
359    while (last_sum_sq - sum_sq).abs() > tol {
360        last_sum_sq = sum_sq;
361        // Get right side of iteration equation
362        let diff = &ys - &evaluation;
363        let mut b = &jac_transpose * &diff;
364        let mut b_div = b.clone();
365        // Get left side of equation
366        let mut multiplied = &jac_transpose * &jac;
367        let mut multiplied_div = multiplied.clone();
368        for i in 0..multiplied.row(0).len() {
369            multiplied[(i, i)] *= N::one() + N::from_real(damping);
370        }
371        // Solve equation with LU w/ partial pivoting first
372        // Then try LU w/ Full pivoting, QR
373        let lu = multiplied.clone().lu();
374        let solved = lu.solve_mut(&mut b);
375        if !solved {
376            let lu = multiplied.clone().full_piv_lu();
377            let solved = lu.solve_mut(&mut b);
378            if !solved {
379                let qr = multiplied.qr();
380                let solved = qr.solve_mut(&mut b);
381                if !solved {
382                    return Err("curve_fit_jac: unable to solve linear equation".to_owned());
383                }
384            }
385        }
386        let new_params = &params + &b;
387
388        // Now solve for damping / damping_mult
389        for i in 0..multiplied_div.row(0).len() {
390            multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
391        }
392        let lu = multiplied_div.clone().lu();
393        let solved = lu.solve_mut(&mut b_div);
394        if !solved {
395            let lu = multiplied_div.clone().full_piv_lu();
396            let solved = lu.solve_mut(&mut b_div);
397            if !solved {
398                let qr = multiplied_div.qr();
399                let solved = qr.solve_mut(&mut b_div);
400                if !solved {
401                    return Err("curve_fit_jac: unable to solve linear equation".to_owned());
402                }
403            }
404        }
405        let new_params_div = &params + &b_div;
406
407        // get residuals for each of the new solutions
408        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
409        let evaluation_div =
410            DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
411        let diff = &ys - &evaluation;
412        let diff_div = &ys - &evaluation_div;
413
414        let resid: N::RealField = diff
415            .iter()
416            .map(|&r| r.modulus_squared())
417            .fold(N::RealField::zero(), |acc, r| acc + r);
418        let resid_div: N::RealField = diff_div
419            .iter()
420            .map(|&r| r.modulus_squared())
421            .fold(N::RealField::zero(), |acc, r| acc + r);
422
423        if resid_div < resid {
424            damping /= damping_mult;
425            evaluation = evaluation_div;
426            params = new_params_div;
427            sum_sq = resid_div;
428        } else {
429            params = new_params;
430            sum_sq = resid;
431        }
432
433        jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
434        jac_transpose = jac.transpose();
435    }
436
437    Ok(params)
438}