rusty_rootsearch/
lib.rs

1use std::{env, fmt::Display, ops::{Sub, Div}};
2// use std::{sync::mpsc::{Sender, Receiver, channel}, thread::{Thread,spawn, JoinHandle}};
3use num_dual::{DualNumFloat,Dual32};
4
5pub trait Derivable<T> where T: DualNumFloat {
6    fn execute_derivative(&self) -> Self;
7    fn zeroth_derivative(&self) -> T;
8    fn first_derivative(&self) -> T;
9}
10
11pub trait Coerceable<T> where T: DualNumFloat{
12    fn coerce_to(&self) -> T;
13    fn coerce_from(value: T) -> Self;
14}
15
16impl Derivable<f32> for Dual32 {
17    fn execute_derivative(&self) -> Self {
18        return self.derivative()
19    }
20    fn zeroth_derivative(&self) -> f32 {
21        return self.re
22    }
23    fn first_derivative(&self) -> f32 {
24        return self.eps
25    }
26}
27
28impl <T: DualNumFloat> Coerceable<T> for Dual32 {
29    fn coerce_to(&self) -> T {
30        return T::from(self.re).unwrap()
31    }
32    fn coerce_from(value: T) -> Self {
33        return Dual32::from_re(value.to_f32().unwrap())
34    }
35}
36
37pub struct NewtonOptions<T> where T: DualNumFloat {
38    pub guess: T,
39    pub patience: i32,
40    pub tolerance: T
41}
42
43pub struct BisectionOptions<T> where T: DualNumFloat {
44    pub lower: T,
45    pub upper: T,
46    pub resolution: i32
47}
48
49pub struct RootSearchOptions<T> where T: DualNumFloat {
50    pub patience: i32,
51    pub tolerance: T,
52    pub lower: T,
53    pub upper: T,
54    pub resolution: i32
55}
56
57pub struct NewtonResult<T> where T: DualNumFloat {
58    pub root: Option<T>,
59    pub iterations: i32
60}
61
62pub struct BisectionResult<T> where T: DualNumFloat {
63    pub lower: T,
64    pub upper: T,
65}
66
67pub struct RootSearchResult<T> where T: DualNumFloat {
68    pub roots: Vec<T>,
69    pub bisections: Vec<BisectionResult<T>>,
70}
71
72fn newton<'a, F, N, T>(f: F, opts: NewtonOptions<T>) -> NewtonResult<T>
73where
74    F: Fn(N) -> N + Send + Sync + 'a,
75    N: Derivable<T> + Coerceable<T> + Display + Copy,
76    T: DualNumFloat
77{
78    let mut current: T = opts.guess;
79    let mut count = 0;
80    let debug_env = env::var("DEBUG");
81    let debug = match debug_env {
82        Ok(val) => val == "true",
83        Err(_) => false
84    };
85    loop {
86        count += 1;
87        let x = N::coerce_from(current).execute_derivative();
88        let z = f(x);
89        let next = x.zeroth_derivative() - z.zeroth_derivative() / z.first_derivative();
90        let diff = next - current;
91        if diff.abs() < opts.tolerance {
92            if debug {
93                println!("Found root at: {}", next);
94            }
95            return NewtonResult{
96                root: Some(next),
97                iterations: count
98            };
99        } else {
100            if count > opts.patience {
101                if debug {
102                    println!("Failed to find root with initial guess of {}", opts.guess);
103                    println!("Last iteration was: {}", current);
104                    println!("Try updating the initial guess or increasing the tolerance or patience");
105                }
106                return NewtonResult{
107                    root: None,
108                    iterations: count
109                };
110            }
111            current = next;
112        }
113    }
114}
115
116fn find_bisections<F, N, T>(f: F, opts: BisectionOptions<T>) -> Vec<BisectionResult<T>>
117where
118    F: Fn(N) -> N + Sync + Send + Copy,
119    N: Derivable<T> + Coerceable<T> + Display + Copy + Sub + Div,
120    T: DualNumFloat
121{
122    let step = (opts.upper - opts.lower) / T::from(opts.resolution).unwrap() + T::epsilon();
123    // Add off-set to step to deal with roots at middle of lower and upper range
124    let mut values: Vec<BisectionResult<T>> = Vec::new();
125
126    for i in 0..opts.resolution {
127        let a = opts.lower + step * T::from(i).unwrap();
128        let b = opts.lower + step * T::from(i+1).unwrap();
129        let fa = f(N::coerce_from(a));
130        let fb = f(N::coerce_from(b));
131        let pos2neg = fa.zeroth_derivative() > T::zero() && fb.zeroth_derivative() < T::zero();
132        let neg2pos = fa.zeroth_derivative() < T::zero() && fb.zeroth_derivative() > T::zero();
133        if pos2neg || neg2pos {
134            values.push(BisectionResult{lower: a, upper: b});
135        }
136    };
137    values
138}
139
140pub fn root_search<F, N, T>(f: F, opts: RootSearchOptions<T>) -> RootSearchResult<T>
141where
142    F: Fn(N) -> N + Sync + Send + Copy,
143    N: Derivable<T> + Coerceable<T> + Display + Copy + Sub + Div,
144    T: DualNumFloat
145{
146    if opts.lower > opts.upper {
147        panic!("Lower bound must be greater than upper bound")
148    }
149    if opts.lower == opts.upper {
150        panic!("Bounds cannot be the same")
151    }
152    let bisections = find_bisections(f, BisectionOptions{
153        lower: opts.lower,
154        upper: opts.upper,
155        resolution: opts.resolution
156    });
157    let mut roots: Vec<T> = Vec::new();
158    for bisection in &bisections {
159        let res = T::from(100).unwrap();
160        let step = (bisection.upper - bisection.lower) / res;
161        for i in 0..res.to_i32().unwrap() {
162            let guess = bisection.lower + (T::from(i).unwrap() * step);
163            let res = newton(f, NewtonOptions{
164                guess: guess,
165                patience: opts.patience,
166                tolerance: opts.tolerance
167            });
168            if res.root.is_none() {
169                break;
170            }
171            let root = res.root.unwrap();
172            if bisection.lower < root && root < bisection.upper {
173                roots.push(root);
174                break;
175            }
176        }
177
178    }
179    RootSearchResult{roots, bisections}
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use num_dual::{Dual32, DualNum};
186
187    #[test]
188    fn find_sine_root_newton() {
189        fn sine<D: DualNum<f32>>(x: D) -> D {
190            x.sin()
191        }
192        let res = newton::<_,Dual32,f32>(&sine, NewtonOptions{
193            guess: 2.0,
194            patience: 1000,
195            tolerance: 0.0001
196        });
197        assert_eq!(std::f32::consts::PI, res.root.unwrap())
198    }
199
200    #[test]
201    fn find_cosine_root_newton() {
202        fn cosine<D: DualNum<f32>>(x: D) -> D {
203            x.cos()
204        }
205        let res = newton::<_,Dual32,f32>(&cosine, NewtonOptions{
206            guess: 2.0,
207            patience: 1000,
208            tolerance: 0.0001
209        });
210        assert_eq!(std::f32::consts::PI / 2.0, res.root.unwrap())
211    }
212
213    #[test]
214    fn find_sine_bisections() {
215        fn sine<D: DualNum<f32>>(x: D) -> D {
216            x.sin()
217        }
218        let bisections = find_bisections::<_,Dual32,f32>(&sine, BisectionOptions{
219            lower: -5.0, 
220            upper: 5.0, 
221            resolution: 1000
222        });
223        for bisection in &bisections {
224            println!("bisection: ({},{})", bisection.lower, bisection.upper)
225        }
226        assert_eq!(bisections.len(), 3)
227    }
228
229    #[test]
230    fn find_cosine_bisections() {
231        fn cosine<D: DualNum<f32>>(x: D) -> D {
232            x.cos()
233        }
234        let bisections = find_bisections::<_,Dual32,f32>(&cosine, BisectionOptions{
235            lower: -5.0, 
236            upper: 5.0, 
237            resolution: 1000
238        });
239        for bisection in &bisections {
240            println!("bisection: ({},{})", bisection.lower, bisection.upper)
241        }
242        assert_eq!(bisections.len(), 4)
243    }
244
245    #[test]
246    fn find_sine_roots() {
247        fn sine<D: DualNum<f32>>(x: D) -> D {
248            x.sin()
249        }
250        let res = root_search::<_,Dual32,f32>(&sine, RootSearchOptions{
251            lower: -5.0,
252            upper: 5.0,
253            patience: 2000,
254            tolerance: 0.0001,
255            resolution: 1000
256        });
257        for root in &res.roots {
258            println!("root: {}", root);
259        }
260        assert_eq!(res.roots.len(), 3);
261        assert!(res.roots.contains(&std::f32::consts::PI));
262        assert!(res.roots.contains(&(-std::f32::consts::PI)));
263        assert!(res.roots.contains(&0.0));
264    }
265
266    #[test]
267    fn find_cosine_roots() {
268        fn cosine<D: DualNum<f32>>(x: D) -> D {
269            x.cos()
270        }
271        let res = root_search::<_,Dual32,f32>(&cosine, RootSearchOptions{
272            lower: -5.0,
273            upper: 5.0,
274            patience: 2000,
275            tolerance: 0.0001,
276            resolution: 1000
277        });
278        for root in &res.roots {
279            println!("root: {}", root);
280        }
281        assert_eq!(res.roots.len(), 4);
282        assert!(res.roots.contains(&std::f32::consts::FRAC_PI_2));
283        assert!(res.roots.contains(&(-std::f32::consts::FRAC_PI_2)));
284        assert!(res.roots.contains(&(std::f32::consts::FRAC_PI_2 * 3.0)));
285        assert!(res.roots.contains(&(-std::f32::consts::FRAC_PI_2 * 3.0)));
286    }
287
288
289}