Skip to main content

laddu_python/
amplitudes.rs

1use crate::data::PyDataset;
2use laddu_core::{
3    amplitudes::{Evaluator, Expression, Parameter, TestAmplitude},
4    CompiledExpression, LadduError, LadduResult, ReadWrite, ThreadPoolManager,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9    exceptions::PyTypeError,
10    prelude::*,
11    types::{PyBytes, PyList, PyTuple},
12};
13use std::collections::HashMap;
14
15fn install_with_threads<R: Send>(
16    threads: Option<usize>,
17    op: impl FnOnce() -> R + Send,
18) -> LadduResult<R> {
19    ThreadPoolManager::shared().install(threads, op)
20}
21
22/// A mathematical expression formed from amplitudes.
23///
24#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
25#[derive(Clone)]
26pub struct PyExpression(pub Expression);
27
28impl<'py> FromPyObject<'_, 'py> for PyExpression {
29    type Error = PyErr;
30
31    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
32        if let Ok(obj) = obj.cast::<PyExpression>() {
33            Ok(obj.borrow().clone())
34        } else if let Ok(obj) = obj.extract::<f64>() {
35            Ok(Self(obj.into()))
36        } else if let Ok(obj) = obj.extract::<Complex64>() {
37            Ok(Self(obj.into()))
38        } else {
39            Err(PyTypeError::new_err("Failed to extract Expression"))
40        }
41    }
42}
43
44/// A convenience method to sum sequences of Expressions
45///
46#[pyfunction(name = "expr_sum")]
47pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
48    if terms.is_empty() {
49        return Ok(PyExpression(Expression::zero()));
50    }
51    if terms.len() == 1 {
52        let term = &terms[0];
53        if let Ok(expression) = term.extract::<PyExpression>() {
54            return Ok(expression);
55        }
56        return Err(PyTypeError::new_err("Item is not a PyExpression"));
57    }
58    let mut iter = terms.iter();
59    let Some(first_term) = iter.next() else {
60        return Ok(PyExpression(Expression::zero()));
61    };
62    let PyExpression(mut summation) = first_term
63        .extract::<PyExpression>()
64        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
65    for term in iter {
66        let PyExpression(expr) = term
67            .extract::<PyExpression>()
68            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
69        summation = summation + expr;
70    }
71    Ok(PyExpression(summation))
72}
73
74/// A convenience method to multiply sequences of Expressions
75///
76#[pyfunction(name = "expr_product")]
77pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
78    if terms.is_empty() {
79        return Ok(PyExpression(Expression::one()));
80    }
81    if terms.len() == 1 {
82        let term = &terms[0];
83        if let Ok(expression) = term.extract::<PyExpression>() {
84            return Ok(expression);
85        }
86        return Err(PyTypeError::new_err("Item is not a PyExpression"));
87    }
88    let mut iter = terms.iter();
89    let Some(first_term) = iter.next() else {
90        return Ok(PyExpression(Expression::one()));
91    };
92    let PyExpression(mut product) = first_term
93        .extract::<PyExpression>()
94        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
95    for term in iter {
96        let PyExpression(expr) = term
97            .extract::<PyExpression>()
98            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
99        product = product * expr;
100    }
101    Ok(PyExpression(product))
102}
103
104/// A convenience class representing a zero-valued Expression
105///
106#[pyfunction(name = "Zero")]
107pub fn py_expr_zero() -> PyExpression {
108    PyExpression(Expression::zero())
109}
110
111/// A convenience class representing a unit-valued Expression
112///
113#[pyfunction(name = "One")]
114pub fn py_expr_one() -> PyExpression {
115    PyExpression(Expression::one())
116}
117
118#[pymethods]
119impl PyExpression {
120    /// The free parameters used by the Expression
121    ///
122    /// Returns
123    /// -------
124    /// parameters : tuple of str
125    ///     The tuple of parameter names
126    #[getter]
127    fn parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
128        PyTuple::new(py, self.0.parameters())
129    }
130    /// The free parameters used by the Expression
131    #[getter]
132    fn free_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
133        PyTuple::new(py, self.0.free_parameters())
134    }
135    /// The fixed parameters used by the Expression
136    #[getter]
137    fn fixed_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
138        PyTuple::new(py, self.0.fixed_parameters())
139    }
140    /// Number of free parameters
141    #[getter]
142    fn n_free(&self) -> usize {
143        self.0.n_free()
144    }
145    /// Number of fixed parameters
146    #[getter]
147    fn n_fixed(&self) -> usize {
148        self.0.n_fixed()
149    }
150    /// Total number of parameters
151    #[getter]
152    fn n_parameters(&self) -> usize {
153        self.0.n_parameters()
154    }
155    /// Load an Expression by precalculating each term over the given Dataset
156    ///
157    /// Parameters
158    /// ----------
159    /// dataset : Dataset
160    ///     The Dataset to use in precalculation
161    ///
162    /// Returns
163    /// -------
164    /// Evaluator
165    ///     An object that can be used to evaluate the `expression` over each event in the
166    ///     `dataset`
167    fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
168        Ok(PyEvaluator(self.0.load(&dataset.0)?))
169    }
170    /// The real part of a complex Expression
171    fn real(&self) -> PyExpression {
172        PyExpression(self.0.real())
173    }
174    /// The imaginary part of a complex Expression
175    fn imag(&self) -> PyExpression {
176        PyExpression(self.0.imag())
177    }
178    /// The complex conjugate of a complex Expression
179    fn conj(&self) -> PyExpression {
180        PyExpression(self.0.conj())
181    }
182    /// The norm-squared of a complex Expression
183    fn norm_sqr(&self) -> PyExpression {
184        PyExpression(self.0.norm_sqr())
185    }
186    /// The square root of an Expression
187    fn sqrt(&self) -> PyExpression {
188        PyExpression(self.0.sqrt())
189    }
190    /// Raise an Expression to an int, float, or Expression power
191    fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
192        if let Ok(value) = power.extract::<i32>() {
193            Ok(PyExpression(self.0.powi(value)))
194        } else if let Ok(value) = power.extract::<f64>() {
195            Ok(PyExpression(self.0.powf(value)))
196        } else if let Ok(expression) = power.extract::<PyExpression>() {
197            Ok(PyExpression(self.0.pow(&expression.0)))
198        } else {
199            Err(PyTypeError::new_err(
200                "power must be an int, float, or Expression",
201            ))
202        }
203    }
204    /// The exponential of an Expression
205    fn exp(&self) -> PyExpression {
206        PyExpression(self.0.exp())
207    }
208    /// The sine of an Expression
209    fn sin(&self) -> PyExpression {
210        PyExpression(self.0.sin())
211    }
212    /// The cosine of an Expression
213    fn cos(&self) -> PyExpression {
214        PyExpression(self.0.cos())
215    }
216    /// The natural logarithm of an Expression
217    fn log(&self) -> PyExpression {
218        PyExpression(self.0.log())
219    }
220    /// The complex phase factor exp(i * expression)
221    fn cis(&self) -> PyExpression {
222        PyExpression(self.0.cis())
223    }
224    /// Fix a parameter used by this Expression.
225    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
226        Ok(self.0.fix_parameter(name, value)?)
227    }
228    /// Mark a parameter used by this Expression as free.
229    fn free_parameter(&self, name: &str) -> PyResult<()> {
230        Ok(self.0.free_parameter(name)?)
231    }
232    /// Rename a single parameter used by this Expression.
233    fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
234        Ok(self.0.rename_parameter(old, new)?)
235    }
236    /// Rename several parameters used by this Expression.
237    fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
238        Ok(self.0.rename_parameters(&mapping)?)
239    }
240    /// Return a tree-like diagnostic view of the compiled Expression.
241    #[getter]
242    fn compiled_expression(&self) -> PyCompiledExpression {
243        PyCompiledExpression(self.0.compiled_expression())
244    }
245    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
246        if let Ok(other_expr) = other.extract::<PyExpression>() {
247            Ok(PyExpression(self.0.clone() + other_expr.0))
248        } else {
249            Err(PyTypeError::new_err("Unsupported operand type for +"))
250        }
251    }
252    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
253        if let Ok(other_expr) = other.extract::<PyExpression>() {
254            Ok(PyExpression(other_expr.0 + self.0.clone()))
255        } else {
256            Err(PyTypeError::new_err("Unsupported operand type for +"))
257        }
258    }
259    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
260        if let Ok(other_expr) = other.extract::<PyExpression>() {
261            Ok(PyExpression(self.0.clone() - other_expr.0))
262        } else {
263            Err(PyTypeError::new_err("Unsupported operand type for -"))
264        }
265    }
266    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
267        if let Ok(other_expr) = other.extract::<PyExpression>() {
268            Ok(PyExpression(other_expr.0 - self.0.clone()))
269        } else {
270            Err(PyTypeError::new_err("Unsupported operand type for -"))
271        }
272    }
273    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
274        if let Ok(other_expr) = other.extract::<PyExpression>() {
275            Ok(PyExpression(self.0.clone() * other_expr.0))
276        } else {
277            Err(PyTypeError::new_err("Unsupported operand type for *"))
278        }
279    }
280    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
281        if let Ok(other_expr) = other.extract::<PyExpression>() {
282            Ok(PyExpression(other_expr.0 * self.0.clone()))
283        } else {
284            Err(PyTypeError::new_err("Unsupported operand type for *"))
285        }
286    }
287    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
288        if let Ok(other_expr) = other.extract::<PyExpression>() {
289            Ok(PyExpression(self.0.clone() / other_expr.0))
290        } else {
291            Err(PyTypeError::new_err("Unsupported operand type for /"))
292        }
293    }
294    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
295        if let Ok(other_expr) = other.extract::<PyExpression>() {
296            Ok(PyExpression(other_expr.0 / self.0.clone()))
297        } else {
298            Err(PyTypeError::new_err("Unsupported operand type for /"))
299        }
300    }
301    fn __neg__(&self) -> PyExpression {
302        PyExpression(-self.0.clone())
303    }
304    fn __str__(&self) -> String {
305        format!("{}", self.0)
306    }
307    fn __repr__(&self) -> String {
308        format!("{:?}", self.0)
309    }
310
311    #[new]
312    fn new() -> Self {
313        Self(Expression::create_null())
314    }
315    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
316        Ok(PyBytes::new(
317            py,
318            serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
319                .map_err(LadduError::PickleError)?
320                .as_slice(),
321        ))
322    }
323    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
324        *self = Self(
325            serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
326                .map_err(LadduError::PickleError)?,
327        );
328        Ok(())
329    }
330}
331
332/// A class which can be used to evaluate a stored Expression
333///
334/// See Also
335/// --------
336/// laddu.Expression.load
337///
338#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
339#[derive(Clone)]
340pub struct PyEvaluator(pub Evaluator);
341
342#[pymethods]
343impl PyEvaluator {
344    /// The free parameters used by the Evaluator
345    ///
346    /// Returns
347    /// -------
348    /// parameters : tuple of str
349    ///     The tuple of parameter names
350    ///
351    #[getter]
352    fn parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
353        PyTuple::new(py, self.0.parameters())
354    }
355    /// The free parameters used by the Evaluator
356    #[getter]
357    fn free_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
358        PyTuple::new(py, self.0.free_parameters())
359    }
360    /// The fixed parameters used by the Evaluator
361    #[getter]
362    fn fixed_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
363        PyTuple::new(py, self.0.fixed_parameters())
364    }
365    /// Number of free parameters
366    #[getter]
367    fn n_free(&self) -> usize {
368        self.0.n_free()
369    }
370    /// Number of fixed parameters
371    #[getter]
372    fn n_fixed(&self) -> usize {
373        self.0.n_fixed()
374    }
375    /// Total number of parameters
376    #[getter]
377    fn n_parameters(&self) -> usize {
378        self.0.n_parameters()
379    }
380    /// Fix a parameter used by this Evaluator.
381    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
382        Ok(self.0.fix_parameter(name, value)?)
383    }
384    /// Mark a parameter used by this Evaluator as free.
385    fn free_parameter(&self, name: &str) -> PyResult<()> {
386        Ok(self.0.free_parameter(name)?)
387    }
388    /// Rename a single parameter used by this Evaluator.
389    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
390        Ok(self.0.rename_parameter(old, new)?)
391    }
392    /// Rename several parameters used by this Evaluator.
393    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
394        Ok(self.0.rename_parameters(&mapping)?)
395    }
396    /// Activates Amplitudes in the Expression by name or glob selector
397    ///
398    /// Parameters
399    /// ----------
400    /// arg : str or list of str
401    ///     Names or ``*``/``?`` glob selectors of Amplitudes to be activated
402    ///
403    /// Raises
404    /// ------
405    /// TypeError
406    ///     If `arg` is not a str or list of str
407    /// ValueError
408    ///     If `arg` or any items of `arg` are not registered Amplitudes
409    /// strict : bool, default=True
410    ///     When ``True``, raise an error if any selector matches no amplitudes. When
411    ///     ``False``, silently skip selectors with no matches.
412    #[pyo3(signature = (arg, *, strict=true))]
413    fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
414        if let Ok(string_arg) = arg.extract::<String>() {
415            if strict {
416                self.0.activate_strict(&string_arg)?;
417            } else {
418                self.0.activate(&string_arg);
419            }
420        } else if let Ok(list_arg) = arg.cast::<PyList>() {
421            let vec: Vec<String> = list_arg.extract()?;
422            if strict {
423                self.0.activate_many_strict(&vec)?;
424            } else {
425                self.0.activate_many(&vec);
426            }
427        } else {
428            return Err(PyTypeError::new_err(
429                "Argument must be either a string or a list of strings",
430            ));
431        }
432        Ok(())
433    }
434    /// Activates all Amplitudes in the Expression
435    ///
436    fn activate_all(&self) {
437        self.0.activate_all();
438    }
439    /// Deactivates Amplitudes in the Expression by name or glob selector
440    ///
441    /// Deactivated Amplitudes act as zeros in the Expression
442    ///
443    /// Parameters
444    /// ----------
445    /// arg : str or list of str
446    ///     Names or ``*``/``?`` glob selectors of Amplitudes to be deactivated
447    ///
448    /// Raises
449    /// ------
450    /// TypeError
451    ///     If `arg` is not a str or list of str
452    /// ValueError
453    ///     If `arg` or any items of `arg` are not registered Amplitudes
454    /// strict : bool, default=True
455    ///     When ``True``, raise an error if any selector matches no amplitudes. When
456    ///     ``False``, silently skip selectors with no matches.
457    #[pyo3(signature = (arg, *, strict=true))]
458    fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
459        if let Ok(string_arg) = arg.extract::<String>() {
460            if strict {
461                self.0.deactivate_strict(&string_arg)?;
462            } else {
463                self.0.deactivate(&string_arg);
464            }
465        } else if let Ok(list_arg) = arg.cast::<PyList>() {
466            let vec: Vec<String> = list_arg.extract()?;
467            if strict {
468                self.0.deactivate_many_strict(&vec)?;
469            } else {
470                self.0.deactivate_many(&vec);
471            }
472        } else {
473            return Err(PyTypeError::new_err(
474                "Argument must be either a string or a list of strings",
475            ));
476        }
477        Ok(())
478    }
479    /// Deactivates all Amplitudes in the Expression
480    ///
481    fn deactivate_all(&self) {
482        self.0.deactivate_all();
483    }
484    /// Isolates Amplitudes in the Expression by name or glob selector
485    ///
486    /// Activates the Amplitudes given in `arg` and deactivates the rest
487    ///
488    /// Parameters
489    /// ----------
490    /// arg : str or list of str
491    ///     Names or ``*``/``?`` glob selectors of Amplitudes to be isolated
492    ///
493    /// Raises
494    /// ------
495    /// TypeError
496    ///     If `arg` is not a str or list of str
497    /// ValueError
498    ///     If `arg` or any items of `arg` are not registered Amplitudes
499    /// strict : bool, default=True
500    ///     When ``True``, raise an error if any selector matches no amplitudes. When
501    ///     ``False``, silently skip selectors with no matches.
502    #[pyo3(signature = (arg, *, strict=true))]
503    fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
504        if let Ok(string_arg) = arg.extract::<String>() {
505            if strict {
506                self.0.isolate_strict(&string_arg)?;
507            } else {
508                self.0.isolate(&string_arg);
509            }
510        } else if let Ok(list_arg) = arg.cast::<PyList>() {
511            let vec: Vec<String> = list_arg.extract()?;
512            if strict {
513                self.0.isolate_many_strict(&vec)?;
514            } else {
515                self.0.isolate_many(&vec);
516            }
517        } else {
518            return Err(PyTypeError::new_err(
519                "Argument must be either a string or a list of strings",
520            ));
521        }
522        Ok(())
523    }
524
525    /// Return the current active-amplitude mask.
526    #[getter]
527    fn active_mask(&self) -> Vec<bool> {
528        self.0.active_mask()
529    }
530
531    /// Apply an active-amplitude mask.
532    fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
533        self.0.set_active_mask(&mask)?;
534        Ok(())
535    }
536
537    /// Return a tree-like diagnostic view of the compiled Expression.
538    #[getter]
539    fn compiled_expression(&self) -> PyCompiledExpression {
540        PyCompiledExpression(self.0.compiled_expression())
541    }
542
543    /// Return the Expression represented by this Evaluator.
544    #[getter]
545    fn expression(&self) -> PyExpression {
546        PyExpression(self.0.expression())
547    }
548
549    /// Evaluate the stored Expression over the stored Dataset
550    ///
551    /// Parameters
552    /// ----------
553    /// parameters : list of float
554    ///     The values to use for the free parameters
555    /// threads : int, optional
556    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
557    ///     global or context-managed default; any positive value overrides that default for
558    ///     this call only)
559    ///
560    /// Returns
561    /// -------
562    /// result : array_like
563    ///     A ``numpy`` array of complex values for each Event in the Dataset
564    ///
565    /// Raises
566    /// ------
567    /// Exception
568    ///     If there was an error building the thread pool
569    ///
570    #[pyo3(signature = (parameters, *, threads=None))]
571    fn evaluate<'py>(
572        &self,
573        py: Python<'py>,
574        parameters: Vec<f64>,
575        threads: Option<usize>,
576    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
577        let values = install_with_threads(threads, || self.0.evaluate(&parameters))?;
578        Ok(PyArray1::from_slice(py, &values?))
579    }
580    /// Evaluate the stored Expression over a subset of the stored Dataset
581    ///
582    /// Parameters
583    /// ----------
584    /// parameters : list of float
585    ///     The values to use for the free parameters
586    /// indices : list of int
587    ///     The indices of events to evaluate
588    /// threads : int, optional
589    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
590    ///     global or context-managed default; any positive value overrides that default for
591    ///     this call only)
592    ///
593    /// Returns
594    /// -------
595    /// result : array_like
596    ///     A ``numpy`` array of complex values for each indexed Event in the Dataset
597    ///
598    /// Raises
599    /// ------
600    /// Exception
601    ///     If there was an error building the thread pool
602    ///
603    #[pyo3(signature = (parameters, indices, *, threads=None))]
604    fn evaluate_batch<'py>(
605        &self,
606        py: Python<'py>,
607        parameters: Vec<f64>,
608        indices: Vec<usize>,
609        threads: Option<usize>,
610    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
611        let values =
612            install_with_threads(threads, || self.0.evaluate_batch(&parameters, &indices))?;
613        Ok(PyArray1::from_slice(py, &values?))
614    }
615    /// Evaluate the gradient of the stored Expression over the stored Dataset
616    ///
617    /// Parameters
618    /// ----------
619    /// parameters : list of float
620    ///     The values to use for the free parameters
621    /// threads : int, optional
622    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
623    ///     global or context-managed default; any positive value overrides that default for
624    ///     this call only)
625    ///
626    /// Returns
627    /// -------
628    /// result : array_like
629    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
630    ///
631    /// Raises
632    /// ------
633    /// Exception
634    ///     If there was an error building the thread pool or problem creating the resulting
635    ///     ``numpy`` array
636    ///
637    #[pyo3(signature = (parameters, *, threads=None))]
638    fn evaluate_gradient<'py>(
639        &self,
640        py: Python<'py>,
641        parameters: Vec<f64>,
642        threads: Option<usize>,
643    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
644        let gradients: LadduResult<_> = install_with_threads(threads, || {
645            Ok(self
646                .0
647                .evaluate_gradient(&parameters)?
648                .iter()
649                .map(|grad| grad.data.as_vec().to_vec())
650                .collect::<Vec<Vec<Complex64>>>())
651        })?;
652        Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
653    }
654    /// Evaluate the gradient of the stored Expression over a subset of the stored Dataset
655    ///
656    /// Parameters
657    /// ----------
658    /// parameters : list of float
659    ///     The values to use for the free parameters
660    /// indices : list of int
661    ///     The indices of events to evaluate
662    /// threads : int, optional
663    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
664    ///     global or context-managed default; any positive value overrides that default for
665    ///     this call only)
666    ///
667    /// Returns
668    /// -------
669    /// result : array_like
670    ///     A ``numpy`` 2D array of complex values for each indexed Event in the Dataset
671    ///
672    /// Raises
673    /// ------
674    /// Exception
675    ///     If there was an error building the thread pool or problem creating the resulting
676    ///     ``numpy`` array
677    ///
678    #[pyo3(signature = (parameters, indices, *, threads=None))]
679    fn evaluate_gradient_batch<'py>(
680        &self,
681        py: Python<'py>,
682        parameters: Vec<f64>,
683        indices: Vec<usize>,
684        threads: Option<usize>,
685    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
686        let gradients: LadduResult<_> = install_with_threads(threads, || {
687            Ok(self
688                .0
689                .evaluate_gradient_batch(&parameters, &indices)?
690                .iter()
691                .map(|grad| grad.data.as_vec().to_vec())
692                .collect::<Vec<Vec<Complex64>>>())
693        })?;
694        Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
695    }
696}
697
698/// A class which can be used to display the compiled form of an Expression
699///
700/// Notes
701/// -----
702/// This should not be used for anything other than diagnostic purposes.
703///
704#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
705#[derive(Clone)]
706pub struct PyCompiledExpression(pub CompiledExpression);
707
708#[pymethods]
709impl PyCompiledExpression {
710    fn __str__(&self) -> String {
711        format!("{}", self.0)
712    }
713    fn __repr__(&self) -> String {
714        format!("{:?}", self.0)
715    }
716}
717
718#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
719#[derive(Clone)]
720pub struct PyParameter(pub Parameter);
721
722#[pymethods]
723impl PyParameter {
724    #[getter]
725    fn name(&self) -> String {
726        self.0.name()
727    }
728    #[getter]
729    fn fixed(&self) -> Option<f64> {
730        self.0.fixed()
731    }
732    #[getter]
733    fn initial(&self) -> Option<f64> {
734        self.0.initial()
735    }
736    #[getter]
737    fn bounds(&self) -> (Option<f64>, Option<f64>) {
738        self.0.bounds()
739    }
740    #[getter]
741    fn unit(&self) -> Option<String> {
742        self.0.unit()
743    }
744    #[getter]
745    fn latex(&self) -> Option<String> {
746        self.0.latex()
747    }
748    #[getter]
749    fn description(&self) -> Option<String> {
750        self.0.description()
751    }
752}
753
754/// A free parameter which floats during an optimization
755///
756/// Parameters
757/// ----------
758/// name : str
759///     The name of the free parameter
760/// fixed : float, optional
761///     If specified, the parameter will be fixed to this value
762/// initial : float, optional
763///     If specified, the parameter will always be initialized to this value
764/// bounds : tuple of (float or None, float or None)
765///     Specify the lower and upper bounds for the parameter (None corresponds to no bound)
766/// unit : str, optional
767///     Optional unit string which may be used to annotate the parameter
768/// latex : str, optional
769///     Optional LaTeX representation of the parameter
770/// description : str, optional
771///     Optional description of the parameter
772///
773/// Returns
774/// -------
775/// laddu.Parameter
776///     An object that can be used as the input for many Amplitude constructors
777///
778/// Notes
779/// -----
780/// Two free parameters with the same name are shared in a fit.
781///
782/// Attempting to set both the fixed and initial value will result in an overwrite (both will be
783/// set to the "fixed" value).
784///
785#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
786pub fn py_parameter(
787    name: &str,
788    fixed: Option<f64>,
789    initial: Option<f64>,
790    bounds: (Option<f64>, Option<f64>),
791    unit: Option<&str>,
792    latex: Option<&str>,
793    description: Option<&str>,
794) -> PyParameter {
795    let par = Parameter::new(name);
796    if let Some(value) = initial {
797        par.set_initial(value);
798    }
799    if let Some(value) = fixed {
800        par.set_fixed_value(Some(value)); // TODO: make this all consistent
801    }
802    par.set_bounds(bounds.0, bounds.1);
803    if let Some(unit) = unit {
804        par.set_unit(unit);
805    }
806    if let Some(latex) = latex {
807        par.set_latex(latex);
808    }
809    if let Some(description) = description {
810        par.set_description(description);
811    }
812    PyParameter(par)
813}
814
815/// An amplitude used only for internal testing which evaluates `(p0 + i * p1) * event.p4s\[0\].e`.
816#[pyfunction(name = "TestAmplitude")]
817pub fn py_test_amplitude(name: &str, re: PyParameter, im: PyParameter) -> PyResult<PyExpression> {
818    Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
819}