laddu_python/
amplitudes.rs

1use crate::data::PyDataset;
2use bincode::{deserialize, serialize};
3use laddu_core::{
4    amplitudes::{
5        constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
6        ParameterLike,
7    },
8    traits::ReadWrite,
9    Complex, Float, LadduError,
10};
11use numpy::{PyArray1, PyArray2};
12use pyo3::{
13    exceptions::PyTypeError,
14    prelude::*,
15    types::{PyBytes, PyList},
16};
17#[cfg(feature = "rayon")]
18use rayon::ThreadPoolBuilder;
19
20/// An object which holds a registered ``Amplitude``
21///
22/// See Also
23/// --------
24/// laddu.Manager.register
25///
26#[pyclass(name = "AmplitudeID", module = "laddu")]
27#[derive(Clone)]
28pub struct PyAmplitudeID(AmplitudeID);
29
30/// A mathematical expression formed from AmplitudeIDs
31///
32#[pyclass(name = "Expression", module = "laddu")]
33#[derive(Clone)]
34pub struct PyExpression(Expression);
35
36#[pymethods]
37impl PyAmplitudeID {
38    /// The real part of a complex Amplitude
39    ///
40    /// Returns
41    /// -------
42    /// Expression
43    ///     The real part of the given Amplitude
44    ///
45    fn real(&self) -> PyExpression {
46        PyExpression(self.0.real())
47    }
48    /// The imaginary part of a complex Amplitude
49    ///
50    /// Returns
51    /// -------
52    /// Expression
53    ///     The imaginary part of the given Amplitude
54    ///
55    fn imag(&self) -> PyExpression {
56        PyExpression(self.0.imag())
57    }
58    /// The norm-squared of a complex Amplitude
59    ///
60    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
61    ///
62    /// Returns
63    /// -------
64    /// Expression
65    ///     The norm-squared of the given Amplitude
66    ///
67    fn norm_sqr(&self) -> PyExpression {
68        PyExpression(self.0.norm_sqr())
69    }
70    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
71        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
72            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
73        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
74            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
75        } else if let Ok(other_int) = other.extract::<usize>() {
76            if other_int == 0 {
77                Ok(PyExpression(Expression::Amp(self.0.clone())))
78            } else {
79                Err(PyTypeError::new_err(
80                    "Addition with an integer for this type is only defined for 0",
81                ))
82            }
83        } else {
84            Err(PyTypeError::new_err("Unsupported operand type for +"))
85        }
86    }
87    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
88        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
89            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
90        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
91            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
92        } else if let Ok(other_int) = other.extract::<usize>() {
93            if other_int == 0 {
94                Ok(PyExpression(Expression::Amp(self.0.clone())))
95            } else {
96                Err(PyTypeError::new_err(
97                    "Addition with an integer for this type is only defined for 0",
98                ))
99            }
100        } else {
101            Err(PyTypeError::new_err("Unsupported operand type for +"))
102        }
103    }
104    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
105        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
106            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
107        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
108            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
109        } else {
110            Err(PyTypeError::new_err("Unsupported operand type for *"))
111        }
112    }
113    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
114        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
115            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
116        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
117            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
118        } else {
119            Err(PyTypeError::new_err("Unsupported operand type for *"))
120        }
121    }
122    fn __str__(&self) -> String {
123        format!("{}", self.0)
124    }
125    fn __repr__(&self) -> String {
126        format!("{:?}", self.0)
127    }
128}
129
130#[pymethods]
131impl PyExpression {
132    /// The real part of a complex Expression
133    ///
134    /// Returns
135    /// -------
136    /// Expression
137    ///     The real part of the given Expression
138    ///
139    fn real(&self) -> PyExpression {
140        PyExpression(self.0.real())
141    }
142    /// The imaginary part of a complex Expression
143    ///
144    /// Returns
145    /// -------
146    /// Expression
147    ///     The imaginary part of the given Expression
148    ///
149    fn imag(&self) -> PyExpression {
150        PyExpression(self.0.imag())
151    }
152    /// The norm-squared of a complex Expression
153    ///
154    /// This is computed as :math:`AA^*` where :math:`A^*` is the complex conjugate
155    ///
156    /// Returns
157    /// -------
158    /// Expression
159    ///     The norm-squared of the given Expression
160    ///
161    fn norm_sqr(&self) -> PyExpression {
162        PyExpression(self.0.norm_sqr())
163    }
164    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
165        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
166            Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
167        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
168            Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
169        } else if let Ok(other_int) = other.extract::<usize>() {
170            if other_int == 0 {
171                Ok(PyExpression(self.0.clone()))
172            } else {
173                Err(PyTypeError::new_err(
174                    "Addition with an integer for this type is only defined for 0",
175                ))
176            }
177        } else {
178            Err(PyTypeError::new_err("Unsupported operand type for +"))
179        }
180    }
181    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
182        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
183            Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
184        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
185            Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
186        } else if let Ok(other_int) = other.extract::<usize>() {
187            if other_int == 0 {
188                Ok(PyExpression(self.0.clone()))
189            } else {
190                Err(PyTypeError::new_err(
191                    "Addition with an integer for this type is only defined for 0",
192                ))
193            }
194        } else {
195            Err(PyTypeError::new_err("Unsupported operand type for +"))
196        }
197    }
198    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
199        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
200            Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
201        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
202            Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
203        } else {
204            Err(PyTypeError::new_err("Unsupported operand type for *"))
205        }
206    }
207    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
208        if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
209            Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
210        } else if let Ok(other_expr) = other.extract::<PyExpression>() {
211            Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
212        } else {
213            Err(PyTypeError::new_err("Unsupported operand type for *"))
214        }
215    }
216    fn __str__(&self) -> String {
217        format!("{}", self.0)
218    }
219    fn __repr__(&self) -> String {
220        format!("{:?}", self.0)
221    }
222}
223
224/// A class which can be used to register Amplitudes and store precalculated data
225///
226#[pyclass(name = "Manager", module = "laddu")]
227pub struct PyManager(Manager);
228
229#[pymethods]
230impl PyManager {
231    #[new]
232    fn new() -> Self {
233        Self(Manager::default())
234    }
235    /// The free parameters used by the Manager
236    ///
237    /// Returns
238    /// -------
239    /// parameters : list of str
240    ///     The list of parameter names
241    ///
242    #[getter]
243    fn parameters(&self) -> Vec<String> {
244        self.0.parameters()
245    }
246    /// Register an Amplitude with the Manager
247    ///
248    /// Parameters
249    /// ----------
250    /// amplitude : Amplitude
251    ///     The Amplitude to register
252    ///
253    /// Returns
254    /// -------
255    /// AmplitudeID
256    ///     A reference to the registered `amplitude` that can be used to form complex
257    ///     Expressions
258    ///
259    /// Raises
260    /// ------
261    /// ValueError
262    ///     If the name of the ``amplitude`` has already been registered
263    ///
264    fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
265        Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
266    }
267    /// Generate a Model from the given expression made of registered Amplitudes
268    ///
269    /// Parameters
270    /// ----------
271    /// expression : Expression or AmplitudeID
272    ///     The expression to use in precalculation
273    ///
274    /// Returns
275    /// -------
276    /// Model
277    ///     An object which represents the underlying mathematical model and can be loaded with
278    ///     a Dataset
279    ///
280    /// Raises
281    /// ------
282    /// TypeError
283    ///     If the expression is not convertable to a Model
284    ///
285    /// Notes
286    /// -----
287    /// While the given `expression` will be the one evaluated in the end, all registered
288    /// Amplitudes will be loaded, and all of their parameters will be included in the final
289    /// expression. These parameters will have no effect on evaluation, but they must be
290    /// included in function calls.
291    ///
292    fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
293        let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
294            Ok(expression.0)
295        } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
296            Ok(Expression::Amp(aid.0))
297        } else {
298            Err(PyTypeError::new_err(
299                "'expression' must either by an Expression or AmplitudeID",
300            ))
301        }?;
302        Ok(PyModel(self.0.model(&expression)))
303    }
304}
305
306/// A class which represents a model composed of registered Amplitudes
307///
308#[pyclass(name = "Model", module = "laddu")]
309pub struct PyModel(pub Model);
310
311#[pymethods]
312impl PyModel {
313    /// The free parameters used by the Manager
314    ///
315    /// Returns
316    /// -------
317    /// parameters : list of str
318    ///     The list of parameter names
319    ///
320    #[getter]
321    fn parameters(&self) -> Vec<String> {
322        self.0.parameters()
323    }
324    /// Load a Model by precalculating each term over the given Dataset
325    ///
326    /// Parameters
327    /// ----------
328    /// dataset : Dataset
329    ///     The Dataset to use in precalculation
330    ///
331    /// Returns
332    /// -------
333    /// Evaluator
334    ///     An object that can be used to evaluate the `expression` over each event in the
335    ///     `dataset`
336    ///
337    /// Notes
338    /// -----
339    /// While the given `expression` will be the one evaluated in the end, all registered
340    /// Amplitudes will be loaded, and all of their parameters will be included in the final
341    /// expression. These parameters will have no effect on evaluation, but they must be
342    /// included in function calls.
343    ///
344    fn load(&self, dataset: &PyDataset) -> PyEvaluator {
345        PyEvaluator(self.0.load(&dataset.0))
346    }
347    /// Save the Model to a file
348    ///
349    /// Parameters
350    /// ----------
351    /// path : str
352    ///     The path of the new file (overwrites if the file exists!)
353    ///
354    /// Raises
355    /// ------
356    /// IOError
357    ///     If anything fails when trying to write the file
358    ///
359    fn save_as(&self, path: &str) -> PyResult<()> {
360        self.0.save_as(path)?;
361        Ok(())
362    }
363    /// Load a Model from a file
364    ///
365    /// Parameters
366    /// ----------
367    /// path : str
368    ///     The path of the existing fit file
369    ///
370    /// Returns
371    /// -------
372    /// Model
373    ///     The model contained in the file
374    ///
375    /// Raises
376    /// ------
377    /// IOError
378    ///     If anything fails when trying to read the file
379    ///
380    #[staticmethod]
381    fn load_from(path: &str) -> PyResult<Self> {
382        Ok(PyModel(Model::load_from(path)?))
383    }
384    #[new]
385    fn new() -> Self {
386        PyModel(Model::create_null())
387    }
388    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
389        Ok(PyBytes::new(
390            py,
391            serialize(&self.0)
392                .map_err(LadduError::SerdeError)?
393                .as_slice(),
394        ))
395    }
396    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
397        *self = PyModel(deserialize(state.as_bytes()).map_err(LadduError::SerdeError)?);
398        Ok(())
399    }
400}
401
402/// An Amplitude which can be registered by a Manager
403///
404/// See Also
405/// --------
406/// laddu.Manager
407///
408#[pyclass(name = "Amplitude", module = "laddu")]
409pub struct PyAmplitude(pub Box<dyn Amplitude>);
410
411/// A class which can be used to evaluate a stored Expression
412///
413/// See Also
414/// --------
415/// laddu.Manager.load
416///
417#[pyclass(name = "Evaluator", module = "laddu")]
418#[derive(Clone)]
419pub struct PyEvaluator(pub Evaluator);
420
421#[pymethods]
422impl PyEvaluator {
423    /// The free parameters used by the Evaluator
424    ///
425    /// Returns
426    /// -------
427    /// parameters : list of str
428    ///     The list of parameter names
429    ///
430    #[getter]
431    fn parameters(&self) -> Vec<String> {
432        self.0.parameters()
433    }
434    /// Activates Amplitudes in the Expression by name
435    ///
436    /// Parameters
437    /// ----------
438    /// arg : str or list of str
439    ///     Names of Amplitudes to be activated
440    ///
441    /// Raises
442    /// ------
443    /// TypeError
444    ///     If `arg` is not a str or list of str
445    /// ValueError
446    ///     If `arg` or any items of `arg` are not registered Amplitudes
447    ///
448    fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
449        if let Ok(string_arg) = arg.extract::<String>() {
450            self.0.activate(&string_arg)?;
451        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
452            let vec: Vec<String> = list_arg.extract()?;
453            self.0.activate_many(&vec)?;
454        } else {
455            return Err(PyTypeError::new_err(
456                "Argument must be either a string or a list of strings",
457            ));
458        }
459        Ok(())
460    }
461    /// Activates all Amplitudes in the Expression
462    ///
463    fn activate_all(&self) {
464        self.0.activate_all();
465    }
466    /// Deactivates Amplitudes in the Expression by name
467    ///
468    /// Deactivated Amplitudes act as zeros in the Expression
469    ///
470    /// Parameters
471    /// ----------
472    /// arg : str or list of str
473    ///     Names of Amplitudes to be deactivated
474    ///
475    /// Raises
476    /// ------
477    /// TypeError
478    ///     If `arg` is not a str or list of str
479    /// ValueError
480    ///     If `arg` or any items of `arg` are not registered Amplitudes
481    ///
482    fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
483        if let Ok(string_arg) = arg.extract::<String>() {
484            self.0.deactivate(&string_arg)?;
485        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
486            let vec: Vec<String> = list_arg.extract()?;
487            self.0.deactivate_many(&vec)?;
488        } else {
489            return Err(PyTypeError::new_err(
490                "Argument must be either a string or a list of strings",
491            ));
492        }
493        Ok(())
494    }
495    /// Deactivates all Amplitudes in the Expression
496    ///
497    fn deactivate_all(&self) {
498        self.0.deactivate_all();
499    }
500    /// Isolates Amplitudes in the Expression by name
501    ///
502    /// Activates the Amplitudes given in `arg` and deactivates the rest
503    ///
504    /// Parameters
505    /// ----------
506    /// arg : str or list of str
507    ///     Names of Amplitudes to be isolated
508    ///
509    /// Raises
510    /// ------
511    /// TypeError
512    ///     If `arg` is not a str or list of str
513    /// ValueError
514    ///     If `arg` or any items of `arg` are not registered Amplitudes
515    ///
516    fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
517        if let Ok(string_arg) = arg.extract::<String>() {
518            self.0.isolate(&string_arg)?;
519        } else if let Ok(list_arg) = arg.downcast::<PyList>() {
520            let vec: Vec<String> = list_arg.extract()?;
521            self.0.isolate_many(&vec)?;
522        } else {
523            return Err(PyTypeError::new_err(
524                "Argument must be either a string or a list of strings",
525            ));
526        }
527        Ok(())
528    }
529    /// Evaluate the stored Expression over the stored Dataset
530    ///
531    /// Parameters
532    /// ----------
533    /// parameters : list of float
534    ///     The values to use for the free parameters
535    /// threads : int, optional
536    ///     The number of threads to use (setting this to None will use all available CPUs)
537    ///
538    /// Returns
539    /// -------
540    /// result : array_like
541    ///     A ``numpy`` array of complex values for each Event in the Dataset
542    ///
543    /// Raises
544    /// ------
545    /// Exception
546    ///     If there was an error building the thread pool
547    ///
548    #[pyo3(signature = (parameters, *, threads=None))]
549    fn evaluate<'py>(
550        &self,
551        py: Python<'py>,
552        parameters: Vec<Float>,
553        threads: Option<usize>,
554    ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
555        #[cfg(feature = "rayon")]
556        {
557            Ok(PyArray1::from_slice(
558                py,
559                &ThreadPoolBuilder::new()
560                    .num_threads(threads.unwrap_or_else(num_cpus::get))
561                    .build()
562                    .map_err(LadduError::from)?
563                    .install(|| self.0.evaluate(&parameters)),
564            ))
565        }
566        #[cfg(not(feature = "rayon"))]
567        {
568            Ok(PyArray1::from_slice(py, &self.0.evaluate(&parameters)))
569        }
570    }
571    /// Evaluate the gradient of the stored Expression over the stored Dataset
572    ///
573    /// Parameters
574    /// ----------
575    /// parameters : list of float
576    ///     The values to use for the free parameters
577    /// threads : int, optional
578    ///     The number of threads to use (setting this to None will use all available CPUs)
579    ///
580    /// Returns
581    /// -------
582    /// result : array_like
583    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
584    ///
585    /// Raises
586    /// ------
587    /// Exception
588    ///     If there was an error building the thread pool or problem creating the resulting
589    ///     ``numpy`` array
590    ///
591    #[pyo3(signature = (parameters, *, threads=None))]
592    fn evaluate_gradient<'py>(
593        &self,
594        py: Python<'py>,
595        parameters: Vec<Float>,
596        threads: Option<usize>,
597    ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
598        #[cfg(feature = "rayon")]
599        {
600            Ok(PyArray2::from_vec2(
601                py,
602                &ThreadPoolBuilder::new()
603                    .num_threads(threads.unwrap_or_else(num_cpus::get))
604                    .build()
605                    .map_err(LadduError::from)?
606                    .install(|| {
607                        self.0
608                            .evaluate_gradient(&parameters)
609                            .iter()
610                            .map(|grad| grad.data.as_vec().to_vec())
611                            .collect::<Vec<Vec<Complex<Float>>>>()
612                    }),
613            )
614            .map_err(LadduError::NumpyError)?)
615        }
616        #[cfg(not(feature = "rayon"))]
617        {
618            Ok(PyArray2::from_vec2(
619                py,
620                &self
621                    .0
622                    .evaluate_gradient(&parameters)
623                    .iter()
624                    .map(|grad| grad.data.as_vec().to_vec())
625                    .collect::<Vec<Vec<Complex<Float>>>>(),
626            )
627            .map_err(LadduError::NumpyError)?)
628        }
629    }
630}
631
632/// A class, typically used to allow Amplitudes to take either free parameters or constants as
633/// inputs
634///
635/// See Also
636/// --------
637/// laddu.parameter
638/// laddu.constant
639///
640#[pyclass(name = "ParameterLike", module = "laddu")]
641#[derive(Clone)]
642pub struct PyParameterLike(pub ParameterLike);
643
644/// A free parameter which floats during an optimization
645///
646/// Parameters
647/// ----------
648/// name : str
649///     The name of the free parameter
650///
651/// Returns
652/// -------
653/// laddu.ParameterLike
654///     An object that can be used as the input for many Amplitude constructors
655///
656/// Notes
657/// -----
658/// Two free parameters with the same name are shared in a fit
659///
660#[pyfunction(name = "parameter")]
661pub fn py_parameter(name: &str) -> PyParameterLike {
662    PyParameterLike(parameter(name))
663}
664
665/// A term which stays constant during an optimization
666///
667/// Parameters
668/// ----------
669/// value : float
670///     The numerical value of the constant
671///
672/// Returns
673/// -------
674/// laddu.ParameterLike
675///     An object that can be used as the input for many Amplitude constructors
676///
677#[pyfunction(name = "constant")]
678pub fn py_constant(value: Float) -> PyParameterLike {
679    PyParameterLike(constant(value))
680}