1use nalgebra::{ComplexField, Const, DimMin, RealField, SMatrix, SVector};
8use num_traits::{FromPrimitive, Zero};
9
10mod polynomial;
11pub use polynomial::*;
12
13pub fn bisection<N, F>(
44 (mut left, mut right): (N, N),
45 mut f: F,
46 tol: N,
47 n_max: usize,
48) -> Result<N, String>
49where
50 N: RealField + FromPrimitive + Copy,
51 F: FnMut(N) -> N,
52{
53 if left >= right {
54 return Err("Bisection: requirement: right > left".to_owned());
55 }
56
57 let mut n = 1;
58
59 let mut f_a = f(left);
60 if (f_a * f(right)).is_sign_positive() {
61 return Err("Bisection: requirement: Signs must be different".to_owned());
62 }
63
64 let half = N::from_f64(0.5).unwrap();
65
66 let mut half_interval = (left - right) * half;
67 let mut middle = left + half_interval;
68
69 if middle.abs() <= tol {
70 return Ok(middle);
71 }
72
73 while n <= n_max {
74 let f_p = f(middle);
75 if (f_p * f_a).is_sign_positive() {
76 left = middle;
77 f_a = f_p;
78 } else {
79 right = middle;
80 }
81
82 half_interval = (right - left) * half;
83
84 let middle_new = left + half_interval;
85
86 if (middle - middle_new).abs() / middle.abs() < tol || middle_new.abs() < tol {
87 return Ok(middle_new);
88 }
89
90 middle = middle_new;
91 n += 1;
92 }
93
94 Err("Bisection: Maximum iterations exceeded".to_owned())
95}
96
97pub fn steffensen<N>(mut initial: N, f: fn(N) -> N, tol: N, n_max: usize) -> Result<N, String>
126where
127 N: RealField + FromPrimitive + Copy,
128{
129 let mut n = 0;
130
131 while n < n_max {
132 let guess = f(initial);
133 let new_guess = f(guess);
134 let diff = initial
135 - (guess - initial).powi(2) / (new_guess - N::from_f64(2.0).unwrap() * guess + initial);
136 if (diff - initial).abs() <= tol {
137 return Ok(diff);
138 }
139 initial = diff;
140 n += 1;
141 }
142
143 Err("Steffensen: Maximum number of iterations exceeded".to_owned())
144}
145
146pub fn newton<N, F, G, const S: usize>(
184 initial: &[N],
185 mut f: F,
186 mut jac: G,
187 tol: <N as ComplexField>::RealField,
188 n_max: usize,
189) -> Result<SVector<N, S>, String>
190where
191 N: ComplexField + FromPrimitive + Copy,
192 <N as ComplexField>::RealField: FromPrimitive + Copy,
193 F: FnMut(&[N]) -> SVector<N, S>,
194 G: FnMut(&[N]) -> SMatrix<N, S, S>,
195 Const<S>: DimMin<Const<S>, Output = Const<S>>,
196{
197 let mut guess = SVector::<N, S>::from_column_slice(initial);
198 let mut norm = guess.dot(&guess).sqrt().abs();
199 let mut n = 0;
200
201 if norm <= tol {
202 return Ok(guess);
203 }
204
205 while n < n_max {
206 let f_val = -f(guess.as_slice());
207 let f_deriv_val = jac(guess.as_slice());
208 let lu = f_deriv_val.lu();
209 match lu.solve(&f_val) {
210 None => return Err("newton: failed to solve linear equation".to_owned()),
211 Some(adjustment) => {
212 let new_guess = guess + adjustment;
213 let new_norm = new_guess.dot(&new_guess).sqrt().abs();
214 if ((norm - new_norm) / norm).abs() <= tol || new_norm <= tol {
215 return Ok(new_guess);
216 }
217
218 norm = new_norm;
219 guess = new_guess;
220 n += 1;
221 }
222 }
223 }
224
225 Err("Newton: Maximum iterations exceeded".to_owned())
226}
227
228fn jac_finite_diff<N, F, const S: usize>(
229 mut f: F,
230 x: &mut SVector<N, S>,
231 h: <N as ComplexField>::RealField,
232) -> SMatrix<N, S, S>
233where
234 N: ComplexField + FromPrimitive + Copy,
235 <N as ComplexField>::RealField: FromPrimitive + Copy,
236 F: FnMut(&[N]) -> SVector<N, S>,
237{
238 let mut mat = SMatrix::<N, S, S>::zero();
239 let h = N::from_real(h);
240 let denom = N::one() / (N::from_i32(2).unwrap() * h);
241
242 for col in 0..mat.row(0).len() {
243 x[col] += h;
244 let above = f(x.as_slice());
245 x[col] -= h;
246 x[col] -= h;
247 let below = f(x.as_slice());
248 x[col] += h;
249 let jac_col = (above + below) * denom;
250 for row in 0..mat.column(0).len() {
251 mat[(row, col)] = jac_col[row];
252 }
253 }
254
255 mat
256}
257
258pub fn secant<N, F, const S: usize>(
290 initial: &[N],
291 mut func: F,
292 h: <N as ComplexField>::RealField,
293 tol: <N as ComplexField>::RealField,
294 n_max: usize,
295) -> Result<SVector<N, S>, String>
296where
297 N: ComplexField + FromPrimitive + Copy,
298 <N as ComplexField>::RealField: FromPrimitive + Copy,
299 F: FnMut(&[N]) -> SVector<N, S>,
300 Const<S>: DimMin<Const<S>, Output = Const<S>>,
301{
302 let mut n = 2;
303
304 let mut guess = SVector::<N, S>::from_column_slice(initial);
305 let mut func_eval = func(guess.as_slice());
306
307 let jac = jac_finite_diff(&mut func, &mut guess, h);
308 let lu = jac.lu();
309 let try_inv = lu.try_inverse();
310 let mut jac_inv = if let Some(inv) = try_inv {
311 inv
312 } else {
313 return Err("Secant: Can not inverse finite element difference jacobian".to_owned());
314 };
315
316 let mut shift = -jac_inv * func_eval;
317 guess += &shift;
318
319 while n < n_max {
320 let func_eval_last = func_eval;
321 func_eval = func(guess.as_slice());
322 let diff = func_eval - func_eval_last;
323 let adjustment = -jac_inv * diff;
324 let s_transpose = shift.transpose();
325 let p = (-s_transpose * adjustment)[(0, 0)];
326 let u = s_transpose * jac_inv;
327 jac_inv += (shift + adjustment) * u / p;
328 shift = -&jac_inv * func_eval;
329 guess += &shift;
330 if shift.norm().abs() <= tol {
331 return Ok(guess);
332 }
333 n += 1;
334 }
335
336 Err("Secant: Maximum iterations exceeded".to_owned())
337}
338
339pub fn brent<N, F>(initial: (N, N), mut f: F, tol: N) -> Result<N, String>
356where
357 N: RealField + FromPrimitive + Copy,
358 F: FnMut(N) -> N,
359{
360 if !tol.is_sign_positive() {
361 return Err("brent: tolerance must be positive".to_owned());
362 }
363
364 let mut left = initial.0;
365 let mut right = initial.1;
366 let mut f_left = f(left);
367 let mut f_right = f(right);
368
369 if f_left.abs() < f_right.abs() {
371 std::mem::swap(&mut left, &mut right);
372 std::mem::swap(&mut f_left, &mut f_right);
373 }
374
375 if !(f_left * f_right).is_sign_negative() {
376 return Err("brent: initial guesses do not bracket root".to_owned());
377 }
378
379 let two = N::from_i32(2).unwrap();
380 let three = N::from_i32(3).unwrap();
381 let four = N::from_i32(4).unwrap();
382
383 let mut c = left;
384 let mut f_c = f_left;
385 let mut s = right - f_right * (right - left) / (f_right - f_left);
386 let mut f_s = f(s);
387 let mut mflag = true;
388 let mut d = c;
389
390 while !(f_right.abs() < tol || f_s.abs() < tol || (left - right).abs() < tol) {
391 if (f_left - f_c).abs() < tol && (f_right - f_c).abs() < tol {
392 s = (left * f_right * f_c) / ((f_left - f_right) * (f_left - f_c))
393 + (right * f_left * f_c) / ((f_right - f_left) * (f_right - f_c))
394 + (c * f_left * f_right) / ((f_c - f_left) * (f_c - f_right));
395 } else {
396 s = right - f_right * (right - left) / (f_right - f_left);
397 }
398
399 if !(s >= (three * left + right) / four && s <= right)
400 || (mflag && (s - right).abs() >= (right - c) / two)
401 || (!mflag && (s - right).abs() >= (c - d).abs() / two)
402 || (mflag && (right - c).abs() < tol)
403 || (!mflag && (c - d).abs() < tol)
404 {
405 s = (left + right) / two;
406 mflag = true;
407 } else {
408 mflag = false;
409 }
410
411 f_s = f(s);
412 d = c;
413 c = right;
414 f_c = f_right;
415 if (f_left * f_s).is_sign_negative() {
416 right = s;
417 f_right = f_s;
418 } else {
419 left = s;
420 f_left = f_s;
421 }
422
423 if f_left.abs() < f_right.abs() {
424 std::mem::swap(&mut left, &mut right);
425 std::mem::swap(&mut f_left, &mut f_right);
426 }
427 }
428
429 if f_s.abs() < tol {
430 Ok(s)
431 } else {
432 Ok(right)
433 }
434}
435
436pub fn itp<N, F>(initial: (N, N), mut f: F, k_1: N, k_2: N, n_0: N, tol: N) -> Result<N, String>
458where
459 N: RealField + FromPrimitive + Copy,
460 F: FnMut(N) -> N,
461{
462 if !tol.is_sign_positive() {
463 return Err("itp: tolerance must be positive".to_owned());
464 }
465
466 if !k_1.is_sign_positive() {
467 return Err("itp: k_1 must be positive".to_owned());
468 }
469
470 if k_2 <= N::one() || k_2 >= (N::one() + N::from_f64(0.5 * (1.0 + 5.0_f64.sqrt())).unwrap()) {
471 return Err("itp: k_2 must be in (1, 1 + golden_ratio)".to_owned());
472 }
473
474 let mut left = initial.0;
475 let mut right = initial.1;
476 let mut f_left = f(left);
477 let mut f_right = f(right);
478
479 if !(f_left * f_right).is_sign_negative() {
480 return Err("itp: initial guesses must bracket root".to_owned());
481 }
482
483 if f_left.is_sign_positive() {
484 std::mem::swap(&mut left, &mut right);
485 std::mem::swap(&mut f_left, &mut f_right);
486 }
487
488 let two = N::from_i32(2).unwrap();
489
490 let n_half = ((right - left).abs() / (two * tol)).log2().ceil();
491 let n_max = n_half + n_0;
492 let mut j = 0;
493
494 while (right - left).abs() > two * tol {
495 let x_half = (left + right) / two;
496 let r = tol * two.powf(n_max + n_0 - N::from_i32(j).unwrap()) - (right - left) / two;
497 let x_f = (f_right * left - f_left * right) / (f_right - f_left);
498 let sigma = (x_half - x_f) / (x_half - x_f).abs();
499 let delta = k_1 * (right - left).powf(k_2);
500 let x_t = if delta <= (x_half - x_f).abs() {
501 x_f + sigma * delta
502 } else {
503 x_half
504 };
505 let x_itp = if (x_t - x_half).abs() <= r {
506 x_t
507 } else {
508 x_half - sigma * r
509 };
510 let f_itp = f(x_itp);
511 if f_itp.is_sign_positive() {
512 right = x_itp;
513 f_right = f_itp;
514 } else if f_itp.is_sign_negative() {
515 left = x_itp;
516 f_left = f_itp;
517 } else {
518 left = x_itp;
519 right = x_itp;
520 }
521 j += 1;
522 }
523
524 Ok((left + right) / two)
525}