laddu_python/
amplitudes.rs

1use crate::data::PyDataset;
2use laddu_core::{
3    amplitudes::{
4        constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
5        ParameterLike, TestAmplitude,
6    },
7    traits::ReadWrite,
8    Float, LadduError,
9};
10use num::Complex;
11use numpy::{PyArray1, PyArray2};
12use pyo3::{
13    exceptions::PyTypeError,
14    prelude::*,
15    types::{PyBytes, PyList},
16};
17#[cfg(feature = "rayon")]
18use rayon::ThreadPoolBuilder;
19
20/// An object which holds a registered ``Amplitude``
21///
22/// See Also
23/// --------
24/// laddu.Manager.register
25///
26#[pyclass(name = "AmplitudeID", module = "laddu")]
27#[derive(Clone)]
28pub struct PyAmplitudeID(AmplitudeID);
29
30/// A mathematical expression formed from AmplitudeIDs
31///
32#[pyclass(name = "Expression", module = "laddu")]
33#[derive(Clone)]
34pub struct PyExpression(Expression);
35
36/// A convenience method to sum sequences of Amplitudes
37///
38#[pyfunction(name = "amplitude_sum")]
39pub fn py_amplitude_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
40    if terms.is_empty() {
41        return Ok(PyExpression(Expression::Zero));
42    }
43    if terms.len() == 1 {
44        let term = &terms[0];
45        if let Ok(expression) = term.extract::<PyExpression>() {
46            return Ok(expression);
47        }
48        if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
49            return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
50        }
51        return Err(PyTypeError::new_err(
52            "Item is neither a PyExpression nor a PyAmplitudeID",
53        ));
54    }
55    let mut iter = terms.iter();
56    let Some(first_term) = iter.next() else {
57        return Ok(PyExpression(Expression::Zero));
58    };
59    if let Ok(first_expression) = first_term.extract::<PyExpression>() {
60        let mut summation = first_expression.clone();
61        for term in iter {
62            summation = summation.__add__(term)?;
63        }
64        return Ok(summation);
65    }
66    if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
67        let mut summation = PyExpression(Expression::Amp(first_amplitude_id.0));
68        for term in iter {
69            summation = summation.__add__(term)?;
70        }
71        return Ok(summation);
72    }
73    Err(PyTypeError::new_err(
74        "Elements must be PyExpression or PyAmplitudeID",
75    ))
76}
77
78/// A convenience method to multiply sequences of Amplitudes
79///
80#[pyfunction(name = "amplitude_product")]
81pub fn py_amplitude_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
82    if terms.is_empty() {
83        return Ok(PyExpression(Expression::One));
84    }
85    if terms.len() == 1 {
86        let term = &terms[0];
87        if let Ok(expression) = term.extract::<PyExpression>() {
88            return Ok(expression);
89        }
90        if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
91            return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
92        }
93        return Err(PyTypeError::new_err(
94            "Item is neither a PyExpression nor a PyAmplitudeID",
95        ));
96    }
97    let mut iter = terms.iter();
98    let Some(first_term) = iter.next() else {
99        return Ok(PyExpression(Expression::One));
100    };
101    if let Ok(first_expression) = first_term.extract::<PyExpression>() {
102        let mut product = first_expression.clone();
103        for term in iter {
104            product = product.__mul__(term)?;
105        }
106        return Ok(product);
107    }
108    if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
109        let mut product = PyExpression(Expression::Amp(first_amplitude_id.0));
110        for term in iter {
111            product = product.__mul__(term)?;
112        }
113        return Ok(product);
114    }
115    Err(PyTypeError::new_err(
116        "Elements must be PyExpression or PyAmplitudeID",
117    ))
118}
119
120/// A convenience class representing a zero-valued Expression
121///
122#[pyfunction(name = "AmplitudeZero")]
123pub fn py_amplitude_zero() -> PyExpression {
124    PyExpression(Expression::Zero)
125}
126
127/// A convenience class representing a unit-valued Expression
128///
129#[pyfunction(name = "AmplitudeOne")]
130pub fn py_amplitude_one() -> PyExpression {
131    PyExpression(Expression::One)
132}
133
134#[pymethods]
135impl PyAmplitudeID {
136    /// The real part of a complex Amplitude
137    ///
138    /// Returns
139    /// -------
140    /// Expression
141    ///     The real part of the given Amplitude
142    ///
143    fn real(&self) -> PyExpression {
144        PyExpression(self.0.real())
145    }
146    /// The imaginary part of a complex Amplitude
147    ///
148    /// Returns
149    /// -------
150    /// Expression
151    ///     The imaginary part of the given Amplitude
152    ///
153    fn imag(&self) -> PyExpression {
154        PyExpression(self.0.imag())
155    }
156    /// The complex conjugate of a complex Amplitude
157    ///
158    /// Returns
159    /// -------
160    /// Expression
161    ///     The complex conjugate of the given Amplitude
162    ///
163    fn conj(&self) -> PyExpression {
164        PyExpression(self.0.conj())
165    }
166    /// The norm-squared of a complex Amplitude
167    ///
168    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
169    ///
170    /// Returns
171    /// -------
172    /// Expression
173    ///     The norm-squared of the given Amplitude
174    ///
175    fn norm_sqr(&self) -> PyExpression {
176        PyExpression(self.0.norm_sqr())
177    }
178    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
179        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
180            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
181        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
182            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
183        } else if let Ok(other_int) = other.extract::<usize>() {
184            if other_int == 0 {
185                Ok(PyExpression(Expression::Amp(self.0.clone())))
186            } else {
187                Err(PyTypeError::new_err(
188                    "Addition with an integer for this type is only defined for 0",
189                ))
190            }
191        } else {
192            Err(PyTypeError::new_err("Unsupported operand type for +"))
193        }
194    }
195    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
196        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
197            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
198        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
199            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
200        } else if let Ok(other_int) = other.extract::<usize>() {
201            if other_int == 0 {
202                Ok(PyExpression(Expression::Amp(self.0.clone())))
203            } else {
204                Err(PyTypeError::new_err(
205                    "Addition with an integer for this type is only defined for 0",
206                ))
207            }
208        } else {
209            Err(PyTypeError::new_err("Unsupported operand type for +"))
210        }
211    }
212    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
213        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
214            Ok(PyExpression(self.0.clone() - other_aid.0.clone()))
215        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
216            Ok(PyExpression(self.0.clone() - other_expr.0.clone()))
217        } else {
218            Err(PyTypeError::new_err("Unsupported operand type for -"))
219        }
220    }
221    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
222        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
223            Ok(PyExpression(other_aid.0.clone() - self.0.clone()))
224        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
225            Ok(PyExpression(other_expr.0.clone() - self.0.clone()))
226        } else {
227            Err(PyTypeError::new_err("Unsupported operand type for -"))
228        }
229    }
230    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
231        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
232            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
233        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
234            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
235        } else {
236            Err(PyTypeError::new_err("Unsupported operand type for *"))
237        }
238    }
239    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
240        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
241            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
242        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
243            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
244        } else {
245            Err(PyTypeError::new_err("Unsupported operand type for *"))
246        }
247    }
248    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
249        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
250            Ok(PyExpression(self.0.clone() / other_aid.0.clone()))
251        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
252            Ok(PyExpression(self.0.clone() / other_expr.0.clone()))
253        } else {
254            Err(PyTypeError::new_err("Unsupported operand type for /"))
255        }
256    }
257    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
258        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
259            Ok(PyExpression(other_aid.0.clone() / self.0.clone()))
260        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
261            Ok(PyExpression(other_expr.0.clone() / self.0.clone()))
262        } else {
263            Err(PyTypeError::new_err("Unsupported operand type for /"))
264        }
265    }
266    fn __neg__(&self) -> PyExpression {
267        PyExpression(-self.0.clone())
268    }
269    fn __str__(&self) -> String {
270        format!("{}", self.0)
271    }
272    fn __repr__(&self) -> String {
273        format!("{:?}", self.0)
274    }
275}
276
277#[pymethods]
278impl PyExpression {
279    /// The real part of a complex Expression
280    ///
281    /// Returns
282    /// -------
283    /// Expression
284    ///     The real part of the given Expression
285    ///
286    fn real(&self) -> PyExpression {
287        PyExpression(self.0.real())
288    }
289    /// The imaginary part of a complex Expression
290    ///
291    /// Returns
292    /// -------
293    /// Expression
294    ///     The imaginary part of the given Expression
295    ///
296    fn imag(&self) -> PyExpression {
297        PyExpression(self.0.imag())
298    }
299    /// The complex conjugate of a complex Expression
300    ///
301    /// Returns
302    /// -------
303    /// Expression
304    ///     The complex conjugate of the given Expression
305    ///
306    fn conj(&self) -> PyExpression {
307        PyExpression(self.0.conj())
308    }
309    /// The norm-squared of a complex Expression
310    ///
311    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
312    ///
313    /// Returns
314    /// -------
315    /// Expression
316    ///     The norm-squared of the given Expression
317    ///
318    fn norm_sqr(&self) -> PyExpression {
319        PyExpression(self.0.norm_sqr())
320    }
321    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
322        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
323            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
324        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
325            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
326        } else if let Ok(other_int) = other.extract::<usize>() {
327            if other_int == 0 {
328                Ok(PyExpression(self.0.clone()))
329            } else {
330                Err(PyTypeError::new_err(
331                    "Addition with an integer for this type is only defined for 0",
332                ))
333            }
334        } else {
335            Err(PyTypeError::new_err("Unsupported operand type for +"))
336        }
337    }
338    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
339        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
340            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
341        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
342            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
343        } else if let Ok(other_int) = other.extract::<usize>() {
344            if other_int == 0 {
345                Ok(PyExpression(self.0.clone()))
346            } else {
347                Err(PyTypeError::new_err(
348                    "Addition with an integer for this type is only defined for 0",
349                ))
350            }
351        } else {
352            Err(PyTypeError::new_err("Unsupported operand type for +"))
353        }
354    }
355    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
356        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
357            Ok(PyExpression(self.0.clone() - other_aid.0.clone()))
358        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
359            Ok(PyExpression(self.0.clone() - other_expr.0.clone()))
360        } else {
361            Err(PyTypeError::new_err("Unsupported operand type for -"))
362        }
363    }
364    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
365        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
366            Ok(PyExpression(other_aid.0.clone() - self.0.clone()))
367        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
368            Ok(PyExpression(other_expr.0.clone() - self.0.clone()))
369        } else {
370            Err(PyTypeError::new_err("Unsupported operand type for -"))
371        }
372    }
373    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
374        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
375            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
376        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
377            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
378        } else {
379            Err(PyTypeError::new_err("Unsupported operand type for *"))
380        }
381    }
382    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
383        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
384            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
385        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
386            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
387        } else {
388            Err(PyTypeError::new_err("Unsupported operand type for *"))
389        }
390    }
391    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
392        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
393            Ok(PyExpression(self.0.clone() / other_aid.0.clone()))
394        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
395            Ok(PyExpression(self.0.clone() / other_expr.0.clone()))
396        } else {
397            Err(PyTypeError::new_err("Unsupported operand type for /"))
398        }
399    }
400    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
401        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
402            Ok(PyExpression(other_aid.0.clone() / self.0.clone()))
403        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
404            Ok(PyExpression(other_expr.0.clone() / self.0.clone()))
405        } else {
406            Err(PyTypeError::new_err("Unsupported operand type for /"))
407        }
408    }
409    fn __neg__(&self) -> PyExpression {
410        PyExpression(-self.0.clone())
411    }
412    fn __str__(&self) -> String {
413        format!("{}", self.0)
414    }
415    fn __repr__(&self) -> String {
416        format!("{:?}", self.0)
417    }
418}
419
420/// A class which can be used to register Amplitudes and store precalculated data
421///
422#[pyclass(name = "Manager", module = "laddu")]
423pub struct PyManager(Manager);
424
425#[pymethods]
426impl PyManager {
427    #[new]
428    fn new() -> Self {
429        Self(Manager::default())
430    }
431    /// The free parameters used by the Manager
432    ///
433    /// Returns
434    /// -------
435    /// parameters : list of str
436    ///     The list of parameter names
437    ///
438    #[getter]
439    fn parameters(&self) -> Vec<String> {
440        self.0.parameters()
441    }
442    /// Register an Amplitude with the Manager
443    ///
444    /// Parameters
445    /// ----------
446    /// amplitude : Amplitude
447    ///     The Amplitude to register
448    ///
449    /// Returns
450    /// -------
451    /// AmplitudeID
452    ///     A reference to the registered `amplitude` that can be used to form complex
453    ///     Expressions
454    ///
455    /// Raises
456    /// ------
457    /// ValueError
458    ///     If the name of the ``amplitude`` has already been registered
459    ///
460    fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
461        Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
462    }
463    /// Generate a Model from the given expression made of registered Amplitudes
464    ///
465    /// Parameters
466    /// ----------
467    /// expression : Expression or AmplitudeID
468    ///     The expression to use in precalculation
469    ///
470    /// Returns
471    /// -------
472    /// Model
473    ///     An object which represents the underlying mathematical model and can be loaded with
474    ///     a Dataset
475    ///
476    /// Raises
477    /// ------
478    /// TypeError
479    ///     If the expression is not convertable to a Model
480    ///
481    /// Notes
482    /// -----
483    /// While the given `expression` will be the one evaluated in the end, all registered
484    /// Amplitudes will be loaded, and all of their parameters will be included in the final
485    /// expression. These parameters will have no effect on evaluation, but they must be
486    /// included in function calls.
487    ///
488    fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
489        let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
490            Ok(expression.0)
491        } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
492            Ok(Expression::Amp(aid.0))
493        } else {
494            Err(PyTypeError::new_err(
495                "'expression' must either by an Expression or AmplitudeID",
496            ))
497        }?;
498        Ok(PyModel(self.0.model(&expression)))
499    }
500}
501
502/// A class which represents a model composed of registered Amplitudes
503///
504#[pyclass(name = "Model", module = "laddu")]
505pub struct PyModel(pub Model);
506
507#[pymethods]
508impl PyModel {
509    /// The free parameters used by the Manager
510    ///
511    /// Returns
512    /// -------
513    /// parameters : list of str
514    ///     The list of parameter names
515    ///
516    #[getter]
517    fn parameters(&self) -> Vec<String> {
518        self.0.parameters()
519    }
520    /// Load a Model by precalculating each term over the given Dataset
521    ///
522    /// Parameters
523    /// ----------
524    /// dataset : Dataset
525    ///     The Dataset to use in precalculation
526    ///
527    /// Returns
528    /// -------
529    /// Evaluator
530    ///     An object that can be used to evaluate the `expression` over each event in the
531    ///     `dataset`
532    ///
533    /// Notes
534    /// -----
535    /// While the given `expression` will be the one evaluated in the end, all registered
536    /// Amplitudes will be loaded, and all of their parameters will be included in the final
537    /// expression. These parameters will have no effect on evaluation, but they must be
538    /// included in function calls.
539    ///
540    fn load(&self, dataset: &PyDataset) -> PyEvaluator {
541        PyEvaluator(self.0.load(&dataset.0))
542    }
543    #[new]
544    fn new() -> Self {
545        Self(Model::create_null())
546    }
547    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
548        Ok(PyBytes::new(
549            py,
550            bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
551                .map_err(LadduError::EncodeError)?
552                .as_slice(),
553        ))
554    }
555    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
556        *self = PyModel(
557            bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
558                .map_err(LadduError::DecodeError)?
559                .0,
560        );
561        Ok(())
562    }
563}
564
565/// An Amplitude which can be registered by a Manager
566///
567/// See Also
568/// --------
569/// laddu.Manager
570///
571#[pyclass(name = "Amplitude", module = "laddu")]
572pub struct PyAmplitude(pub Box<dyn Amplitude>);
573
574/// A class which can be used to evaluate a stored Expression
575///
576/// See Also
577/// --------
578/// laddu.Manager.load
579///
580#[pyclass(name = "Evaluator", module = "laddu")]
581#[derive(Clone)]
582pub struct PyEvaluator(pub Evaluator);
583
584#[pymethods]
585impl PyEvaluator {
586    /// The free parameters used by the Evaluator
587    ///
588    /// Returns
589    /// -------
590    /// parameters : list of str
591    ///     The list of parameter names
592    ///
593    #[getter]
594    fn parameters(&self) -> Vec<String> {
595        self.0.parameters()
596    }
597    /// Activates Amplitudes in the Expression by name
598    ///
599    /// Parameters
600    /// ----------
601    /// arg : str or list of str
602    ///     Names of Amplitudes to be activated
603    ///
604    /// Raises
605    /// ------
606    /// TypeError
607    ///     If `arg` is not a str or list of str
608    /// ValueError
609    ///     If `arg` or any items of `arg` are not registered Amplitudes
610    ///
611    fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
612        if let Ok(string_arg) = arg.extract::<String>() {
613            self.0.activate(&string_arg)?;
614        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
615            let vec: Vec<String> = list_arg.extract()?;
616            self.0.activate_many(&vec)?;
617        } else {
618            return Err(PyTypeError::new_err(
619                "Argument must be either a string or a list of strings",
620            ));
621        }
622        Ok(())
623    }
624    /// Activates all Amplitudes in the Expression
625    ///
626    fn activate_all(&self) {
627        self.0.activate_all();
628    }
629    /// Deactivates Amplitudes in the Expression by name
630    ///
631    /// Deactivated Amplitudes act as zeros in the Expression
632    ///
633    /// Parameters
634    /// ----------
635    /// arg : str or list of str
636    ///     Names of Amplitudes to be deactivated
637    ///
638    /// Raises
639    /// ------
640    /// TypeError
641    ///     If `arg` is not a str or list of str
642    /// ValueError
643    ///     If `arg` or any items of `arg` are not registered Amplitudes
644    ///
645    fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
646        if let Ok(string_arg) = arg.extract::<String>() {
647            self.0.deactivate(&string_arg)?;
648        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
649            let vec: Vec<String> = list_arg.extract()?;
650            self.0.deactivate_many(&vec)?;
651        } else {
652            return Err(PyTypeError::new_err(
653                "Argument must be either a string or a list of strings",
654            ));
655        }
656        Ok(())
657    }
658    /// Deactivates all Amplitudes in the Expression
659    ///
660    fn deactivate_all(&self) {
661        self.0.deactivate_all();
662    }
663    /// Isolates Amplitudes in the Expression by name
664    ///
665    /// Activates the Amplitudes given in `arg` and deactivates the rest
666    ///
667    /// Parameters
668    /// ----------
669    /// arg : str or list of str
670    ///     Names of Amplitudes to be isolated
671    ///
672    /// Raises
673    /// ------
674    /// TypeError
675    ///     If `arg` is not a str or list of str
676    /// ValueError
677    ///     If `arg` or any items of `arg` are not registered Amplitudes
678    ///
679    fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
680        if let Ok(string_arg) = arg.extract::<String>() {
681            self.0.isolate(&string_arg)?;
682        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
683            let vec: Vec<String> = list_arg.extract()?;
684            self.0.isolate_many(&vec)?;
685        } else {
686            return Err(PyTypeError::new_err(
687                "Argument must be either a string or a list of strings",
688            ));
689        }
690        Ok(())
691    }
692    /// Evaluate the stored Expression over the stored Dataset
693    ///
694    /// Parameters
695    /// ----------
696    /// parameters : list of float
697    ///     The values to use for the free parameters
698    /// threads : int, optional
699    ///     The number of threads to use (setting this to None will use all available CPUs)
700    ///
701    /// Returns
702    /// -------
703    /// result : array_like
704    ///     A ``numpy`` array of complex values for each Event in the Dataset
705    ///
706    /// Raises
707    /// ------
708    /// Exception
709    ///     If there was an error building the thread pool
710    ///
711    #[pyo3(signature = (parameters, *, threads=None))]
712    fn evaluate<'py>(
713        &self,
714        py: Python<'py>,
715        parameters: Vec<Float>,
716        threads: Option<usize>,
717    ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
718        #[cfg(feature = "rayon")]
719        {
720            Ok(PyArray1::from_slice(
721                py,
722                &ThreadPoolBuilder::new()
723                    .num_threads(threads.unwrap_or_else(num_cpus::get))
724                    .build()
725                    .map_err(LadduError::from)?
726                    .install(|| self.0.evaluate(&parameters)),
727            ))
728        }
729        #[cfg(not(feature = "rayon"))]
730        {
731            Ok(PyArray1::from_slice(py, &self.0.evaluate(&parameters)))
732        }
733    }
734    /// Evaluate the stored Expression over a subset of the stored Dataset
735    ///
736    /// Parameters
737    /// ----------
738    /// parameters : list of float
739    ///     The values to use for the free parameters
740    /// indices : list of int
741    ///     The indices of events to evaluate
742    /// threads : int, optional
743    ///     The number of threads to use (setting this to None will use all available CPUs)
744    ///
745    /// Returns
746    /// -------
747    /// result : array_like
748    ///     A ``numpy`` array of complex values for each indexed Event in the Dataset
749    ///
750    /// Raises
751    /// ------
752    /// Exception
753    ///     If there was an error building the thread pool
754    ///
755    #[pyo3(signature = (parameters, indices, *, threads=None))]
756    fn evaluate_batch<'py>(
757        &self,
758        py: Python<'py>,
759        parameters: Vec<Float>,
760        indices: Vec<usize>,
761        threads: Option<usize>,
762    ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
763        #[cfg(feature = "rayon")]
764        {
765            Ok(PyArray1::from_slice(
766                py,
767                &ThreadPoolBuilder::new()
768                    .num_threads(threads.unwrap_or_else(num_cpus::get))
769                    .build()
770                    .map_err(LadduError::from)?
771                    .install(|| self.0.evaluate_batch(&parameters, &indices)),
772            ))
773        }
774        #[cfg(not(feature = "rayon"))]
775        {
776            Ok(PyArray1::from_slice(
777                py,
778                &self.0.evaluate_batch(&parameters, &indices),
779            ))
780        }
781    }
782    /// Evaluate the gradient of the stored Expression over the stored Dataset
783    ///
784    /// Parameters
785    /// ----------
786    /// parameters : list of float
787    ///     The values to use for the free parameters
788    /// threads : int, optional
789    ///     The number of threads to use (setting this to None will use all available CPUs)
790    ///
791    /// Returns
792    /// -------
793    /// result : array_like
794    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
795    ///
796    /// Raises
797    /// ------
798    /// Exception
799    ///     If there was an error building the thread pool or problem creating the resulting
800    ///     ``numpy`` array
801    ///
802    #[pyo3(signature = (parameters, *, threads=None))]
803    fn evaluate_gradient<'py>(
804        &self,
805        py: Python<'py>,
806        parameters: Vec<Float>,
807        threads: Option<usize>,
808    ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
809        #[cfg(feature = "rayon")]
810        {
811            Ok(PyArray2::from_vec2(
812                py,
813                &ThreadPoolBuilder::new()
814                    .num_threads(threads.unwrap_or_else(num_cpus::get))
815                    .build()
816                    .map_err(LadduError::from)?
817                    .install(|| {
818                        self.0
819                            .evaluate_gradient(&parameters)
820                            .iter()
821                            .map(|grad| grad.data.as_vec().to_vec())
822                            .collect::<Vec<Vec<Complex<Float>>>>()
823                    }),
824            )
825            .map_err(LadduError::NumpyError)?)
826        }
827        #[cfg(not(feature = "rayon"))]
828        {
829            Ok(PyArray2::from_vec2(
830                py,
831                &self
832                    .0
833                    .evaluate_gradient(&parameters)
834                    .iter()
835                    .map(|grad| grad.data.as_vec().to_vec())
836                    .collect::<Vec<Vec<Complex<Float>>>>(),
837            )
838            .map_err(LadduError::NumpyError)?)
839        }
840    }
841    /// Evaluate the gradient of the stored Expression over a subset of the stored Dataset
842    ///
843    /// Parameters
844    /// ----------
845    /// parameters : list of float
846    ///     The values to use for the free parameters
847    /// indices : list of int
848    ///     The indices of events to evaluate
849    /// threads : int, optional
850    ///     The number of threads to use (setting this to None will use all available CPUs)
851    ///
852    /// Returns
853    /// -------
854    /// result : array_like
855    ///     A ``numpy`` 2D array of complex values for each indexed Event in the Dataset
856    ///
857    /// Raises
858    /// ------
859    /// Exception
860    ///     If there was an error building the thread pool or problem creating the resulting
861    ///     ``numpy`` array
862    ///
863    #[pyo3(signature = (parameters, indices, *, threads=None))]
864    fn evaluate_gradient_batch<'py>(
865        &self,
866        py: Python<'py>,
867        parameters: Vec<Float>,
868        indices: Vec<usize>,
869        threads: Option<usize>,
870    ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
871        #[cfg(feature = "rayon")]
872        {
873            Ok(PyArray2::from_vec2(
874                py,
875                &ThreadPoolBuilder::new()
876                    .num_threads(threads.unwrap_or_else(num_cpus::get))
877                    .build()
878                    .map_err(LadduError::from)?
879                    .install(|| {
880                        self.0
881                            .evaluate_gradient_batch(&parameters, &indices)
882                            .iter()
883                            .map(|grad| grad.data.as_vec().to_vec())
884                            .collect::<Vec<Vec<Complex<Float>>>>()
885                    }),
886            )
887            .map_err(LadduError::NumpyError)?)
888        }
889        #[cfg(not(feature = "rayon"))]
890        {
891            Ok(PyArray2::from_vec2(
892                py,
893                &self
894                    .0
895                    .evaluate_gradient_batch(&parameters, &indices)
896                    .iter()
897                    .map(|grad| grad.data.as_vec().to_vec())
898                    .collect::<Vec<Vec<Complex<Float>>>>(),
899            )
900            .map_err(LadduError::NumpyError)?)
901        }
902    }
903}
904
905/// A class, typically used to allow Amplitudes to take either free parameters or constants as
906/// inputs
907///
908/// See Also
909/// --------
910/// laddu.parameter
911/// laddu.constant
912///
913#[pyclass(name = "ParameterLike", module = "laddu")]
914#[derive(Clone)]
915pub struct PyParameterLike(pub ParameterLike);
916
917/// A free parameter which floats during an optimization
918///
919/// Parameters
920/// ----------
921/// name : str
922///     The name of the free parameter
923///
924/// Returns
925/// -------
926/// laddu.ParameterLike
927///     An object that can be used as the input for many Amplitude constructors
928///
929/// Notes
930/// -----
931/// Two free parameters with the same name are shared in a fit
932///
933#[pyfunction(name = "parameter")]
934pub fn py_parameter(name: &str) -> PyParameterLike {
935    PyParameterLike(parameter(name))
936}
937
938/// A term which stays constant during an optimization
939///
940/// Parameters
941/// ----------
942/// value : float
943///     The numerical value of the constant
944///
945/// Returns
946/// -------
947/// laddu.ParameterLike
948///     An object that can be used as the input for many Amplitude constructors
949///
950#[pyfunction(name = "constant")]
951pub fn py_constant(value: Float) -> PyParameterLike {
952    PyParameterLike(constant(value))
953}
954
955/// A amplitude used only for internal testing which evaluates (p0 + i * p1) * event.p4s[0].e
956#[pyfunction(name = "TestAmplitude")]
957pub fn py_test_amplitude(name: &str, re: PyParameterLike, im: PyParameterLike) -> PyAmplitude {
958    PyAmplitude(TestAmplitude::new(name, re.0, im.0))
959}