Skip to main content

laddu_python/
amplitudes.rs

1use crate::data::PyDataset;
2use laddu_core::{
3    amplitudes::{constant, parameter, Evaluator, Expression, ParameterLike, TestAmplitude},
4    f64, LadduError, ReadWrite,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9    exceptions::PyTypeError,
10    prelude::*,
11    types::{PyBytes, PyList},
12};
13#[cfg(feature = "rayon")]
14use rayon::ThreadPoolBuilder;
15use std::collections::HashMap;
16
17/// A mathematical expression formed from amplitudes.
18///
19#[pyclass(name = "Expression", module = "laddu")]
20#[derive(Clone)]
21pub struct PyExpression(pub Expression);
22
23/// A convenience method to sum sequences of Expressions
24///
25#[pyfunction(name = "expr_sum")]
26pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
27    if terms.is_empty() {
28        return Ok(PyExpression(Expression::zero()));
29    }
30    if terms.len() == 1 {
31        let term = &terms[0];
32        if let Ok(expression) = term.extract::<PyExpression>() {
33            return Ok(expression);
34        }
35        return Err(PyTypeError::new_err("Item is not a PyExpression"));
36    }
37    let mut iter = terms.iter();
38    let Some(first_term) = iter.next() else {
39        return Ok(PyExpression(Expression::zero()));
40    };
41    let PyExpression(mut summation) = first_term
42        .extract::<PyExpression>()
43        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
44    for term in iter {
45        let PyExpression(expr) = term
46            .extract::<PyExpression>()
47            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
48        summation = summation + expr;
49    }
50    Ok(PyExpression(summation))
51}
52
53/// A convenience method to multiply sequences of Expressions
54///
55#[pyfunction(name = "expr_product")]
56pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
57    if terms.is_empty() {
58        return Ok(PyExpression(Expression::one()));
59    }
60    if terms.len() == 1 {
61        let term = &terms[0];
62        if let Ok(expression) = term.extract::<PyExpression>() {
63            return Ok(expression);
64        }
65        return Err(PyTypeError::new_err("Item is not a PyExpression"));
66    }
67    let mut iter = terms.iter();
68    let Some(first_term) = iter.next() else {
69        return Ok(PyExpression(Expression::one()));
70    };
71    let PyExpression(mut product) = first_term
72        .extract::<PyExpression>()
73        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
74    for term in iter {
75        let PyExpression(expr) = term
76            .extract::<PyExpression>()
77            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
78        product = product * expr;
79    }
80    Ok(PyExpression(product))
81}
82
83/// A convenience class representing a zero-valued Expression
84///
85#[pyfunction(name = "Zero")]
86pub fn py_expr_zero() -> PyExpression {
87    PyExpression(Expression::zero())
88}
89
90/// A convenience class representing a unit-valued Expression
91///
92#[pyfunction(name = "One")]
93pub fn py_expr_one() -> PyExpression {
94    PyExpression(Expression::one())
95}
96
97#[pymethods]
98impl PyExpression {
99    /// The free parameters used by the Expression
100    ///
101    /// Returns
102    /// -------
103    /// parameters : list of str
104    ///     The list of parameter names
105    #[getter]
106    fn parameters(&self) -> Vec<String> {
107        self.0.parameters()
108    }
109    /// The free parameters used by the Expression
110    #[getter]
111    fn free_parameters(&self) -> Vec<String> {
112        self.0.free_parameters()
113    }
114    /// The fixed parameters used by the Expression
115    #[getter]
116    fn fixed_parameters(&self) -> Vec<String> {
117        self.0.fixed_parameters()
118    }
119    /// Number of free parameters
120    #[getter]
121    fn n_free(&self) -> usize {
122        self.0.n_free()
123    }
124    /// Number of fixed parameters
125    #[getter]
126    fn n_fixed(&self) -> usize {
127        self.0.n_fixed()
128    }
129    /// Total number of parameters
130    #[getter]
131    fn n_parameters(&self) -> usize {
132        self.0.n_parameters()
133    }
134    /// Load an Expression by precalculating each term over the given Dataset
135    ///
136    /// Parameters
137    /// ----------
138    /// dataset : Dataset
139    ///     The Dataset to use in precalculation
140    ///
141    /// Returns
142    /// -------
143    /// Evaluator
144    ///     An object that can be used to evaluate the `expression` over each event in the
145    ///     `dataset`
146    fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
147        Ok(PyEvaluator(self.0.load(&dataset.0)?))
148    }
149    /// The real part of a complex Expression
150    fn real(&self) -> PyExpression {
151        PyExpression(self.0.real())
152    }
153    /// The imaginary part of a complex Expression
154    fn imag(&self) -> PyExpression {
155        PyExpression(self.0.imag())
156    }
157    /// The complex conjugate of a complex Expression
158    fn conj(&self) -> PyExpression {
159        PyExpression(self.0.conj())
160    }
161    /// The norm-squared of a complex Expression
162    fn norm_sqr(&self) -> PyExpression {
163        PyExpression(self.0.norm_sqr())
164    }
165    /// Return a new Expression with the given parameter fixed
166    fn fix(&self, name: &str, value: f64) -> PyResult<PyExpression> {
167        Ok(PyExpression(self.0.fix(name, value)?))
168    }
169    /// Return a new Expression with the given parameter freed
170    fn free(&self, name: &str) -> PyResult<PyExpression> {
171        Ok(PyExpression(self.0.free(name)?))
172    }
173    /// Return a new Expression with a single parameter renamed
174    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyExpression> {
175        Ok(PyExpression(self.0.rename_parameter(old, new)?))
176    }
177    /// Return a new Expression with several parameters renamed
178    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyExpression> {
179        Ok(PyExpression(self.0.rename_parameters(&mapping)?))
180    }
181    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
182        if let Ok(other_expr) = other.extract::<PyExpression>() {
183            Ok(PyExpression(self.0.clone() + other_expr.0))
184        } else if let Ok(other_int) = other.extract::<usize>() {
185            if other_int == 0 {
186                Ok(PyExpression(self.0.clone()))
187            } else {
188                Err(PyTypeError::new_err(
189                    "Addition with an integer for this type is only defined for 0",
190                ))
191            }
192        } else {
193            Err(PyTypeError::new_err("Unsupported operand type for +"))
194        }
195    }
196    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
197        if let Ok(other_expr) = other.extract::<PyExpression>() {
198            Ok(PyExpression(other_expr.0 + self.0.clone()))
199        } else if let Ok(other_int) = other.extract::<usize>() {
200            if other_int == 0 {
201                Ok(PyExpression(self.0.clone()))
202            } else {
203                Err(PyTypeError::new_err(
204                    "Addition with an integer for this type is only defined for 0",
205                ))
206            }
207        } else {
208            Err(PyTypeError::new_err("Unsupported operand type for +"))
209        }
210    }
211    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
212        if let Ok(other_expr) = other.extract::<PyExpression>() {
213            Ok(PyExpression(self.0.clone() - other_expr.0))
214        } else {
215            Err(PyTypeError::new_err("Unsupported operand type for -"))
216        }
217    }
218    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
219        if let Ok(other_expr) = other.extract::<PyExpression>() {
220            Ok(PyExpression(other_expr.0 - self.0.clone()))
221        } else {
222            Err(PyTypeError::new_err("Unsupported operand type for -"))
223        }
224    }
225    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
226        if let Ok(other_expr) = other.extract::<PyExpression>() {
227            Ok(PyExpression(self.0.clone() * other_expr.0))
228        } else {
229            Err(PyTypeError::new_err("Unsupported operand type for *"))
230        }
231    }
232    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
233        if let Ok(other_expr) = other.extract::<PyExpression>() {
234            Ok(PyExpression(other_expr.0 * self.0.clone()))
235        } else {
236            Err(PyTypeError::new_err("Unsupported operand type for *"))
237        }
238    }
239    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
240        if let Ok(other_expr) = other.extract::<PyExpression>() {
241            Ok(PyExpression(self.0.clone() / other_expr.0))
242        } else {
243            Err(PyTypeError::new_err("Unsupported operand type for /"))
244        }
245    }
246    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
247        if let Ok(other_expr) = other.extract::<PyExpression>() {
248            Ok(PyExpression(other_expr.0 / self.0.clone()))
249        } else {
250            Err(PyTypeError::new_err("Unsupported operand type for /"))
251        }
252    }
253    fn __neg__(&self) -> PyExpression {
254        PyExpression(-self.0.clone())
255    }
256    fn __str__(&self) -> String {
257        format!("{}", self.0)
258    }
259    fn __repr__(&self) -> String {
260        format!("{:?}", self.0)
261    }
262
263    #[new]
264    fn new() -> Self {
265        Self(Expression::create_null())
266    }
267    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
268        Ok(PyBytes::new(
269            py,
270            bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
271                .map_err(LadduError::EncodeError)?
272                .as_slice(),
273        ))
274    }
275    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
276        *self = Self(
277            bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
278                .map_err(LadduError::DecodeError)?
279                .0,
280        );
281        Ok(())
282    }
283}
284
285/// A class which can be used to evaluate a stored Expression
286///
287/// See Also
288/// --------
289/// laddu.Expression.load
290///
291#[pyclass(name = "Evaluator", module = "laddu")]
292#[derive(Clone)]
293pub struct PyEvaluator(pub Evaluator);
294
295#[pymethods]
296impl PyEvaluator {
297    /// The free parameters used by the Evaluator
298    ///
299    /// Returns
300    /// -------
301    /// parameters : list of str
302    ///     The list of parameter names
303    ///
304    #[getter]
305    fn parameters(&self) -> Vec<String> {
306        self.0.parameters()
307    }
308    /// The free parameters used by the Evaluator
309    #[getter]
310    fn free_parameters(&self) -> Vec<String> {
311        self.0.free_parameters()
312    }
313    /// The fixed parameters used by the Evaluator
314    #[getter]
315    fn fixed_parameters(&self) -> Vec<String> {
316        self.0.fixed_parameters()
317    }
318    /// Number of free parameters
319    #[getter]
320    fn n_free(&self) -> usize {
321        self.0.n_free()
322    }
323    /// Number of fixed parameters
324    #[getter]
325    fn n_fixed(&self) -> usize {
326        self.0.n_fixed()
327    }
328    /// Total number of parameters
329    #[getter]
330    fn n_parameters(&self) -> usize {
331        self.0.n_parameters()
332    }
333    /// Return a new Evaluator with the given parameter fixed
334    fn fix(&self, name: &str, value: f64) -> PyResult<PyEvaluator> {
335        Ok(PyEvaluator(self.0.fix(name, value)?))
336    }
337    /// Return a new Evaluator with the given parameter freed
338    fn free(&self, name: &str) -> PyResult<PyEvaluator> {
339        Ok(PyEvaluator(self.0.free(name)?))
340    }
341    /// Return a new Evaluator with a single parameter renamed
342    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyEvaluator> {
343        Ok(PyEvaluator(self.0.rename_parameter(old, new)?))
344    }
345    /// Return a new Evaluator with several parameters renamed
346    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyEvaluator> {
347        Ok(PyEvaluator(self.0.rename_parameters(&mapping)?))
348    }
349    /// Activates Amplitudes in the Expression by name
350    ///
351    /// Parameters
352    /// ----------
353    /// arg : str or list of str
354    ///     Names of Amplitudes to be activated
355    ///
356    /// Raises
357    /// ------
358    /// TypeError
359    ///     If `arg` is not a str or list of str
360    /// ValueError
361    ///     If `arg` or any items of `arg` are not registered Amplitudes
362    /// strict : bool, default=True
363    ///     When ``True``, raise an error if any amplitude is missing. When ``False``,
364    ///     silently skip missing amplitudes.
365    #[pyo3(signature = (arg, *, strict=true))]
366    fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
367        if let Ok(string_arg) = arg.extract::<String>() {
368            if strict {
369                self.0.activate_strict(&string_arg)?;
370            } else {
371                self.0.activate(&string_arg);
372            }
373        } else if let Ok(list_arg) = arg.cast::<PyList>() {
374            let vec: Vec<String> = list_arg.extract()?;
375            if strict {
376                self.0.activate_many_strict(&vec)?;
377            } else {
378                self.0.activate_many(&vec);
379            }
380        } else {
381            return Err(PyTypeError::new_err(
382                "Argument must be either a string or a list of strings",
383            ));
384        }
385        Ok(())
386    }
387    /// Activates all Amplitudes in the Expression
388    ///
389    fn activate_all(&self) {
390        self.0.activate_all();
391    }
392    /// Deactivates Amplitudes in the Expression by name
393    ///
394    /// Deactivated Amplitudes act as zeros in the Expression
395    ///
396    /// Parameters
397    /// ----------
398    /// arg : str or list of str
399    ///     Names of Amplitudes to be deactivated
400    ///
401    /// Raises
402    /// ------
403    /// TypeError
404    ///     If `arg` is not a str or list of str
405    /// ValueError
406    ///     If `arg` or any items of `arg` are not registered Amplitudes
407    /// strict : bool, default=True
408    ///     When ``True``, raise an error if any amplitude is missing. When ``False``,
409    ///     silently skip missing amplitudes.
410    #[pyo3(signature = (arg, *, strict=true))]
411    fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
412        if let Ok(string_arg) = arg.extract::<String>() {
413            if strict {
414                self.0.deactivate_strict(&string_arg)?;
415            } else {
416                self.0.deactivate(&string_arg);
417            }
418        } else if let Ok(list_arg) = arg.cast::<PyList>() {
419            let vec: Vec<String> = list_arg.extract()?;
420            if strict {
421                self.0.deactivate_many_strict(&vec)?;
422            } else {
423                self.0.deactivate_many(&vec);
424            }
425        } else {
426            return Err(PyTypeError::new_err(
427                "Argument must be either a string or a list of strings",
428            ));
429        }
430        Ok(())
431    }
432    /// Deactivates all Amplitudes in the Expression
433    ///
434    fn deactivate_all(&self) {
435        self.0.deactivate_all();
436    }
437    /// Isolates Amplitudes in the Expression by name
438    ///
439    /// Activates the Amplitudes given in `arg` and deactivates the rest
440    ///
441    /// Parameters
442    /// ----------
443    /// arg : str or list of str
444    ///     Names of Amplitudes to be isolated
445    ///
446    /// Raises
447    /// ------
448    /// TypeError
449    ///     If `arg` is not a str or list of str
450    /// ValueError
451    ///     If `arg` or any items of `arg` are not registered Amplitudes
452    /// strict : bool, default=True
453    ///     When ``True``, raise an error if any amplitude is missing. When ``False``,
454    ///     silently skip missing amplitudes.
455    #[pyo3(signature = (arg, *, strict=true))]
456    fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
457        if let Ok(string_arg) = arg.extract::<String>() {
458            if strict {
459                self.0.isolate_strict(&string_arg)?;
460            } else {
461                self.0.isolate(&string_arg);
462            }
463        } else if let Ok(list_arg) = arg.cast::<PyList>() {
464            let vec: Vec<String> = list_arg.extract()?;
465            if strict {
466                self.0.isolate_many_strict(&vec)?;
467            } else {
468                self.0.isolate_many(&vec);
469            }
470        } else {
471            return Err(PyTypeError::new_err(
472                "Argument must be either a string or a list of strings",
473            ));
474        }
475        Ok(())
476    }
477    /// Evaluate the stored Expression over the stored Dataset
478    ///
479    /// Parameters
480    /// ----------
481    /// parameters : list of float
482    ///     The values to use for the free parameters
483    /// threads : int, optional
484    ///     The number of threads to use (setting this to None will use all available CPUs)
485    ///
486    /// Returns
487    /// -------
488    /// result : array_like
489    ///     A ``numpy`` array of complex values for each Event in the Dataset
490    ///
491    /// Raises
492    /// ------
493    /// Exception
494    ///     If there was an error building the thread pool
495    ///
496    #[pyo3(signature = (parameters, *, threads=None))]
497    #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
498    fn evaluate<'py>(
499        &self,
500        py: Python<'py>,
501        parameters: Vec<f64>,
502        threads: Option<usize>,
503    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
504        #[cfg(feature = "rayon")]
505        {
506            Ok(PyArray1::from_slice(
507                py,
508                &ThreadPoolBuilder::new()
509                    .num_threads(threads.unwrap_or_else(num_cpus::get))
510                    .build()
511                    .map_err(LadduError::from)?
512                    .install(|| self.0.evaluate(&parameters)),
513            ))
514        }
515        #[cfg(not(feature = "rayon"))]
516        {
517            Ok(PyArray1::from_slice(py, &self.0.evaluate(&parameters)))
518        }
519    }
520    /// Evaluate the stored Expression over a subset of the stored Dataset
521    ///
522    /// Parameters
523    /// ----------
524    /// parameters : list of float
525    ///     The values to use for the free parameters
526    /// indices : list of int
527    ///     The indices of events to evaluate
528    /// threads : int, optional
529    ///     The number of threads to use (setting this to None will use all available CPUs)
530    ///
531    /// Returns
532    /// -------
533    /// result : array_like
534    ///     A ``numpy`` array of complex values for each indexed Event in the Dataset
535    ///
536    /// Raises
537    /// ------
538    /// Exception
539    ///     If there was an error building the thread pool
540    ///
541    #[pyo3(signature = (parameters, indices, *, threads=None))]
542    #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
543    fn evaluate_batch<'py>(
544        &self,
545        py: Python<'py>,
546        parameters: Vec<f64>,
547        indices: Vec<usize>,
548        threads: Option<usize>,
549    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
550        #[cfg(feature = "rayon")]
551        {
552            Ok(PyArray1::from_slice(
553                py,
554                &ThreadPoolBuilder::new()
555                    .num_threads(threads.unwrap_or_else(num_cpus::get))
556                    .build()
557                    .map_err(LadduError::from)?
558                    .install(|| self.0.evaluate_batch(&parameters, &indices)),
559            ))
560        }
561        #[cfg(not(feature = "rayon"))]
562        {
563            Ok(PyArray1::from_slice(
564                py,
565                &self.0.evaluate_batch(&parameters, &indices),
566            ))
567        }
568    }
569    /// Evaluate the gradient of the stored Expression over the stored Dataset
570    ///
571    /// Parameters
572    /// ----------
573    /// parameters : list of float
574    ///     The values to use for the free parameters
575    /// threads : int, optional
576    ///     The number of threads to use (setting this to None will use all available CPUs)
577    ///
578    /// Returns
579    /// -------
580    /// result : array_like
581    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
582    ///
583    /// Raises
584    /// ------
585    /// Exception
586    ///     If there was an error building the thread pool or problem creating the resulting
587    ///     ``numpy`` array
588    ///
589    #[pyo3(signature = (parameters, *, threads=None))]
590    #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
591    fn evaluate_gradient<'py>(
592        &self,
593        py: Python<'py>,
594        parameters: Vec<f64>,
595        threads: Option<usize>,
596    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
597        #[cfg(feature = "rayon")]
598        {
599            Ok(PyArray2::from_vec2(
600                py,
601                &ThreadPoolBuilder::new()
602                    .num_threads(threads.unwrap_or_else(num_cpus::get))
603                    .build()
604                    .map_err(LadduError::from)?
605                    .install(|| {
606                        self.0
607                            .evaluate_gradient(&parameters)
608                            .iter()
609                            .map(|grad| grad.data.as_vec().to_vec())
610                            .collect::<Vec<Vec<Complex64>>>()
611                    }),
612            )
613            .map_err(LadduError::NumpyError)?)
614        }
615        #[cfg(not(feature = "rayon"))]
616        {
617            Ok(PyArray2::from_vec2(
618                py,
619                &self
620                    .0
621                    .evaluate_gradient(&parameters)
622                    .iter()
623                    .map(|grad| grad.data.as_vec().to_vec())
624                    .collect::<Vec<Vec<Complex64>>>(),
625            )
626            .map_err(LadduError::NumpyError)?)
627        }
628    }
629    /// Evaluate the gradient of the stored Expression over a subset of the stored Dataset
630    ///
631    /// Parameters
632    /// ----------
633    /// parameters : list of float
634    ///     The values to use for the free parameters
635    /// indices : list of int
636    ///     The indices of events to evaluate
637    /// threads : int, optional
638    ///     The number of threads to use (setting this to None will use all available CPUs)
639    ///
640    /// Returns
641    /// -------
642    /// result : array_like
643    ///     A ``numpy`` 2D array of complex values for each indexed Event in the Dataset
644    ///
645    /// Raises
646    /// ------
647    /// Exception
648    ///     If there was an error building the thread pool or problem creating the resulting
649    ///     ``numpy`` array
650    ///
651    #[pyo3(signature = (parameters, indices, *, threads=None))]
652    #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
653    fn evaluate_gradient_batch<'py>(
654        &self,
655        py: Python<'py>,
656        parameters: Vec<f64>,
657        indices: Vec<usize>,
658        threads: Option<usize>,
659    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
660        #[cfg(feature = "rayon")]
661        {
662            Ok(PyArray2::from_vec2(
663                py,
664                &ThreadPoolBuilder::new()
665                    .num_threads(threads.unwrap_or_else(num_cpus::get))
666                    .build()
667                    .map_err(LadduError::from)?
668                    .install(|| {
669                        self.0
670                            .evaluate_gradient_batch(&parameters, &indices)
671                            .iter()
672                            .map(|grad| grad.data.as_vec().to_vec())
673                            .collect::<Vec<Vec<Complex64>>>()
674                    }),
675            )
676            .map_err(LadduError::NumpyError)?)
677        }
678        #[cfg(not(feature = "rayon"))]
679        {
680            Ok(PyArray2::from_vec2(
681                py,
682                &self
683                    .0
684                    .evaluate_gradient_batch(&parameters, &indices)
685                    .iter()
686                    .map(|grad| grad.data.as_vec().to_vec())
687                    .collect::<Vec<Vec<Complex64>>>(),
688            )
689            .map_err(LadduError::NumpyError)?)
690        }
691    }
692}
693
694/// A class, typically used to allow Amplitudes to take either free parameters or constants as
695/// inputs
696///
697/// See Also
698/// --------
699/// laddu.parameter
700/// laddu.constant
701///
702#[pyclass(name = "ParameterLike", module = "laddu")]
703#[derive(Clone)]
704pub struct PyParameterLike(pub ParameterLike);
705
706/// A free parameter which floats during an optimization
707///
708/// Parameters
709/// ----------
710/// name : str
711///     The name of the free parameter
712///
713/// Returns
714/// -------
715/// laddu.ParameterLike
716///     An object that can be used as the input for many Amplitude constructors
717///
718/// Notes
719/// -----
720/// Two free parameters with the same name are shared in a fit
721///
722#[pyfunction(name = "parameter")]
723pub fn py_parameter(name: &str) -> PyParameterLike {
724    PyParameterLike(parameter(name))
725}
726
727/// A term which stays constant during an optimization
728///
729/// Parameters
730/// ----------
731/// name : str
732///     The name of the parameter
733/// value : float
734///     The numerical value of the constant
735///
736/// Returns
737/// -------
738/// laddu.ParameterLike
739///     An object that can be used as the input for many Amplitude constructors
740///
741#[pyfunction(name = "constant")]
742pub fn py_constant(name: &str, value: f64) -> PyParameterLike {
743    PyParameterLike(constant(name, value))
744}
745
746/// An amplitude used only for internal testing which evaluates `(p0 + i * p1) * event.p4s\[0\].e`.
747#[pyfunction(name = "TestAmplitude")]
748pub fn py_test_amplitude(
749    name: &str,
750    re: PyParameterLike,
751    im: PyParameterLike,
752) -> PyResult<PyExpression> {
753    Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
754}