Skip to main content

laddu_python/extensions/
likelihood.rs

1use std::collections::HashMap;
2
3use ganesh::python::IntoPySummary;
4use laddu_core::{validate_free_parameter_len, LadduError};
5use laddu_extensions::{
6    likelihood::{LikelihoodTerm, StochasticNLL},
7    LikelihoodExpression, LikelihoodScalar, NLL,
8};
9use numpy::{PyArray1, PyArray2, PyArray3};
10use pyo3::{
11    exceptions::{PyTypeError, PyValueError},
12    prelude::*,
13    types::{PyAny, PyList},
14    IntoPyObjectExt,
15};
16
17use crate::{
18    amplitudes::{PyCompiledExpression, PyEvaluator, PyExpression, PyParameterMap},
19    data::PyDataset,
20    extensions::{
21        install_laddu_with_threads,
22        optimize::{mcmc_from_python, minimize_from_python},
23    },
24};
25
26#[cfg_attr(coverage_nightly, coverage(off))]
27fn extract_subset_names(subset: Option<Bound<'_, PyAny>>) -> PyResult<Option<Vec<String>>> {
28    let Some(subset) = subset else {
29        return Ok(None);
30    };
31    if let Ok(string_arg) = subset.extract::<String>() {
32        Ok(Some(vec![string_arg]))
33    } else if let Ok(list_arg) = subset.extract::<Vec<String>>() {
34        Ok(Some(list_arg))
35    } else {
36        Err(PyTypeError::new_err(
37            "subset must be either a string or a list of strings",
38        ))
39    }
40}
41
42#[cfg_attr(coverage_nightly, coverage(off))]
43fn extract_subsets_arg(
44    subsets: Option<Bound<'_, PyAny>>,
45) -> PyResult<Option<Vec<Option<Vec<String>>>>> {
46    let Some(subsets) = subsets else {
47        return Ok(None);
48    };
49    subsets
50        .extract::<Vec<Option<Vec<String>>>>()
51        .map(Some)
52        .map_err(|_| {
53            PyTypeError::new_err(
54                "subsets must be a list whose items are either None or lists of strings",
55            )
56        })
57}
58
59/// Python wrapper for [`LikelihoodExpression`].
60#[pyclass(name = "LikelihoodExpression", module = "laddu", from_py_object)]
61#[derive(Clone)]
62pub struct PyLikelihoodExpression(pub LikelihoodExpression);
63
64/// A convenience method to sum sequences of [`LikelihoodExpression`]s or identifiers.
65///
66/// Parameters
67/// ----------
68/// terms : sequence of LikelihoodExpression
69///     A non-empty sequence whose elements are summed. Single-element sequences are returned
70///     unchanged while empty sequences evaluate to ``LikelihoodZero``.
71///
72/// Returns
73/// -------
74/// LikelihoodExpression
75///     A new expression representing the sum of all inputs.
76///
77/// See Also
78/// --------
79/// likelihood_product
80/// LikelihoodZero
81///
82/// Examples
83/// --------
84/// >>> from laddu import LikelihoodScalar, likelihood_sum
85/// >>> expression = likelihood_sum([LikelihoodScalar('alpha')])
86/// >>> expression.evaluate([0.5])
87/// 0.5
88/// >>> likelihood_sum([]).evaluate([])
89/// 0.0
90///
91/// Notes
92/// -----
93/// When multiple inputs share the same parameter name, the value and fixed/free status from the
94/// earliest term in the sequence take precedence.
95#[cfg_attr(coverage_nightly, coverage(off))]
96#[pyfunction(name = "likelihood_sum")]
97pub fn py_likelihood_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyLikelihoodExpression> {
98    if terms.is_empty() {
99        return Ok(PyLikelihoodExpression(LikelihoodExpression::zero()));
100    }
101    if terms.len() == 1 {
102        let term = &terms[0];
103        if let Ok(expression) = term.extract::<PyLikelihoodExpression>() {
104            return Ok(expression);
105        }
106        return Err(PyTypeError::new_err("Item is not a PyLikelihoodExpression"));
107    }
108    let mut iter = terms.iter();
109    let Some(first_term) = iter.next() else {
110        return Ok(PyLikelihoodExpression(LikelihoodExpression::zero()));
111    };
112    let PyLikelihoodExpression(mut summation) = first_term
113        .extract::<PyLikelihoodExpression>()
114        .map_err(|_| PyTypeError::new_err("Elements must be PyLikelihoodExpression"))?;
115    for term in iter {
116        let PyLikelihoodExpression(expr) = term
117            .extract::<PyLikelihoodExpression>()
118            .map_err(|_| PyTypeError::new_err("Elements must be PyLikelihoodExpression"))?;
119        summation = summation + expr;
120    }
121    Ok(PyLikelihoodExpression(summation))
122}
123
124/// A convenience method to multiply sequences of [`LikelihoodExpression`]s.
125///
126/// Parameters
127/// ----------
128/// terms : sequence of LikelihoodExpression
129///     A non-empty sequence whose elements are multiplied. Single-element sequences are returned
130///     unchanged while empty sequences evaluate to ``LikelihoodOne``.
131///
132/// Returns
133/// -------
134/// LikelihoodExpression
135///     A new expression representing the product of all inputs.
136///
137/// See Also
138/// --------
139/// likelihood_sum
140/// LikelihoodOne
141///
142/// Examples
143/// --------
144/// >>> from laddu import LikelihoodScalar, likelihood_product
145/// >>> expression = likelihood_product([LikelihoodScalar('alpha'), LikelihoodScalar('beta')])
146/// >>> expression.parameters
147/// ['alpha', 'beta']
148/// >>> expression.evaluate([2.0, 3.0])
149/// 6.0
150///
151/// Notes
152/// -----
153/// When parameters overlap between inputs, the parameter definition from the earliest term is used.
154#[cfg_attr(coverage_nightly, coverage(off))]
155#[pyfunction(name = "likelihood_product")]
156pub fn py_likelihood_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyLikelihoodExpression> {
157    if terms.is_empty() {
158        return Ok(PyLikelihoodExpression(LikelihoodExpression::one()));
159    }
160    if terms.len() == 1 {
161        let term = &terms[0];
162        if let Ok(expression) = term.extract::<PyLikelihoodExpression>() {
163            return Ok(expression);
164        }
165        return Err(PyTypeError::new_err("Item is not a PyLikelihoodExpression"));
166    }
167    let mut iter = terms.iter();
168    let Some(first_term) = iter.next() else {
169        return Ok(PyLikelihoodExpression(LikelihoodExpression::one()));
170    };
171    let PyLikelihoodExpression(mut product) = first_term
172        .extract::<PyLikelihoodExpression>()
173        .map_err(|_| PyTypeError::new_err("Elements must be PyLikelihoodExpression"))?;
174    for term in iter {
175        let PyLikelihoodExpression(expr) = term
176            .extract::<PyLikelihoodExpression>()
177            .map_err(|_| PyTypeError::new_err("Elements must be PyLikelihoodExpression"))?;
178        product = product * expr;
179    }
180    Ok(PyLikelihoodExpression(product))
181}
182
183/// A convenience constructor for a zero-valued [`LikelihoodExpression`].
184///
185/// Returns
186/// -------
187/// LikelihoodExpression
188///     An expression that evaluates to ``0`` for any parameter values.
189///
190/// See Also
191/// --------
192/// LikelihoodOne
193/// likelihood_sum
194///
195/// Examples
196/// --------
197/// >>> from laddu import LikelihoodZero
198/// >>> expression = LikelihoodZero()
199/// >>> expression.parameters
200/// []
201/// >>> expression.evaluate([])
202/// 0.0
203#[cfg_attr(coverage_nightly, coverage(off))]
204#[pyfunction(name = "LikelihoodZero")]
205pub fn py_likelihood_zero() -> PyLikelihoodExpression {
206    PyLikelihoodExpression(LikelihoodExpression::zero())
207}
208
209/// A convenience constructor for a unit-valued [`LikelihoodExpression`].
210///
211/// Returns
212/// -------
213/// LikelihoodExpression
214///     An expression that evaluates to ``1`` for any parameter values.
215///
216/// See Also
217/// --------
218/// LikelihoodZero
219/// likelihood_product
220///
221/// Examples
222/// --------
223/// >>> from laddu import LikelihoodOne
224/// >>> LikelihoodOne().evaluate([])
225/// 1.0
226#[cfg_attr(coverage_nightly, coverage(off))]
227#[pyfunction(name = "LikelihoodOne")]
228pub fn py_likelihood_one() -> PyLikelihoodExpression {
229    PyLikelihoodExpression(LikelihoodExpression::one())
230}
231
232#[cfg_attr(coverage_nightly, coverage(off))]
233#[pymethods]
234impl PyLikelihoodExpression {
235    /// Parameters referenced by the expression.
236    #[getter]
237    fn parameters(&self) -> PyParameterMap {
238        PyParameterMap(self.0.parameters())
239    }
240
241    /// Fix a parameter to a constant value.
242    ///
243    /// Parameters
244    /// ----------
245    /// name : str
246    ///     Name of the parameter.
247    /// value : float
248    ///     Value used during evaluation.
249    ///
250    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
251        Ok(self.0.fix_parameter(name, value)?)
252    }
253
254    /// Free a parameter that was previously fixed.
255    ///
256    /// Parameters
257    /// ----------
258    /// name : str
259    ///     Name of the parameter.
260    ///
261    fn free_parameter(&self, name: &str) -> PyResult<()> {
262        Ok(self.0.free_parameter(name)?)
263    }
264
265    /// Rename a parameter.
266    ///
267    /// Parameters
268    /// ----------
269    /// old : str
270    ///     Current parameter name.
271    /// new : str
272    ///     Desired parameter name.
273    ///
274    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
275        Ok(self.0.rename_parameter(old, new)?)
276    }
277
278    /// Rename multiple parameters at once.
279    ///
280    /// Parameters
281    /// ----------
282    /// mapping : dict[str, str]
283    ///     Mapping from old names to new names.
284    ///
285    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
286        Ok(self.0.rename_parameters(&mapping)?)
287    }
288    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyLikelihoodExpression> {
289        if let Ok(other_expr) = other.extract::<PyLikelihoodExpression>() {
290            Ok(PyLikelihoodExpression(
291                self.0.clone() + other_expr.0.clone(),
292            ))
293        } else if let Ok(int) = other.extract::<usize>() {
294            if int == 0 {
295                Ok(PyLikelihoodExpression(self.0.clone()))
296            } else {
297                Err(PyTypeError::new_err(
298                    "Addition with an integer for this type is only defined for 0",
299                ))
300            }
301        } else {
302            Err(PyTypeError::new_err("Unsupported operand type for +"))
303        }
304    }
305    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyLikelihoodExpression> {
306        if let Ok(other_expr) = other.extract::<PyLikelihoodExpression>() {
307            Ok(PyLikelihoodExpression(
308                other_expr.0.clone() + self.0.clone(),
309            ))
310        } else if let Ok(int) = other.extract::<usize>() {
311            if int == 0 {
312                Ok(PyLikelihoodExpression(self.0.clone()))
313            } else {
314                Err(PyTypeError::new_err(
315                    "Addition with an integer for this type is only defined for 0",
316                ))
317            }
318        } else {
319            Err(PyTypeError::new_err("Unsupported operand type for +"))
320        }
321    }
322    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyLikelihoodExpression> {
323        if let Ok(other_expr) = other.extract::<PyLikelihoodExpression>() {
324            Ok(PyLikelihoodExpression(
325                self.0.clone() * other_expr.0.clone(),
326            ))
327        } else {
328            Err(PyTypeError::new_err("Unsupported operand type for *"))
329        }
330    }
331    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyLikelihoodExpression> {
332        if let Ok(other_expr) = other.extract::<PyLikelihoodExpression>() {
333            Ok(PyLikelihoodExpression(
334                other_expr.0.clone() * self.0.clone(),
335            ))
336        } else {
337            Err(PyTypeError::new_err("Unsupported operand type for *"))
338        }
339    }
340    fn __str__(&self) -> String {
341        format!("{}", self.0)
342    }
343    fn __repr__(&self) -> String {
344        format!("{:?}", self.0)
345    }
346
347    /// Number of free parameters in the expression.
348    #[getter]
349    fn n_free(&self) -> usize {
350        self.0.n_free()
351    }
352
353    /// Number of fixed parameters in the expression.
354    #[getter]
355    fn n_fixed(&self) -> usize {
356        self.0.n_fixed()
357    }
358
359    /// Total number of parameters (free + fixed).
360    #[getter]
361    fn n_parameters(&self) -> usize {
362        self.0.n_parameters()
363    }
364
365    /// Evaluate the sum of all terms in the expression.
366    ///
367    /// Parameters
368    /// ----------
369    /// parameters : list of float
370    ///     Parameter values for the free parameters (length ``n_free``).
371    /// threads : int, optional
372    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
373    ///     global or context-managed default; any positive value overrides that default for
374    ///     this call only)
375    ///
376    /// Returns
377    /// -------
378    /// result : float
379    ///     The total negative log-likelihood summed over all terms
380    ///
381    /// Raises
382    /// ------
383    /// Exception
384    ///     If there was an error building the thread pool
385    ///
386    #[pyo3(signature = (parameters, *, threads=None))]
387    fn evaluate(&self, parameters: Vec<f64>, threads: Option<usize>) -> PyResult<f64> {
388        validate_free_parameter_len(parameters.len(), self.0.n_free())?;
389        install_laddu_with_threads(threads, || self.0.evaluate(&parameters)).map_err(PyErr::from)
390    }
391    /// Evaluate the gradient of the sum of all terms in the expression.
392    ///
393    /// Parameters
394    /// ----------
395    /// parameters : list of float
396    ///     Parameter values for the free parameters (length ``n_free``).
397    /// threads : int, optional
398    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
399    ///     global or context-managed default; any positive value overrides that default for
400    ///     this call only)
401    ///
402    /// Returns
403    /// -------
404    /// result : array_like
405    ///     A ``numpy`` array representing the gradient of the sum of all terms in the
406    ///     evaluator with length ``n_free``.
407    ///
408    /// Raises
409    /// ------
410    /// Exception
411    ///     If there was an error building the thread pool or problem creating the resulting
412    ///     ``numpy`` array
413    ///
414    #[pyo3(signature = (parameters, *, threads=None))]
415    fn evaluate_gradient<'py>(
416        &self,
417        py: Python<'py>,
418        parameters: Vec<f64>,
419        threads: Option<usize>,
420    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
421        validate_free_parameter_len(parameters.len(), self.0.n_free())?;
422        let gradient =
423            install_laddu_with_threads(threads, || self.0.evaluate_gradient(&parameters))?;
424        Ok(PyArray1::from_slice(py, gradient.as_slice()))
425    }
426    #[cfg_attr(doctest, doc = "```ignore")]
427    /// Minimize the LikelihoodTerm with respect to the free parameters in the model
428    ///
429    /// This method "runs the fit". Given an initial position `p0`, this
430    /// method performs a minimization over the likelihood term, optimizing the model
431    /// over the stored signal data and Monte Carlo.
432    ///
433    /// Parameters
434    /// ----------
435    /// p0 : array_like or ganesh.NelderMeadInit or ganesh.PSOInit or ganesh.CMAESInit or ganesh.DifferentialEvolutionInit
436    ///     Initial state for the selected minimizer. Use a length-``n_free`` vector for
437    ///     ``lbfgsb``, ``adam``, ``conjugate-gradient``, and ``trust-region``; either a
438    ///     vector or ``ganesh.NelderMeadInit`` for ``nelder-mead``; a ``ganesh.CMAESInit``
439    ///     for ``cma-es``; a ``ganesh.DifferentialEvolutionInit`` for ``differential-evolution``,
440    ///     and either a 2D swarm array or ``ganesh.PSOInit`` for ``pso``.
441    /// method : {'lbfgsb', 'adam', 'conjugate-gradient', 'trust-region', 'nelder-mead', 'cma-es', 'differential-evolution', 'pso'}
442    ///     The minimization algorithm to use
443    /// config : ganesh config object, optional
444    ///     Method-specific Ganesh configuration, such as ``ganesh.LBFGSBConfig`` or
445    ///     ``ganesh.PSOConfig``. Bounds are configured here when supported by the method.
446    /// options : ganesh options object, optional
447    ///     Method-specific Ganesh options object controlling step limits, built-in observers,
448    ///     and built-in terminators.
449    /// observers : MinimizerObserver or list of MinimizerObserver, optional
450    ///     User-defined observers which are called at each step
451    /// terminators : MinimizerTerminator or list of MinimizerTerminator, optional
452    ///     User-defined terminators which are called at each step
453    /// threads : int, default=0
454    ///     The number of threads to use (setting this to ``0`` uses the current global or
455    ///     context-managed default; any positive value overrides that default for this call
456    ///     only)
457    ///
458    /// Returns
459    /// -------
460    /// MinimizationSummary
461    ///     A summary of the minimization algorithm at termination
462    ///
463    /// Raises
464    /// ------
465    /// Exception
466    ///     If there was an error building the thread pool
467    ///
468    /// Examples
469    /// --------
470    /// >>> import ganesh
471    /// >>> expression.minimize(
472    /// ...     [1.0],
473    /// ...     method='lbfgsb',
474    /// ...     options=ganesh.LBFGSBOptions(max_steps=150),
475    /// ... )  # doctest: +SKIP
476    ///
477    /// Notes
478    /// -----
479    /// ``config`` and ``options`` use Ganesh's Python API directly. For example, pass
480    /// ``ganesh.LBFGSBConfig(bounds=[...])`` for bounded L-BFGS-B, or
481    /// ``ganesh.AdamOptions(max_steps=500)`` to cap the number of Adam iterations.
482    ///
483    /// References
484    /// ----------
485    /// Gao, F. & Han, L. (2010). *Implementing the Nelder-Mead simplex algorithm with adaptive
486    /// parameters*. Comput. Optim. Appl. 51(1), 259–277. <https://doi.org/10.1007/s10589-010-9329-3>
487    ///
488    /// Lagarias, J. C., Reeds, J. A., Wright, M. H., & Wright, P. E. (1998). *Convergence Properties
489    /// of the Nelder–Mead Simplex Method in Low Dimensions*. SIAM J. Optim. 9(1), 112–147.
490    /// <https://doi.org/10.1137/S1052623496303470>
491    ///
492    /// Singer, S. & Singer, S. (2004). *Efficient Implementation of the Nelder–Mead Search Algorithm*.
493    /// Appl. Numer. Anal. & Comput. 1(2), 524–534. <https://doi.org/10.1002/anac.200410015>
494    ///
495    #[cfg_attr(doctest, doc = "```")]
496    #[pyo3(signature = (p0, *, method="lbfgsb".to_string(), config=None, options=None, observers=None, terminators=None, threads=0))]
497    #[allow(clippy::too_many_arguments)]
498    fn minimize<'py>(
499        &self,
500        py: Python<'py>,
501        p0: Bound<'_, PyAny>,
502        method: String,
503        config: Option<Bound<'_, PyAny>>,
504        options: Option<Bound<'_, PyAny>>,
505        observers: Option<Bound<'_, PyAny>>,
506        terminators: Option<Bound<'_, PyAny>>,
507        threads: usize,
508    ) -> PyResult<Bound<'py, PyAny>> {
509        let parameter_names = self.0.parameters().free().names();
510        minimize_from_python(
511            &self.0,
512            &p0,
513            self.0.n_free(),
514            &parameter_names,
515            method,
516            config.as_ref(),
517            options.as_ref(),
518            observers,
519            terminators,
520            threads,
521        )?
522        .to_py_class(py)
523    }
524    /// Run an MCMC algorithm on the free parameters of the LikelihoodTerm's model
525    ///
526    /// This method can be used to sample the underlying likelihood term given an initial
527    /// position for each walker `p0`.
528    ///
529    /// Parameters
530    /// ----------
531    /// p0 : array_like or ganesh.AIESInit or ganesh.ESSInit
532    ///     Initial sampler state. Use a 2D walker matrix with shape
533    ///     ``(n_walkers, n_parameters)`` for the common case, or pass an explicit Ganesh
534    ///     init object for method-specific initialization.
535    /// method : {'aies', 'ess'}
536    ///     The MCMC algorithm to use
537    /// config : ganesh config object, optional
538    ///     Method-specific Ganesh configuration, such as ``ganesh.AIESConfig`` or
539    ///     ``ganesh.ESSConfig``.
540    /// options : ganesh options object, optional
541    ///     Method-specific Ganesh options object controlling step limits, built-in observers,
542    ///     and built-in terminators.
543    /// observers : MCMCObserver or list of MCMCObserver, optional
544    ///     User-defined observers which are called at each step
545    /// terminators : MCMCTerminator or list of MCMCTerminator, optional
546    ///     User-defined terminators which are called at each step
547    /// threads : int, default=0
548    ///     The number of threads to use (setting this to ``0`` uses the current global or
549    ///     context-managed default; any positive value overrides that default for this call
550    ///     only)
551    ///
552    /// Returns
553    /// -------
554    /// MCMCSummary
555    ///     The status of the MCMC algorithm at termination
556    ///
557    /// Raises
558    /// ------
559    /// Exception
560    ///     If there was an error building the thread pool
561    ///
562    /// See Also
563    /// --------
564    /// NLL.mcmc
565    /// StochasticNLL.mcmc
566    ///
567    /// Examples
568    /// --------
569    /// >>> from laddu import LikelihoodScalar, likelihood_sum
570    /// >>> import ganesh
571    /// >>> expression = likelihood_sum([LikelihoodScalar('alpha')])
572    /// >>> summary = expression.mcmc(
573    /// ...     [[0.0], [0.4]],
574    /// ...     method='aies',
575    /// ...     options=ganesh.AIESOptions(max_steps=4),
576    /// ... )
577    /// >>> summary.dimension[2]
578    /// 1
579    /// >>> summary.chain(flat=True).shape[1]
580    /// 1
581    ///
582    /// Notes
583    /// -----
584    /// ``config`` and ``options`` use Ganesh's Python API directly. For example, custom
585    /// move mixes belong in ``ganesh.AIESConfig`` or ``ganesh.ESSConfig``, while
586    /// ``ganesh.AIESOptions(max_steps=...)`` and ``ganesh.ESSOptions(max_steps=...)`` control
587    /// run limits.
588    ///
589    /// References
590    /// ----------
591    /// Goodman, J. & Weare, J. (2010). *Ensemble samplers with affine invariance*. CAMCoS 5(1), 65–80. <https://doi.org/10.2140/camcos.2010.5.65>
592    ///
593    /// Karamanis, M. & Beutler, F. (2021). *Ensemble slice sampling*. Stat Comput 31(5). <https://doi.org/10.1007/s11222-021-10038-2>
594    ///
595    #[pyo3(signature = (p0, *, method="aies".to_string(), config=None, options=None, observers=None, terminators=None, threads=0))]
596    #[allow(clippy::too_many_arguments)]
597    fn mcmc<'py>(
598        &self,
599        py: Python<'py>,
600        p0: Bound<'_, PyAny>,
601        method: String,
602        config: Option<Bound<'_, PyAny>>,
603        options: Option<Bound<'_, PyAny>>,
604        observers: Option<Bound<'_, PyAny>>,
605        terminators: Option<Bound<'_, PyAny>>,
606        threads: usize,
607    ) -> PyResult<Bound<'py, PyAny>> {
608        let parameter_names = self.0.parameters().free().names();
609        mcmc_from_python(
610            &self.0,
611            &p0,
612            self.0.n_free(),
613            &parameter_names,
614            method,
615            config.as_ref(),
616            options.as_ref(),
617            observers,
618            terminators,
619            threads,
620        )?
621        .to_py_class(py)
622    }
623}
624
625/// A (extended) negative log-likelihood evaluator.
626#[pyclass(name = "NLL", module = "laddu", from_py_object)]
627#[derive(Clone)]
628pub struct PyNLL(pub Box<NLL>);
629
630#[cfg_attr(coverage_nightly, coverage(off))]
631#[pymethods]
632impl PyNLL {
633    #[new]
634    #[pyo3(signature = (expression, ds_data, ds_accmc, *, n_mc=None))]
635    fn new(
636        expression: &PyExpression,
637        ds_data: &PyDataset,
638        ds_accmc: &PyDataset,
639        n_mc: Option<f64>,
640    ) -> PyResult<Self> {
641        Ok(Self(NLL::new(
642            &expression.0,
643            &ds_data.0,
644            &ds_accmc.0,
645            n_mc,
646        )?))
647    }
648
649    #[getter]
650    fn data(&self) -> PyDataset {
651        PyDataset(self.0.data_evaluator.dataset.clone())
652    }
653
654    #[getter]
655    fn accmc(&self) -> PyDataset {
656        PyDataset(self.0.accmc_evaluator.dataset.clone())
657    }
658
659    #[getter]
660    fn data_evaluator(&self) -> PyEvaluator {
661        PyEvaluator(self.0.data_evaluator.clone())
662    }
663
664    #[getter]
665    fn accmc_evaluator(&self) -> PyEvaluator {
666        PyEvaluator(self.0.accmc_evaluator.clone())
667    }
668
669    #[getter]
670    fn expression(&self) -> PyExpression {
671        PyExpression(self.0.expression())
672    }
673
674    #[getter]
675    fn compiled_expression(&self) -> PyCompiledExpression {
676        PyCompiledExpression(self.0.compiled_expression())
677    }
678
679    #[pyo3(signature = (batch_size, *, seed=None))]
680    fn to_stochastic(&self, batch_size: usize, seed: Option<usize>) -> PyResult<PyStochasticNLL> {
681        Ok(PyStochasticNLL(self.0.to_stochastic(batch_size, seed)?))
682    }
683
684    fn to_expression(&self) -> PyResult<PyLikelihoodExpression> {
685        Ok(PyLikelihoodExpression(self.0.clone().into_expression()?))
686    }
687
688    #[getter]
689    fn parameters(&self) -> PyParameterMap {
690        PyParameterMap(self.0.parameters())
691    }
692
693    #[getter]
694    fn n_free(&self) -> usize {
695        self.0.n_free()
696    }
697
698    #[getter]
699    fn n_fixed(&self) -> usize {
700        self.0.n_fixed()
701    }
702
703    #[getter]
704    fn n_parameters(&self) -> usize {
705        self.0.n_parameters()
706    }
707
708    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
709        Ok(self.0.fix_parameter(name, value)?)
710    }
711
712    fn free_parameter(&self, name: &str) -> PyResult<()> {
713        Ok(self.0.free_parameter(name)?)
714    }
715
716    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
717        Ok(self.0.rename_parameter(old, new)?)
718    }
719
720    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
721        Ok(self.0.rename_parameters(&mapping)?)
722    }
723
724    #[pyo3(signature = (arg, *, strict=true))]
725    fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
726        if let Ok(string_arg) = arg.extract::<String>() {
727            if strict {
728                self.0.activate_strict(&string_arg)?;
729            } else {
730                self.0.activate(&string_arg);
731            }
732        } else if let Ok(list_arg) = arg.cast::<PyList>() {
733            let vec: Vec<String> = list_arg.extract()?;
734            if strict {
735                self.0.activate_many_strict(&vec)?;
736            } else {
737                self.0.activate_many(&vec);
738            }
739        } else {
740            return Err(PyTypeError::new_err(
741                "Argument must be either a string or a list of strings",
742            ));
743        }
744        Ok(())
745    }
746
747    fn activate_all(&self) {
748        self.0.activate_all();
749    }
750
751    #[pyo3(signature = (arg, *, strict=true))]
752    fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
753        if let Ok(string_arg) = arg.extract::<String>() {
754            if strict {
755                self.0.deactivate_strict(&string_arg)?;
756            } else {
757                self.0.deactivate(&string_arg);
758            }
759        } else if let Ok(list_arg) = arg.cast::<PyList>() {
760            let vec: Vec<String> = list_arg.extract()?;
761            if strict {
762                self.0.deactivate_many_strict(&vec)?;
763            } else {
764                self.0.deactivate_many(&vec);
765            }
766        } else {
767            return Err(PyTypeError::new_err(
768                "Argument must be either a string or a list of strings",
769            ));
770        }
771        Ok(())
772    }
773
774    fn deactivate_all(&self) {
775        self.0.deactivate_all();
776    }
777
778    #[pyo3(signature = (arg, *, strict=true))]
779    fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
780        if let Ok(string_arg) = arg.extract::<String>() {
781            if strict {
782                self.0.isolate_strict(&string_arg)?;
783            } else {
784                self.0.isolate(&string_arg);
785            }
786        } else if let Ok(list_arg) = arg.cast::<PyList>() {
787            let vec: Vec<String> = list_arg.extract()?;
788            if strict {
789                self.0.isolate_many_strict(&vec)?;
790            } else {
791                self.0.isolate_many(&vec);
792            }
793        } else {
794            return Err(PyTypeError::new_err(
795                "Argument must be either a string or a list of strings",
796            ));
797        }
798        Ok(())
799    }
800
801    #[pyo3(signature = (parameters, *, threads=None))]
802    fn evaluate(&self, parameters: Vec<f64>, threads: Option<usize>) -> PyResult<f64> {
803        validate_free_parameter_len(parameters.len(), self.0.n_free())?;
804        install_laddu_with_threads(threads, || {
805            LikelihoodTerm::evaluate(self.0.as_ref(), &parameters)
806        })
807        .map_err(PyErr::from)
808    }
809
810    #[pyo3(signature = (parameters, *, threads=None))]
811    fn evaluate_gradient<'py>(
812        &self,
813        py: Python<'py>,
814        parameters: Vec<f64>,
815        threads: Option<usize>,
816    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
817        validate_free_parameter_len(parameters.len(), self.0.n_free())?;
818        let gradient = install_laddu_with_threads(threads, || {
819            LikelihoodTerm::evaluate_gradient(self.0.as_ref(), &parameters)
820        })?;
821        Ok(PyArray1::from_slice(py, gradient.as_slice()))
822    }
823
824    #[allow(clippy::too_many_arguments)]
825    #[pyo3(signature = (
826        parameters,
827        *,
828        subset = None,
829        subsets = None,
830        strict = false,
831        mc_evaluator = None,
832        threads = None
833    ))]
834    fn project_weights<'py>(
835        &self,
836        py: Python<'py>,
837        parameters: Vec<f64>,
838        subset: Option<Bound<'_, PyAny>>,
839        subsets: Option<Bound<'_, PyAny>>,
840        strict: bool,
841        mc_evaluator: Option<PyEvaluator>,
842        threads: Option<usize>,
843    ) -> PyResult<Bound<'py, PyAny>> {
844        validate_free_parameter_len(parameters.len(), self.0.n_free())?;
845        if subset.is_some() && subsets.is_some() {
846            return Err(PyValueError::new_err(
847                "subset and subsets are mutually exclusive",
848            ));
849        }
850        let subset = extract_subset_names(subset)?;
851        let subsets = extract_subsets_arg(subsets)?;
852        let mc_evaluator = mc_evaluator.map(|pyeval| pyeval.0.clone());
853        match (subset, subsets) {
854            (Some(names), None) => {
855                let projection = install_laddu_with_threads(threads, || {
856                    if strict {
857                        self.0.project_weights_subset_strict(
858                            &parameters,
859                            &names,
860                            mc_evaluator.clone(),
861                        )
862                    } else {
863                        self.0
864                            .project_weights_subset(&parameters, &names, mc_evaluator.clone())
865                    }
866                })?;
867                Ok(PyArray1::from_slice(py, projection.as_slice()).into_any())
868            }
869            (None, Some(subsets)) => {
870                let projection = install_laddu_with_threads(threads, || {
871                    let mut rows = Vec::with_capacity(subsets.len());
872                    for subset in &subsets {
873                        let weights = match subset {
874                            Some(names) => {
875                                if strict {
876                                    self.0.project_weights_subset_strict(
877                                        &parameters,
878                                        names,
879                                        mc_evaluator.clone(),
880                                    )?
881                                } else {
882                                    self.0.project_weights_subset(
883                                        &parameters,
884                                        names,
885                                        mc_evaluator.clone(),
886                                    )?
887                                }
888                            }
889                            None => self.0.project_weights(&parameters, mc_evaluator.clone())?,
890                        };
891                        rows.push(weights);
892                    }
893                    Ok::<_, LadduError>(rows)
894                })?;
895                Ok(PyArray2::from_vec2(py, &projection)
896                    .map_err(LadduError::NumpyError)?
897                    .into_any())
898            }
899            (None, None) => {
900                let projection = install_laddu_with_threads(threads, || {
901                    self.0.project_weights(&parameters, mc_evaluator.clone())
902                })?;
903                Ok(PyArray1::from_slice(py, projection.as_slice()).into_any())
904            }
905            (Some(_), Some(_)) => unreachable!("checked above"),
906        }
907    }
908
909    #[allow(clippy::too_many_arguments)]
910    #[pyo3(signature = (
911        parameters,
912        *,
913        subset = None,
914        subsets = None,
915        strict = false,
916        mc_evaluator = None,
917        threads = None
918    ))]
919    fn project_weights_and_gradients<'py>(
920        &self,
921        py: Python<'py>,
922        parameters: Vec<f64>,
923        subset: Option<Bound<'_, PyAny>>,
924        subsets: Option<Bound<'_, PyAny>>,
925        strict: bool,
926        mc_evaluator: Option<PyEvaluator>,
927        threads: Option<usize>,
928    ) -> PyResult<Bound<'py, PyAny>> {
929        validate_free_parameter_len(parameters.len(), self.0.n_free())?;
930        if subset.is_some() && subsets.is_some() {
931            return Err(PyValueError::new_err(
932                "subset and subsets are mutually exclusive",
933            ));
934        }
935        let subset = extract_subset_names(subset)?;
936        let subsets = extract_subsets_arg(subsets)?;
937        let mc_evaluator = mc_evaluator.map(|pyeval| pyeval.0.clone());
938        match (subset, subsets) {
939            (Some(names), None) => {
940                let (weights, gradients) = install_laddu_with_threads(threads, || {
941                    if strict {
942                        self.0.project_weights_and_gradients_subset_strict(
943                            &parameters,
944                            &names,
945                            mc_evaluator.clone(),
946                        )
947                    } else {
948                        self.0.project_weights_and_gradients_subset(
949                            &parameters,
950                            &names,
951                            mc_evaluator.clone(),
952                        )
953                    }
954                })?;
955                let gradients = gradients
956                    .iter()
957                    .map(|gradient| gradient.as_slice().to_vec())
958                    .collect::<Vec<_>>();
959                (
960                    PyArray1::from_slice(py, weights.as_slice()),
961                    PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?,
962                )
963                    .into_bound_py_any(py)
964            }
965            (None, Some(subsets)) => {
966                let (weights, gradients) = install_laddu_with_threads(threads, || {
967                    let mut weight_rows = Vec::with_capacity(subsets.len());
968                    let mut gradient_rows = Vec::with_capacity(subsets.len());
969                    for subset in &subsets {
970                        let (subset_weights, subset_gradients) = match subset {
971                            Some(names) => {
972                                if strict {
973                                    self.0.project_weights_and_gradients_subset_strict(
974                                        &parameters,
975                                        names,
976                                        mc_evaluator.clone(),
977                                    )?
978                                } else {
979                                    self.0.project_weights_and_gradients_subset(
980                                        &parameters,
981                                        names,
982                                        mc_evaluator.clone(),
983                                    )?
984                                }
985                            }
986                            None => self
987                                .0
988                                .project_weights_and_gradients(&parameters, mc_evaluator.clone())?,
989                        };
990                        weight_rows.push(subset_weights);
991                        gradient_rows.push(
992                            subset_gradients
993                                .iter()
994                                .map(|gradient| gradient.as_slice().to_vec())
995                                .collect::<Vec<_>>(),
996                        );
997                    }
998                    Ok::<_, LadduError>((weight_rows, gradient_rows))
999                })?;
1000                (
1001                    PyArray2::from_vec2(py, &weights).map_err(LadduError::NumpyError)?,
1002                    PyArray3::from_vec3(py, &gradients).map_err(LadduError::NumpyError)?,
1003                )
1004                    .into_bound_py_any(py)
1005            }
1006            (None, None) => {
1007                let (weights, gradients) = install_laddu_with_threads(threads, || {
1008                    self.0
1009                        .project_weights_and_gradients(&parameters, mc_evaluator.clone())
1010                })?;
1011                let gradients = gradients
1012                    .iter()
1013                    .map(|gradient| gradient.as_slice().to_vec())
1014                    .collect::<Vec<_>>();
1015                (
1016                    PyArray1::from_slice(py, weights.as_slice()),
1017                    PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?,
1018                )
1019                    .into_bound_py_any(py)
1020            }
1021            (Some(_), Some(_)) => unreachable!("checked above"),
1022        }
1023    }
1024
1025    #[pyo3(signature = (p0, *, method="lbfgsb".to_string(), config=None, options=None, observers=None, terminators=None, threads=0))]
1026    #[allow(clippy::too_many_arguments)]
1027    fn minimize<'py>(
1028        &self,
1029        py: Python<'py>,
1030        p0: Bound<'_, PyAny>,
1031        method: String,
1032        config: Option<Bound<'_, PyAny>>,
1033        options: Option<Bound<'_, PyAny>>,
1034        observers: Option<Bound<'_, PyAny>>,
1035        terminators: Option<Bound<'_, PyAny>>,
1036        threads: usize,
1037    ) -> PyResult<Bound<'py, PyAny>> {
1038        let parameter_names = self.0.parameters().free().names();
1039        minimize_from_python(
1040            self.0.as_ref(),
1041            &p0,
1042            self.0.n_free(),
1043            &parameter_names,
1044            method,
1045            config.as_ref(),
1046            options.as_ref(),
1047            observers,
1048            terminators,
1049            threads,
1050        )?
1051        .to_py_class(py)
1052    }
1053
1054    #[pyo3(signature = (p0, *, method="aies".to_string(), config=None, options=None, observers=None, terminators=None, threads=0))]
1055    #[allow(clippy::too_many_arguments)]
1056    fn mcmc<'py>(
1057        &self,
1058        py: Python<'py>,
1059        p0: Bound<'_, PyAny>,
1060        method: String,
1061        config: Option<Bound<'_, PyAny>>,
1062        options: Option<Bound<'_, PyAny>>,
1063        observers: Option<Bound<'_, PyAny>>,
1064        terminators: Option<Bound<'_, PyAny>>,
1065        threads: usize,
1066    ) -> PyResult<Bound<'py, PyAny>> {
1067        let parameter_names = self.0.parameters().free().names();
1068        mcmc_from_python(
1069            self.0.as_ref(),
1070            &p0,
1071            self.0.n_free(),
1072            &parameter_names,
1073            method,
1074            config.as_ref(),
1075            options.as_ref(),
1076            observers,
1077            terminators,
1078            threads,
1079        )?
1080        .to_py_class(py)
1081    }
1082}
1083
1084/// A stochastic (extended) negative log-likelihood evaluator.
1085#[pyclass(name = "StochasticNLL", module = "laddu", skip_from_py_object)]
1086#[derive(Clone)]
1087pub struct PyStochasticNLL(pub StochasticNLL);
1088
1089#[cfg_attr(coverage_nightly, coverage(off))]
1090#[pymethods]
1091impl PyStochasticNLL {
1092    #[getter]
1093    fn nll(&self) -> PyNLL {
1094        PyNLL(Box::new(self.0.nll.clone()))
1095    }
1096
1097    #[getter]
1098    fn expression(&self) -> PyExpression {
1099        PyExpression(self.0.expression())
1100    }
1101
1102    #[getter]
1103    fn compiled_expression(&self) -> PyCompiledExpression {
1104        PyCompiledExpression(self.0.compiled_expression())
1105    }
1106
1107    #[pyo3(signature = (p0, *, method="lbfgsb".to_string(), config=None, options=None, observers=None, terminators=None, threads=0))]
1108    #[allow(clippy::too_many_arguments)]
1109    fn minimize<'py>(
1110        &self,
1111        py: Python<'py>,
1112        p0: Bound<'_, PyAny>,
1113        method: String,
1114        config: Option<Bound<'_, PyAny>>,
1115        options: Option<Bound<'_, PyAny>>,
1116        observers: Option<Bound<'_, PyAny>>,
1117        terminators: Option<Bound<'_, PyAny>>,
1118        threads: usize,
1119    ) -> PyResult<Bound<'py, PyAny>> {
1120        let parameter_names = self.0.parameters().free().names();
1121        minimize_from_python(
1122            &self.0,
1123            &p0,
1124            self.0.n_free(),
1125            &parameter_names,
1126            method,
1127            config.as_ref(),
1128            options.as_ref(),
1129            observers,
1130            terminators,
1131            threads,
1132        )?
1133        .to_py_class(py)
1134    }
1135
1136    #[pyo3(signature = (p0, *, method="aies".to_string(), config=None, options=None, observers=None, terminators=None, threads=0))]
1137    #[allow(clippy::too_many_arguments)]
1138    fn mcmc<'py>(
1139        &self,
1140        py: Python<'py>,
1141        p0: Bound<'_, PyAny>,
1142        method: String,
1143        config: Option<Bound<'_, PyAny>>,
1144        options: Option<Bound<'_, PyAny>>,
1145        observers: Option<Bound<'_, PyAny>>,
1146        terminators: Option<Bound<'_, PyAny>>,
1147        threads: usize,
1148    ) -> PyResult<Bound<'py, PyAny>> {
1149        let parameter_names = self.0.parameters().free().names();
1150        mcmc_from_python(
1151            &self.0,
1152            &p0,
1153            self.0.n_free(),
1154            &parameter_names,
1155            method,
1156            config.as_ref(),
1157            options.as_ref(),
1158            observers,
1159            terminators,
1160            threads,
1161        )?
1162        .to_py_class(py)
1163    }
1164}
1165
1166/// A parameterized scalar term which can be converted into a [`LikelihoodExpression`].
1167///
1168/// Parameters
1169/// ----------
1170/// name : str
1171///     The name of the new scalar parameter.
1172///
1173/// Returns
1174/// -------
1175/// LikelihoodExpression
1176///     A [`LikelihoodExpression`] representing a single free scaling parameter.
1177///
1178/// See Also
1179/// --------
1180/// likelihood_sum
1181/// likelihood_product
1182///
1183/// Examples
1184/// --------
1185/// >>> from laddu import LikelihoodScalar, likelihood_sum
1186/// >>> expr = likelihood_sum([LikelihoodScalar('alpha')])
1187/// >>> expr.evaluate([1.25])
1188/// 1.25
1189#[cfg_attr(coverage_nightly, coverage(off))]
1190#[pyfunction(name = "LikelihoodScalar")]
1191pub fn py_likelihood_scalar(name: String) -> PyResult<PyLikelihoodExpression> {
1192    Ok(PyLikelihoodExpression(LikelihoodScalar::new(name)?))
1193}