Skip to main content

laddu_python/extensions/
callbacks.rs

1use std::{ops::ControlFlow, sync::Arc};
2
3use ganesh::{
4    algorithms::{
5        gradient::GradientStatus,
6        gradient_free::GradientFreeStatus,
7        mcmc::{integrated_autocorrelation_times, EnsembleStatus},
8        particles::SwarmStatus,
9    },
10    python::{PyEnsembleStatus, PyGradientFreeStatus, PyGradientStatus, PySwarmStatus},
11    traits::{Algorithm, Observer, Terminator},
12    DVector,
13};
14use laddu_core::LadduError;
15use laddu_extensions::optimize::MaybeThreadPool;
16use numpy::{PyArray1, ToPyArray};
17use pyo3::prelude::*;
18/// An enum used by a terminator to continue or stop an algorithm.
19///
20#[pyclass(eq, eq_int, name = "ControlFlow", module = "laddu", from_py_object)]
21#[derive(PartialEq, Clone)]
22pub enum PyControlFlow {
23    /// Continue running the algorithm.
24    Continue = 0,
25    /// Terminate the algorithm.
26    Break = 1,
27}
28
29impl From<PyControlFlow> for ControlFlow<()> {
30    fn from(v: PyControlFlow) -> Self {
31        match v {
32            PyControlFlow::Continue => ControlFlow::Continue(()),
33            PyControlFlow::Break => ControlFlow::Break(()),
34        }
35    }
36}
37
38/// An [`Observer`] which can be used to monitor the progress of a minimization.
39///
40/// This should be paired with a Python object which has an `observe` method
41/// that takes the current step and a method-specific minimization status object.
42#[derive(Clone)]
43pub struct MinimizationObserver(Arc<Py<PyAny>>);
44impl<'a, 'py> FromPyObject<'a, 'py> for MinimizationObserver {
45    type Error = PyErr;
46    fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
47        Ok(MinimizationObserver(Arc::new(ob.to_owned().unbind())))
48    }
49}
50impl<A, P, C> Observer<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
51    for MinimizationObserver
52where
53    A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
54{
55    fn observe(
56        &mut self,
57        current_step: usize,
58        _algorithm: &A,
59        _problem: &P,
60        status: &GradientStatus,
61        _args: &MaybeThreadPool,
62        _config: &C,
63    ) {
64        Python::attach(|py| {
65            if let Err(err) = self.0.bind(py).call_method1(
66                "observe",
67                (
68                    current_step,
69                    Py::new(py, PyGradientStatus::from(status.clone()))
70                        .expect("ganesh gradient status should construct"),
71                ),
72            ) {
73                err.print(py);
74            }
75        })
76    }
77}
78impl<A, P, C> Observer<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
79    for MinimizationObserver
80where
81    A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
82{
83    fn observe(
84        &mut self,
85        current_step: usize,
86        _algorithm: &A,
87        _problem: &P,
88        status: &GradientFreeStatus,
89        _args: &MaybeThreadPool,
90        _config: &C,
91    ) {
92        Python::attach(|py| {
93            if let Err(err) = self.0.bind(py).call_method1(
94                "observe",
95                (
96                    current_step,
97                    Py::new(py, PyGradientFreeStatus::from(status.clone()))
98                        .expect("ganesh gradient-free status should construct"),
99                ),
100            ) {
101                err.print(py);
102            }
103        })
104    }
105}
106impl<A, P, C> Observer<A, P, SwarmStatus, MaybeThreadPool, LadduError, C> for MinimizationObserver
107where
108    A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
109{
110    fn observe(
111        &mut self,
112        current_step: usize,
113        _algorithm: &A,
114        _problem: &P,
115        status: &SwarmStatus,
116        _args: &MaybeThreadPool,
117        _config: &C,
118    ) {
119        Python::attach(|py| -> PyResult<()> {
120            if let Err(err) = self.0.bind(py).call_method1(
121                "observe",
122                (
123                    current_step,
124                    Py::new(py, PySwarmStatus::from(status.clone()))?,
125                ),
126            ) {
127                err.print(py);
128            }
129            Ok(())
130        })
131        .expect("call to 'observe' has failed!")
132    }
133}
134
135/// An [`Terminator`] which can be used to monitor the progress of a minimization.
136///
137/// This should be paired with a Python object which has an `check_for_termination` method
138/// that takes the current step and a method-specific minimization status object and returns a
139/// [`PyControlFlow`].
140#[derive(Clone)]
141pub struct MinimizationTerminator(Arc<Py<PyAny>>);
142
143impl<'a, 'py> FromPyObject<'a, 'py> for MinimizationTerminator {
144    type Error = PyErr;
145    fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
146        Ok(MinimizationTerminator(Arc::new(ob.to_owned().unbind())))
147    }
148}
149
150impl<A, P, C> Terminator<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
151    for MinimizationTerminator
152where
153    A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
154{
155    fn check_for_termination(
156        &mut self,
157        current_step: usize,
158        _algorithm: &mut A,
159        _problem: &P,
160        status: &mut GradientStatus,
161        _args: &MaybeThreadPool,
162        _config: &C,
163    ) -> ControlFlow<()> {
164        Python::attach(|py| -> PyResult<ControlFlow<()>> {
165            let py_status = Py::new(py, PyGradientStatus::from(status.clone()))?;
166            let ret = self
167                .0
168                .bind(py)
169                .call_method1("check_for_termination", (current_step, py_status))?;
170            let cf: PyControlFlow = ret.extract()?;
171            Ok(cf.into())
172        })
173        .expect("call to 'check_for_termination' has failed!")
174    }
175}
176impl<A, P, C> Terminator<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
177    for MinimizationTerminator
178where
179    A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
180{
181    fn check_for_termination(
182        &mut self,
183        current_step: usize,
184        _algorithm: &mut A,
185        _problem: &P,
186        status: &mut GradientFreeStatus,
187        _args: &MaybeThreadPool,
188        _config: &C,
189    ) -> ControlFlow<()> {
190        Python::attach(|py| -> PyResult<ControlFlow<()>> {
191            let py_status = Py::new(py, PyGradientFreeStatus::from(status.clone()))?;
192            let ret = self
193                .0
194                .bind(py)
195                .call_method1("check_for_termination", (current_step, py_status))?;
196            let cf: PyControlFlow = ret.extract()?;
197            Ok(cf.into())
198        })
199        .expect("call to 'check_for_termination' has failed!")
200    }
201}
202impl<A, P, C> Terminator<A, P, SwarmStatus, MaybeThreadPool, LadduError, C>
203    for MinimizationTerminator
204where
205    A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
206{
207    fn check_for_termination(
208        &mut self,
209        current_step: usize,
210        _algorithm: &mut A,
211        _problem: &P,
212        status: &mut SwarmStatus,
213        _args: &MaybeThreadPool,
214        _config: &C,
215    ) -> ControlFlow<()> {
216        Python::attach(|py| -> PyResult<ControlFlow<()>> {
217            let py_status = Py::new(py, PySwarmStatus::from(status.clone()))?;
218            let ret = self
219                .0
220                .bind(py)
221                .call_method1("check_for_termination", (current_step, py_status))?;
222            let cf: PyControlFlow = ret.extract()?;
223            Ok(cf.into())
224        })
225        .expect("call to 'check_for_termination' has failed!")
226    }
227}
228
229/// An [`Observer`] which can be used to monitor the progress of an MCMC algorithm.
230///
231/// This should be paired with a Python object which has an `observe` method
232/// that takes the current step and a `ganesh.EnsembleStatus` object.
233#[derive(Clone)]
234pub struct MCMCObserver(Arc<Py<PyAny>>);
235
236impl<'a, 'py> FromPyObject<'a, 'py> for MCMCObserver {
237    type Error = PyErr;
238    fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
239        Ok(MCMCObserver(Arc::new(ob.to_owned().unbind())))
240    }
241}
242
243impl<A, P, C> Observer<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCObserver
244where
245    A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
246{
247    fn observe(
248        &mut self,
249        current_step: usize,
250        _algorithm: &A,
251        _problem: &P,
252        status: &EnsembleStatus,
253        _args: &MaybeThreadPool,
254        _config: &C,
255    ) {
256        Python::attach(|py| -> PyResult<()> {
257            if let Err(err) = self.0.bind(py).call_method1(
258                "observe",
259                (
260                    current_step,
261                    Py::new(py, PyEnsembleStatus::from(status.clone()))?,
262                ),
263            ) {
264                err.print(py);
265            }
266            Ok(())
267        })
268        .expect("call to 'observe' has failed!")
269    }
270}
271
272/// A [`Terminator`] which can be used to monitor the progress of an MCMC algorithm.
273///
274/// This should be paired with a Python object which has an `check_for_termination` method
275/// that takes the current step and a `ganesh.EnsembleStatus` object and returns a
276/// [`PyControlFlow`].
277#[derive(Clone)]
278pub struct MCMCTerminator(Arc<Py<PyAny>>);
279
280impl<'a, 'py> FromPyObject<'a, 'py> for MCMCTerminator {
281    type Error = PyErr;
282    fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
283        Ok(MCMCTerminator(Arc::new(ob.to_owned().unbind())))
284    }
285}
286
287impl<A, P, C> Terminator<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCTerminator
288where
289    A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
290{
291    fn check_for_termination(
292        &mut self,
293        current_step: usize,
294        _algorithm: &mut A,
295        _problem: &P,
296        status: &mut EnsembleStatus,
297        _args: &MaybeThreadPool,
298        _config: &C,
299    ) -> ControlFlow<()> {
300        Python::attach(|py| -> PyResult<ControlFlow<()>> {
301            let py_status = Py::new(py, PyEnsembleStatus::from(status.clone()))?;
302            let ret = self
303                .0
304                .bind(py)
305                .call_method1("check_for_termination", (current_step, py_status))?;
306            let cf: PyControlFlow = ret.extract()?;
307            Ok(cf.into())
308        })
309        .expect("call to 'check_for_termination' has failed!")
310    }
311}
312
313/// Calculate the integrated autocorrelation time for each parameter according to
314/// Karamanis & Beutler (2021).
315///
316/// Parameters
317/// ----------
318/// samples : array_like
319///     An array of dimension ``(n_walkers, n_steps, n_parameters)``.
320/// c : float, default = 7.0
321///     The time window for Sokal's autowindowing function. If omitted, the
322///     default window size of 7.0 is used.
323///
324/// Returns
325/// -------
326/// array of shape (n_parameters,)
327///
328/// Examples
329/// --------
330/// >>> import numpy as np
331/// >>> from laddu import integrated_autocorrelation_times
332/// >>> samples = np.random.randn(4, 16, 2).tolist()
333/// >>> integrated_autocorrelation_times(samples).shape
334/// (2,)
335///
336/// References
337/// ----------
338/// Karamanis, M. & Beutler, F. (2021). *Ensemble slice sampling*. Stat. Comput. 31(5). <https://doi.org/10.1007/s11222-021-10038-2>
339///
340/// Sokal, A. (1997). *Monte Carlo Methods in Statistical Mechanics: Foundations and New Algorithms*. NATO ASI Series, 131–192. <https://doi.org/10.1007/978-1-4899-0319-8_6>
341#[pyfunction(name = "integrated_autocorrelation_times")]
342#[pyo3(signature = (samples, *, c=None))]
343pub fn py_integrated_autocorrelation_times<'py>(
344    py: Python<'py>,
345    samples: Vec<Vec<Vec<f64>>>,
346    c: Option<f64>,
347) -> Bound<'py, PyArray1<f64>> {
348    let samples: Vec<Vec<DVector<f64>>> = samples
349        .into_iter()
350        .map(|walker| walker.into_iter().map(DVector::from_vec).collect())
351        .collect();
352    integrated_autocorrelation_times(samples, c)
353        .as_slice()
354        .to_pyarray(py)
355}