1use crate::{
2 likelihoods::{LikelihoodTerm, StochasticNLL},
3 LikelihoodEvaluator, NLL,
4};
5use ganesh::{
6 algorithms::{
7 gradient::{Adam, AdamConfig, GradientStatus, LBFGSBConfig, LBFGSB},
8 gradient_free::{GradientFreeStatus, NelderMead, NelderMeadConfig},
9 mcmc::{AIESConfig, ESSConfig, EnsembleStatus, AIES, ESS},
10 particles::{PSOConfig, SwarmStatus, PSO},
11 },
12 core::{summary::HasParameterNames, Callbacks, MCMCSummary, MinimizationSummary},
13 traits::{Algorithm, CostFunction, Gradient, LogDensity, Observer, Status},
14};
15use laddu_core::{Float, LadduError};
16use nalgebra::DVector;
17#[cfg(feature = "rayon")]
18use rayon::{ThreadPool, ThreadPoolBuilder};
19
20pub enum MinimizationSettings<P> {
22 LBFGSB {
24 config: LBFGSBConfig,
26 callbacks: Callbacks<LBFGSB, P, GradientStatus, MaybeThreadPool, LadduError, LBFGSBConfig>,
28 num_threads: usize,
30 },
31 Adam {
33 config: AdamConfig,
35 callbacks: Callbacks<Adam, P, GradientStatus, MaybeThreadPool, LadduError, AdamConfig>,
37 num_threads: usize,
39 },
40 NelderMead {
42 config: NelderMeadConfig,
44 callbacks: Callbacks<
46 NelderMead,
47 P,
48 GradientFreeStatus,
49 MaybeThreadPool,
50 LadduError,
51 NelderMeadConfig,
52 >,
53 num_threads: usize,
55 },
56 PSO {
58 config: PSOConfig,
60 callbacks: Callbacks<PSO, P, SwarmStatus, MaybeThreadPool, LadduError, PSOConfig>,
62 num_threads: usize,
64 },
65}
66
67pub enum MCMCSettings<P> {
69 AIES {
71 config: AIESConfig,
73 callbacks: Callbacks<AIES, P, EnsembleStatus, MaybeThreadPool, LadduError, AIESConfig>,
75 num_threads: usize,
77 },
78 ESS {
80 config: ESSConfig,
82 callbacks: Callbacks<ESS, P, EnsembleStatus, MaybeThreadPool, LadduError, ESSConfig>,
84 num_threads: usize,
86 },
87}
88
89#[derive(Debug)]
91pub struct MaybeThreadPool {
92 #[cfg(feature = "rayon")]
93 pub thread_pool: ThreadPool,
95}
96pub trait Threadable {
98 fn get_pool(&self) -> Result<MaybeThreadPool, LadduError>;
104}
105impl<P> Threadable for MinimizationSettings<P> {
106 fn get_pool(&self) -> Result<MaybeThreadPool, LadduError> {
107 #[cfg(feature = "rayon")]
108 {
109 Ok(MaybeThreadPool {
110 thread_pool: ThreadPoolBuilder::new()
111 .num_threads(match self {
112 Self::LBFGSB {
113 config: _,
114 callbacks: _,
115 num_threads,
116 }
117 | Self::Adam {
118 config: _,
119 callbacks: _,
120 num_threads,
121 }
122 | Self::NelderMead {
123 config: _,
124 callbacks: _,
125 num_threads,
126 }
127 | Self::PSO {
128 config: _,
129 callbacks: _,
130 num_threads,
131 } => *num_threads,
132 })
133 .build()
134 .map_err(LadduError::from)?,
135 })
136 }
137 #[cfg(not(feature = "rayon"))]
138 {
139 Ok(MaybeThreadPool {})
140 }
141 }
142}
143
144impl<P> Threadable for MCMCSettings<P> {
145 fn get_pool(&self) -> Result<MaybeThreadPool, LadduError> {
146 #[cfg(feature = "rayon")]
147 {
148 Ok(MaybeThreadPool {
149 thread_pool: ThreadPoolBuilder::new()
150 .num_threads(match self {
151 Self::AIES {
152 config: _,
153 callbacks: _,
154 num_threads,
155 }
156 | Self::ESS {
157 config: _,
158 callbacks: _,
159 num_threads,
160 } => *num_threads,
161 })
162 .build()
163 .map_err(LadduError::from)?,
164 })
165 }
166 #[cfg(not(feature = "rayon"))]
167 {
168 Ok(MaybeThreadPool {})
169 }
170 }
171}
172
173#[derive(Copy, Clone)]
174struct LikelihoodTermObserver;
175impl<A, P, S, U, E, C> Observer<A, P, S, U, E, C> for LikelihoodTermObserver
176where
177 A: Algorithm<P, S, U, E, Config = C>,
178 P: LikelihoodTerm,
179 S: Status,
180{
181 fn observe(
182 &mut self,
183 _current_step: usize,
184 _algorithm: &A,
185 problem: &P,
186 _status: &S,
187 _args: &U,
188 _config: &C,
189 ) {
190 problem.update();
191 }
192}
193
194impl CostFunction<MaybeThreadPool, LadduError> for NLL {
195 fn evaluate(
196 &self,
197 parameters: &DVector<Float>,
198 args: &MaybeThreadPool,
199 ) -> Result<Float, LadduError> {
200 #[cfg(feature = "rayon")]
201 {
202 Ok(args
203 .thread_pool
204 .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
205 }
206 #[cfg(not(feature = "rayon"))]
207 {
208 Ok(LikelihoodTerm::evaluate(self, parameters.into()))
209 }
210 }
211}
212impl Gradient<MaybeThreadPool, LadduError> for NLL {
213 fn gradient(
214 &self,
215 parameters: &DVector<Float>,
216 args: &MaybeThreadPool,
217 ) -> Result<DVector<Float>, LadduError> {
218 #[cfg(feature = "rayon")]
219 {
220 Ok(args
221 .thread_pool
222 .install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into())))
223 }
224 #[cfg(not(feature = "rayon"))]
225 {
226 Ok(LikelihoodTerm::evaluate_gradient(self, parameters.into()))
227 }
228 }
229}
230impl LogDensity<MaybeThreadPool, LadduError> for NLL {
231 fn log_density(
232 &self,
233 parameters: &DVector<Float>,
234 args: &MaybeThreadPool,
235 ) -> Result<Float, LadduError> {
236 #[cfg(feature = "rayon")]
237 {
238 Ok(-args
239 .thread_pool
240 .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
241 }
242 #[cfg(not(feature = "rayon"))]
243 {
244 Ok(-LikelihoodTerm::evaluate(self, parameters.into()))
245 }
246 }
247}
248
249impl NLL {
250 pub fn minimize(
257 &self,
258 settings: MinimizationSettings<Self>,
259 ) -> Result<MinimizationSummary, LadduError> {
260 let mtp = settings.get_pool()?;
261 Ok(match settings {
262 MinimizationSettings::LBFGSB {
263 config,
264 callbacks,
265 num_threads: _,
266 } => LBFGSB::default().process(
267 self,
268 &mtp,
269 config,
270 callbacks.with_observer(LikelihoodTermObserver),
271 ),
272 MinimizationSettings::Adam {
273 config,
274 callbacks,
275 num_threads: _,
276 } => Adam::default().process(
277 self,
278 &mtp,
279 config,
280 callbacks.with_observer(LikelihoodTermObserver),
281 ),
282 MinimizationSettings::NelderMead {
283 config,
284 callbacks,
285 num_threads: _,
286 } => NelderMead::default().process(
287 self,
288 &mtp,
289 config,
290 callbacks.with_observer(LikelihoodTermObserver),
291 ),
292 MinimizationSettings::PSO {
293 config,
294 callbacks,
295 num_threads: _,
296 } => PSO::default().process(
297 self,
298 &mtp,
299 config,
300 callbacks.with_observer(LikelihoodTermObserver),
301 ),
302 }?
303 .with_parameter_names(self.parameters()))
304 }
305
306 pub fn mcmc(&self, settings: MCMCSettings<Self>) -> Result<MCMCSummary, LadduError> {
313 let mtp = settings.get_pool()?;
314 Ok(match settings {
315 MCMCSettings::AIES {
316 config,
317 callbacks,
318 num_threads: _,
319 } => AIES::default().process(
320 self,
321 &mtp,
322 config,
323 callbacks.with_observer(LikelihoodTermObserver),
324 ),
325 MCMCSettings::ESS {
326 config,
327 callbacks,
328 num_threads: _,
329 } => ESS::default().process(
330 self,
331 &mtp,
332 config,
333 callbacks.with_observer(LikelihoodTermObserver),
334 ),
335 }?
336 .with_parameter_names(self.parameters()))
337 }
338}
339
340impl CostFunction<MaybeThreadPool, LadduError> for StochasticNLL {
341 fn evaluate(
342 &self,
343 parameters: &DVector<Float>,
344 args: &MaybeThreadPool,
345 ) -> Result<Float, LadduError> {
346 #[cfg(feature = "rayon")]
347 {
348 Ok(args
349 .thread_pool
350 .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
351 }
352 #[cfg(not(feature = "rayon"))]
353 {
354 Ok(LikelihoodTerm::evaluate(self, parameters.into()))
355 }
356 }
357}
358impl Gradient<MaybeThreadPool, LadduError> for StochasticNLL {
359 fn gradient(
360 &self,
361 parameters: &DVector<Float>,
362 args: &MaybeThreadPool,
363 ) -> Result<DVector<Float>, LadduError> {
364 #[cfg(feature = "rayon")]
365 {
366 Ok(args
367 .thread_pool
368 .install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into())))
369 }
370 #[cfg(not(feature = "rayon"))]
371 {
372 Ok(LikelihoodTerm::evaluate_gradient(self, parameters.into()))
373 }
374 }
375}
376impl LogDensity<MaybeThreadPool, LadduError> for StochasticNLL {
377 fn log_density(
378 &self,
379 parameters: &DVector<Float>,
380 args: &MaybeThreadPool,
381 ) -> Result<Float, LadduError> {
382 #[cfg(feature = "rayon")]
383 {
384 Ok(-args
385 .thread_pool
386 .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
387 }
388 #[cfg(not(feature = "rayon"))]
389 {
390 Ok(-LikelihoodTerm::evaluate(self, parameters.into()))
391 }
392 }
393}
394
395impl StochasticNLL {
396 pub fn minimize(
403 &self,
404 settings: MinimizationSettings<Self>,
405 ) -> Result<MinimizationSummary, LadduError> {
406 let mtp = settings.get_pool()?;
407 Ok(match settings {
408 MinimizationSettings::LBFGSB {
409 config,
410 callbacks,
411 num_threads: _,
412 } => LBFGSB::default().process(
413 self,
414 &mtp,
415 config,
416 callbacks.with_observer(LikelihoodTermObserver),
417 ),
418 MinimizationSettings::Adam {
419 config,
420 callbacks,
421 num_threads: _,
422 } => Adam::default().process(
423 self,
424 &mtp,
425 config,
426 callbacks.with_observer(LikelihoodTermObserver),
427 ),
428 MinimizationSettings::NelderMead {
429 config,
430 callbacks,
431 num_threads: _,
432 } => NelderMead::default().process(
433 self,
434 &mtp,
435 config,
436 callbacks.with_observer(LikelihoodTermObserver),
437 ),
438 MinimizationSettings::PSO {
439 config,
440 callbacks,
441 num_threads: _,
442 } => PSO::default().process(
443 self,
444 &mtp,
445 config,
446 callbacks.with_observer(LikelihoodTermObserver),
447 ),
448 }?
449 .with_parameter_names(self.parameters()))
450 }
451
452 pub fn mcmc(&self, settings: MCMCSettings<Self>) -> Result<MCMCSummary, LadduError> {
459 let mtp = settings.get_pool()?;
460 Ok(match settings {
461 MCMCSettings::AIES {
462 config,
463 callbacks,
464 num_threads: _,
465 } => AIES::default().process(
466 self,
467 &mtp,
468 config,
469 callbacks.with_observer(LikelihoodTermObserver),
470 ),
471 MCMCSettings::ESS {
472 config,
473 callbacks,
474 num_threads: _,
475 } => ESS::default().process(
476 self,
477 &mtp,
478 config,
479 callbacks.with_observer(LikelihoodTermObserver),
480 ),
481 }?
482 .with_parameter_names(self.parameters()))
483 }
484}
485
486impl CostFunction<MaybeThreadPool, LadduError> for LikelihoodEvaluator {
487 fn evaluate(
488 &self,
489 parameters: &DVector<Float>,
490 args: &MaybeThreadPool,
491 ) -> Result<Float, LadduError> {
492 #[cfg(feature = "rayon")]
493 {
494 Ok(args
495 .thread_pool
496 .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
497 }
498 #[cfg(not(feature = "rayon"))]
499 {
500 Ok(LikelihoodTerm::evaluate(self, parameters.into()))
501 }
502 }
503}
504impl Gradient<MaybeThreadPool, LadduError> for LikelihoodEvaluator {
505 fn gradient(
506 &self,
507 parameters: &DVector<Float>,
508 args: &MaybeThreadPool,
509 ) -> Result<DVector<Float>, LadduError> {
510 #[cfg(feature = "rayon")]
511 {
512 Ok(args
513 .thread_pool
514 .install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into())))
515 }
516 #[cfg(not(feature = "rayon"))]
517 {
518 Ok(LikelihoodTerm::evaluate_gradient(self, parameters.into()))
519 }
520 }
521}
522impl LogDensity<MaybeThreadPool, LadduError> for LikelihoodEvaluator {
523 fn log_density(
524 &self,
525 parameters: &DVector<Float>,
526 args: &MaybeThreadPool,
527 ) -> Result<Float, LadduError> {
528 #[cfg(feature = "rayon")]
529 {
530 Ok(-args
531 .thread_pool
532 .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
533 }
534 #[cfg(not(feature = "rayon"))]
535 {
536 Ok(-LikelihoodTerm::evaluate(self, parameters.into()))
537 }
538 }
539}
540
541impl LikelihoodEvaluator {
542 pub fn minimize(
549 &self,
550 settings: MinimizationSettings<Self>,
551 ) -> Result<MinimizationSummary, LadduError> {
552 let mtp = settings.get_pool()?;
553 match settings {
554 MinimizationSettings::LBFGSB {
555 config,
556 callbacks,
557 num_threads: _,
558 } => LBFGSB::default().process(
559 self,
560 &mtp,
561 config,
562 callbacks.with_observer(LikelihoodTermObserver),
563 ),
564 MinimizationSettings::Adam {
565 config,
566 callbacks,
567 num_threads: _,
568 } => Adam::default().process(
569 self,
570 &mtp,
571 config,
572 callbacks.with_observer(LikelihoodTermObserver),
573 ),
574 MinimizationSettings::NelderMead {
575 config,
576 callbacks,
577 num_threads: _,
578 } => NelderMead::default().process(
579 self,
580 &mtp,
581 config,
582 callbacks.with_observer(LikelihoodTermObserver),
583 ),
584 MinimizationSettings::PSO {
585 config,
586 callbacks,
587 num_threads: _,
588 } => PSO::default().process(
589 self,
590 &mtp,
591 config,
592 callbacks.with_observer(LikelihoodTermObserver),
593 ),
594 }
595 }
596
597 pub fn mcmc(&self, settings: MCMCSettings<Self>) -> Result<MCMCSummary, LadduError> {
604 let mtp = settings.get_pool()?;
605 match settings {
606 MCMCSettings::AIES {
607 config,
608 callbacks,
609 num_threads: _,
610 } => AIES::default().process(
611 self,
612 &mtp,
613 config,
614 callbacks.with_observer(LikelihoodTermObserver),
615 ),
616 MCMCSettings::ESS {
617 config,
618 callbacks,
619 num_threads: _,
620 } => ESS::default().process(
621 self,
622 &mtp,
623 config,
624 callbacks.with_observer(LikelihoodTermObserver),
625 ),
626 }
627 }
628}
629
630#[cfg(feature = "python")]
632pub mod py_ganesh {
633 use std::{ops::ControlFlow, sync::Arc};
634
635 use super::*;
636
637 use ganesh::{
638 algorithms::{
639 gradient::{
640 adam::AdamEMATerminator,
641 lbfgsb::{
642 LBFGSBErrorMode, LBFGSBFTerminator, LBFGSBGTerminator, LBFGSBInfNormGTerminator,
643 },
644 },
645 gradient_free::nelder_mead::{
646 NelderMeadFTerminator, NelderMeadXTerminator, SimplexConstructionMethod,
647 SimplexExpansionMethod,
648 },
649 line_search::{HagerZhangLineSearch, MoreThuenteLineSearch, StrongWolfeLineSearch},
650 mcmc::{
651 integrated_autocorrelation_times, AIESMove, AutocorrelationTerminator, ESSMove,
652 Walker,
653 },
654 particles::{
655 Swarm, SwarmBoundaryMethod, SwarmParticle, SwarmPositionInitializer, SwarmTopology,
656 SwarmUpdateMethod, SwarmVelocityInitializer,
657 },
658 },
659 core::{Bounds, CtrlCAbortSignal, DebugObserver, MaxSteps},
660 traits::{Observer, Status, SupportsBounds, SupportsTransform, Terminator},
661 };
662 use laddu_core::{Float, LadduError, ReadWrite};
663 use nalgebra::DMatrix;
664 use numpy::{PyArray1, PyArray2, PyArray3, ToPyArray};
665 use parking_lot::Mutex;
666 use pyo3::{
667 exceptions::{PyTypeError, PyValueError},
668 prelude::*,
669 types::{PyBytes, PyDict, PyList},
670 };
671
672 pub trait FromPyArgs<A = ()>: Sized {
674 fn from_pyargs(args: &A, d: &Bound<PyDict>) -> PyResult<Self>;
676 }
677 impl FromPyArgs<Vec<Float>> for LBFGSBConfig {
678 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
679 let mut config = LBFGSBConfig::new(args);
680 if let Some(m) = d.get_item("m")? {
681 let m_int = m.extract()?;
682 config = config.with_memory_limit(m_int);
683 }
684 if let Some(flag) = d.get_item("skip_hessian")? {
685 if flag.extract()? {
686 config = config.with_error_mode(LBFGSBErrorMode::Skip);
687 }
688 }
689 if let Some(linesearch_dict) = d.get_item("line_search")? {
690 config = config.with_line_search(StrongWolfeLineSearch::from_pyargs(
691 &(),
692 &linesearch_dict.extract()?,
693 )?);
694 }
695 Ok(config)
696 }
697 }
698 impl FromPyArgs for StrongWolfeLineSearch {
699 fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
700 if let Some(method) = d.get_item("method")? {
701 match method
702 .extract::<String>()?
703 .to_lowercase()
704 .trim()
705 .replace("-", "")
706 .replace(" ", "")
707 .as_str()
708 {
709 "morethuente" => {
710 let mut line_search = MoreThuenteLineSearch::default();
711 if let Some(max_iterations) = d.get_item("max_iterations")? {
712 line_search =
713 line_search.with_max_iterations(max_iterations.extract()?);
714 }
715 if let Some(max_zoom) = d.get_item("max_zoom")? {
716 line_search = line_search.with_max_zoom(max_zoom.extract()?);
717 }
718 match (d.get_item("c1")?, d.get_item("c2")?) {
719 (Some(c1), Some(c2)) => {
720 line_search = line_search.with_c1_c2(c1.extract()?, c2.extract()?);
721 }
722 (Some(c1), None) => {
723 line_search = line_search.with_c1(c1.extract()?);
724 }
725 (None, Some(c2)) => {
726 line_search = line_search.with_c2(c2.extract()?);
727 }
728 (None, None) => {}
729 }
730 Ok(StrongWolfeLineSearch::MoreThuente(line_search))
731 }
732 "hagerzhang" => {
733 let mut line_search = HagerZhangLineSearch::default();
734 if let Some(max_iterations) = d.get_item("max_iterations")? {
735 line_search =
736 line_search.with_max_iterations(max_iterations.extract()?);
737 }
738 if let Some(max_bisects) = d.get_item("max_bisects")? {
739 line_search = line_search.with_max_bisects(max_bisects.extract()?);
740 }
741 match (d.get_item("delta")?, d.get_item("sigma")?) {
742 (Some(delta), Some(sigma)) => {
743 line_search = line_search
744 .with_delta_sigma(delta.extract()?, sigma.extract()?);
745 }
746 (Some(delta), None) => {
747 line_search = line_search.with_delta(delta.extract()?);
748 }
749 (None, Some(sigma)) => {
750 line_search = line_search.with_sigma(sigma.extract()?);
751 }
752 (None, None) => {}
753 }
754 if let Some(epsilon) = d.get_item("epsilon")? {
755 line_search = line_search.with_epsilon(epsilon.extract()?);
756 }
757 if let Some(theta) = d.get_item("theta")? {
758 line_search = line_search.with_theta(theta.extract()?);
759 }
760 if let Some(gamma) = d.get_item("gamma")? {
761 line_search = line_search.with_gamma(gamma.extract()?);
762 }
763 Ok(StrongWolfeLineSearch::HagerZhang(line_search))
764 }
765 _ => Err(PyTypeError::new_err(format!(
766 "Invalid line search method: {}",
767 method
768 ))),
769 }
770 } else {
771 Err(PyTypeError::new_err("Line search method not specified"))
772 }
773 }
774 }
775 impl<P> FromPyArgs
776 for Callbacks<LBFGSB, P, GradientStatus, MaybeThreadPool, LadduError, LBFGSBConfig>
777 where
778 P: Gradient<MaybeThreadPool, LadduError>,
779 {
780 fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
781 let mut callbacks = Callbacks::empty();
782 if let Some(eps_f) = d.get_item("eps_f")? {
783 if let Some(eps_abs) = eps_f.extract()? {
784 callbacks = callbacks.with_terminator(LBFGSBFTerminator { eps_abs });
785 } else {
786 callbacks = callbacks.with_terminator(LBFGSBFTerminator::default());
787 }
788 } else {
789 callbacks = callbacks.with_terminator(LBFGSBFTerminator::default());
790 }
791 if let Some(eps_g) = d.get_item("eps_g")? {
792 if let Some(eps_abs) = eps_g.extract()? {
793 callbacks = callbacks.with_terminator(LBFGSBGTerminator { eps_abs });
794 } else {
795 callbacks = callbacks.with_terminator(LBFGSBGTerminator::default());
796 }
797 } else {
798 callbacks = callbacks.with_terminator(LBFGSBGTerminator::default());
799 }
800 if let Some(eps_norm_g) = d.get_item("eps_norm_g")? {
801 if let Some(eps_abs) = eps_norm_g.extract()? {
802 callbacks = callbacks.with_terminator(LBFGSBInfNormGTerminator { eps_abs });
803 } else {
804 callbacks = callbacks.with_terminator(LBFGSBInfNormGTerminator::default());
805 }
806 } else {
807 callbacks = callbacks.with_terminator(LBFGSBInfNormGTerminator::default());
808 }
809 Ok(callbacks)
810 }
811 }
812 impl FromPyArgs<Vec<Float>> for AdamConfig {
813 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
814 let mut config = AdamConfig::new(args);
815 if let Some(alpha) = d.get_item("alpha")? {
816 config = config.with_alpha(alpha.extract()?);
817 }
818 if let Some(beta_1) = d.get_item("beta_1")? {
819 config = config.with_beta_1(beta_1.extract()?);
820 }
821 if let Some(beta_2) = d.get_item("beta_2")? {
822 config = config.with_beta_2(beta_2.extract()?);
823 }
824 if let Some(epsilon) = d.get_item("epsilon")? {
825 config = config.with_epsilon(epsilon.extract()?);
826 }
827 Ok(config)
828 }
829 }
830 impl<P> FromPyArgs for Callbacks<Adam, P, GradientStatus, MaybeThreadPool, LadduError, AdamConfig>
831 where
832 P: Gradient<MaybeThreadPool, LadduError>,
833 {
834 fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
835 let mut callbacks = Callbacks::empty();
836 let mut term = AdamEMATerminator::default();
837 if let Some(beta_c) = d.get_item("beta_c")? {
838 term.beta_c = beta_c.extract()?;
839 }
840 if let Some(eps_loss) = d.get_item("eps_loss")? {
841 term.eps_loss = eps_loss.extract()?;
842 }
843 if let Some(patience) = d.get_item("patience")? {
844 term.patience = patience.extract()?;
845 }
846 callbacks = callbacks.with_terminator(term);
847 Ok(callbacks)
848 }
849 }
850 impl FromPyArgs<Vec<Float>> for NelderMeadConfig {
851 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
852 let construction_method = SimplexConstructionMethod::from_pyargs(args, d)?;
853 let mut config = NelderMeadConfig::new_with_method(construction_method);
854 if let Some(alpha) = d.get_item("alpha")? {
855 config = config.with_alpha(alpha.extract()?);
856 }
857 if let Some(beta) = d.get_item("beta")? {
858 config = config.with_beta(beta.extract()?);
859 }
860 if let Some(gamma) = d.get_item("gamma")? {
861 config = config.with_gamma(gamma.extract()?);
862 }
863 if let Some(delta) = d.get_item("delta")? {
864 config = config.with_delta(delta.extract()?);
865 }
866 if let Some(adaptive) = d.get_item("adaptive")? {
867 if adaptive.extract()? {
868 config = config.with_adaptive(args.len());
869 }
870 }
871 if let Some(expansion_method) = d.get_item("expansion_method")? {
872 match expansion_method
873 .extract::<String>()?
874 .to_lowercase()
875 .trim()
876 .replace("-", "")
877 .replace(" ", "")
878 .as_str()
879 {
880 "greedyminimization" => {
881 config = config
882 .with_expansion_method(SimplexExpansionMethod::GreedyMinimization);
883 Ok(())
884 }
885 "greedyexpansion" => {
886 config = config
887 .with_expansion_method(SimplexExpansionMethod::GreedyMinimization);
888 Ok(())
889 }
890 _ => Err(PyValueError::new_err(format!(
891 "Invalid expansion method: {}",
892 expansion_method
893 ))),
894 }?
895 }
896 Ok(config)
897 }
898 }
899 impl FromPyArgs<Vec<Float>> for SimplexConstructionMethod {
900 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
901 if let Some(simplex_construction_method) = d.get_item("simplex_construction_method")? {
902 match simplex_construction_method
903 .extract::<String>()?
904 .to_lowercase()
905 .trim()
906 .replace("-", "")
907 .replace(" ", "")
908 .as_str()
909 {
910 "scaledorthogonal" => {
911 let orthogonal_multiplier = d
912 .get_item("orthogonal_multiplier")?
913 .map(|v| v.extract())
914 .transpose()?
915 .unwrap_or(1.05);
916 let orthogonal_zero_step = d
917 .get_item("orthogonal_zero_step")?
918 .map(|v| v.extract())
919 .transpose()?
920 .unwrap_or(0.00025);
921 return Ok(SimplexConstructionMethod::custom_scaled_orthogonal(
922 args,
923 orthogonal_multiplier,
924 orthogonal_zero_step,
925 ));
926 }
927 "orthogonal" => {
928 let simplex_size = d
929 .get_item("simplex_size")?
930 .map(|v| v.extract())
931 .transpose()?
932 .unwrap_or(1.0);
933 return Ok(SimplexConstructionMethod::custom_orthogonal(
934 args,
935 simplex_size,
936 ));
937 }
938 "custom" => {
939 if let Some(other_simplex_points) = d.get_item("simplex")? {
940 let mut simplex = Vec::with_capacity(args.len() + 1);
941 simplex[0] = DVector::from_vec(args.clone());
942 let others = other_simplex_points.extract::<Vec<Vec<Float>>>()?; if others.len() != args.len() {
944 return Err(PyValueError::new_err(format!(
945 "Expected {} additional simplex points, got {}.",
946 args.len(),
947 others.len()
948 )));
949 }
950 simplex.extend(others.iter().map(|x| DVector::from_vec(x.clone())));
951 return Ok(SimplexConstructionMethod::custom(simplex));
952 } else {
953 return Err(PyValueError::new_err("Simplex must be specified when using the 'custom' simplex_construction_method."));
954 }
955 }
956 _ => {
957 return Err(PyValueError::new_err(format!(
958 "Invalid simplex_construction_method: {}",
959 simplex_construction_method
960 )))
961 }
962 }
963 } else {
964 Ok(SimplexConstructionMethod::scaled_orthogonal(args))
965 }
966 }
967 }
968 impl<P> FromPyArgs
969 for Callbacks<
970 NelderMead,
971 P,
972 GradientFreeStatus,
973 MaybeThreadPool,
974 LadduError,
975 NelderMeadConfig,
976 >
977 where
978 P: CostFunction<MaybeThreadPool, LadduError>,
979 {
980 fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
981 let mut callbacks = Callbacks::empty();
982 let eps_f = if let Some(eps_f) = d.get_item("eps_f")? {
983 eps_f.extract()?
984 } else {
985 Float::EPSILON.powf(0.25)
986 };
987 if let Some(f_term) = d.get_item("f_terminator")? {
988 match f_term
989 .extract::<String>()?
990 .to_lowercase()
991 .trim()
992 .replace("-", "")
993 .replace(" ", "")
994 .as_str()
995 {
996 "amoeba" => {
997 callbacks = callbacks
998 .with_terminator(NelderMeadFTerminator::Amoeba { eps_rel: eps_f });
999 }
1000 "absolute" => {
1001 callbacks = callbacks
1002 .with_terminator(NelderMeadFTerminator::Absolute { eps_abs: eps_f });
1003 }
1004 "stddev" => {
1005 callbacks = callbacks
1006 .with_terminator(NelderMeadFTerminator::StdDev { eps_abs: eps_f });
1007 }
1008 _ => Err(PyValueError::new_err(format!(
1009 "Invalid f_terminator: {}",
1010 f_term
1011 )))?,
1012 }
1013 } else {
1014 callbacks =
1015 callbacks.with_terminator(NelderMeadFTerminator::StdDev { eps_abs: eps_f });
1016 }
1017 let eps_x = if let Some(eps_x) = d.get_item("eps_x")? {
1018 eps_x.extract()?
1019 } else {
1020 Float::EPSILON.powf(0.25)
1021 };
1022 if let Some(x_term) = d.get_item("x_terminator")? {
1023 match x_term
1024 .extract::<String>()?
1025 .to_lowercase()
1026 .trim()
1027 .replace("-", "")
1028 .replace(" ", "")
1029 .as_str()
1030 {
1031 "diameter" => {
1032 callbacks = callbacks
1033 .with_terminator(NelderMeadXTerminator::Diameter { eps_abs: eps_x });
1034 }
1035 "higham" => {
1036 callbacks = callbacks
1037 .with_terminator(NelderMeadXTerminator::Higham { eps_rel: eps_x });
1038 }
1039 "rowan" => {
1040 callbacks = callbacks
1041 .with_terminator(NelderMeadXTerminator::Rowan { eps_rel: eps_x });
1042 }
1043 "singer" => {
1044 callbacks = callbacks
1045 .with_terminator(NelderMeadXTerminator::Singer { eps_rel: eps_x });
1046 }
1047 _ => Err(PyValueError::new_err(format!(
1048 "Invalid x_terminator: {}",
1049 x_term
1050 )))?,
1051 }
1052 } else {
1053 callbacks =
1054 callbacks.with_terminator(NelderMeadXTerminator::Singer { eps_rel: eps_x });
1055 }
1056 Ok(callbacks)
1057 }
1058 }
1059 impl FromPyArgs<Vec<Float>> for SwarmPositionInitializer {
1060 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1061 if let Some(swarm_position_initializer) = d.get_item("swarm_position_initializer")? {
1062 match swarm_position_initializer
1063 .extract::<String>()?
1064 .to_lowercase()
1065 .trim()
1066 .replace("-", "")
1067 .replace(" ", "")
1068 .as_str()
1069 {
1070 "randominlimits" => {
1071 if let (Some(swarm_position_bounds), Some(swarm_size)) = (
1072 d.get_item("swarm_position_bounds")?,
1073 d.get_item("swarm_size")?,
1074 ) {
1075 return Ok(SwarmPositionInitializer::RandomInLimits {
1076 bounds: swarm_position_bounds.extract()?,
1077 n_particles: swarm_size.extract()?,
1078 });
1079 } else {
1080 return Err(PyValueError::new_err("The swarm_position_bounds and swarm_size must be specified when using the 'randominlimits' swarm_position_initializer."));
1081 }
1082 }
1083 "latinhypercube" => {
1084 if let (Some(swarm_position_bounds), Some(swarm_size)) = (
1085 d.get_item("swarm_position_bounds")?,
1086 d.get_item("swarm_size")?,
1087 ) {
1088 return Ok(SwarmPositionInitializer::LatinHypercube {
1089 bounds: swarm_position_bounds.extract()?,
1090 n_particles: swarm_size.extract()?,
1091 });
1092 } else {
1093 return Err(PyValueError::new_err("The swarm_position_bounds and swarm_size must be specified when using the 'latinhypercube' swarm_position_initializer."));
1094 }
1095 }
1096 "custom" => {
1097 if let Some(swarm) = d.get_item("swarm")? {
1098 return Ok(SwarmPositionInitializer::Custom(
1099 swarm
1100 .extract::<Vec<Vec<Float>>>()?
1101 .iter()
1102 .chain(vec![args].into_iter())
1103 .map(|x| DVector::from_vec(x.clone()))
1104 .collect(), ));
1106 } else {
1107 return Err(PyValueError::new_err("The swarm must be specified when using the 'custom' swarm_position_initializer."));
1108 }
1109 }
1110 _ => {
1111 return Err(PyValueError::new_err(format!(
1112 "Invalid swarm_position_initializer: {}",
1113 swarm_position_initializer
1114 )));
1115 }
1116 }
1117 } else {
1118 return Err(PyValueError::new_err(
1119 "The swarm_position_initializer must be specified for the PSO algorithm.",
1120 ));
1121 }
1122 }
1123 }
1124 impl FromPyArgs<Vec<Float>> for Swarm {
1125 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1126 let swarm_position_initializer = SwarmPositionInitializer::from_pyargs(args, d)?;
1127 let mut swarm = Swarm::new(swarm_position_initializer);
1128 if let Some(swarm_topology_str) = d.get_item("swarm_topology")? {
1129 match swarm_topology_str
1130 .extract::<String>()?
1131 .to_lowercase()
1132 .trim()
1133 .replace("-", "")
1134 .replace(" ", "")
1135 .as_str()
1136 {
1137 "global" => {
1138 swarm = swarm.with_topology(SwarmTopology::Global);
1139 }
1140 "ring" => {
1141 swarm = swarm.with_topology(SwarmTopology::Ring);
1142 }
1143 _ => {
1144 return Err(PyValueError::new_err(format!(
1145 "Invalid swarm_topology: {}",
1146 swarm_topology_str
1147 )))
1148 }
1149 }
1150 }
1151 if let Some(swarm_update_method_str) = d.get_item("swarm_update_method")? {
1152 match swarm_update_method_str
1153 .extract::<String>()?
1154 .to_lowercase()
1155 .trim()
1156 .replace("-", "")
1157 .replace(" ", "")
1158 .as_str()
1159 {
1160 "sync" | "synchronous" => {
1161 swarm = swarm.with_update_method(SwarmUpdateMethod::Synchronous);
1162 }
1163 "async" | "asynchronous" => {
1164 swarm = swarm.with_update_method(SwarmUpdateMethod::Asynchronous);
1165 }
1166 _ => {
1167 return Err(PyValueError::new_err(format!(
1168 "Invalid swarm_update_method: {}",
1169 swarm_update_method_str
1170 )))
1171 }
1172 }
1173 }
1174 if let Some(swarm_boundary_method_str) = d.get_item("swarm_boundary_method")? {
1175 match swarm_boundary_method_str
1176 .extract::<String>()?
1177 .to_lowercase()
1178 .trim()
1179 .replace("-", "")
1180 .replace(" ", "")
1181 .as_str()
1182 {
1183 "inf" => {
1184 swarm = swarm.with_boundary_method(SwarmBoundaryMethod::Inf);
1185 }
1186 "shr" => {
1187 swarm = swarm.with_boundary_method(SwarmBoundaryMethod::Shr);
1188 }
1189 _ => {
1190 return Err(PyValueError::new_err(format!(
1191 "Invalid swarm_boundary_method: {}",
1192 swarm_boundary_method_str
1193 )))
1194 }
1195 }
1196 }
1197 if let Some(swarm_velocity_bounds) = d.get_item("swarm_velocity_bounds")? {
1198 swarm = swarm.with_velocity_initializer(SwarmVelocityInitializer::RandomInLimits(
1199 swarm_velocity_bounds.extract()?,
1200 ));
1201 }
1202 Ok(swarm)
1203 }
1204 }
1205 impl FromPyArgs<Vec<Float>> for PSOConfig {
1206 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1207 let swarm = Swarm::from_pyargs(args, d)?;
1208 let mut config = PSOConfig::new(swarm);
1209 if let Some(omega) = d.get_item("omega")? {
1210 config = config.with_omega(omega.extract()?);
1211 }
1212 if let Some(c1) = d.get_item("c1")? {
1213 config = config.with_c1(c1.extract()?);
1214 }
1215 if let Some(c2) = d.get_item("c2")? {
1216 config = config.with_c2(c2.extract()?);
1217 }
1218 Ok(config)
1219 }
1220 }
1221 impl<P> FromPyArgs<Vec<Float>> for MinimizationSettings<P>
1222 where
1223 P: Gradient<MaybeThreadPool, LadduError>,
1224 {
1225 fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1226 let bounds: Option<Vec<ganesh::traits::boundlike::Bound>> = d
1227 .get_item("bounds")?
1228 .map(|bounds| bounds.extract::<Vec<(Option<Float>, Option<Float>)>>())
1229 .transpose()?
1230 .map(|bounds| {
1231 bounds
1232 .into_iter()
1233 .map(ganesh::traits::boundlike::Bound::from)
1234 .collect()
1235 });
1236 let num_threads = d
1237 .get_item("threads")?
1238 .map(|t| t.extract())
1239 .transpose()?
1240 .unwrap_or(0);
1241 let add_debug = d
1242 .get_item("debug")?
1243 .map(|d| d.extract())
1244 .transpose()?
1245 .unwrap_or(false);
1246 let observers = if let Some(observers) = d.get_item("observers")? {
1247 if let Ok(observers) = observers.downcast::<PyList>() {
1248 observers.into_iter().map(|observer| {
1249 if let Ok(observer) = observer.extract::<MinimizationObserver>() {
1250 Ok(observer)
1251 } else {
1252 Err(PyValueError::new_err("The observers must be either a single MinimizationObserver or a list of MinimizationObservers."))
1253 }
1254 }).collect::<PyResult<Vec<MinimizationObserver>>>()?
1255 } else if let Ok(observer) = observers.extract::<MinimizationObserver>() {
1256 vec![observer]
1257 } else {
1258 return Err(PyValueError::new_err("The observers must be either a single MinimizationObserver or a list of MinimizationObservers."));
1259 }
1260 } else {
1261 vec![]
1262 };
1263 let terminators = if let Some(terminators) = d.get_item("terminators")? {
1264 if let Ok(terminators) = terminators.downcast::<PyList>() {
1265 terminators.into_iter().map(|terminator| {
1266 if let Ok(terminator) = terminator.extract::<MinimizationTerminator>() {
1267 Ok(terminator)
1268 } else {
1269 Err(PyValueError::new_err("The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators."))
1270 }
1271 }).collect::<PyResult<Vec<MinimizationTerminator>>>()?
1272 } else if let Ok(terminator) = terminators.extract::<MinimizationTerminator>() {
1273 vec![terminator]
1274 } else {
1275 return Err(PyValueError::new_err("The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators."));
1276 }
1277 } else {
1278 vec![]
1279 };
1280 let max_steps: Option<usize> = d
1281 .get_item("max_steps")?
1282 .map(|ms| ms.extract())
1283 .transpose()?;
1284 let settings: Bound<PyDict> = d
1285 .get_item("settings")?
1286 .map(|settings| settings.extract())
1287 .transpose()?
1288 .unwrap_or_else(|| PyDict::new(d.py()));
1289 if let Some(method) = d.get_item("method")? {
1290 match method
1291 .extract::<String>()?
1292 .to_lowercase()
1293 .trim()
1294 .replace("-", "")
1295 .replace(" ", "")
1296 .as_str()
1297 {
1298 "lbfgsb" => {
1299 let mut config = LBFGSBConfig::from_pyargs(args, &settings)?;
1300 if let Some(bounds) = bounds {
1301 config = config.with_bounds(bounds);
1302 }
1303 let mut callbacks = Callbacks::from_pyargs(&(), &settings)?;
1304 if add_debug {
1305 callbacks = callbacks.with_observer(DebugObserver);
1306 }
1307 if let Some(max_steps) = max_steps {
1308 callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1309 }
1310 for observer in observers {
1311 callbacks = callbacks.with_observer(observer);
1312 }
1313 for terminator in terminators {
1314 callbacks = callbacks.with_terminator(terminator);
1315 }
1316 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1317 Ok(MinimizationSettings::LBFGSB {
1318 config,
1319 callbacks,
1320 num_threads,
1321 })
1322 }
1323 "adam" => {
1324 let mut config = AdamConfig::from_pyargs(args, &settings)?;
1325 if let Some(bounds) = bounds {
1326 config = config.with_transform(&Bounds::from(bounds))
1327 }
1328 let mut callbacks = Callbacks::from_pyargs(&(), &settings)?;
1329 if add_debug {
1330 callbacks = callbacks.with_observer(DebugObserver);
1331 }
1332 if let Some(max_steps) = max_steps {
1333 callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1334 }
1335 for observer in observers {
1336 callbacks = callbacks.with_observer(observer);
1337 }
1338 for terminator in terminators {
1339 callbacks = callbacks.with_terminator(terminator);
1340 }
1341 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1342 Ok(MinimizationSettings::Adam {
1343 config,
1344 callbacks,
1345 num_threads,
1346 })
1347 }
1348 "neldermead" => {
1349 let mut config = NelderMeadConfig::from_pyargs(args, &settings)?;
1350 if let Some(bounds) = bounds {
1351 config = config.with_bounds(bounds);
1352 }
1353 let mut callbacks = Callbacks::from_pyargs(&(), &settings)?;
1354 if add_debug {
1355 callbacks = callbacks.with_observer(DebugObserver);
1356 }
1357 if let Some(max_steps) = max_steps {
1358 callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1359 }
1360 for observer in observers {
1361 callbacks = callbacks.with_observer(observer);
1362 }
1363 for terminator in terminators {
1364 callbacks = callbacks.with_terminator(terminator);
1365 }
1366 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1367 Ok(MinimizationSettings::NelderMead {
1368 config,
1369 callbacks,
1370 num_threads,
1371 })
1372 }
1373 "pso" => {
1374 let mut config = PSOConfig::from_pyargs(args, &settings)?;
1375 if let Some(bounds) = bounds {
1376 if let Some(use_transform) = settings.get_item("use_transform")? {
1377 if use_transform.extract()? {
1378 config = config.with_transform(&Bounds::from(bounds))
1379 } else {
1380 config = config.with_bounds(bounds)
1381 }
1382 } else {
1383 config = config.with_bounds(bounds)
1384 }
1385 }
1386 let mut callbacks = Callbacks::empty();
1387 if add_debug {
1388 return Err(PyValueError::new_err(
1389 "The debug setting is not yet supported for PSO",
1390 ));
1391 }
1394 if let Some(max_steps) = max_steps {
1395 callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1396 }
1397 for observer in observers {
1398 callbacks = callbacks.with_observer(observer);
1399 }
1400 for terminator in terminators {
1401 callbacks = callbacks.with_terminator(terminator);
1402 }
1403 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1404 Ok(MinimizationSettings::PSO {
1405 config,
1406 callbacks,
1407 num_threads,
1408 })
1409 }
1410 _ => Err(PyValueError::new_err(format!(
1411 "Invalid minimizer: {}",
1412 method
1413 ))),
1414 }
1415 } else {
1416 Err(PyValueError::new_err("No method specified"))
1417 }
1418 }
1419 }
1420
1421 impl FromPyArgs<Vec<DVector<Float>>> for AIESConfig {
1422 fn from_pyargs(args: &Vec<DVector<Float>>, d: &Bound<PyDict>) -> PyResult<Self> {
1423 let mut config = AIESConfig::new(args.to_vec());
1424 if let Some(moves) = d.get_item("moves")? {
1425 let moves_list = moves.downcast::<PyList>()?;
1426 let mut aies_moves = vec![];
1427 for mcmc_move in moves_list {
1428 if let Ok(default_move) = mcmc_move.extract::<(String, Float)>() {
1429 match default_move
1430 .0
1431 .to_lowercase()
1432 .trim()
1433 .replace("-", "")
1434 .replace(" ", "")
1435 .as_str()
1436 {
1437 "stretch" => aies_moves.push(AIESMove::stretch(default_move.1)),
1438 "walk" => aies_moves.push(AIESMove::walk(default_move.1)),
1439 _ => {
1440 return Err(PyValueError::new_err(format!(
1441 "Invalid AIES move: {}",
1442 default_move.0
1443 )))
1444 }
1445 }
1446 } else if let Ok(custom_move) =
1447 mcmc_move.extract::<(String, Bound<PyDict>, Float)>()
1448 {
1449 match custom_move
1450 .0
1451 .to_lowercase()
1452 .trim()
1453 .replace("-", "")
1454 .replace(" ", "")
1455 .as_str()
1456 {
1457 "stretch" => aies_moves.push((
1458 AIESMove::Stretch {
1459 a: custom_move
1460 .1
1461 .get_item("a")?
1462 .map(|val| val.extract())
1463 .transpose()?
1464 .unwrap_or(2.0),
1465 },
1466 custom_move.2,
1467 )),
1468 "walk" => aies_moves.push(AIESMove::walk(custom_move.2)),
1469 _ => {
1470 return Err(PyValueError::new_err(format!(
1471 "Invalid AIES move: {}",
1472 custom_move.0
1473 )))
1474 }
1475 }
1476 } else {
1477 return Err(PyValueError::new_err("The 'moves' argument must be a list of (str, float) or (str, dict, float) tuples!"));
1478 }
1479 }
1480 config = config.with_moves(aies_moves);
1481 }
1482 Ok(config)
1483 }
1484 }
1485
1486 impl FromPyArgs<Vec<DVector<Float>>> for ESSConfig {
1487 fn from_pyargs(args: &Vec<DVector<Float>>, d: &Bound<PyDict>) -> PyResult<Self> {
1488 let mut config = ESSConfig::new(args.to_vec());
1489 if let Some(moves) = d.get_item("moves")? {
1490 let moves_list = moves.downcast::<PyList>()?;
1491 let mut ess_moves = vec![];
1492 for mcmc_move in moves_list {
1493 if let Ok(default_move) = mcmc_move.extract::<(String, Float)>() {
1494 match default_move
1495 .0
1496 .to_lowercase()
1497 .trim()
1498 .replace("-", "")
1499 .replace(" ", "")
1500 .as_str()
1501 {
1502 "differential" => ess_moves.push(ESSMove::differential(default_move.1)),
1503 "gaussian" => ess_moves.push(ESSMove::gaussian(default_move.1)),
1504 "global" => {
1505 ess_moves.push(ESSMove::global(default_move.1, None, None, None))
1506 }
1507 _ => {
1508 return Err(PyValueError::new_err(format!(
1509 "Invalid ESS move: {}",
1510 default_move.0
1511 )))
1512 }
1513 }
1514 } else if let Ok(custom_move) =
1515 mcmc_move.extract::<(String, Bound<PyDict>, Float)>()
1516 {
1517 match custom_move
1518 .0
1519 .to_lowercase()
1520 .trim()
1521 .replace("-", "")
1522 .replace(" ", "")
1523 .as_str()
1524 {
1525 "differential" => ess_moves.push(ESSMove::differential(custom_move.2)),
1526 "gaussian" => ess_moves.push(ESSMove::gaussian(custom_move.2)),
1527 "global" => ess_moves.push(ESSMove::global(
1528 custom_move.2,
1529 custom_move
1530 .1
1531 .get_item("scale")?
1532 .map(|value| value.extract())
1533 .transpose()?,
1534 custom_move
1535 .1
1536 .get_item("rescale_cov")?
1537 .map(|value| value.extract())
1538 .transpose()?,
1539 custom_move
1540 .1
1541 .get_item("n_components")?
1542 .map(|value| value.extract())
1543 .transpose()?,
1544 )),
1545 _ => {
1546 return Err(PyValueError::new_err(format!(
1547 "Invalid ESS move: {}",
1548 custom_move.0
1549 )))
1550 }
1551 }
1552 } else {
1553 return Err(PyValueError::new_err("The 'moves' argument must be a list of (str, float) or (str, dict, float) tuples!"));
1554 }
1555 }
1556 config = config.with_moves(ess_moves)
1557 }
1558 if let Some(n_adaptive) = d.get_item("n_adaptive")? {
1559 config = config.with_n_adaptive(n_adaptive.extract()?);
1560 }
1561 if let Some(mu) = d.get_item("mu")? {
1562 config = config.with_mu(mu.extract()?);
1563 }
1564 if let Some(max_steps) = d.get_item("max_steps")? {
1565 config = config.with_max_steps(max_steps.extract()?);
1566 }
1567 Ok(config)
1568 }
1569 }
1570
1571 impl<P> FromPyArgs<Vec<DVector<Float>>> for MCMCSettings<P>
1572 where
1573 P: LogDensity<MaybeThreadPool, LadduError>,
1574 {
1575 fn from_pyargs(args: &Vec<DVector<Float>>, d: &Bound<PyDict>) -> PyResult<Self> {
1576 let bounds: Option<Vec<ganesh::traits::boundlike::Bound>> = d
1577 .get_item("bounds")?
1578 .map(|bounds| bounds.extract::<Vec<(Option<Float>, Option<Float>)>>())
1579 .transpose()?
1580 .map(|bounds| {
1581 bounds
1582 .into_iter()
1583 .map(ganesh::traits::boundlike::Bound::from)
1584 .collect()
1585 });
1586 let num_threads = d
1587 .get_item("threads")?
1588 .map(|t| t.extract())
1589 .transpose()?
1590 .unwrap_or(0);
1591 let add_debug = d
1592 .get_item("debug")?
1593 .map(|d| d.extract())
1594 .transpose()?
1595 .unwrap_or(false);
1596 let observers = if let Some(observers) = d.get_item("observers")? {
1597 if let Ok(observers) = observers.downcast::<PyList>() {
1598 observers.into_iter().map(|observer| {
1599 if let Ok(observer) = observer.extract::<MCMCObserver>() {
1600 Ok(observer)
1601 } else {
1602 Err(PyValueError::new_err("The observers must be either a single MCMCObserver or a list of MCMCObservers."))
1603 }
1604 }).collect::<PyResult<Vec<MCMCObserver>>>()?
1605 } else if let Ok(observer) = observers.extract::<MCMCObserver>() {
1606 vec![observer]
1607 } else {
1608 return Err(PyValueError::new_err("The observers must be either a single MCMCObserver or a list of MCMCObservers."));
1609 }
1610 } else {
1611 vec![]
1612 };
1613 let terminators = if let Some(terminators) = d.get_item("terminators")? {
1614 if let Ok(terminators) = terminators.downcast::<PyList>() {
1615 terminators
1616 .into_iter()
1617 .map(|terminator| {
1618 if let Ok(terminator) =
1619 terminator.extract::<PyAutocorrelationTerminator>()
1620 {
1621 Ok(PythonMCMCTerminator::Autocorrelation(terminator))
1622 }
1623 else if let Ok(terminator) = terminator.extract::<MCMCTerminator>() {
1624 Ok(PythonMCMCTerminator::UserDefined(terminator))
1625 } else {
1626 Err(PyValueError::new_err("The terminators must be either a single MCMCTerminator or a list of MCMCTerminators."))
1627 }
1628 })
1629 .collect::<PyResult<Vec<PythonMCMCTerminator>>>()?
1630 } else if let Ok(terminator) = terminators.extract::<PyAutocorrelationTerminator>()
1631 {
1632 vec![PythonMCMCTerminator::Autocorrelation(terminator)]
1633 } else if let Ok(terminator) = terminators.extract::<MCMCTerminator>() {
1634 vec![PythonMCMCTerminator::UserDefined(terminator)]
1635 } else {
1636 return Err(PyValueError::new_err("The terminators must be either a single MCMCTerminator or a list of MCMCTerminators."));
1637 }
1638 } else {
1639 vec![]
1640 };
1641 let max_steps: Option<usize> = d
1642 .get_item("max_steps")?
1643 .map(|ms| ms.extract())
1644 .transpose()?;
1645 let settings: Bound<PyDict> = d
1646 .get_item("settings")?
1647 .map(|settings| settings.extract())
1648 .transpose()?
1649 .unwrap_or_else(|| PyDict::new(d.py()));
1650 if let Some(method) = d.get_item("method")? {
1651 match method
1652 .extract::<String>()?
1653 .to_lowercase()
1654 .trim()
1655 .replace("-", "")
1656 .replace(" ", "")
1657 .as_str()
1658 {
1659 "aies" => {
1660 let mut config = AIESConfig::from_pyargs(args, &settings)?;
1661 if let Some(bounds) = bounds {
1662 config = config.with_transform(&Bounds::from(bounds))
1663 }
1664 let mut callbacks = Callbacks::empty();
1665 if add_debug {
1666 callbacks = callbacks.with_observer(DebugObserver);
1667 }
1668 if let Some(max_steps) = max_steps {
1669 callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1670 }
1671 for observer in observers {
1672 callbacks = callbacks.with_observer(observer);
1673 }
1674 for terminator in terminators {
1675 callbacks = callbacks.with_terminator(terminator);
1676 }
1677 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1678 Ok(MCMCSettings::AIES {
1679 config,
1680 callbacks,
1681 num_threads,
1682 })
1683 }
1684 "ess" => {
1685 let mut config = ESSConfig::from_pyargs(args, &settings)?;
1686 if let Some(bounds) = bounds {
1687 config = config.with_transform(&Bounds::from(bounds))
1688 }
1689 let mut callbacks = Callbacks::empty();
1690 if add_debug {
1691 callbacks = callbacks.with_observer(DebugObserver);
1692 }
1693 if let Some(max_steps) = max_steps {
1694 callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1695 }
1696 for observer in observers {
1697 callbacks = callbacks.with_observer(observer);
1698 }
1699 for terminator in terminators {
1700 callbacks = callbacks.with_terminator(terminator);
1701 }
1702 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1703 Ok(MCMCSettings::ESS {
1704 config,
1705 callbacks,
1706 num_threads,
1707 })
1708 }
1709 _ => Err(PyValueError::new_err(format!(
1710 "Invalid MCMC algorithm: {}",
1711 method
1712 ))),
1713 }
1714 } else {
1715 Err(PyValueError::new_err("No method specified"))
1716 }
1717 }
1718 }
1719
1720 enum MinimizationStatus {
1721 GradientStatus(Arc<Mutex<GradientStatus>>),
1722 GradientFreeStatus(Arc<Mutex<GradientFreeStatus>>),
1723 SwarmStatus(Arc<Mutex<SwarmStatus>>),
1724 }
1725 impl MinimizationStatus {
1726 fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1727 match self {
1728 Self::GradientStatus(gradient_status) => {
1729 gradient_status.lock().x.as_slice().to_pyarray(py)
1730 }
1731 Self::GradientFreeStatus(gradient_free_status) => {
1732 gradient_free_status.lock().x.as_slice().to_pyarray(py)
1733 }
1734 Self::SwarmStatus(swarm_status) => {
1735 swarm_status.lock().gbest.x.as_slice().to_pyarray(py)
1736 }
1737 }
1738 }
1739 fn fx(&self) -> Float {
1740 match self {
1741 Self::GradientStatus(gradient_status) => gradient_status.lock().fx,
1742 Self::GradientFreeStatus(gradient_free_status) => gradient_free_status.lock().fx,
1743 Self::SwarmStatus(swarm_status) => swarm_status.lock().gbest.fx.unwrap(),
1744 }
1745 }
1746 fn message(&self) -> String {
1747 match self {
1748 Self::GradientStatus(gradient_status) => {
1749 gradient_status.lock().message().to_string()
1750 }
1751 Self::GradientFreeStatus(gradient_free_status) => {
1752 gradient_free_status.lock().message().to_string()
1753 }
1754 Self::SwarmStatus(swarm_status) => swarm_status.lock().message().to_string(),
1755 }
1756 }
1757 fn err<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray1<Float>>> {
1758 match self {
1759 Self::GradientStatus(gradient_status) => gradient_status.lock().err.clone(),
1760 Self::GradientFreeStatus(gradient_free_status) => {
1761 gradient_free_status.lock().err.clone()
1762 }
1763 Self::SwarmStatus(_) => None,
1764 }
1765 .map(|e| e.as_slice().to_pyarray(py))
1766 }
1767 fn n_f_evals(&self) -> usize {
1768 match self {
1769 Self::GradientStatus(gradient_status) => gradient_status.lock().n_f_evals,
1770 Self::GradientFreeStatus(gradient_free_status) => {
1771 gradient_free_status.lock().n_f_evals
1772 }
1773 Self::SwarmStatus(swarm_status) => swarm_status.lock().n_f_evals,
1774 }
1775 }
1776 fn n_g_evals(&self) -> usize {
1777 match self {
1778 Self::GradientStatus(gradient_status) => gradient_status.lock().n_g_evals,
1779 Self::GradientFreeStatus(_) => 0,
1780 Self::SwarmStatus(_) => 0,
1781 }
1782 }
1783 fn cov<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1784 match self {
1785 Self::GradientStatus(gradient_status) => gradient_status.lock().cov.clone(),
1786 Self::GradientFreeStatus(gradient_free_status) => {
1787 gradient_free_status.lock().cov.clone()
1788 }
1789 Self::SwarmStatus(_) => None,
1790 }
1791 .map(|cov| cov.to_pyarray(py))
1792 }
1793 fn hess<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1794 match self {
1795 Self::GradientStatus(gradient_status) => gradient_status.lock().hess.clone(),
1796 Self::GradientFreeStatus(gradient_free_status) => {
1797 gradient_free_status.lock().hess.clone()
1798 }
1799 Self::SwarmStatus(_) => None,
1800 }
1801 .map(|hess| hess.to_pyarray(py))
1802 }
1803 fn converged(&self) -> bool {
1804 match self {
1805 Self::GradientStatus(gradient_status) => gradient_status.lock().converged(),
1806 Self::GradientFreeStatus(gradient_free_status) => {
1807 gradient_free_status.lock().converged()
1808 }
1809 Self::SwarmStatus(swarm_status) => swarm_status.lock().converged(),
1810 }
1811 }
1812 fn swarm(&self) -> Option<PySwarm> {
1813 match self {
1814 Self::GradientStatus(_) | Self::GradientFreeStatus(_) => None,
1815 Self::SwarmStatus(swarm_status) => Some(PySwarm(swarm_status.lock().swarm.clone())),
1816 }
1817 }
1818 }
1819
1820 #[pyclass(name = "Swarm", module = "laddu")]
1823 pub struct PySwarm(Swarm);
1824
1825 #[pymethods]
1826 impl PySwarm {
1827 #[getter]
1834 fn particles(&self) -> Vec<PySwarmParticle> {
1835 self.0
1836 .get_particles()
1837 .into_iter()
1838 .map(PySwarmParticle)
1839 .collect()
1840 }
1841 }
1842
1843 #[pyclass(name = "SwarmParticle", module = "laddu")]
1846 pub struct PySwarmParticle(SwarmParticle);
1847
1848 #[pymethods]
1849 impl PySwarmParticle {
1850 #[getter]
1857 fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1858 self.0.position.x.as_slice().to_pyarray(py)
1859 }
1860 #[getter]
1867 fn fx(&self) -> Float {
1868 self.0.position.fx.unwrap()
1869 }
1870 #[getter]
1877 fn x_best<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1878 self.0.best.x.as_slice().to_pyarray(py)
1879 }
1880 #[getter]
1887 fn fx_best(&self) -> Float {
1888 self.0.best.fx.unwrap()
1889 }
1890 #[getter]
1897 fn velocity<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1898 self.0.velocity.as_slice().to_pyarray(py)
1899 }
1900 }
1901
1902 #[pyclass(name = "MinimizationStatus", module = "laddu")]
1905 pub struct PyMinimizationStatus(MinimizationStatus);
1906
1907 #[pymethods]
1908 impl PyMinimizationStatus {
1909 #[getter]
1916 fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1917 self.0.x(py)
1918 }
1919 #[getter]
1926 fn fx(&self) -> Float {
1927 self.0.fx()
1928 }
1929 #[getter]
1936 fn message(&self) -> String {
1937 self.0.message()
1938 }
1939 #[getter]
1946 fn err<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray1<Float>>> {
1947 self.0.err(py)
1948 }
1949 #[getter]
1956 fn n_f_evals(&self) -> usize {
1957 self.0.n_f_evals()
1958 }
1959 #[getter]
1966 fn n_g_evals(&self) -> usize {
1967 self.0.n_g_evals()
1968 }
1969 #[getter]
1976 fn cov<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1977 self.0.cov(py)
1978 }
1979 #[getter]
1986 fn hess<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1987 self.0.hess(py)
1988 }
1989 #[getter]
1990 fn converged(&self) -> bool {
1991 self.0.converged()
1992 }
1993 #[getter]
2000 fn swarm(&self) -> Option<PySwarm> {
2001 self.0.swarm()
2002 }
2003 }
2004
2005 #[pyclass(name = "MinimizationSummary", module = "laddu")]
2008 #[derive(Clone)]
2009 pub struct PyMinimizationSummary(pub MinimizationSummary);
2010
2011 #[pymethods]
2012 impl PyMinimizationSummary {
2013 #[getter]
2020 fn bounds(&self) -> Option<Vec<(Float, Float)>> {
2021 self.0
2022 .clone()
2023 .bounds
2024 .map(|bs| bs.iter().map(|b| b.0.as_floats()).collect())
2025 }
2026 #[getter]
2033 fn parameter_names(&self) -> Vec<String> {
2034 self.0.parameter_names.clone().unwrap_or_default()
2035 }
2036 #[getter]
2043 fn message(&self) -> String {
2044 self.0.message.clone()
2045 }
2046 #[getter]
2053 fn x0<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2054 self.0.x0.as_slice().to_pyarray(py)
2055 }
2056 #[getter]
2063 fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2064 self.0.x.as_slice().to_pyarray(py)
2065 }
2066 #[getter]
2073 fn std<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2074 self.0.std.as_slice().to_pyarray(py)
2075 }
2076 #[getter]
2083 fn fx(&self) -> Float {
2084 self.0.fx
2085 }
2086 #[getter]
2093 fn cost_evals(&self) -> usize {
2094 self.0.cost_evals
2095 }
2096 #[getter]
2103 fn gradient_evals(&self) -> usize {
2104 self.0.gradient_evals
2105 }
2106 #[getter]
2113 fn converged(&self) -> bool {
2114 self.0.converged
2115 }
2116 #[getter]
2123 fn covariance<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<Float>> {
2124 self.0.covariance.to_pyarray(py)
2125 }
2126 fn __str__(&self) -> String {
2127 self.0.to_string()
2128 }
2129 #[new]
2130 fn new() -> Self {
2131 Self(MinimizationSummary::create_null())
2132 }
2133 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
2134 Ok(PyBytes::new(
2135 py,
2136 bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
2137 .map_err(LadduError::EncodeError)?
2138 .as_slice(),
2139 ))
2140 }
2141 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
2142 *self = Self(
2143 bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
2144 .map_err(LadduError::DecodeError)?
2145 .0,
2146 );
2147 Ok(())
2148 }
2149 }
2150
2151 #[pyclass(eq, eq_int, name = "ControlFlow", module = "laddu")]
2154 #[derive(PartialEq, Clone)]
2155 pub enum PyControlFlow {
2156 Continue = 0,
2158 Break = 1,
2160 }
2161
2162 impl From<PyControlFlow> for ControlFlow<()> {
2163 fn from(v: PyControlFlow) -> Self {
2164 match v {
2165 PyControlFlow::Continue => ControlFlow::Continue(()),
2166 PyControlFlow::Break => ControlFlow::Break(()),
2167 }
2168 }
2169 }
2170
2171 #[derive(Clone)]
2176 pub struct MinimizationObserver(Arc<Py<PyAny>>);
2177 impl FromPyObject<'_> for MinimizationObserver {
2178 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2179 Ok(MinimizationObserver(Arc::new(ob.clone().unbind())))
2180 }
2181 }
2182 impl<A, P, C> Observer<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
2183 for MinimizationObserver
2184 where
2185 A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
2186 {
2187 fn observe(
2188 &mut self,
2189 current_step: usize,
2190 _algorithm: &A,
2191 _problem: &P,
2192 status: &GradientStatus,
2193 _args: &MaybeThreadPool,
2194 _config: &C,
2195 ) {
2196 Python::attach(|py| {
2197 self.0
2198 .bind(py)
2199 .call_method1(
2200 "observe",
2201 (
2202 current_step,
2203 PyMinimizationStatus(MinimizationStatus::GradientStatus(Arc::new(
2204 Mutex::new(status.clone()),
2205 ))),
2206 ),
2207 )
2208 .expect("Error calling observe");
2209 })
2210 }
2211 }
2212 impl<A, P, C> Observer<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
2213 for MinimizationObserver
2214 where
2215 A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
2216 {
2217 fn observe(
2218 &mut self,
2219 current_step: usize,
2220 _algorithm: &A,
2221 _problem: &P,
2222 status: &GradientFreeStatus,
2223 _args: &MaybeThreadPool,
2224 _config: &C,
2225 ) {
2226 Python::attach(|py| {
2227 self.0
2228 .bind(py)
2229 .call_method1(
2230 "observe",
2231 (
2232 current_step,
2233 PyMinimizationStatus(MinimizationStatus::GradientFreeStatus(Arc::new(
2234 Mutex::new(status.clone()),
2235 ))),
2236 ),
2237 )
2238 .expect("Error calling observe");
2239 })
2240 }
2241 }
2242 impl<A, P, C> Observer<A, P, SwarmStatus, MaybeThreadPool, LadduError, C> for MinimizationObserver
2243 where
2244 A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
2245 {
2246 fn observe(
2247 &mut self,
2248 current_step: usize,
2249 _algorithm: &A,
2250 _problem: &P,
2251 status: &SwarmStatus,
2252 _args: &MaybeThreadPool,
2253 _config: &C,
2254 ) {
2255 Python::attach(|py| {
2256 self.0
2257 .bind(py)
2258 .call_method1(
2259 "observe",
2260 (
2261 current_step,
2262 PyMinimizationStatus(MinimizationStatus::SwarmStatus(Arc::new(
2263 Mutex::new(status.clone()),
2264 ))),
2265 ),
2266 )
2267 .expect("Error calling observe");
2268 })
2269 }
2270 }
2271
2272 #[derive(Clone)]
2278 pub struct MinimizationTerminator(Arc<Py<PyAny>>);
2279
2280 impl FromPyObject<'_> for MinimizationTerminator {
2281 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2282 Ok(MinimizationTerminator(Arc::new(ob.clone().unbind())))
2283 }
2284 }
2285
2286 impl<A, P, C> Terminator<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
2287 for MinimizationTerminator
2288 where
2289 A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
2290 {
2291 fn check_for_termination(
2292 &mut self,
2293 current_step: usize,
2294 _algorithm: &mut A,
2295 _problem: &P,
2296 status: &mut GradientStatus,
2297 _args: &MaybeThreadPool,
2298 _config: &C,
2299 ) -> ControlFlow<()> {
2300 Python::attach(|py| -> PyResult<ControlFlow<()>> {
2301 let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2302 let py_status = Py::new(
2303 py,
2304 PyMinimizationStatus(MinimizationStatus::GradientStatus(
2305 wrapped_status.clone(),
2306 )),
2307 )?;
2308 let ret = self
2309 .0
2310 .bind(py)
2311 .call_method1("check_for_termination", (current_step, py_status))
2312 .expect("Error calling check_for_termination");
2313 {
2314 let mut guard = wrapped_status.lock();
2315 std::mem::swap(status, &mut *guard);
2316 }
2317 let cf: PyControlFlow = ret.extract()?;
2318 Ok(cf.into())
2319 })
2320 .unwrap_or(ControlFlow::Continue(()))
2321 }
2322 }
2323 impl<A, P, C> Terminator<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
2324 for MinimizationTerminator
2325 where
2326 A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
2327 {
2328 fn check_for_termination(
2329 &mut self,
2330 current_step: usize,
2331 _algorithm: &mut A,
2332 _problem: &P,
2333 status: &mut GradientFreeStatus,
2334 _args: &MaybeThreadPool,
2335 _config: &C,
2336 ) -> ControlFlow<()> {
2337 Python::attach(|py| -> PyResult<ControlFlow<()>> {
2338 let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2339 let py_status = Py::new(
2340 py,
2341 PyMinimizationStatus(MinimizationStatus::GradientFreeStatus(
2342 wrapped_status.clone(),
2343 )),
2344 )?;
2345 let ret = self
2346 .0
2347 .bind(py)
2348 .call_method1("check_for_termination", (current_step, py_status))
2349 .expect("Error calling check_for_termination");
2350 {
2351 let mut guard = wrapped_status.lock();
2352 std::mem::swap(status, &mut *guard);
2353 }
2354 let cf: PyControlFlow = ret.extract()?;
2355 Ok(cf.into())
2356 })
2357 .unwrap_or(ControlFlow::Continue(()))
2358 }
2359 }
2360 impl<A, P, C> Terminator<A, P, SwarmStatus, MaybeThreadPool, LadduError, C>
2361 for MinimizationTerminator
2362 where
2363 A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
2364 {
2365 fn check_for_termination(
2366 &mut self,
2367 current_step: usize,
2368 _algorithm: &mut A,
2369 _problem: &P,
2370 status: &mut SwarmStatus,
2371 _args: &MaybeThreadPool,
2372 _config: &C,
2373 ) -> ControlFlow<()> {
2374 Python::attach(|py| -> PyResult<ControlFlow<()>> {
2375 let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2376 let py_status = Py::new(
2377 py,
2378 PyMinimizationStatus(MinimizationStatus::SwarmStatus(wrapped_status.clone())),
2379 )?;
2380 let ret = self
2381 .0
2382 .bind(py)
2383 .call_method1("check_for_termination", (current_step, py_status))
2384 .expect("Error calling check_for_termination");
2385 {
2386 let mut guard = wrapped_status.lock();
2387 std::mem::swap(status, &mut *guard);
2388 }
2389 let cf: PyControlFlow = ret.extract()?;
2390 Ok(cf.into())
2391 })
2392 .unwrap_or(ControlFlow::Continue(()))
2393 }
2394 }
2395
2396 #[pyclass(name = "Walker", module = "laddu")]
2399 pub struct PyWalker(pub Walker);
2400
2401 #[pymethods]
2402 impl PyWalker {
2403 #[getter]
2409 fn dimension(&self) -> (usize, usize) {
2410 self.0.dimension()
2411 }
2412 fn get_latest<'py>(&self, py: Python<'py>) -> (Bound<'py, PyArray1<Float>>, Float) {
2415 let point = self.0.get_latest();
2416 let lock = point.read();
2417 (lock.x.clone().as_slice().to_pyarray(py), lock.fx_checked())
2418 }
2419 }
2420
2421 #[pyclass(name = "EnsembleStatus", module = "laddu")]
2424 pub struct PyEnsembleStatus(Arc<Mutex<EnsembleStatus>>);
2425
2426 #[pymethods]
2427 impl PyEnsembleStatus {
2428 #[getter]
2435 fn message(&self) -> String {
2436 self.0.lock().message().to_string()
2437 }
2438 #[getter]
2445 fn n_f_evals(&self) -> usize {
2446 self.0.lock().n_f_evals
2447 }
2448 #[getter]
2455 fn n_g_evals(&self) -> usize {
2456 self.0.lock().n_g_evals
2457 }
2458 #[getter]
2465 fn walkers(&self) -> Vec<PyWalker> {
2466 self.0
2467 .lock()
2468 .walkers
2469 .iter()
2470 .map(|w| PyWalker(w.clone()))
2471 .collect()
2472 }
2473 #[getter]
2480 fn dimension(&self) -> (usize, usize, usize) {
2481 self.0.lock().dimension()
2482 }
2483
2484 #[pyo3(signature = (*, burn = None, thin = None))]
2498 fn get_chain<'py>(
2499 &self,
2500 py: Python<'py>,
2501 burn: Option<usize>,
2502 thin: Option<usize>,
2503 ) -> PyResult<Bound<'py, PyArray3<Float>>> {
2504 let vec_chain: Vec<Vec<Vec<Float>>> = self
2505 .0
2506 .lock()
2507 .get_chain(burn, thin)
2508 .iter()
2509 .map(|steps| steps.iter().map(|p| p.as_slice().to_vec()).collect())
2510 .collect();
2511 Ok(PyArray3::from_vec3(py, &vec_chain)?)
2512 }
2513
2514 #[pyo3(signature = (*, burn = None, thin = None))]
2528 fn get_flat_chain<'py>(
2529 &self,
2530 py: Python<'py>,
2531 burn: Option<usize>,
2532 thin: Option<usize>,
2533 ) -> Bound<'py, PyArray2<Float>> {
2534 DMatrix::from_columns(&self.0.lock().get_flat_chain(burn, thin))
2535 .transpose()
2536 .to_pyarray(py)
2537 }
2538 }
2539
2540 #[pyclass(name = "MCMCSummary", module = "laddu")]
2543 pub struct PyMCMCSummary(pub MCMCSummary);
2544
2545 #[pymethods]
2546 impl PyMCMCSummary {
2547 #[getter]
2554 fn bounds(&self) -> Option<Vec<(Float, Float)>> {
2555 self.0
2556 .clone()
2557 .bounds
2558 .map(|bs| bs.iter().map(|b| b.0.as_floats()).collect())
2559 }
2560 #[getter]
2567 fn parameter_names(&self) -> Vec<String> {
2568 self.0.parameter_names.clone().unwrap_or_default()
2569 }
2570 #[getter]
2577 fn message(&self) -> String {
2578 self.0.message.clone()
2579 }
2580 #[getter]
2587 fn cost_evals(&self) -> usize {
2588 self.0.cost_evals
2589 }
2590 #[getter]
2597 fn gradient_evals(&self) -> usize {
2598 self.0.gradient_evals
2599 }
2600 #[getter]
2607 fn converged(&self) -> bool {
2608 self.0.converged
2609 }
2610 #[getter]
2617 fn dimension(&self) -> (usize, usize, usize) {
2618 self.0.dimension
2619 }
2620
2621 #[pyo3(signature = (*, burn = None, thin = None))]
2635 fn get_chain<'py>(
2636 &self,
2637 py: Python<'py>,
2638 burn: Option<usize>,
2639 thin: Option<usize>,
2640 ) -> PyResult<Bound<'py, PyArray3<Float>>> {
2641 let vec_chain: Vec<Vec<Vec<Float>>> = self
2642 .0
2643 .get_chain(burn, thin)
2644 .iter()
2645 .map(|steps| steps.iter().map(|p| p.as_slice().to_vec()).collect())
2646 .collect();
2647 Ok(PyArray3::from_vec3(py, &vec_chain)?)
2648 }
2649
2650 #[pyo3(signature = (*, burn = None, thin = None))]
2664 fn get_flat_chain<'py>(
2665 &self,
2666 py: Python<'py>,
2667 burn: Option<usize>,
2668 thin: Option<usize>,
2669 ) -> Bound<'py, PyArray2<Float>> {
2670 DMatrix::from_columns(&self.0.get_flat_chain(burn, thin))
2671 .transpose()
2672 .to_pyarray(py)
2673 }
2674
2675 #[new]
2676 fn new() -> Self {
2677 Self(MCMCSummary::create_null())
2678 }
2679 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
2680 Ok(PyBytes::new(
2681 py,
2682 bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
2683 .map_err(LadduError::EncodeError)?
2684 .as_slice(),
2685 ))
2686 }
2687 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
2688 *self = Self(
2689 bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
2690 .map_err(LadduError::DecodeError)?
2691 .0,
2692 );
2693 Ok(())
2694 }
2695 }
2696
2697 #[derive(Clone)]
2702 pub struct MCMCObserver(Arc<Py<PyAny>>);
2703
2704 impl FromPyObject<'_> for MCMCObserver {
2705 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2706 Ok(MCMCObserver(Arc::new(ob.clone().unbind())))
2707 }
2708 }
2709
2710 impl<A, P, C> Observer<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCObserver
2711 where
2712 A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
2713 {
2714 fn observe(
2715 &mut self,
2716 current_step: usize,
2717 _algorithm: &A,
2718 _problem: &P,
2719 status: &EnsembleStatus,
2720 _args: &MaybeThreadPool,
2721 _config: &C,
2722 ) {
2723 Python::attach(|py| {
2724 self.0
2725 .bind(py)
2726 .call_method1(
2727 "observe",
2728 (
2729 current_step,
2730 PyEnsembleStatus(Arc::new(Mutex::new(status.clone()))),
2731 ),
2732 )
2733 .expect("Error calling observe");
2734 })
2735 }
2736 }
2737
2738 #[derive(Clone)]
2739 enum PythonMCMCTerminator {
2740 UserDefined(MCMCTerminator),
2741 Autocorrelation(PyAutocorrelationTerminator),
2742 }
2743
2744 impl<A, P, C> Terminator<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C>
2745 for PythonMCMCTerminator
2746 where
2747 A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
2748 {
2749 fn check_for_termination(
2750 &mut self,
2751 current_step: usize,
2752 algorithm: &mut A,
2753 problem: &P,
2754 status: &mut EnsembleStatus,
2755 args: &MaybeThreadPool,
2756 config: &C,
2757 ) -> ControlFlow<()> {
2758 match self {
2759 Self::UserDefined(mcmcterminator) => mcmcterminator.check_for_termination(
2760 current_step,
2761 algorithm,
2762 problem,
2763 status,
2764 args,
2765 config,
2766 ),
2767 Self::Autocorrelation(py_autocorrelation_terminator) => {
2768 py_autocorrelation_terminator.0.check_for_termination(
2769 current_step,
2770 algorithm,
2771 problem,
2772 status,
2773 args,
2774 config,
2775 )
2776 }
2777 }
2778 }
2779 }
2780 #[derive(Clone)]
2786 pub struct MCMCTerminator(Arc<Py<PyAny>>);
2787
2788 impl FromPyObject<'_> for MCMCTerminator {
2789 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2790 Ok(MCMCTerminator(Arc::new(ob.clone().unbind())))
2791 }
2792 }
2793
2794 impl<A, P, C> Terminator<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCTerminator
2795 where
2796 A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
2797 {
2798 fn check_for_termination(
2799 &mut self,
2800 current_step: usize,
2801 _algorithm: &mut A,
2802 _problem: &P,
2803 status: &mut EnsembleStatus,
2804 _args: &MaybeThreadPool,
2805 _config: &C,
2806 ) -> ControlFlow<()> {
2807 Python::attach(|py| -> PyResult<ControlFlow<()>> {
2808 let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2809 let py_status = Py::new(py, PyEnsembleStatus(wrapped_status.clone()))?;
2810 let ret = self
2811 .0
2812 .bind(py)
2813 .call_method1("check_for_termination", (current_step, py_status))
2814 .expect("Error calling check_for_termination");
2815 {
2816 let mut guard = wrapped_status.lock();
2817 std::mem::swap(status, &mut *guard);
2818 }
2819 let cf: PyControlFlow = ret.extract()?;
2820 Ok(cf.into())
2821 })
2822 .unwrap_or(ControlFlow::Continue(()))
2823 }
2824 }
2825
2826 #[pyfunction(name = "integrated_autocorrelation_times")]
2843 #[pyo3(signature = (samples, *, c=None))]
2844 pub fn py_integrated_autocorrelation_times<'py>(
2845 py: Python<'py>,
2846 samples: Vec<Vec<Vec<Float>>>,
2847 c: Option<Float>,
2848 ) -> Bound<'py, PyArray1<Float>> {
2849 let samples: Vec<Vec<DVector<Float>>> = samples
2850 .into_iter()
2851 .map(|v| v.into_iter().map(|p| DVector::from_vec(p)).collect())
2852 .collect();
2853 integrated_autocorrelation_times(samples, c)
2854 .as_slice()
2855 .to_pyarray(py)
2856 }
2857
2858 #[pyclass(name = "AutocorrelationTerminator", module = "laddu")]
2881 #[derive(Clone)]
2882 pub struct PyAutocorrelationTerminator(Arc<Mutex<AutocorrelationTerminator>>);
2883
2884 #[pymethods]
2885 impl PyAutocorrelationTerminator {
2886 #[new]
2887 #[pyo3(signature = (*, n_check = 50, n_taus_threshold = 50, dtau_threshold = 0.01, discard = 0.5, terminate = true, sokal_window = None, verbose = false))]
2888 fn new(
2889 n_check: usize,
2890 n_taus_threshold: usize,
2891 dtau_threshold: Float,
2892 discard: Float,
2893 terminate: bool,
2894 sokal_window: Option<Float>,
2895 verbose: bool,
2896 ) -> Self {
2897 let mut act = AutocorrelationTerminator::default()
2898 .with_n_check(n_check)
2899 .with_n_taus_threshold(n_taus_threshold)
2900 .with_dtau_threshold(dtau_threshold)
2901 .with_discard(discard)
2902 .with_terminate(terminate)
2903 .with_verbose(verbose);
2904 if let Some(sokal_window) = sokal_window {
2905 act = act.with_sokal_window(sokal_window)
2906 }
2907 Self(act.build())
2908 }
2909
2910 #[getter]
2917 fn taus<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2918 self.0.lock().taus.to_pyarray(py)
2919 }
2920 }
2921}