use crate::error::{MathError, MathResult};
use crate::solvers::{brent, SolverConfig, SolverResult};
pub fn hybrid<F, DF>(
f: F,
df: DF,
initial_guess: f64,
bounds: Option<(f64, f64)>,
config: &SolverConfig,
) -> MathResult<SolverResult>
where
F: Fn(f64) -> f64,
DF: Fn(f64) -> f64,
{
let newton_result = newton_with_monitoring(&f, &df, initial_guess, config);
match newton_result {
Ok(result) => Ok(result),
Err(_) => {
if let Some((a, b)) = bounds {
brent(&f, a, b, config)
} else {
match find_bracket(&f, initial_guess) {
Some((a, b)) => brent(&f, a, b, config),
None => Err(MathError::invalid_input(
"Newton-Raphson failed and could not find bracketing interval for Brent",
)),
}
}
}
}
}
fn newton_with_monitoring<F, DF>(
f: &F,
df: &DF,
initial_guess: f64,
config: &SolverConfig,
) -> MathResult<SolverResult>
where
F: Fn(f64) -> f64,
DF: Fn(f64) -> f64,
{
let mut x = initial_guess;
let mut prev_residual = f64::MAX;
let mut divergence_count = 0;
const MAX_DIVERGENCE: u32 = 3;
let newton_max_iter = config.max_iterations.min(20);
for iteration in 0..newton_max_iter {
let fx = f(x);
let residual = fx.abs();
if residual < config.tolerance {
return Ok(SolverResult {
root: x,
iterations: iteration,
residual: fx,
});
}
if residual > prev_residual * 2.0 {
divergence_count += 1;
if divergence_count >= MAX_DIVERGENCE {
return Err(MathError::invalid_input("Newton-Raphson diverging"));
}
} else {
divergence_count = 0;
}
prev_residual = residual;
let dfx = df(x);
if dfx.abs() < 1e-15 {
return Err(MathError::DivisionByZero { value: dfx });
}
let step = fx / dfx;
if step.abs() > 1e10 {
return Err(MathError::invalid_input("Newton step too large"));
}
x -= step;
if !x.is_finite() {
return Err(MathError::invalid_input("Newton produced non-finite value"));
}
if step.abs() < config.tolerance {
let final_fx = f(x);
return Ok(SolverResult {
root: x,
iterations: iteration + 1,
residual: final_fx,
});
}
}
Err(MathError::convergence_failed(newton_max_iter, f(x).abs()))
}
fn find_bracket<F>(f: &F, initial_guess: f64) -> Option<(f64, f64)>
where
F: Fn(f64) -> f64,
{
let mut left = initial_guess;
let mut right = initial_guess;
let mut delta = 0.1;
if initial_guess.abs() < 1e-10 {
left = -1.0;
right = 1.0;
}
let f_init = f(initial_guess);
for _ in 0..50 {
left -= delta;
right += delta;
let f_left = f(left);
let f_right = f(right);
if f_left * f_init < 0.0 {
return Some((left, initial_guess));
}
if f_right * f_init < 0.0 {
return Some((initial_guess, right));
}
if f_left * f_right < 0.0 {
return Some((left, right));
}
delta *= 2.0;
if delta > 1e6 {
break;
}
}
None
}
pub fn hybrid_numerical<F>(
f: F,
initial_guess: f64,
bounds: Option<(f64, f64)>,
config: &SolverConfig,
) -> MathResult<SolverResult>
where
F: Fn(f64) -> f64,
{
let h = 1e-8;
let df = |x: f64| (f(x + h) - f(x - h)) / (2.0 * h);
hybrid(&f, df, initial_guess, bounds, config)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_sqrt_2() {
let f = |x: f64| x * x - 2.0;
let df = |x: f64| 2.0 * x;
let result = hybrid(f, df, 1.5, Some((1.0, 2.0)), &SolverConfig::default()).unwrap();
assert_relative_eq!(result.root, std::f64::consts::SQRT_2, epsilon = 1e-10);
}
#[test]
fn test_cubic() {
let f = |x: f64| x * x * x - x - 2.0;
let df = |x: f64| 3.0 * x * x - 1.0;
let result = hybrid(f, df, 1.5, Some((1.0, 2.0)), &SolverConfig::default()).unwrap();
assert!(f(result.root).abs() < 1e-10);
}
#[test]
fn test_fallback_to_brent() {
let f = |x: f64| x * x * x - 2.0 * x - 5.0;
let df = |x: f64| 3.0 * x * x - 2.0;
let result = hybrid(f, df, 0.0, Some((1.0, 3.0)), &SolverConfig::default()).unwrap();
assert!(f(result.root).abs() < 1e-10);
}
#[test]
fn test_numerical_derivative() {
let f = |x: f64| x * x - 2.0;
let result = hybrid_numerical(f, 1.5, Some((1.0, 2.0)), &SolverConfig::default()).unwrap();
assert_relative_eq!(result.root, std::f64::consts::SQRT_2, epsilon = 1e-8);
}
#[test]
fn test_auto_bracket_finding() {
let f = |x: f64| x * x - 2.0;
let df = |x: f64| 2.0 * x;
let result = hybrid(f, df, 1.5, None, &SolverConfig::default()).unwrap();
assert_relative_eq!(result.root, std::f64::consts::SQRT_2, epsilon = 1e-10);
}
#[test]
fn test_ytm_like_calculation() {
let target_price = 95.0;
let coupon = 5.0;
let face = 100.0;
let years = 5;
let price_from_yield = |y: f64| {
let mut pv = 0.0;
for t in 1..=years {
pv += coupon / (1.0 + y).powi(t);
}
pv += face / (1.0 + y).powi(years);
pv - target_price
};
let d_price_from_yield = |y: f64| {
let mut dpv = 0.0;
for t in 1..=years {
dpv -= (t as f64) * coupon / (1.0 + y).powi(t + 1);
}
dpv -= (years as f64) * face / (1.0 + y).powi(years + 1);
dpv
};
let result = hybrid(
price_from_yield,
d_price_from_yield,
0.05, Some((0.0, 0.20)),
&SolverConfig::default(),
)
.unwrap();
assert!(price_from_yield(result.root).abs() < 1e-10);
assert!(result.root > 0.05);
}
}