Skip to main content

laddu_python/extensions/
optimize.rs

1use ganesh::{
2    algorithms::{
3        gradient::{
4            Adam, AdamConfig, ConjugateGradient, ConjugateGradientConfig, GradientStatus,
5            LBFGSBConfig, TrustRegion, TrustRegionConfig, LBFGSB,
6        },
7        gradient_free::{
8            nelder_mead::NelderMeadInit, CMAESConfig, CMAESInit, DifferentialEvolution,
9            DifferentialEvolutionConfig, DifferentialEvolutionInit, GradientFreeStatus, NelderMead,
10            NelderMeadConfig, CMAES,
11        },
12        mcmc::{aies::AIESInit, ess::ESSInit, AIESConfig, ESSConfig, EnsembleStatus, AIES, ESS},
13        particles::{PSOConfig, SwarmStatus, PSO},
14    },
15    core::{Callbacks, CtrlCAbortSignal, MCMCSummary, MinimizationSummary},
16    python::{
17        PyAIESOptions, PyAdamOptions, PyCMAESOptions, PyConjugateGradientOptions,
18        PyDifferentialEvolutionOptions, PyESSOptions, PyLBFGSBOptions, PyNelderMeadOptions,
19        PyPSOOptions, PyTrustRegionOptions,
20    },
21    traits::{Algorithm, CostFunction, Gradient, LogDensity, Status, SupportsParameterNames},
22    DVector,
23};
24use laddu_core::{f64, validate_free_parameter_len, LadduError, LadduResult};
25use laddu_extensions::{
26    likelihood::LikelihoodTerm, optimize::MaybeThreadPool, LikelihoodTermObserver,
27};
28use pyo3::{
29    exceptions::{PyTypeError, PyValueError},
30    prelude::*,
31    types::PyList,
32    PyErr,
33};
34
35use crate::extensions::callbacks::{
36    MCMCObserver, MCMCTerminator, MinimizationObserver, MinimizationTerminator,
37};
38
39fn run_minimizer<A, P, S>(
40    problem: &P,
41    num_threads: usize,
42    init: A::Init,
43    config: A::Config,
44    callbacks: Callbacks<A, P, S, MaybeThreadPool, LadduError, A::Config>,
45) -> LadduResult<MinimizationSummary>
46where
47    A: Algorithm<P, S, MaybeThreadPool, LadduError, Summary = MinimizationSummary> + Default,
48    P: LikelihoodTerm,
49    S: Status,
50{
51    let mtp = MaybeThreadPool::new(num_threads);
52    A::default().process(
53        problem,
54        &mtp,
55        init,
56        config,
57        callbacks.with_observer(LikelihoodTermObserver),
58    )
59}
60
61fn run_mcmc_algorithm<A, P>(
62    problem: &P,
63    num_threads: usize,
64    init: A::Init,
65    config: A::Config,
66    callbacks: Callbacks<A, P, EnsembleStatus, MaybeThreadPool, LadduError, A::Config>,
67) -> LadduResult<MCMCSummary>
68where
69    A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Summary = MCMCSummary> + Default,
70    P: LikelihoodTerm,
71{
72    let mtp = MaybeThreadPool::new(num_threads);
73    A::default().process(
74        problem,
75        &mtp,
76        init,
77        config,
78        callbacks.with_observer(LikelihoodTermObserver),
79    )
80}
81
82fn validate_mcmc_parameter_len(walkers: &[Vec<f64>], expected_len: usize) -> PyResult<()> {
83    for walker in walkers {
84        validate_free_parameter_len(walker.len(), expected_len)
85            .map_err(|err| PyValueError::new_err(err.to_string()))?;
86    }
87    Ok(())
88}
89
90fn normalize_method_name(method: &str) -> String {
91    method
92        .to_lowercase()
93        .trim()
94        .replace("-", "")
95        .replace(" ", "")
96}
97
98fn extract_minimization_observers(
99    observers: Option<Bound<'_, PyAny>>,
100) -> PyResult<Vec<MinimizationObserver>> {
101    if let Some(observers) = observers {
102        if let Ok(observers) = observers.cast::<PyList>() {
103            observers
104                    .into_iter()
105                    .map(|observer| {
106                        observer.extract::<MinimizationObserver>().map_err(|_| {
107                            PyValueError::new_err(
108                                "The observers must be either a single MinimizationObserver or a list of MinimizationObservers.",
109                            )
110                        })
111                    })
112                    .collect()
113        } else if let Ok(observer) = observers.extract::<MinimizationObserver>() {
114            Ok(vec![observer])
115        } else {
116            Err(PyValueError::new_err(
117                "The observers must be either a single MinimizationObserver or a list of MinimizationObservers.",
118            ))
119        }
120    } else {
121        Ok(vec![])
122    }
123}
124
125fn extract_minimization_terminators(
126    terminators: Option<Bound<'_, PyAny>>,
127) -> PyResult<Vec<MinimizationTerminator>> {
128    if let Some(terminators) = terminators {
129        if let Ok(terminators) = terminators.cast::<PyList>() {
130            terminators
131                    .into_iter()
132                    .map(|terminator| {
133                        terminator.extract::<MinimizationTerminator>().map_err(|_| {
134                            PyValueError::new_err(
135                                "The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators.",
136                            )
137                        })
138                    })
139                    .collect()
140        } else if let Ok(terminator) = terminators.extract::<MinimizationTerminator>() {
141            Ok(vec![terminator])
142        } else {
143            Err(PyValueError::new_err(
144                "The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators.",
145            ))
146        }
147    } else {
148        Ok(vec![])
149    }
150}
151
152fn extract_mcmc_observers(observers: Option<Bound<'_, PyAny>>) -> PyResult<Vec<MCMCObserver>> {
153    if let Some(observers) = observers {
154        if let Ok(observers) = observers.cast::<PyList>() {
155            observers
156                    .into_iter()
157                    .map(|observer| {
158                        observer.extract::<MCMCObserver>().map_err(|_| {
159                            PyValueError::new_err(
160                                "The observers must be either a single MCMCObserver or a list of MCMCObservers.",
161                            )
162                        })
163                    })
164                    .collect()
165        } else if let Ok(observer) = observers.extract::<MCMCObserver>() {
166            Ok(vec![observer])
167        } else {
168            Err(PyValueError::new_err(
169                "The observers must be either a single MCMCObserver or a list of MCMCObservers.",
170            ))
171        }
172    } else {
173        Ok(vec![])
174    }
175}
176
177fn extract_mcmc_terminators(
178    terminators: Option<Bound<'_, PyAny>>,
179) -> PyResult<Vec<MCMCTerminator>> {
180    if let Some(terminators) = terminators {
181        if let Ok(terminators) = terminators.cast::<PyList>() {
182            terminators
183                    .into_iter()
184                    .map(|terminator| terminator.extract::<MCMCTerminator>().map_err(|_| {
185                        PyValueError::new_err(
186                            "The terminators must be either a single MCMCTerminator or a list of MCMCTerminators.",
187                        )
188                    }))
189                    .collect()
190        } else if let Ok(terminator) = terminators.extract::<MCMCTerminator>() {
191            Ok(vec![terminator])
192        } else {
193            Err(PyValueError::new_err(
194                "The terminators must be either a single MCMCTerminator or a list of MCMCTerminators.",
195            ))
196        }
197    } else {
198        Ok(vec![])
199    }
200}
201
202#[allow(clippy::too_many_arguments)]
203pub(crate) fn minimize_from_python<P>(
204    problem: &P,
205    p0: &Bound<'_, PyAny>,
206    n_free: usize,
207    parameter_names: &[String],
208    method: String,
209    config: Option<&Bound<'_, PyAny>>,
210    options: Option<&Bound<'_, PyAny>>,
211    observers: Option<Bound<'_, PyAny>>,
212    terminators: Option<Bound<'_, PyAny>>,
213    threads: usize,
214) -> PyResult<MinimizationSummary>
215where
216    P: Gradient<MaybeThreadPool, LadduError>
217        + CostFunction<MaybeThreadPool, LadduError>
218        + LikelihoodTerm,
219{
220    let observers = extract_minimization_observers(observers)?;
221    let terminators = extract_minimization_terminators(terminators)?;
222    let method = normalize_method_name(&method);
223
224    match method.as_str() {
225        "lbfgsb" => {
226            let init = p0.extract::<Vec<f64>>()?;
227            validate_free_parameter_len(init.len(), n_free)?;
228            let mut config = config
229                .map(|c| c.extract::<LBFGSBConfig>())
230                .transpose()
231                .map_err(|err| PyTypeError::new_err(err.to_string()))?
232                .unwrap_or_default();
233            if config.get_parameter_names_mut().is_none() {
234                config = config.with_parameter_names(parameter_names);
235            }
236            let parsed_options = options
237                .map(|opt| opt.extract::<PyLBFGSBOptions>())
238                .transpose()
239                .map_err(|err| PyTypeError::new_err(err.to_string()))?
240                .unwrap_or_default();
241            let mut callbacks = parsed_options
242                .build_callbacks()
243                .map_err(|err| PyValueError::new_err(err.to_string()))?;
244            for observer in observers {
245                callbacks = callbacks.with_observer(observer);
246            }
247            for terminator in terminators {
248                callbacks = callbacks.with_terminator(terminator);
249            }
250            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
251            run_minimizer::<LBFGSB, _, GradientStatus>(
252                problem,
253                threads,
254                DVector::from_vec(init),
255                config,
256                callbacks,
257            )
258            .map_err(PyErr::from)
259        }
260        "adam" => {
261            let init = p0.extract::<Vec<f64>>()?;
262            validate_free_parameter_len(init.len(), n_free)?;
263            let mut config = config
264                .map(|c| c.extract::<AdamConfig>())
265                .transpose()
266                .map_err(|err| PyTypeError::new_err(err.to_string()))?
267                .unwrap_or_default();
268            if config.get_parameter_names_mut().is_none() {
269                config = config.with_parameter_names(parameter_names);
270            }
271            let parsed_options = options
272                .map(|opt| opt.extract::<PyAdamOptions>())
273                .transpose()
274                .map_err(|err| PyTypeError::new_err(err.to_string()))?
275                .unwrap_or_default();
276            let mut callbacks = parsed_options.build_callbacks();
277            for observer in observers {
278                callbacks = callbacks.with_observer(observer);
279            }
280            for terminator in terminators {
281                callbacks = callbacks.with_terminator(terminator);
282            }
283            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
284            run_minimizer::<Adam, _, GradientStatus>(
285                problem,
286                threads,
287                DVector::from_vec(init),
288                config,
289                callbacks,
290            )
291            .map_err(PyErr::from)
292        }
293        "conjugategradient" => {
294            let init = p0.extract::<Vec<f64>>()?;
295            validate_free_parameter_len(init.len(), n_free)?;
296            let mut config = config
297                .map(|c| c.extract::<ConjugateGradientConfig>())
298                .transpose()
299                .map_err(|err| PyTypeError::new_err(err.to_string()))?
300                .unwrap_or_default();
301            if config.get_parameter_names_mut().is_none() {
302                config = config.with_parameter_names(parameter_names);
303            }
304            let parsed_options = options
305                .map(|opt| opt.extract::<PyConjugateGradientOptions>())
306                .transpose()
307                .map_err(|err| PyTypeError::new_err(err.to_string()))?
308                .unwrap_or_default();
309            let mut callbacks = parsed_options.build_callbacks()?;
310            for observer in observers {
311                callbacks = callbacks.with_observer(observer);
312            }
313            for terminator in terminators {
314                callbacks = callbacks.with_terminator(terminator);
315            }
316            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
317            run_minimizer::<ConjugateGradient, _, GradientStatus>(
318                problem,
319                threads,
320                DVector::from_vec(init),
321                config,
322                callbacks,
323            )
324            .map_err(PyErr::from)
325        }
326        "trustregion" => {
327            let init = p0.extract::<Vec<f64>>()?;
328            validate_free_parameter_len(init.len(), n_free)?;
329            let mut config = config
330                .map(|c| c.extract::<TrustRegionConfig>())
331                .transpose()
332                .map_err(|err| PyTypeError::new_err(err.to_string()))?
333                .unwrap_or_default();
334            if config.get_parameter_names_mut().is_none() {
335                config = config.with_parameter_names(parameter_names);
336            }
337            let parsed_options = options
338                .map(|opt| opt.extract::<PyTrustRegionOptions>())
339                .transpose()
340                .map_err(|err| PyTypeError::new_err(err.to_string()))?
341                .unwrap_or_default();
342            let mut callbacks = parsed_options.build_callbacks()?;
343            for observer in observers {
344                callbacks = callbacks.with_observer(observer);
345            }
346            for terminator in terminators {
347                callbacks = callbacks.with_terminator(terminator);
348            }
349            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
350            run_minimizer::<TrustRegion, _, GradientStatus>(
351                problem,
352                threads,
353                DVector::from_vec(init),
354                config,
355                callbacks,
356            )
357            .map_err(PyErr::from)
358        }
359        "cmaes" => {
360            let init = p0.extract::<CMAESInit>()?;
361            let mut config = config
362                .map(|c| c.extract::<CMAESConfig>())
363                .transpose()
364                .map_err(|err| PyTypeError::new_err(err.to_string()))?
365                .unwrap_or_default();
366            if config.get_parameter_names_mut().is_none() {
367                config = config.with_parameter_names(parameter_names);
368            }
369            let parsed_options = options
370                .map(|opt| opt.extract::<PyCMAESOptions>())
371                .transpose()
372                .map_err(|err| PyTypeError::new_err(err.to_string()))?
373                .unwrap_or_default();
374            let mut callbacks = parsed_options.build_callbacks();
375            for observer in observers {
376                callbacks = callbacks.with_observer(observer);
377            }
378            for terminator in terminators {
379                callbacks = callbacks.with_terminator(terminator);
380            }
381            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
382            run_minimizer::<CMAES, _, GradientFreeStatus>(problem, threads, init, config, callbacks)
383                .map_err(PyErr::from)
384        }
385        "differential-evolution" => {
386            let init = p0.extract::<DifferentialEvolutionInit>()?;
387            let mut config = config
388                .map(|c| c.extract::<DifferentialEvolutionConfig>())
389                .transpose()
390                .map_err(|err| PyTypeError::new_err(err.to_string()))?
391                .unwrap_or_default();
392            if config.get_parameter_names_mut().is_none() {
393                config = config.with_parameter_names(parameter_names);
394            }
395            let parsed_options = options
396                .map(|opt| opt.extract::<PyDifferentialEvolutionOptions>())
397                .transpose()
398                .map_err(|err| PyTypeError::new_err(err.to_string()))?
399                .unwrap_or_default();
400            let mut callbacks = parsed_options.build_callbacks();
401            for observer in observers {
402                callbacks = callbacks.with_observer(observer);
403            }
404            for terminator in terminators {
405                callbacks = callbacks.with_terminator(terminator);
406            }
407            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
408            run_minimizer::<DifferentialEvolution, _, GradientFreeStatus>(
409                problem, threads, init, config, callbacks,
410            )
411            .map_err(PyErr::from)
412        }
413        "neldermead" => {
414            let init = if let Ok(init) = p0.extract::<NelderMeadInit>() {
415                init
416            } else if let Ok(init) = p0.extract::<Vec<f64>>() {
417                validate_free_parameter_len(init.len(), n_free)?;
418                NelderMeadInit::new(init)
419            } else {
420                return Err(PyTypeError::new_err(
421                    "p0 for method 'nelder-mead' must be NelderMeadInit-compatible or a point",
422                ));
423            };
424            let mut config = config
425                .map(|c| c.extract::<NelderMeadConfig>())
426                .transpose()
427                .map_err(|err| PyTypeError::new_err(err.to_string()))?
428                .unwrap_or_default();
429            if config.get_parameter_names_mut().is_none() {
430                config = config.with_parameter_names(parameter_names);
431            }
432            let parsed_options = options
433                .map(|opt| opt.extract::<PyNelderMeadOptions>())
434                .transpose()
435                .map_err(|err| PyTypeError::new_err(err.to_string()))?
436                .unwrap_or_default();
437            let mut callbacks = parsed_options.build_callbacks();
438            for observer in observers {
439                callbacks = callbacks.with_observer(observer);
440            }
441            for terminator in terminators {
442                callbacks = callbacks.with_terminator(terminator);
443            }
444            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
445            run_minimizer::<NelderMead, _, GradientFreeStatus>(
446                problem, threads, init, config, callbacks,
447            )
448            .map_err(PyErr::from)
449        }
450        "pso" => {
451            let init = if let Ok(init) = p0.extract::<ganesh::algorithms::particles::Swarm>() {
452                init
453            } else if let Ok(positions) = p0.extract::<Vec<Vec<f64>>>() {
454                validate_mcmc_parameter_len(&positions, n_free)?;
455                ganesh::algorithms::particles::Swarm::new(
456                    ganesh::algorithms::particles::SwarmPositionInitializer::Custom(
457                        positions.into_iter().map(DVector::from_vec).collect(),
458                    ),
459                )
460            } else {
461                return Err(PyTypeError::new_err(
462                    "p0 for method 'pso' must be PSOInit-compatible or a position matrix",
463                ));
464            };
465            let mut config = config
466                .map(|c| c.extract::<PSOConfig>())
467                .transpose()
468                .map_err(|err| PyTypeError::new_err(err.to_string()))?
469                .unwrap_or_default();
470            if config.get_parameter_names_mut().is_none() {
471                config = config.with_parameter_names(parameter_names);
472            }
473            let parsed_options = options
474                .map(|opt| opt.extract::<PyPSOOptions>())
475                .transpose()
476                .map_err(|err| PyTypeError::new_err(err.to_string()))?
477                .unwrap_or_default();
478            let mut callbacks = parsed_options.build_callbacks();
479            for observer in observers {
480                callbacks = callbacks.with_observer(observer);
481            }
482            for terminator in terminators {
483                callbacks = callbacks.with_terminator(terminator);
484            }
485            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
486            run_minimizer::<PSO, _, SwarmStatus>(problem, threads, init, config, callbacks)
487                .map_err(PyErr::from)
488        }
489        _ => Err(PyValueError::new_err(format!(
490            "Invalid minimizer: {method}"
491        ))),
492    }
493}
494
495#[allow(clippy::too_many_arguments)]
496pub(crate) fn mcmc_from_python<P>(
497    problem: &P,
498    p0: &Bound<'_, PyAny>,
499    n_free: usize,
500    parameter_names: &[String],
501    method: String,
502    config: Option<&Bound<'_, PyAny>>,
503    options: Option<&Bound<'_, PyAny>>,
504    observers: Option<Bound<'_, PyAny>>,
505    terminators: Option<Bound<'_, PyAny>>,
506    threads: usize,
507) -> PyResult<MCMCSummary>
508where
509    P: LogDensity<MaybeThreadPool, LadduError> + LikelihoodTerm,
510{
511    let observers = extract_mcmc_observers(observers)?;
512    let terminators = extract_mcmc_terminators(terminators)?;
513    let method = normalize_method_name(&method);
514
515    match method.as_str() {
516        "aies" => {
517            let mut config = config
518                .map(|c| c.extract::<AIESConfig>())
519                .transpose()
520                .map_err(|err| PyTypeError::new_err(err.to_string()))?
521                .unwrap_or_default();
522            if config.get_parameter_names_mut().is_none() {
523                config = config.with_parameter_names(parameter_names);
524            }
525            let parsed_options = options
526                .map(|opt| opt.extract::<PyAIESOptions>())
527                .transpose()
528                .map_err(|err| PyTypeError::new_err(err.to_string()))?
529                .unwrap_or_default();
530            let init = if let Ok(init) = p0.extract::<AIESInit>() {
531                init
532            } else if let Ok(walkers) = p0.extract::<Vec<Vec<f64>>>() {
533                validate_mcmc_parameter_len(&walkers, n_free)?;
534                AIESInit::new(walkers.into_iter().map(DVector::from_vec).collect())
535                    .map_err(|err| PyValueError::new_err(err.to_string()))?
536            } else {
537                return Err(PyTypeError::new_err(
538                    "p0 for method 'aies' must be AIESInit-compatible or a walker matrix",
539                ));
540            };
541            let mut callbacks = parsed_options.build_callbacks();
542            for observer in observers {
543                callbacks = callbacks.with_observer(observer);
544            }
545            for terminator in terminators {
546                callbacks = callbacks.with_terminator(terminator);
547            }
548            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
549            run_mcmc_algorithm::<AIES, _>(problem, threads, init, config, callbacks)
550                .map_err(PyErr::from)
551        }
552        "ess" => {
553            let mut config = config
554                .map(|c| c.extract::<ESSConfig>())
555                .transpose()
556                .map_err(|err| PyTypeError::new_err(err.to_string()))?
557                .unwrap_or_default();
558            if config.get_parameter_names_mut().is_none() {
559                config = config.with_parameter_names(parameter_names);
560            }
561            let parsed_options = options
562                .map(|opt| opt.extract::<PyESSOptions>())
563                .transpose()
564                .map_err(|err| PyTypeError::new_err(err.to_string()))?
565                .unwrap_or_default();
566            let init = if let Ok(init) = p0.extract::<ESSInit>() {
567                init
568            } else if let Ok(walkers) = p0.extract::<Vec<Vec<f64>>>() {
569                validate_mcmc_parameter_len(&walkers, n_free)?;
570                ESSInit::new(walkers.into_iter().map(DVector::from_vec).collect())
571                    .map_err(|err| PyValueError::new_err(err.to_string()))?
572            } else {
573                return Err(PyTypeError::new_err(
574                    "p0 for method 'ess' must be ESSInit-compatible or a walker matrix",
575                ));
576            };
577            let mut callbacks = parsed_options.build_callbacks();
578            for observer in observers {
579                callbacks = callbacks.with_observer(observer);
580            }
581            for terminator in terminators {
582                callbacks = callbacks.with_terminator(terminator);
583            }
584            callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
585            run_mcmc_algorithm::<ESS, _>(problem, threads, init, config, callbacks)
586                .map_err(PyErr::from)
587        }
588        _ => Err(PyValueError::new_err(format!(
589            "Invalid MCMC algorithm: {method}"
590        ))),
591    }
592}