use crate::utils::{
convergence_data::ConvergenceData,
enums::{SolverError, TerminationReason},
perturb::perturb_real,
solver_settings::SolverSettings,
termination::{is_vtol_satisfied, is_xatol_satisfied},
};
use core::f64;
use std::sync::LazyLock;
pub static DEFAULT_NEWTON_SOLVER_SETTINGS: LazyLock<SolverSettings> =
LazyLock::new(|| SolverSettings {
max_iter: Some(200),
xatol: Some(1e-10),
..Default::default()
});
pub fn root_newton(
f: &impl Fn(f64) -> f64,
df: &impl Fn(f64) -> f64,
x0: f64,
solver_settings: Option<&SolverSettings>,
mut convergence_data: Option<&mut ConvergenceData>,
) -> Result<f64, SolverError> {
let solver_settings: &SolverSettings =
solver_settings.unwrap_or(&DEFAULT_NEWTON_SOLVER_SETTINGS);
let mut x_curr = x0;
let mut x_next: f64;
let mut f_next: f64;
let mut df_next: f64;
let mut f_curr = f(x_curr);
let mut df_curr = df(x_curr);
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.x_all.push(x_curr);
convergence_data.f_all.push(f_curr);
convergence_data.df_all.push(df_curr);
convergence_data.n_feval += 1;
convergence_data.n_deval += 1;
}
if is_vtol_satisfied(f_curr, solver_settings, convergence_data.as_deref_mut()) {
return Ok(x_curr);
}
for _ in 0..solver_settings.max_iter.unwrap() {
if df_curr == 0.0 {
x_curr = perturb_real(x_curr);
f_curr = f(x_curr);
df_curr = df(x_curr);
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.n_feval += 1;
convergence_data.n_deval += 1;
}
if df_curr == 0.0 {
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.termination_reason = TerminationReason::ZeroDerivative;
}
break;
}
}
x_next = x_curr - (f_curr / df_curr);
f_next = f(x_next);
df_next = df(x_next);
if let Some(convergence_data) = convergence_data.as_deref_mut() {
convergence_data.x_all.push(x_next);
convergence_data.f_all.push(f_next);
convergence_data.df_all.push(df_next);
convergence_data.n_iter += 1;
convergence_data.n_feval += 1;
convergence_data.n_deval += 1;
}
if is_xatol_satisfied(
x_curr,
x_next,
solver_settings,
convergence_data.as_deref_mut(),
) {
x_curr = x_next;
break;
}
if is_vtol_satisfied(f_next, solver_settings, convergence_data.as_deref_mut()) {
x_curr = x_next;
break;
}
x_curr = x_next;
f_curr = f_next;
df_curr = df_next;
}
Ok(x_curr)
}
pub fn root_newton_fast(
f: &impl Fn(f64) -> f64,
df: &impl Fn(f64) -> f64,
x0: f64,
xatol: Option<f64>,
) -> f64 {
let xatol = xatol.unwrap_or(1e-10);
let mut x_curr = x0;
let mut x_next: f64;
let mut f_next: f64;
let mut df_next: f64;
let mut f_curr = f(x_curr);
let mut df_curr = df(x_curr);
for _ in 0..200 {
if df_curr == 0.0 {
x_curr = perturb_real(x_curr);
f_curr = f(x_curr);
df_curr = df(x_curr);
if df_curr == 0.0 {
break;
}
}
x_next = x_curr - (f_curr / df_curr);
f_next = f(x_next);
df_next = df(x_next);
if (x_next - x_curr).abs() <= xatol {
x_curr = x_next;
break;
}
x_curr = x_next;
f_curr = f_next;
df_curr = df_next;
}
x_curr
}
#[cfg(test)]
mod tests {
use super::*;
use numtest::*;
#[allow(clippy::too_many_arguments)]
fn root_newton_test_helper(
f: &impl Fn(f64) -> f64,
df: &impl Fn(f64) -> f64,
x0: f64,
solver_settings: Option<&SolverSettings>,
x_exp: f64,
root_tol: Option<f64>,
value_tol: Option<f64>,
n_iter_exp: u32,
n_feval_exp: u32,
n_deval_exp: u32,
reason_exp: TerminationReason,
) {
let root_tol = root_tol.unwrap_or(1e-10);
let value_tol = value_tol.unwrap_or(1e-10);
let mut convergence_data = ConvergenceData::default();
let xatol: Option<f64> = if let Some(solver_settings) = solver_settings {
solver_settings.xatol
} else {
None
};
let x = root_newton(f, df, x0, solver_settings, Some(&mut convergence_data)).unwrap();
let x_fast = root_newton_fast(f, df, x0, xatol);
assert_equal_to_atol!(x, x_exp, root_tol);
assert_equal_to_atol!(f(x), 0.0, value_tol);
if solver_settings.is_none() {
assert_eq!(x, x_fast);
}
assert_eq!(convergence_data.n_iter, n_iter_exp);
assert_eq!(convergence_data.n_feval, n_feval_exp);
assert_eq!(convergence_data.n_deval, n_deval_exp);
assert_eq!(convergence_data.termination_reason, reason_exp);
}
#[test]
fn test_root_newton_at_root() {
root_newton_test_helper(
&|x: f64| (x - 1.0).powi(3),
&|x: f64| 3.0 * (x - 1.0).powi(2),
1.0,
None,
1.0,
Some(1e-13),
None,
1,
3,
3,
TerminationReason::AbsoluteStepToleranceSatisfied,
);
}
#[test]
fn test_root_newton_near_root() {
root_newton_test_helper(
&|x: f64| (x - 1.0).powi(3),
&|x: f64| 3.0 * (x - 1.0).powi(2),
1.5,
None,
1.0,
Some(1e-9),
None,
54,
55,
55,
TerminationReason::AbsoluteStepToleranceSatisfied,
);
}
#[test]
fn test_root_newton_starting_at_stationary_point() {
root_newton_test_helper(
&|x: f64| x.powi(2) - 1.0,
&|x: f64| 2.0 * x,
0.0,
None,
1.0,
Some(1e-10),
None,
50,
52,
52,
TerminationReason::AbsoluteStepToleranceSatisfied,
);
}
#[test]
fn test_root_newton_constant_function() {
root_newton_test_helper(
&|_x: f64| 1.0,
&|_x: f64| 0.0,
1.5,
None,
1.0,
Some(0.51),
Some(1.0),
0,
2,
2,
TerminationReason::ZeroDerivative,
);
}
#[test]
fn test_root_newton_xatol() {
let mut solver_settings = DEFAULT_NEWTON_SOLVER_SETTINGS.clone();
solver_settings.xatol = Some(0.001);
root_newton_test_helper(
&|x: f64| (x - 1.0).powi(3),
&|x: f64| 3.0 * (x - 1.0).powi(2),
1.5,
Some(&solver_settings),
1.0,
Some(0.003),
Some(1e-7),
14,
15,
15,
TerminationReason::AbsoluteStepToleranceSatisfied,
);
}
#[test]
fn test_root_newton_vtol() {
let mut solver_settings = DEFAULT_NEWTON_SOLVER_SETTINGS.clone();
solver_settings.vtol = Some(0.001);
root_newton_test_helper(
&|x: f64| (x - 1.0).powi(3),
&|x: f64| 3.0 * (x - 1.0).powi(2),
1.5,
Some(&solver_settings),
1.0,
Some(0.15),
Some(0.001),
4,
5,
5,
TerminationReason::ValueToleranceSatisfied,
);
}
#[test]
fn test_root_newton_iterates_1() {
let f = |x: f64| x.powi(2) - 612.0;
let df = |x: f64| 2.0 * x;
let x0 = 1.0;
let mut convergence_data = ConvergenceData::default();
let solver_settings = SolverSettings {
xatol: Some(1e-14),
max_iter: Some(10),
..Default::default()
};
let root = root_newton(
&f,
&df,
x0,
Some(&solver_settings),
Some(&mut convergence_data),
)
.unwrap();
assert_eq!(root, *convergence_data.x_all.last().unwrap());
assert_arrays_equal_to_decimal!(
convergence_data.x_all,
[
1.0,
306.5,
154.2483686786297,
79.10799786435472,
43.42212868215148,
28.758162428779126,
25.019538536995714,
24.74021067122501,
24.738633803961573,
24.738633753705965,
24.73863375370596
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.f_all,
[
-6.11e2,
93330.25,
23180.559240018472,
5646.07532610675,
1273.4812592893222,
215.03190628004336,
13.977308604213704,
0.07802405659595024,
2.486510197741154e-6,
1.1368683772161603e-13,
-1.1368683772161603e-13
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.df_all,
[
2.0,
613.0,
308.4967373572594,
158.21599572870943,
86.84425736430296,
57.51632485755825,
50.03907707399143,
49.48042134245002,
49.477267607923146,
49.47726750741193,
49.47726750741192
],
16
);
}
#[test]
fn test_root_newton_iterates_2() {
let f = |x: f64| x.powi(2) - 612.0;
let df = |x: f64| 2.0 * x;
let x0 = 10.0;
let mut convergence_data = ConvergenceData::default();
let solver_settings = SolverSettings {
xatol: Some(1e-14),
max_iter: Some(5),
..Default::default()
};
let root = root_newton(
&f,
&df,
x0,
Some(&solver_settings),
Some(&mut convergence_data),
)
.unwrap();
assert_eq!(root, *convergence_data.x_all.last().unwrap());
assert_arrays_equal_to_decimal!(
convergence_data.x_all,
[
10.0,
35.6,
26.395505617977527,
24.790635492455475,
24.738688294075324,
24.738633753766084
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.f_all,
[
-512.0,
655.3600000000001,
84.72271682868313,
2.5756081197931735,
2.698511419453098e-3,
2.9746161089860834e-9
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.df_all,
[
20.0,
71.2,
52.791011235955054,
49.58127098491095,
49.47737658815065,
49.47726750753217
],
16
);
}
#[test]
fn test_root_newton_iterates_3() {
let f = |x: f64| x.powi(2) - 612.0;
let df = |x: f64| 2.0 * x;
let x0 = -20.0;
let mut convergence_data = ConvergenceData::default();
let solver_settings = SolverSettings {
xatol: Some(1e-14),
max_iter: Some(4),
..Default::default()
};
let root = root_newton(
&f,
&df,
x0,
Some(&solver_settings),
Some(&mut convergence_data),
)
.unwrap();
assert_eq!(root, *convergence_data.x_all.last().unwrap());
assert_arrays_equal_to_decimal!(
convergence_data.x_all,
[
-20.0,
-25.3,
-24.744861660079053,
-24.738634537440753,
-24.738633753705976
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.f_all,
[
-212.0,
28.090000000000032,
0.3081785764502456,
3.877705648847041e-5,
6.821210263296962e-13
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.df_all,
[
-40.0,
-50.6,
-49.489723320158106,
-49.477269074881505,
-49.47726750741195
],
16
);
}
#[test]
fn test_root_newton_iterates_4() {
let f = |x: f64| x.powi(3) - x.cos();
let df = |x: f64| 3.0 * x.powi(2) + x.sin();
let x0 = 0.5;
let mut convergence_data = ConvergenceData::default();
let solver_settings = SolverSettings {
xatol: Some(1e-14),
max_iter: Some(6),
..Default::default()
};
let root = root_newton(
&f,
&df,
x0,
Some(&solver_settings),
Some(&mut convergence_data),
)
.unwrap();
assert_eq!(root, *convergence_data.x_all.last().unwrap());
assert_arrays_equal_to_decimal!(
convergence_data.x_all,
[
0.5,
1.1121416370972725,
0.9096726937368068,
0.8672638182088165,
0.8654771352982646,
0.8654740331109566,
0.8654740331016144
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.f_all,
[
-0.7525825618903728,
0.9328201795040982,
0.13875403935061037,
0.005393998041341108,
9.333106352094056e-6,
2.8106295069108e-11,
-2.220446049250313e-16
],
16
);
assert_arrays_equal_to_decimal!(
convergence_data.df_all,
[
1.229425538604203,
4.6072259973390155,
3.2718160437673296,
3.019001306546996,
3.0085566812522133e0,
3.008538560967953,
3.008538560913384
],
16
);
}
}