1use ganesh::{
2 algorithms::{
3 gradient::{
4 Adam, AdamConfig, ConjugateGradient, ConjugateGradientConfig, GradientStatus,
5 LBFGSBConfig, TrustRegion, TrustRegionConfig, LBFGSB,
6 },
7 gradient_free::{
8 nelder_mead::NelderMeadInit, CMAESConfig, CMAESInit, DifferentialEvolution,
9 DifferentialEvolutionConfig, DifferentialEvolutionInit, GradientFreeStatus, NelderMead,
10 NelderMeadConfig, CMAES,
11 },
12 mcmc::{aies::AIESInit, ess::ESSInit, AIESConfig, ESSConfig, EnsembleStatus, AIES, ESS},
13 particles::{PSOConfig, SwarmStatus, PSO},
14 },
15 core::{Callbacks, CtrlCAbortSignal, MCMCSummary, MinimizationSummary},
16 python::{
17 PyAIESOptions, PyAdamOptions, PyCMAESOptions, PyConjugateGradientOptions,
18 PyDifferentialEvolutionOptions, PyESSOptions, PyLBFGSBOptions, PyNelderMeadOptions,
19 PyPSOOptions, PyTrustRegionOptions,
20 },
21 traits::{Algorithm, CostFunction, Gradient, LogDensity, Status, SupportsParameterNames},
22 DVector,
23};
24use laddu_core::{f64, validate_free_parameter_len, LadduError, LadduResult};
25use laddu_extensions::{
26 likelihood::LikelihoodTerm, optimize::MaybeThreadPool, LikelihoodTermObserver,
27};
28use pyo3::{
29 exceptions::{PyTypeError, PyValueError},
30 prelude::*,
31 types::PyList,
32 PyErr,
33};
34
35use crate::extensions::callbacks::{
36 MCMCObserver, MCMCTerminator, MinimizationObserver, MinimizationTerminator,
37};
38
39fn run_minimizer<A, P, S>(
40 problem: &P,
41 num_threads: usize,
42 init: A::Init,
43 config: A::Config,
44 callbacks: Callbacks<A, P, S, MaybeThreadPool, LadduError, A::Config>,
45) -> LadduResult<MinimizationSummary>
46where
47 A: Algorithm<P, S, MaybeThreadPool, LadduError, Summary = MinimizationSummary> + Default,
48 P: LikelihoodTerm,
49 S: Status,
50{
51 let mtp = MaybeThreadPool::new(num_threads);
52 A::default().process(
53 problem,
54 &mtp,
55 init,
56 config,
57 callbacks.with_observer(LikelihoodTermObserver),
58 )
59}
60
61fn run_mcmc_algorithm<A, P>(
62 problem: &P,
63 num_threads: usize,
64 init: A::Init,
65 config: A::Config,
66 callbacks: Callbacks<A, P, EnsembleStatus, MaybeThreadPool, LadduError, A::Config>,
67) -> LadduResult<MCMCSummary>
68where
69 A: Algorithm<P, EnsembleStatus, MaybeThreadPool, LadduError, Summary = MCMCSummary> + Default,
70 P: LikelihoodTerm,
71{
72 let mtp = MaybeThreadPool::new(num_threads);
73 A::default().process(
74 problem,
75 &mtp,
76 init,
77 config,
78 callbacks.with_observer(LikelihoodTermObserver),
79 )
80}
81
82fn validate_mcmc_parameter_len(walkers: &[Vec<f64>], expected_len: usize) -> PyResult<()> {
83 for walker in walkers {
84 validate_free_parameter_len(walker.len(), expected_len)
85 .map_err(|err| PyValueError::new_err(err.to_string()))?;
86 }
87 Ok(())
88}
89
90fn normalize_method_name(method: &str) -> String {
91 method
92 .to_lowercase()
93 .trim()
94 .replace("-", "")
95 .replace(" ", "")
96}
97
98fn extract_minimization_observers(
99 observers: Option<Bound<'_, PyAny>>,
100) -> PyResult<Vec<MinimizationObserver>> {
101 if let Some(observers) = observers {
102 if let Ok(observers) = observers.cast::<PyList>() {
103 observers
104 .into_iter()
105 .map(|observer| {
106 observer.extract::<MinimizationObserver>().map_err(|_| {
107 PyValueError::new_err(
108 "The observers must be either a single MinimizationObserver or a list of MinimizationObservers.",
109 )
110 })
111 })
112 .collect()
113 } else if let Ok(observer) = observers.extract::<MinimizationObserver>() {
114 Ok(vec![observer])
115 } else {
116 Err(PyValueError::new_err(
117 "The observers must be either a single MinimizationObserver or a list of MinimizationObservers.",
118 ))
119 }
120 } else {
121 Ok(vec![])
122 }
123}
124
125fn extract_minimization_terminators(
126 terminators: Option<Bound<'_, PyAny>>,
127) -> PyResult<Vec<MinimizationTerminator>> {
128 if let Some(terminators) = terminators {
129 if let Ok(terminators) = terminators.cast::<PyList>() {
130 terminators
131 .into_iter()
132 .map(|terminator| {
133 terminator.extract::<MinimizationTerminator>().map_err(|_| {
134 PyValueError::new_err(
135 "The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators.",
136 )
137 })
138 })
139 .collect()
140 } else if let Ok(terminator) = terminators.extract::<MinimizationTerminator>() {
141 Ok(vec![terminator])
142 } else {
143 Err(PyValueError::new_err(
144 "The terminators must be either a single MinimizationTerminator or a list of MinimizationTerminators.",
145 ))
146 }
147 } else {
148 Ok(vec![])
149 }
150}
151
152fn extract_mcmc_observers(observers: Option<Bound<'_, PyAny>>) -> PyResult<Vec<MCMCObserver>> {
153 if let Some(observers) = observers {
154 if let Ok(observers) = observers.cast::<PyList>() {
155 observers
156 .into_iter()
157 .map(|observer| {
158 observer.extract::<MCMCObserver>().map_err(|_| {
159 PyValueError::new_err(
160 "The observers must be either a single MCMCObserver or a list of MCMCObservers.",
161 )
162 })
163 })
164 .collect()
165 } else if let Ok(observer) = observers.extract::<MCMCObserver>() {
166 Ok(vec![observer])
167 } else {
168 Err(PyValueError::new_err(
169 "The observers must be either a single MCMCObserver or a list of MCMCObservers.",
170 ))
171 }
172 } else {
173 Ok(vec![])
174 }
175}
176
177fn extract_mcmc_terminators(
178 terminators: Option<Bound<'_, PyAny>>,
179) -> PyResult<Vec<MCMCTerminator>> {
180 if let Some(terminators) = terminators {
181 if let Ok(terminators) = terminators.cast::<PyList>() {
182 terminators
183 .into_iter()
184 .map(|terminator| terminator.extract::<MCMCTerminator>().map_err(|_| {
185 PyValueError::new_err(
186 "The terminators must be either a single MCMCTerminator or a list of MCMCTerminators.",
187 )
188 }))
189 .collect()
190 } else if let Ok(terminator) = terminators.extract::<MCMCTerminator>() {
191 Ok(vec![terminator])
192 } else {
193 Err(PyValueError::new_err(
194 "The terminators must be either a single MCMCTerminator or a list of MCMCTerminators.",
195 ))
196 }
197 } else {
198 Ok(vec![])
199 }
200}
201
202#[allow(clippy::too_many_arguments)]
203pub(crate) fn minimize_from_python<P>(
204 problem: &P,
205 p0: &Bound<'_, PyAny>,
206 n_free: usize,
207 parameter_names: &[String],
208 method: String,
209 config: Option<&Bound<'_, PyAny>>,
210 options: Option<&Bound<'_, PyAny>>,
211 observers: Option<Bound<'_, PyAny>>,
212 terminators: Option<Bound<'_, PyAny>>,
213 threads: usize,
214) -> PyResult<MinimizationSummary>
215where
216 P: Gradient<MaybeThreadPool, LadduError>
217 + CostFunction<MaybeThreadPool, LadduError>
218 + LikelihoodTerm,
219{
220 let observers = extract_minimization_observers(observers)?;
221 let terminators = extract_minimization_terminators(terminators)?;
222 let method = normalize_method_name(&method);
223
224 match method.as_str() {
225 "lbfgsb" => {
226 let init = p0.extract::<Vec<f64>>()?;
227 validate_free_parameter_len(init.len(), n_free)?;
228 let mut config = config
229 .map(|c| c.extract::<LBFGSBConfig>())
230 .transpose()
231 .map_err(|err| PyTypeError::new_err(err.to_string()))?
232 .unwrap_or_default();
233 if config.get_parameter_names_mut().is_none() {
234 config = config.with_parameter_names(parameter_names);
235 }
236 let parsed_options = options
237 .map(|opt| opt.extract::<PyLBFGSBOptions>())
238 .transpose()
239 .map_err(|err| PyTypeError::new_err(err.to_string()))?
240 .unwrap_or_default();
241 let mut callbacks = parsed_options
242 .build_callbacks()
243 .map_err(|err| PyValueError::new_err(err.to_string()))?;
244 for observer in observers {
245 callbacks = callbacks.with_observer(observer);
246 }
247 for terminator in terminators {
248 callbacks = callbacks.with_terminator(terminator);
249 }
250 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
251 run_minimizer::<LBFGSB, _, GradientStatus>(
252 problem,
253 threads,
254 DVector::from_vec(init),
255 config,
256 callbacks,
257 )
258 .map_err(PyErr::from)
259 }
260 "adam" => {
261 let init = p0.extract::<Vec<f64>>()?;
262 validate_free_parameter_len(init.len(), n_free)?;
263 let mut config = config
264 .map(|c| c.extract::<AdamConfig>())
265 .transpose()
266 .map_err(|err| PyTypeError::new_err(err.to_string()))?
267 .unwrap_or_default();
268 if config.get_parameter_names_mut().is_none() {
269 config = config.with_parameter_names(parameter_names);
270 }
271 let parsed_options = options
272 .map(|opt| opt.extract::<PyAdamOptions>())
273 .transpose()
274 .map_err(|err| PyTypeError::new_err(err.to_string()))?
275 .unwrap_or_default();
276 let mut callbacks = parsed_options.build_callbacks();
277 for observer in observers {
278 callbacks = callbacks.with_observer(observer);
279 }
280 for terminator in terminators {
281 callbacks = callbacks.with_terminator(terminator);
282 }
283 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
284 run_minimizer::<Adam, _, GradientStatus>(
285 problem,
286 threads,
287 DVector::from_vec(init),
288 config,
289 callbacks,
290 )
291 .map_err(PyErr::from)
292 }
293 "conjugategradient" => {
294 let init = p0.extract::<Vec<f64>>()?;
295 validate_free_parameter_len(init.len(), n_free)?;
296 let mut config = config
297 .map(|c| c.extract::<ConjugateGradientConfig>())
298 .transpose()
299 .map_err(|err| PyTypeError::new_err(err.to_string()))?
300 .unwrap_or_default();
301 if config.get_parameter_names_mut().is_none() {
302 config = config.with_parameter_names(parameter_names);
303 }
304 let parsed_options = options
305 .map(|opt| opt.extract::<PyConjugateGradientOptions>())
306 .transpose()
307 .map_err(|err| PyTypeError::new_err(err.to_string()))?
308 .unwrap_or_default();
309 let mut callbacks = parsed_options.build_callbacks()?;
310 for observer in observers {
311 callbacks = callbacks.with_observer(observer);
312 }
313 for terminator in terminators {
314 callbacks = callbacks.with_terminator(terminator);
315 }
316 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
317 run_minimizer::<ConjugateGradient, _, GradientStatus>(
318 problem,
319 threads,
320 DVector::from_vec(init),
321 config,
322 callbacks,
323 )
324 .map_err(PyErr::from)
325 }
326 "trustregion" => {
327 let init = p0.extract::<Vec<f64>>()?;
328 validate_free_parameter_len(init.len(), n_free)?;
329 let mut config = config
330 .map(|c| c.extract::<TrustRegionConfig>())
331 .transpose()
332 .map_err(|err| PyTypeError::new_err(err.to_string()))?
333 .unwrap_or_default();
334 if config.get_parameter_names_mut().is_none() {
335 config = config.with_parameter_names(parameter_names);
336 }
337 let parsed_options = options
338 .map(|opt| opt.extract::<PyTrustRegionOptions>())
339 .transpose()
340 .map_err(|err| PyTypeError::new_err(err.to_string()))?
341 .unwrap_or_default();
342 let mut callbacks = parsed_options.build_callbacks()?;
343 for observer in observers {
344 callbacks = callbacks.with_observer(observer);
345 }
346 for terminator in terminators {
347 callbacks = callbacks.with_terminator(terminator);
348 }
349 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
350 run_minimizer::<TrustRegion, _, GradientStatus>(
351 problem,
352 threads,
353 DVector::from_vec(init),
354 config,
355 callbacks,
356 )
357 .map_err(PyErr::from)
358 }
359 "cmaes" => {
360 let init = p0.extract::<CMAESInit>()?;
361 let mut config = config
362 .map(|c| c.extract::<CMAESConfig>())
363 .transpose()
364 .map_err(|err| PyTypeError::new_err(err.to_string()))?
365 .unwrap_or_default();
366 if config.get_parameter_names_mut().is_none() {
367 config = config.with_parameter_names(parameter_names);
368 }
369 let parsed_options = options
370 .map(|opt| opt.extract::<PyCMAESOptions>())
371 .transpose()
372 .map_err(|err| PyTypeError::new_err(err.to_string()))?
373 .unwrap_or_default();
374 let mut callbacks = parsed_options.build_callbacks();
375 for observer in observers {
376 callbacks = callbacks.with_observer(observer);
377 }
378 for terminator in terminators {
379 callbacks = callbacks.with_terminator(terminator);
380 }
381 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
382 run_minimizer::<CMAES, _, GradientFreeStatus>(problem, threads, init, config, callbacks)
383 .map_err(PyErr::from)
384 }
385 "differential-evolution" => {
386 let init = p0.extract::<DifferentialEvolutionInit>()?;
387 let mut config = config
388 .map(|c| c.extract::<DifferentialEvolutionConfig>())
389 .transpose()
390 .map_err(|err| PyTypeError::new_err(err.to_string()))?
391 .unwrap_or_default();
392 if config.get_parameter_names_mut().is_none() {
393 config = config.with_parameter_names(parameter_names);
394 }
395 let parsed_options = options
396 .map(|opt| opt.extract::<PyDifferentialEvolutionOptions>())
397 .transpose()
398 .map_err(|err| PyTypeError::new_err(err.to_string()))?
399 .unwrap_or_default();
400 let mut callbacks = parsed_options.build_callbacks();
401 for observer in observers {
402 callbacks = callbacks.with_observer(observer);
403 }
404 for terminator in terminators {
405 callbacks = callbacks.with_terminator(terminator);
406 }
407 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
408 run_minimizer::<DifferentialEvolution, _, GradientFreeStatus>(
409 problem, threads, init, config, callbacks,
410 )
411 .map_err(PyErr::from)
412 }
413 "neldermead" => {
414 let init = if let Ok(init) = p0.extract::<NelderMeadInit>() {
415 init
416 } else if let Ok(init) = p0.extract::<Vec<f64>>() {
417 validate_free_parameter_len(init.len(), n_free)?;
418 NelderMeadInit::new(init)
419 } else {
420 return Err(PyTypeError::new_err(
421 "p0 for method 'nelder-mead' must be NelderMeadInit-compatible or a point",
422 ));
423 };
424 let mut config = config
425 .map(|c| c.extract::<NelderMeadConfig>())
426 .transpose()
427 .map_err(|err| PyTypeError::new_err(err.to_string()))?
428 .unwrap_or_default();
429 if config.get_parameter_names_mut().is_none() {
430 config = config.with_parameter_names(parameter_names);
431 }
432 let parsed_options = options
433 .map(|opt| opt.extract::<PyNelderMeadOptions>())
434 .transpose()
435 .map_err(|err| PyTypeError::new_err(err.to_string()))?
436 .unwrap_or_default();
437 let mut callbacks = parsed_options.build_callbacks();
438 for observer in observers {
439 callbacks = callbacks.with_observer(observer);
440 }
441 for terminator in terminators {
442 callbacks = callbacks.with_terminator(terminator);
443 }
444 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
445 run_minimizer::<NelderMead, _, GradientFreeStatus>(
446 problem, threads, init, config, callbacks,
447 )
448 .map_err(PyErr::from)
449 }
450 "pso" => {
451 let init = if let Ok(init) = p0.extract::<ganesh::algorithms::particles::Swarm>() {
452 init
453 } else if let Ok(positions) = p0.extract::<Vec<Vec<f64>>>() {
454 validate_mcmc_parameter_len(&positions, n_free)?;
455 ganesh::algorithms::particles::Swarm::new(
456 ganesh::algorithms::particles::SwarmPositionInitializer::Custom(
457 positions.into_iter().map(DVector::from_vec).collect(),
458 ),
459 )
460 } else {
461 return Err(PyTypeError::new_err(
462 "p0 for method 'pso' must be PSOInit-compatible or a position matrix",
463 ));
464 };
465 let mut config = config
466 .map(|c| c.extract::<PSOConfig>())
467 .transpose()
468 .map_err(|err| PyTypeError::new_err(err.to_string()))?
469 .unwrap_or_default();
470 if config.get_parameter_names_mut().is_none() {
471 config = config.with_parameter_names(parameter_names);
472 }
473 let parsed_options = options
474 .map(|opt| opt.extract::<PyPSOOptions>())
475 .transpose()
476 .map_err(|err| PyTypeError::new_err(err.to_string()))?
477 .unwrap_or_default();
478 let mut callbacks = parsed_options.build_callbacks();
479 for observer in observers {
480 callbacks = callbacks.with_observer(observer);
481 }
482 for terminator in terminators {
483 callbacks = callbacks.with_terminator(terminator);
484 }
485 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
486 run_minimizer::<PSO, _, SwarmStatus>(problem, threads, init, config, callbacks)
487 .map_err(PyErr::from)
488 }
489 _ => Err(PyValueError::new_err(format!(
490 "Invalid minimizer: {method}"
491 ))),
492 }
493}
494
495#[allow(clippy::too_many_arguments)]
496pub(crate) fn mcmc_from_python<P>(
497 problem: &P,
498 p0: &Bound<'_, PyAny>,
499 n_free: usize,
500 parameter_names: &[String],
501 method: String,
502 config: Option<&Bound<'_, PyAny>>,
503 options: Option<&Bound<'_, PyAny>>,
504 observers: Option<Bound<'_, PyAny>>,
505 terminators: Option<Bound<'_, PyAny>>,
506 threads: usize,
507) -> PyResult<MCMCSummary>
508where
509 P: LogDensity<MaybeThreadPool, LadduError> + LikelihoodTerm,
510{
511 let observers = extract_mcmc_observers(observers)?;
512 let terminators = extract_mcmc_terminators(terminators)?;
513 let method = normalize_method_name(&method);
514
515 match method.as_str() {
516 "aies" => {
517 let mut config = config
518 .map(|c| c.extract::<AIESConfig>())
519 .transpose()
520 .map_err(|err| PyTypeError::new_err(err.to_string()))?
521 .unwrap_or_default();
522 if config.get_parameter_names_mut().is_none() {
523 config = config.with_parameter_names(parameter_names);
524 }
525 let parsed_options = options
526 .map(|opt| opt.extract::<PyAIESOptions>())
527 .transpose()
528 .map_err(|err| PyTypeError::new_err(err.to_string()))?
529 .unwrap_or_default();
530 let init = if let Ok(init) = p0.extract::<AIESInit>() {
531 init
532 } else if let Ok(walkers) = p0.extract::<Vec<Vec<f64>>>() {
533 validate_mcmc_parameter_len(&walkers, n_free)?;
534 AIESInit::new(walkers.into_iter().map(DVector::from_vec).collect())
535 .map_err(|err| PyValueError::new_err(err.to_string()))?
536 } else {
537 return Err(PyTypeError::new_err(
538 "p0 for method 'aies' must be AIESInit-compatible or a walker matrix",
539 ));
540 };
541 let mut callbacks = parsed_options.build_callbacks();
542 for observer in observers {
543 callbacks = callbacks.with_observer(observer);
544 }
545 for terminator in terminators {
546 callbacks = callbacks.with_terminator(terminator);
547 }
548 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
549 run_mcmc_algorithm::<AIES, _>(problem, threads, init, config, callbacks)
550 .map_err(PyErr::from)
551 }
552 "ess" => {
553 let mut config = config
554 .map(|c| c.extract::<ESSConfig>())
555 .transpose()
556 .map_err(|err| PyTypeError::new_err(err.to_string()))?
557 .unwrap_or_default();
558 if config.get_parameter_names_mut().is_none() {
559 config = config.with_parameter_names(parameter_names);
560 }
561 let parsed_options = options
562 .map(|opt| opt.extract::<PyESSOptions>())
563 .transpose()
564 .map_err(|err| PyTypeError::new_err(err.to_string()))?
565 .unwrap_or_default();
566 let init = if let Ok(init) = p0.extract::<ESSInit>() {
567 init
568 } else if let Ok(walkers) = p0.extract::<Vec<Vec<f64>>>() {
569 validate_mcmc_parameter_len(&walkers, n_free)?;
570 ESSInit::new(walkers.into_iter().map(DVector::from_vec).collect())
571 .map_err(|err| PyValueError::new_err(err.to_string()))?
572 } else {
573 return Err(PyTypeError::new_err(
574 "p0 for method 'ess' must be ESSInit-compatible or a walker matrix",
575 ));
576 };
577 let mut callbacks = parsed_options.build_callbacks();
578 for observer in observers {
579 callbacks = callbacks.with_observer(observer);
580 }
581 for terminator in terminators {
582 callbacks = callbacks.with_terminator(terminator);
583 }
584 callbacks = callbacks.with_terminator(CtrlCAbortSignal::new());
585 run_mcmc_algorithm::<ESS, _>(problem, threads, init, config, callbacks)
586 .map_err(PyErr::from)
587 }
588 _ => Err(PyValueError::new_err(format!(
589 "Invalid MCMC algorithm: {method}"
590 ))),
591 }
592}