use pyo3::prelude::*;
use pyo3::types::PyDict;
use scirs2_numpy::IntoPyArray;
use scirs2_core::{ndarray::ArrayView1, Array1};
use scirs2_optimize::global::{differential_evolution, DifferentialEvolutionOptions};
use scirs2_optimize::scalar::{minimize_scalar, Method as ScalarMethod, Options as ScalarOptions};
use scirs2_optimize::unconstrained::{minimize, Bounds, Method, Options};
#[pyfunction]
#[pyo3(signature = (fun, bracket, method="brent", options=None))]
fn minimize_scalar_py(
py: Python,
fun: &Bound<'_, PyAny>,
bracket: (f64, f64),
method: &str,
options: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<PyAny>> {
let maxiter = options
.and_then(|o| o.get_item("maxiter").ok().flatten())
.and_then(|v| v.extract().ok());
let tol = options
.and_then(|o| o.get_item("tol").ok().flatten())
.and_then(|v| v.extract().ok());
let fun_clone = fun.clone().unbind();
let f = move |x: f64| -> f64 {
Python::attach(|py| {
let result = fun_clone
.bind(py)
.call1((x,))
.unwrap_or_else(|_| py.None().into_bound(py));
result.extract::<f64>().unwrap_or(f64::NAN)
})
};
let scalar_method = match method {
"brent" => ScalarMethod::Brent,
"golden" => ScalarMethod::Golden,
"bounded" => ScalarMethod::Bounded,
_ => ScalarMethod::Brent,
};
let mut options = ScalarOptions::default();
if let Some(mi) = maxiter {
options.max_iter = mi;
}
if let Some(t) = tol {
options.xatol = t;
}
let result = minimize_scalar(f, Some(bracket), scalar_method, Some(options))
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
let dict = PyDict::new(py);
dict.set_item("x", result.x)?;
dict.set_item("fun", result.fun)?;
dict.set_item("success", result.success)?;
dict.set_item("nit", result.nit)?;
dict.set_item("nfev", result.function_evals)?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (fun, a, b, xtol=1e-12, maxiter=100))]
fn brentq_py(
py: Python,
fun: &Bound<'_, PyAny>,
a: f64,
b: f64,
xtol: f64,
maxiter: usize,
) -> PyResult<Py<PyAny>> {
let fun_clone = fun.clone().unbind();
let f = |x: f64| -> f64 {
Python::attach(|py| {
let result = fun_clone
.bind(py)
.call1((x,))
.unwrap_or_else(|_| py.None().into_bound(py));
result.extract::<f64>().unwrap_or(f64::NAN)
})
};
let mut a = a;
let mut b = b;
let mut fa = f(a);
let mut fb = f(b);
if fa * fb > 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"f(a) and f(b) must have opposite signs",
));
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut d = b - a;
let mut e = d;
let mut iter = 0;
while iter < maxiter {
if fb.abs() < xtol {
let dict = PyDict::new(py);
dict.set_item("x", b)?;
dict.set_item("fun", fb)?;
dict.set_item("iterations", iter)?;
dict.set_item("success", true)?;
return Ok(dict.into());
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
c = a;
fc = fa;
}
let tol = 2.0 * f64::EPSILON * b.abs() + xtol;
let m = (c - b) / 2.0;
if m.abs() <= tol {
let dict = PyDict::new(py);
dict.set_item("x", b)?;
dict.set_item("fun", fb)?;
dict.set_item("iterations", iter)?;
dict.set_item("success", true)?;
return Ok(dict.into());
}
let mut use_bisection = true;
if e.abs() >= tol && fa.abs() > fb.abs() {
let s = fb / fa;
let (p, q) = if (a - c).abs() < 1e-14 {
(2.0 * m * s, 1.0 - s)
} else {
let q = fa / fc;
let r = fb / fc;
(
s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0)),
(q - 1.0) * (r - 1.0) * (s - 1.0),
)
};
let (p, q) = if p > 0.0 { (p, -q) } else { (-p, q) };
if 2.0 * p < 3.0 * m * q - (tol * q).abs() && p < (e * q / 2.0).abs() {
e = d;
d = p / q;
use_bisection = false;
}
}
if use_bisection {
d = m;
e = m;
}
a = b;
fa = fb;
if d.abs() > tol {
b += d;
} else {
b += if m > 0.0 { tol } else { -tol };
}
fb = f(b);
if (fb > 0.0) == (fc > 0.0) {
c = a;
fc = fa;
d = b - a;
e = d;
}
iter += 1;
}
let dict = PyDict::new(py);
dict.set_item("x", b)?;
dict.set_item("fun", fb)?;
dict.set_item("iterations", iter)?;
dict.set_item("success", false)?;
dict.set_item("message", "Maximum iterations reached")?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (fun, x0, method="bfgs", options=None, bounds=None))]
fn minimize_py(
py: Python,
fun: &Bound<'_, PyAny>,
x0: Vec<f64>,
method: &str,
options: Option<&Bound<'_, PyDict>>,
bounds: Option<Vec<(f64, f64)>>,
) -> PyResult<Py<PyAny>> {
let opt_method = match method.to_lowercase().as_str() {
"nelder-mead" | "neldermead" => Method::NelderMead,
"powell" => Method::Powell,
"cg" | "conjugate-gradient" => Method::CG,
"bfgs" => Method::BFGS,
"lbfgs" | "l-bfgs" => Method::LBFGS,
"lbfgsb" | "l-bfgs-b" => Method::LBFGSB,
"newton-cg" => Method::NewtonCG,
"trust-ncg" => Method::TrustNCG,
"sr1" => Method::SR1,
"dfp" => Method::DFP,
_ => Method::BFGS, };
let maxiter = options
.and_then(|o| o.get_item("maxiter").ok().flatten())
.and_then(|v| v.extract().ok());
let ftol = options
.and_then(|o| o.get_item("ftol").ok().flatten())
.and_then(|v| v.extract().ok());
let gtol = options
.and_then(|o| o.get_item("gtol").ok().flatten())
.and_then(|v| v.extract().ok());
let mut opt_options = Options::default();
if let Some(mi) = maxiter {
opt_options.max_iter = mi;
}
if let Some(ft) = ftol {
opt_options.ftol = ft;
}
if let Some(gt) = gtol {
opt_options.gtol = gt;
}
if let Some(b) = bounds {
let n = x0.len();
let mut lower = vec![None; n];
let mut upper = vec![None; n];
for (i, (l, u)) in b.iter().enumerate() {
if i < n {
lower[i] = Some(*l);
upper[i] = Some(*u);
}
}
opt_options.bounds = Some(Bounds { lower, upper });
}
let fun_arc = std::sync::Arc::new(fun.clone().unbind());
let f = move |x: &ArrayView1<f64>| -> f64 {
let fun_clone = fun_arc.clone();
Python::attach(|py| {
let x_vec: Vec<f64> = x.to_vec();
let result = fun_clone
.bind(py)
.call1((x_vec,))
.unwrap_or_else(|_| py.None().into_bound(py));
result.extract::<f64>().unwrap_or(f64::NAN)
})
};
let result = minimize(f, &x0, opt_method, Some(opt_options))
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
let dict = PyDict::new(py);
dict.set_item("x", result.x.into_pyarray(py).unbind())?;
dict.set_item("fun", result.fun)?;
dict.set_item("success", result.success)?;
dict.set_item("message", result.message)?;
dict.set_item("nit", result.nit)?;
dict.set_item("nfev", result.func_evals)?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (fun, bounds, options=None))]
fn differential_evolution_py(
py: Python,
fun: &Bound<'_, PyAny>,
bounds: Vec<(f64, f64)>,
options: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<PyAny>> {
let maxiter = options
.and_then(|o| o.get_item("maxiter").ok().flatten())
.and_then(|v| v.extract().ok());
let popsize = options
.and_then(|o| o.get_item("popsize").ok().flatten())
.and_then(|v| v.extract().ok());
let tol = options
.and_then(|o| o.get_item("tol").ok().flatten())
.and_then(|v| v.extract().ok());
let seed = options
.and_then(|o| o.get_item("seed").ok().flatten())
.and_then(|v| v.extract().ok());
let fun_arc = std::sync::Arc::new(fun.clone().unbind());
let f = move |x: &ArrayView1<f64>| -> f64 {
let fun_clone = fun_arc.clone();
Python::attach(|py| {
let x_vec: Vec<f64> = x.to_vec();
let result = fun_clone
.bind(py)
.call1((x_vec,))
.unwrap_or_else(|_| py.None().into_bound(py));
result.extract::<f64>().unwrap_or(f64::NAN)
})
};
let mut de_options = DifferentialEvolutionOptions::default();
if let Some(mi) = maxiter {
de_options.maxiter = mi;
}
if let Some(ps) = popsize {
de_options.popsize = ps;
}
if let Some(t) = tol {
de_options.tol = t;
}
if let Some(s) = seed {
de_options.seed = Some(s);
}
let result = differential_evolution(f, bounds.to_vec(), Some(de_options), None)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
let dict = PyDict::new(py);
dict.set_item("x", result.x.into_pyarray(py).unbind())?;
dict.set_item("fun", result.fun)?;
dict.set_item("success", result.success)?;
dict.set_item("message", result.message)?;
dict.set_item("nit", result.nit)?;
dict.set_item("nfev", result.func_evals)?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (f, xdata, ydata, p0=None, method="lm", maxfev=1000))]
fn curve_fit_py(
py: Python,
f: &Bound<'_, PyAny>,
xdata: Vec<f64>,
ydata: Vec<f64>,
p0: Option<Vec<f64>>,
method: &str,
maxfev: usize,
) -> PyResult<Py<PyAny>> {
use scirs2_optimize::least_squares::{least_squares, Method as LSMethod, Options as LSOptions};
if xdata.len() != ydata.len() {
return Err(pyo3::exceptions::PyValueError::new_err(
"xdata and ydata must have the same length",
));
}
let n_data = xdata.len();
let params_init = p0.unwrap_or_else(|| vec![1.0; 2]);
let ls_method = match method.to_lowercase().as_str() {
"lm" => LSMethod::LevenbergMarquardt,
"trf" => LSMethod::TrustRegionReflective,
"dogbox" => LSMethod::Dogbox,
_ => LSMethod::LevenbergMarquardt,
};
let xdata_clone = xdata.clone();
let ydata_clone = ydata.clone();
let f_arc = std::sync::Arc::new(f.clone().unbind());
let residual_fn = move |params: &[f64], _data: &[f64]| -> Array1<f64> {
let f_clone = f_arc.clone();
let xdata_ref = &xdata_clone;
let ydata_ref = &ydata_clone;
Python::attach(|py| {
let mut residuals = Vec::with_capacity(n_data);
for i in 0..n_data {
let mut args = vec![xdata_ref[i]];
args.extend_from_slice(params);
let tuple = pyo3::types::PyTuple::new(py, &args)
.unwrap_or_else(|_| pyo3::types::PyTuple::empty(py));
let f_val: f64 = f_clone
.bind(py)
.call1(tuple)
.unwrap_or_else(|_| py.None().into_bound(py))
.extract()
.unwrap_or(f64::NAN);
residuals.push(ydata_ref[i] - f_val);
}
Array1::from_vec(residuals)
})
};
let options = LSOptions {
max_nfev: Some(maxfev),
..Default::default()
};
let empty_data = Array1::from_vec(vec![]);
let result = least_squares(
residual_fn,
&Array1::from_vec(params_init),
ls_method,
None::<fn(&[f64], &[f64]) -> scirs2_core::ndarray::Array2<f64>>, &empty_data, Some(options),
)
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Curve fitting failed: {}", e))
})?;
let dict = PyDict::new(py);
dict.set_item("popt", result.x.into_pyarray(py).unbind())?;
dict.set_item("success", result.success)?;
dict.set_item("nfev", result.nfev)?;
dict.set_item("message", result.message)?;
Ok(dict.into())
}
pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(minimize_py, m)?)?;
m.add_function(wrap_pyfunction!(minimize_scalar_py, m)?)?;
m.add_function(wrap_pyfunction!(brentq_py, m)?)?;
m.add_function(wrap_pyfunction!(differential_evolution_py, m)?)?;
m.add_function(wrap_pyfunction!(curve_fit_py, m)?)?;
Ok(())
}