use std::cell::RefCell;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use crate::eos::{
alpha as eos_alpha_rs, d_alpha_d_tr as eos_d_alpha_rs, family_constants, h_departure_rt,
ln_phi_pure, s_departure_r, z_factor, CubicEos, EosError, PhaseId,
};
use crate::saturation::{d_psat_dt_antoine, psat_antoine, SatError};
use crate::types::Component;
use crate::virial::{
b_mix as virial_b_mix, h_departure_rt_virial, ln_phi_mix_virial, ln_phi_pure_virial,
pitzer_b, pitzer_b0, pitzer_b1, pitzer_d_b_d_t, s_departure_r_virial, z_factor_virial,
};
#[pyfunction]
fn version() -> &'static str {
env!("CARGO_PKG_VERSION")
}
#[pyfunction]
fn default_units_toml() -> &'static str {
vle_units::default_units_toml()
}
#[pyfunction]
fn solve_cubic(a: f64, b: f64, c: f64, d: f64) -> PyResult<Vec<f64>> {
crate::numerics::cubic::solve_real(a, b, c, d).map_err(|e| PyValueError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f, a, b, xtol = 1e-9, max_iter = 100))]
fn brent(py: Python<'_>, f: PyObject, a: f64, b: f64, xtol: f64, max_iter: usize) -> PyResult<f64> {
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::root_finding::brent(
|x| call_scalar_callback(py, &f, x, &err_cache),
a,
b,
xtol,
max_iter,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f, a, b, xtol = 1e-9, max_iter = 100))]
fn illinois(
py: Python<'_>,
f: PyObject,
a: f64,
b: f64,
xtol: f64,
max_iter: usize,
) -> PyResult<f64> {
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::root_finding::illinois(
|x| call_scalar_callback(py, &f, x, &err_cache),
a,
b,
xtol,
max_iter,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f_and_derivs, x0, xtol = 1e-12, max_iter = 50))]
fn halley(
py: Python<'_>,
f_and_derivs: PyObject,
x0: f64,
xtol: f64,
max_iter: usize,
) -> PyResult<f64> {
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::halley::halley(
|x| match f_and_derivs
.call1(py, (x,))
.and_then(|r| r.extract::<(f64, f64, f64)>(py))
{
Ok(triple) => triple,
Err(e) => {
if err_cache.borrow().is_none() {
*err_cache.borrow_mut() = Some(e);
}
(f64::NAN, f64::NAN, f64::NAN)
}
},
x0,
xtol,
max_iter,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (f, x0, xtol = 1e-8, ftol = 1e-8, max_iter = 100, refresh_every = 5, fd_step = 1e-7))]
#[allow(clippy::too_many_arguments)]
fn broyden(
py: Python<'_>,
f: PyObject,
x0: Vec<f64>,
xtol: f64,
ftol: f64,
max_iter: usize,
refresh_every: usize,
fd_step: f64,
) -> PyResult<Vec<f64>> {
let cfg = crate::numerics::broyden::BroydenConfig {
xtol,
ftol,
max_iter,
refresh_every,
fd_step,
};
let err_cache: RefCell<Option<PyErr>> = RefCell::new(None);
let result = crate::numerics::broyden::broyden(
|x| match f
.call1(py, (x.to_vec(),))
.and_then(|r| r.extract::<Vec<f64>>(py))
{
Ok(v) => v,
Err(e) => {
if err_cache.borrow().is_none() {
*err_cache.borrow_mut() = Some(e);
}
vec![f64::NAN; x.len()]
}
},
&x0,
cfg,
);
if let Some(e) = err_cache.into_inner() {
return Err(e);
}
result.map_err(|e| match e {
crate::numerics::broyden::BroydenError::DimensionMismatch { .. } => {
PyValueError::new_err(e.to_string())
}
_ => PyRuntimeError::new_err(e.to_string()),
})
}
fn call_scalar_callback(
py: Python<'_>,
f: &PyObject,
x: f64,
err_cache: &RefCell<Option<PyErr>>,
) -> f64 {
if err_cache.borrow().is_some() {
return f64::NAN;
}
match f.call1(py, (x,)).and_then(|r| r.extract::<f64>(py)) {
Ok(v) => v,
Err(e) => {
*err_cache.borrow_mut() = Some(e);
f64::NAN
}
}
}
#[pyfunction]
fn sum_frac_residual(xs: Vec<f64>) -> f64 {
crate::numerics::utils::sum_frac_residual(&xs)
}
#[pyfunction]
fn norm_l1(xs: Vec<f64>) -> f64 {
crate::numerics::utils::norm_l1(&xs)
}
#[pyfunction]
fn norm_l2(xs: Vec<f64>) -> f64 {
crate::numerics::utils::norm_l2(&xs)
}
#[pyfunction]
fn norm_linf(xs: Vec<f64>) -> f64 {
crate::numerics::utils::norm_linf(&xs)
}
fn comp_for_eos(tc: f64, pc: f64, omega: f64) -> Component {
Component {
tc,
pc,
omega,
..Component::default()
}
}
fn phase_from_str(s: &str) -> PyResult<PhaseId> {
match s.to_ascii_lowercase().as_str() {
"vapor" | "v" | "gas" => Ok(PhaseId::Vapor),
"liquid" | "l" => Ok(PhaseId::Liquid),
other => Err(PyValueError::new_err(format!(
"phase must be 'vapor' or 'liquid' (got {other:?})"
))),
}
}
fn map_eos_err(e: EosError) -> PyErr {
match e {
EosError::NotImplemented(_) => {
pyo3::exceptions::PyNotImplementedError::new_err(e.to_string())
}
_ => PyRuntimeError::new_err(e.to_string()),
}
}
#[pyfunction]
fn eos_alpha(eos: CubicEos, tr: f64, omega: f64) -> f64 {
eos_alpha_rs(eos, tr, &comp_for_eos(1.0, 1.0, omega))
}
#[pyfunction]
fn eos_d_alpha_d_tr(eos: CubicEos, tr: f64, omega: f64) -> f64 {
eos_d_alpha_rs(eos, tr, &comp_for_eos(1.0, 1.0, omega))
}
#[pyfunction]
fn eos_family_constants(eos: CubicEos) -> (f64, f64, f64, f64) {
let fc = family_constants(eos);
(fc.k1, fc.k2, fc.om_a, fc.om_b)
}
#[pyfunction]
fn eos_z_factor(
eos: CubicEos,
t: f64,
p: f64,
tc: f64,
pc: f64,
omega: f64,
phase: &str,
) -> PyResult<f64> {
let comp = comp_for_eos(tc, pc, omega);
let phase = phase_from_str(phase)?;
z_factor(eos, t, p, &comp, phase).map_err(map_eos_err)
}
#[pyfunction]
fn eos_ln_phi_pure(
eos: CubicEos,
t: f64,
p: f64,
tc: f64,
pc: f64,
omega: f64,
phase: &str,
) -> PyResult<f64> {
let comp = comp_for_eos(tc, pc, omega);
let phase = phase_from_str(phase)?;
ln_phi_pure(eos, t, p, &comp, phase).map_err(map_eos_err)
}
#[pyfunction]
fn eos_h_departure_rt(
eos: CubicEos,
t: f64,
p: f64,
tc: f64,
pc: f64,
omega: f64,
phase: &str,
) -> PyResult<f64> {
let comp = comp_for_eos(tc, pc, omega);
let phase = phase_from_str(phase)?;
h_departure_rt(eos, t, p, &comp, phase).map_err(map_eos_err)
}
#[pyfunction]
fn eos_s_departure_r(
eos: CubicEos,
t: f64,
p: f64,
tc: f64,
pc: f64,
omega: f64,
phase: &str,
) -> PyResult<f64> {
let comp = comp_for_eos(tc, pc, omega);
let phase = phase_from_str(phase)?;
s_departure_r(eos, t, p, &comp, phase).map_err(map_eos_err)
}
fn map_sat_err(e: SatError) -> PyErr {
match e {
SatError::NotImplemented(_) => {
pyo3::exceptions::PyNotImplementedError::new_err(e.to_string())
}
SatError::BadCoefficients { .. } => PyValueError::new_err(e.to_string()),
SatError::OutOfRange(_) => PyValueError::new_err(e.to_string()),
}
}
#[pyfunction]
fn antoine_psat(t: f64, pc: f64, coeffs: Vec<f64>) -> PyResult<f64> {
let comp = Component {
pc,
psat_coeffs: coeffs,
..Component::default()
};
psat_antoine(&comp, t).map_err(map_sat_err)
}
#[pyfunction]
fn antoine_d_psat_dt(t: f64, pc: f64, coeffs: Vec<f64>) -> PyResult<f64> {
let comp = Component {
pc,
psat_coeffs: coeffs,
..Component::default()
};
d_psat_dt_antoine(&comp, t).map_err(map_sat_err)
}
#[pyfunction]
fn virial_pitzer_b0(tr: f64) -> f64 {
pitzer_b0(tr)
}
#[pyfunction]
fn virial_pitzer_b1(tr: f64) -> f64 {
pitzer_b1(tr)
}
#[pyfunction]
fn virial_b_pure(tc: f64, pc: f64, omega: f64, t: f64) -> f64 {
pitzer_b(&comp_for_eos(tc, pc, omega), t)
}
#[pyfunction]
fn virial_d_b_d_t_pure(tc: f64, pc: f64, omega: f64, t: f64) -> f64 {
pitzer_d_b_d_t(&comp_for_eos(tc, pc, omega), t)
}
#[pyfunction]
fn virial_z(tc: f64, pc: f64, omega: f64, t: f64, p: f64) -> f64 {
z_factor_virial(&comp_for_eos(tc, pc, omega), t, p)
}
#[pyfunction]
fn virial_ln_phi(tc: f64, pc: f64, omega: f64, t: f64, p: f64) -> f64 {
ln_phi_pure_virial(&comp_for_eos(tc, pc, omega), t, p)
}
#[pyfunction]
fn virial_h_dep_rt(tc: f64, pc: f64, omega: f64, t: f64, p: f64) -> f64 {
h_departure_rt_virial(&comp_for_eos(tc, pc, omega), t, p)
}
#[pyfunction]
fn virial_s_dep_r(tc: f64, pc: f64, omega: f64, t: f64, p: f64) -> f64 {
s_departure_r_virial(&comp_for_eos(tc, pc, omega), t, p)
}
#[pyfunction]
fn virial_b_mix_py(
tcs: Vec<f64>,
pcs: Vec<f64>,
omegas: Vec<f64>,
mole_fractions: Vec<f64>,
t: f64,
) -> PyResult<f64> {
let n = tcs.len();
if pcs.len() != n || omegas.len() != n || mole_fractions.len() != n {
return Err(PyValueError::new_err(
"tcs, pcs, omegas, mole_fractions must all have the same length",
));
}
let comps: Vec<Component> = (0..n).map(|i| comp_for_eos(tcs[i], pcs[i], omegas[i])).collect();
virial_b_mix(&comps, &mole_fractions, t).map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
fn virial_ln_phi_mix(
tcs: Vec<f64>,
pcs: Vec<f64>,
omegas: Vec<f64>,
mole_fractions: Vec<f64>,
t: f64,
p: f64,
) -> PyResult<Vec<f64>> {
let n = tcs.len();
if pcs.len() != n || omegas.len() != n || mole_fractions.len() != n {
return Err(PyValueError::new_err(
"tcs, pcs, omegas, mole_fractions must all have the same length",
));
}
let comps: Vec<Component> = (0..n).map(|i| comp_for_eos(tcs[i], pcs[i], omegas[i])).collect();
ln_phi_mix_virial(&comps, &mole_fractions, t, p)
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pymodule]
fn _engine(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(version, m)?)?;
m.add_function(wrap_pyfunction!(default_units_toml, m)?)?;
m.add_function(wrap_pyfunction!(solve_cubic, m)?)?;
m.add_function(wrap_pyfunction!(brent, m)?)?;
m.add_function(wrap_pyfunction!(illinois, m)?)?;
m.add_function(wrap_pyfunction!(halley, m)?)?;
m.add_function(wrap_pyfunction!(broyden, m)?)?;
m.add_function(wrap_pyfunction!(sum_frac_residual, m)?)?;
m.add_function(wrap_pyfunction!(norm_l1, m)?)?;
m.add_function(wrap_pyfunction!(norm_l2, m)?)?;
m.add_function(wrap_pyfunction!(norm_linf, m)?)?;
m.add_class::<crate::eos::CubicEos>()?;
m.add_class::<crate::activity::ActivityModel>()?;
m.add_class::<crate::mixing::MixingRule>()?;
m.add_class::<crate::saturation::SatPressureModel>()?;
m.add_class::<crate::eos::PhaseId>()?;
m.add_function(wrap_pyfunction!(eos_alpha, m)?)?;
m.add_function(wrap_pyfunction!(eos_d_alpha_d_tr, m)?)?;
m.add_function(wrap_pyfunction!(eos_family_constants, m)?)?;
m.add_function(wrap_pyfunction!(eos_z_factor, m)?)?;
m.add_function(wrap_pyfunction!(eos_ln_phi_pure, m)?)?;
m.add_function(wrap_pyfunction!(eos_h_departure_rt, m)?)?;
m.add_function(wrap_pyfunction!(eos_s_departure_r, m)?)?;
m.add_function(wrap_pyfunction!(antoine_psat, m)?)?;
m.add_function(wrap_pyfunction!(antoine_d_psat_dt, m)?)?;
m.add_function(wrap_pyfunction!(virial_pitzer_b0, m)?)?;
m.add_function(wrap_pyfunction!(virial_pitzer_b1, m)?)?;
m.add_function(wrap_pyfunction!(virial_b_pure, m)?)?;
m.add_function(wrap_pyfunction!(virial_d_b_d_t_pure, m)?)?;
m.add_function(wrap_pyfunction!(virial_z, m)?)?;
m.add_function(wrap_pyfunction!(virial_ln_phi, m)?)?;
m.add_function(wrap_pyfunction!(virial_h_dep_rt, m)?)?;
m.add_function(wrap_pyfunction!(virial_s_dep_r, m)?)?;
m.add_function(wrap_pyfunction!(virial_b_mix_py, m)?)?;
m.add_function(wrap_pyfunction!(virial_ln_phi_mix, m)?)?;
Ok(())
}