use crate::python::array::PyArray;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyAnyMethods, PyDict, PyTuple};
use std::cell::RefCell;
#[pyfunction]
fn minimize(
py: Python<'_>,
fun: Py<PyAny>,
x0: &PyArray,
method: Option<String>,
tol: Option<f64>,
) -> PyResult<Py<PyAny>> {
let x0_vec = x0.tolist();
if x0_vec.is_empty() {
return Err(PyValueError::new_err("x0 must not be empty"));
}
let method_str = method.as_deref().unwrap_or("nelder-mead").to_lowercase();
let captured_err: RefCell<Option<PyErr>> = RefCell::new(None);
let closure_result = match method_str.as_str() {
"nelder-mead" | "nm" | "neldermead" => {
let closure = |x: &[f64]| -> f64 {
Python::attach(|inner_py| {
let py_list: Vec<f64> = x.to_vec();
let args = match PyTuple::new(inner_py, [py_list.into_pyobject(inner_py).ok()])
{
Ok(t) => t,
Err(e) => {
*captured_err.borrow_mut() = Some(e);
return f64::NAN;
}
};
match fun.call(inner_py, args, None) {
Ok(result) => match result.extract::<f64>(inner_py) {
Ok(v) => v,
Err(e) => {
*captured_err.borrow_mut() = Some(e);
f64::NAN
}
},
Err(e) => {
*captured_err.borrow_mut() = Some(e);
f64::NAN
}
}
})
};
let mut config = crate::optimize::OptimizeConfig::<f64>::default();
if let Some(t) = tol {
config.ftol = t;
config.xtol = t;
}
crate::optimize::nelder_mead(closure, &x0_vec, Some(config))
}
"bfgs" => {
let eps = 1e-7_f64;
let closure = |x: &[f64]| -> f64 {
Python::attach(|inner_py| {
let py_list: Vec<f64> = x.to_vec();
let args = match PyTuple::new(inner_py, [py_list.into_pyobject(inner_py).ok()])
{
Ok(t) => t,
Err(e) => {
*captured_err.borrow_mut() = Some(e);
return f64::NAN;
}
};
match fun.call(inner_py, args, None) {
Ok(result) => match result.extract::<f64>(inner_py) {
Ok(v) => v,
Err(e) => {
*captured_err.borrow_mut() = Some(e);
f64::NAN
}
},
Err(e) => {
*captured_err.borrow_mut() = Some(e);
f64::NAN
}
}
})
};
let grad_closure = |x: &[f64]| -> Vec<f64> {
Python::attach(|inner_py| {
let n = x.len();
let mut g = vec![0.0_f64; n];
for i in 0..n {
let mut xph = x.to_vec();
let mut xmh = x.to_vec();
xph[i] += eps;
xmh[i] -= eps;
let call_val = |xv: Vec<f64>| -> f64 {
let args =
match PyTuple::new(inner_py, [xv.into_pyobject(inner_py).ok()]) {
Ok(t) => t,
Err(e) => {
*captured_err.borrow_mut() = Some(e);
return f64::NAN;
}
};
match fun.call(inner_py, args, None) {
Ok(result) => result.extract::<f64>(inner_py).unwrap_or(f64::NAN),
Err(e) => {
*captured_err.borrow_mut() = Some(e);
f64::NAN
}
}
};
g[i] = (call_val(xph) - call_val(xmh)) / (2.0 * eps);
}
g
})
};
let mut config = crate::optimize::OptimizeConfig::<f64>::default();
if let Some(t) = tol {
config.ftol = t;
config.gtol = t;
}
crate::optimize::bfgs(closure, grad_closure, &x0_vec, Some(config))
}
other => {
return Err(PyValueError::new_err(format!(
"Unknown method '{}'. Supported: 'nelder-mead', 'bfgs'",
other
)));
}
};
if let Some(err) = captured_err.into_inner() {
return Err(err);
}
let opt_result = closure_result.map_err(|e| PyValueError::new_err(e.to_string()))?;
let dict = PyDict::new(py);
dict.set_item("x", opt_result.x.into_pyobject(py)?.into_any().unbind())?;
dict.set_item("fun", opt_result.fun)?;
dict.set_item("success", opt_result.success)?;
dict.set_item("nit", opt_result.nit)?;
dict.set_item("nfev", opt_result.nfev)?;
dict.set_item("message", opt_result.message)?;
Ok(dict.into_any().unbind())
}
#[pyfunction]
fn root_scalar(fun: Py<PyAny>, bracket: (f64, f64), method: Option<String>) -> PyResult<f64> {
let (a, b) = bracket;
let method_str = method.as_deref().unwrap_or("brentq").to_lowercase();
let captured_err: RefCell<Option<PyErr>> = RefCell::new(None);
let closure = |x: f64| -> f64 {
Python::attach(|py| {
let args = match PyTuple::new(py, [x]) {
Ok(t) => t,
Err(e) => {
*captured_err.borrow_mut() = Some(e);
return f64::NAN;
}
};
match fun.call(py, args, None) {
Ok(result) => match result.extract::<f64>(py) {
Ok(v) => v,
Err(e) => {
*captured_err.borrow_mut() = Some(e);
f64::NAN
}
},
Err(e) => {
*captured_err.borrow_mut() = Some(e);
f64::NAN
}
}
})
};
let root_result = match method_str.as_str() {
"brentq" | "brent" => crate::roots::brentq(closure, a, b, 1e-10, 1000)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
"bisect" | "bisection" => crate::roots::bisect(closure, a, b, 1e-10, 1000)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
other => {
return Err(PyValueError::new_err(format!(
"Unknown method '{}'. Supported: 'brentq', 'bisect'",
other
)));
}
};
if let Some(err) = captured_err.into_inner() {
return Err(err);
}
Ok(root_result.root)
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
let optimize_module = PyModule::new(m.py(), "optimize")?;
optimize_module.add_function(wrap_pyfunction!(minimize, m)?)?;
optimize_module.add_function(wrap_pyfunction!(root_scalar, m)?)?;
m.add_submodule(&optimize_module)?;
Ok(())
}