1use std::sync::Arc;
2
3use fastrand::Rng;
4use ganesh::{
5 algorithms::LBFGSB,
6 mcmc::{aies::WeightedAIESMove, ess::WeightedESSMove, MCMCAlgorithm, MCMCObserver, AIES, ESS},
7 observers::{DebugMCMCObserver, DebugObserver},
8 Algorithm, Observer, Status,
9};
10use laddu_core::{Ensemble, LadduError};
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::ThreadPool;
14
15struct VerboseObserver {
16 show_step: bool,
17 show_x: bool,
18 show_fx: bool,
19}
20impl VerboseObserver {
21 fn build(self) -> Arc<RwLock<Self>> {
22 Arc::new(RwLock::new(self))
23 }
24}
25
26pub struct MinimizerOptions {
28 #[cfg(feature = "rayon")]
29 pub(crate) algorithm: Box<dyn ganesh::Algorithm<ThreadPool, LadduError>>,
30 #[cfg(not(feature = "rayon"))]
31 pub(crate) algorithm: Box<dyn ganesh::Algorithm<(), LadduError>>,
32 #[cfg(feature = "rayon")]
33 pub(crate) observers: Vec<Arc<RwLock<dyn Observer<ThreadPool>>>>,
34 #[cfg(not(feature = "rayon"))]
35 pub(crate) observers: Vec<Arc<RwLock<dyn Observer<()>>>>,
36 pub(crate) max_steps: usize,
37 #[cfg(feature = "rayon")]
38 pub(crate) threads: usize,
39}
40
41impl Default for MinimizerOptions {
42 fn default() -> Self {
43 Self {
44 algorithm: Box::new(LBFGSB::default()),
45 observers: Default::default(),
46 max_steps: 4000,
47 #[cfg(all(feature = "rayon", feature = "num_cpus"))]
48 threads: num_cpus::get(),
49 #[cfg(all(feature = "rayon", not(feature = "num_cpus")))]
50 threads: 0,
51 }
52 }
53}
54
55impl MinimizerOptions {
56 pub fn debug(self) -> Self {
58 let mut observers = self.observers;
59 observers.push(DebugObserver::build());
60 Self {
61 algorithm: self.algorithm,
62 observers,
63 max_steps: self.max_steps,
64 #[cfg(feature = "rayon")]
65 threads: self.threads,
66 }
67 }
68 pub fn verbose(self, show_step: bool, show_x: bool, show_fx: bool) -> Self {
70 let mut observers = self.observers;
71 observers.push(
72 VerboseObserver {
73 show_step,
74 show_x,
75 show_fx,
76 }
77 .build(),
78 );
79 Self {
80 algorithm: self.algorithm,
81 observers,
82 max_steps: self.max_steps,
83 #[cfg(feature = "rayon")]
84 threads: self.threads,
85 }
86 }
87 #[cfg(feature = "rayon")]
90 pub fn with_algorithm<A: Algorithm<ThreadPool, LadduError> + 'static>(
91 self,
92 algorithm: A,
93 ) -> Self {
94 Self {
95 algorithm: Box::new(algorithm),
96 observers: self.observers,
97 max_steps: self.max_steps,
98 threads: self.threads,
99 }
100 }
101
102 #[cfg(not(feature = "rayon"))]
105 pub fn with_algorithm<A: Algorithm<(), LadduError> + 'static>(self, algorithm: A) -> Self {
106 Self {
107 algorithm: Box::new(algorithm),
108 observers: self.observers,
109 max_steps: self.max_steps,
110 }
111 }
112 #[cfg(feature = "rayon")]
114 pub fn with_observer(self, observer: Arc<RwLock<dyn Observer<ThreadPool>>>) -> Self {
115 let mut observers = self.observers;
116 observers.push(observer.clone());
117 Self {
118 algorithm: self.algorithm,
119 observers,
120 max_steps: self.max_steps,
121 threads: self.threads,
122 }
123 }
124 #[cfg(not(feature = "rayon"))]
126 pub fn with_observer(self, observer: Arc<RwLock<dyn Observer<()>>>) -> Self {
127 let mut observers = self.observers;
128 observers.push(observer.clone());
129 Self {
130 algorithm: self.algorithm,
131 observers,
132 max_steps: self.max_steps,
133 }
134 }
135
136 pub fn with_max_steps(self, max_steps: usize) -> Self {
138 Self {
139 algorithm: self.algorithm,
140 observers: self.observers,
141 max_steps,
142 #[cfg(feature = "rayon")]
143 threads: self.threads,
144 }
145 }
146
147 #[cfg(feature = "rayon")]
149 pub fn with_threads(self, threads: usize) -> Self {
150 Self {
151 algorithm: self.algorithm,
152 observers: self.observers,
153 max_steps: self.max_steps,
154 threads,
155 }
156 }
157}
158
159#[cfg(feature = "rayon")]
160impl Observer<ThreadPool> for VerboseObserver {
161 fn callback(&mut self, step: usize, status: &mut Status, _user_data: &mut ThreadPool) -> bool {
162 if self.show_step {
163 println!("Step: {}", step);
164 }
165 if self.show_x {
166 println!("Current Best Position: {}", status.x.transpose());
167 }
168 if self.show_fx {
169 println!("Current Best Value: {}", status.fx);
170 }
171 false
172 }
173}
174
175impl Observer<()> for VerboseObserver {
176 fn callback(&mut self, step: usize, status: &mut Status, _user_data: &mut ()) -> bool {
177 if self.show_step {
178 println!("Step: {}", step);
179 }
180 if self.show_x {
181 println!("Current Best Position: {}", status.x.transpose());
182 }
183 if self.show_fx {
184 println!("Current Best Value: {}", status.fx);
185 }
186 false
187 }
188}
189
190struct VerboseMCMCObserver;
191impl VerboseMCMCObserver {
192 fn build() -> Arc<RwLock<Self>> {
193 Arc::new(RwLock::new(Self))
194 }
195}
196
197#[cfg(feature = "rayon")]
198impl MCMCObserver<ThreadPool> for VerboseMCMCObserver {
199 fn callback(
200 &mut self,
201 step: usize,
202 _ensemble: &mut Ensemble,
203 _thread_pool: &mut ThreadPool,
204 ) -> bool {
205 println!("Step: {}", step);
206 false
207 }
208}
209
210impl MCMCObserver<()> for VerboseMCMCObserver {
211 fn callback(&mut self, step: usize, _ensemble: &mut Ensemble, _user_data: &mut ()) -> bool {
212 println!("Step: {}", step);
213 false
214 }
215}
216
217pub struct MCMCOptions {
219 #[cfg(feature = "rayon")]
220 pub(crate) algorithm: Box<dyn MCMCAlgorithm<ThreadPool, LadduError>>,
221 #[cfg(not(feature = "rayon"))]
222 pub(crate) algorithm: Box<dyn MCMCAlgorithm<(), LadduError>>,
223 #[cfg(feature = "rayon")]
224 pub(crate) observers: Vec<Arc<RwLock<dyn MCMCObserver<ThreadPool>>>>,
225 #[cfg(not(feature = "rayon"))]
226 pub(crate) observers: Vec<Arc<RwLock<dyn MCMCObserver<()>>>>,
227 #[cfg(feature = "rayon")]
228 pub(crate) threads: usize,
229}
230
231impl MCMCOptions {
232 pub fn new_ess<T: AsRef<[WeightedESSMove]>>(moves: T, rng: Rng) -> Self {
234 Self {
235 algorithm: Box::new(ESS::new(moves, rng).with_n_adaptive(100)),
236 observers: Default::default(),
237 #[cfg(all(feature = "rayon", feature = "num_cpus"))]
238 threads: num_cpus::get(),
239 #[cfg(all(feature = "rayon", not(feature = "num_cpus")))]
240 threads: 0,
241 }
242 }
243 pub fn new_aies<T: AsRef<[WeightedAIESMove]>>(moves: T, rng: Rng) -> Self {
245 Self {
246 algorithm: Box::new(AIES::new(moves, rng)),
247 observers: Default::default(),
248 #[cfg(all(feature = "rayon", feature = "num_cpus"))]
249 threads: num_cpus::get(),
250 #[cfg(all(feature = "rayon", not(feature = "num_cpus")))]
251 threads: 0,
252 }
253 }
254 pub fn debug(self) -> Self {
256 let mut observers = self.observers;
257 observers.push(DebugMCMCObserver::build());
258 Self {
259 algorithm: self.algorithm,
260 observers,
261 #[cfg(feature = "rayon")]
262 threads: self.threads,
263 }
264 }
265 pub fn verbose(self) -> Self {
267 let mut observers = self.observers;
268 observers.push(VerboseMCMCObserver::build());
269 Self {
270 algorithm: self.algorithm,
271 observers,
272 #[cfg(feature = "rayon")]
273 threads: self.threads,
274 }
275 }
276 #[cfg(feature = "rayon")]
278 pub fn with_algorithm<A: MCMCAlgorithm<ThreadPool, LadduError> + 'static>(
279 self,
280 algorithm: A,
281 ) -> Self {
282 Self {
283 algorithm: Box::new(algorithm),
284 observers: self.observers,
285 threads: self.threads,
286 }
287 }
288 #[cfg(not(feature = "rayon"))]
290 pub fn with_algorithm<A: MCMCAlgorithm<(), LadduError> + 'static>(self, algorithm: A) -> Self {
291 Self {
292 algorithm: Box::new(algorithm),
293 observers: self.observers,
294 }
295 }
296 #[cfg(feature = "rayon")]
297 pub fn with_observer(self, observer: Arc<RwLock<dyn MCMCObserver<ThreadPool>>>) -> Self {
299 let mut observers = self.observers;
300 observers.push(observer.clone());
301 Self {
302 algorithm: self.algorithm,
303 observers,
304 threads: self.threads,
305 }
306 }
307 #[cfg(not(feature = "rayon"))]
308 pub fn with_observer(self, observer: Arc<RwLock<dyn MCMCObserver<()>>>) -> Self {
310 let mut observers = self.observers;
311 observers.push(observer.clone());
312 Self {
313 algorithm: self.algorithm,
314 observers,
315 }
316 }
317
318 #[cfg(feature = "rayon")]
320 pub fn with_threads(self, threads: usize) -> Self {
321 Self {
322 algorithm: self.algorithm,
323 observers: self.observers,
324 threads,
325 }
326 }
327}
328
329#[cfg(feature = "python")]
331pub mod py_ganesh {
332 use super::*;
333 use std::sync::Arc;
334
335 use bincode::{deserialize, serialize};
336 use fastrand::Rng;
337 use ganesh::{
338 algorithms::{
339 lbfgsb::{LBFGSBFTerminator, LBFGSBGTerminator},
340 nelder_mead::{NelderMeadFTerminator, NelderMeadXTerminator, SimplexExpansionMethod},
341 NelderMead, LBFGSB,
342 },
343 mcmc::{
344 aies::WeightedAIESMove, ess::WeightedESSMove, integrated_autocorrelation_times,
345 AIESMove, ESSMove, MCMCObserver, AIES, ESS,
346 },
347 observers::AutocorrelationObserver,
348 Observer, Status,
349 };
350 use laddu_core::{DVector, Ensemble, Float, LadduError, ReadWrite};
351 use laddu_python::GetStrExtractObj;
352 use numpy::{PyArray1, PyArray2, PyArray3};
353 use parking_lot::RwLock;
354 use pyo3::{
355 exceptions::{PyTypeError, PyValueError},
356 prelude::*,
357 types::{PyBytes, PyDict, PyList, PyTuple},
358 };
359
360 #[pyclass]
362 #[pyo3(name = "Observer")]
363 pub struct PyObserver(Py<PyAny>);
364
365 #[pymethods]
366 impl PyObserver {
367 #[new]
368 fn new(observer: Py<PyAny>) -> Self {
369 Self(observer)
370 }
371 }
372
373 #[pyclass]
375 #[pyo3(name = "MCMCObserver")]
376 pub struct PyMCMCObserver(Py<PyAny>);
377
378 #[pymethods]
379 impl PyMCMCObserver {
380 #[new]
381 fn new(observer: Py<PyAny>) -> Self {
382 Self(observer)
383 }
384 }
385
386 #[pyclass(name = "Status", module = "laddu")]
389 #[derive(Clone)]
390 pub struct PyStatus(pub Status);
391 #[pymethods]
392 impl PyStatus {
393 #[getter]
400 fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
401 PyArray1::from_slice(py, self.0.x.as_slice())
402 }
403 #[getter]
410 fn err<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray1<Float>>> {
411 self.0
412 .err
413 .clone()
414 .map(|err| PyArray1::from_slice(py, err.as_slice()))
415 }
416 #[getter]
423 fn x0<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
424 PyArray1::from_slice(py, self.0.x0.as_slice())
425 }
426 #[getter]
433 fn fx(&self) -> Float {
434 self.0.fx
435 }
436 #[getter]
448 fn cov<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyArray2<Float>>>> {
449 self.0
450 .cov
451 .clone()
452 .map(|cov| {
453 Ok(PyArray2::from_vec2(
454 py,
455 &cov.row_iter()
456 .map(|row| row.iter().cloned().collect())
457 .collect::<Vec<Vec<Float>>>(),
458 )
459 .map_err(LadduError::NumpyError)?)
460 })
461 .transpose()
462 }
463 #[getter]
475 fn hess<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyArray2<Float>>>> {
476 self.0
477 .hess
478 .clone()
479 .map(|hess| {
480 Ok(PyArray2::from_vec2(
481 py,
482 &hess
483 .row_iter()
484 .map(|row| row.iter().cloned().collect())
485 .collect::<Vec<Vec<Float>>>(),
486 )
487 .map_err(LadduError::NumpyError)?)
488 })
489 .transpose()
490 }
491 #[getter]
498 fn message(&self) -> String {
499 self.0.message.clone()
500 }
501 #[getter]
508 fn converged(&self) -> bool {
509 self.0.converged
510 }
511 #[getter]
518 fn bounds(&self) -> Option<Vec<PyBound>> {
519 self.0
520 .bounds
521 .clone()
522 .map(|bounds| bounds.iter().map(|bound| PyBound(*bound)).collect())
523 }
524 #[getter]
531 fn n_f_evals(&self) -> usize {
532 self.0.n_f_evals
533 }
534 #[getter]
541 fn n_g_evals(&self) -> usize {
542 self.0.n_g_evals
543 }
544 fn __str__(&self) -> String {
545 self.0.to_string()
546 }
547 fn __repr__(&self) -> String {
548 format!("{:?}", self.0)
549 }
550 fn save_as(&self, path: &str) -> PyResult<()> {
563 self.0.save_as(path)?;
564 Ok(())
565 }
566 #[staticmethod]
584 fn load_from(path: &str) -> PyResult<Self> {
585 Ok(PyStatus(Status::load_from(path)?))
586 }
587 #[new]
588 fn new() -> Self {
589 PyStatus(Status::create_null())
590 }
591 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
592 Ok(PyBytes::new(
593 py,
594 serialize(&self.0)
595 .map_err(LadduError::SerdeError)?
596 .as_slice(),
597 ))
598 }
599 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
600 *self = PyStatus(deserialize(state.as_bytes()).map_err(LadduError::SerdeError)?);
601 Ok(())
602 }
603 fn as_dict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
615 let dict = PyDict::new(py);
616 dict.set_item("x", self.x(py))?;
617 dict.set_item("err", self.err(py))?;
618 dict.set_item("x0", self.x0(py))?;
619 dict.set_item("fx", self.fx())?;
620 dict.set_item("cov", self.cov(py)?)?;
621 dict.set_item("hess", self.hess(py)?)?;
622 dict.set_item("message", self.message())?;
623 dict.set_item("converged", self.converged())?;
624 dict.set_item("bounds", self.bounds())?;
625 dict.set_item("n_f_evals", self.n_f_evals())?;
626 dict.set_item("n_g_evals", self.n_g_evals())?;
627 Ok(dict)
628 }
629 }
630
631 #[pyclass(name = "Ensemble", module = "laddu")]
634 #[derive(Clone)]
635 pub struct PyEnsemble(pub Ensemble);
636 #[pymethods]
637 impl PyEnsemble {
638 #[getter]
640 fn dimension(&self) -> (usize, usize, usize) {
641 self.0.dimension()
642 }
643 #[pyo3(signature = (*, burn = 0, thin = 1))]
664 fn get_chain<'py>(
665 &self,
666 py: Python<'py>,
667 burn: Option<usize>,
668 thin: Option<usize>,
669 ) -> PyResult<Bound<'py, PyArray3<Float>>> {
670 let chain = self.0.get_chain(burn, thin);
671 Ok(PyArray3::from_vec3(
672 py,
673 &chain
674 .iter()
675 .map(|walker| {
676 walker
677 .iter()
678 .map(|step| step.data.as_vec().to_vec())
679 .collect()
680 })
681 .collect::<Vec<_>>(),
682 )
683 .map_err(LadduError::NumpyError)?)
684 }
685 #[pyo3(signature = (*, burn = 0, thin = 1))]
706 fn get_flat_chain<'py>(
707 &self,
708 py: Python<'py>,
709 burn: Option<usize>,
710 thin: Option<usize>,
711 ) -> PyResult<Bound<'py, PyArray2<Float>>> {
712 let chain = self.0.get_flat_chain(burn, thin);
713 Ok(PyArray2::from_vec2(
714 py,
715 &chain
716 .iter()
717 .map(|step| step.data.as_vec().to_vec())
718 .collect::<Vec<_>>(),
719 )
720 .map_err(LadduError::NumpyError)?)
721 }
722 fn save_as(&self, path: &str) -> PyResult<()> {
735 self.0.save_as(path)?;
736 Ok(())
737 }
738 #[staticmethod]
756 fn load_from(path: &str) -> PyResult<Self> {
757 Ok(PyEnsemble(Ensemble::load_from(path)?))
758 }
759 #[new]
760 fn new() -> Self {
761 PyEnsemble(Ensemble::create_null())
762 }
763 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
764 Ok(PyBytes::new(
765 py,
766 serialize(&self.0)
767 .map_err(LadduError::SerdeError)?
768 .as_slice(),
769 ))
770 }
771 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
772 *self = PyEnsemble(deserialize(state.as_bytes()).map_err(LadduError::SerdeError)?);
773 Ok(())
774 }
775 #[pyo3(signature = (*, c=7.0, burn=0, thin=1))]
789 fn get_integrated_autocorrelation_times<'py>(
790 &self,
791 py: Python<'py>,
792 c: Option<Float>,
793 burn: Option<usize>,
794 thin: Option<usize>,
795 ) -> Bound<'py, PyArray1<Float>> {
796 PyArray1::from_slice(
797 py,
798 self.0
799 .get_integrated_autocorrelation_times(c, burn, thin)
800 .as_slice(),
801 )
802 }
803 }
804
805 #[pyfunction(name = "integrated_autocorrelation_times")]
818 #[pyo3(signature = (x, *, c=7.0))]
819 pub fn py_integrated_autocorrelation_times(
820 py: Python<'_>,
821 x: Vec<Vec<Vec<Float>>>,
822 c: Option<Float>,
823 ) -> Bound<'_, PyArray1<Float>> {
824 let x: Vec<Vec<DVector<Float>>> = x
825 .into_iter()
826 .map(|y| y.into_iter().map(DVector::from_vec).collect())
827 .collect();
828 PyArray1::from_slice(py, integrated_autocorrelation_times(x, c).as_slice())
829 }
830
831 #[pyclass(name = "AutocorrelationObserver", module = "laddu")]
852 pub struct PyAutocorrelationObserver(Arc<RwLock<AutocorrelationObserver>>);
853
854 #[pymethods]
855 impl PyAutocorrelationObserver {
856 #[new]
857 #[pyo3(signature = (*, n_check=50, n_taus_threshold=50, dtau_threshold=0.01, discard=0.5, terminate=true, c=7.0, verbose=false))]
858 fn new(
859 n_check: usize,
860 n_taus_threshold: usize,
861 dtau_threshold: Float,
862 discard: Float,
863 terminate: bool,
864 c: Float,
865 verbose: bool,
866 ) -> Self {
867 Self(
868 AutocorrelationObserver::default()
869 .with_n_check(n_check)
870 .with_n_taus_threshold(n_taus_threshold)
871 .with_dtau_threshold(dtau_threshold)
872 .with_discard(discard)
873 .with_terminate(terminate)
874 .with_sokal_window(c)
875 .with_verbose(verbose)
876 .build(),
877 )
878 }
879 #[getter]
882 fn taus<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
883 let taus = self.0.read().taus.clone();
884 PyArray1::from_vec(py, taus)
885 }
886 }
887
888 #[pyclass]
891 #[derive(Clone)]
892 #[pyo3(name = "Bound")]
893 pub struct PyBound(laddu_core::Bound);
894 #[pymethods]
895 impl PyBound {
896 #[getter]
903 fn lower(&self) -> Float {
904 self.0.lower()
905 }
906 #[getter]
913 fn upper(&self) -> Float {
914 self.0.upper()
915 }
916 }
917
918 impl Observer<()> for PyObserver {
919 fn callback(&mut self, step: usize, status: &mut Status, _user_data: &mut ()) -> bool {
920 let (new_status, result) = Python::with_gil(|py| {
921 let res = self
922 .0
923 .bind(py)
924 .call_method(
925 "callback",
926 (step, PyStatus(status.clone())),
927 None,
928 )
929 .expect("Observer does not have a \"callback(step: int, status: laddu.Status) -> tuple[laddu.Status, bool]\" method!");
930 let res_tuple = res
931 .downcast::<PyTuple>()
932 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!");
933 let new_status = res_tuple
934 .get_item(0)
935 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
936 .extract::<PyStatus>()
937 .expect("The first item returned from \"callback\" must be a \"laddu.Status\"!")
938 .0;
939 let result = res_tuple
940 .get_item(1)
941 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
942 .extract::<bool>()
943 .expect("The second item returned from \"callback\" must be a \"bool\"!");
944 (new_status, result)
945 });
946 *status = new_status;
947 result
948 }
949 }
950
951 #[cfg(feature = "rayon")]
952 impl Observer<ThreadPool> for PyObserver {
953 fn callback(
954 &mut self,
955 step: usize,
956 status: &mut Status,
957 _thread_pool: &mut ThreadPool,
958 ) -> bool {
959 let (new_status, result) = Python::with_gil(|py| {
960 let res = self
961 .0
962 .bind(py)
963 .call_method(
964 "callback",
965 (step, PyStatus(status.clone())),
966 None,
967 )
968 .expect("Observer does not have a \"callback(step: int, status: laddu.Status) -> tuple[laddu.Status, bool]\" method!");
969 let res_tuple = res
970 .downcast::<PyTuple>()
971 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!");
972 let new_status = res_tuple
973 .get_item(0)
974 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
975 .extract::<PyStatus>()
976 .expect("The first item returned from \"callback\" must be a \"laddu.Status\"!")
977 .0;
978 let result = res_tuple
979 .get_item(1)
980 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
981 .extract::<bool>()
982 .expect("The second item returned from \"callback\" must be a \"bool\"!");
983 (new_status, result)
984 });
985 *status = new_status;
986 result
987 }
988 }
989 impl FromPyObject<'_> for PyObserver {
990 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
991 Ok(PyObserver(ob.clone().into()))
992 }
993 }
994 impl MCMCObserver<()> for PyMCMCObserver {
995 fn callback(&mut self, step: usize, ensemble: &mut Ensemble, _user_data: &mut ()) -> bool {
996 let (new_ensemble, result) = Python::with_gil(|py| {
997 let res = self
998 .0
999 .bind(py)
1000 .call_method(
1001 "callback",
1002 (step, PyEnsemble(ensemble.clone())),
1003 None,
1004 )
1005 .expect("MCMCObserver does not have a \"callback(step: int, status: laddu.Ensemble) -> tuple[laddu.Ensemble, bool]\" method!");
1006 let res_tuple = res
1007 .downcast::<PyTuple>()
1008 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!");
1009 let new_status = res_tuple
1010 .get_item(0)
1011 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
1012 .extract::<PyEnsemble>()
1013 .expect(
1014 "The first item returned from \"callback\" must be a \"laddu.Ensemble\"!",
1015 )
1016 .0;
1017 let result = res_tuple
1018 .get_item(1)
1019 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
1020 .extract::<bool>()
1021 .expect("The second item returned from \"callback\" must be a \"bool\"!");
1022 (new_status, result)
1023 });
1024 *ensemble = new_ensemble;
1025 result
1026 }
1027 }
1028 #[cfg(feature = "rayon")]
1029 impl MCMCObserver<ThreadPool> for PyMCMCObserver {
1030 fn callback(
1031 &mut self,
1032 step: usize,
1033 ensemble: &mut Ensemble,
1034 _thread_pool: &mut ThreadPool,
1035 ) -> bool {
1036 let (new_ensemble, result) = Python::with_gil(|py| {
1037 let res = self
1038 .0
1039 .bind(py)
1040 .call_method(
1041 "callback",
1042 (step, PyEnsemble(ensemble.clone())),
1043 None,
1044 )
1045 .expect("MCMCObserver does not have a \"callback(step: int, status: laddu.Ensemble) -> tuple[laddu.Ensemble, bool]\" method!");
1046 let res_tuple = res
1047 .downcast::<PyTuple>()
1048 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!");
1049 let new_status = res_tuple
1050 .get_item(0)
1051 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
1052 .extract::<PyEnsemble>()
1053 .expect(
1054 "The first item returned from \"callback\" must be a \"laddu.Ensemble\"!",
1055 )
1056 .0;
1057 let result = res_tuple
1058 .get_item(1)
1059 .expect("\"callback\" method should return a \"tuple[laddu.Status, bool]\"!")
1060 .extract::<bool>()
1061 .expect("The second item returned from \"callback\" must be a \"bool\"!");
1062 (new_status, result)
1063 });
1064 *ensemble = new_ensemble;
1065 result
1066 }
1067 }
1068
1069 impl FromPyObject<'_> for PyMCMCObserver {
1070 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
1071 Ok(PyMCMCObserver(ob.clone().into()))
1072 }
1073 }
1074
1075 #[cfg(feature = "python")]
1076 pub(crate) fn py_parse_minimizer_options(
1077 n_parameters: usize,
1078 method: &str,
1079 max_steps: usize,
1080 debug: bool,
1081 verbose: bool,
1082 kwargs: Option<&Bound<'_, PyDict>>,
1083 ) -> PyResult<MinimizerOptions> {
1084 use ganesh::algorithms::lbfgsb::LBFGSBErrorMode;
1085
1086 let mut options = MinimizerOptions::default();
1087 let mut show_step = true;
1088 let mut show_x = true;
1089 let mut show_fx = true;
1090 if let Some(kwargs) = kwargs {
1091 show_step = kwargs.get_extract::<bool>("show_step")?.unwrap_or(true);
1092 show_x = kwargs.get_extract::<bool>("show_x")?.unwrap_or(true);
1093 show_fx = kwargs.get_extract::<bool>("show_fx")?.unwrap_or(true);
1094 let tol_x_rel = kwargs
1095 .get_extract::<Float>("tol_x_rel")?
1096 .unwrap_or(Float::EPSILON);
1097 let tol_x_abs = kwargs
1098 .get_extract::<Float>("tol_x_abs")?
1099 .unwrap_or(Float::EPSILON);
1100 let tol_f_rel = kwargs
1101 .get_extract::<Float>("tol_f_rel")?
1102 .unwrap_or(Float::EPSILON);
1103 let tol_f_abs = kwargs
1104 .get_extract::<Float>("tol_f_abs")?
1105 .unwrap_or(Float::EPSILON);
1106 let tol_g_abs = kwargs
1107 .get_extract::<Float>("tol_g_abs")?
1108 .unwrap_or(Float::cbrt(Float::EPSILON));
1109 let g_tolerance = kwargs.get_extract::<Float>("g_tolerance")?.unwrap_or(1e-5);
1110 let skip_hessian = kwargs.get_extract::<bool>("skip_hessian")?.unwrap_or(false);
1111 let adaptive = kwargs.get_extract::<bool>("adaptive")?.unwrap_or(false);
1112 let alpha = kwargs.get_extract::<Float>("alpha")?;
1113 let beta = kwargs.get_extract::<Float>("beta")?;
1114 let gamma = kwargs.get_extract::<Float>("gamma")?;
1115 let delta = kwargs.get_extract::<Float>("delta")?;
1116 let simplex_expansion_method = kwargs
1117 .get_extract::<String>("simplex_expansion_method")?
1118 .unwrap_or("greedy minimization".into());
1119 let nelder_mead_f_terminator = kwargs
1120 .get_extract::<String>("nelder_mead_f_terminator")?
1121 .unwrap_or("stddev".into());
1122 let nelder_mead_x_terminator = kwargs
1123 .get_extract::<String>("nelder_mead_x_terminator")?
1124 .unwrap_or("singer".into());
1125 #[cfg(feature = "rayon")]
1126 let threads = kwargs
1127 .get_extract::<usize>("threads")
1128 .unwrap_or(None)
1129 .unwrap_or_else(num_cpus::get);
1130 let mut observers: Vec<Arc<RwLock<PyObserver>>> = Vec::default();
1131 if let Ok(Some(observer_arg)) = kwargs.get_item("observers") {
1132 if let Ok(observer_list) = observer_arg.downcast::<PyList>() {
1133 for item in observer_list.iter() {
1134 let observer = item.extract::<PyObserver>()?;
1135 observers.push(Arc::new(RwLock::new(observer)));
1136 }
1137 } else if let Ok(single_observer) = observer_arg.extract::<PyObserver>() {
1138 observers.push(Arc::new(RwLock::new(single_observer)));
1139 } else {
1140 return Err(PyTypeError::new_err("The keyword argument \"observers\" must either be a single Observer or a list of Observers!"));
1141 }
1142 }
1143 for observer in observers {
1144 options = options.with_observer(observer);
1145 }
1146 match method {
1147 "lbfgsb" => {
1148 let mut lbfgsb = LBFGSB::default()
1149 .with_terminator_f(LBFGSBFTerminator { tol_f_abs })
1150 .with_terminator_g(LBFGSBGTerminator { tol_g_abs })
1151 .with_g_tolerance(g_tolerance);
1152 if skip_hessian {
1153 lbfgsb = lbfgsb.with_error_mode(LBFGSBErrorMode::Skip);
1154 }
1155 options = options.with_algorithm(lbfgsb)
1156 }
1157 "nelder_mead" => {
1158 let terminator_f = match nelder_mead_f_terminator.as_str() {
1159 "amoeba" => NelderMeadFTerminator::Amoeba { tol_f_rel },
1160 "absolute" => NelderMeadFTerminator::Absolute { tol_f_abs },
1161 "stddev" => NelderMeadFTerminator::StdDev { tol_f_abs },
1162 "none" => NelderMeadFTerminator::None,
1163 _ => {
1164 return Err(PyValueError::new_err(format!(
1165 "Invalid \"nelder_mead_f_terminator\": \"{}\"",
1166 nelder_mead_f_terminator
1167 )))
1168 }
1169 };
1170 let terminator_x = match nelder_mead_x_terminator.as_str() {
1171 "diameter" => NelderMeadXTerminator::Diameter { tol_x_abs },
1172 "higham" => NelderMeadXTerminator::Higham { tol_x_rel },
1173 "rowan" => NelderMeadXTerminator::Rowan { tol_x_rel },
1174 "singer" => NelderMeadXTerminator::Singer { tol_x_rel },
1175 "none" => NelderMeadXTerminator::None,
1176 _ => {
1177 return Err(PyValueError::new_err(format!(
1178 "Invalid \"nelder_mead_x_terminator\": \"{}\"",
1179 nelder_mead_x_terminator
1180 )))
1181 }
1182 };
1183 let simplex_expansion_method = match simplex_expansion_method.as_str() {
1184 "greedy minimization" => SimplexExpansionMethod::GreedyMinimization,
1185 "greedy expansion" => SimplexExpansionMethod::GreedyExpansion,
1186 _ => {
1187 return Err(PyValueError::new_err(format!(
1188 "Invalid \"simplex_expansion_method\": \"{}\"",
1189 simplex_expansion_method
1190 )))
1191 }
1192 };
1193 let mut nelder_mead = NelderMead::default()
1194 .with_terminator_f(terminator_f)
1195 .with_terminator_x(terminator_x)
1196 .with_expansion_method(simplex_expansion_method);
1197 if adaptive {
1198 nelder_mead = nelder_mead.with_adaptive(n_parameters);
1199 }
1200 if let Some(alpha) = alpha {
1201 nelder_mead = nelder_mead.with_alpha(alpha);
1202 }
1203 if let Some(beta) = beta {
1204 nelder_mead = nelder_mead.with_beta(beta);
1205 }
1206 if let Some(gamma) = gamma {
1207 nelder_mead = nelder_mead.with_gamma(gamma);
1208 }
1209 if let Some(delta) = delta {
1210 nelder_mead = nelder_mead.with_delta(delta);
1211 }
1212 if skip_hessian {
1213 nelder_mead = nelder_mead.with_no_error_calculation();
1214 }
1215 options = options.with_algorithm(nelder_mead)
1216 }
1217 _ => {
1218 return Err(PyValueError::new_err(format!(
1219 "Invalid \"method\": \"{}\"",
1220 method
1221 )))
1222 }
1223 }
1224 #[cfg(feature = "rayon")]
1225 {
1226 options = options.with_threads(threads);
1227 }
1228 }
1229 if debug {
1230 options = options.debug();
1231 }
1232 if verbose {
1233 options = options.verbose(show_step, show_x, show_fx);
1234 }
1235 options = options.with_max_steps(max_steps);
1236 Ok(options)
1237 }
1238
1239 #[cfg(feature = "python")]
1240 pub(crate) fn py_parse_mcmc_options(
1241 method: &str,
1242 debug: bool,
1243 verbose: bool,
1244 kwargs: Option<&Bound<'_, PyDict>>,
1245 rng: Rng,
1246 ) -> PyResult<MCMCOptions> {
1247 let default_ess_moves = [ESSMove::differential(0.9), ESSMove::gaussian(0.1)];
1248 let default_aies_moves = [AIESMove::stretch(0.9), AIESMove::walk(0.1)];
1249 let mut options = MCMCOptions::new_ess(default_ess_moves, rng.clone());
1250 if let Some(kwargs) = kwargs {
1251 let n_adaptive = kwargs.get_extract::<usize>("n_adaptive")?.unwrap_or(100);
1252 let mu = kwargs.get_extract::<Float>("mu")?.unwrap_or(1.0);
1253 let max_ess_steps = kwargs
1254 .get_extract::<usize>("max_ess_steps")?
1255 .unwrap_or(10000);
1256 let mut ess_moves: Vec<WeightedESSMove> = Vec::default();
1257 if let Ok(Some(ess_move_list_arg)) = kwargs.get_item("ess_moves") {
1258 if let Ok(ess_move_list) = ess_move_list_arg.downcast::<PyList>() {
1259 for item in ess_move_list.iter() {
1260 let item_tuple = item.downcast::<PyTuple>()?;
1261 let move_name = item_tuple.get_item(0)?.extract::<String>()?;
1262 let move_weight = item_tuple.get_item(1)?.extract::<Float>()?;
1263 match move_name.to_lowercase().as_ref() {
1264 "differential" => ess_moves.push(ESSMove::differential(move_weight)),
1265 "gaussian" => ess_moves.push(ESSMove::gaussian(move_weight)),
1266 _ => {
1267 return Err(PyValueError::new_err(format!(
1268 "Unknown ESS move type: {}",
1269 move_name
1270 )))
1271 }
1272 }
1273 }
1274 }
1275 }
1276 if ess_moves.is_empty() {
1277 ess_moves = default_ess_moves.to_vec();
1278 }
1279 let mut aies_moves: Vec<WeightedAIESMove> = Vec::default();
1280 if let Ok(Some(aies_move_list_arg)) = kwargs.get_item("aies_moves") {
1281 if let Ok(aies_move_list) = aies_move_list_arg.downcast::<PyList>() {
1282 for item in aies_move_list.iter() {
1283 let item_tuple = item.downcast::<PyTuple>()?;
1284 if let Ok(move_name) = item_tuple.get_item(0)?.extract::<String>() {
1285 let move_weight = item_tuple.get_item(1)?.extract::<Float>()?;
1286 match move_name.to_lowercase().as_ref() {
1287 "stretch" => aies_moves.push(AIESMove::stretch(move_weight)),
1288 "walk" => aies_moves.push(AIESMove::walk(move_weight)),
1289 _ => {
1290 return Err(PyValueError::new_err(format!(
1291 "Unknown AIES move type: {}",
1292 move_name
1293 )))
1294 }
1295 }
1296 } else if let Ok(move_spec) = item_tuple.get_item(0)?.downcast::<PyTuple>()
1297 {
1298 let move_name = move_spec.get_item(0)?.extract::<String>()?;
1299 let move_weight = item_tuple.get_item(1)?.extract::<Float>()?;
1300 if move_name.to_lowercase() == "stretch" {
1301 let a = move_spec.get_item(1)?.extract::<Float>()?;
1302 aies_moves.push((AIESMove::Stretch { a }, move_weight))
1303 } else {
1304 return Err(PyValueError::new_err(
1305 "Only the 'stretch' move has a hyperparameter",
1306 ));
1307 }
1308 }
1309 }
1310 }
1311 }
1312 if aies_moves.is_empty() {
1313 aies_moves = default_aies_moves.to_vec();
1314 }
1315 #[cfg(feature = "rayon")]
1316 let threads = kwargs
1317 .get_extract::<usize>("threads")
1318 .unwrap_or(None)
1319 .unwrap_or_else(num_cpus::get);
1320 #[cfg(feature = "rayon")]
1321 let mut observers: Vec<
1322 Arc<RwLock<dyn ganesh::mcmc::MCMCObserver<ThreadPool>>>,
1323 > = Vec::default();
1324 #[cfg(not(feature = "rayon"))]
1325 let mut observers: Vec<Arc<RwLock<dyn ganesh::mcmc::MCMCObserver<()>>>> =
1326 Vec::default();
1327 if let Ok(Some(observer_arg)) = kwargs.get_item("observers") {
1328 if let Ok(observer_list) = observer_arg.downcast::<PyList>() {
1329 for item in observer_list.iter() {
1330 if let Ok(observer) = item.downcast::<PyAutocorrelationObserver>() {
1331 observers.push(observer.borrow().0.clone());
1332 } else if let Ok(observer) = item.extract::<PyMCMCObserver>() {
1333 observers.push(Arc::new(RwLock::new(observer)));
1334 }
1335 }
1336 } else if let Ok(single_observer) =
1337 observer_arg.downcast::<PyAutocorrelationObserver>()
1338 {
1339 observers.push(single_observer.borrow().0.clone());
1340 } else if let Ok(single_observer) = observer_arg.extract::<PyMCMCObserver>() {
1341 observers.push(Arc::new(RwLock::new(single_observer)));
1342 } else {
1343 return Err(PyTypeError::new_err("The keyword argument \"observers\" must either be a single MCMCObserver or a list of MCMCObservers!"));
1344 }
1345 }
1346 for observer in observers {
1347 options = options.with_observer(observer.clone());
1348 }
1349 match method.to_lowercase().as_ref() {
1350 "ess" => {
1351 options = options.with_algorithm(
1352 ESS::new(ess_moves, rng)
1353 .with_mu(mu)
1354 .with_n_adaptive(n_adaptive)
1355 .with_max_steps(max_ess_steps),
1356 )
1357 }
1358 "aies" => options = options.with_algorithm(AIES::new(aies_moves, rng)),
1359 _ => {
1360 return Err(PyValueError::new_err(format!(
1361 "Invalid \"method\": \"{}\"",
1362 method
1363 )))
1364 }
1365 }
1366 #[cfg(feature = "rayon")]
1367 {
1368 options = options.with_threads(threads);
1369 }
1370 }
1371 if debug {
1372 options = options.debug();
1373 }
1374 if verbose {
1375 options = options.verbose();
1376 }
1377 Ok(options)
1378 }
1379}