bacon_sci/roots/
mod.rs

1/* This file is part of bacon.
2 * Copyright (c) Wyatt Campbell.
3 *
4 * See repository LICENSE for information.
5 */
6
7use nalgebra::{ComplexField, Const, DimMin, RealField, SMatrix, SVector};
8use num_traits::{FromPrimitive, Zero};
9
10mod polynomial;
11pub use polynomial::*;
12
13/// Use the bisection method to solve for a zero of an equation.
14///
15/// This function takes an interval and uses a binary search to find
16/// where in that interval a function has a root. The signs of the function
17/// at each end of the interval must be different
18///
19/// # Returns
20/// Ok(root) when a root has been found, Err on failure
21///
22/// # Params
23/// `(left, right)` The starting interval. f(left) * f(right) > 0.0
24///
25/// `f` The function to find the root for
26///
27/// `tol` The tolerance of the relative error between iterations.
28///
29/// `n_max` The maximum number of iterations to use.
30///
31/// # Examples
32/// ```
33/// use bacon_sci::roots::bisection;
34///
35/// fn cubic(x: f64) -> f64 {
36///   x*x*x
37/// }
38/// //...
39/// fn example() {
40///   let solution = bisection((-1.0, 1.0), cubic, 0.001, 1000).unwrap();
41/// }
42/// ```
43pub fn bisection<N, F>(
44    (mut left, mut right): (N, N),
45    mut f: F,
46    tol: N,
47    n_max: usize,
48) -> Result<N, String>
49where
50    N: RealField + FromPrimitive + Copy,
51    F: FnMut(N) -> N,
52{
53    if left >= right {
54        return Err("Bisection: requirement: right > left".to_owned());
55    }
56
57    let mut n = 1;
58
59    let mut f_a = f(left);
60    if (f_a * f(right)).is_sign_positive() {
61        return Err("Bisection: requirement: Signs must be different".to_owned());
62    }
63
64    let half = N::from_f64(0.5).unwrap();
65
66    let mut half_interval = (left - right) * half;
67    let mut middle = left + half_interval;
68
69    if middle.abs() <= tol {
70        return Ok(middle);
71    }
72
73    while n <= n_max {
74        let f_p = f(middle);
75        if (f_p * f_a).is_sign_positive() {
76            left = middle;
77            f_a = f_p;
78        } else {
79            right = middle;
80        }
81
82        half_interval = (right - left) * half;
83
84        let middle_new = left + half_interval;
85
86        if (middle - middle_new).abs() / middle.abs() < tol || middle_new.abs() < tol {
87            return Ok(middle_new);
88        }
89
90        middle = middle_new;
91        n += 1;
92    }
93
94    Err("Bisection: Maximum iterations exceeded".to_owned())
95}
96
97/// Use steffenson's method to find a fixed point
98///
99/// Use steffenson's method to find a value x so that f(x) = x, given
100/// a starting point.
101///
102/// # Returns
103/// `Ok(x)` so that `f(x) - x < tol` on success, `Err` on failure
104///
105/// # Params
106/// `initial` inital guess for the fixed point
107///
108/// `f` Function to find the fixed point of
109///
110/// `tol` Tolerance from 0 to try and achieve
111///
112/// `n_max` maximum number of iterations
113///
114/// # Examples
115/// ```
116/// use bacon_sci::roots::steffensen;
117/// fn cosine(x: f64) -> f64 {
118///   x.cos()
119/// }
120/// //...
121/// fn example() -> Result<(), String> {
122///   let solution = steffensen(0.5f64, cosine, 0.0001, 1000)?;
123///   Ok(())
124/// }
125pub fn steffensen<N>(mut initial: N, f: fn(N) -> N, tol: N, n_max: usize) -> Result<N, String>
126where
127    N: RealField + FromPrimitive + Copy,
128{
129    let mut n = 0;
130
131    while n < n_max {
132        let guess = f(initial);
133        let new_guess = f(guess);
134        let diff = initial
135            - (guess - initial).powi(2) / (new_guess - N::from_f64(2.0).unwrap() * guess + initial);
136        if (diff - initial).abs() <= tol {
137            return Ok(diff);
138        }
139        initial = diff;
140        n += 1;
141    }
142
143    Err("Steffensen: Maximum number of iterations exceeded".to_owned())
144}
145
146/// Use Newton's method to find a root of a vector function.
147///
148/// Using a vector function and its derivative, find a root based on an initial guess
149/// using Newton's method.
150///
151/// # Returns
152/// `Ok(vec)` on success, where `vec` is a vector input for which the function is
153/// zero. `Err` on failure.
154///
155/// # Params
156/// `initial` Initial guess of the root. Should be near actual root. Slice since this
157/// function finds roots of vector functions.
158///
159/// `f` Vector function for which to find the root
160///
161/// `f_deriv` Derivative of `f`
162///
163/// `tol` tolerance for error between iterations of Newton's method
164///
165/// `n_max` Maximum number of iterations
166///
167/// # Examples
168/// ```
169/// use nalgebra::{SVector, SMatrix};
170/// use bacon_sci::roots::newton;
171/// fn cubic(x: &[f64]) -> SVector<f64, 1> {
172///   SVector::<f64, 1>::from_iterator(x.iter().map(|x| x.powi(3)))
173/// }
174///
175/// fn cubic_deriv(x: &[f64]) -> SMatrix<f64, 1, 1> {
176///  SMatrix::<f64, 1, 1>::from_iterator(x.iter().map(|x| 3.0*x.powi(2)))
177/// }
178/// //...
179/// fn example() {
180///   let solution = newton(&[0.1], cubic, cubic_deriv, 0.001, 1000).unwrap();
181/// }
182/// ```
183pub fn newton<N, F, G, const S: usize>(
184    initial: &[N],
185    mut f: F,
186    mut jac: G,
187    tol: <N as ComplexField>::RealField,
188    n_max: usize,
189) -> Result<SVector<N, S>, String>
190where
191    N: ComplexField + FromPrimitive + Copy,
192    <N as ComplexField>::RealField: FromPrimitive + Copy,
193    F: FnMut(&[N]) -> SVector<N, S>,
194    G: FnMut(&[N]) -> SMatrix<N, S, S>,
195    Const<S>: DimMin<Const<S>, Output = Const<S>>,
196{
197    let mut guess = SVector::<N, S>::from_column_slice(initial);
198    let mut norm = guess.dot(&guess).sqrt().abs();
199    let mut n = 0;
200
201    if norm <= tol {
202        return Ok(guess);
203    }
204
205    while n < n_max {
206        let f_val = -f(guess.as_slice());
207        let f_deriv_val = jac(guess.as_slice());
208        let lu = f_deriv_val.lu();
209        match lu.solve(&f_val) {
210            None => return Err("newton: failed to solve linear equation".to_owned()),
211            Some(adjustment) => {
212                let new_guess = guess + adjustment;
213                let new_norm = new_guess.dot(&new_guess).sqrt().abs();
214                if ((norm - new_norm) / norm).abs() <= tol || new_norm <= tol {
215                    return Ok(new_guess);
216                }
217
218                norm = new_norm;
219                guess = new_guess;
220                n += 1;
221            }
222        }
223    }
224
225    Err("Newton: Maximum iterations exceeded".to_owned())
226}
227
228fn jac_finite_diff<N, F, const S: usize>(
229    mut f: F,
230    x: &mut SVector<N, S>,
231    h: <N as ComplexField>::RealField,
232) -> SMatrix<N, S, S>
233where
234    N: ComplexField + FromPrimitive + Copy,
235    <N as ComplexField>::RealField: FromPrimitive + Copy,
236    F: FnMut(&[N]) -> SVector<N, S>,
237{
238    let mut mat = SMatrix::<N, S, S>::zero();
239    let h = N::from_real(h);
240    let denom = N::one() / (N::from_i32(2).unwrap() * h);
241
242    for col in 0..mat.row(0).len() {
243        x[col] += h;
244        let above = f(x.as_slice());
245        x[col] -= h;
246        x[col] -= h;
247        let below = f(x.as_slice());
248        x[col] += h;
249        let jac_col = (above + below) * denom;
250        for row in 0..mat.column(0).len() {
251            mat[(row, col)] = jac_col[row];
252        }
253    }
254
255    mat
256}
257
258/// Use secant method to find a root of a vector function.
259///
260/// Using a vector function and its derivative, find a root based on an initial guess
261/// and finite element differences using Broyden's method.
262///
263/// # Returns
264/// `Ok(vec)` on success, where `vec` is a vector input for which the function is
265/// zero. `Err` on failure.
266///
267/// # Params
268/// `initial` Initial guesses of the root. Should be near actual root. Slice since this
269/// function finds roots of vector functions.
270///
271/// `f` Vector function for which to find the root
272///
273/// `tol` tolerance for error between iterations of Newton's method
274///
275/// `n_max` Maximum number of iterations
276///
277/// # Examples
278/// ```
279/// use nalgebra::SVector;
280/// use bacon_sci::roots::secant;
281/// fn cubic(x: &[f64]) -> SVector<f64, 1> {
282///   SVector::<f64, 1>::from_iterator(x.iter().map(|x| x.powi(3)))
283/// }
284/// //...
285/// fn example() {
286///   let solution = secant(&[0.1], cubic, 0.1, 0.001, 1000).unwrap();
287/// }
288/// ```
289pub fn secant<N, F, const S: usize>(
290    initial: &[N],
291    mut func: F,
292    h: <N as ComplexField>::RealField,
293    tol: <N as ComplexField>::RealField,
294    n_max: usize,
295) -> Result<SVector<N, S>, String>
296where
297    N: ComplexField + FromPrimitive + Copy,
298    <N as ComplexField>::RealField: FromPrimitive + Copy,
299    F: FnMut(&[N]) -> SVector<N, S>,
300    Const<S>: DimMin<Const<S>, Output = Const<S>>,
301{
302    let mut n = 2;
303
304    let mut guess = SVector::<N, S>::from_column_slice(initial);
305    let mut func_eval = func(guess.as_slice());
306
307    let jac = jac_finite_diff(&mut func, &mut guess, h);
308    let lu = jac.lu();
309    let try_inv = lu.try_inverse();
310    let mut jac_inv = if let Some(inv) = try_inv {
311        inv
312    } else {
313        return Err("Secant: Can not inverse finite element difference jacobian".to_owned());
314    };
315
316    let mut shift = -jac_inv * func_eval;
317    guess += &shift;
318
319    while n < n_max {
320        let func_eval_last = func_eval;
321        func_eval = func(guess.as_slice());
322        let diff = func_eval - func_eval_last;
323        let adjustment = -jac_inv * diff;
324        let s_transpose = shift.transpose();
325        let p = (-s_transpose * adjustment)[(0, 0)];
326        let u = s_transpose * jac_inv;
327        jac_inv += (shift + adjustment) * u / p;
328        shift = -&jac_inv * func_eval;
329        guess += &shift;
330        if shift.norm().abs() <= tol {
331            return Ok(guess);
332        }
333        n += 1;
334    }
335
336    Err("Secant: Maximum iterations exceeded".to_owned())
337}
338
339/// Use Brent's method to find the root of a function
340///
341/// The initial guesses must bracket the root. That is, the function evaluations of
342/// the initial guesses must differ in sign.
343///
344/// # Examples
345/// ```
346/// use bacon_sci::roots::brent;
347/// fn cubic(x: f64) -> f64 {
348///     x.powi(3)
349/// }
350/// //...
351/// fn example() {
352///   let solution = brent((0.1, -0.1), cubic, 1e-5).unwrap();
353/// }
354/// ```
355pub fn brent<N, F>(initial: (N, N), mut f: F, tol: N) -> Result<N, String>
356where
357    N: RealField + FromPrimitive + Copy,
358    F: FnMut(N) -> N,
359{
360    if !tol.is_sign_positive() {
361        return Err("brent: tolerance must be positive".to_owned());
362    }
363
364    let mut left = initial.0;
365    let mut right = initial.1;
366    let mut f_left = f(left);
367    let mut f_right = f(right);
368
369    // Make a the maximum
370    if f_left.abs() < f_right.abs() {
371        std::mem::swap(&mut left, &mut right);
372        std::mem::swap(&mut f_left, &mut f_right);
373    }
374
375    if !(f_left * f_right).is_sign_negative() {
376        return Err("brent: initial guesses do not bracket root".to_owned());
377    }
378
379    let two = N::from_i32(2).unwrap();
380    let three = N::from_i32(3).unwrap();
381    let four = N::from_i32(4).unwrap();
382
383    let mut c = left;
384    let mut f_c = f_left;
385    let mut s = right - f_right * (right - left) / (f_right - f_left);
386    let mut f_s = f(s);
387    let mut mflag = true;
388    let mut d = c;
389
390    while !(f_right.abs() < tol || f_s.abs() < tol || (left - right).abs() < tol) {
391        if (f_left - f_c).abs() < tol && (f_right - f_c).abs() < tol {
392            s = (left * f_right * f_c) / ((f_left - f_right) * (f_left - f_c))
393                + (right * f_left * f_c) / ((f_right - f_left) * (f_right - f_c))
394                + (c * f_left * f_right) / ((f_c - f_left) * (f_c - f_right));
395        } else {
396            s = right - f_right * (right - left) / (f_right - f_left);
397        }
398
399        if !(s >= (three * left + right) / four && s <= right)
400            || (mflag && (s - right).abs() >= (right - c) / two)
401            || (!mflag && (s - right).abs() >= (c - d).abs() / two)
402            || (mflag && (right - c).abs() < tol)
403            || (!mflag && (c - d).abs() < tol)
404        {
405            s = (left + right) / two;
406            mflag = true;
407        } else {
408            mflag = false;
409        }
410
411        f_s = f(s);
412        d = c;
413        c = right;
414        f_c = f_right;
415        if (f_left * f_s).is_sign_negative() {
416            right = s;
417            f_right = f_s;
418        } else {
419            left = s;
420            f_left = f_s;
421        }
422
423        if f_left.abs() < f_right.abs() {
424            std::mem::swap(&mut left, &mut right);
425            std::mem::swap(&mut f_left, &mut f_right);
426        }
427    }
428
429    if f_s.abs() < tol {
430        Ok(s)
431    } else {
432        Ok(right)
433    }
434}
435
436/// Find the root of an equation using the ITP method.
437///
438/// The initial guess must bracket the root, that is the
439/// function evaluations must differ in sign between the
440/// two initial guesses. k_1 is a parameter in (0, infty).
441/// k_2 is a paramater in (1, 1 + golden_ratio). n_0 is a parameter
442/// in [0, infty). This method gives the worst case performance of the
443/// bisection method (which has the best worst case performance) with
444/// better average convergance.
445///
446/// # Examples
447/// ```
448/// use bacon_sci::roots::itp;
449/// fn cubic(x: f64) -> f64 {
450///     x.powi(3)
451/// }
452/// //...
453/// fn example() {
454///   let solution = itp((0.1, -0.1), cubic, 0.1, 2.0, 0.99, 1e-5).unwrap();
455/// }
456/// ```
457pub fn itp<N, F>(initial: (N, N), mut f: F, k_1: N, k_2: N, n_0: N, tol: N) -> Result<N, String>
458where
459    N: RealField + FromPrimitive + Copy,
460    F: FnMut(N) -> N,
461{
462    if !tol.is_sign_positive() {
463        return Err("itp: tolerance must be positive".to_owned());
464    }
465
466    if !k_1.is_sign_positive() {
467        return Err("itp: k_1 must be positive".to_owned());
468    }
469
470    if k_2 <= N::one() || k_2 >= (N::one() + N::from_f64(0.5 * (1.0 + 5.0_f64.sqrt())).unwrap()) {
471        return Err("itp: k_2 must be in (1, 1 + golden_ratio)".to_owned());
472    }
473
474    let mut left = initial.0;
475    let mut right = initial.1;
476    let mut f_left = f(left);
477    let mut f_right = f(right);
478
479    if !(f_left * f_right).is_sign_negative() {
480        return Err("itp: initial guesses must bracket root".to_owned());
481    }
482
483    if f_left.is_sign_positive() {
484        std::mem::swap(&mut left, &mut right);
485        std::mem::swap(&mut f_left, &mut f_right);
486    }
487
488    let two = N::from_i32(2).unwrap();
489
490    let n_half = ((right - left).abs() / (two * tol)).log2().ceil();
491    let n_max = n_half + n_0;
492    let mut j = 0;
493
494    while (right - left).abs() > two * tol {
495        let x_half = (left + right) / two;
496        let r = tol * two.powf(n_max + n_0 - N::from_i32(j).unwrap()) - (right - left) / two;
497        let x_f = (f_right * left - f_left * right) / (f_right - f_left);
498        let sigma = (x_half - x_f) / (x_half - x_f).abs();
499        let delta = k_1 * (right - left).powf(k_2);
500        let x_t = if delta <= (x_half - x_f).abs() {
501            x_f + sigma * delta
502        } else {
503            x_half
504        };
505        let x_itp = if (x_t - x_half).abs() <= r {
506            x_t
507        } else {
508            x_half - sigma * r
509        };
510        let f_itp = f(x_itp);
511        if f_itp.is_sign_positive() {
512            right = x_itp;
513            f_right = f_itp;
514        } else if f_itp.is_sign_negative() {
515            left = x_itp;
516            f_left = f_itp;
517        } else {
518            left = x_itp;
519            right = x_itp;
520        }
521        j += 1;
522    }
523
524    Ok((left + right) / two)
525}