laddu_python/
amplitudes.rs

1use crate::data::PyDataset;
2use laddu_core::{
3    amplitudes::{
4        constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
5        ParameterLike,
6    },
7    traits::ReadWrite,
8    Complex, Float, LadduError,
9};
10use numpy::{PyArray1, PyArray2};
11use pyo3::{
12    exceptions::PyTypeError,
13    prelude::*,
14    types::{PyBytes, PyList},
15};
16#[cfg(feature = "rayon")]
17use rayon::ThreadPoolBuilder;
18
19/// An object which holds a registered ``Amplitude``
20///
21/// See Also
22/// --------
23/// laddu.Manager.register
24///
25#[pyclass(name = "AmplitudeID", module = "laddu")]
26#[derive(Clone)]
27pub struct PyAmplitudeID(AmplitudeID);
28
29/// A mathematical expression formed from AmplitudeIDs
30///
31#[pyclass(name = "Expression", module = "laddu")]
32#[derive(Clone)]
33pub struct PyExpression(Expression);
34
35/// A convenience method to sum sequences of Amplitudes
36///
37#[pyfunction(name = "amplitude_sum")]
38pub fn py_amplitude_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
39    if terms.is_empty() {
40        return Ok(PyExpression(Expression::Zero));
41    }
42    if terms.len() == 1 {
43        let term = &terms[0];
44        if let Ok(expression) = term.extract::<PyExpression>() {
45            return Ok(expression);
46        }
47        if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
48            return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
49        }
50        return Err(PyTypeError::new_err(
51            "Item is neither a PyExpression nor a PyAmplitudeID",
52        ));
53    }
54    let mut iter = terms.iter();
55    let Some(first_term) = iter.next() else {
56        return Ok(PyExpression(Expression::Zero));
57    };
58    if let Ok(first_expression) = first_term.extract::<PyExpression>() {
59        let mut summation = first_expression.clone();
60        for term in iter {
61            summation = summation.__add__(term)?;
62        }
63        return Ok(summation);
64    }
65    if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
66        let mut summation = PyExpression(Expression::Amp(first_amplitude_id.0));
67        for term in iter {
68            summation = summation.__add__(term)?;
69        }
70        return Ok(summation);
71    }
72    Err(PyTypeError::new_err(
73        "Elements must be PyExpression or PyAmplitudeID",
74    ))
75}
76
77/// A convenience method to multiply sequences of Amplitudes
78///
79#[pyfunction(name = "amplitude_product")]
80pub fn py_amplitude_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
81    if terms.is_empty() {
82        return Ok(PyExpression(Expression::One));
83    }
84    if terms.len() == 1 {
85        let term = &terms[0];
86        if let Ok(expression) = term.extract::<PyExpression>() {
87            return Ok(expression);
88        }
89        if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
90            return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
91        }
92        return Err(PyTypeError::new_err(
93            "Item is neither a PyExpression nor a PyAmplitudeID",
94        ));
95    }
96    let mut iter = terms.iter();
97    let Some(first_term) = iter.next() else {
98        return Ok(PyExpression(Expression::One));
99    };
100    if let Ok(first_expression) = first_term.extract::<PyExpression>() {
101        let mut product = first_expression.clone();
102        for term in iter {
103            product = product.__mul__(term)?;
104        }
105        return Ok(product);
106    }
107    if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
108        let mut product = PyExpression(Expression::Amp(first_amplitude_id.0));
109        for term in iter {
110            product = product.__mul__(term)?;
111        }
112        return Ok(product);
113    }
114    Err(PyTypeError::new_err(
115        "Elements must be PyExpression or PyAmplitudeID",
116    ))
117}
118
119/// A convenience class representing a zero-valued Expression
120///
121#[pyfunction(name = "AmplitudeZero")]
122pub fn py_amplitude_zero() -> PyExpression {
123    PyExpression(Expression::Zero)
124}
125
126/// A convenience class representing a unit-valued Expression
127///
128#[pyfunction(name = "AmplitudeOne")]
129pub fn py_amplitude_one() -> PyExpression {
130    PyExpression(Expression::One)
131}
132
133#[pymethods]
134impl PyAmplitudeID {
135    /// The real part of a complex Amplitude
136    ///
137    /// Returns
138    /// -------
139    /// Expression
140    ///     The real part of the given Amplitude
141    ///
142    fn real(&self) -> PyExpression {
143        PyExpression(self.0.real())
144    }
145    /// The imaginary part of a complex Amplitude
146    ///
147    /// Returns
148    /// -------
149    /// Expression
150    ///     The imaginary part of the given Amplitude
151    ///
152    fn imag(&self) -> PyExpression {
153        PyExpression(self.0.imag())
154    }
155    /// The complex conjugate of a complex Amplitude
156    ///
157    /// Returns
158    /// -------
159    /// Expression
160    ///     The complex conjugate of the given Amplitude
161    ///
162    fn conj(&self) -> PyExpression {
163        PyExpression(self.0.conj())
164    }
165    /// The norm-squared of a complex Amplitude
166    ///
167    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
168    ///
169    /// Returns
170    /// -------
171    /// Expression
172    ///     The norm-squared of the given Amplitude
173    ///
174    fn norm_sqr(&self) -> PyExpression {
175        PyExpression(self.0.norm_sqr())
176    }
177    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
178        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
179            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
180        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
181            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
182        } else if let Ok(other_int) = other.extract::<usize>() {
183            if other_int == 0 {
184                Ok(PyExpression(Expression::Amp(self.0.clone())))
185            } else {
186                Err(PyTypeError::new_err(
187                    "Addition with an integer for this type is only defined for 0",
188                ))
189            }
190        } else {
191            Err(PyTypeError::new_err("Unsupported operand type for +"))
192        }
193    }
194    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
195        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
196            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
197        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
198            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
199        } else if let Ok(other_int) = other.extract::<usize>() {
200            if other_int == 0 {
201                Ok(PyExpression(Expression::Amp(self.0.clone())))
202            } else {
203                Err(PyTypeError::new_err(
204                    "Addition with an integer for this type is only defined for 0",
205                ))
206            }
207        } else {
208            Err(PyTypeError::new_err("Unsupported operand type for +"))
209        }
210    }
211    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
212        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
213            Ok(PyExpression(self.0.clone() - other_aid.0.clone()))
214        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
215            Ok(PyExpression(self.0.clone() - other_expr.0.clone()))
216        } else {
217            Err(PyTypeError::new_err("Unsupported operand type for -"))
218        }
219    }
220    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
221        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
222            Ok(PyExpression(other_aid.0.clone() - self.0.clone()))
223        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
224            Ok(PyExpression(other_expr.0.clone() - self.0.clone()))
225        } else {
226            Err(PyTypeError::new_err("Unsupported operand type for -"))
227        }
228    }
229    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
230        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
231            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
232        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
233            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
234        } else {
235            Err(PyTypeError::new_err("Unsupported operand type for *"))
236        }
237    }
238    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
239        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
240            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
241        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
242            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
243        } else {
244            Err(PyTypeError::new_err("Unsupported operand type for *"))
245        }
246    }
247    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
248        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
249            Ok(PyExpression(self.0.clone() / other_aid.0.clone()))
250        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
251            Ok(PyExpression(self.0.clone() / other_expr.0.clone()))
252        } else {
253            Err(PyTypeError::new_err("Unsupported operand type for /"))
254        }
255    }
256    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
257        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
258            Ok(PyExpression(other_aid.0.clone() / self.0.clone()))
259        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
260            Ok(PyExpression(other_expr.0.clone() / self.0.clone()))
261        } else {
262            Err(PyTypeError::new_err("Unsupported operand type for /"))
263        }
264    }
265    fn __neg__(&self) -> PyExpression {
266        PyExpression(-self.0.clone())
267    }
268    fn __str__(&self) -> String {
269        format!("{}", self.0)
270    }
271    fn __repr__(&self) -> String {
272        format!("{:?}", self.0)
273    }
274}
275
276#[pymethods]
277impl PyExpression {
278    /// The real part of a complex Expression
279    ///
280    /// Returns
281    /// -------
282    /// Expression
283    ///     The real part of the given Expression
284    ///
285    fn real(&self) -> PyExpression {
286        PyExpression(self.0.real())
287    }
288    /// The imaginary part of a complex Expression
289    ///
290    /// Returns
291    /// -------
292    /// Expression
293    ///     The imaginary part of the given Expression
294    ///
295    fn imag(&self) -> PyExpression {
296        PyExpression(self.0.imag())
297    }
298    /// The complex conjugate of a complex Expression
299    ///
300    /// Returns
301    /// -------
302    /// Expression
303    ///     The complex conjugate of the given Expression
304    ///
305    fn conj(&self) -> PyExpression {
306        PyExpression(self.0.conj())
307    }
308    /// The norm-squared of a complex Expression
309    ///
310    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
311    ///
312    /// Returns
313    /// -------
314    /// Expression
315    ///     The norm-squared of the given Expression
316    ///
317    fn norm_sqr(&self) -> PyExpression {
318        PyExpression(self.0.norm_sqr())
319    }
320    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
321        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
322            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
323        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
324            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
325        } else if let Ok(other_int) = other.extract::<usize>() {
326            if other_int == 0 {
327                Ok(PyExpression(self.0.clone()))
328            } else {
329                Err(PyTypeError::new_err(
330                    "Addition with an integer for this type is only defined for 0",
331                ))
332            }
333        } else {
334            Err(PyTypeError::new_err("Unsupported operand type for +"))
335        }
336    }
337    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
338        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
339            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
340        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
341            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
342        } else if let Ok(other_int) = other.extract::<usize>() {
343            if other_int == 0 {
344                Ok(PyExpression(self.0.clone()))
345            } else {
346                Err(PyTypeError::new_err(
347                    "Addition with an integer for this type is only defined for 0",
348                ))
349            }
350        } else {
351            Err(PyTypeError::new_err("Unsupported operand type for +"))
352        }
353    }
354    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
355        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
356            Ok(PyExpression(self.0.clone() - other_aid.0.clone()))
357        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
358            Ok(PyExpression(self.0.clone() - other_expr.0.clone()))
359        } else {
360            Err(PyTypeError::new_err("Unsupported operand type for -"))
361        }
362    }
363    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
364        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
365            Ok(PyExpression(other_aid.0.clone() - self.0.clone()))
366        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
367            Ok(PyExpression(other_expr.0.clone() - self.0.clone()))
368        } else {
369            Err(PyTypeError::new_err("Unsupported operand type for -"))
370        }
371    }
372    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
373        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
374            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
375        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
376            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
377        } else {
378            Err(PyTypeError::new_err("Unsupported operand type for *"))
379        }
380    }
381    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
382        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
383            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
384        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
385            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
386        } else {
387            Err(PyTypeError::new_err("Unsupported operand type for *"))
388        }
389    }
390    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
391        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
392            Ok(PyExpression(self.0.clone() / other_aid.0.clone()))
393        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
394            Ok(PyExpression(self.0.clone() / other_expr.0.clone()))
395        } else {
396            Err(PyTypeError::new_err("Unsupported operand type for /"))
397        }
398    }
399    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
400        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
401            Ok(PyExpression(other_aid.0.clone() / self.0.clone()))
402        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
403            Ok(PyExpression(other_expr.0.clone() / self.0.clone()))
404        } else {
405            Err(PyTypeError::new_err("Unsupported operand type for /"))
406        }
407    }
408    fn __neg__(&self) -> PyExpression {
409        PyExpression(-self.0.clone())
410    }
411    fn __str__(&self) -> String {
412        format!("{}", self.0)
413    }
414    fn __repr__(&self) -> String {
415        format!("{:?}", self.0)
416    }
417}
418
419/// A class which can be used to register Amplitudes and store precalculated data
420///
421#[pyclass(name = "Manager", module = "laddu")]
422pub struct PyManager(Manager);
423
424#[pymethods]
425impl PyManager {
426    #[new]
427    fn new() -> Self {
428        Self(Manager::default())
429    }
430    /// The free parameters used by the Manager
431    ///
432    /// Returns
433    /// -------
434    /// parameters : list of str
435    ///     The list of parameter names
436    ///
437    #[getter]
438    fn parameters(&self) -> Vec<String> {
439        self.0.parameters()
440    }
441    /// Register an Amplitude with the Manager
442    ///
443    /// Parameters
444    /// ----------
445    /// amplitude : Amplitude
446    ///     The Amplitude to register
447    ///
448    /// Returns
449    /// -------
450    /// AmplitudeID
451    ///     A reference to the registered `amplitude` that can be used to form complex
452    ///     Expressions
453    ///
454    /// Raises
455    /// ------
456    /// ValueError
457    ///     If the name of the ``amplitude`` has already been registered
458    ///
459    fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
460        Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
461    }
462    /// Generate a Model from the given expression made of registered Amplitudes
463    ///
464    /// Parameters
465    /// ----------
466    /// expression : Expression or AmplitudeID
467    ///     The expression to use in precalculation
468    ///
469    /// Returns
470    /// -------
471    /// Model
472    ///     An object which represents the underlying mathematical model and can be loaded with
473    ///     a Dataset
474    ///
475    /// Raises
476    /// ------
477    /// TypeError
478    ///     If the expression is not convertable to a Model
479    ///
480    /// Notes
481    /// -----
482    /// While the given `expression` will be the one evaluated in the end, all registered
483    /// Amplitudes will be loaded, and all of their parameters will be included in the final
484    /// expression. These parameters will have no effect on evaluation, but they must be
485    /// included in function calls.
486    ///
487    fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
488        let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
489            Ok(expression.0)
490        } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
491            Ok(Expression::Amp(aid.0))
492        } else {
493            Err(PyTypeError::new_err(
494                "'expression' must either by an Expression or AmplitudeID",
495            ))
496        }?;
497        Ok(PyModel(self.0.model(&expression)))
498    }
499}
500
501/// A class which represents a model composed of registered Amplitudes
502///
503#[pyclass(name = "Model", module = "laddu")]
504pub struct PyModel(pub Model);
505
506#[pymethods]
507impl PyModel {
508    /// The free parameters used by the Manager
509    ///
510    /// Returns
511    /// -------
512    /// parameters : list of str
513    ///     The list of parameter names
514    ///
515    #[getter]
516    fn parameters(&self) -> Vec<String> {
517        self.0.parameters()
518    }
519    /// Load a Model by precalculating each term over the given Dataset
520    ///
521    /// Parameters
522    /// ----------
523    /// dataset : Dataset
524    ///     The Dataset to use in precalculation
525    ///
526    /// Returns
527    /// -------
528    /// Evaluator
529    ///     An object that can be used to evaluate the `expression` over each event in the
530    ///     `dataset`
531    ///
532    /// Notes
533    /// -----
534    /// While the given `expression` will be the one evaluated in the end, all registered
535    /// Amplitudes will be loaded, and all of their parameters will be included in the final
536    /// expression. These parameters will have no effect on evaluation, but they must be
537    /// included in function calls.
538    ///
539    fn load(&self, dataset: &PyDataset) -> PyEvaluator {
540        PyEvaluator(self.0.load(&dataset.0))
541    }
542    /// Save the Model to a file
543    ///
544    /// Parameters
545    /// ----------
546    /// path : str
547    ///     The path of the new file (overwrites if the file exists!)
548    ///
549    /// Raises
550    /// ------
551    /// IOError
552    ///     If anything fails when trying to write the file
553    ///
554    fn save_as(&self, path: &str) -> PyResult<()> {
555        self.0.save_as(path)?;
556        Ok(())
557    }
558    /// Load a Model from a file
559    ///
560    /// Parameters
561    /// ----------
562    /// path : str
563    ///     The path of the existing fit file
564    ///
565    /// Returns
566    /// -------
567    /// Model
568    ///     The model contained in the file
569    ///
570    /// Raises
571    /// ------
572    /// IOError
573    ///     If anything fails when trying to read the file
574    ///
575    #[staticmethod]
576    fn load_from(path: &str) -> PyResult<Self> {
577        Ok(PyModel(Model::load_from(path)?))
578    }
579    #[new]
580    fn new() -> Self {
581        PyModel(Model::create_null())
582    }
583    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
584        Ok(PyBytes::new(
585            py,
586            bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
587                .map_err(LadduError::EncodeError)?
588                .as_slice(),
589        ))
590    }
591    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
592        *self = PyModel(
593            bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
594                .map_err(LadduError::DecodeError)?
595                .0,
596        );
597        Ok(())
598    }
599}
600
601/// An Amplitude which can be registered by a Manager
602///
603/// See Also
604/// --------
605/// laddu.Manager
606///
607#[pyclass(name = "Amplitude", module = "laddu")]
608pub struct PyAmplitude(pub Box<dyn Amplitude>);
609
610/// A class which can be used to evaluate a stored Expression
611///
612/// See Also
613/// --------
614/// laddu.Manager.load
615///
616#[pyclass(name = "Evaluator", module = "laddu")]
617#[derive(Clone)]
618pub struct PyEvaluator(pub Evaluator);
619
620#[pymethods]
621impl PyEvaluator {
622    /// The free parameters used by the Evaluator
623    ///
624    /// Returns
625    /// -------
626    /// parameters : list of str
627    ///     The list of parameter names
628    ///
629    #[getter]
630    fn parameters(&self) -> Vec<String> {
631        self.0.parameters()
632    }
633    /// Activates Amplitudes in the Expression by name
634    ///
635    /// Parameters
636    /// ----------
637    /// arg : str or list of str
638    ///     Names of Amplitudes to be activated
639    ///
640    /// Raises
641    /// ------
642    /// TypeError
643    ///     If `arg` is not a str or list of str
644    /// ValueError
645    ///     If `arg` or any items of `arg` are not registered Amplitudes
646    ///
647    fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
648        if let Ok(string_arg) = arg.extract::<String>() {
649            self.0.activate(&string_arg)?;
650        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
651            let vec: Vec<String> = list_arg.extract()?;
652            self.0.activate_many(&vec)?;
653        } else {
654            return Err(PyTypeError::new_err(
655                "Argument must be either a string or a list of strings",
656            ));
657        }
658        Ok(())
659    }
660    /// Activates all Amplitudes in the Expression
661    ///
662    fn activate_all(&self) {
663        self.0.activate_all();
664    }
665    /// Deactivates Amplitudes in the Expression by name
666    ///
667    /// Deactivated Amplitudes act as zeros in the Expression
668    ///
669    /// Parameters
670    /// ----------
671    /// arg : str or list of str
672    ///     Names of Amplitudes to be deactivated
673    ///
674    /// Raises
675    /// ------
676    /// TypeError
677    ///     If `arg` is not a str or list of str
678    /// ValueError
679    ///     If `arg` or any items of `arg` are not registered Amplitudes
680    ///
681    fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
682        if let Ok(string_arg) = arg.extract::<String>() {
683            self.0.deactivate(&string_arg)?;
684        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
685            let vec: Vec<String> = list_arg.extract()?;
686            self.0.deactivate_many(&vec)?;
687        } else {
688            return Err(PyTypeError::new_err(
689                "Argument must be either a string or a list of strings",
690            ));
691        }
692        Ok(())
693    }
694    /// Deactivates all Amplitudes in the Expression
695    ///
696    fn deactivate_all(&self) {
697        self.0.deactivate_all();
698    }
699    /// Isolates Amplitudes in the Expression by name
700    ///
701    /// Activates the Amplitudes given in `arg` and deactivates the rest
702    ///
703    /// Parameters
704    /// ----------
705    /// arg : str or list of str
706    ///     Names of Amplitudes to be isolated
707    ///
708    /// Raises
709    /// ------
710    /// TypeError
711    ///     If `arg` is not a str or list of str
712    /// ValueError
713    ///     If `arg` or any items of `arg` are not registered Amplitudes
714    ///
715    fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
716        if let Ok(string_arg) = arg.extract::<String>() {
717            self.0.isolate(&string_arg)?;
718        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
719            let vec: Vec<String> = list_arg.extract()?;
720            self.0.isolate_many(&vec)?;
721        } else {
722            return Err(PyTypeError::new_err(
723                "Argument must be either a string or a list of strings",
724            ));
725        }
726        Ok(())
727    }
728    /// Evaluate the stored Expression over the stored Dataset
729    ///
730    /// Parameters
731    /// ----------
732    /// parameters : list of float
733    ///     The values to use for the free parameters
734    /// threads : int, optional
735    ///     The number of threads to use (setting this to None will use all available CPUs)
736    ///
737    /// Returns
738    /// -------
739    /// result : array_like
740    ///     A ``numpy`` array of complex values for each Event in the Dataset
741    ///
742    /// Raises
743    /// ------
744    /// Exception
745    ///     If there was an error building the thread pool
746    ///
747    #[pyo3(signature = (parameters, *, threads=None))]
748    fn evaluate<'py>(
749        &self,
750        py: Python<'py>,
751        parameters: Vec<Float>,
752        threads: Option<usize>,
753    ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
754        #[cfg(feature = "rayon")]
755        {
756            Ok(PyArray1::from_slice(
757                py,
758                &ThreadPoolBuilder::new()
759                    .num_threads(threads.unwrap_or_else(num_cpus::get))
760                    .build()
761                    .map_err(LadduError::from)?
762                    .install(|| self.0.evaluate(&parameters)),
763            ))
764        }
765        #[cfg(not(feature = "rayon"))]
766        {
767            Ok(PyArray1::from_slice(py, &self.0.evaluate(&parameters)))
768        }
769    }
770    /// Evaluate the gradient of the stored Expression over the stored Dataset
771    ///
772    /// Parameters
773    /// ----------
774    /// parameters : list of float
775    ///     The values to use for the free parameters
776    /// threads : int, optional
777    ///     The number of threads to use (setting this to None will use all available CPUs)
778    ///
779    /// Returns
780    /// -------
781    /// result : array_like
782    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
783    ///
784    /// Raises
785    /// ------
786    /// Exception
787    ///     If there was an error building the thread pool or problem creating the resulting
788    ///     ``numpy`` array
789    ///
790    #[pyo3(signature = (parameters, *, threads=None))]
791    fn evaluate_gradient<'py>(
792        &self,
793        py: Python<'py>,
794        parameters: Vec<Float>,
795        threads: Option<usize>,
796    ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
797        #[cfg(feature = "rayon")]
798        {
799            Ok(PyArray2::from_vec2(
800                py,
801                &ThreadPoolBuilder::new()
802                    .num_threads(threads.unwrap_or_else(num_cpus::get))
803                    .build()
804                    .map_err(LadduError::from)?
805                    .install(|| {
806                        self.0
807                            .evaluate_gradient(&parameters)
808                            .iter()
809                            .map(|grad| grad.data.as_vec().to_vec())
810                            .collect::<Vec<Vec<Complex<Float>>>>()
811                    }),
812            )
813            .map_err(LadduError::NumpyError)?)
814        }
815        #[cfg(not(feature = "rayon"))]
816        {
817            Ok(PyArray2::from_vec2(
818                py,
819                &self
820                    .0
821                    .evaluate_gradient(&parameters)
822                    .iter()
823                    .map(|grad| grad.data.as_vec().to_vec())
824                    .collect::<Vec<Vec<Complex<Float>>>>(),
825            )
826            .map_err(LadduError::NumpyError)?)
827        }
828    }
829}
830
831/// A class, typically used to allow Amplitudes to take either free parameters or constants as
832/// inputs
833///
834/// See Also
835/// --------
836/// laddu.parameter
837/// laddu.constant
838///
839#[pyclass(name = "ParameterLike", module = "laddu")]
840#[derive(Clone)]
841pub struct PyParameterLike(pub ParameterLike);
842
843/// A free parameter which floats during an optimization
844///
845/// Parameters
846/// ----------
847/// name : str
848///     The name of the free parameter
849///
850/// Returns
851/// -------
852/// laddu.ParameterLike
853///     An object that can be used as the input for many Amplitude constructors
854///
855/// Notes
856/// -----
857/// Two free parameters with the same name are shared in a fit
858///
859#[pyfunction(name = "parameter")]
860pub fn py_parameter(name: &str) -> PyParameterLike {
861    PyParameterLike(parameter(name))
862}
863
864/// A term which stays constant during an optimization
865///
866/// Parameters
867/// ----------
868/// value : float
869///     The numerical value of the constant
870///
871/// Returns
872/// -------
873/// laddu.ParameterLike
874///     An object that can be used as the input for many Amplitude constructors
875///
876#[pyfunction(name = "constant")]
877pub fn py_constant(value: Float) -> PyParameterLike {
878    PyParameterLike(constant(value))
879}