use std::cell::RefCell;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
#[pyfunction]
fn version() -> &'static str {
env!("CARGO_PKG_VERSION")
}
#[pyfunction]
fn default_units_toml() -> &'static str {
vle_units::default_units_toml()
}
#[pyfunction]
fn solve_cubic(a: f64, b: f64, c: f64, d: f64) -> PyResult<Vec<f64>> {
crate::numerics::cubic::solve_real(a, b, c, d).map_err(|e| PyValueError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f, a, b, xtol = 1e-9, max_iter = 100))]
fn brent(py: Python<'_>, f: PyObject, a: f64, b: f64, xtol: f64, max_iter: usize) -> PyResult<f64> {
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::root_finding::brent(
|x| call_scalar_callback(py, &f, x, &err_cache),
a,
b,
xtol,
max_iter,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f, a, b, xtol = 1e-9, max_iter = 100))]
fn illinois(
py: Python<'_>,
f: PyObject,
a: f64,
b: f64,
xtol: f64,
max_iter: usize,
) -> PyResult<f64> {
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::root_finding::illinois(
|x| call_scalar_callback(py, &f, x, &err_cache),
a,
b,
xtol,
max_iter,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f_and_derivs, x0, xtol = 1e-12, max_iter = 50))]
fn halley(
py: Python<'_>,
f_and_derivs: PyObject,
x0: f64,
xtol: f64,
max_iter: usize,
) -> PyResult<f64> {
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::halley::halley(
|x| match f_and_derivs
.call1(py, (x,))
.and_then(|r| r.extract::<(f64, f64, f64)>(py))
{
Ok(triple) => triple,
Err(e) => {
if err_cache.borrow().is_none() {
*err_cache.borrow_mut() = Some(e);
}
(f64::NAN, f64::NAN, f64::NAN)
}
},
x0,
xtol,
max_iter,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f, x0, xtol = 1e-8, ftol = 1e-8, max_iter = 100, refresh_every = 5, fd_step = 1e-7))]
#[allow(clippy::too_many_arguments)]
fn broyden(
py: Python<'_>,
f: PyObject,
x0: Vec<f64>,
xtol: f64,
ftol: f64,
max_iter: usize,
refresh_every: usize,
fd_step: f64,
) -> PyResult<Vec<f64>> {
let cfg = crate::numerics::broyden::BroydenConfig {
xtol,
ftol,
max_iter,
refresh_every,
fd_step,
};
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::broyden::broyden(
|x| match f
.call1(py, (x.to_vec(),))
.and_then(|r| r.extract::<Vec<f64>>(py))
{
Ok(v) => v,
Err(e) => {
if err_cache.borrow().is_none() {
*err_cache.borrow_mut() = Some(e);
}
vec![f64::NAN; x.len()]
}
},
&x0,
cfg,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| match e {
crate::numerics::broyden::BroydenError::DimensionMismatch { .. } => {
PyValueError::new_err(e.to_string())
}
_ => PyRuntimeError::new_err(e.to_string()),
})
}
fn call_scalar_callback(
py: Python<'_>,
f: &PyObject,
x: f64,
err_cache: &RefCell<Option<PyErr>>,
) -> f64 {
if err_cache.borrow().is_some() {
return f64::NAN;
}
match f.call1(py, (x,)).and_then(|r| r.extract::<f64>(py)) {
Ok(v) => v,
Err(e) => {
*err_cache.borrow_mut() = Some(e);
f64::NAN
}
}
}
#[pyfunction]
fn sum_frac_residual(xs: Vec<f64>) -> f64 {
crate::numerics::utils::sum_frac_residual(&xs)
}
#[pyfunction]
fn norm_l1(xs: Vec<f64>) -> f64 {
crate::numerics::utils::norm_l1(&xs)
}
#[pyfunction]
fn norm_l2(xs: Vec<f64>) -> f64 {
crate::numerics::utils::norm_l2(&xs)
}
#[pyfunction]
fn norm_linf(xs: Vec<f64>) -> f64 {
crate::numerics::utils::norm_linf(&xs)
}
#[pymodule]
fn _engine(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(version, m)?)?;
m.add_function(wrap_pyfunction!(default_units_toml, m)?)?;
m.add_function(wrap_pyfunction!(solve_cubic, m)?)?;
m.add_function(wrap_pyfunction!(brent, m)?)?;
m.add_function(wrap_pyfunction!(illinois, m)?)?;
m.add_function(wrap_pyfunction!(halley, m)?)?;
m.add_function(wrap_pyfunction!(broyden, m)?)?;
m.add_function(wrap_pyfunction!(sum_frac_residual, m)?)?;
m.add_function(wrap_pyfunction!(norm_l1, m)?)?;
m.add_function(wrap_pyfunction!(norm_l2, m)?)?;
m.add_function(wrap_pyfunction!(norm_linf, m)?)?;
m.add_class::<crate::eos::CubicEos>()?;
m.add_class::<crate::activity::ActivityModel>()?;
m.add_class::<crate::mixing::MixingRule>()?;
m.add_class::<crate::saturation::SatPressureModel>()?;
Ok(())
}