use super::wls::{NormalEquations, WlsResult, WlsSolverError};
use thiserror::Error;
#[derive(Debug, Clone, Copy)]
pub struct NonlinearOptions {
pub max_iter: usize,
pub tol_rel: f64,
pub tol_chi2_rel: f64,
}
impl Default for NonlinearOptions {
fn default() -> Self {
Self {
max_iter: 12,
tol_rel: 1e-9,
tol_chi2_rel: 1e-6,
}
}
}
#[derive(Debug, Error)]
pub enum NonlinearError {
#[error(transparent)]
Solver(#[from] WlsSolverError),
#[error("did not converge in {0} iterations")]
DidNotConverge(usize),
}
#[derive(Debug, Clone)]
pub struct NonlinearReport {
pub parameters: Vec<f64>,
pub last: WlsResult,
pub iterations: usize,
}
pub fn gauss_newton<F>(
initial: Vec<f64>,
opts: NonlinearOptions,
mut assemble: F,
) -> Result<NonlinearReport, NonlinearError>
where
F: FnMut(&[f64]) -> Result<NormalEquations, NonlinearError>,
{
use super::wls::WlsSolverError;
if opts.max_iter == 0 {
return Err(NonlinearError::Solver(WlsSolverError::other(
"max_iter must be > 0",
)));
}
if !opts.tol_rel.is_finite() || opts.tol_rel < 0.0 {
return Err(NonlinearError::Solver(WlsSolverError::other(format!(
"tol_rel must be finite and ≥ 0 (got {})",
opts.tol_rel
))));
}
if !opts.tol_chi2_rel.is_finite() || opts.tol_chi2_rel < 0.0 {
return Err(NonlinearError::Solver(WlsSolverError::other(format!(
"tol_chi2_rel must be finite and ≥ 0 (got {})",
opts.tol_chi2_rel
))));
}
if let Some(bad) = initial.iter().copied().find(|v| !v.is_finite()) {
return Err(NonlinearError::Solver(WlsSolverError::other(format!(
"initial vector must be finite (got {bad})"
))));
}
let mut params = initial;
let mut last_chi2 = f64::INFINITY;
#[allow(unused_assignments)]
let mut last_result: Option<WlsResult> = None;
for it in 0..opts.max_iter {
let ne = assemble(¶ms)?;
if ne.n_params() != params.len() {
return Err(NonlinearError::Solver(WlsSolverError::other(format!(
"assembler returned {} parameters but initial vector has {} entries",
ne.n_params(),
params.len()
))));
}
let result = ne.solve()?;
let mut max_rel = 0.0_f64;
for (i, dp) in result.update.iter().enumerate() {
let scale = params[i].abs().max(1.0);
let rel = dp.abs() / scale;
if rel > max_rel {
max_rel = rel;
}
params[i] += dp;
}
let chi2 = result.reduced_chi2();
let chi2_rel = if last_chi2.is_finite() && last_chi2 > 0.0 {
(chi2 - last_chi2).abs() / last_chi2
} else {
f64::INFINITY
};
last_chi2 = chi2;
last_result = Some(result);
if max_rel < opts.tol_rel && chi2_rel < opts.tol_chi2_rel {
return Ok(NonlinearReport {
parameters: params,
last: last_result.unwrap(),
iterations: it + 1,
});
}
}
Err(NonlinearError::DidNotConverge(opts.max_iter))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pod::estimation::wls::NormalEquations;
fn default_opts() -> NonlinearOptions {
NonlinearOptions {
max_iter: 20,
tol_rel: 1e-9,
tol_chi2_rel: 1e-6,
}
}
#[test]
fn converges_for_1d_linear_problem() {
let opts = NonlinearOptions {
max_iter: 20,
tol_rel: 1e-9,
tol_chi2_rel: 2.0,
};
let report = gauss_newton(vec![0.0], opts, |params| {
let mut ne = NormalEquations::new(1);
ne.add_row(&[(0, 1.0)], 3.0 - params[0], 0.1).unwrap();
Ok(ne)
})
.unwrap();
assert!((report.parameters[0] - 3.0).abs() < 1e-6);
assert!(report.iterations >= 1);
}
#[test]
fn max_iter_zero_is_error() {
let opts = NonlinearOptions {
max_iter: 0,
..default_opts()
};
let err = gauss_newton(vec![0.0], opts, |_| Ok(NormalEquations::new(1))).unwrap_err();
assert!(matches!(err, NonlinearError::Solver(_)));
}
#[test]
fn non_finite_tol_rel_is_error() {
let opts = NonlinearOptions {
tol_rel: f64::NAN,
..default_opts()
};
let err = gauss_newton(vec![0.0], opts, |_| Ok(NormalEquations::new(1))).unwrap_err();
assert!(matches!(err, NonlinearError::Solver(_)));
}
#[test]
fn negative_tol_chi2_is_error() {
let opts = NonlinearOptions {
tol_chi2_rel: -1.0,
..default_opts()
};
let err = gauss_newton(vec![0.0], opts, |_| Ok(NormalEquations::new(1))).unwrap_err();
assert!(matches!(err, NonlinearError::Solver(_)));
}
#[test]
fn non_finite_initial_is_error() {
let err = gauss_newton(vec![f64::NAN], default_opts(), |_| {
Ok(NormalEquations::new(1))
})
.unwrap_err();
assert!(matches!(err, NonlinearError::Solver(_)));
}
#[test]
fn dimension_mismatch_is_error() {
let err = gauss_newton(vec![0.0, 0.0], default_opts(), |_| {
Ok(NormalEquations::new(1))
})
.unwrap_err();
assert!(matches!(err, NonlinearError::Solver(_)));
}
#[test]
fn did_not_converge_after_max_iter() {
let opts = NonlinearOptions {
max_iter: 1,
tol_rel: 0.0,
tol_chi2_rel: 0.0,
};
let err = gauss_newton(vec![0.0], opts, |params| {
let mut ne = NormalEquations::new(1);
ne.add_row(&[(0, 1.0)], 100.0 - params[0], 0.1).unwrap();
Ok(ne)
})
.unwrap_err();
assert!(matches!(err, NonlinearError::DidNotConverge(1)));
}
}