advanced_algorithms/numerical/
newton_raphson.rs1use crate::{AlgorithmError, Result};
20
21pub fn find_root<F, G>(
42 f: F,
43 df: G,
44 x0: f64,
45 tolerance: f64,
46 max_iterations: usize,
47) -> Result<f64>
48where
49 F: Fn(f64) -> f64,
50 G: Fn(f64) -> f64,
51{
52 let mut x = x0;
53
54 for iteration in 0..max_iterations {
55 let fx = f(x);
56 let dfx = df(x);
57
58 if dfx.abs() < 1e-14 {
60 return Err(AlgorithmError::NumericalInstability(
61 format!("Derivative too close to zero at x = {}", x)
62 ));
63 }
64
65 let x_new = x - fx / dfx;
67
68 if x_new.is_nan() || x_new.is_infinite() {
70 return Err(AlgorithmError::ConvergenceFailure(
71 format!("Solution diverged at iteration {}", iteration)
72 ));
73 }
74
75 if (x_new - x).abs() < tolerance && fx.abs() < tolerance {
77 return Ok(x_new);
78 }
79
80 x = x_new;
81 }
82
83 Err(AlgorithmError::ConvergenceFailure(
84 format!("Failed to converge after {} iterations", max_iterations)
85 ))
86}
87
88pub struct NewtonRaphsonConfig {
90 pub tolerance: f64,
92 pub max_iterations: usize,
94 pub verbose: bool,
96}
97
98impl Default for NewtonRaphsonConfig {
99 fn default() -> Self {
100 NewtonRaphsonConfig {
101 tolerance: 1e-10,
102 max_iterations: 100,
103 verbose: false,
104 }
105 }
106}
107
108pub struct NewtonRaphsonSolver<F, G>
110where
111 F: Fn(f64) -> f64,
112 G: Fn(f64) -> f64,
113{
114 f: F,
115 df: G,
116 config: NewtonRaphsonConfig,
117}
118
119impl<F, G> NewtonRaphsonSolver<F, G>
120where
121 F: Fn(f64) -> f64,
122 G: Fn(f64) -> f64,
123{
124 pub fn new(f: F, df: G) -> Self {
126 NewtonRaphsonSolver {
127 f,
128 df,
129 config: NewtonRaphsonConfig::default(),
130 }
131 }
132
133 pub fn with_config(mut self, config: NewtonRaphsonConfig) -> Self {
135 self.config = config;
136 self
137 }
138
139 pub fn solve(&self, x0: f64) -> Result<SolutionResult> {
141 let mut x = x0;
142 let mut history = Vec::new();
143
144 for iteration in 0..self.config.max_iterations {
145 let fx = (self.f)(x);
146 let dfx = (self.df)(x);
147
148 if self.config.verbose {
149 history.push((x, fx));
150 }
151
152 if dfx.abs() < 1e-14 {
153 return Err(AlgorithmError::NumericalInstability(
154 format!("Derivative too close to zero at x = {}", x)
155 ));
156 }
157
158 let x_new = x - fx / dfx;
159
160 if x_new.is_nan() || x_new.is_infinite() {
161 return Err(AlgorithmError::ConvergenceFailure(
162 format!("Solution diverged at iteration {}", iteration)
163 ));
164 }
165
166 if (x_new - x).abs() < self.config.tolerance && fx.abs() < self.config.tolerance {
167 return Ok(SolutionResult {
168 root: x_new,
169 iterations: iteration + 1,
170 final_error: fx.abs(),
171 history,
172 });
173 }
174
175 x = x_new;
176 }
177
178 Err(AlgorithmError::ConvergenceFailure(
179 format!("Failed to converge after {} iterations", self.config.max_iterations)
180 ))
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct SolutionResult {
187 pub root: f64,
189 pub iterations: usize,
191 pub final_error: f64,
193 pub history: Vec<(f64, f64)>,
195}
196
197pub fn find_roots_in_interval<F, G>(
212 f: F,
213 df: G,
214 start: f64,
215 end: f64,
216 num_points: usize,
217 tolerance: f64,
218) -> Vec<f64>
219where
220 F: Fn(f64) -> f64 + Copy,
221 G: Fn(f64) -> f64 + Copy,
222{
223 let step = (end - start) / (num_points as f64);
224 let mut roots = Vec::new();
225
226 for i in 0..num_points {
227 let x0 = start + (i as f64) * step;
228
229 if let Ok(root) = find_root(f, df, x0, tolerance, 100) {
230 let is_new = roots.iter().all(|&r: &f64| (r - root).abs() > tolerance * 10.0);
232
233 if is_new && root >= start && root <= end {
234 roots.push(root);
235 }
236 }
237 }
238
239 roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
240 roots
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_square_root() {
249 let f = |x: f64| x * x - 2.0;
251 let df = |x: f64| 2.0 * x;
252
253 let root = find_root(f, df, 1.0, 1e-10, 100).unwrap();
254 assert!((root - 2.0_f64.sqrt()).abs() < 1e-10);
255 }
256
257 #[test]
258 fn test_cubic() {
259 let f = |x: f64| x * x * x - x - 2.0;
261 let df = |x: f64| 3.0 * x * x - 1.0;
262
263 let root = find_root(f, df, 1.5, 1e-10, 100).unwrap();
264 assert!(f(root).abs() < 1e-10);
265 }
266
267 #[test]
268 fn test_solver() {
269 let f = |x: f64| x * x - 4.0;
270 let df = |x: f64| 2.0 * x;
271
272 let solver = NewtonRaphsonSolver::new(f, df);
273 let result = solver.solve(1.0).unwrap();
274
275 assert!((result.root - 2.0).abs() < 1e-10);
276 assert!(result.iterations > 0);
277 }
278
279 #[test]
280 fn test_multiple_roots() {
281 let f = |x: f64| x * x - 1.0;
283 let df = |x: f64| 2.0 * x;
284
285 let roots = find_roots_in_interval(f, df, -2.0, 2.0, 10, 1e-10);
286 assert_eq!(roots.len(), 2);
287 }
288}