laddu_python/
amplitudes.rs

1use crate::data::PyDataset;
2use laddu_core::{
3    amplitudes::{
4        constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
5        ParameterLike,
6    },
7    traits::ReadWrite,
8    Complex, Float, LadduError,
9};
10use numpy::{PyArray1, PyArray2};
11use pyo3::{
12    exceptions::PyTypeError,
13    prelude::*,
14    types::{PyBytes, PyList},
15};
16#[cfg(feature = "rayon")]
17use rayon::ThreadPoolBuilder;
18
19/// An object which holds a registered ``Amplitude``
20///
21/// See Also
22/// --------
23/// laddu.Manager.register
24///
25#[pyclass(name = "AmplitudeID", module = "laddu")]
26#[derive(Clone)]
27pub struct PyAmplitudeID(AmplitudeID);
28
29/// A mathematical expression formed from AmplitudeIDs
30///
31#[pyclass(name = "Expression", module = "laddu")]
32#[derive(Clone)]
33pub struct PyExpression(Expression);
34
35/// A convenience method to sum sequences of Amplitudes
36///
37#[pyfunction(name = "amplitude_sum")]
38pub fn py_amplitude_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
39    if terms.is_empty() {
40        return Ok(PyExpression(Expression::Zero));
41    }
42    if terms.len() == 1 {
43        let term = &terms[0];
44        if let Ok(expression) = term.extract::<PyExpression>() {
45            return Ok(expression);
46        }
47        if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
48            return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
49        }
50        return Err(PyTypeError::new_err(
51            "Item is neither a PyExpression nor a PyAmplitudeID",
52        ));
53    }
54    let mut iter = terms.iter();
55    let Some(first_term) = iter.next() else {
56        return Ok(PyExpression(Expression::Zero));
57    };
58    if let Ok(first_expression) = first_term.extract::<PyExpression>() {
59        let mut summation = first_expression.clone();
60        for term in iter {
61            summation = summation.__add__(term)?;
62        }
63        return Ok(summation);
64    }
65    if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
66        let mut summation = PyExpression(Expression::Amp(first_amplitude_id.0));
67        for term in iter {
68            summation = summation.__add__(term)?;
69        }
70        return Ok(summation);
71    }
72    Err(PyTypeError::new_err(
73        "Elements must be PyExpression or PyAmplitudeID",
74    ))
75}
76
77/// A convenience method to multiply sequences of Amplitudes
78///
79#[pyfunction(name = "amplitude_product")]
80pub fn py_amplitude_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
81    if terms.is_empty() {
82        return Ok(PyExpression(Expression::One));
83    }
84    if terms.len() == 1 {
85        let term = &terms[0];
86        if let Ok(expression) = term.extract::<PyExpression>() {
87            return Ok(expression);
88        }
89        if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
90            return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
91        }
92        return Err(PyTypeError::new_err(
93            "Item is neither a PyExpression nor a PyAmplitudeID",
94        ));
95    }
96    let mut iter = terms.iter();
97    let Some(first_term) = iter.next() else {
98        return Ok(PyExpression(Expression::One));
99    };
100    if let Ok(first_expression) = first_term.extract::<PyExpression>() {
101        let mut product = first_expression.clone();
102        for term in iter {
103            product = product.__mul__(term)?;
104        }
105        return Ok(product);
106    }
107    if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
108        let mut product = PyExpression(Expression::Amp(first_amplitude_id.0));
109        for term in iter {
110            product = product.__mul__(term)?;
111        }
112        return Ok(product);
113    }
114    Err(PyTypeError::new_err(
115        "Elements must be PyExpression or PyAmplitudeID",
116    ))
117}
118
119/// A convenience class representing a zero-valued Expression
120///
121#[pyfunction(name = "AmplitudeZero")]
122pub fn py_amplitude_zero() -> PyExpression {
123    PyExpression(Expression::Zero)
124}
125
126/// A convenience class representing a unit-valued Expression
127///
128#[pyfunction(name = "AmplitudeOne")]
129pub fn py_amplitude_one() -> PyExpression {
130    PyExpression(Expression::One)
131}
132
133#[pymethods]
134impl PyAmplitudeID {
135    /// The real part of a complex Amplitude
136    ///
137    /// Returns
138    /// -------
139    /// Expression
140    ///     The real part of the given Amplitude
141    ///
142    fn real(&self) -> PyExpression {
143        PyExpression(self.0.real())
144    }
145    /// The imaginary part of a complex Amplitude
146    ///
147    /// Returns
148    /// -------
149    /// Expression
150    ///     The imaginary part of the given Amplitude
151    ///
152    fn imag(&self) -> PyExpression {
153        PyExpression(self.0.imag())
154    }
155    /// The norm-squared of a complex Amplitude
156    ///
157    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
158    ///
159    /// Returns
160    /// -------
161    /// Expression
162    ///     The norm-squared of the given Amplitude
163    ///
164    fn norm_sqr(&self) -> PyExpression {
165        PyExpression(self.0.norm_sqr())
166    }
167    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
168        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
169            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
170        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
171            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
172        } else if let Ok(other_int) = other.extract::<usize>() {
173            if other_int == 0 {
174                Ok(PyExpression(Expression::Amp(self.0.clone())))
175            } else {
176                Err(PyTypeError::new_err(
177                    "Addition with an integer for this type is only defined for 0",
178                ))
179            }
180        } else {
181            Err(PyTypeError::new_err("Unsupported operand type for +"))
182        }
183    }
184    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
185        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
186            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
187        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
188            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
189        } else if let Ok(other_int) = other.extract::<usize>() {
190            if other_int == 0 {
191                Ok(PyExpression(Expression::Amp(self.0.clone())))
192            } else {
193                Err(PyTypeError::new_err(
194                    "Addition with an integer for this type is only defined for 0",
195                ))
196            }
197        } else {
198            Err(PyTypeError::new_err("Unsupported operand type for +"))
199        }
200    }
201    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
202        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
203            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
204        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
205            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
206        } else {
207            Err(PyTypeError::new_err("Unsupported operand type for *"))
208        }
209    }
210    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
211        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
212            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
213        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
214            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
215        } else {
216            Err(PyTypeError::new_err("Unsupported operand type for *"))
217        }
218    }
219    fn __str__(&self) -> String {
220        format!("{}", self.0)
221    }
222    fn __repr__(&self) -> String {
223        format!("{:?}", self.0)
224    }
225}
226
227#[pymethods]
228impl PyExpression {
229    /// The real part of a complex Expression
230    ///
231    /// Returns
232    /// -------
233    /// Expression
234    ///     The real part of the given Expression
235    ///
236    fn real(&self) -> PyExpression {
237        PyExpression(self.0.real())
238    }
239    /// The imaginary part of a complex Expression
240    ///
241    /// Returns
242    /// -------
243    /// Expression
244    ///     The imaginary part of the given Expression
245    ///
246    fn imag(&self) -> PyExpression {
247        PyExpression(self.0.imag())
248    }
249    /// The norm-squared of a complex Expression
250    ///
251    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
252    ///
253    /// Returns
254    /// -------
255    /// Expression
256    ///     The norm-squared of the given Expression
257    ///
258    fn norm_sqr(&self) -> PyExpression {
259        PyExpression(self.0.norm_sqr())
260    }
261    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
262        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
263            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
264        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
265            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
266        } else if let Ok(other_int) = other.extract::<usize>() {
267            if other_int == 0 {
268                Ok(PyExpression(self.0.clone()))
269            } else {
270                Err(PyTypeError::new_err(
271                    "Addition with an integer for this type is only defined for 0",
272                ))
273            }
274        } else {
275            Err(PyTypeError::new_err("Unsupported operand type for +"))
276        }
277    }
278    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
279        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
280            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
281        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
282            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
283        } else if let Ok(other_int) = other.extract::<usize>() {
284            if other_int == 0 {
285                Ok(PyExpression(self.0.clone()))
286            } else {
287                Err(PyTypeError::new_err(
288                    "Addition with an integer for this type is only defined for 0",
289                ))
290            }
291        } else {
292            Err(PyTypeError::new_err("Unsupported operand type for +"))
293        }
294    }
295    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
296        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
297            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
298        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
299            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
300        } else {
301            Err(PyTypeError::new_err("Unsupported operand type for *"))
302        }
303    }
304    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
305        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
306            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
307        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
308            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
309        } else {
310            Err(PyTypeError::new_err("Unsupported operand type for *"))
311        }
312    }
313    fn __str__(&self) -> String {
314        format!("{}", self.0)
315    }
316    fn __repr__(&self) -> String {
317        format!("{:?}", self.0)
318    }
319}
320
321/// A class which can be used to register Amplitudes and store precalculated data
322///
323#[pyclass(name = "Manager", module = "laddu")]
324pub struct PyManager(Manager);
325
326#[pymethods]
327impl PyManager {
328    #[new]
329    fn new() -> Self {
330        Self(Manager::default())
331    }
332    /// The free parameters used by the Manager
333    ///
334    /// Returns
335    /// -------
336    /// parameters : list of str
337    ///     The list of parameter names
338    ///
339    #[getter]
340    fn parameters(&self) -> Vec<String> {
341        self.0.parameters()
342    }
343    /// Register an Amplitude with the Manager
344    ///
345    /// Parameters
346    /// ----------
347    /// amplitude : Amplitude
348    ///     The Amplitude to register
349    ///
350    /// Returns
351    /// -------
352    /// AmplitudeID
353    ///     A reference to the registered `amplitude` that can be used to form complex
354    ///     Expressions
355    ///
356    /// Raises
357    /// ------
358    /// ValueError
359    ///     If the name of the ``amplitude`` has already been registered
360    ///
361    fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
362        Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
363    }
364    /// Generate a Model from the given expression made of registered Amplitudes
365    ///
366    /// Parameters
367    /// ----------
368    /// expression : Expression or AmplitudeID
369    ///     The expression to use in precalculation
370    ///
371    /// Returns
372    /// -------
373    /// Model
374    ///     An object which represents the underlying mathematical model and can be loaded with
375    ///     a Dataset
376    ///
377    /// Raises
378    /// ------
379    /// TypeError
380    ///     If the expression is not convertable to a Model
381    ///
382    /// Notes
383    /// -----
384    /// While the given `expression` will be the one evaluated in the end, all registered
385    /// Amplitudes will be loaded, and all of their parameters will be included in the final
386    /// expression. These parameters will have no effect on evaluation, but they must be
387    /// included in function calls.
388    ///
389    fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
390        let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
391            Ok(expression.0)
392        } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
393            Ok(Expression::Amp(aid.0))
394        } else {
395            Err(PyTypeError::new_err(
396                "'expression' must either by an Expression or AmplitudeID",
397            ))
398        }?;
399        Ok(PyModel(self.0.model(&expression)))
400    }
401}
402
403/// A class which represents a model composed of registered Amplitudes
404///
405#[pyclass(name = "Model", module = "laddu")]
406pub struct PyModel(pub Model);
407
408#[pymethods]
409impl PyModel {
410    /// The free parameters used by the Manager
411    ///
412    /// Returns
413    /// -------
414    /// parameters : list of str
415    ///     The list of parameter names
416    ///
417    #[getter]
418    fn parameters(&self) -> Vec<String> {
419        self.0.parameters()
420    }
421    /// Load a Model by precalculating each term over the given Dataset
422    ///
423    /// Parameters
424    /// ----------
425    /// dataset : Dataset
426    ///     The Dataset to use in precalculation
427    ///
428    /// Returns
429    /// -------
430    /// Evaluator
431    ///     An object that can be used to evaluate the `expression` over each event in the
432    ///     `dataset`
433    ///
434    /// Notes
435    /// -----
436    /// While the given `expression` will be the one evaluated in the end, all registered
437    /// Amplitudes will be loaded, and all of their parameters will be included in the final
438    /// expression. These parameters will have no effect on evaluation, but they must be
439    /// included in function calls.
440    ///
441    fn load(&self, dataset: &PyDataset) -> PyEvaluator {
442        PyEvaluator(self.0.load(&dataset.0))
443    }
444    /// Save the Model to a file
445    ///
446    /// Parameters
447    /// ----------
448    /// path : str
449    ///     The path of the new file (overwrites if the file exists!)
450    ///
451    /// Raises
452    /// ------
453    /// IOError
454    ///     If anything fails when trying to write the file
455    ///
456    fn save_as(&self, path: &str) -> PyResult<()> {
457        self.0.save_as(path)?;
458        Ok(())
459    }
460    /// Load a Model from a file
461    ///
462    /// Parameters
463    /// ----------
464    /// path : str
465    ///     The path of the existing fit file
466    ///
467    /// Returns
468    /// -------
469    /// Model
470    ///     The model contained in the file
471    ///
472    /// Raises
473    /// ------
474    /// IOError
475    ///     If anything fails when trying to read the file
476    ///
477    #[staticmethod]
478    fn load_from(path: &str) -> PyResult<Self> {
479        Ok(PyModel(Model::load_from(path)?))
480    }
481    #[new]
482    fn new() -> Self {
483        PyModel(Model::create_null())
484    }
485    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
486        Ok(PyBytes::new(
487            py,
488            bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
489                .map_err(LadduError::EncodeError)?
490                .as_slice(),
491        ))
492    }
493    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
494        *self = PyModel(
495            bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
496                .map_err(LadduError::DecodeError)?
497                .0,
498        );
499        Ok(())
500    }
501}
502
503/// An Amplitude which can be registered by a Manager
504///
505/// See Also
506/// --------
507/// laddu.Manager
508///
509#[pyclass(name = "Amplitude", module = "laddu")]
510pub struct PyAmplitude(pub Box<dyn Amplitude>);
511
512/// A class which can be used to evaluate a stored Expression
513///
514/// See Also
515/// --------
516/// laddu.Manager.load
517///
518#[pyclass(name = "Evaluator", module = "laddu")]
519#[derive(Clone)]
520pub struct PyEvaluator(pub Evaluator);
521
522#[pymethods]
523impl PyEvaluator {
524    /// The free parameters used by the Evaluator
525    ///
526    /// Returns
527    /// -------
528    /// parameters : list of str
529    ///     The list of parameter names
530    ///
531    #[getter]
532    fn parameters(&self) -> Vec<String> {
533        self.0.parameters()
534    }
535    /// Activates Amplitudes in the Expression by name
536    ///
537    /// Parameters
538    /// ----------
539    /// arg : str or list of str
540    ///     Names of Amplitudes to be activated
541    ///
542    /// Raises
543    /// ------
544    /// TypeError
545    ///     If `arg` is not a str or list of str
546    /// ValueError
547    ///     If `arg` or any items of `arg` are not registered Amplitudes
548    ///
549    fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
550        if let Ok(string_arg) = arg.extract::<String>() {
551            self.0.activate(&string_arg)?;
552        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
553            let vec: Vec<String> = list_arg.extract()?;
554            self.0.activate_many(&vec)?;
555        } else {
556            return Err(PyTypeError::new_err(
557                "Argument must be either a string or a list of strings",
558            ));
559        }
560        Ok(())
561    }
562    /// Activates all Amplitudes in the Expression
563    ///
564    fn activate_all(&self) {
565        self.0.activate_all();
566    }
567    /// Deactivates Amplitudes in the Expression by name
568    ///
569    /// Deactivated Amplitudes act as zeros in the Expression
570    ///
571    /// Parameters
572    /// ----------
573    /// arg : str or list of str
574    ///     Names of Amplitudes to be deactivated
575    ///
576    /// Raises
577    /// ------
578    /// TypeError
579    ///     If `arg` is not a str or list of str
580    /// ValueError
581    ///     If `arg` or any items of `arg` are not registered Amplitudes
582    ///
583    fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
584        if let Ok(string_arg) = arg.extract::<String>() {
585            self.0.deactivate(&string_arg)?;
586        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
587            let vec: Vec<String> = list_arg.extract()?;
588            self.0.deactivate_many(&vec)?;
589        } else {
590            return Err(PyTypeError::new_err(
591                "Argument must be either a string or a list of strings",
592            ));
593        }
594        Ok(())
595    }
596    /// Deactivates all Amplitudes in the Expression
597    ///
598    fn deactivate_all(&self) {
599        self.0.deactivate_all();
600    }
601    /// Isolates Amplitudes in the Expression by name
602    ///
603    /// Activates the Amplitudes given in `arg` and deactivates the rest
604    ///
605    /// Parameters
606    /// ----------
607    /// arg : str or list of str
608    ///     Names of Amplitudes to be isolated
609    ///
610    /// Raises
611    /// ------
612    /// TypeError
613    ///     If `arg` is not a str or list of str
614    /// ValueError
615    ///     If `arg` or any items of `arg` are not registered Amplitudes
616    ///
617    fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
618        if let Ok(string_arg) = arg.extract::<String>() {
619            self.0.isolate(&string_arg)?;
620        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
621            let vec: Vec<String> = list_arg.extract()?;
622            self.0.isolate_many(&vec)?;
623        } else {
624            return Err(PyTypeError::new_err(
625                "Argument must be either a string or a list of strings",
626            ));
627        }
628        Ok(())
629    }
630    /// Evaluate the stored Expression over the stored Dataset
631    ///
632    /// Parameters
633    /// ----------
634    /// parameters : list of float
635    ///     The values to use for the free parameters
636    /// threads : int, optional
637    ///     The number of threads to use (setting this to None will use all available CPUs)
638    ///
639    /// Returns
640    /// -------
641    /// result : array_like
642    ///     A ``numpy`` array of complex values for each Event in the Dataset
643    ///
644    /// Raises
645    /// ------
646    /// Exception
647    ///     If there was an error building the thread pool
648    ///
649    #[pyo3(signature = (parameters, *, threads=None))]
650    fn evaluate<'py>(
651        &self,
652        py: Python<'py>,
653        parameters: Vec<Float>,
654        threads: Option<usize>,
655    ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
656        #[cfg(feature = "rayon")]
657        {
658            Ok(PyArray1::from_slice(
659                py,
660                &ThreadPoolBuilder::new()
661                    .num_threads(threads.unwrap_or_else(num_cpus::get))
662                    .build()
663                    .map_err(LadduError::from)?
664                    .install(|| self.0.evaluate(&parameters)),
665            ))
666        }
667        #[cfg(not(feature = "rayon"))]
668        {
669            Ok(PyArray1::from_slice(py, &self.0.evaluate(&parameters)))
670        }
671    }
672    /// Evaluate the gradient of the stored Expression over the stored Dataset
673    ///
674    /// Parameters
675    /// ----------
676    /// parameters : list of float
677    ///     The values to use for the free parameters
678    /// threads : int, optional
679    ///     The number of threads to use (setting this to None will use all available CPUs)
680    ///
681    /// Returns
682    /// -------
683    /// result : array_like
684    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
685    ///
686    /// Raises
687    /// ------
688    /// Exception
689    ///     If there was an error building the thread pool or problem creating the resulting
690    ///     ``numpy`` array
691    ///
692    #[pyo3(signature = (parameters, *, threads=None))]
693    fn evaluate_gradient<'py>(
694        &self,
695        py: Python<'py>,
696        parameters: Vec<Float>,
697        threads: Option<usize>,
698    ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
699        #[cfg(feature = "rayon")]
700        {
701            Ok(PyArray2::from_vec2(
702                py,
703                &ThreadPoolBuilder::new()
704                    .num_threads(threads.unwrap_or_else(num_cpus::get))
705                    .build()
706                    .map_err(LadduError::from)?
707                    .install(|| {
708                        self.0
709                            .evaluate_gradient(&parameters)
710                            .iter()
711                            .map(|grad| grad.data.as_vec().to_vec())
712                            .collect::<Vec<Vec<Complex<Float>>>>()
713                    }),
714            )
715            .map_err(LadduError::NumpyError)?)
716        }
717        #[cfg(not(feature = "rayon"))]
718        {
719            Ok(PyArray2::from_vec2(
720                py,
721                &self
722                    .0
723                    .evaluate_gradient(&parameters)
724                    .iter()
725                    .map(|grad| grad.data.as_vec().to_vec())
726                    .collect::<Vec<Vec<Complex<Float>>>>(),
727            )
728            .map_err(LadduError::NumpyError)?)
729        }
730    }
731}
732
733/// A class, typically used to allow Amplitudes to take either free parameters or constants as
734/// inputs
735///
736/// See Also
737/// --------
738/// laddu.parameter
739/// laddu.constant
740///
741#[pyclass(name = "ParameterLike", module = "laddu")]
742#[derive(Clone)]
743pub struct PyParameterLike(pub ParameterLike);
744
745/// A free parameter which floats during an optimization
746///
747/// Parameters
748/// ----------
749/// name : str
750///     The name of the free parameter
751///
752/// Returns
753/// -------
754/// laddu.ParameterLike
755///     An object that can be used as the input for many Amplitude constructors
756///
757/// Notes
758/// -----
759/// Two free parameters with the same name are shared in a fit
760///
761#[pyfunction(name = "parameter")]
762pub fn py_parameter(name: &str) -> PyParameterLike {
763    PyParameterLike(parameter(name))
764}
765
766/// A term which stays constant during an optimization
767///
768/// Parameters
769/// ----------
770/// value : float
771///     The numerical value of the constant
772///
773/// Returns
774/// -------
775/// laddu.ParameterLike
776///     An object that can be used as the input for many Amplitude constructors
777///
778#[pyfunction(name = "constant")]
779pub fn py_constant(value: Float) -> PyParameterLike {
780    PyParameterLike(constant(value))
781}