eqsolver/solvers/single_variable/
newton.rs

1use crate::{SolverError, SolverResult, DEFAULT_ITERMAX, DEFAULT_TOL};
2use num_traits::Float;
3use std::ops::Fn;
4
5/// # Newton-Raphson
6///
7/// Newton solves an equation `f(x) = 0` given the function `f` and its derivative `df` as closures that takes a `Float` and outputs a `Float`.
8/// This function uses the Newton-Raphson's method ([Wikipedia](https://en.wikipedia.org/wiki/Newton%27s_method)).
9///
10/// **Default Tolerance:** `1e-6`
11///
12/// **Default Max Iterations:** `50`
13///
14/// ## Examples
15///
16/// ### A solution exists
17///
18/// ```
19/// // Want to solve x in cos(x) = sin(x). This is equivalent to solving x in cos(x) - sin(x) = 0.
20/// use eqsolver::single_variable::Newton;
21/// let f = |x: f64| x.cos() - x.sin();
22/// let df = |x: f64| -x.sin() - x.cos(); // Derivative of f
23///
24/// // Solve with Newton's Method. Error is less than 1E-6. Starting guess is around 0.8.
25/// let solution = Newton::new(f, df)
26///     .with_tol(1e-6)
27///     .solve(0.8)
28///     .unwrap();
29/// assert!((solution - std::f64::consts::FRAC_PI_4).abs() <= 1e-6); // Exactly x = pi/4
30/// ```
31///
32/// ### A solution does not exist
33///
34/// ```
35/// use eqsolver::{single_variable::Newton, SolverError};
36/// let f = |x: f64| x*x + 1.;
37/// let df = |x: f64| 2.*x;
38///
39/// // Solve with Newton's Method. Error is less than 1E-6. Starting guess is around 1.
40/// let solution = Newton::new(f, df).solve(1.);
41/// assert_eq!(solution.err().unwrap(), SolverError::NotANumber); // No solution, will diverge
42/// ```
43pub struct Newton<T, F, D> {
44    f: F,
45    df: D,
46    tolerance: T,
47    iter_max: usize,
48}
49
50impl<T, F, D> Newton<T, F, D>
51where
52    T: Float,
53    F: Fn(T) -> T,
54    D: Fn(T) -> T,
55{
56    /// Set up the solver
57    ///
58    /// Instantiates the solver using the given closure representing the function `f` to find roots for. This function also takes `f`'s derivative `df`
59    pub fn new(f: F, df: D) -> Self {
60        Self {
61            f,
62            df,
63            tolerance: T::from(DEFAULT_TOL).unwrap(),
64            iter_max: DEFAULT_ITERMAX,
65        }
66    }
67
68    /// Updates the solver's tolerance (Magnitude of Error).
69    ///
70    /// **Default Tolerance:** `1e-6`
71    ///
72    /// ## Examples
73    /// ```
74    /// use eqsolver::single_variable::Newton;
75    /// let f = |x: f64| x*x - 2.; // Solve x^2 = 2
76    /// let df = |x: f64| 2.*x; // Derivative of f
77    /// let solution = Newton::new(f, df)
78    ///     .with_tol(1e-12)
79    ///     .solve(1.4)
80    ///     .unwrap();
81    /// assert!((solution - 2_f64.sqrt()).abs() <= 1e-12);
82    /// ```
83    pub fn with_tol(&mut self, tol: T) -> &mut Self {
84        self.tolerance = tol;
85        self
86    }
87
88    /// Updates the solver's amount of iterations done before terminating the iteration
89    ///
90    /// **Default Max Iterations:** `50`
91    ///
92    /// ## Examples
93    /// ```
94    /// use eqsolver::{single_variable::Newton, SolverError};
95    ///
96    /// let f = |x: f64| x.powf(-x); // Solve x^-x = 0
97    /// let df = |x: f64| -x.powf(-x) * (1. + x.ln()); // Derivative of f
98    /// let solution = Newton::new(f, df)
99    ///     .with_itermax(20)
100    ///     .solve(1.); // Solver will terminate after 20 iterations
101    /// assert_eq!(solution.err().unwrap(), SolverError::MaxIterReached);
102    /// ```
103    pub fn with_itermax(&mut self, max: usize) -> &mut Self {
104        self.iter_max = max;
105        self
106    }
107
108    /// Solves for `x` in `f(x) = 0` where `f` is the stored function.
109    ///
110    /// ## Examples
111    /// ```
112    /// use eqsolver::{DEFAULT_TOL, single_variable::Newton};
113    /// let f = |x: f64| x*x - 2.; // Solve x^2 = 2
114    /// let df = |x: f64| 2.*x; // Derivative of f
115    /// let solution = Newton::new(f, df)
116    ///     .solve(1.4)
117    ///     .unwrap();
118    /// assert!((solution - 2_f64.sqrt()).abs() <= DEFAULT_TOL); // Default Tolerance = 1e-6
119    /// ```
120    pub fn solve(&self, mut x0: T) -> SolverResult<T> {
121        let mut dx = T::max_value(); // We assume error is infinite at the start
122        let mut iter = 1;
123
124        // Newton-Raphson's Iteration
125        while dx.abs() > self.tolerance && iter <= self.iter_max {
126            dx = (self.f)(x0) / (self.df)(x0);
127            x0 = x0 - dx;
128            iter += 1;
129        }
130
131        if iter >= self.iter_max {
132            return Err(SolverError::MaxIterReached);
133        }
134
135        if x0.is_nan() {
136            return Err(SolverError::NotANumber);
137        }
138
139        Ok(x0)
140    }
141}