laddu_extensions/
ganesh_ext.rs

1use crate::{
2    likelihoods::{LikelihoodTerm, StochasticNLL},
3    LikelihoodEvaluator, NLL,
4};
5use ganesh::{
6    algorithms::{
7        gradient::{Adam, AdamConfig, GradientStatus, LBFGSBConfig, LBFGSB},
8        gradient_free::{GradientFreeStatus, NelderMead, NelderMeadConfig},
9        mcmc::{AIESConfig, ESSConfig, EnsembleStatus, AIES, ESS},
10        particles::{PSOConfig, SwarmStatus, PSO},
11    },
12    core::{summary::HasParameterNames, Callbacks, MCMCSummary, MinimizationSummary},
13    traits::{Algorithm, CostFunction, Gradient, LogDensity, Observer, Status},
14};
15use laddu_core::{Float, LadduError};
16use nalgebra::DVector;
17#[cfg(feature = "rayon")]
18use rayon::{ThreadPool, ThreadPoolBuilder};
19
20/// A settings wrapper for minimization algorithms
21pub enum MinimizationSettings<P> {
22    /// Settings for the L-BFGS-B algorithm
23    LBFGSB {
24        /// The configuration struct
25        config: LBFGSBConfig,
26        /// Callbacks to apply to the algorithm
27        callbacks: Callbacks<LBFGSB, P, GradientStatus, MaybeThreadPool, LadduError, LBFGSBConfig>,
28        /// The number of threads to use (0 will run in single-threaded mode)
29        num_threads: usize,
30    },
31    /// Settings for the Adam algorithm
32    Adam {
33        /// The configuration struct
34        config: AdamConfig,
35        /// Callbacks to apply to the algorithm
36        callbacks: Callbacks<Adam, P, GradientStatus, MaybeThreadPool, LadduError, AdamConfig>,
37        /// The number of threads to use (0 will run in single-threaded mode)
38        num_threads: usize,
39    },
40    /// Settings for the Nelder-Mead algorithm
41    NelderMead {
42        /// The configuration struct
43        config: NelderMeadConfig,
44        /// Callbacks to apply to the algorithm
45        callbacks: Callbacks<
46            NelderMead,
47            P,
48            GradientFreeStatus,
49            MaybeThreadPool,
50            LadduError,
51            NelderMeadConfig,
52        >,
53        /// The number of threads to use (0 will run in single-threaded mode)
54        num_threads: usize,
55    },
56    /// Settings for the Particle Swarm Optimimization algorithm
57    PSO {
58        /// The configuration struct
59        config: PSOConfig,
60        /// Callbacks to apply to the algorithm
61        callbacks: Callbacks<PSO, P, SwarmStatus, MaybeThreadPool, LadduError, PSOConfig>,
62        /// The number of threads to use (0 will run in single-threaded mode)
63        num_threads: usize,
64    },
65}
66
67/// A settings wrapper for MCMC algorithms
68pub enum MCMCSettings<P> {
69    /// Settings for the Affine Invariant Ensemble Sampler
70    AIES {
71        /// The configuration struct
72        config: AIESConfig,
73        /// Callbacks to apply to the algorithm
74        callbacks: Callbacks<AIES, P, EnsembleStatus, MaybeThreadPool, LadduError, AIESConfig>,
75        /// The number of threads to use (0 will run in single-threaded mode)
76        num_threads: usize,
77    },
78    /// Settings for the Ensemble Slice Sampler
79    ESS {
80        /// The configuration struct
81        config: ESSConfig,
82        /// Callbacks to apply to the algorithm
83        callbacks: Callbacks<ESS, P, EnsembleStatus, MaybeThreadPool, LadduError, ESSConfig>,
84        /// The number of threads to use (0 will run in single-threaded mode)
85        num_threads: usize,
86    },
87}
88
89/// A wrapper struct which conditionally contains a thread pool if the rayon feature is enabled.
90#[derive(Debug)]
91pub struct MaybeThreadPool {
92    #[cfg(feature = "rayon")]
93    /// The underlying [`ThreadPool`]
94    pub thread_pool: ThreadPool,
95}
96/// A trait for objects which can be used in multithreaded contexts
97pub trait Threadable {
98    /// Returns a [`MaybeThreadPool`] associated with the object.
99    ///
100    /// # Errors
101    ///
102    /// This will return an error if there is any problem constructing the underlying thread pool.
103    fn get_pool(&self) -> Result<MaybeThreadPool, LadduError>;
104}
105impl<P> Threadable for MinimizationSettings<P> {
106    fn get_pool(&self) -> Result<MaybeThreadPool, LadduError> {
107        #[cfg(feature = "rayon")]
108        {
109            Ok(MaybeThreadPool {
110                thread_pool: ThreadPoolBuilder::new()
111                    .num_threads(match self {
112                        Self::LBFGSB {
113                            config: _,
114                            callbacks: _,
115                            num_threads,
116                        }
117                        | Self::Adam {
118                            config: _,
119                            callbacks: _,
120                            num_threads,
121                        }
122                        | Self::NelderMead {
123                            config: _,
124                            callbacks: _,
125                            num_threads,
126                        }
127                        | Self::PSO {
128                            config: _,
129                            callbacks: _,
130                            num_threads,
131                        } => *num_threads,
132                    })
133                    .build()
134                    .map_err(LadduError::from)?,
135            })
136        }
137        #[cfg(not(feature = "rayon"))]
138        {
139            Ok(MaybeThreadPool {})
140        }
141    }
142}
143
144impl<P> Threadable for MCMCSettings<P> {
145    fn get_pool(&self) -> Result<MaybeThreadPool, LadduError> {
146        #[cfg(feature = "rayon")]
147        {
148            Ok(MaybeThreadPool {
149                thread_pool: ThreadPoolBuilder::new()
150                    .num_threads(match self {
151                        Self::AIES {
152                            config: _,
153                            callbacks: _,
154                            num_threads,
155                        }
156                        | Self::ESS {
157                            config: _,
158                            callbacks: _,
159                            num_threads,
160                        } => *num_threads,
161                    })
162                    .build()
163                    .map_err(LadduError::from)?,
164            })
165        }
166        #[cfg(not(feature = "rayon"))]
167        {
168            Ok(MaybeThreadPool {})
169        }
170    }
171}
172
173#[derive(Copy, Clone)]
174struct LikelihoodTermObserver;
175impl<A, P, S, U, E, C> Observer<A, P, S, U, E, C> for LikelihoodTermObserver
176where
177    A: Algorithm<P, S, U, E, Config = C>,
178    P: LikelihoodTerm,
179    S: Status,
180{
181    fn observe(
182        &mut self,
183        _current_step: usize,
184        _algorithm: &A,
185        problem: &P,
186        _status: &S,
187        _args: &U,
188        _config: &C,
189    ) {
190        problem.update();
191    }
192}
193
194impl CostFunction<MaybeThreadPool, LadduError> for NLL {
195    fn evaluate(
196        &self,
197        parameters: &DVector<Float>,
198        args: &MaybeThreadPool,
199    ) -> Result<Float, LadduError> {
200        #[cfg(feature = "rayon")]
201        {
202            Ok(args
203                .thread_pool
204                .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
205        }
206        #[cfg(not(feature = "rayon"))]
207        {
208            Ok(LikelihoodTerm::evaluate(self, parameters.into()))
209        }
210    }
211}
212impl Gradient<MaybeThreadPool, LadduError> for NLL {
213    fn gradient(
214        &self,
215        parameters: &DVector<Float>,
216        args: &MaybeThreadPool,
217    ) -> Result<DVector<Float>, LadduError> {
218        #[cfg(feature = "rayon")]
219        {
220            Ok(args
221                .thread_pool
222                .install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into())))
223        }
224        #[cfg(not(feature = "rayon"))]
225        {
226            Ok(LikelihoodTerm::evaluate_gradient(self, parameters.into()))
227        }
228    }
229}
230impl LogDensity<MaybeThreadPool, LadduError> for NLL {
231    fn log_density(
232        &self,
233        parameters: &DVector<Float>,
234        args: &MaybeThreadPool,
235    ) -> Result<Float, LadduError> {
236        #[cfg(feature = "rayon")]
237        {
238            Ok(-args
239                .thread_pool
240                .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
241        }
242        #[cfg(not(feature = "rayon"))]
243        {
244            Ok(-LikelihoodTerm::evaluate(self, parameters.into()))
245        }
246    }
247}
248
249impl NLL {
250    /// Minimize the [`NLL`] with the algorithm given in the [`MinimizationSettings`].
251    ///
252    /// # Errors
253    ///
254    /// This method may return an error if there was any problem constructing thread pools or
255    /// evaluating the underlying model.
256    pub fn minimize(
257        &self,
258        settings: MinimizationSettings<Self>,
259    ) -> Result<MinimizationSummary, LadduError> {
260        let mtp = settings.get_pool()?;
261        Ok(match settings {
262            MinimizationSettings::LBFGSB {
263                config,
264                callbacks,
265                num_threads: _,
266            } => LBFGSB::default().process(
267                self,
268                &mtp,
269                config,
270                callbacks.with_observer(LikelihoodTermObserver),
271            ),
272            MinimizationSettings::Adam {
273                config,
274                callbacks,
275                num_threads: _,
276            } => Adam::default().process(
277                self,
278                &mtp,
279                config,
280                callbacks.with_observer(LikelihoodTermObserver),
281            ),
282            MinimizationSettings::NelderMead {
283                config,
284                callbacks,
285                num_threads: _,
286            } => NelderMead::default().process(
287                self,
288                &mtp,
289                config,
290                callbacks.with_observer(LikelihoodTermObserver),
291            ),
292            MinimizationSettings::PSO {
293                config,
294                callbacks,
295                num_threads: _,
296            } => PSO::default().process(
297                self,
298                &mtp,
299                config,
300                callbacks.with_observer(LikelihoodTermObserver),
301            ),
302        }?
303        .with_parameter_names(self.parameters()))
304    }
305
306    /// Run an MCMC sampling algorithm over the [`NLL`] with the given [`MCMCSettings`].
307    ///
308    /// # Errors
309    ///
310    /// This method may return an error if there was any problem constructing thread pools or
311    /// evaluating the underlying model.
312    pub fn mcmc(&self, settings: MCMCSettings<Self>) -> Result<MCMCSummary, LadduError> {
313        let mtp = settings.get_pool()?;
314        Ok(match settings {
315            MCMCSettings::AIES {
316                config,
317                callbacks,
318                num_threads: _,
319            } => AIES::default().process(
320                self,
321                &mtp,
322                config,
323                callbacks.with_observer(LikelihoodTermObserver),
324            ),
325            MCMCSettings::ESS {
326                config,
327                callbacks,
328                num_threads: _,
329            } => ESS::default().process(
330                self,
331                &mtp,
332                config,
333                callbacks.with_observer(LikelihoodTermObserver),
334            ),
335        }?
336        .with_parameter_names(self.parameters()))
337    }
338}
339
340impl CostFunction<MaybeThreadPool, LadduError> for StochasticNLL {
341    fn evaluate(
342        &self,
343        parameters: &DVector<Float>,
344        args: &MaybeThreadPool,
345    ) -> Result<Float, LadduError> {
346        #[cfg(feature = "rayon")]
347        {
348            Ok(args
349                .thread_pool
350                .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
351        }
352        #[cfg(not(feature = "rayon"))]
353        {
354            Ok(LikelihoodTerm::evaluate(self, parameters.into()))
355        }
356    }
357}
358impl Gradient<MaybeThreadPool, LadduError> for StochasticNLL {
359    fn gradient(
360        &self,
361        parameters: &DVector<Float>,
362        args: &MaybeThreadPool,
363    ) -> Result<DVector<Float>, LadduError> {
364        #[cfg(feature = "rayon")]
365        {
366            Ok(args
367                .thread_pool
368                .install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into())))
369        }
370        #[cfg(not(feature = "rayon"))]
371        {
372            Ok(LikelihoodTerm::evaluate_gradient(self, parameters.into()))
373        }
374    }
375}
376impl LogDensity<MaybeThreadPool, LadduError> for StochasticNLL {
377    fn log_density(
378        &self,
379        parameters: &DVector<Float>,
380        args: &MaybeThreadPool,
381    ) -> Result<Float, LadduError> {
382        #[cfg(feature = "rayon")]
383        {
384            Ok(-args
385                .thread_pool
386                .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
387        }
388        #[cfg(not(feature = "rayon"))]
389        {
390            Ok(-LikelihoodTerm::evaluate(self, parameters.into()))
391        }
392    }
393}
394
395impl StochasticNLL {
396    /// Minimize the [`StochasticNLL`] with the algorithm given in the [`MinimizationSettings`].
397    ///
398    /// # Errors
399    ///
400    /// This method may return an error if there was any problem constructing thread pools or
401    /// evaluating the underlying model.
402    pub fn minimize(
403        &self,
404        settings: MinimizationSettings<Self>,
405    ) -> Result<MinimizationSummary, LadduError> {
406        let mtp = settings.get_pool()?;
407        Ok(match settings {
408            MinimizationSettings::LBFGSB {
409                config,
410                callbacks,
411                num_threads: _,
412            } => LBFGSB::default().process(
413                self,
414                &mtp,
415                config,
416                callbacks.with_observer(LikelihoodTermObserver),
417            ),
418            MinimizationSettings::Adam {
419                config,
420                callbacks,
421                num_threads: _,
422            } => Adam::default().process(
423                self,
424                &mtp,
425                config,
426                callbacks.with_observer(LikelihoodTermObserver),
427            ),
428            MinimizationSettings::NelderMead {
429                config,
430                callbacks,
431                num_threads: _,
432            } => NelderMead::default().process(
433                self,
434                &mtp,
435                config,
436                callbacks.with_observer(LikelihoodTermObserver),
437            ),
438            MinimizationSettings::PSO {
439                config,
440                callbacks,
441                num_threads: _,
442            } => PSO::default().process(
443                self,
444                &mtp,
445                config,
446                callbacks.with_observer(LikelihoodTermObserver),
447            ),
448        }?
449        .with_parameter_names(self.parameters()))
450    }
451
452    /// Run an MCMC sampling algorithm over the [`StochasticNLL`] with the given [`MCMCSettings`].
453    ///
454    /// # Errors
455    ///
456    /// This method may return an error if there was any problem constructing thread pools or
457    /// evaluating the underlying model.
458    pub fn mcmc(&self, settings: MCMCSettings<Self>) -> Result<MCMCSummary, LadduError> {
459        let mtp = settings.get_pool()?;
460        Ok(match settings {
461            MCMCSettings::AIES {
462                config,
463                callbacks,
464                num_threads: _,
465            } => AIES::default().process(
466                self,
467                &mtp,
468                config,
469                callbacks.with_observer(LikelihoodTermObserver),
470            ),
471            MCMCSettings::ESS {
472                config,
473                callbacks,
474                num_threads: _,
475            } => ESS::default().process(
476                self,
477                &mtp,
478                config,
479                callbacks.with_observer(LikelihoodTermObserver),
480            ),
481        }?
482        .with_parameter_names(self.parameters()))
483    }
484}
485
486impl CostFunction<MaybeThreadPool, LadduError> for LikelihoodEvaluator {
487    fn evaluate(
488        &self,
489        parameters: &DVector<Float>,
490        args: &MaybeThreadPool,
491    ) -> Result<Float, LadduError> {
492        #[cfg(feature = "rayon")]
493        {
494            Ok(args
495                .thread_pool
496                .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
497        }
498        #[cfg(not(feature = "rayon"))]
499        {
500            Ok(LikelihoodTerm::evaluate(self, parameters.into()))
501        }
502    }
503}
504impl Gradient<MaybeThreadPool, LadduError> for LikelihoodEvaluator {
505    fn gradient(
506        &self,
507        parameters: &DVector<Float>,
508        args: &MaybeThreadPool,
509    ) -> Result<DVector<Float>, LadduError> {
510        #[cfg(feature = "rayon")]
511        {
512            Ok(args
513                .thread_pool
514                .install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into())))
515        }
516        #[cfg(not(feature = "rayon"))]
517        {
518            Ok(LikelihoodTerm::evaluate_gradient(self, parameters.into()))
519        }
520    }
521}
522impl LogDensity<MaybeThreadPool, LadduError> for LikelihoodEvaluator {
523    fn log_density(
524        &self,
525        parameters: &DVector<Float>,
526        args: &MaybeThreadPool,
527    ) -> Result<Float, LadduError> {
528        #[cfg(feature = "rayon")]
529        {
530            Ok(-args
531                .thread_pool
532                .install(|| LikelihoodTerm::evaluate(self, parameters.into())))
533        }
534        #[cfg(not(feature = "rayon"))]
535        {
536            Ok(-LikelihoodTerm::evaluate(self, parameters.into()))
537        }
538    }
539}
540
541impl LikelihoodEvaluator {
542    /// Minimize the [`LikelihoodEvaluator`] with the algorithm given in the [`MinimizationSettings`].
543    ///
544    /// # Errors
545    ///
546    /// This method may return an error if there was any problem constructing thread pools or
547    /// evaluating the underlying model.
548    pub fn minimize(
549        &self,
550        settings: MinimizationSettings<Self>,
551    ) -> Result<MinimizationSummary, LadduError> {
552        let mtp = settings.get_pool()?;
553        match settings {
554            MinimizationSettings::LBFGSB {
555                config,
556                callbacks,
557                num_threads: _,
558            } => LBFGSB::default().process(
559                self,
560                &mtp,
561                config,
562                callbacks.with_observer(LikelihoodTermObserver),
563            ),
564            MinimizationSettings::Adam {
565                config,
566                callbacks,
567                num_threads: _,
568            } => Adam::default().process(
569                self,
570                &mtp,
571                config,
572                callbacks.with_observer(LikelihoodTermObserver),
573            ),
574            MinimizationSettings::NelderMead {
575                config,
576                callbacks,
577                num_threads: _,
578            } => NelderMead::default().process(
579                self,
580                &mtp,
581                config,
582                callbacks.with_observer(LikelihoodTermObserver),
583            ),
584            MinimizationSettings::PSO {
585                config,
586                callbacks,
587                num_threads: _,
588            } => PSO::default().process(
589                self,
590                &mtp,
591                config,
592                callbacks.with_observer(LikelihoodTermObserver),
593            ),
594        }
595    }
596
597    /// Run an MCMC sampling algorithm over the [`LikelihoodEvaluator`] with the given [`MCMCSettings`].
598    ///
599    /// # Errors
600    ///
601    /// This method may return an error if there was any problem constructing thread pools or
602    /// evaluating the underlying model.
603    pub fn mcmc(&self, settings: MCMCSettings<Self>) -> Result<MCMCSummary, LadduError> {
604        let mtp = settings.get_pool()?;
605        match settings {
606            MCMCSettings::AIES {
607                config,
608                callbacks,
609                num_threads: _,
610            } => AIES::default().process(
611                self,
612                &mtp,
613                config,
614                callbacks.with_observer(LikelihoodTermObserver),
615            ),
616            MCMCSettings::ESS {
617                config,
618                callbacks,
619                num_threads: _,
620            } => ESS::default().process(
621                self,
622                &mtp,
623                config,
624                callbacks.with_observer(LikelihoodTermObserver),
625            ),
626        }
627    }
628}
629
630/// Python bindings for the [`ganesh`] crate
631#[cfg(feature = "python")]
632pub mod py_ganesh {
633    use std::{ops::ControlFlow, sync::Arc};
634
635    use super::*;
636
637    use ganesh::{
638        algorithms::{
639            gradient::{
640                adam::AdamEMATerminator,
641                lbfgsb::{
642                    LBFGSBErrorMode, LBFGSBFTerminator, LBFGSBGTerminator, LBFGSBInfNormGTerminator,
643                },
644            },
645            gradient_free::nelder_mead::{
646                NelderMeadFTerminator, NelderMeadXTerminator, SimplexConstructionMethod,
647                SimplexExpansionMethod,
648            },
649            line_search::{HagerZhangLineSearch, MoreThuenteLineSearch, StrongWolfeLineSearch},
650            mcmc::{
651                integrated_autocorrelation_times, AIESMove, AutocorrelationTerminator, ESSMove,
652                Walker,
653            },
654            particles::{
655                Swarm, SwarmBoundaryMethod, SwarmParticle, SwarmPositionInitializer, SwarmTopology,
656                SwarmUpdateMethod, SwarmVelocityInitializer,
657            },
658        },
659        core::{Bounds, CtrlCAbortSignal, DebugObserver, MaxSteps},
660        traits::{Observer, Status, SupportsBounds, SupportsTransform, Terminator},
661    };
662    use laddu_core::{Float, LadduError, ReadWrite};
663    use nalgebra::DMatrix;
664    use numpy::{PyArray1, PyArray2, PyArray3, ToPyArray};
665    use parking_lot::Mutex;
666    use pyo3::{
667        exceptions::{PyTypeError, PyValueError},
668        prelude::*,
669        types::{PyBytes, PyDict, PyList},
670    };
671
672    /// A helper trait for parsing Python arguments.
673    pub trait FromPyArgs<A = ()>: Sized {
674        /// Convert the given Python arguments into a [`Self`].
675        fn from_pyargs(args: &A, d: &Bound<PyDict>) -> PyResult<Self>;
676    }
677    impl FromPyArgs<Vec<Float>> for LBFGSBConfig {
678        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
679            let mut config = LBFGSBConfig::new(args);
680            if let Some(m) = d.get_item("m")? {
681                let m_int = m.extract()?;
682                config = config.with_memory_limit(m_int);
683            }
684            if let Some(flag) = d.get_item("skip_hessian")? {
685                if flag.extract()? {
686                    config = config.with_error_mode(LBFGSBErrorMode::Skip);
687                }
688            }
689            if let Some(linesearch_dict) = d.get_item("line_search")? {
690                config = config.with_line_search(StrongWolfeLineSearch::from_pyargs(
691                    &(),
692                    &linesearch_dict.extract()?,
693                )?);
694            }
695            Ok(config)
696        }
697    }
698    impl FromPyArgs for StrongWolfeLineSearch {
699        fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
700            if let Some(method) = d.get_item("method")? {
701                match method
702                    .extract::<String>()?
703                    .to_lowercase()
704                    .trim()
705                    .replace("-", "")
706                    .replace(" ", "")
707                    .as_str()
708                {
709                    "morethuente" => {
710                        let mut line_search = MoreThuenteLineSearch::default();
711                        if let Some(max_iterations) = d.get_item("max_iterations")? {
712                            line_search =
713                                line_search.with_max_iterations(max_iterations.extract()?);
714                        }
715                        if let Some(max_zoom) = d.get_item("max_zoom")? {
716                            line_search = line_search.with_max_zoom(max_zoom.extract()?);
717                        }
718                        match (d.get_item("c1")?, d.get_item("c2")?) {
719                            (Some(c1), Some(c2)) => {
720                                line_search = line_search.with_c1_c2(c1.extract()?, c2.extract()?);
721                            }
722                            (Some(c1), None) => {
723                                line_search = line_search.with_c1(c1.extract()?);
724                            }
725                            (None, Some(c2)) => {
726                                line_search = line_search.with_c2(c2.extract()?);
727                            }
728                            (None, None) => {}
729                        }
730                        Ok(StrongWolfeLineSearch::MoreThuente(line_search))
731                    }
732                    "hagerzhang" => {
733                        let mut line_search = HagerZhangLineSearch::default();
734                        if let Some(max_iterations) = d.get_item("max_iterations")? {
735                            line_search =
736                                line_search.with_max_iterations(max_iterations.extract()?);
737                        }
738                        if let Some(max_bisects) = d.get_item("max_bisects")? {
739                            line_search = line_search.with_max_bisects(max_bisects.extract()?);
740                        }
741                        match (d.get_item("delta")?, d.get_item("sigma")?) {
742                            (Some(delta), Some(sigma)) => {
743                                line_search = line_search
744                                    .with_delta_sigma(delta.extract()?, sigma.extract()?);
745                            }
746                            (Some(delta), None) => {
747                                line_search = line_search.with_delta(delta.extract()?);
748                            }
749                            (None, Some(sigma)) => {
750                                line_search = line_search.with_sigma(sigma.extract()?);
751                            }
752                            (None, None) => {}
753                        }
754                        if let Some(epsilon) = d.get_item("epsilon")? {
755                            line_search = line_search.with_epsilon(epsilon.extract()?);
756                        }
757                        if let Some(theta) = d.get_item("theta")? {
758                            line_search = line_search.with_theta(theta.extract()?);
759                        }
760                        if let Some(gamma) = d.get_item("gamma")? {
761                            line_search = line_search.with_gamma(gamma.extract()?);
762                        }
763                        Ok(StrongWolfeLineSearch::HagerZhang(line_search))
764                    }
765                    _ => Err(PyTypeError::new_err(format!(
766                        "Invalid line search method: {}",
767                        method
768                    ))),
769                }
770            } else {
771                Err(PyTypeError::new_err("Line search method not specified"))
772            }
773        }
774    }
775    impl<P> FromPyArgs
776        for Callbacks<LBFGSB, P, GradientStatus, MaybeThreadPool, LadduError, LBFGSBConfig>
777    where
778        P: Gradient<MaybeThreadPool, LadduError>,
779    {
780        fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
781            let mut callbacks = Callbacks::empty();
782            if let Some(eps_f) = d.get_item("eps_f")? {
783                if let Some(eps_abs) = eps_f.extract()? {
784                    callbacks = callbacks.with_terminator(LBFGSBFTerminator { eps_abs });
785                } else {
786                    callbacks = callbacks.with_terminator(LBFGSBFTerminator::default());
787                }
788            } else {
789                callbacks = callbacks.with_terminator(LBFGSBFTerminator::default());
790            }
791            if let Some(eps_g) = d.get_item("eps_g")? {
792                if let Some(eps_abs) = eps_g.extract()? {
793                    callbacks = callbacks.with_terminator(LBFGSBGTerminator { eps_abs });
794                } else {
795                    callbacks = callbacks.with_terminator(LBFGSBGTerminator::default());
796                }
797            } else {
798                callbacks = callbacks.with_terminator(LBFGSBGTerminator::default());
799            }
800            if let Some(eps_norm_g) = d.get_item("eps_norm_g")? {
801                if let Some(eps_abs) = eps_norm_g.extract()? {
802                    callbacks = callbacks.with_terminator(LBFGSBInfNormGTerminator { eps_abs });
803                } else {
804                    callbacks = callbacks.with_terminator(LBFGSBInfNormGTerminator::default());
805                }
806            } else {
807                callbacks = callbacks.with_terminator(LBFGSBInfNormGTerminator::default());
808            }
809            Ok(callbacks)
810        }
811    }
812    impl FromPyArgs<Vec<Float>> for AdamConfig {
813        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
814            let mut config = AdamConfig::new(args);
815            if let Some(alpha) = d.get_item("alpha")? {
816                config = config.with_alpha(alpha.extract()?);
817            }
818            if let Some(beta_1) = d.get_item("beta_1")? {
819                config = config.with_beta_1(beta_1.extract()?);
820            }
821            if let Some(beta_2) = d.get_item("beta_2")? {
822                config = config.with_beta_2(beta_2.extract()?);
823            }
824            if let Some(epsilon) = d.get_item("epsilon")? {
825                config = config.with_epsilon(epsilon.extract()?);
826            }
827            Ok(config)
828        }
829    }
830    impl<P> FromPyArgs for Callbacks<Adam, P, GradientStatus, MaybeThreadPool, LadduError, AdamConfig>
831    where
832        P: Gradient<MaybeThreadPool, LadduError>,
833    {
834        fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
835            let mut callbacks = Callbacks::empty();
836            let mut term = AdamEMATerminator::default();
837            if let Some(beta_c) = d.get_item("beta_c")? {
838                term.beta_c = beta_c.extract()?;
839            }
840            if let Some(eps_loss) = d.get_item("eps_loss")? {
841                term.eps_loss = eps_loss.extract()?;
842            }
843            if let Some(patience) = d.get_item("patience")? {
844                term.patience = patience.extract()?;
845            }
846            callbacks = callbacks.with_terminator(term);
847            Ok(callbacks)
848        }
849    }
850    impl FromPyArgs<Vec<Float>> for NelderMeadConfig {
851        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
852            let construction_method = SimplexConstructionMethod::from_pyargs(args, d)?;
853            let mut config = NelderMeadConfig::new_with_method(construction_method);
854            if let Some(alpha) = d.get_item("alpha")? {
855                config = config.with_alpha(alpha.extract()?);
856            }
857            if let Some(beta) = d.get_item("beta")? {
858                config = config.with_beta(beta.extract()?);
859            }
860            if let Some(gamma) = d.get_item("gamma")? {
861                config = config.with_gamma(gamma.extract()?);
862            }
863            if let Some(delta) = d.get_item("delta")? {
864                config = config.with_delta(delta.extract()?);
865            }
866            if let Some(adaptive) = d.get_item("adaptive")? {
867                if adaptive.extract()? {
868                    config = config.with_adaptive(args.len());
869                }
870            }
871            if let Some(expansion_method) = d.get_item("expansion_method")? {
872                match expansion_method
873                    .extract::<String>()?
874                    .to_lowercase()
875                    .trim()
876                    .replace("-", "")
877                    .replace(" ", "")
878                    .as_str()
879                {
880                    "greedyminimization" => {
881                        config = config
882                            .with_expansion_method(SimplexExpansionMethod::GreedyMinimization);
883                        Ok(())
884                    }
885                    "greedyexpansion" => {
886                        config = config
887                            .with_expansion_method(SimplexExpansionMethod::GreedyMinimization);
888                        Ok(())
889                    }
890                    _ => Err(PyValueError::new_err(format!(
891                        "Invalid expansion method: {}",
892                        expansion_method
893                    ))),
894                }?
895            }
896            Ok(config)
897        }
898    }
899    impl FromPyArgs<Vec<Float>> for SimplexConstructionMethod {
900        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
901            if let Some(simplex_construction_method) = d.get_item("simplex_construction_method")? {
902                match simplex_construction_method
903                    .extract::<String>()?
904                    .to_lowercase()
905                    .trim()
906                    .replace("-", "")
907                    .replace(" ", "")
908                    .as_str()
909                {
910                    "scaledorthogonal" => {
911                        let orthogonal_multiplier = d
912                            .get_item("orthogonal_multiplier")?
913                            .map(|v| v.extract())
914                            .transpose()?
915                            .unwrap_or(1.05);
916                        let orthogonal_zero_step = d
917                            .get_item("orthogonal_zero_step")?
918                            .map(|v| v.extract())
919                            .transpose()?
920                            .unwrap_or(0.00025);
921                        return Ok(SimplexConstructionMethod::custom_scaled_orthogonal(
922                            args,
923                            orthogonal_multiplier,
924                            orthogonal_zero_step,
925                        ));
926                    }
927                    "orthogonal" => {
928                        let simplex_size = d
929                            .get_item("simplex_size")?
930                            .map(|v| v.extract())
931                            .transpose()?
932                            .unwrap_or(1.0);
933                        return Ok(SimplexConstructionMethod::custom_orthogonal(
934                            args,
935                            simplex_size,
936                        ));
937                    }
938                    "custom" => {
939                        if let Some(other_simplex_points) = d.get_item("simplex")? {
940                            let mut simplex = Vec::with_capacity(args.len() + 1);
941                            simplex[0] = DVector::from_vec(args.clone());
942                            let others = other_simplex_points.extract::<Vec<Vec<Float>>>()?; // TODO: numpy arrays
943                            if others.len() != args.len() {
944                                return Err(PyValueError::new_err(format!(
945                                    "Expected {} additional simplex points, got {}.",
946                                    args.len(),
947                                    others.len()
948                                )));
949                            }
950                            simplex.extend(others.iter().map(|x| DVector::from_vec(x.clone())));
951                            return Ok(SimplexConstructionMethod::custom(simplex));
952                        } else {
953                            return Err(PyValueError::new_err("Simplex must be specified when using the 'custom' simplex_construction_method."));
954                        }
955                    }
956                    _ => {
957                        return Err(PyValueError::new_err(format!(
958                            "Invalid simplex_construction_method: {}",
959                            simplex_construction_method
960                        )))
961                    }
962                }
963            } else {
964                Ok(SimplexConstructionMethod::scaled_orthogonal(args))
965            }
966        }
967    }
968    impl<P> FromPyArgs
969        for Callbacks<
970            NelderMead,
971            P,
972            GradientFreeStatus,
973            MaybeThreadPool,
974            LadduError,
975            NelderMeadConfig,
976        >
977    where
978        P: CostFunction<MaybeThreadPool, LadduError>,
979    {
980        fn from_pyargs(_args: &(), d: &Bound<PyDict>) -> PyResult<Self> {
981            let mut callbacks = Callbacks::empty();
982            let eps_f = if let Some(eps_f) = d.get_item("eps_f")? {
983                eps_f.extract()?
984            } else {
985                Float::EPSILON.powf(0.25)
986            };
987            if let Some(f_term) = d.get_item("f_terminator")? {
988                match f_term
989                    .extract::<String>()?
990                    .to_lowercase()
991                    .trim()
992                    .replace("-", "")
993                    .replace(" ", "")
994                    .as_str()
995                {
996                    "amoeba" => {
997                        callbacks = callbacks
998                            .with_terminator(NelderMeadFTerminator::Amoeba { eps_rel: eps_f });
999                    }
1000                    "absolute" => {
1001                        callbacks = callbacks
1002                            .with_terminator(NelderMeadFTerminator::Absolute { eps_abs: eps_f });
1003                    }
1004                    "stddev" => {
1005                        callbacks = callbacks
1006                            .with_terminator(NelderMeadFTerminator::StdDev { eps_abs: eps_f });
1007                    }
1008                    _ => Err(PyValueError::new_err(format!(
1009                        "Invalid f_terminator: {}",
1010                        f_term
1011                    )))?,
1012                }
1013            } else {
1014                callbacks =
1015                    callbacks.with_terminator(NelderMeadFTerminator::StdDev { eps_abs: eps_f });
1016            }
1017            let eps_x = if let Some(eps_x) = d.get_item("eps_x")? {
1018                eps_x.extract()?
1019            } else {
1020                Float::EPSILON.powf(0.25)
1021            };
1022            if let Some(x_term) = d.get_item("x_terminator")? {
1023                match x_term
1024                    .extract::<String>()?
1025                    .to_lowercase()
1026                    .trim()
1027                    .replace("-", "")
1028                    .replace(" ", "")
1029                    .as_str()
1030                {
1031                    "diameter" => {
1032                        callbacks = callbacks
1033                            .with_terminator(NelderMeadXTerminator::Diameter { eps_abs: eps_x });
1034                    }
1035                    "higham" => {
1036                        callbacks = callbacks
1037                            .with_terminator(NelderMeadXTerminator::Higham { eps_rel: eps_x });
1038                    }
1039                    "rowan" => {
1040                        callbacks = callbacks
1041                            .with_terminator(NelderMeadXTerminator::Rowan { eps_rel: eps_x });
1042                    }
1043                    "singer" => {
1044                        callbacks = callbacks
1045                            .with_terminator(NelderMeadXTerminator::Singer { eps_rel: eps_x });
1046                    }
1047                    _ => Err(PyValueError::new_err(format!(
1048                        "Invalid x_terminator: {}",
1049                        x_term
1050                    )))?,
1051                }
1052            } else {
1053                callbacks =
1054                    callbacks.with_terminator(NelderMeadXTerminator::Singer { eps_rel: eps_x });
1055            }
1056            Ok(callbacks)
1057        }
1058    }
1059    impl FromPyArgs<Vec<Float>> for SwarmPositionInitializer {
1060        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1061            if let Some(swarm_position_initializer) = d.get_item("swarm_position_initializer")? {
1062                match swarm_position_initializer
1063                    .extract::<String>()?
1064                    .to_lowercase()
1065                    .trim()
1066                    .replace("-", "")
1067                    .replace(" ", "")
1068                    .as_str()
1069                {
1070                    "randominlimits" => {
1071                        if let (Some(swarm_position_bounds), Some(swarm_size)) = (
1072                            d.get_item("swarm_position_bounds")?,
1073                            d.get_item("swarm_size")?,
1074                        ) {
1075                            return Ok(SwarmPositionInitializer::RandomInLimits {
1076                                bounds: swarm_position_bounds.extract()?,
1077                                n_particles: swarm_size.extract()?,
1078                            });
1079                        } else {
1080                            return Err(PyValueError::new_err("The swarm_position_bounds and swarm_size must be specified when using the 'randominlimits' swarm_position_initializer."));
1081                        }
1082                    }
1083                    "latinhypercube" => {
1084                        if let (Some(swarm_position_bounds), Some(swarm_size)) = (
1085                            d.get_item("swarm_position_bounds")?,
1086                            d.get_item("swarm_size")?,
1087                        ) {
1088                            return Ok(SwarmPositionInitializer::LatinHypercube {
1089                                bounds: swarm_position_bounds.extract()?,
1090                                n_particles: swarm_size.extract()?,
1091                            });
1092                        } else {
1093                            return Err(PyValueError::new_err("The swarm_position_bounds and swarm_size must be specified when using the 'latinhypercube' swarm_position_initializer."));
1094                        }
1095                    }
1096                    "custom" => {
1097                        if let Some(swarm) = d.get_item("swarm")? {
1098                            return Ok(SwarmPositionInitializer::Custom(
1099                                swarm
1100                                    .extract::<Vec<Vec<Float>>>()?
1101                                    .iter()
1102                                    .chain(vec![args].into_iter())
1103                                    .map(|x| DVector::from_vec(x.clone()))
1104                                    .collect(), // TODO: numpy arrays?
1105                            ));
1106                        } else {
1107                            return Err(PyValueError::new_err("The swarm must be specified when using the 'custom' swarm_position_initializer."));
1108                        }
1109                    }
1110                    _ => {
1111                        return Err(PyValueError::new_err(format!(
1112                            "Invalid swarm_position_initializer: {}",
1113                            swarm_position_initializer
1114                        )));
1115                    }
1116                }
1117            } else {
1118                return Err(PyValueError::new_err(
1119                    "The swarm_position_initializer must be specified for the PSO algorithm.",
1120                ));
1121            }
1122        }
1123    }
1124    impl FromPyArgs<Vec<Float>> for Swarm {
1125        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1126            let swarm_position_initializer = SwarmPositionInitializer::from_pyargs(args, d)?;
1127            let mut swarm = Swarm::new(swarm_position_initializer);
1128            if let Some(swarm_topology_str) = d.get_item("swarm_topology")? {
1129                match swarm_topology_str
1130                    .extract::<String>()?
1131                    .to_lowercase()
1132                    .trim()
1133                    .replace("-", "")
1134                    .replace(" ", "")
1135                    .as_str()
1136                {
1137                    "global" => {
1138                        swarm = swarm.with_topology(SwarmTopology::Global);
1139                    }
1140                    "ring" => {
1141                        swarm = swarm.with_topology(SwarmTopology::Ring);
1142                    }
1143                    _ => {
1144                        return Err(PyValueError::new_err(format!(
1145                            "Invalid swarm_topology: {}",
1146                            swarm_topology_str
1147                        )))
1148                    }
1149                }
1150            }
1151            if let Some(swarm_update_method_str) = d.get_item("swarm_update_method")? {
1152                match swarm_update_method_str
1153                    .extract::<String>()?
1154                    .to_lowercase()
1155                    .trim()
1156                    .replace("-", "")
1157                    .replace(" ", "")
1158                    .as_str()
1159                {
1160                    "sync" | "synchronous" => {
1161                        swarm = swarm.with_update_method(SwarmUpdateMethod::Synchronous);
1162                    }
1163                    "async" | "asynchronous" => {
1164                        swarm = swarm.with_update_method(SwarmUpdateMethod::Asynchronous);
1165                    }
1166                    _ => {
1167                        return Err(PyValueError::new_err(format!(
1168                            "Invalid swarm_update_method: {}",
1169                            swarm_update_method_str
1170                        )))
1171                    }
1172                }
1173            }
1174            if let Some(swarm_boundary_method_str) = d.get_item("swarm_boundary_method")? {
1175                match swarm_boundary_method_str
1176                    .extract::<String>()?
1177                    .to_lowercase()
1178                    .trim()
1179                    .replace("-", "")
1180                    .replace(" ", "")
1181                    .as_str()
1182                {
1183                    "inf" => {
1184                        swarm = swarm.with_boundary_method(SwarmBoundaryMethod::Inf);
1185                    }
1186                    "shr" => {
1187                        swarm = swarm.with_boundary_method(SwarmBoundaryMethod::Shr);
1188                    }
1189                    _ => {
1190                        return Err(PyValueError::new_err(format!(
1191                            "Invalid swarm_boundary_method: {}",
1192                            swarm_boundary_method_str
1193                        )))
1194                    }
1195                }
1196            }
1197            if let Some(swarm_velocity_bounds) = d.get_item("swarm_velocity_bounds")? {
1198                swarm = swarm.with_velocity_initializer(SwarmVelocityInitializer::RandomInLimits(
1199                    swarm_velocity_bounds.extract()?,
1200                ));
1201            }
1202            Ok(swarm)
1203        }
1204    }
1205    impl FromPyArgs<Vec<Float>> for PSOConfig {
1206        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1207            let swarm = Swarm::from_pyargs(args, d)?;
1208            let mut config = PSOConfig::new(swarm);
1209            if let Some(omega) = d.get_item("omega")? {
1210                config = config.with_omega(omega.extract()?);
1211            }
1212            if let Some(c1) = d.get_item("c1")? {
1213                config = config.with_c1(c1.extract()?);
1214            }
1215            if let Some(c2) = d.get_item("c2")? {
1216                config = config.with_c2(c2.extract()?);
1217            }
1218            Ok(config)
1219        }
1220    }
1221    impl<P> FromPyArgs<Vec<Float>> for MinimizationSettings<P>
1222    where
1223        P: Gradient<MaybeThreadPool, LadduError>,
1224    {
1225        fn from_pyargs(args: &Vec<Float>, d: &Bound<PyDict>) -> PyResult<Self> {
1226            let bounds: Option<Vec<ganesh::traits::boundlike::Bound>> = d
1227                .get_item("bounds")?
1228                .map(|bounds| bounds.extract::<Vec<(Option<Float>, Option<Float>)>>())
1229                .transpose()?
1230                .map(|bounds| {
1231                    bounds
1232                        .into_iter()
1233                        .map(ganesh::traits::boundlike::Bound::from)
1234                        .collect()
1235                });
1236            let num_threads = d
1237                .get_item("threads")?
1238                .map(|t| t.extract())
1239                .transpose()?
1240                .unwrap_or(0);
1241            let add_debug = d
1242                .get_item("debug")?
1243                .map(|d| d.extract())
1244                .transpose()?
1245                .unwrap_or(false);
1246            let observers = if let Some(observers) = d.get_item("observers")? {
1247                if let Ok(observers) = observers.downcast::<PyList>() {
1248                    observers.into_iter().map(|observer| {
1249                        if let Ok(observer) = observer.extract::<MinimizationObserver>() {
1250                            Ok(observer)
1251                        } else {
1252                            Err(PyValueError::new_err("The observers must be either a single MinimizationObserver or a list of MinimizationObservers."))
1253                        }
1254                    }).collect::<PyResult<Vec<MinimizationObserver>>>()?
1255                } else if let Ok(observer) = observers.extract::<MinimizationObserver>() {
1256                    vec![observer]
1257                } else {
1258                    return Err(PyValueError::new_err("The observers must be either a single MinimizationObserver or a list of MinimizationObservers."));
1259                }
1260            } else {
1261                vec![]
1262            };
1263            let terminators = if let Some(terminators) = d.get_item("terminators")? {
1264                if let Ok(terminators) = terminators.downcast::<PyList>() {
1265                    terminators.into_iter().map(|terminator| {
1266                    if let Ok(terminator) = terminator.extract::<MinimizationTerminator>() {
1267                        Ok(terminator)
1268                    } else {
1269                        Err(PyValueError::new_err("The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators."))
1270                        }
1271                    }).collect::<PyResult<Vec<MinimizationTerminator>>>()?
1272                } else if let Ok(terminator) = terminators.extract::<MinimizationTerminator>() {
1273                    vec![terminator]
1274                } else {
1275                    return Err(PyValueError::new_err("The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators."));
1276                }
1277            } else {
1278                vec![]
1279            };
1280            let max_steps: Option<usize> = d
1281                .get_item("max_steps")?
1282                .map(|ms| ms.extract())
1283                .transpose()?;
1284            let settings: Bound<PyDict> = d
1285                .get_item("settings")?
1286                .map(|settings| settings.extract())
1287                .transpose()?
1288                .unwrap_or_else(|| PyDict::new(d.py()));
1289            if let Some(method) = d.get_item("method")? {
1290                match method
1291                    .extract::<String>()?
1292                    .to_lowercase()
1293                    .trim()
1294                    .replace("-", "")
1295                    .replace(" ", "")
1296                    .as_str()
1297                {
1298                    "lbfgsb" => {
1299                        let mut config = LBFGSBConfig::from_pyargs(args, &settings)?;
1300                        if let Some(bounds) = bounds {
1301                            config = config.with_bounds(bounds);
1302                        }
1303                        let mut callbacks = Callbacks::from_pyargs(&(), &settings)?;
1304                        if add_debug {
1305                            callbacks = callbacks.with_observer(DebugObserver);
1306                        }
1307                        if let Some(max_steps) = max_steps {
1308                            callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1309                        }
1310                        for observer in observers {
1311                            callbacks = callbacks.with_observer(observer);
1312                        }
1313                        for terminator in terminators {
1314                            callbacks = callbacks.with_terminator(terminator);
1315                        }
1316                        callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1317                        Ok(MinimizationSettings::LBFGSB {
1318                            config,
1319                            callbacks,
1320                            num_threads,
1321                        })
1322                    }
1323                    "adam" => {
1324                        let mut config = AdamConfig::from_pyargs(args, &settings)?;
1325                        if let Some(bounds) = bounds {
1326                            config = config.with_transform(&Bounds::from(bounds))
1327                        }
1328                        let mut callbacks = Callbacks::from_pyargs(&(), &settings)?;
1329                        if add_debug {
1330                            callbacks = callbacks.with_observer(DebugObserver);
1331                        }
1332                        if let Some(max_steps) = max_steps {
1333                            callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1334                        }
1335                        for observer in observers {
1336                            callbacks = callbacks.with_observer(observer);
1337                        }
1338                        for terminator in terminators {
1339                            callbacks = callbacks.with_terminator(terminator);
1340                        }
1341                        callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1342                        Ok(MinimizationSettings::Adam {
1343                            config,
1344                            callbacks,
1345                            num_threads,
1346                        })
1347                    }
1348                    "neldermead" => {
1349                        let mut config = NelderMeadConfig::from_pyargs(args, &settings)?;
1350                        if let Some(bounds) = bounds {
1351                            config = config.with_bounds(bounds);
1352                        }
1353                        let mut callbacks = Callbacks::from_pyargs(&(), &settings)?;
1354                        if add_debug {
1355                            callbacks = callbacks.with_observer(DebugObserver);
1356                        }
1357                        if let Some(max_steps) = max_steps {
1358                            callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1359                        }
1360                        for observer in observers {
1361                            callbacks = callbacks.with_observer(observer);
1362                        }
1363                        for terminator in terminators {
1364                            callbacks = callbacks.with_terminator(terminator);
1365                        }
1366                        callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1367                        Ok(MinimizationSettings::NelderMead {
1368                            config,
1369                            callbacks,
1370                            num_threads,
1371                        })
1372                    }
1373                    "pso" => {
1374                        let mut config = PSOConfig::from_pyargs(args, &settings)?;
1375                        if let Some(bounds) = bounds {
1376                            if let Some(use_transform) = settings.get_item("use_transform")? {
1377                                if use_transform.extract()? {
1378                                    config = config.with_transform(&Bounds::from(bounds))
1379                                } else {
1380                                    config = config.with_bounds(bounds)
1381                                }
1382                            } else {
1383                                config = config.with_bounds(bounds)
1384                            }
1385                        }
1386                        let mut callbacks = Callbacks::empty();
1387                        if add_debug {
1388                            return Err(PyValueError::new_err(
1389                                "The debug setting is not yet supported for PSO",
1390                            ));
1391                            // callbacks = callbacks.with_observer(DebugObserver);
1392                            // TODO: SwarmStatus needs to impl Debug
1393                        }
1394                        if let Some(max_steps) = max_steps {
1395                            callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1396                        }
1397                        for observer in observers {
1398                            callbacks = callbacks.with_observer(observer);
1399                        }
1400                        for terminator in terminators {
1401                            callbacks = callbacks.with_terminator(terminator);
1402                        }
1403                        callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1404                        Ok(MinimizationSettings::PSO {
1405                            config,
1406                            callbacks,
1407                            num_threads,
1408                        })
1409                    }
1410                    _ => Err(PyValueError::new_err(format!(
1411                        "Invalid minimizer: {}",
1412                        method
1413                    ))),
1414                }
1415            } else {
1416                Err(PyValueError::new_err("No method specified"))
1417            }
1418        }
1419    }
1420
1421    impl FromPyArgs<Vec<DVector<Float>>> for AIESConfig {
1422        fn from_pyargs(args: &Vec<DVector<Float>>, d: &Bound<PyDict>) -> PyResult<Self> {
1423            let mut config = AIESConfig::new(args.to_vec());
1424            if let Some(moves) = d.get_item("moves")? {
1425                let moves_list = moves.downcast::<PyList>()?;
1426                let mut aies_moves = vec![];
1427                for mcmc_move in moves_list {
1428                    if let Ok(default_move) = mcmc_move.extract::<(String, Float)>() {
1429                        match default_move
1430                            .0
1431                            .to_lowercase()
1432                            .trim()
1433                            .replace("-", "")
1434                            .replace(" ", "")
1435                            .as_str()
1436                        {
1437                            "stretch" => aies_moves.push(AIESMove::stretch(default_move.1)),
1438                            "walk" => aies_moves.push(AIESMove::walk(default_move.1)),
1439                            _ => {
1440                                return Err(PyValueError::new_err(format!(
1441                                    "Invalid AIES move: {}",
1442                                    default_move.0
1443                                )))
1444                            }
1445                        }
1446                    } else if let Ok(custom_move) =
1447                        mcmc_move.extract::<(String, Bound<PyDict>, Float)>()
1448                    {
1449                        match custom_move
1450                            .0
1451                            .to_lowercase()
1452                            .trim()
1453                            .replace("-", "")
1454                            .replace(" ", "")
1455                            .as_str()
1456                        {
1457                            "stretch" => aies_moves.push((
1458                                AIESMove::Stretch {
1459                                    a: custom_move
1460                                        .1
1461                                        .get_item("a")?
1462                                        .map(|val| val.extract())
1463                                        .transpose()?
1464                                        .unwrap_or(2.0),
1465                                },
1466                                custom_move.2,
1467                            )),
1468                            "walk" => aies_moves.push(AIESMove::walk(custom_move.2)),
1469                            _ => {
1470                                return Err(PyValueError::new_err(format!(
1471                                    "Invalid AIES move: {}",
1472                                    custom_move.0
1473                                )))
1474                            }
1475                        }
1476                    } else {
1477                        return Err(PyValueError::new_err("The 'moves' argument must be a list of (str, float) or (str, dict, float) tuples!"));
1478                    }
1479                }
1480                config = config.with_moves(aies_moves);
1481            }
1482            Ok(config)
1483        }
1484    }
1485
1486    impl FromPyArgs<Vec<DVector<Float>>> for ESSConfig {
1487        fn from_pyargs(args: &Vec<DVector<Float>>, d: &Bound<PyDict>) -> PyResult<Self> {
1488            let mut config = ESSConfig::new(args.to_vec());
1489            if let Some(moves) = d.get_item("moves")? {
1490                let moves_list = moves.downcast::<PyList>()?;
1491                let mut ess_moves = vec![];
1492                for mcmc_move in moves_list {
1493                    if let Ok(default_move) = mcmc_move.extract::<(String, Float)>() {
1494                        match default_move
1495                            .0
1496                            .to_lowercase()
1497                            .trim()
1498                            .replace("-", "")
1499                            .replace(" ", "")
1500                            .as_str()
1501                        {
1502                            "differential" => ess_moves.push(ESSMove::differential(default_move.1)),
1503                            "gaussian" => ess_moves.push(ESSMove::gaussian(default_move.1)),
1504                            "global" => {
1505                                ess_moves.push(ESSMove::global(default_move.1, None, None, None))
1506                            }
1507                            _ => {
1508                                return Err(PyValueError::new_err(format!(
1509                                    "Invalid ESS move: {}",
1510                                    default_move.0
1511                                )))
1512                            }
1513                        }
1514                    } else if let Ok(custom_move) =
1515                        mcmc_move.extract::<(String, Bound<PyDict>, Float)>()
1516                    {
1517                        match custom_move
1518                            .0
1519                            .to_lowercase()
1520                            .trim()
1521                            .replace("-", "")
1522                            .replace(" ", "")
1523                            .as_str()
1524                        {
1525                            "differential" => ess_moves.push(ESSMove::differential(custom_move.2)),
1526                            "gaussian" => ess_moves.push(ESSMove::gaussian(custom_move.2)),
1527                            "global" => ess_moves.push(ESSMove::global(
1528                                custom_move.2,
1529                                custom_move
1530                                    .1
1531                                    .get_item("scale")?
1532                                    .map(|value| value.extract())
1533                                    .transpose()?,
1534                                custom_move
1535                                    .1
1536                                    .get_item("rescale_cov")?
1537                                    .map(|value| value.extract())
1538                                    .transpose()?,
1539                                custom_move
1540                                    .1
1541                                    .get_item("n_components")?
1542                                    .map(|value| value.extract())
1543                                    .transpose()?,
1544                            )),
1545                            _ => {
1546                                return Err(PyValueError::new_err(format!(
1547                                    "Invalid ESS move: {}",
1548                                    custom_move.0
1549                                )))
1550                            }
1551                        }
1552                    } else {
1553                        return Err(PyValueError::new_err("The 'moves' argument must be a list of (str, float) or (str, dict, float) tuples!"));
1554                    }
1555                }
1556                config = config.with_moves(ess_moves)
1557            }
1558            if let Some(n_adaptive) = d.get_item("n_adaptive")? {
1559                config = config.with_n_adaptive(n_adaptive.extract()?);
1560            }
1561            if let Some(mu) = d.get_item("mu")? {
1562                config = config.with_mu(mu.extract()?);
1563            }
1564            if let Some(max_steps) = d.get_item("max_steps")? {
1565                config = config.with_max_steps(max_steps.extract()?);
1566            }
1567            Ok(config)
1568        }
1569    }
1570
1571    impl<P> FromPyArgs<Vec<DVector<Float>>> for MCMCSettings<P>
1572    where
1573        P: LogDensity<MaybeThreadPool, LadduError>,
1574    {
1575        fn from_pyargs(args: &Vec<DVector<Float>>, d: &Bound<PyDict>) -> PyResult<Self> {
1576            let bounds: Option<Vec<ganesh::traits::boundlike::Bound>> = d
1577                .get_item("bounds")?
1578                .map(|bounds| bounds.extract::<Vec<(Option<Float>, Option<Float>)>>())
1579                .transpose()?
1580                .map(|bounds| {
1581                    bounds
1582                        .into_iter()
1583                        .map(ganesh::traits::boundlike::Bound::from)
1584                        .collect()
1585                });
1586            let num_threads = d
1587                .get_item("threads")?
1588                .map(|t| t.extract())
1589                .transpose()?
1590                .unwrap_or(0);
1591            let add_debug = d
1592                .get_item("debug")?
1593                .map(|d| d.extract())
1594                .transpose()?
1595                .unwrap_or(false);
1596            let observers = if let Some(observers) = d.get_item("observers")? {
1597                if let Ok(observers) = observers.downcast::<PyList>() {
1598                    observers.into_iter().map(|observer| {
1599                        if let Ok(observer) = observer.extract::<MCMCObserver>() {
1600                            Ok(observer)
1601                        } else {
1602                            Err(PyValueError::new_err("The observers must be either a single MCMCObserver or a list of MCMCObservers."))
1603                        }
1604                    }).collect::<PyResult<Vec<MCMCObserver>>>()?
1605                } else if let Ok(observer) = observers.extract::<MCMCObserver>() {
1606                    vec![observer]
1607                } else {
1608                    return Err(PyValueError::new_err("The observers must be either a single MCMCObserver or a list of MCMCObservers."));
1609                }
1610            } else {
1611                vec![]
1612            };
1613            let terminators = if let Some(terminators) = d.get_item("terminators")? {
1614                if let Ok(terminators) = terminators.downcast::<PyList>() {
1615                    terminators
1616                        .into_iter()
1617                        .map(|terminator| {
1618                            if let Ok(terminator) =
1619                                terminator.extract::<PyAutocorrelationTerminator>()
1620                            {
1621                                Ok(PythonMCMCTerminator::Autocorrelation(terminator))
1622                            }
1623                            else if let Ok(terminator) = terminator.extract::<MCMCTerminator>() {
1624                                Ok(PythonMCMCTerminator::UserDefined(terminator))
1625                            } else {
1626                                Err(PyValueError::new_err("The terminators must be either a single MCMCTerminator or a list of MCMCTerminators."))
1627                            }
1628                        })
1629                        .collect::<PyResult<Vec<PythonMCMCTerminator>>>()?
1630                } else if let Ok(terminator) = terminators.extract::<PyAutocorrelationTerminator>()
1631                {
1632                    vec![PythonMCMCTerminator::Autocorrelation(terminator)]
1633                } else if let Ok(terminator) = terminators.extract::<MCMCTerminator>() {
1634                    vec![PythonMCMCTerminator::UserDefined(terminator)]
1635                } else {
1636                    return Err(PyValueError::new_err("The terminators must be either a single MCMCTerminator or a list of MCMCTerminators."));
1637                }
1638            } else {
1639                vec![]
1640            };
1641            let max_steps: Option<usize> = d
1642                .get_item("max_steps")?
1643                .map(|ms| ms.extract())
1644                .transpose()?;
1645            let settings: Bound<PyDict> = d
1646                .get_item("settings")?
1647                .map(|settings| settings.extract())
1648                .transpose()?
1649                .unwrap_or_else(|| PyDict::new(d.py()));
1650            if let Some(method) = d.get_item("method")? {
1651                match method
1652                    .extract::<String>()?
1653                    .to_lowercase()
1654                    .trim()
1655                    .replace("-", "")
1656                    .replace(" ", "")
1657                    .as_str()
1658                {
1659                    "aies" => {
1660                        let mut config = AIESConfig::from_pyargs(args, &settings)?;
1661                        if let Some(bounds) = bounds {
1662                            config = config.with_transform(&Bounds::from(bounds))
1663                        }
1664                        let mut callbacks = Callbacks::empty();
1665                        if add_debug {
1666                            callbacks = callbacks.with_observer(DebugObserver);
1667                        }
1668                        if let Some(max_steps) = max_steps {
1669                            callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1670                        }
1671                        for observer in observers {
1672                            callbacks = callbacks.with_observer(observer);
1673                        }
1674                        for terminator in terminators {
1675                            callbacks = callbacks.with_terminator(terminator);
1676                        }
1677                        callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1678                        Ok(MCMCSettings::AIES {
1679                            config,
1680                            callbacks,
1681                            num_threads,
1682                        })
1683                    }
1684                    "ess" => {
1685                        let mut config = ESSConfig::from_pyargs(args, &settings)?;
1686                        if let Some(bounds) = bounds {
1687                            config = config.with_transform(&Bounds::from(bounds))
1688                        }
1689                        let mut callbacks = Callbacks::empty();
1690                        if add_debug {
1691                            callbacks = callbacks.with_observer(DebugObserver);
1692                        }
1693                        if let Some(max_steps) = max_steps {
1694                            callbacks = callbacks.with_terminator(MaxSteps(max_steps));
1695                        }
1696                        for observer in observers {
1697                            callbacks = callbacks.with_observer(observer);
1698                        }
1699                        for terminator in terminators {
1700                            callbacks = callbacks.with_terminator(terminator);
1701                        }
1702                        callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
1703                        Ok(MCMCSettings::ESS {
1704                            config,
1705                            callbacks,
1706                            num_threads,
1707                        })
1708                    }
1709                    _ => Err(PyValueError::new_err(format!(
1710                        "Invalid MCMC algorithm: {}",
1711                        method
1712                    ))),
1713                }
1714            } else {
1715                Err(PyValueError::new_err("No method specified"))
1716            }
1717        }
1718    }
1719
1720    enum MinimizationStatus {
1721        GradientStatus(Arc<Mutex<GradientStatus>>),
1722        GradientFreeStatus(Arc<Mutex<GradientFreeStatus>>),
1723        SwarmStatus(Arc<Mutex<SwarmStatus>>),
1724    }
1725    impl MinimizationStatus {
1726        fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1727            match self {
1728                Self::GradientStatus(gradient_status) => {
1729                    gradient_status.lock().x.as_slice().to_pyarray(py)
1730                }
1731                Self::GradientFreeStatus(gradient_free_status) => {
1732                    gradient_free_status.lock().x.as_slice().to_pyarray(py)
1733                }
1734                Self::SwarmStatus(swarm_status) => {
1735                    swarm_status.lock().gbest.x.as_slice().to_pyarray(py)
1736                }
1737            }
1738        }
1739        fn fx(&self) -> Float {
1740            match self {
1741                Self::GradientStatus(gradient_status) => gradient_status.lock().fx,
1742                Self::GradientFreeStatus(gradient_free_status) => gradient_free_status.lock().fx,
1743                Self::SwarmStatus(swarm_status) => swarm_status.lock().gbest.fx.unwrap(),
1744            }
1745        }
1746        fn message(&self) -> String {
1747            match self {
1748                Self::GradientStatus(gradient_status) => {
1749                    gradient_status.lock().message().to_string()
1750                }
1751                Self::GradientFreeStatus(gradient_free_status) => {
1752                    gradient_free_status.lock().message().to_string()
1753                }
1754                Self::SwarmStatus(swarm_status) => swarm_status.lock().message().to_string(),
1755            }
1756        }
1757        fn err<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray1<Float>>> {
1758            match self {
1759                Self::GradientStatus(gradient_status) => gradient_status.lock().err.clone(),
1760                Self::GradientFreeStatus(gradient_free_status) => {
1761                    gradient_free_status.lock().err.clone()
1762                }
1763                Self::SwarmStatus(_) => None,
1764            }
1765            .map(|e| e.as_slice().to_pyarray(py))
1766        }
1767        fn n_f_evals(&self) -> usize {
1768            match self {
1769                Self::GradientStatus(gradient_status) => gradient_status.lock().n_f_evals,
1770                Self::GradientFreeStatus(gradient_free_status) => {
1771                    gradient_free_status.lock().n_f_evals
1772                }
1773                Self::SwarmStatus(swarm_status) => swarm_status.lock().n_f_evals,
1774            }
1775        }
1776        fn n_g_evals(&self) -> usize {
1777            match self {
1778                Self::GradientStatus(gradient_status) => gradient_status.lock().n_g_evals,
1779                Self::GradientFreeStatus(_) => 0,
1780                Self::SwarmStatus(_) => 0,
1781            }
1782        }
1783        fn cov<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1784            match self {
1785                Self::GradientStatus(gradient_status) => gradient_status.lock().cov.clone(),
1786                Self::GradientFreeStatus(gradient_free_status) => {
1787                    gradient_free_status.lock().cov.clone()
1788                }
1789                Self::SwarmStatus(_) => None,
1790            }
1791            .map(|cov| cov.to_pyarray(py))
1792        }
1793        fn hess<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1794            match self {
1795                Self::GradientStatus(gradient_status) => gradient_status.lock().hess.clone(),
1796                Self::GradientFreeStatus(gradient_free_status) => {
1797                    gradient_free_status.lock().hess.clone()
1798                }
1799                Self::SwarmStatus(_) => None,
1800            }
1801            .map(|hess| hess.to_pyarray(py))
1802        }
1803        fn converged(&self) -> bool {
1804            match self {
1805                Self::GradientStatus(gradient_status) => gradient_status.lock().converged(),
1806                Self::GradientFreeStatus(gradient_free_status) => {
1807                    gradient_free_status.lock().converged()
1808                }
1809                Self::SwarmStatus(swarm_status) => swarm_status.lock().converged(),
1810            }
1811        }
1812        fn swarm(&self) -> Option<PySwarm> {
1813            match self {
1814                Self::GradientStatus(_) | Self::GradientFreeStatus(_) => None,
1815                Self::SwarmStatus(swarm_status) => Some(PySwarm(swarm_status.lock().swarm.clone())),
1816            }
1817        }
1818    }
1819
1820    /// A swarm of particles used in particle swarm optimization.
1821    ///
1822    #[pyclass(name = "Swarm", module = "laddu")]
1823    pub struct PySwarm(Swarm);
1824
1825    #[pymethods]
1826    impl PySwarm {
1827        /// The particles in the swarm.
1828        ///
1829        /// Returns
1830        /// -------
1831        /// list of SwarmParticle
1832        ///
1833        #[getter]
1834        fn particles(&self) -> Vec<PySwarmParticle> {
1835            self.0
1836                .get_particles()
1837                .into_iter()
1838                .map(PySwarmParticle)
1839                .collect()
1840        }
1841    }
1842
1843    /// A particle in a swarm used in particle swarm optimization.
1844    ///
1845    #[pyclass(name = "SwarmParticle", module = "laddu")]
1846    pub struct PySwarmParticle(SwarmParticle);
1847
1848    #[pymethods]
1849    impl PySwarmParticle {
1850        /// The position of the particle.
1851        ///
1852        /// Returns
1853        /// -------
1854        /// array_like
1855        ///
1856        #[getter]
1857        fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1858            self.0.position.x.as_slice().to_pyarray(py)
1859        }
1860        /// The evaluation of the objective function at the particle's position.
1861        ///
1862        /// Returns
1863        /// -------
1864        /// float
1865        ///
1866        #[getter]
1867        fn fx(&self) -> Float {
1868            self.0.position.fx.unwrap()
1869        }
1870        /// The best position found by the particle.
1871        ///
1872        /// Returns
1873        /// -------
1874        /// array_like
1875        ///
1876        #[getter]
1877        fn x_best<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1878            self.0.best.x.as_slice().to_pyarray(py)
1879        }
1880        /// The evaluation of the objective function at the particle's best position.
1881        ///
1882        /// Returns
1883        /// -------
1884        /// float
1885        ///
1886        #[getter]
1887        fn fx_best(&self) -> Float {
1888            self.0.best.fx.unwrap()
1889        }
1890        /// The velocity vector of the particle.
1891        ///
1892        /// Returns
1893        /// -------
1894        /// array_like
1895        ///
1896        #[getter]
1897        fn velocity<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1898            self.0.velocity.as_slice().to_pyarray(py)
1899        }
1900    }
1901
1902    /// The intermediate status used to inform the user of the current state of a minimization algorithm.
1903    ///
1904    #[pyclass(name = "MinimizationStatus", module = "laddu")]
1905    pub struct PyMinimizationStatus(MinimizationStatus);
1906
1907    #[pymethods]
1908    impl PyMinimizationStatus {
1909        /// The current best position of the minimizer.
1910        ///
1911        /// Returns
1912        /// -------
1913        /// array_like
1914        ///
1915        #[getter]
1916        fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
1917            self.0.x(py)
1918        }
1919        /// The current value of the objective function at the best position.
1920        ///
1921        /// Returns
1922        /// -------
1923        /// float
1924        ///
1925        #[getter]
1926        fn fx(&self) -> Float {
1927            self.0.fx()
1928        }
1929        /// A message indicating the current state of the minimization.
1930        ///
1931        /// Returns
1932        /// -------
1933        /// str
1934        ///
1935        #[getter]
1936        fn message(&self) -> String {
1937            self.0.message()
1938        }
1939        /// The current error estimate at the best position. May be None for algorithms that do not estimate errors.
1940        ///
1941        /// Returns
1942        /// -------
1943        /// array_like or None
1944        ///
1945        #[getter]
1946        fn err<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray1<Float>>> {
1947            self.0.err(py)
1948        }
1949        /// The number of objective function evaluations performed.
1950        ///
1951        /// Returns
1952        /// -------
1953        /// int
1954        ///
1955        #[getter]
1956        fn n_f_evals(&self) -> usize {
1957            self.0.n_f_evals()
1958        }
1959        /// The number of gradient function evaluations performed.
1960        ///
1961        /// Returns
1962        /// -------
1963        /// int
1964        ///
1965        #[getter]
1966        fn n_g_evals(&self) -> usize {
1967            self.0.n_g_evals()
1968        }
1969        /// The covariance matrix of the best position. May be None for algorithms that do not estimate errors.
1970        ///
1971        /// Returns
1972        /// -------
1973        /// array_like or None
1974        ///
1975        #[getter]
1976        fn cov<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1977            self.0.cov(py)
1978        }
1979        /// The Hessian matrix of the best position. May be None for algorithms that do not estimate errors.
1980        ///
1981        /// Returns
1982        /// -------
1983        /// array_like or None
1984        ///
1985        #[getter]
1986        fn hess<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<Float>>> {
1987            self.0.hess(py)
1988        }
1989        #[getter]
1990        fn converged(&self) -> bool {
1991            self.0.converged()
1992        }
1993        /// The swarm of particles used in swarm-based optimization algorithms. May be None for algorithms that do not use a swarm.
1994        ///
1995        /// Returns
1996        /// -------
1997        /// Swarm or None
1998        ///
1999        #[getter]
2000        fn swarm(&self) -> Option<PySwarm> {
2001            self.0.swarm()
2002        }
2003    }
2004
2005    /// A summary of the results of a minimization.
2006    ///
2007    #[pyclass(name = "MinimizationSummary", module = "laddu")]
2008    #[derive(Clone)]
2009    pub struct PyMinimizationSummary(pub MinimizationSummary);
2010
2011    #[pymethods]
2012    impl PyMinimizationSummary {
2013        /// Bounds which were used during the minimization.
2014        ///
2015        /// Returns
2016        /// -------
2017        /// list of tuple of floats or None
2018        ///
2019        #[getter]
2020        fn bounds(&self) -> Option<Vec<(Float, Float)>> {
2021            self.0
2022                .clone()
2023                .bounds
2024                .map(|bs| bs.iter().map(|b| b.0.as_floats()).collect())
2025        }
2026        /// Names of each parameter used in the minimization.
2027        ///
2028        /// Returns
2029        /// -------
2030        /// list of str
2031        ///
2032        #[getter]
2033        fn parameter_names(&self) -> Vec<String> {
2034            self.0.parameter_names.clone().unwrap_or_default()
2035        }
2036        /// The status at the end of the minimization.
2037        ///
2038        /// Returns
2039        /// -------
2040        /// str
2041        ///
2042        #[getter]
2043        fn message(&self) -> String {
2044            self.0.message.clone()
2045        }
2046        /// The starting position of the minimizer.
2047        ///
2048        /// Returns
2049        /// -------
2050        /// array_like
2051        ///
2052        #[getter]
2053        fn x0<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2054            self.0.x0.as_slice().to_pyarray(py)
2055        }
2056        /// The best position found by the minimizer.
2057        ///
2058        /// Returns
2059        /// -------
2060        /// array_like
2061        ///
2062        #[getter]
2063        fn x<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2064            self.0.x.as_slice().to_pyarray(py)
2065        }
2066        /// The uncertainty associated with each parameter (may be zeros if no uncertainty was estimated or nan if the covariance matrix was not positive definite).
2067        ///
2068        /// Returns
2069        /// -------
2070        /// array_like
2071        ///
2072        #[getter]
2073        fn std<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2074            self.0.std.as_slice().to_pyarray(py)
2075        }
2076        /// The value of the objective function at the best position found by the minimizer.
2077        ///
2078        /// Returns
2079        /// -------
2080        /// float
2081        ///
2082        #[getter]
2083        fn fx(&self) -> Float {
2084            self.0.fx
2085        }
2086        /// The number of objective function evaluations performed.
2087        ///
2088        /// Returns
2089        /// -------
2090        /// int
2091        ///
2092        #[getter]
2093        fn cost_evals(&self) -> usize {
2094            self.0.cost_evals
2095        }
2096        /// The number of gradient function evaluations performed.
2097        ///
2098        /// Returns
2099        /// -------
2100        /// int
2101        ///
2102        #[getter]
2103        fn gradient_evals(&self) -> usize {
2104            self.0.gradient_evals
2105        }
2106        /// True if the minimization algorithm has converged.
2107        ///
2108        /// Returns
2109        /// -------
2110        /// bool
2111        ///
2112        #[getter]
2113        fn converged(&self) -> bool {
2114            self.0.converged
2115        }
2116        /// The covariance matrix of the best position (may contain zeros if the algorithm did not estimate a Hessian).
2117        ///
2118        /// Returns
2119        /// -------
2120        /// array_like
2121        ///
2122        #[getter]
2123        fn covariance<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<Float>> {
2124            self.0.covariance.to_pyarray(py)
2125        }
2126        fn __str__(&self) -> String {
2127            self.0.to_string()
2128        }
2129        #[new]
2130        fn new() -> Self {
2131            Self(MinimizationSummary::create_null())
2132        }
2133        fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
2134            Ok(PyBytes::new(
2135                py,
2136                bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
2137                    .map_err(LadduError::EncodeError)?
2138                    .as_slice(),
2139            ))
2140        }
2141        fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
2142            *self = Self(
2143                bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
2144                    .map_err(LadduError::DecodeError)?
2145                    .0,
2146            );
2147            Ok(())
2148        }
2149    }
2150
2151    /// An enum used by a terminator to continue or stop an algorithm.
2152    ///
2153    #[pyclass(eq, eq_int, name = "ControlFlow", module = "laddu")]
2154    #[derive(PartialEq, Clone)]
2155    pub enum PyControlFlow {
2156        /// Continue running the algorithm.
2157        Continue = 0,
2158        /// Terminate the algorithm.
2159        Break = 1,
2160    }
2161
2162    impl From<PyControlFlow> for ControlFlow<()> {
2163        fn from(v: PyControlFlow) -> Self {
2164            match v {
2165                PyControlFlow::Continue => ControlFlow::Continue(()),
2166                PyControlFlow::Break => ControlFlow::Break(()),
2167            }
2168        }
2169    }
2170
2171    /// An [`Observer`] which can be used to monitor the progress of a minimization.
2172    ///
2173    /// This should be paired with a Python object which has an `observe` method
2174    /// that takes the current step and a [`PyMinimizationStatus`] as arguments.
2175    #[derive(Clone)]
2176    pub struct MinimizationObserver(Arc<Py<PyAny>>);
2177    impl FromPyObject<'_> for MinimizationObserver {
2178        fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2179            Ok(MinimizationObserver(Arc::new(ob.clone().unbind())))
2180        }
2181    }
2182    impl<A, P, C> Observer<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
2183        for MinimizationObserver
2184    where
2185        A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
2186    {
2187        fn observe(
2188            &mut self,
2189            current_step: usize,
2190            _algorithm: &A,
2191            _problem: &P,
2192            status: &GradientStatus,
2193            _args: &MaybeThreadPool,
2194            _config: &C,
2195        ) {
2196            Python::attach(|py| {
2197                self.0
2198                    .bind(py)
2199                    .call_method1(
2200                        "observe",
2201                        (
2202                            current_step,
2203                            PyMinimizationStatus(MinimizationStatus::GradientStatus(Arc::new(
2204                                Mutex::new(status.clone()),
2205                            ))),
2206                        ),
2207                    )
2208                    .expect("Error calling observe");
2209            })
2210        }
2211    }
2212    impl<A, P, C> Observer<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
2213        for MinimizationObserver
2214    where
2215        A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
2216    {
2217        fn observe(
2218            &mut self,
2219            current_step: usize,
2220            _algorithm: &A,
2221            _problem: &P,
2222            status: &GradientFreeStatus,
2223            _args: &MaybeThreadPool,
2224            _config: &C,
2225        ) {
2226            Python::attach(|py| {
2227                self.0
2228                    .bind(py)
2229                    .call_method1(
2230                        "observe",
2231                        (
2232                            current_step,
2233                            PyMinimizationStatus(MinimizationStatus::GradientFreeStatus(Arc::new(
2234                                Mutex::new(status.clone()),
2235                            ))),
2236                        ),
2237                    )
2238                    .expect("Error calling observe");
2239            })
2240        }
2241    }
2242    impl<A, P, C> Observer<A, P, SwarmStatus, MaybeThreadPool, LadduError, C> for MinimizationObserver
2243    where
2244        A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
2245    {
2246        fn observe(
2247            &mut self,
2248            current_step: usize,
2249            _algorithm: &A,
2250            _problem: &P,
2251            status: &SwarmStatus,
2252            _args: &MaybeThreadPool,
2253            _config: &C,
2254        ) {
2255            Python::attach(|py| {
2256                self.0
2257                    .bind(py)
2258                    .call_method1(
2259                        "observe",
2260                        (
2261                            current_step,
2262                            PyMinimizationStatus(MinimizationStatus::SwarmStatus(Arc::new(
2263                                Mutex::new(status.clone()),
2264                            ))),
2265                        ),
2266                    )
2267                    .expect("Error calling observe");
2268            })
2269        }
2270    }
2271
2272    /// An [`Terminator`] which can be used to monitor the progress of a minimization.
2273    ///
2274    /// This should be paired with a Python object which has an `check_for_termination` method
2275    /// that takes the current step and a [`PyMinimizationStatus`] as arguments and returns a
2276    /// [`PyControlFlow`].
2277    #[derive(Clone)]
2278    pub struct MinimizationTerminator(Arc<Py<PyAny>>);
2279
2280    impl FromPyObject<'_> for MinimizationTerminator {
2281        fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2282            Ok(MinimizationTerminator(Arc::new(ob.clone().unbind())))
2283        }
2284    }
2285
2286    impl<A, P, C> Terminator<A, P, GradientStatus, MaybeThreadPool, LadduError, C>
2287        for MinimizationTerminator
2288    where
2289        A: Algorithm<P, GradientStatus, MaybeThreadPool, LadduError, Config = C>,
2290    {
2291        fn check_for_termination(
2292            &mut self,
2293            current_step: usize,
2294            _algorithm: &mut A,
2295            _problem: &P,
2296            status: &mut GradientStatus,
2297            _args: &MaybeThreadPool,
2298            _config: &C,
2299        ) -> ControlFlow<()> {
2300            Python::attach(|py| -> PyResult<ControlFlow<()>> {
2301                let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2302                let py_status = Py::new(
2303                    py,
2304                    PyMinimizationStatus(MinimizationStatus::GradientStatus(
2305                        wrapped_status.clone(),
2306                    )),
2307                )?;
2308                let ret = self
2309                    .0
2310                    .bind(py)
2311                    .call_method1("check_for_termination", (current_step, py_status))
2312                    .expect("Error calling check_for_termination");
2313                {
2314                    let mut guard = wrapped_status.lock();
2315                    std::mem::swap(status, &mut *guard);
2316                }
2317                let cf: PyControlFlow = ret.extract()?;
2318                Ok(cf.into())
2319            })
2320            .unwrap_or(ControlFlow::Continue(()))
2321        }
2322    }
2323    impl<A, P, C> Terminator<A, P, GradientFreeStatus, MaybeThreadPool, LadduError, C>
2324        for MinimizationTerminator
2325    where
2326        A: Algorithm<P, GradientFreeStatus, MaybeThreadPool, LadduError, Config = C>,
2327    {
2328        fn check_for_termination(
2329            &mut self,
2330            current_step: usize,
2331            _algorithm: &mut A,
2332            _problem: &P,
2333            status: &mut GradientFreeStatus,
2334            _args: &MaybeThreadPool,
2335            _config: &C,
2336        ) -> ControlFlow<()> {
2337            Python::attach(|py| -> PyResult<ControlFlow<()>> {
2338                let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2339                let py_status = Py::new(
2340                    py,
2341                    PyMinimizationStatus(MinimizationStatus::GradientFreeStatus(
2342                        wrapped_status.clone(),
2343                    )),
2344                )?;
2345                let ret = self
2346                    .0
2347                    .bind(py)
2348                    .call_method1("check_for_termination", (current_step, py_status))
2349                    .expect("Error calling check_for_termination");
2350                {
2351                    let mut guard = wrapped_status.lock();
2352                    std::mem::swap(status, &mut *guard);
2353                }
2354                let cf: PyControlFlow = ret.extract()?;
2355                Ok(cf.into())
2356            })
2357            .unwrap_or(ControlFlow::Continue(()))
2358        }
2359    }
2360    impl<A, P, C> Terminator<A, P, SwarmStatus, MaybeThreadPool, LadduError, C>
2361        for MinimizationTerminator
2362    where
2363        A: Algorithm<P, SwarmStatus, MaybeThreadPool, LadduError, Config = C>,
2364    {
2365        fn check_for_termination(
2366            &mut self,
2367            current_step: usize,
2368            _algorithm: &mut A,
2369            _problem: &P,
2370            status: &mut SwarmStatus,
2371            _args: &MaybeThreadPool,
2372            _config: &C,
2373        ) -> ControlFlow<()> {
2374            Python::attach(|py| -> PyResult<ControlFlow<()>> {
2375                let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2376                let py_status = Py::new(
2377                    py,
2378                    PyMinimizationStatus(MinimizationStatus::SwarmStatus(wrapped_status.clone())),
2379                )?;
2380                let ret = self
2381                    .0
2382                    .bind(py)
2383                    .call_method1("check_for_termination", (current_step, py_status))
2384                    .expect("Error calling check_for_termination");
2385                {
2386                    let mut guard = wrapped_status.lock();
2387                    std::mem::swap(status, &mut *guard);
2388                }
2389                let cf: PyControlFlow = ret.extract()?;
2390                Ok(cf.into())
2391            })
2392            .unwrap_or(ControlFlow::Continue(()))
2393        }
2394    }
2395
2396    /// A walker in an MCMC ensemble.
2397    ///
2398    #[pyclass(name = "Walker", module = "laddu")]
2399    pub struct PyWalker(pub Walker);
2400
2401    #[pymethods]
2402    impl PyWalker {
2403        /// The dimension of the walker's space (n_steps, n_variables)
2404        ///
2405        /// Returns
2406        /// -------
2407        /// tuple of int
2408        #[getter]
2409        fn dimension(&self) -> (usize, usize) {
2410            self.0.dimension()
2411        }
2412        /// Retrieve the latest point and the latest objective value of the Walker.
2413        ///
2414        fn get_latest<'py>(&self, py: Python<'py>) -> (Bound<'py, PyArray1<Float>>, Float) {
2415            let point = self.0.get_latest();
2416            let lock = point.read();
2417            (lock.x.clone().as_slice().to_pyarray(py), lock.fx_checked())
2418        }
2419    }
2420
2421    /// The intermediate status used to inform the user of the current state of an MCMC algorithm.
2422    ///
2423    #[pyclass(name = "EnsembleStatus", module = "laddu")]
2424    pub struct PyEnsembleStatus(Arc<Mutex<EnsembleStatus>>);
2425
2426    #[pymethods]
2427    impl PyEnsembleStatus {
2428        /// A message indicating the current state of the minimization.
2429        ///
2430        /// Returns
2431        /// -------
2432        /// str
2433        ///
2434        #[getter]
2435        fn message(&self) -> String {
2436            self.0.lock().message().to_string()
2437        }
2438        /// The number of objective function evaluations performed.
2439        ///
2440        /// Returns
2441        /// -------
2442        /// int
2443        ///
2444        #[getter]
2445        fn n_f_evals(&self) -> usize {
2446            self.0.lock().n_f_evals
2447        }
2448        /// The number of gradient function evaluations performed.
2449        ///
2450        /// Returns
2451        /// -------
2452        /// int
2453        ///
2454        #[getter]
2455        fn n_g_evals(&self) -> usize {
2456            self.0.lock().n_g_evals
2457        }
2458        /// The walkers in the ensemble.
2459        ///
2460        /// Returns
2461        /// -------
2462        /// list of Walker
2463        ///
2464        #[getter]
2465        fn walkers(&self) -> Vec<PyWalker> {
2466            self.0
2467                .lock()
2468                .walkers
2469                .iter()
2470                .map(|w| PyWalker(w.clone()))
2471                .collect()
2472        }
2473        /// The dimension of the ensemble `(n_walkers, n_steps, n_variables)`.
2474        ///
2475        /// Returns
2476        /// -------
2477        /// tuple of int
2478        ///
2479        #[getter]
2480        fn dimension(&self) -> (usize, usize, usize) {
2481            self.0.lock().dimension()
2482        }
2483
2484        /// Retrieve the chain of the MCMC sampling.
2485        ///
2486        /// Parameters
2487        /// ----------
2488        /// burn : int, optional
2489        ///     The number of steps to discard from the beginning of the chain.
2490        /// thin : int, optional
2491        ///     The number of steps to skip between samples.
2492        ///
2493        /// Returns
2494        /// -------
2495        /// chain : array of shape (n_steps, n_variables, n_walkers)
2496        ///
2497        #[pyo3(signature = (*, burn = None, thin = None))]
2498        fn get_chain<'py>(
2499            &self,
2500            py: Python<'py>,
2501            burn: Option<usize>,
2502            thin: Option<usize>,
2503        ) -> PyResult<Bound<'py, PyArray3<Float>>> {
2504            let vec_chain: Vec<Vec<Vec<Float>>> = self
2505                .0
2506                .lock()
2507                .get_chain(burn, thin)
2508                .iter()
2509                .map(|steps| steps.iter().map(|p| p.as_slice().to_vec()).collect())
2510                .collect();
2511            Ok(PyArray3::from_vec3(py, &vec_chain)?)
2512        }
2513
2514        /// Retrieve the chain of the MCMC sampling, flattened over walkers.
2515        ///
2516        /// Parameters
2517        /// ----------
2518        /// burn : int, optional
2519        ///     The number of steps to discard from the beginning of the chain.
2520        /// thin : int, optional
2521        ///     The number of steps to skip between samples.
2522        ///
2523        /// Returns
2524        /// -------
2525        /// flat_chain : array of shape (n_steps * n_walkers, n_variables)
2526        ///
2527        #[pyo3(signature = (*, burn = None, thin = None))]
2528        fn get_flat_chain<'py>(
2529            &self,
2530            py: Python<'py>,
2531            burn: Option<usize>,
2532            thin: Option<usize>,
2533        ) -> Bound<'py, PyArray2<Float>> {
2534            DMatrix::from_columns(&self.0.lock().get_flat_chain(burn, thin))
2535                .transpose()
2536                .to_pyarray(py)
2537        }
2538    }
2539
2540    /// A summary of the results of an MCMC sampling.
2541    ///
2542    #[pyclass(name = "MCMCSummary", module = "laddu")]
2543    pub struct PyMCMCSummary(pub MCMCSummary);
2544
2545    #[pymethods]
2546    impl PyMCMCSummary {
2547        /// Bounds which were used during the MCMC sampling.
2548        ///
2549        /// Returns
2550        /// -------
2551        /// list of tuple of floats or None
2552        ///
2553        #[getter]
2554        fn bounds(&self) -> Option<Vec<(Float, Float)>> {
2555            self.0
2556                .clone()
2557                .bounds
2558                .map(|bs| bs.iter().map(|b| b.0.as_floats()).collect())
2559        }
2560        /// Names of each parameter used in the MCMC sampling.
2561        ///
2562        /// Returns
2563        /// -------
2564        /// list of str
2565        ///
2566        #[getter]
2567        fn parameter_names(&self) -> Vec<String> {
2568            self.0.parameter_names.clone().unwrap_or_default()
2569        }
2570        /// The status at the end of the MCMC sampling.
2571        ///
2572        /// Returns
2573        /// -------
2574        /// str
2575        ///
2576        #[getter]
2577        fn message(&self) -> String {
2578            self.0.message.clone()
2579        }
2580        /// The number of objective function evaluations performed.
2581        ///
2582        /// Returns
2583        /// -------
2584        /// int
2585        ///
2586        #[getter]
2587        fn cost_evals(&self) -> usize {
2588            self.0.cost_evals
2589        }
2590        /// The number of gradient function evaluations performed.
2591        ///
2592        /// Returns
2593        /// -------
2594        /// int
2595        ///
2596        #[getter]
2597        fn gradient_evals(&self) -> usize {
2598            self.0.gradient_evals
2599        }
2600        /// True if the MCMC algorithm has converged.
2601        ///
2602        /// Returns
2603        /// -------
2604        /// bool
2605        ///
2606        #[getter]
2607        fn converged(&self) -> bool {
2608            self.0.converged
2609        }
2610        /// The dimension of the ensemble `(n_walkers, n_steps, n_variables)`.
2611        ///
2612        /// Returns
2613        /// -------
2614        /// tuple of int
2615        ///
2616        #[getter]
2617        fn dimension(&self) -> (usize, usize, usize) {
2618            self.0.dimension
2619        }
2620
2621        /// Retrieve the chain of the MCMC sampling.
2622        ///
2623        /// Parameters
2624        /// ----------
2625        /// burn : int, optional
2626        ///     The number of steps to discard from the beginning of the chain.
2627        /// thin : int, optional
2628        ///     The number of steps to skip between samples.
2629        ///
2630        /// Returns
2631        /// -------
2632        /// chain : array of shape (n_steps, n_variables, n_walkers)
2633        ///
2634        #[pyo3(signature = (*, burn = None, thin = None))]
2635        fn get_chain<'py>(
2636            &self,
2637            py: Python<'py>,
2638            burn: Option<usize>,
2639            thin: Option<usize>,
2640        ) -> PyResult<Bound<'py, PyArray3<Float>>> {
2641            let vec_chain: Vec<Vec<Vec<Float>>> = self
2642                .0
2643                .get_chain(burn, thin)
2644                .iter()
2645                .map(|steps| steps.iter().map(|p| p.as_slice().to_vec()).collect())
2646                .collect();
2647            Ok(PyArray3::from_vec3(py, &vec_chain)?)
2648        }
2649
2650        /// Retrieve the chain of the MCMC sampling, flattened over walkers.
2651        ///
2652        /// Parameters
2653        /// ----------
2654        /// burn : int, optional
2655        ///     The number of steps to discard from the beginning of the chain.
2656        /// thin : int, optional
2657        ///     The number of steps to skip between samples.
2658        ///
2659        /// Returns
2660        /// -------
2661        /// flat_chain : array of shape (n_steps * n_walkers, n_variables)
2662        ///
2663        #[pyo3(signature = (*, burn = None, thin = None))]
2664        fn get_flat_chain<'py>(
2665            &self,
2666            py: Python<'py>,
2667            burn: Option<usize>,
2668            thin: Option<usize>,
2669        ) -> Bound<'py, PyArray2<Float>> {
2670            DMatrix::from_columns(&self.0.get_flat_chain(burn, thin))
2671                .transpose()
2672                .to_pyarray(py)
2673        }
2674
2675        #[new]
2676        fn new() -> Self {
2677            Self(MCMCSummary::create_null())
2678        }
2679        fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
2680            Ok(PyBytes::new(
2681                py,
2682                bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
2683                    .map_err(LadduError::EncodeError)?
2684                    .as_slice(),
2685            ))
2686        }
2687        fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
2688            *self = Self(
2689                bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
2690                    .map_err(LadduError::DecodeError)?
2691                    .0,
2692            );
2693            Ok(())
2694        }
2695    }
2696
2697    /// An [`Observer`] which can be used to monitor the progress of an MCMC algorithm.
2698    ///
2699    /// This should be paired with a Python object which has an `observe` method
2700    /// that takes the current step and a [`PyEnsembleStatus`] as arguments.
2701    #[derive(Clone)]
2702    pub struct MCMCObserver(Arc<Py<PyAny>>);
2703
2704    impl FromPyObject<'_> for MCMCObserver {
2705        fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2706            Ok(MCMCObserver(Arc::new(ob.clone().unbind())))
2707        }
2708    }
2709
2710    impl<A, P, C> Observer<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCObserver
2711    where
2712        A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
2713    {
2714        fn observe(
2715            &mut self,
2716            current_step: usize,
2717            _algorithm: &A,
2718            _problem: &P,
2719            status: &EnsembleStatus,
2720            _args: &MaybeThreadPool,
2721            _config: &C,
2722        ) {
2723            Python::attach(|py| {
2724                self.0
2725                    .bind(py)
2726                    .call_method1(
2727                        "observe",
2728                        (
2729                            current_step,
2730                            PyEnsembleStatus(Arc::new(Mutex::new(status.clone()))),
2731                        ),
2732                    )
2733                    .expect("Error calling observe");
2734            })
2735        }
2736    }
2737
2738    #[derive(Clone)]
2739    enum PythonMCMCTerminator {
2740        UserDefined(MCMCTerminator),
2741        Autocorrelation(PyAutocorrelationTerminator),
2742    }
2743
2744    impl<A, P, C> Terminator<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C>
2745        for PythonMCMCTerminator
2746    where
2747        A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
2748    {
2749        fn check_for_termination(
2750            &mut self,
2751            current_step: usize,
2752            algorithm: &mut A,
2753            problem: &P,
2754            status: &mut EnsembleStatus,
2755            args: &MaybeThreadPool,
2756            config: &C,
2757        ) -> ControlFlow<()> {
2758            match self {
2759                Self::UserDefined(mcmcterminator) => mcmcterminator.check_for_termination(
2760                    current_step,
2761                    algorithm,
2762                    problem,
2763                    status,
2764                    args,
2765                    config,
2766                ),
2767                Self::Autocorrelation(py_autocorrelation_terminator) => {
2768                    py_autocorrelation_terminator.0.check_for_termination(
2769                        current_step,
2770                        algorithm,
2771                        problem,
2772                        status,
2773                        args,
2774                        config,
2775                    )
2776                }
2777            }
2778        }
2779    }
2780    /// A [`Terminator`] which can be used to monitor the progress of an MCMC algorithm.
2781    ///
2782    /// This should be paired with a Python object which has an `check_for_termination` method
2783    /// that takes the current step and a [`PyEnsembleStatus`] as arguments and returns a
2784    /// [`PyControlFlow`].
2785    #[derive(Clone)]
2786    pub struct MCMCTerminator(Arc<Py<PyAny>>);
2787
2788    impl FromPyObject<'_> for MCMCTerminator {
2789        fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
2790            Ok(MCMCTerminator(Arc::new(ob.clone().unbind())))
2791        }
2792    }
2793
2794    impl<A, P, C> Terminator<A, P, EnsembleStatus, MaybeThreadPool, LadduError, C> for MCMCTerminator
2795    where
2796        A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Config = C>,
2797    {
2798        fn check_for_termination(
2799            &mut self,
2800            current_step: usize,
2801            _algorithm: &mut A,
2802            _problem: &P,
2803            status: &mut EnsembleStatus,
2804            _args: &MaybeThreadPool,
2805            _config: &C,
2806        ) -> ControlFlow<()> {
2807            Python::attach(|py| -> PyResult<ControlFlow<()>> {
2808                let wrapped_status = Arc::new(Mutex::new(std::mem::take(status)));
2809                let py_status = Py::new(py, PyEnsembleStatus(wrapped_status.clone()))?;
2810                let ret = self
2811                    .0
2812                    .bind(py)
2813                    .call_method1("check_for_termination", (current_step, py_status))
2814                    .expect("Error calling check_for_termination");
2815                {
2816                    let mut guard = wrapped_status.lock();
2817                    std::mem::swap(status, &mut *guard);
2818                }
2819                let cf: PyControlFlow = ret.extract()?;
2820                Ok(cf.into())
2821            })
2822            .unwrap_or(ControlFlow::Continue(()))
2823        }
2824    }
2825
2826    /// Calculate the integrated autocorrelation time for each parameter according to
2827    /// [Karamanis]_
2828    ///
2829    /// Parameters
2830    /// ----------
2831    /// x : array_like
2832    ///     An array of dimension ``(n_walkers, n_steps, n_parameters)``
2833    /// c : float, default = 7.0
2834    ///     Set the time window for Sokal's autowindowing function[Sokal]_. If None, the default window
2835    ///     size of 7.0 is used.
2836    ///
2837    /// .. rubric:: References
2838    ///
2839    /// .. [Karamanis] Karamanis, M., & Beutler, F. (2021). Ensemble slice sampling. Statistics and Computing, 31(5). https://doi.org/10.1007/s11222-021-10038-2
2840    /// .. [Sokal] Sokal, A. (1997). Monte Carlo Methods in Statistical Mechanics: Foundations and New Algorithms. In NATO ASI Series (pp. 131–192). Springer US. https://doi.org/10.1007/978-1-4899-0319-8_6
2841    ///
2842    #[pyfunction(name = "integrated_autocorrelation_times")]
2843    #[pyo3(signature = (samples, *, c=None))]
2844    pub fn py_integrated_autocorrelation_times<'py>(
2845        py: Python<'py>,
2846        samples: Vec<Vec<Vec<Float>>>,
2847        c: Option<Float>,
2848    ) -> Bound<'py, PyArray1<Float>> {
2849        let samples: Vec<Vec<DVector<Float>>> = samples
2850            .into_iter()
2851            .map(|v| v.into_iter().map(|p| DVector::from_vec(p)).collect())
2852            .collect();
2853        integrated_autocorrelation_times(samples, c)
2854            .as_slice()
2855            .to_pyarray(py)
2856    }
2857
2858    /// A terminator for MCMC algorithms that monitors autocorrelation according to [Karamanis]_.
2859    ///
2860    /// Parameters
2861    /// ----------
2862    /// n_check : int, default=50
2863    ///     Number of steps to take between autocorrelation checks.
2864    /// n_taus_threshold : int, default=50
2865    ///     Convergence may be achieved if the number of steps exceeds this value times the current
2866    ///     mean autocorrelation time.
2867    /// dtau_threshold : float, default=0.01
2868    ///     The minimum change in mean autocorrelation time required to consider convergence.
2869    /// discard : float, default=0.5
2870    ///     The fraction of the chain to discard when calculating autocorrelation times.
2871    /// terminate : bool, default=True
2872    ///     If set to False, the terminator will act like an observer and only store
2873    ///     autocorrelation times.
2874    /// sokal_window : float, default=None
2875    ///     Set the time window for Sokal's autowindowing function[Sokal]_. If None, the default window
2876    ///     size of 7.0 is used.
2877    /// verbose : bool, default=False
2878    ///     Print autocorrelation information at each check step.
2879    ///
2880    #[pyclass(name = "AutocorrelationTerminator", module = "laddu")]
2881    #[derive(Clone)]
2882    pub struct PyAutocorrelationTerminator(Arc<Mutex<AutocorrelationTerminator>>);
2883
2884    #[pymethods]
2885    impl PyAutocorrelationTerminator {
2886        #[new]
2887        #[pyo3(signature = (*, n_check = 50, n_taus_threshold = 50, dtau_threshold = 0.01, discard = 0.5, terminate = true, sokal_window = None, verbose = false))]
2888        fn new(
2889            n_check: usize,
2890            n_taus_threshold: usize,
2891            dtau_threshold: Float,
2892            discard: Float,
2893            terminate: bool,
2894            sokal_window: Option<Float>,
2895            verbose: bool,
2896        ) -> Self {
2897            let mut act = AutocorrelationTerminator::default()
2898                .with_n_check(n_check)
2899                .with_n_taus_threshold(n_taus_threshold)
2900                .with_dtau_threshold(dtau_threshold)
2901                .with_discard(discard)
2902                .with_terminate(terminate)
2903                .with_verbose(verbose);
2904            if let Some(sokal_window) = sokal_window {
2905                act = act.with_sokal_window(sokal_window)
2906            }
2907            Self(act.build())
2908        }
2909
2910        /// A list of autocorrelation times for each parameter.
2911        ///
2912        /// Returns
2913        /// -------
2914        /// taus : array_like
2915        ///
2916        #[getter]
2917        fn taus<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
2918            self.0.lock().taus.to_pyarray(py)
2919        }
2920    }
2921}