1use std::{env, fmt::Display, ops::{Sub, Div}};
2use num_dual::{DualNumFloat,Dual32};
4
5pub trait Derivable<T> where T: DualNumFloat {
6 fn execute_derivative(&self) -> Self;
7 fn zeroth_derivative(&self) -> T;
8 fn first_derivative(&self) -> T;
9}
10
11pub trait Coerceable<T> where T: DualNumFloat{
12 fn coerce_to(&self) -> T;
13 fn coerce_from(value: T) -> Self;
14}
15
16impl Derivable<f32> for Dual32 {
17 fn execute_derivative(&self) -> Self {
18 return self.derivative()
19 }
20 fn zeroth_derivative(&self) -> f32 {
21 return self.re
22 }
23 fn first_derivative(&self) -> f32 {
24 return self.eps
25 }
26}
27
28impl <T: DualNumFloat> Coerceable<T> for Dual32 {
29 fn coerce_to(&self) -> T {
30 return T::from(self.re).unwrap()
31 }
32 fn coerce_from(value: T) -> Self {
33 return Dual32::from_re(value.to_f32().unwrap())
34 }
35}
36
37pub struct NewtonOptions<T> where T: DualNumFloat {
38 pub guess: T,
39 pub patience: i32,
40 pub tolerance: T
41}
42
43pub struct BisectionOptions<T> where T: DualNumFloat {
44 pub lower: T,
45 pub upper: T,
46 pub resolution: i32
47}
48
49pub struct RootSearchOptions<T> where T: DualNumFloat {
50 pub patience: i32,
51 pub tolerance: T,
52 pub lower: T,
53 pub upper: T,
54 pub resolution: i32
55}
56
57pub struct NewtonResult<T> where T: DualNumFloat {
58 pub root: Option<T>,
59 pub iterations: i32
60}
61
62pub struct BisectionResult<T> where T: DualNumFloat {
63 pub lower: T,
64 pub upper: T,
65}
66
67pub struct RootSearchResult<T> where T: DualNumFloat {
68 pub roots: Vec<T>,
69 pub bisections: Vec<BisectionResult<T>>,
70}
71
72fn newton<'a, F, N, T>(f: F, opts: NewtonOptions<T>) -> NewtonResult<T>
73where
74 F: Fn(N) -> N + Send + Sync + 'a,
75 N: Derivable<T> + Coerceable<T> + Display + Copy,
76 T: DualNumFloat
77{
78 let mut current: T = opts.guess;
79 let mut count = 0;
80 let debug_env = env::var("DEBUG");
81 let debug = match debug_env {
82 Ok(val) => val == "true",
83 Err(_) => false
84 };
85 loop {
86 count += 1;
87 let x = N::coerce_from(current).execute_derivative();
88 let z = f(x);
89 let next = x.zeroth_derivative() - z.zeroth_derivative() / z.first_derivative();
90 let diff = next - current;
91 if diff.abs() < opts.tolerance {
92 if debug {
93 println!("Found root at: {}", next);
94 }
95 return NewtonResult{
96 root: Some(next),
97 iterations: count
98 };
99 } else {
100 if count > opts.patience {
101 if debug {
102 println!("Failed to find root with initial guess of {}", opts.guess);
103 println!("Last iteration was: {}", current);
104 println!("Try updating the initial guess or increasing the tolerance or patience");
105 }
106 return NewtonResult{
107 root: None,
108 iterations: count
109 };
110 }
111 current = next;
112 }
113 }
114}
115
116fn find_bisections<F, N, T>(f: F, opts: BisectionOptions<T>) -> Vec<BisectionResult<T>>
117where
118 F: Fn(N) -> N + Sync + Send + Copy,
119 N: Derivable<T> + Coerceable<T> + Display + Copy + Sub + Div,
120 T: DualNumFloat
121{
122 let step = (opts.upper - opts.lower) / T::from(opts.resolution).unwrap() + T::epsilon();
123 let mut values: Vec<BisectionResult<T>> = Vec::new();
125
126 for i in 0..opts.resolution {
127 let a = opts.lower + step * T::from(i).unwrap();
128 let b = opts.lower + step * T::from(i+1).unwrap();
129 let fa = f(N::coerce_from(a));
130 let fb = f(N::coerce_from(b));
131 let pos2neg = fa.zeroth_derivative() > T::zero() && fb.zeroth_derivative() < T::zero();
132 let neg2pos = fa.zeroth_derivative() < T::zero() && fb.zeroth_derivative() > T::zero();
133 if pos2neg || neg2pos {
134 values.push(BisectionResult{lower: a, upper: b});
135 }
136 };
137 values
138}
139
140pub fn root_search<F, N, T>(f: F, opts: RootSearchOptions<T>) -> RootSearchResult<T>
141where
142 F: Fn(N) -> N + Sync + Send + Copy,
143 N: Derivable<T> + Coerceable<T> + Display + Copy + Sub + Div,
144 T: DualNumFloat
145{
146 if opts.lower > opts.upper {
147 panic!("Lower bound must be greater than upper bound")
148 }
149 if opts.lower == opts.upper {
150 panic!("Bounds cannot be the same")
151 }
152 let bisections = find_bisections(f, BisectionOptions{
153 lower: opts.lower,
154 upper: opts.upper,
155 resolution: opts.resolution
156 });
157 let mut roots: Vec<T> = Vec::new();
158 for bisection in &bisections {
159 let res = T::from(100).unwrap();
160 let step = (bisection.upper - bisection.lower) / res;
161 for i in 0..res.to_i32().unwrap() {
162 let guess = bisection.lower + (T::from(i).unwrap() * step);
163 let res = newton(f, NewtonOptions{
164 guess: guess,
165 patience: opts.patience,
166 tolerance: opts.tolerance
167 });
168 if res.root.is_none() {
169 break;
170 }
171 let root = res.root.unwrap();
172 if bisection.lower < root && root < bisection.upper {
173 roots.push(root);
174 break;
175 }
176 }
177
178 }
179 RootSearchResult{roots, bisections}
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use num_dual::{Dual32, DualNum};
186
187 #[test]
188 fn find_sine_root_newton() {
189 fn sine<D: DualNum<f32>>(x: D) -> D {
190 x.sin()
191 }
192 let res = newton::<_,Dual32,f32>(&sine, NewtonOptions{
193 guess: 2.0,
194 patience: 1000,
195 tolerance: 0.0001
196 });
197 assert_eq!(std::f32::consts::PI, res.root.unwrap())
198 }
199
200 #[test]
201 fn find_cosine_root_newton() {
202 fn cosine<D: DualNum<f32>>(x: D) -> D {
203 x.cos()
204 }
205 let res = newton::<_,Dual32,f32>(&cosine, NewtonOptions{
206 guess: 2.0,
207 patience: 1000,
208 tolerance: 0.0001
209 });
210 assert_eq!(std::f32::consts::PI / 2.0, res.root.unwrap())
211 }
212
213 #[test]
214 fn find_sine_bisections() {
215 fn sine<D: DualNum<f32>>(x: D) -> D {
216 x.sin()
217 }
218 let bisections = find_bisections::<_,Dual32,f32>(&sine, BisectionOptions{
219 lower: -5.0,
220 upper: 5.0,
221 resolution: 1000
222 });
223 for bisection in &bisections {
224 println!("bisection: ({},{})", bisection.lower, bisection.upper)
225 }
226 assert_eq!(bisections.len(), 3)
227 }
228
229 #[test]
230 fn find_cosine_bisections() {
231 fn cosine<D: DualNum<f32>>(x: D) -> D {
232 x.cos()
233 }
234 let bisections = find_bisections::<_,Dual32,f32>(&cosine, BisectionOptions{
235 lower: -5.0,
236 upper: 5.0,
237 resolution: 1000
238 });
239 for bisection in &bisections {
240 println!("bisection: ({},{})", bisection.lower, bisection.upper)
241 }
242 assert_eq!(bisections.len(), 4)
243 }
244
245 #[test]
246 fn find_sine_roots() {
247 fn sine<D: DualNum<f32>>(x: D) -> D {
248 x.sin()
249 }
250 let res = root_search::<_,Dual32,f32>(&sine, RootSearchOptions{
251 lower: -5.0,
252 upper: 5.0,
253 patience: 2000,
254 tolerance: 0.0001,
255 resolution: 1000
256 });
257 for root in &res.roots {
258 println!("root: {}", root);
259 }
260 assert_eq!(res.roots.len(), 3);
261 assert!(res.roots.contains(&std::f32::consts::PI));
262 assert!(res.roots.contains(&(-std::f32::consts::PI)));
263 assert!(res.roots.contains(&0.0));
264 }
265
266 #[test]
267 fn find_cosine_roots() {
268 fn cosine<D: DualNum<f32>>(x: D) -> D {
269 x.cos()
270 }
271 let res = root_search::<_,Dual32,f32>(&cosine, RootSearchOptions{
272 lower: -5.0,
273 upper: 5.0,
274 patience: 2000,
275 tolerance: 0.0001,
276 resolution: 1000
277 });
278 for root in &res.roots {
279 println!("root: {}", root);
280 }
281 assert_eq!(res.roots.len(), 4);
282 assert!(res.roots.contains(&std::f32::consts::FRAC_PI_2));
283 assert!(res.roots.contains(&(-std::f32::consts::FRAC_PI_2)));
284 assert!(res.roots.contains(&(std::f32::consts::FRAC_PI_2 * 3.0)));
285 assert!(res.roots.contains(&(-std::f32::consts::FRAC_PI_2 * 3.0)));
286 }
287
288
289}