use crate::ast::{BinaryOp, Equation, Expression, Variable};
use crate::resolution_path::{Operation, ResolutionPath};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum NumericalError {
NoConvergence,
Unstable,
InvalidInitialGuess,
EvaluationFailed(String),
Other(String),
}
impl std::fmt::Display for NumericalError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NumericalError::NoConvergence => write!(f, "Failed to converge within iteration limit"),
NumericalError::Unstable => write!(f, "Numerical instability detected"),
NumericalError::InvalidInitialGuess => write!(f, "Invalid initial guess"),
NumericalError::EvaluationFailed(msg) => {
write!(f, "Function evaluation failed: {}", msg)
}
NumericalError::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for NumericalError {}
pub type NumericalResult<T> = Result<T, NumericalError>;
#[derive(Debug, Clone)]
pub struct NumericalConfig {
pub max_iterations: usize,
pub tolerance: f64,
pub initial_guess: Option<f64>,
pub step_size: f64,
}
impl Default for NumericalConfig {
fn default() -> Self {
Self {
max_iterations: 1000,
tolerance: 1e-10,
initial_guess: None,
step_size: 1e-6,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct NumericalSolution {
pub value: f64,
pub iterations: usize,
pub residual: f64,
pub converged: bool,
}
#[derive(Debug)]
pub struct NewtonRaphson {
config: NumericalConfig,
}
impl NewtonRaphson {
pub fn new(config: NumericalConfig) -> Self {
Self { config }
}
pub fn with_default_config() -> Self {
Self {
config: NumericalConfig::default(),
}
}
pub fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> NumericalResult<(NumericalSolution, ResolutionPath)> {
use crate::resolution_path::ResolutionPathBuilder;
let f = Expression::Binary(
crate::ast::BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let f_prime = f.differentiate(&variable.name);
let mut x = self.config.initial_guess.unwrap_or(1.0);
let mut path = ResolutionPathBuilder::new(f.clone());
path = path.step(
Operation::NumericalApproximation,
format!(
"Starting Newton-Raphson method with initial guess xâ‚€ = {}",
x
),
Expression::Float(x),
);
path = path.step(
Operation::NumericalApproximation,
format!("Using symbolic derivative: f'(x) = {}", f_prime),
f_prime.clone(),
);
let mut converged = false;
let mut iterations = 0;
let mut residual = 0.0;
for i in 0..self.config.max_iterations {
iterations = i + 1;
let mut vars = HashMap::new();
vars.insert(variable.name.clone(), x);
let fx = f.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!(
"Failed to evaluate function at x = {}",
x
))
})?;
residual = fx.abs();
if residual < self.config.tolerance {
converged = true;
path = path.step(
Operation::NumericalApproximation,
format!(
"Converged: |f(x)| = {} < {}",
residual, self.config.tolerance
),
Expression::Float(x),
);
break;
}
let derivative = f_prime.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!(
"Failed to evaluate derivative at x = {}",
x
))
})?;
if derivative.abs() < 1e-14 {
return Err(NumericalError::Unstable);
}
let x_next = x - fx / derivative;
if !x_next.is_finite() {
return Err(NumericalError::Unstable);
}
if i % 10 == 0 || i == self.config.max_iterations - 1 {
path = path.step(
Operation::NumericalApproximation,
format!(
"Iteration {}: x = {}, f(x) = {}, f'(x) = {}",
iterations, x_next, fx, derivative
),
Expression::Float(x_next),
);
}
if (x_next - x).abs() < self.config.tolerance {
x = x_next;
converged = true;
path = path.step(
Operation::NumericalApproximation,
format!(
"Converged: |Δx| = {} < {}",
(x_next - x).abs(),
self.config.tolerance
),
Expression::Float(x),
);
break;
}
x = x_next;
}
if !converged {
return Err(NumericalError::NoConvergence);
}
let solution = NumericalSolution {
value: x,
iterations,
residual,
converged,
};
let final_path = path.finish(Expression::Float(x));
Ok((solution, final_path))
}
}
#[derive(Debug)]
pub struct SecantMethod {
config: NumericalConfig,
}
impl SecantMethod {
pub fn new(config: NumericalConfig) -> Self {
Self { config }
}
pub fn with_default_config() -> Self {
Self {
config: NumericalConfig::default(),
}
}
pub fn solve(
&self,
equation: &Equation,
variable: &Variable,
initial_points: (f64, f64),
) -> NumericalResult<(NumericalSolution, ResolutionPath)> {
use crate::resolution_path::ResolutionPathBuilder;
let (f, eval) = secant_make_eval(equation, variable);
let (mut x_prev, mut x_curr) = initial_points;
let mut f_prev = eval(x_prev)?;
let mut f_curr = eval(x_curr)?;
let mut path = ResolutionPathBuilder::new(f);
path = path.step(
Operation::NumericalApproximation,
format!(
"Starting secant: x0={x_prev}, x1={x_curr}: \
f(x0)={f_prev:.6e}, f(x1)={f_curr:.6e}"
),
Expression::Float(x_curr),
);
let (solution, x_final) = secant_iterate(
&eval,
&mut path,
&mut x_prev,
&mut x_curr,
&mut f_prev,
&mut f_curr,
self.config.max_iterations,
self.config.tolerance,
)?;
path = path.step(
Operation::NumericalApproximation,
format!("Converged: x={x_final:.15}, |f(x)|={:.6e}", f_curr.abs()),
Expression::Float(x_final),
);
Ok((solution, path.finish(Expression::Float(x_final))))
}
}
#[derive(Debug)]
pub struct BisectionMethod {
config: NumericalConfig,
}
impl BisectionMethod {
pub fn new(config: NumericalConfig) -> Self {
Self { config }
}
pub fn with_default_config() -> Self {
Self {
config: NumericalConfig::default(),
}
}
pub fn solve(
&self,
equation: &Equation,
variable: &Variable,
interval: (f64, f64),
) -> NumericalResult<(NumericalSolution, ResolutionPath)> {
use crate::resolution_path::ResolutionPathBuilder;
let f = Expression::Binary(
crate::ast::BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let mut a = interval.0;
let mut b = interval.1;
if a > b {
std::mem::swap(&mut a, &mut b);
}
let mut path = ResolutionPathBuilder::new(f.clone());
let mut vars = HashMap::new();
vars.insert(variable.name.clone(), a);
let fa = f.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!("Failed to evaluate function at x = {}", a))
})?;
vars.insert(variable.name.clone(), b);
let fb = f.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!("Failed to evaluate function at x = {}", b))
})?;
if fa * fb > 0.0 {
return Err(NumericalError::Other(format!(
"Bisection requires f(a) and f(b) to have opposite signs. f({}) = {}, f({}) = {}",
a, fa, b, fb
)));
}
path = path.step(
Operation::NumericalApproximation,
format!(
"Starting bisection method on interval [{}, {}]. f({}) = {}, f({}) = {}",
a, b, a, fa, b, fb
),
Expression::Float((a + b) / 2.0),
);
let mut iterations = 0;
let mut c = (a + b) / 2.0;
let mut fc = 0.0;
for i in 0..self.config.max_iterations {
iterations = i + 1;
c = (a + b) / 2.0;
vars.insert(variable.name.clone(), c);
fc = f.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!(
"Failed to evaluate function at x = {}",
c
))
})?;
if fc.abs() < self.config.tolerance {
path = path.step(
Operation::NumericalApproximation,
format!(
"Converged: |f({})| = {} < {}",
c,
fc.abs(),
self.config.tolerance
),
Expression::Float(c),
);
break;
}
if (b - a) / 2.0 < self.config.tolerance {
path = path.step(
Operation::NumericalApproximation,
format!(
"Converged: interval width {} < {}",
(b - a) / 2.0,
self.config.tolerance
),
Expression::Float(c),
);
break;
}
vars.insert(variable.name.clone(), a);
let fa_curr = f.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!(
"Failed to evaluate function at x = {}",
a
))
})?;
if fa_curr * fc < 0.0 {
b = c;
} else {
a = c;
}
if i % 10 == 0 || i == self.config.max_iterations - 1 {
path = path.step(
Operation::NumericalApproximation,
format!(
"Iteration {}: interval = [{}, {}], midpoint = {}, f(midpoint) = {}",
iterations, a, b, c, fc
),
Expression::Float(c),
);
}
}
let solution = NumericalSolution {
value: c,
iterations,
residual: fc.abs(),
converged: fc.abs() < self.config.tolerance || (b - a) / 2.0 < self.config.tolerance,
};
if !solution.converged {
return Err(NumericalError::NoConvergence);
}
let final_path = path.finish(Expression::Float(c));
Ok((solution, final_path))
}
}
#[derive(Debug)]
pub struct BrentsMethod {
config: NumericalConfig,
}
impl BrentsMethod {
pub fn new(config: NumericalConfig) -> Self {
Self { config }
}
pub fn with_default_config() -> Self {
Self {
config: NumericalConfig::default(),
}
}
pub fn solve(
&self,
equation: &Equation,
variable: &Variable,
interval: (f64, f64),
) -> NumericalResult<(NumericalSolution, ResolutionPath)> {
use crate::resolution_path::ResolutionPathBuilder;
let (f, eval) = brent_make_eval(equation, variable);
let (mut a, mut b, mut fa, mut fb) = brent_init_bracket(interval, &eval)?;
let mut path = ResolutionPathBuilder::new(f);
path = path.step(
Operation::NumericalApproximation,
format!("Starting Brent on [{a}, {b}]: f(a)={fa:.6e}, f(b)={fb:.6e}"),
Expression::Float(b),
);
let mut st = BrentState {
c: a,
fc: fa,
d: b - a,
e: b - a,
mflag: true,
};
let mut iterations = 0;
let mut converged = false;
for i in 0..self.config.max_iterations {
iterations = i + 1;
if fb.abs() < self.config.tolerance {
converged = true;
break;
}
let (s, bisected) = brent_next_point(a, b, fa, fb, &st, self.config.tolerance);
st.mflag = bisected;
let fs = eval(s)?;
if i % 10 == 0 {
let method = if bisected { "bisect" } else { "interpolate" };
path = path.step(
Operation::NumericalApproximation,
format!("Iter {iterations}: x={s:.10}, f(x)={fs:.6e} [{method}]"),
Expression::Float(s),
);
}
brent_update_bracket(&mut a, &mut b, &mut st, &mut fa, &mut fb, s, fs);
if (b - a).abs() < self.config.tolerance {
converged = true;
break;
}
}
if !converged {
return Err(NumericalError::NoConvergence);
}
path = path.step(
Operation::NumericalApproximation,
format!("Converged: x={b:.15}, |f(x)|={:.6e}", fb.abs()),
Expression::Float(b),
);
let sol = NumericalSolution {
value: b,
iterations,
residual: fb.abs(),
converged,
};
Ok((sol, path.finish(Expression::Float(b))))
}
}
#[derive(Debug)]
pub struct GradientDescent {
config: NumericalConfig,
learning_rate: f64,
}
impl GradientDescent {
pub fn new(config: NumericalConfig, learning_rate: f64) -> Self {
Self {
config,
learning_rate,
}
}
pub fn minimize(
&self,
expression: &Expression,
variables: &[Variable],
) -> NumericalResult<HashMap<Variable, f64>> {
if variables.is_empty() {
return Err(NumericalError::Other(
"No variables to optimize".to_string(),
));
}
let derivatives: Vec<Expression> = variables
.iter()
.map(|v| expression.differentiate(&v.name))
.collect();
let mut values: HashMap<String, f64> = variables
.iter()
.map(|v| (v.name.clone(), self.config.initial_guess.unwrap_or(1.0)))
.collect();
let mut prev_value = f64::INFINITY;
for _iteration in 0..self.config.max_iterations {
let current_value = expression.evaluate(&values).ok_or_else(|| {
NumericalError::EvaluationFailed("Failed to evaluate objective".to_string())
})?;
if (prev_value - current_value).abs() < self.config.tolerance {
return Ok(variables
.iter()
.map(|v| (v.clone(), values[&v.name]))
.collect());
}
prev_value = current_value;
for (i, var) in variables.iter().enumerate() {
let grad = derivatives[i].evaluate(&values).ok_or_else(|| {
NumericalError::EvaluationFailed(format!(
"Failed to evaluate gradient for {}",
var.name
))
})?;
if !grad.is_finite() {
return Err(NumericalError::Unstable);
}
let current = values[&var.name];
values.insert(var.name.clone(), current - self.learning_rate * grad);
}
}
Err(NumericalError::NoConvergence)
}
}
#[derive(Debug)]
pub struct LevenbergMarquardt {
config: NumericalConfig,
}
impl LevenbergMarquardt {
pub fn new(config: NumericalConfig) -> Self {
Self { config }
}
pub fn with_default_config() -> Self {
Self {
config: NumericalConfig::default(),
}
}
pub fn solve_least_squares(
&self,
equations: &[Equation],
variables: &[Variable],
) -> NumericalResult<HashMap<Variable, f64>> {
if equations.is_empty() || variables.is_empty() {
return Err(NumericalError::Other(
"Need at least one equation and one variable".to_string(),
));
}
let n = variables.len();
let residuals: Vec<Expression> = equations
.iter()
.map(|eq| {
Expression::Binary(
BinaryOp::Sub,
Box::new(eq.left.clone()),
Box::new(eq.right.clone()),
)
})
.collect();
let jacobian: Vec<Vec<Expression>> = residuals
.iter()
.map(|r| variables.iter().map(|v| r.differentiate(&v.name)).collect())
.collect();
let mut vals: HashMap<String, f64> = variables
.iter()
.map(|v| (v.name.clone(), self.config.initial_guess.unwrap_or(0.0)))
.collect();
let mut lambda = 1e-3_f64;
for _iter in 0..self.config.max_iterations {
let r_vals: Vec<f64> = residuals
.iter()
.map(|r| r.evaluate(&vals).unwrap_or(f64::NAN))
.collect();
if r_vals.iter().any(|v| !v.is_finite()) {
return Err(NumericalError::Unstable);
}
let cost: f64 = r_vals.iter().map(|v| v * v).sum();
if cost < self.config.tolerance * self.config.tolerance {
return Ok(variables
.iter()
.map(|v| (v.clone(), vals[&v.name]))
.collect());
}
let j_vals: Vec<Vec<f64>> = jacobian
.iter()
.map(|row| {
row.iter()
.map(|e| e.evaluate(&vals).unwrap_or(0.0))
.collect()
})
.collect();
let mut jtj = vec![vec![0.0; n]; n];
let mut jtr = vec![0.0; n];
for (i, j_row) in j_vals.iter().enumerate() {
for j in 0..n {
jtr[j] += j_row[j] * r_vals[i];
for k in 0..n {
jtj[j][k] += j_row[j] * j_row[k];
}
}
}
for j in 0..n {
jtj[j][j] += lambda;
}
let delta = solve_linear_nxn(&jtj, &jtr.iter().map(|v| -v).collect::<Vec<_>>())
.ok_or_else(|| NumericalError::Other("Singular matrix".to_string()))?;
let mut trial = vals.clone();
for (j, var) in variables.iter().enumerate() {
*trial.get_mut(&var.name).unwrap() += delta[j];
}
let trial_cost: f64 = residuals
.iter()
.map(|r| r.evaluate(&trial).unwrap_or(f64::NAN).powi(2))
.sum();
if trial_cost < cost {
vals = trial;
lambda *= 0.5;
} else {
lambda *= 2.0;
}
if delta.iter().map(|d| d.abs()).fold(0.0_f64, f64::max) < self.config.tolerance {
return Ok(variables
.iter()
.map(|v| (v.clone(), vals[&v.name]))
.collect());
}
}
Err(NumericalError::NoConvergence)
}
}
fn solve_linear_nxn(a: &[Vec<f64>], b: &[f64]) -> Option<Vec<f64>> {
let n = b.len();
let mut aug: Vec<Vec<f64>> = a
.iter()
.enumerate()
.map(|(i, row)| {
let mut r = row.clone();
r.push(b[i]);
r
})
.collect();
for col in 0..n {
let max_row =
(col..n).max_by(|&a, &b| aug[a][col].abs().partial_cmp(&aug[b][col].abs()).unwrap())?;
aug.swap(col, max_row);
let pivot = aug[col][col];
if pivot.abs() < 1e-15 {
return None;
}
for row in (col + 1)..n {
let factor = aug[row][col] / pivot;
for j in col..=n {
let val = aug[col][j];
aug[row][j] -= factor * val;
}
}
}
let mut x = vec![0.0; n];
for i in (0..n).rev() {
let mut sum = aug[i][n];
for j in (i + 1)..n {
sum -= aug[i][j] * x[j];
}
x[i] = sum / aug[i][i];
}
Some(x)
}
#[derive(Debug)]
pub struct SmartNumericalSolver {
config: NumericalConfig,
}
impl SmartNumericalSolver {
pub fn new(config: NumericalConfig) -> Self {
Self { config }
}
pub fn with_default_config() -> Self {
Self {
config: NumericalConfig::default(),
}
}
pub fn solve(
&self,
equation: &Equation,
variable: &Variable,
) -> NumericalResult<(NumericalSolution, ResolutionPath)> {
let f = Expression::Binary(
crate::ast::BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
if self.config.initial_guess.is_some() {
let newton = NewtonRaphson::new(self.config.clone());
if let Ok(result) = newton.solve(equation, variable) {
return Ok(result);
}
}
let initial_guess = self.config.initial_guess.unwrap_or(1.0);
if let Some((a, b)) = bracket_root(&f, variable, initial_guess, 10000.0) {
let bisection = BisectionMethod::new(self.config.clone());
if let Ok(result) = bisection.solve(equation, variable, (a, b)) {
return Ok(result);
}
}
let initial_guesses = vec![0.0, 1.0, -1.0, 10.0, -10.0, 100.0, -100.0];
for guess in initial_guesses {
let mut config = self.config.clone();
config.initial_guess = Some(guess);
let newton = NewtonRaphson::new(config);
if let Ok(result) = newton.solve(equation, variable) {
return Ok(result);
}
}
let centers = vec![0.0, 1.0, -1.0, 10.0, -10.0, 100.0];
for center in centers {
if let Some((a, b)) = bracket_root(&f, variable, center, 10000.0) {
let bisection = BisectionMethod::new(self.config.clone());
if let Ok(result) = bisection.solve(equation, variable, (a, b)) {
return Ok(result);
}
}
}
Err(NumericalError::NoConvergence)
}
pub fn solve_with_interval(
&self,
equation: &Equation,
variable: &Variable,
interval: (f64, f64),
) -> NumericalResult<(NumericalSolution, ResolutionPath)> {
let bisection = BisectionMethod::new(self.config.clone());
bisection.solve(equation, variable, interval)
}
}
pub struct Evaluator {
variables: HashMap<Variable, f64>,
}
impl Evaluator {
pub fn new() -> Self {
Self {
variables: HashMap::new(),
}
}
pub fn with_variables(variables: HashMap<Variable, f64>) -> Self {
Self { variables }
}
pub fn set_variable(&mut self, var: Variable, value: f64) {
self.variables.insert(var, value);
}
pub fn evaluate(&self, _expression: &Expression) -> Result<f64, String> {
Err("Not yet implemented".to_string())
}
}
impl Default for Evaluator {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
fn compute_derivative_fd(
expr: &Expression,
variable: &Variable,
x: f64,
h: f64,
) -> NumericalResult<f64> {
let mut vars = HashMap::new();
vars.insert(variable.name.clone(), x + h);
let f_plus = expr.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!("Failed to evaluate function at x = {}", x + h))
})?;
vars.insert(variable.name.clone(), x - h);
let f_minus = expr.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!("Failed to evaluate function at x = {}", x - h))
})?;
let derivative = (f_plus - f_minus) / (2.0 * h);
if !derivative.is_finite() {
return Err(NumericalError::Unstable);
}
Ok(derivative)
}
fn bracket_root(
expr: &Expression,
variable: &Variable,
center: f64,
max_range: f64,
) -> Option<(f64, f64)> {
let mut vars = HashMap::new();
for scale in [1.0_f64, 10.0, 100.0, 1000.0] {
let range = scale.min(max_range);
for offset in [0.0, range / 4.0, range / 2.0, 3.0 * range / 4.0] {
let a = center - range + offset;
let b = center + range - offset;
vars.insert(variable.name.clone(), a);
let fa = expr.evaluate(&vars)?;
vars.insert(variable.name.clone(), b);
let fb = expr.evaluate(&vars)?;
if fa * fb < 0.0 {
return Some((a, b));
}
}
}
None
}
struct BrentState {
c: f64,
fc: f64,
d: f64,
e: f64,
mflag: bool,
}
fn brent_make_eval(
equation: &Equation,
variable: &Variable,
) -> (Expression, impl Fn(f64) -> NumericalResult<f64>) {
let f = Expression::Binary(
crate::ast::BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let f_owned = f.clone();
let var_name = variable.name.clone();
let eval = move |xv: f64| -> NumericalResult<f64> {
let mut vars = HashMap::new();
vars.insert(var_name.clone(), xv);
f_owned.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!("Failed to evaluate at x = {xv}"))
})
};
(f, eval)
}
fn brent_init_bracket(
interval: (f64, f64),
eval: &impl Fn(f64) -> NumericalResult<f64>,
) -> NumericalResult<(f64, f64, f64, f64)> {
let (mut a, mut b) = interval;
let mut fa = eval(a)?;
let mut fb = eval(b)?;
if fa * fb > 0.0 {
return Err(NumericalError::Other(format!(
"Brent's method requires f(a) and f(b) to have opposite signs. \
f({}) = {}, f({}) = {}",
a, fa, b, fb
)));
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
Ok((a, b, fa, fb))
}
fn brent_next_point(a: f64, b: f64, fa: f64, fb: f64, st: &BrentState, tol: f64) -> (f64, bool) {
let s = brent_interpolation_step(a, b, st.c, fa, fb, st.fc, tol);
let bisect_mid = (a + b) / 2.0;
let lo = (3.0 * a + b) / 4.0;
let hi = b;
let in_bracket = if lo <= hi {
lo <= s && s <= hi
} else {
hi <= s && s <= lo
};
let reject = !in_bracket
|| (st.mflag && (s - b).abs() >= (b - st.c).abs() / 2.0)
|| (!st.mflag && (s - b).abs() >= st.e.abs() / 2.0)
|| (st.mflag && (b - st.c).abs() < tol)
|| (!st.mflag && st.e.abs() < tol);
if reject {
(bisect_mid, true)
} else {
(s, false)
}
}
fn brent_update_bracket(
a: &mut f64,
b: &mut f64,
st: &mut BrentState,
fa: &mut f64,
fb: &mut f64,
s_final: f64,
fs: f64,
) {
st.d = st.e;
st.e = *b - *a;
st.c = *b;
st.fc = *fb;
if *fa * fs < 0.0 {
*b = s_final;
*fb = fs;
} else {
*a = s_final;
*fa = fs;
}
if fa.abs() < fb.abs() {
std::mem::swap(a, b);
std::mem::swap(fa, fb);
}
}
fn brent_interpolation_step(a: f64, b: f64, c: f64, fa: f64, fb: f64, fc: f64, tol: f64) -> f64 {
if (fa - fc).abs() > tol && (fb - fc).abs() > tol {
let r = fb / fc;
let s = fb / fa;
let t = fa / fc;
let p = s * (t * (r - t) * (c - b) - (1.0 - r) * (b - a));
let q = (t - 1.0) * (r - 1.0) * (s - 1.0);
if q.abs() > tol {
return b + p / q;
}
}
let denom = fb - fa;
if denom.abs() > tol {
b - fb * (b - a) / denom
} else {
b + 2.0 * (b - a) }
}
fn secant_make_eval(
equation: &Equation,
variable: &Variable,
) -> (Expression, impl Fn(f64) -> NumericalResult<f64>) {
let f = Expression::Binary(
crate::ast::BinaryOp::Sub,
Box::new(equation.left.clone()),
Box::new(equation.right.clone()),
);
let f_owned = f.clone();
let var_name = variable.name.clone();
let eval = move |xv: f64| -> NumericalResult<f64> {
let mut vars = HashMap::new();
vars.insert(var_name.clone(), xv);
f_owned.evaluate(&vars).ok_or_else(|| {
NumericalError::EvaluationFailed(format!("Failed to evaluate at x = {xv}"))
})
};
(f, eval)
}
fn secant_step(x_prev: f64, x_curr: f64, f_prev: f64, f_curr: f64) -> NumericalResult<f64> {
let denom = f_curr - f_prev;
if denom.abs() < f64::EPSILON * 10.0 {
return Err(NumericalError::Unstable);
}
Ok(x_curr - f_curr * (x_curr - x_prev) / denom)
}
#[allow(clippy::too_many_arguments)]
fn secant_iterate(
eval: &impl Fn(f64) -> NumericalResult<f64>,
path: &mut crate::resolution_path::ResolutionPathBuilder,
x_prev: &mut f64,
x_curr: &mut f64,
f_prev: &mut f64,
f_curr: &mut f64,
max_iterations: usize,
tolerance: f64,
) -> NumericalResult<(NumericalSolution, f64)> {
let mut iterations = 0_usize;
let mut converged = false;
for i in 0..max_iterations {
iterations = i + 1;
if f_curr.abs() < tolerance {
converged = true;
break;
}
let x_next = secant_step(*x_prev, *x_curr, *f_prev, *f_curr)?;
if i % 10 == 0 {
*path = (*path).clone().step(
Operation::NumericalApproximation,
format!(
"Iter {iterations}: x={x_next:.10}, f(x_curr)={:.6e}",
f_curr
),
Expression::Float(x_next),
);
}
let f_next = eval(x_next)?;
*x_prev = *x_curr;
*f_prev = *f_curr;
*x_curr = x_next;
*f_curr = f_next;
if (*x_curr - *x_prev).abs() < tolerance {
converged = true;
break;
}
}
if !converged {
return Err(NumericalError::NoConvergence);
}
let solution = NumericalSolution {
value: *x_curr,
iterations,
residual: f_curr.abs(),
converged,
};
Ok((solution, *x_curr))
}