numerics_rs/root_finding/
mod.rs

1use std::fmt::Debug;
2
3mod bisection;
4mod builder;
5mod newton_raphson;
6mod secant;
7
8pub use builder::RootFinderBuilder;
9
10#[derive(Debug)]
11pub enum RootFindingMethod {
12    Bisection,
13    Brent,
14    Secant,
15    InverseQuadraticInterpolation,
16    NewtonRaphson,
17}
18type F = dyn Fn(f64) -> f64;
19pub struct RootFindingIterationDecorator<'a> {
20    function: &'a F,           // The target function f(x)
21    derivative: Option<&'a F>, // The derivative f'(x)
22    num_it: usize,
23    max_iterations: usize,
24    log_convergence: bool,
25    root_finder: Box<dyn RootFinder + 'a>,
26}
27
28impl<'a> RootFindingIterationDecorator<'a> {
29    pub fn new(
30        function: &'a F,           // The target function f(x)
31        derivative: Option<&'a F>, // The derivative f'(x)
32        root_finder: Box<dyn RootFinder + 'a>,
33        max_iterations: usize,
34        log_convergence: bool,
35    ) -> Self {
36        Self {
37            function,
38            derivative,
39            num_it: 1,
40            max_iterations,
41            log_convergence,
42            root_finder,
43        }
44    }
45
46    pub fn find_root(&mut self) -> Result<f64, String> {
47        let rf = &mut self.root_finder;
48        let mut args = rf.get_init_args();
49        loop {
50            let fx = args
51                .iter()
52                .map(|arg| (self.function)(*arg))
53                .collect::<Vec<_>>();
54            let mut dfx = vec![];
55            if self.derivative.is_some() {
56                dfx.extend(args.iter().map(|arg| self.derivative.unwrap()(*arg)));
57            }
58            //TODO: Add time logging as well
59            if self.log_convergence {
60                println!(
61                    "Iteration {}: x = {:?}, fx = {:?}, dfx = {:?}",
62                    self.num_it, args, fx, dfx
63                );
64            }
65            let should_stop: Option<Result<f64, String>> = rf.should_stop(&fx, &dfx);
66            if let Some(res) = should_stop {
67                return res;
68            }
69            if self.num_it == self.max_iterations {
70                return Err("Maximum iterations reached without convergence.".to_string());
71            }
72            self.num_it += 1;
73            args = rf.get_next_args(&fx, &dfx);
74        }
75    }
76}
77
78pub trait RootFinder {
79    fn get_init_args(&mut self) -> Box<[f64]>;
80    fn get_next_args(&mut self, fx: &[f64], dfx: &[f64]) -> Box<[f64]>;
81    fn should_stop(&self, fx: &[f64], dfx: &[f64]) -> Option<Result<f64, String>>;
82}