eqsolver/solvers/single_variable/newton.rs
1use crate::{SolverError, SolverResult, DEFAULT_ITERMAX, DEFAULT_TOL};
2use num_traits::Float;
3use std::ops::Fn;
4
5/// # Newton-Raphson
6///
7/// Newton solves an equation `f(x) = 0` given the function `f` and its derivative `df` as closures that takes a `Float` and outputs a `Float`.
8/// This function uses the Newton-Raphson's method ([Wikipedia](https://en.wikipedia.org/wiki/Newton%27s_method)).
9///
10/// **Default Tolerance:** `1e-6`
11///
12/// **Default Max Iterations:** `50`
13///
14/// ## Examples
15///
16/// ### A solution exists
17///
18/// ```
19/// // Want to solve x in cos(x) = sin(x). This is equivalent to solving x in cos(x) - sin(x) = 0.
20/// use eqsolver::single_variable::Newton;
21/// let f = |x: f64| x.cos() - x.sin();
22/// let df = |x: f64| -x.sin() - x.cos(); // Derivative of f
23///
24/// // Solve with Newton's Method. Error is less than 1E-6. Starting guess is around 0.8.
25/// let solution = Newton::new(f, df)
26/// .with_tol(1e-6)
27/// .solve(0.8)
28/// .unwrap();
29/// assert!((solution - std::f64::consts::FRAC_PI_4).abs() <= 1e-6); // Exactly x = pi/4
30/// ```
31///
32/// ### A solution does not exist
33///
34/// ```
35/// use eqsolver::{single_variable::Newton, SolverError};
36/// let f = |x: f64| x*x + 1.;
37/// let df = |x: f64| 2.*x;
38///
39/// // Solve with Newton's Method. Error is less than 1E-6. Starting guess is around 1.
40/// let solution = Newton::new(f, df).solve(1.);
41/// assert_eq!(solution.err().unwrap(), SolverError::NotANumber); // No solution, will diverge
42/// ```
43pub struct Newton<T, F, D> {
44 f: F,
45 df: D,
46 tolerance: T,
47 iter_max: usize,
48}
49
50impl<T, F, D> Newton<T, F, D>
51where
52 T: Float,
53 F: Fn(T) -> T,
54 D: Fn(T) -> T,
55{
56 /// Set up the solver
57 ///
58 /// Instantiates the solver using the given closure representing the function `f` to find roots for. This function also takes `f`'s derivative `df`
59 pub fn new(f: F, df: D) -> Self {
60 Self {
61 f,
62 df,
63 tolerance: T::from(DEFAULT_TOL).unwrap(),
64 iter_max: DEFAULT_ITERMAX,
65 }
66 }
67
68 /// Updates the solver's tolerance (Magnitude of Error).
69 ///
70 /// **Default Tolerance:** `1e-6`
71 ///
72 /// ## Examples
73 /// ```
74 /// use eqsolver::single_variable::Newton;
75 /// let f = |x: f64| x*x - 2.; // Solve x^2 = 2
76 /// let df = |x: f64| 2.*x; // Derivative of f
77 /// let solution = Newton::new(f, df)
78 /// .with_tol(1e-12)
79 /// .solve(1.4)
80 /// .unwrap();
81 /// assert!((solution - 2_f64.sqrt()).abs() <= 1e-12);
82 /// ```
83 pub fn with_tol(&mut self, tol: T) -> &mut Self {
84 self.tolerance = tol;
85 self
86 }
87
88 /// Updates the solver's amount of iterations done before terminating the iteration
89 ///
90 /// **Default Max Iterations:** `50`
91 ///
92 /// ## Examples
93 /// ```
94 /// use eqsolver::{single_variable::Newton, SolverError};
95 ///
96 /// let f = |x: f64| x.powf(-x); // Solve x^-x = 0
97 /// let df = |x: f64| -x.powf(-x) * (1. + x.ln()); // Derivative of f
98 /// let solution = Newton::new(f, df)
99 /// .with_itermax(20)
100 /// .solve(1.); // Solver will terminate after 20 iterations
101 /// assert_eq!(solution.err().unwrap(), SolverError::MaxIterReached);
102 /// ```
103 pub fn with_itermax(&mut self, max: usize) -> &mut Self {
104 self.iter_max = max;
105 self
106 }
107
108 /// Solves for `x` in `f(x) = 0` where `f` is the stored function.
109 ///
110 /// ## Examples
111 /// ```
112 /// use eqsolver::{DEFAULT_TOL, single_variable::Newton};
113 /// let f = |x: f64| x*x - 2.; // Solve x^2 = 2
114 /// let df = |x: f64| 2.*x; // Derivative of f
115 /// let solution = Newton::new(f, df)
116 /// .solve(1.4)
117 /// .unwrap();
118 /// assert!((solution - 2_f64.sqrt()).abs() <= DEFAULT_TOL); // Default Tolerance = 1e-6
119 /// ```
120 pub fn solve(&self, mut x0: T) -> SolverResult<T> {
121 let mut dx = T::max_value(); // We assume error is infinite at the start
122 let mut iter = 1;
123
124 // Newton-Raphson's Iteration
125 while dx.abs() > self.tolerance && iter <= self.iter_max {
126 dx = (self.f)(x0) / (self.df)(x0);
127 x0 = x0 - dx;
128 iter += 1;
129 }
130
131 if iter >= self.iter_max {
132 return Err(SolverError::MaxIterReached);
133 }
134
135 if x0.is_nan() {
136 return Err(SolverError::NotANumber);
137 }
138
139 Ok(x0)
140 }
141}