1use std::{ops::ControlFlow, sync::Arc};
2
3use ganesh::{
4 algorithms::{
5 gradient::GradientStatus,
6 gradient_free::GradientFreeStatus,
7 mcmc::{integrated_autocorrelation_times, EnsembleStatus},
8 particles::SwarmStatus,
9 },
10 python::{PyEnsembleStatus, PyGradientFreeStatus, PyGradientStatus, PySwarmStatus},
11 traits::{Algorithm, Observer, Terminator},
12 DVector,
13};
14use laddu_core::LadduError;
15use laddu_extensions::optimize::MaybeThreadPool;
16use numpy::{PyArray1, ToPyArray};
17use pyo3::prelude::*;
18#[pyclass(eq, eq_int, name = "ControlFlow", module = "laddu", from_py_object)]
21#[derive(PartialEq, Clone)]
22pub enum PyControlFlow {
23 Continue = 0,
25 Break = 1,
27}
28
29impl From<PyControlFlow> for ControlFlow<()> {
30 fn from(v: PyControlFlow) -> Self {
31 match v {
32 PyControlFlow::Continue => ControlFlow::Continue(()),
33 PyControlFlow::Break => ControlFlow::Break(()),
34 }
35 }
36}
37
38#[derive(Clone)]
43pub struct MinimizationObserver(Arc<Py<PyAny>>);
44impl<'a, 'py> FromPyObject<'a, 'py> for MinimizationObserver {
45 type Error = PyErr;
46 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
47 Ok(MinimizationObserver(Arc::new(ob.to_owned().unbind())))
48 }
49}
50impl<A, P, C> Observer<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
51 for MinimizationObserver
52where
53 A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
54{
55 fn observe(
56 &mut self,
57 current_step: usize,
58 _algorithm: &A,
59 _problem: &P,
60 status: &GradientStatus,
61 _args: &MaybeThreadPool,
62 _config: &C,
63 ) {
64 Python::attach(|py| {
65 if let Err(err) = self.0.bind(py).call_method1(
66 "observe",
67 (
68 current_step,
69 Py::new(py, PyGradientStatus::from(status.clone()))
70 .expect("ganesh gradient status should construct"),
71 ),
72 ) {
73 err.print(py);
74 }
75 })
76 }
77}
78impl<A, P, C> Observer<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
79 for MinimizationObserver
80where
81 A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
82{
83 fn observe(
84 &mut self,
85 current_step: usize,
86 _algorithm: &A,
87 _problem: &P,
88 status: &GradientFreeStatus,
89 _args: &MaybeThreadPool,
90 _config: &C,
91 ) {
92 Python::attach(|py| {
93 if let Err(err) = self.0.bind(py).call_method1(
94 "observe",
95 (
96 current_step,
97 Py::new(py, PyGradientFreeStatus::from(status.clone()))
98 .expect("ganesh gradient-free status should construct"),
99 ),
100 ) {
101 err.print(py);
102 }
103 })
104 }
105}
106impl<A, P, C> Observer<A, P, SwarmStatus, MaybeThreadPool, LadduError, C> for MinimizationObserver
107where
108 A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
109{
110 fn observe(
111 &mut self,
112 current_step: usize,
113 _algorithm: &A,
114 _problem: &P,
115 status: &SwarmStatus,
116 _args: &MaybeThreadPool,
117 _config: &C,
118 ) {
119 Python::attach(|py| -> PyResult<()> {
120 if let Err(err) = self.0.bind(py).call_method1(
121 "observe",
122 (
123 current_step,
124 Py::new(py, PySwarmStatus::from(status.clone()))?,
125 ),
126 ) {
127 err.print(py);
128 }
129 Ok(())
130 })
131 .expect("call to 'observe' has failed!")
132 }
133}
134
135#[derive(Clone)]
141pub struct MinimizationTerminator(Arc<Py<PyAny>>);
142
143impl<'a, 'py> FromPyObject<'a, 'py> for MinimizationTerminator {
144 type Error = PyErr;
145 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
146 Ok(MinimizationTerminator(Arc::new(ob.to_owned().unbind())))
147 }
148}
149
150impl<A, P, C> Terminator<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
151 for MinimizationTerminator
152where
153 A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
154{
155 fn check_for_termination(
156 &mut self,
157 current_step: usize,
158 _algorithm: &mut A,
159 _problem: &P,
160 status: &mut GradientStatus,
161 _args: &MaybeThreadPool,
162 _config: &C,
163 ) -> ControlFlow<()> {
164 Python::attach(|py| -> PyResult<ControlFlow<()>> {
165 let py_status = Py::new(py, PyGradientStatus::from(status.clone()))?;
166 let ret = self
167 .0
168 .bind(py)
169 .call_method1("check_for_termination", (current_step, py_status))?;
170 let cf: PyControlFlow = ret.extract()?;
171 Ok(cf.into())
172 })
173 .expect("call to 'check_for_termination' has failed!")
174 }
175}
176impl<A, P, C> Terminator<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
177 for MinimizationTerminator
178where
179 A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
180{
181 fn check_for_termination(
182 &mut self,
183 current_step: usize,
184 _algorithm: &mut A,
185 _problem: &P,
186 status: &mut GradientFreeStatus,
187 _args: &MaybeThreadPool,
188 _config: &C,
189 ) -> ControlFlow<()> {
190 Python::attach(|py| -> PyResult<ControlFlow<()>> {
191 let py_status = Py::new(py, PyGradientFreeStatus::from(status.clone()))?;
192 let ret = self
193 .0
194 .bind(py)
195 .call_method1("check_for_termination", (current_step, py_status))?;
196 let cf: PyControlFlow = ret.extract()?;
197 Ok(cf.into())
198 })
199 .expect("call to 'check_for_termination' has failed!")
200 }
201}
202impl<A, P, C> Terminator<A, P, SwarmStatus, MaybeThreadPool, LadduError, C>
203 for MinimizationTerminator
204where
205 A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
206{
207 fn check_for_termination(
208 &mut self,
209 current_step: usize,
210 _algorithm: &mut A,
211 _problem: &P,
212 status: &mut SwarmStatus,
213 _args: &MaybeThreadPool,
214 _config: &C,
215 ) -> ControlFlow<()> {
216 Python::attach(|py| -> PyResult<ControlFlow<()>> {
217 let py_status = Py::new(py, PySwarmStatus::from(status.clone()))?;
218 let ret = self
219 .0
220 .bind(py)
221 .call_method1("check_for_termination", (current_step, py_status))?;
222 let cf: PyControlFlow = ret.extract()?;
223 Ok(cf.into())
224 })
225 .expect("call to 'check_for_termination' has failed!")
226 }
227}
228
229#[derive(Clone)]
234pub struct MCMCObserver(Arc<Py<PyAny>>);
235
236impl<'a, 'py> FromPyObject<'a, 'py> for MCMCObserver {
237 type Error = PyErr;
238 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
239 Ok(MCMCObserver(Arc::new(ob.to_owned().unbind())))
240 }
241}
242
243impl<A, P, C> Observer<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCObserver
244where
245 A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
246{
247 fn observe(
248 &mut self,
249 current_step: usize,
250 _algorithm: &A,
251 _problem: &P,
252 status: &EnsembleStatus,
253 _args: &MaybeThreadPool,
254 _config: &C,
255 ) {
256 Python::attach(|py| -> PyResult<()> {
257 if let Err(err) = self.0.bind(py).call_method1(
258 "observe",
259 (
260 current_step,
261 Py::new(py, PyEnsembleStatus::from(status.clone()))?,
262 ),
263 ) {
264 err.print(py);
265 }
266 Ok(())
267 })
268 .expect("call to 'observe' has failed!")
269 }
270}
271
272#[derive(Clone)]
278pub struct MCMCTerminator(Arc<Py<PyAny>>);
279
280impl<'a, 'py> FromPyObject<'a, 'py> for MCMCTerminator {
281 type Error = PyErr;
282 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
283 Ok(MCMCTerminator(Arc::new(ob.to_owned().unbind())))
284 }
285}
286
287impl<A, P, C> Terminator<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCTerminator
288where
289 A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
290{
291 fn check_for_termination(
292 &mut self,
293 current_step: usize,
294 _algorithm: &mut A,
295 _problem: &P,
296 status: &mut EnsembleStatus,
297 _args: &MaybeThreadPool,
298 _config: &C,
299 ) -> ControlFlow<()> {
300 Python::attach(|py| -> PyResult<ControlFlow<()>> {
301 let py_status = Py::new(py, PyEnsembleStatus::from(status.clone()))?;
302 let ret = self
303 .0
304 .bind(py)
305 .call_method1("check_for_termination", (current_step, py_status))?;
306 let cf: PyControlFlow = ret.extract()?;
307 Ok(cf.into())
308 })
309 .expect("call to 'check_for_termination' has failed!")
310 }
311}
312
313#[pyfunction(name = "integrated_autocorrelation_times")]
342#[pyo3(signature = (samples, *, c=None))]
343pub fn py_integrated_autocorrelation_times<'py>(
344 py: Python<'py>,
345 samples: Vec<Vec<Vec<f64>>>,
346 c: Option<f64>,
347) -> Bound<'py, PyArray1<f64>> {
348 let samples: Vec<Vec<DVector<f64>>> = samples
349 .into_iter()
350 .map(|walker| walker.into_iter().map(DVector::from_vec).collect())
351 .collect();
352 integrated_autocorrelation_times(samples, c)
353 .as_slice()
354 .to_pyarray(py)
355}