Skip to main content

laddu_python/
amplitudes.rs

1use crate::data::PyDataset;
2use laddu_core::{
3    amplitudes::{constant, parameter, Evaluator, Expression, ParameterLike, TestAmplitude},
4    f64, CompiledExpression, LadduError, LadduResult, ReadWrite, ThreadPoolManager,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9    exceptions::PyTypeError,
10    prelude::*,
11    types::{PyBytes, PyList},
12};
13use std::collections::HashMap;
14
15fn install_with_threads<R: Send>(
16    threads: Option<usize>,
17    op: impl FnOnce() -> R + Send,
18) -> LadduResult<R> {
19    ThreadPoolManager::shared().install(threads, op)
20}
21
22/// A mathematical expression formed from amplitudes.
23///
24#[pyclass(name = "Expression", module = "laddu", from_py_object)]
25#[derive(Clone)]
26pub struct PyExpression(pub Expression);
27
28/// A convenience method to sum sequences of Expressions
29///
30#[pyfunction(name = "expr_sum")]
31pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
32    if terms.is_empty() {
33        return Ok(PyExpression(Expression::zero()));
34    }
35    if terms.len() == 1 {
36        let term = &terms[0];
37        if let Ok(expression) = term.extract::<PyExpression>() {
38            return Ok(expression);
39        }
40        return Err(PyTypeError::new_err("Item is not a PyExpression"));
41    }
42    let mut iter = terms.iter();
43    let Some(first_term) = iter.next() else {
44        return Ok(PyExpression(Expression::zero()));
45    };
46    let PyExpression(mut summation) = first_term
47        .extract::<PyExpression>()
48        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
49    for term in iter {
50        let PyExpression(expr) = term
51            .extract::<PyExpression>()
52            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
53        summation = summation + expr;
54    }
55    Ok(PyExpression(summation))
56}
57
58/// A convenience method to multiply sequences of Expressions
59///
60#[pyfunction(name = "expr_product")]
61pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
62    if terms.is_empty() {
63        return Ok(PyExpression(Expression::one()));
64    }
65    if terms.len() == 1 {
66        let term = &terms[0];
67        if let Ok(expression) = term.extract::<PyExpression>() {
68            return Ok(expression);
69        }
70        return Err(PyTypeError::new_err("Item is not a PyExpression"));
71    }
72    let mut iter = terms.iter();
73    let Some(first_term) = iter.next() else {
74        return Ok(PyExpression(Expression::one()));
75    };
76    let PyExpression(mut product) = first_term
77        .extract::<PyExpression>()
78        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
79    for term in iter {
80        let PyExpression(expr) = term
81            .extract::<PyExpression>()
82            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
83        product = product * expr;
84    }
85    Ok(PyExpression(product))
86}
87
88/// A convenience class representing a zero-valued Expression
89///
90#[pyfunction(name = "Zero")]
91pub fn py_expr_zero() -> PyExpression {
92    PyExpression(Expression::zero())
93}
94
95/// A convenience class representing a unit-valued Expression
96///
97#[pyfunction(name = "One")]
98pub fn py_expr_one() -> PyExpression {
99    PyExpression(Expression::one())
100}
101
102#[pymethods]
103impl PyExpression {
104    /// The free parameters used by the Expression
105    ///
106    /// Returns
107    /// -------
108    /// parameters : list of str
109    ///     The list of parameter names
110    #[getter]
111    fn parameters(&self) -> Vec<String> {
112        self.0.parameters()
113    }
114    /// The free parameters used by the Expression
115    #[getter]
116    fn free_parameters(&self) -> Vec<String> {
117        self.0.free_parameters()
118    }
119    /// The fixed parameters used by the Expression
120    #[getter]
121    fn fixed_parameters(&self) -> Vec<String> {
122        self.0.fixed_parameters()
123    }
124    /// Number of free parameters
125    #[getter]
126    fn n_free(&self) -> usize {
127        self.0.n_free()
128    }
129    /// Number of fixed parameters
130    #[getter]
131    fn n_fixed(&self) -> usize {
132        self.0.n_fixed()
133    }
134    /// Total number of parameters
135    #[getter]
136    fn n_parameters(&self) -> usize {
137        self.0.n_parameters()
138    }
139    /// Load an Expression by precalculating each term over the given Dataset
140    ///
141    /// Parameters
142    /// ----------
143    /// dataset : Dataset
144    ///     The Dataset to use in precalculation
145    ///
146    /// Returns
147    /// -------
148    /// Evaluator
149    ///     An object that can be used to evaluate the `expression` over each event in the
150    ///     `dataset`
151    fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
152        Ok(PyEvaluator(self.0.load(&dataset.0)?))
153    }
154    /// The real part of a complex Expression
155    fn real(&self) -> PyExpression {
156        PyExpression(self.0.real())
157    }
158    /// The imaginary part of a complex Expression
159    fn imag(&self) -> PyExpression {
160        PyExpression(self.0.imag())
161    }
162    /// The complex conjugate of a complex Expression
163    fn conj(&self) -> PyExpression {
164        PyExpression(self.0.conj())
165    }
166    /// The norm-squared of a complex Expression
167    fn norm_sqr(&self) -> PyExpression {
168        PyExpression(self.0.norm_sqr())
169    }
170    /// The square root of an Expression
171    fn sqrt(&self) -> PyExpression {
172        PyExpression(self.0.sqrt())
173    }
174    /// Raise an Expression to an int, float, or Expression power
175    fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
176        if let Ok(expression) = power.extract::<PyExpression>() {
177            Ok(PyExpression(self.0.pow(&expression.0)))
178        } else if let Ok(value) = power.extract::<i32>() {
179            Ok(PyExpression(self.0.powi(value)))
180        } else if let Ok(value) = power.extract::<f64>() {
181            Ok(PyExpression(self.0.powf(value)))
182        } else {
183            Err(PyTypeError::new_err(
184                "power must be an int, float, or Expression",
185            ))
186        }
187    }
188    /// The exponential of an Expression
189    fn exp(&self) -> PyExpression {
190        PyExpression(self.0.exp())
191    }
192    /// The sine of an Expression
193    fn sin(&self) -> PyExpression {
194        PyExpression(self.0.sin())
195    }
196    /// The cosine of an Expression
197    fn cos(&self) -> PyExpression {
198        PyExpression(self.0.cos())
199    }
200    /// The natural logarithm of an Expression
201    fn log(&self) -> PyExpression {
202        PyExpression(self.0.log())
203    }
204    /// The complex phase factor exp(i * expression)
205    fn cis(&self) -> PyExpression {
206        PyExpression(self.0.cis())
207    }
208    /// Return a new Expression with the given parameter fixed
209    fn fix(&self, name: &str, value: f64) -> PyResult<PyExpression> {
210        Ok(PyExpression(self.0.fix(name, value)?))
211    }
212    /// Return a new Expression with the given parameter freed
213    fn free(&self, name: &str) -> PyResult<PyExpression> {
214        Ok(PyExpression(self.0.free(name)?))
215    }
216    /// Return a new Expression with a single parameter renamed
217    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyExpression> {
218        Ok(PyExpression(self.0.rename_parameter(old, new)?))
219    }
220    /// Return a new Expression with several parameters renamed
221    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyExpression> {
222        Ok(PyExpression(self.0.rename_parameters(&mapping)?))
223    }
224    /// Return a tree-like diagnostic view of the compiled Expression.
225    #[getter]
226    fn compiled_expression(&self) -> PyCompiledExpression {
227        PyCompiledExpression(self.0.compiled_expression())
228    }
229    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
230        if let Ok(other_expr) = other.extract::<PyExpression>() {
231            Ok(PyExpression(self.0.clone() + other_expr.0))
232        } else if let Ok(other_int) = other.extract::<usize>() {
233            if other_int == 0 {
234                Ok(PyExpression(self.0.clone()))
235            } else {
236                Err(PyTypeError::new_err(
237                    "Addition with an integer for this type is only defined for 0",
238                ))
239            }
240        } else {
241            Err(PyTypeError::new_err("Unsupported operand type for +"))
242        }
243    }
244    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
245        if let Ok(other_expr) = other.extract::<PyExpression>() {
246            Ok(PyExpression(other_expr.0 + self.0.clone()))
247        } else if let Ok(other_int) = other.extract::<usize>() {
248            if other_int == 0 {
249                Ok(PyExpression(self.0.clone()))
250            } else {
251                Err(PyTypeError::new_err(
252                    "Addition with an integer for this type is only defined for 0",
253                ))
254            }
255        } else {
256            Err(PyTypeError::new_err("Unsupported operand type for +"))
257        }
258    }
259    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
260        if let Ok(other_expr) = other.extract::<PyExpression>() {
261            Ok(PyExpression(self.0.clone() - other_expr.0))
262        } else {
263            Err(PyTypeError::new_err("Unsupported operand type for -"))
264        }
265    }
266    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
267        if let Ok(other_expr) = other.extract::<PyExpression>() {
268            Ok(PyExpression(other_expr.0 - self.0.clone()))
269        } else {
270            Err(PyTypeError::new_err("Unsupported operand type for -"))
271        }
272    }
273    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
274        if let Ok(other_expr) = other.extract::<PyExpression>() {
275            Ok(PyExpression(self.0.clone() * other_expr.0))
276        } else {
277            Err(PyTypeError::new_err("Unsupported operand type for *"))
278        }
279    }
280    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
281        if let Ok(other_expr) = other.extract::<PyExpression>() {
282            Ok(PyExpression(other_expr.0 * self.0.clone()))
283        } else {
284            Err(PyTypeError::new_err("Unsupported operand type for *"))
285        }
286    }
287    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
288        if let Ok(other_expr) = other.extract::<PyExpression>() {
289            Ok(PyExpression(self.0.clone() / other_expr.0))
290        } else {
291            Err(PyTypeError::new_err("Unsupported operand type for /"))
292        }
293    }
294    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
295        if let Ok(other_expr) = other.extract::<PyExpression>() {
296            Ok(PyExpression(other_expr.0 / self.0.clone()))
297        } else {
298            Err(PyTypeError::new_err("Unsupported operand type for /"))
299        }
300    }
301    fn __neg__(&self) -> PyExpression {
302        PyExpression(-self.0.clone())
303    }
304    fn __str__(&self) -> String {
305        format!("{}", self.0)
306    }
307    fn __repr__(&self) -> String {
308        format!("{:?}", self.0)
309    }
310
311    #[new]
312    fn new() -> Self {
313        Self(Expression::create_null())
314    }
315    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
316        Ok(PyBytes::new(
317            py,
318            serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
319                .map_err(LadduError::PickleError)?
320                .as_slice(),
321        ))
322    }
323    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
324        *self = Self(
325            serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
326                .map_err(LadduError::PickleError)?,
327        );
328        Ok(())
329    }
330}
331
332/// A class which can be used to evaluate a stored Expression
333///
334/// See Also
335/// --------
336/// laddu.Expression.load
337///
338#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
339#[derive(Clone)]
340pub struct PyEvaluator(pub Evaluator);
341
342#[pymethods]
343impl PyEvaluator {
344    /// The free parameters used by the Evaluator
345    ///
346    /// Returns
347    /// -------
348    /// parameters : list of str
349    ///     The list of parameter names
350    ///
351    #[getter]
352    fn parameters(&self) -> Vec<String> {
353        self.0.parameters()
354    }
355    /// The free parameters used by the Evaluator
356    #[getter]
357    fn free_parameters(&self) -> Vec<String> {
358        self.0.free_parameters()
359    }
360    /// The fixed parameters used by the Evaluator
361    #[getter]
362    fn fixed_parameters(&self) -> Vec<String> {
363        self.0.fixed_parameters()
364    }
365    /// Number of free parameters
366    #[getter]
367    fn n_free(&self) -> usize {
368        self.0.n_free()
369    }
370    /// Number of fixed parameters
371    #[getter]
372    fn n_fixed(&self) -> usize {
373        self.0.n_fixed()
374    }
375    /// Total number of parameters
376    #[getter]
377    fn n_parameters(&self) -> usize {
378        self.0.n_parameters()
379    }
380    /// Return a new Evaluator with the given parameter fixed
381    fn fix(&self, name: &str, value: f64) -> PyResult<PyEvaluator> {
382        Ok(PyEvaluator(self.0.fix(name, value)?))
383    }
384    /// Return a new Evaluator with the given parameter freed
385    fn free(&self, name: &str) -> PyResult<PyEvaluator> {
386        Ok(PyEvaluator(self.0.free(name)?))
387    }
388    /// Return a new Evaluator with a single parameter renamed
389    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyEvaluator> {
390        Ok(PyEvaluator(self.0.rename_parameter(old, new)?))
391    }
392    /// Return a new Evaluator with several parameters renamed
393    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyEvaluator> {
394        Ok(PyEvaluator(self.0.rename_parameters(&mapping)?))
395    }
396    /// Activates Amplitudes in the Expression by name or glob selector
397    ///
398    /// Parameters
399    /// ----------
400    /// arg : str or list of str
401    ///     Names or ``*``/``?`` glob selectors of Amplitudes to be activated
402    ///
403    /// Raises
404    /// ------
405    /// TypeError
406    ///     If `arg` is not a str or list of str
407    /// ValueError
408    ///     If `arg` or any items of `arg` are not registered Amplitudes
409    /// strict : bool, default=True
410    ///     When ``True``, raise an error if any selector matches no amplitudes. When
411    ///     ``False``, silently skip selectors with no matches.
412    #[pyo3(signature = (arg, *, strict=true))]
413    fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
414        if let Ok(string_arg) = arg.extract::<String>() {
415            if strict {
416                self.0.activate_strict(&string_arg)?;
417            } else {
418                self.0.activate(&string_arg);
419            }
420        } else if let Ok(list_arg) = arg.cast::<PyList>() {
421            let vec: Vec<String> = list_arg.extract()?;
422            if strict {
423                self.0.activate_many_strict(&vec)?;
424            } else {
425                self.0.activate_many(&vec);
426            }
427        } else {
428            return Err(PyTypeError::new_err(
429                "Argument must be either a string or a list of strings",
430            ));
431        }
432        Ok(())
433    }
434    /// Activates all Amplitudes in the Expression
435    ///
436    fn activate_all(&self) {
437        self.0.activate_all();
438    }
439    /// Deactivates Amplitudes in the Expression by name or glob selector
440    ///
441    /// Deactivated Amplitudes act as zeros in the Expression
442    ///
443    /// Parameters
444    /// ----------
445    /// arg : str or list of str
446    ///     Names or ``*``/``?`` glob selectors of Amplitudes to be deactivated
447    ///
448    /// Raises
449    /// ------
450    /// TypeError
451    ///     If `arg` is not a str or list of str
452    /// ValueError
453    ///     If `arg` or any items of `arg` are not registered Amplitudes
454    /// strict : bool, default=True
455    ///     When ``True``, raise an error if any selector matches no amplitudes. When
456    ///     ``False``, silently skip selectors with no matches.
457    #[pyo3(signature = (arg, *, strict=true))]
458    fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
459        if let Ok(string_arg) = arg.extract::<String>() {
460            if strict {
461                self.0.deactivate_strict(&string_arg)?;
462            } else {
463                self.0.deactivate(&string_arg);
464            }
465        } else if let Ok(list_arg) = arg.cast::<PyList>() {
466            let vec: Vec<String> = list_arg.extract()?;
467            if strict {
468                self.0.deactivate_many_strict(&vec)?;
469            } else {
470                self.0.deactivate_many(&vec);
471            }
472        } else {
473            return Err(PyTypeError::new_err(
474                "Argument must be either a string or a list of strings",
475            ));
476        }
477        Ok(())
478    }
479    /// Deactivates all Amplitudes in the Expression
480    ///
481    fn deactivate_all(&self) {
482        self.0.deactivate_all();
483    }
484    /// Isolates Amplitudes in the Expression by name or glob selector
485    ///
486    /// Activates the Amplitudes given in `arg` and deactivates the rest
487    ///
488    /// Parameters
489    /// ----------
490    /// arg : str or list of str
491    ///     Names or ``*``/``?`` glob selectors of Amplitudes to be isolated
492    ///
493    /// Raises
494    /// ------
495    /// TypeError
496    ///     If `arg` is not a str or list of str
497    /// ValueError
498    ///     If `arg` or any items of `arg` are not registered Amplitudes
499    /// strict : bool, default=True
500    ///     When ``True``, raise an error if any selector matches no amplitudes. When
501    ///     ``False``, silently skip selectors with no matches.
502    #[pyo3(signature = (arg, *, strict=true))]
503    fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
504        if let Ok(string_arg) = arg.extract::<String>() {
505            if strict {
506                self.0.isolate_strict(&string_arg)?;
507            } else {
508                self.0.isolate(&string_arg);
509            }
510        } else if let Ok(list_arg) = arg.cast::<PyList>() {
511            let vec: Vec<String> = list_arg.extract()?;
512            if strict {
513                self.0.isolate_many_strict(&vec)?;
514            } else {
515                self.0.isolate_many(&vec);
516            }
517        } else {
518            return Err(PyTypeError::new_err(
519                "Argument must be either a string or a list of strings",
520            ));
521        }
522        Ok(())
523    }
524
525    /// Return the current active-amplitude mask.
526    #[getter]
527    fn active_mask(&self) -> Vec<bool> {
528        self.0.active_mask()
529    }
530
531    /// Apply an active-amplitude mask.
532    fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
533        self.0.set_active_mask(&mask)?;
534        Ok(())
535    }
536
537    /// Return a tree-like diagnostic view of the compiled Expression.
538    #[getter]
539    fn compiled_expression(&self) -> PyCompiledExpression {
540        PyCompiledExpression(self.0.compiled_expression())
541    }
542
543    /// Return the Expression represented by this Evaluator.
544    #[getter]
545    fn expression(&self) -> PyExpression {
546        PyExpression(self.0.expression())
547    }
548
549    /// Evaluate the stored Expression over the stored Dataset
550    ///
551    /// Parameters
552    /// ----------
553    /// parameters : list of float
554    ///     The values to use for the free parameters
555    /// threads : int, optional
556    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
557    ///     global or context-managed default; any positive value overrides that default for
558    ///     this call only)
559    ///
560    /// Returns
561    /// -------
562    /// result : array_like
563    ///     A ``numpy`` array of complex values for each Event in the Dataset
564    ///
565    /// Raises
566    /// ------
567    /// Exception
568    ///     If there was an error building the thread pool
569    ///
570    #[pyo3(signature = (parameters, *, threads=None))]
571    fn evaluate<'py>(
572        &self,
573        py: Python<'py>,
574        parameters: Vec<f64>,
575        threads: Option<usize>,
576    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
577        let values = install_with_threads(threads, || self.0.evaluate(&parameters))?;
578        Ok(PyArray1::from_slice(py, &values))
579    }
580    /// Evaluate the stored Expression over a subset of the stored Dataset
581    ///
582    /// Parameters
583    /// ----------
584    /// parameters : list of float
585    ///     The values to use for the free parameters
586    /// indices : list of int
587    ///     The indices of events to evaluate
588    /// threads : int, optional
589    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
590    ///     global or context-managed default; any positive value overrides that default for
591    ///     this call only)
592    ///
593    /// Returns
594    /// -------
595    /// result : array_like
596    ///     A ``numpy`` array of complex values for each indexed Event in the Dataset
597    ///
598    /// Raises
599    /// ------
600    /// Exception
601    ///     If there was an error building the thread pool
602    ///
603    #[pyo3(signature = (parameters, indices, *, threads=None))]
604    fn evaluate_batch<'py>(
605        &self,
606        py: Python<'py>,
607        parameters: Vec<f64>,
608        indices: Vec<usize>,
609        threads: Option<usize>,
610    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
611        let values =
612            install_with_threads(threads, || self.0.evaluate_batch(&parameters, &indices))?;
613        Ok(PyArray1::from_slice(py, &values))
614    }
615    /// Evaluate the gradient of the stored Expression over the stored Dataset
616    ///
617    /// Parameters
618    /// ----------
619    /// parameters : list of float
620    ///     The values to use for the free parameters
621    /// threads : int, optional
622    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
623    ///     global or context-managed default; any positive value overrides that default for
624    ///     this call only)
625    ///
626    /// Returns
627    /// -------
628    /// result : array_like
629    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
630    ///
631    /// Raises
632    /// ------
633    /// Exception
634    ///     If there was an error building the thread pool or problem creating the resulting
635    ///     ``numpy`` array
636    ///
637    #[pyo3(signature = (parameters, *, threads=None))]
638    fn evaluate_gradient<'py>(
639        &self,
640        py: Python<'py>,
641        parameters: Vec<f64>,
642        threads: Option<usize>,
643    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
644        let gradients = install_with_threads(threads, || {
645            self.0
646                .evaluate_gradient(&parameters)
647                .iter()
648                .map(|grad| grad.data.as_vec().to_vec())
649                .collect::<Vec<Vec<Complex64>>>()
650        })?;
651        Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
652    }
653    /// Evaluate the gradient of the stored Expression over a subset of the stored Dataset
654    ///
655    /// Parameters
656    /// ----------
657    /// parameters : list of float
658    ///     The values to use for the free parameters
659    /// indices : list of int
660    ///     The indices of events to evaluate
661    /// threads : int, optional
662    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
663    ///     global or context-managed default; any positive value overrides that default for
664    ///     this call only)
665    ///
666    /// Returns
667    /// -------
668    /// result : array_like
669    ///     A ``numpy`` 2D array of complex values for each indexed Event in the Dataset
670    ///
671    /// Raises
672    /// ------
673    /// Exception
674    ///     If there was an error building the thread pool or problem creating the resulting
675    ///     ``numpy`` array
676    ///
677    #[pyo3(signature = (parameters, indices, *, threads=None))]
678    fn evaluate_gradient_batch<'py>(
679        &self,
680        py: Python<'py>,
681        parameters: Vec<f64>,
682        indices: Vec<usize>,
683        threads: Option<usize>,
684    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
685        let gradients = install_with_threads(threads, || {
686            self.0
687                .evaluate_gradient_batch(&parameters, &indices)
688                .iter()
689                .map(|grad| grad.data.as_vec().to_vec())
690                .collect::<Vec<Vec<Complex64>>>()
691        })?;
692        Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
693    }
694}
695
696/// A class which can be used to display the compiled form of an Expression
697///
698/// Notes
699/// -----
700/// This should not be used for anything other than diagnostic purposes.
701///
702#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
703#[derive(Clone)]
704pub struct PyCompiledExpression(pub CompiledExpression);
705
706#[pymethods]
707impl PyCompiledExpression {
708    fn __str__(&self) -> String {
709        format!("{}", self.0)
710    }
711    fn __repr__(&self) -> String {
712        format!("{:?}", self.0)
713    }
714}
715
716/// A class, typically used to allow Amplitudes to take either free parameters or constants as
717/// inputs
718///
719/// See Also
720/// --------
721/// laddu.parameter
722/// laddu.constant
723///
724#[pyclass(name = "ParameterLike", module = "laddu", from_py_object)]
725#[derive(Clone)]
726pub struct PyParameterLike(pub ParameterLike);
727
728/// A free parameter which floats during an optimization
729///
730/// Parameters
731/// ----------
732/// name : str
733///     The name of the free parameter
734///
735/// Returns
736/// -------
737/// laddu.ParameterLike
738///     An object that can be used as the input for many Amplitude constructors
739///
740/// Notes
741/// -----
742/// Two free parameters with the same name are shared in a fit
743///
744#[pyfunction(name = "parameter")]
745pub fn py_parameter(name: &str) -> PyParameterLike {
746    PyParameterLike(parameter(name))
747}
748
749/// A term which stays constant during an optimization
750///
751/// Parameters
752/// ----------
753/// name : str
754///     The name of the parameter
755/// value : float
756///     The numerical value of the constant
757///
758/// Returns
759/// -------
760/// laddu.ParameterLike
761///     An object that can be used as the input for many Amplitude constructors
762///
763#[pyfunction(name = "constant")]
764pub fn py_constant(name: &str, value: f64) -> PyParameterLike {
765    PyParameterLike(constant(name, value))
766}
767
768/// An amplitude used only for internal testing which evaluates `(p0 + i * p1) * event.p4s\[0\].e`.
769#[pyfunction(name = "TestAmplitude")]
770pub fn py_test_amplitude(
771    name: &str,
772    re: PyParameterLike,
773    im: PyParameterLike,
774) -> PyResult<PyExpression> {
775    Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
776}