bacon_sci/optimize/
mod.rs

1use crate::polynomial::Polynomial;
2use nalgebra::{ComplexField, DMatrix, DVector, RealField, SVector};
3use num_traits::{FromPrimitive, One, Zero};
4
5/// Linear least-squares regression
6///
7/// # Errors
8/// Returns an error if the linear fit fails. (`xs.len() != ys.len()`)
9///
10/// # Panics
11/// Panics if a `usize` can not be transformed into the generic type.
12pub fn linear_fit<N>(xs: &[N], ys: &[N]) -> Result<Polynomial<N>, String>
13where
14    N: ComplexField + FromPrimitive + Copy,
15    <N as ComplexField>::RealField: FromPrimitive + Copy,
16{
17    if xs.len() != ys.len() {
18        return Err("linear_fit: xs length does not match ys length".to_owned());
19    }
20
21    let mut sum_x = N::zero();
22    let mut sum_y = N::zero();
23    let mut sum_x_sq = N::zero();
24    let mut sum_y_sq = N::zero();
25    let mut sum_xy = N::zero();
26
27    for (ind, x) in xs.iter().enumerate() {
28        sum_x += *x;
29        sum_y += ys[ind];
30        sum_x_sq += x.powi(2);
31        sum_y_sq += ys[ind].powi(2);
32        sum_xy += ys[ind] * *x;
33    }
34
35    let m = N::from_usize(xs.len()).unwrap();
36    let denom = m * sum_x_sq - sum_x.powi(2);
37    let a = (m * sum_xy - sum_x * sum_y) / denom;
38    let b = (sum_x_sq * sum_y - sum_xy * sum_x) / denom;
39
40    Ok(polynomial![a, b])
41}
42
43// Compute the J matrix for LM using finite differences, 3 point formula
44fn jac_finite_differences<N, F, const V: usize>(
45    mut f: F,
46    xs: &[N],
47    params: &mut SVector<N, V>,
48    mat: &mut DMatrix<N>,
49    h: N::RealField,
50) where
51    N: ComplexField + FromPrimitive + Copy,
52    F: FnMut(N, &SVector<N, V>) -> N,
53    <N as ComplexField>::RealField: FromPrimitive + Copy,
54{
55    let h = N::from_real(h);
56    let denom = N::one() / (N::from_i32(2).unwrap() * h);
57    for row in 0..mat.column(0).len() {
58        for col in 0..mat.row(0).len() {
59            params[col] += h;
60            let above = f(xs[row], params);
61            params[col] -= h;
62            params[col] -= h;
63            let below = f(xs[row], params);
64            mat[(row, col)] = denom * (above + below);
65            params[col] += h;
66        }
67    }
68}
69
70// Compute the J matrix for LM using analytic formula
71fn jac_analytic<N, F, const V: usize>(
72    mut jac: F,
73    xs: &[N],
74    params: &mut SVector<N, V>,
75    mat: &mut DMatrix<N>,
76) where
77    N: ComplexField + Copy,
78    F: FnMut(N, &SVector<N, V>) -> SVector<N, V>,
79{
80    for row in 0..mat.column(0).len() {
81        let deriv = jac(xs[row], params);
82        for col in 0..mat.row(0).len() {
83            mat[(row, col)] = deriv[col];
84        }
85    }
86}
87
88#[derive(Debug, Clone)]
89pub struct CurveFitParams<N: ComplexField> {
90    pub damping: N::RealField,
91    pub tolerance: N::RealField,
92    pub h: N::RealField,
93    pub damping_mult: N::RealField,
94}
95
96impl<N: ComplexField + FromPrimitive> Default for CurveFitParams<N> {
97    fn default() -> Self {
98        CurveFitParams {
99            damping: N::from_f64(2.0).unwrap().real(),
100            tolerance: N::from_f64(1e-5).unwrap().real(),
101            h: N::from_f64(0.1).unwrap().real(),
102            damping_mult: N::from_f64(1.5).unwrap().real(),
103        }
104    }
105}
106
107#[allow(clippy::too_many_arguments)]
108fn initial_residuals<N, F, const V: usize>(
109    xs: &[N],
110    ys: &DVector<N>,
111    damping: &mut N::RealField,
112    damping_mult: N::RealField,
113    h: N::RealField,
114    mut f: F,
115    jac: &mut DMatrix<N>,
116    jac_transpose: &mut DMatrix<N>,
117    mut params: SVector<N, V>,
118) -> Result<(N::RealField, DVector<N>), String>
119where
120    N: ComplexField + Copy + FromPrimitive,
121    <N as ComplexField>::RealField: Copy + FromPrimitive,
122    F: FnMut(N, &SVector<N, V>) -> N,
123{
124    let mut resid = Vec::with_capacity(xs.len());
125    for (ind, &x) in xs.iter().enumerate() {
126        resid.push(ys[ind] - f(x, &params));
127    }
128    let sum_sq_initial: N::RealField = resid
129        .iter()
130        .map(|&r| r.modulus_squared())
131        .fold(N::RealField::zero(), |acc, r| acc + r);
132
133    // Get initial factor
134    let mut sum_sq = sum_sq_initial + N::RealField::one();
135    let mut damping_tmp = *damping / damping_mult;
136    let mut j = 0;
137    let mut evaluation: DVector<N> =
138        DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
139    while sum_sq > sum_sq_initial && j < 1000 {
140        damping_tmp *= damping_mult;
141        let diff = ys - &evaluation;
142        let mut b = jac_transpose as &DMatrix<N> * &diff;
143        // Always square
144        let mut multiplied = jac_transpose as &DMatrix<N> * jac as &DMatrix<N>;
145        for i in 0..multiplied.row(0).len() {
146            multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
147        }
148        let lu = multiplied.clone().lu();
149        let solved = lu.solve_mut(&mut b);
150        if !solved {
151            return Err("curve_fit: unable to solve linear equation".to_owned());
152        }
153        params += &b;
154        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
155        let diff = ys - &evaluation;
156        sum_sq = diff
157            .iter()
158            .map(|&r| r.modulus_squared())
159            .fold(N::RealField::zero(), |acc, r| acc + r);
160        j += 1;
161        jac_finite_differences(&mut f, xs, &mut params, jac, h);
162        *jac_transpose = jac.transpose();
163    }
164    if j != 1000 {
165        *damping = damping_tmp;
166    }
167    Ok((sum_sq, evaluation))
168}
169
170#[allow(clippy::too_many_arguments)]
171fn initial_residuals_exact<N, F, G, const V: usize>(
172    xs: &[N],
173    ys: &DVector<N>,
174    damping: &mut N::RealField,
175    damping_mult: N::RealField,
176    mut f: F,
177    mut jacobian: G,
178    jac: &mut DMatrix<N>,
179    jac_transpose: &mut DMatrix<N>,
180    mut params: SVector<N, V>,
181) -> Result<(N::RealField, DVector<N>), String>
182where
183    N: ComplexField + Copy + FromPrimitive,
184    <N as ComplexField>::RealField: Copy + FromPrimitive,
185    F: FnMut(N, &SVector<N, V>) -> N,
186    G: FnMut(N, &SVector<N, V>) -> SVector<N, V>,
187{
188    // Get the initial sum of square residuals
189    let mut resid = Vec::with_capacity(xs.len());
190    for (ind, &x) in xs.iter().enumerate() {
191        resid.push(ys[ind] - f(x, &params));
192    }
193    let sum_sq_initial: N::RealField = resid
194        .iter()
195        .map(|&r| r.modulus_squared())
196        .fold(N::RealField::zero(), |acc, r| acc + r);
197
198    // Get initial factor
199    let mut sum_sq = sum_sq_initial + N::RealField::one();
200    let mut damping_tmp = *damping / damping_mult;
201    let mut j = 0;
202    let mut evaluation: DVector<N> =
203        DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
204    while sum_sq > sum_sq_initial && j < 1000 {
205        damping_tmp *= damping_mult;
206        let diff = ys - &evaluation;
207        let mut b = jac_transpose as &DMatrix<N> * &diff;
208        // Always square
209        let mut multiplied = jac_transpose as &DMatrix<N> * jac as &DMatrix<N>;
210        for i in 0..multiplied.row(0).len() {
211            multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
212        }
213        let lu = multiplied.clone().lu();
214        let solved = lu.solve_mut(&mut b);
215        if !solved {
216            let lu = multiplied.clone().full_piv_lu();
217            let full_lu_solved = lu.solve_mut(&mut b);
218            if !full_lu_solved {
219                let qr = multiplied.qr();
220                let qr_solved = qr.solve_mut(&mut b);
221                if !qr_solved {
222                    return Err("curve_fit_jac: unable to solve linear equation".to_owned());
223                }
224            }
225        }
226        params += &b;
227        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &params)));
228        let diff = ys - &evaluation;
229        sum_sq = diff
230            .iter()
231            .map(|&r| r.modulus_squared())
232            .fold(N::RealField::zero(), |acc, r| acc + r);
233        j += 1;
234        jac_analytic(&mut jacobian, xs, &mut params, jac);
235        *jac_transpose = jac.transpose();
236    }
237    if j != 1000 {
238        *damping = damping_tmp;
239    }
240
241    Ok((sum_sq, evaluation))
242}
243
244/// Fit a curve using the Levenberg-Marquardt algorithm.
245///
246/// Uses finite differences of h to calculate the jacobian. If jacobian
247/// can be found analytically, then use `curve_fit_jac`. Keeps iterating until
248/// the differences between the sum of the square residuals of two iterations
249/// is under tol.
250///
251/// # Errors
252/// Returns an error if curve fitting fails.
253///
254/// # Panics
255/// Panics if a u8 can not be converted to the generic type.
256pub fn curve_fit<N, F, const V: usize>(
257    mut f: F,
258    xs: &[N],
259    ys: &[N],
260    initial: &[N],
261    params: &CurveFitParams<N>,
262) -> Result<SVector<N, V>, String>
263where
264    N: ComplexField + FromPrimitive + Copy,
265    <N as ComplexField>::RealField: FromPrimitive + Copy,
266    F: FnMut(N, &SVector<N, V>) -> N,
267{
268    let tol = params.tolerance;
269    let mut damping = params.damping;
270    let h = params.h;
271    let damping_mult = params.damping_mult;
272
273    if !tol.is_sign_positive() {
274        return Err("curve_fit: tol must be positive".to_owned());
275    }
276
277    if !h.is_sign_positive() {
278        return Err("curve_fit: h must be positive".to_owned());
279    }
280
281    if !damping.is_sign_positive() {
282        return Err("curve_fit: damping must be positive".to_owned());
283    }
284
285    if xs.len() != ys.len() {
286        return Err("curve_fit: xs length must match ys length".to_owned());
287    }
288
289    let mut params = SVector::<N, V>::from_column_slice(initial);
290    let ys = DVector::<N>::from_column_slice(ys);
291    let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
292    jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
293    let mut jac_transpose = jac.transpose();
294
295    // Get the initial sum of square residuals
296    let (mut sum_sq, mut evaluation) = initial_residuals(
297        xs,
298        &ys,
299        &mut damping,
300        damping_mult,
301        h,
302        &mut f,
303        &mut jac,
304        &mut jac_transpose,
305        params,
306    )?;
307
308    let mut last_sum_sq = sum_sq;
309    sum_sq += N::from_u8(2).unwrap().real() * tol;
310    while (last_sum_sq - sum_sq).abs() > tol {
311        last_sum_sq = sum_sq;
312        // Get right side of iteration equation
313        let diff = &ys - &evaluation;
314        let mut b = &jac_transpose * &diff;
315        let mut b_div = b.clone();
316        // Get left side of equation
317        let mut multiplied = &jac_transpose * &jac;
318        let mut multiplied_div = multiplied.clone();
319        for i in 0..multiplied.row(0).len() {
320            multiplied[(i, i)] *= N::one() + N::from_real(damping);
321        }
322        // Solve equation with LU w/ partial pivoting first
323        let lu = multiplied.clone().lu();
324        let lu_solved = lu.solve_mut(&mut b);
325        if !lu_solved {
326            return Err("curve_fit: unable to solve linear equation".to_owned());
327        }
328        let new_params = params + &b;
329
330        // Now solve for damping / damping_mult
331        for i in 0..multiplied_div.row(0).len() {
332            multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
333        }
334        let lu = multiplied_div.clone().lu();
335        let solved = lu.solve_mut(&mut b_div);
336        if !solved {
337            let lu = multiplied_div.clone().full_piv_lu();
338            let full_lu_solved = lu.solve_mut(&mut b_div);
339            if !full_lu_solved {
340                let qr = multiplied_div.qr();
341                let qr_solved = qr.solve_mut(&mut b_div);
342                if !qr_solved {
343                    return Err("curve_fit: unable to solve linear equation".to_owned());
344                }
345            }
346        }
347        let new_params_div = params + &b_div;
348
349        // get residuals for each of the new solutions
350        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
351        let evaluation_div =
352            DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
353        let diff = &ys - &evaluation;
354        let diff_div = &ys - &evaluation_div;
355
356        let resid: N::RealField = diff
357            .iter()
358            .map(|&r| r.modulus_squared())
359            .fold(N::RealField::zero(), |acc, r| acc + r);
360        let resid_div: N::RealField = diff_div
361            .iter()
362            .map(|&r| r.modulus_squared())
363            .fold(N::RealField::zero(), |acc, r| acc + r);
364
365        if resid_div < resid {
366            damping /= damping_mult;
367            evaluation = evaluation_div;
368            params = new_params_div;
369            sum_sq = resid_div;
370        } else {
371            params = new_params;
372            sum_sq = resid;
373        }
374
375        jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
376        jac_transpose = jac.transpose();
377    }
378
379    Ok(params)
380}
381
382/// Fit a curve using the Levenberg-Marquardt algorithm.
383///
384/// Uses an analytic jacobian.Keeps iterating until
385/// the differences between the sum of the square residuals of two iterations
386/// is under tol. Jacobian should be a function that returns a column vector
387/// where jacobian[i] is the partial derivative of f with respect to param[i].
388///
389/// # Errors
390/// Returns an error if curve fit fails.
391///
392/// # Panics
393/// Panics if a u8 can not be converted to the generic type.
394pub fn curve_fit_jac<N, F, G, const V: usize>(
395    mut f: F,
396    xs: &[N],
397    ys: &[N],
398    initial: &[N],
399    mut jacobian: G,
400    params: &CurveFitParams<N>,
401) -> Result<SVector<N, V>, String>
402where
403    N: ComplexField + FromPrimitive + Copy,
404    <N as ComplexField>::RealField: FromPrimitive + Copy,
405    F: FnMut(N, &SVector<N, V>) -> N,
406    G: FnMut(N, &SVector<N, V>) -> SVector<N, V>,
407{
408    let tol = params.tolerance;
409    let mut damping = params.damping;
410    let damping_mult = params.damping_mult;
411
412    if !tol.is_sign_positive() {
413        return Err("curve_fit_jac: tol must be positive".to_owned());
414    }
415
416    if !damping.is_sign_positive() {
417        return Err("curve_fit_jac: damping must be positive".to_owned());
418    }
419
420    if xs.len() != ys.len() {
421        return Err("curve_fit_jac: xs length must match ys length".to_owned());
422    }
423
424    let mut params = SVector::<N, V>::from_column_slice(initial);
425    let ys = DVector::<N>::from_column_slice(ys);
426    let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
427    jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
428    let mut jac_transpose = jac.transpose();
429
430    let (mut sum_sq, mut evaluation) = initial_residuals_exact(
431        xs,
432        &ys,
433        &mut damping,
434        damping_mult,
435        &mut f,
436        &mut jacobian,
437        &mut jac,
438        &mut jac_transpose,
439        params,
440    )?;
441
442    let mut last_sum_sq = sum_sq;
443    sum_sq += N::from_u8(2).unwrap().real() * tol;
444    while (last_sum_sq - sum_sq).abs() > tol {
445        last_sum_sq = sum_sq;
446        // Get right side of iteration equation
447        let diff = &ys - &evaluation;
448        let mut b = &jac_transpose * &diff;
449        let mut b_div = b.clone();
450        // Get left side of equation
451        let mut multiplied = &jac_transpose * &jac;
452        let mut multiplied_div = multiplied.clone();
453        for i in 0..multiplied.row(0).len() {
454            multiplied[(i, i)] *= N::one() + N::from_real(damping);
455        }
456        // Solve equation with LU w/ partial pivoting first
457        // Then try LU w/ Full pivoting, QR
458        let lu = multiplied.clone().lu();
459        let lu_solved = lu.solve_mut(&mut b);
460        if !lu_solved {
461            let lu = multiplied.clone().full_piv_lu();
462            let full_lu_solved = lu.solve_mut(&mut b);
463            if !full_lu_solved {
464                let qr = multiplied.qr();
465                let qr_solved = qr.solve_mut(&mut b);
466                if !qr_solved {
467                    return Err("curve_fit_jac: unable to solve linear equation".to_owned());
468                }
469            }
470        }
471        let new_params = params + &b;
472
473        // Now solve for damping / damping_mult
474        for i in 0..multiplied_div.row(0).len() {
475            multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
476        }
477        let lu = multiplied_div.clone().lu();
478        let solved = lu.solve_mut(&mut b_div);
479        if !solved {
480            let lu = multiplied_div.clone().full_piv_lu();
481            let full_lu_solved = lu.solve_mut(&mut b_div);
482            if !full_lu_solved {
483                let qr = multiplied_div.qr();
484                let qr_solved = qr.solve_mut(&mut b_div);
485                if !qr_solved {
486                    return Err("curve_fit_jac: unable to solve linear equation".to_owned());
487                }
488            }
489        }
490        let new_params_div = params + &b_div;
491
492        // get residuals for each of the new solutions
493        evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
494        let evaluation_div =
495            DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
496        let diff = &ys - &evaluation;
497        let diff_div = &ys - &evaluation_div;
498
499        let resid: N::RealField = diff
500            .iter()
501            .map(|&r| r.modulus_squared())
502            .fold(N::RealField::zero(), |acc, r| acc + r);
503        let resid_div: N::RealField = diff_div
504            .iter()
505            .map(|&r| r.modulus_squared())
506            .fold(N::RealField::zero(), |acc, r| acc + r);
507
508        if resid_div < resid {
509            damping /= damping_mult;
510            evaluation = evaluation_div;
511            params = new_params_div;
512            sum_sq = resid_div;
513        } else {
514            params = new_params;
515            sum_sq = resid;
516        }
517
518        jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
519        jac_transpose = jac.transpose();
520    }
521
522    Ok(params)
523}