use std::{ops::ControlFlow, sync::Arc};
use ganesh::{
algorithms::{
gradient::GradientStatus,
gradient_free::GradientFreeStatus,
mcmc::{integrated_autocorrelation_times, EnsembleStatus},
particles::SwarmStatus,
},
python::{PyEnsembleStatus, PyGradientFreeStatus, PyGradientStatus, PySwarmStatus},
traits::{Algorithm, Observer, Terminator},
DVector,
};
use laddu_core::LadduError;
use laddu_extensions::optimize::MaybeThreadPool;
use numpy::{PyArray1, ToPyArray};
use pyo3::prelude::*;
#[pyclass(eq, eq_int, name = "ControlFlow", module = "laddu", from_py_object)]
#[derive(PartialEq, Clone)]
pub enum PyControlFlow {
Continue = 0,
Break = 1,
}
impl From<PyControlFlow> for ControlFlow<()> {
fn from(v: PyControlFlow) -> Self {
match v {
PyControlFlow::Continue => ControlFlow::Continue(()),
PyControlFlow::Break => ControlFlow::Break(()),
}
}
}
#[derive(Clone)]
pub struct MinimizationObserver(Arc<Py<PyAny>>);
impl<'a, 'py> FromPyObject<'a, 'py> for MinimizationObserver {
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
Ok(MinimizationObserver(Arc::new(ob.to_owned().unbind())))
}
}
impl<A, P, C> Observer<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
for MinimizationObserver
where
A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn observe(
&mut self,
current_step: usize,
_algorithm: &A,
_problem: &P,
status: &GradientStatus,
_args: &MaybeThreadPool,
_config: &C,
) {
Python::attach(|py| {
if let Err(err) = self.0.bind(py).call_method1(
"observe",
(
current_step,
Py::new(py, PyGradientStatus::from(status.clone()))
.expect("ganesh gradient status should construct"),
),
) {
err.print(py);
}
})
}
}
impl<A, P, C> Observer<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
for MinimizationObserver
where
A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn observe(
&mut self,
current_step: usize,
_algorithm: &A,
_problem: &P,
status: &GradientFreeStatus,
_args: &MaybeThreadPool,
_config: &C,
) {
Python::attach(|py| {
if let Err(err) = self.0.bind(py).call_method1(
"observe",
(
current_step,
Py::new(py, PyGradientFreeStatus::from(status.clone()))
.expect("ganesh gradient-free status should construct"),
),
) {
err.print(py);
}
})
}
}
impl<A, P, C> Observer<A, P, SwarmStatus, MaybeThreadPool, LadduError, C> for MinimizationObserver
where
A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn observe(
&mut self,
current_step: usize,
_algorithm: &A,
_problem: &P,
status: &SwarmStatus,
_args: &MaybeThreadPool,
_config: &C,
) {
Python::attach(|py| -> PyResult<()> {
if let Err(err) = self.0.bind(py).call_method1(
"observe",
(
current_step,
Py::new(py, PySwarmStatus::from(status.clone()))?,
),
) {
err.print(py);
}
Ok(())
})
.expect("call to 'observe' has failed!")
}
}
#[derive(Clone)]
pub struct MinimizationTerminator(Arc<Py<PyAny>>);
impl<'a, 'py> FromPyObject<'a, 'py> for MinimizationTerminator {
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
Ok(MinimizationTerminator(Arc::new(ob.to_owned().unbind())))
}
}
impl<A, P, C> Terminator<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
for MinimizationTerminator
where
A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn check_for_termination(
&mut self,
current_step: usize,
_algorithm: &mut A,
_problem: &P,
status: &mut GradientStatus,
_args: &MaybeThreadPool,
_config: &C,
) -> ControlFlow<()> {
Python::attach(|py| -> PyResult<ControlFlow<()>> {
let py_status = Py::new(py, PyGradientStatus::from(status.clone()))?;
let ret = self
.0
.bind(py)
.call_method1("check_for_termination", (current_step, py_status))?;
let cf: PyControlFlow = ret.extract()?;
Ok(cf.into())
})
.expect("call to 'check_for_termination' has failed!")
}
}
impl<A, P, C> Terminator<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
for MinimizationTerminator
where
A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn check_for_termination(
&mut self,
current_step: usize,
_algorithm: &mut A,
_problem: &P,
status: &mut GradientFreeStatus,
_args: &MaybeThreadPool,
_config: &C,
) -> ControlFlow<()> {
Python::attach(|py| -> PyResult<ControlFlow<()>> {
let py_status = Py::new(py, PyGradientFreeStatus::from(status.clone()))?;
let ret = self
.0
.bind(py)
.call_method1("check_for_termination", (current_step, py_status))?;
let cf: PyControlFlow = ret.extract()?;
Ok(cf.into())
})
.expect("call to 'check_for_termination' has failed!")
}
}
impl<A, P, C> Terminator<A, P, SwarmStatus, MaybeThreadPool, LadduError, C>
for MinimizationTerminator
where
A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn check_for_termination(
&mut self,
current_step: usize,
_algorithm: &mut A,
_problem: &P,
status: &mut SwarmStatus,
_args: &MaybeThreadPool,
_config: &C,
) -> ControlFlow<()> {
Python::attach(|py| -> PyResult<ControlFlow<()>> {
let py_status = Py::new(py, PySwarmStatus::from(status.clone()))?;
let ret = self
.0
.bind(py)
.call_method1("check_for_termination", (current_step, py_status))?;
let cf: PyControlFlow = ret.extract()?;
Ok(cf.into())
})
.expect("call to 'check_for_termination' has failed!")
}
}
#[derive(Clone)]
pub struct MCMCObserver(Arc<Py<PyAny>>);
impl<'a, 'py> FromPyObject<'a, 'py> for MCMCObserver {
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
Ok(MCMCObserver(Arc::new(ob.to_owned().unbind())))
}
}
impl<A, P, C> Observer<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCObserver
where
A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn observe(
&mut self,
current_step: usize,
_algorithm: &A,
_problem: &P,
status: &EnsembleStatus,
_args: &MaybeThreadPool,
_config: &C,
) {
Python::attach(|py| -> PyResult<()> {
if let Err(err) = self.0.bind(py).call_method1(
"observe",
(
current_step,
Py::new(py, PyEnsembleStatus::from(status.clone()))?,
),
) {
err.print(py);
}
Ok(())
})
.expect("call to 'observe' has failed!")
}
}
#[derive(Clone)]
pub struct MCMCTerminator(Arc<Py<PyAny>>);
impl<'a, 'py> FromPyObject<'a, 'py> for MCMCTerminator {
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
Ok(MCMCTerminator(Arc::new(ob.to_owned().unbind())))
}
}
impl<A, P, C> Terminator<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCTerminator
where
A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
{
fn check_for_termination(
&mut self,
current_step: usize,
_algorithm: &mut A,
_problem: &P,
status: &mut EnsembleStatus,
_args: &MaybeThreadPool,
_config: &C,
) -> ControlFlow<()> {
Python::attach(|py| -> PyResult<ControlFlow<()>> {
let py_status = Py::new(py, PyEnsembleStatus::from(status.clone()))?;
let ret = self
.0
.bind(py)
.call_method1("check_for_termination", (current_step, py_status))?;
let cf: PyControlFlow = ret.extract()?;
Ok(cf.into())
})
.expect("call to 'check_for_termination' has failed!")
}
}
#[pyfunction(name = "integrated_autocorrelation_times")]
#[pyo3(signature = (samples, *, c=None))]
pub fn py_integrated_autocorrelation_times<'py>(
py: Python<'py>,
samples: Vec<Vec<Vec<f64>>>,
c: Option<f64>,
) -> Bound<'py, PyArray1<f64>> {
let samples: Vec<Vec<DVector<f64>>> = samples
.into_iter()
.map(|walker| walker.into_iter().map(DVector::from_vec).collect())
.collect();
integrated_autocorrelation_times(samples, c)
.as_slice()
.to_pyarray(py)
}