laddu_python/utils/
variables.rs

1use crate::{
2    data::{PyDataset, PyEvent},
3    utils::vectors::{PyVec3, PyVec4},
4};
5use laddu_core::{
6    data::{Dataset, DatasetMetadata, EventData},
7    traits::Variable,
8    utils::variables::{
9        Angles, CosTheta, IntoP4Selection, Mandelstam, Mass, P4Selection, Phi, PolAngle,
10        PolMagnitude, Polarization, Topology, VariableExpression,
11    },
12    LadduResult,
13};
14use numpy::PyArray1;
15use pyo3::{exceptions::PyValueError, prelude::*};
16use serde::{Deserialize, Serialize};
17use std::{
18    fmt::{Debug, Display},
19    sync::Arc,
20};
21
22#[derive(FromPyObject, Clone, Serialize, Deserialize)]
23pub enum PyVariable {
24    #[pyo3(transparent)]
25    Mass(PyMass),
26    #[pyo3(transparent)]
27    CosTheta(PyCosTheta),
28    #[pyo3(transparent)]
29    Phi(PyPhi),
30    #[pyo3(transparent)]
31    PolAngle(PyPolAngle),
32    #[pyo3(transparent)]
33    PolMagnitude(PyPolMagnitude),
34    #[pyo3(transparent)]
35    Mandelstam(PyMandelstam),
36}
37
38impl Debug for PyVariable {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::Mass(v) => write!(f, "{:?}", v.0),
42            Self::CosTheta(v) => write!(f, "{:?}", v.0),
43            Self::Phi(v) => write!(f, "{:?}", v.0),
44            Self::PolAngle(v) => write!(f, "{:?}", v.0),
45            Self::PolMagnitude(v) => write!(f, "{:?}", v.0),
46            Self::Mandelstam(v) => write!(f, "{:?}", v.0),
47        }
48    }
49}
50impl Display for PyVariable {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            Self::Mass(v) => write!(f, "{}", v.0),
54            Self::CosTheta(v) => write!(f, "{}", v.0),
55            Self::Phi(v) => write!(f, "{}", v.0),
56            Self::PolAngle(v) => write!(f, "{}", v.0),
57            Self::PolMagnitude(v) => write!(f, "{}", v.0),
58            Self::Mandelstam(v) => write!(f, "{}", v.0),
59        }
60    }
61}
62
63impl PyVariable {
64    pub(crate) fn bind_in_place(&mut self, metadata: &DatasetMetadata) -> PyResult<()> {
65        match self {
66            Self::Mass(mass) => mass.0.bind(metadata).map_err(PyErr::from),
67            Self::CosTheta(cos_theta) => cos_theta.0.bind(metadata).map_err(PyErr::from),
68            Self::Phi(phi) => phi.0.bind(metadata).map_err(PyErr::from),
69            Self::PolAngle(pol_angle) => pol_angle.0.bind(metadata).map_err(PyErr::from),
70            Self::PolMagnitude(pol_magnitude) => {
71                pol_magnitude.0.bind(metadata).map_err(PyErr::from)
72            }
73            Self::Mandelstam(mandelstam) => mandelstam.0.bind(metadata).map_err(PyErr::from),
74        }
75    }
76
77    pub(crate) fn bound(&self, metadata: &DatasetMetadata) -> PyResult<Self> {
78        let mut cloned = self.clone();
79        cloned.bind_in_place(metadata)?;
80        Ok(cloned)
81    }
82
83    pub(crate) fn evaluate_event(&self, event: &Arc<EventData>) -> PyResult<f64> {
84        Ok(self.value(event.as_ref()))
85    }
86}
87
88#[pyclass(name = "VariableExpression", module = "laddu")]
89pub struct PyVariableExpression(pub VariableExpression);
90
91#[pymethods]
92impl PyVariableExpression {
93    fn __and__(&self, rhs: &PyVariableExpression) -> PyVariableExpression {
94        PyVariableExpression(self.0.clone() & rhs.0.clone())
95    }
96    fn __or__(&self, rhs: &PyVariableExpression) -> PyVariableExpression {
97        PyVariableExpression(self.0.clone() | rhs.0.clone())
98    }
99    fn __invert__(&self) -> PyVariableExpression {
100        PyVariableExpression(!self.0.clone())
101    }
102    fn __str__(&self) -> String {
103        format!("{}", self.0)
104    }
105}
106
107#[derive(Clone, FromPyObject)]
108pub enum PyP4SelectionInput {
109    #[pyo3(transparent)]
110    Name(String),
111    #[pyo3(transparent)]
112    Names(Vec<String>),
113}
114
115impl PyP4SelectionInput {
116    fn into_selection(self) -> P4Selection {
117        match self {
118            PyP4SelectionInput::Name(name) => name.into_selection(),
119            PyP4SelectionInput::Names(names) => names.into_selection(),
120        }
121    }
122}
123
124/// A reusable 2-to-2 reaction description shared by multiple Variables.
125#[pyclass(name = "Topology", module = "laddu")]
126#[derive(Clone, Serialize, Deserialize)]
127pub struct PyTopology(pub Topology);
128
129#[pymethods]
130impl PyTopology {
131    #[new]
132    fn new(
133        k1: PyP4SelectionInput,
134        k2: PyP4SelectionInput,
135        k3: PyP4SelectionInput,
136        k4: PyP4SelectionInput,
137    ) -> Self {
138        Self(Topology::new(
139            k1.into_selection(),
140            k2.into_selection(),
141            k3.into_selection(),
142            k4.into_selection(),
143        ))
144    }
145
146    #[staticmethod]
147    fn missing_k1(k2: PyP4SelectionInput, k3: PyP4SelectionInput, k4: PyP4SelectionInput) -> Self {
148        Self(Topology::missing_k1(
149            k2.into_selection(),
150            k3.into_selection(),
151            k4.into_selection(),
152        ))
153    }
154
155    #[staticmethod]
156    fn missing_k2(k1: PyP4SelectionInput, k3: PyP4SelectionInput, k4: PyP4SelectionInput) -> Self {
157        Self(Topology::missing_k2(
158            k1.into_selection(),
159            k3.into_selection(),
160            k4.into_selection(),
161        ))
162    }
163
164    #[staticmethod]
165    fn missing_k3(k1: PyP4SelectionInput, k2: PyP4SelectionInput, k4: PyP4SelectionInput) -> Self {
166        Self(Topology::missing_k3(
167            k1.into_selection(),
168            k2.into_selection(),
169            k4.into_selection(),
170        ))
171    }
172
173    #[staticmethod]
174    fn missing_k4(k1: PyP4SelectionInput, k2: PyP4SelectionInput, k3: PyP4SelectionInput) -> Self {
175        Self(Topology::missing_k4(
176            k1.into_selection(),
177            k2.into_selection(),
178            k3.into_selection(),
179        ))
180    }
181
182    fn k1_names(&self) -> Option<Vec<String>> {
183        self.0.k1_names().map(|names| names.to_vec())
184    }
185
186    fn k2_names(&self) -> Option<Vec<String>> {
187        self.0.k2_names().map(|names| names.to_vec())
188    }
189
190    fn k3_names(&self) -> Option<Vec<String>> {
191        self.0.k3_names().map(|names| names.to_vec())
192    }
193
194    fn k4_names(&self) -> Option<Vec<String>> {
195        self.0.k4_names().map(|names| names.to_vec())
196    }
197
198    fn com_boost_vector(&self, event: &PyEvent) -> PyResult<PyVec3> {
199        let (topology, event_data) = self.topology_for_event(event)?;
200        Ok(PyVec3(topology.com_boost_vector(event_data)))
201    }
202
203    fn k1(&self, event: &PyEvent) -> PyResult<PyVec4> {
204        let (topology, event_data) = self.topology_for_event(event)?;
205        Ok(PyVec4(topology.k1(event_data)))
206    }
207
208    fn k2(&self, event: &PyEvent) -> PyResult<PyVec4> {
209        let (topology, event_data) = self.topology_for_event(event)?;
210        Ok(PyVec4(topology.k2(event_data)))
211    }
212
213    fn k3(&self, event: &PyEvent) -> PyResult<PyVec4> {
214        let (topology, event_data) = self.topology_for_event(event)?;
215        Ok(PyVec4(topology.k3(event_data)))
216    }
217
218    fn k4(&self, event: &PyEvent) -> PyResult<PyVec4> {
219        let (topology, event_data) = self.topology_for_event(event)?;
220        Ok(PyVec4(topology.k4(event_data)))
221    }
222
223    fn k1_com(&self, event: &PyEvent) -> PyResult<PyVec4> {
224        let (topology, event_data) = self.topology_for_event(event)?;
225        Ok(PyVec4(topology.k1_com(event_data)))
226    }
227
228    fn k2_com(&self, event: &PyEvent) -> PyResult<PyVec4> {
229        let (topology, event_data) = self.topology_for_event(event)?;
230        Ok(PyVec4(topology.k2_com(event_data)))
231    }
232
233    fn k3_com(&self, event: &PyEvent) -> PyResult<PyVec4> {
234        let (topology, event_data) = self.topology_for_event(event)?;
235        Ok(PyVec4(topology.k3_com(event_data)))
236    }
237
238    fn k4_com(&self, event: &PyEvent) -> PyResult<PyVec4> {
239        let (topology, event_data) = self.topology_for_event(event)?;
240        Ok(PyVec4(topology.k4_com(event_data)))
241    }
242
243    fn __repr__(&self) -> String {
244        format!("{:?}", self.0)
245    }
246    fn __str__(&self) -> String {
247        format!("{}", self.0)
248    }
249}
250
251impl PyTopology {
252    fn topology_for_event<'event>(
253        &self,
254        event: &'event PyEvent,
255    ) -> PyResult<(Topology, &'event EventData)> {
256        let metadata = event.metadata_opt().ok_or_else(|| {
257            PyValueError::new_err(
258                "This event is not associated with metadata; supply `p4_names`/`aux_names` when constructing it or evaluate via a Dataset.",
259            )
260        })?;
261        let mut topology = self.0.clone();
262        topology.bind(metadata).map_err(PyErr::from)?;
263        Ok((topology, event.event.data()))
264    }
265}
266
267/// The invariant mass of an arbitrary combination of constituent particles in an Event
268///
269/// This variable is calculated by summing up the 4-momenta of each particle listed by index in
270/// `constituents` and taking the invariant magnitude of the resulting 4-vector.
271///
272/// Parameters
273/// ----------
274/// constituents : str or list of str
275///     Particle names to combine when constructing the final four-momentum
276///
277/// See Also
278/// --------
279/// laddu.utils.vectors.Vec4.m
280///
281#[pyclass(name = "Mass", module = "laddu")]
282#[derive(Clone, Serialize, Deserialize)]
283pub struct PyMass(pub Mass);
284
285#[pymethods]
286impl PyMass {
287    #[new]
288    fn new(constituents: PyP4SelectionInput) -> Self {
289        Self(Mass::new(constituents.into_selection()))
290    }
291    /// The value of this Variable for the given Event
292    ///
293    /// Parameters
294    /// ----------
295    /// event : Event
296    ///     The Event upon which the Variable is calculated
297    ///
298    /// Returns
299    /// -------
300    /// value : float
301    ///     The value of the Variable for the given `event`
302    ///
303    fn value(&self, event: &PyEvent) -> PyResult<f64> {
304        let metadata = event
305            .metadata_opt()
306            .ok_or_else(|| PyValueError::new_err(
307                "This event is not associated with metadata; supply `p4_names`/`aux_names` when constructing it or evaluate via a Dataset.",
308            ))?;
309        let mut variable = self.0.clone();
310        variable.bind(metadata).map_err(PyErr::from)?;
311        Ok(variable.value(event.event.data()))
312    }
313    /// All values of this Variable on the given Dataset
314    ///
315    /// Parameters
316    /// ----------
317    /// dataset : Dataset
318    ///     The Dataset upon which the Variable is calculated
319    ///
320    /// Returns
321    /// -------
322    /// values : array_like
323    ///     The values of the Variable for each Event in the given `dataset`
324    ///
325    fn value_on<'py>(
326        &self,
327        py: Python<'py>,
328        dataset: &PyDataset,
329    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
330        let values = self.0.value_on(&dataset.0).map_err(PyErr::from)?;
331        Ok(PyArray1::from_vec(py, values))
332    }
333    fn __eq__(&self, value: f64) -> PyVariableExpression {
334        PyVariableExpression(self.0.eq(value))
335    }
336    fn __lt__(&self, value: f64) -> PyVariableExpression {
337        PyVariableExpression(self.0.lt(value))
338    }
339    fn __gt__(&self, value: f64) -> PyVariableExpression {
340        PyVariableExpression(self.0.gt(value))
341    }
342    fn __le__(&self, value: f64) -> PyVariableExpression {
343        PyVariableExpression(self.0.le(value))
344    }
345    fn __ge__(&self, value: f64) -> PyVariableExpression {
346        PyVariableExpression(self.0.ge(value))
347    }
348    fn __repr__(&self) -> String {
349        format!("{:?}", self.0)
350    }
351    fn __str__(&self) -> String {
352        format!("{}", self.0)
353    }
354}
355
356/// The cosine of the polar decay angle in the rest frame of the given `resonance`
357///
358/// This Variable is calculated by forming the given frame (helicity or Gottfried-Jackson) and
359/// calculating the spherical angles according to one of the decaying `daughter` particles.
360///
361/// The helicity frame is defined in terms of the following Cartesian axes in the rest frame of
362/// the `resonance`:
363///
364/// .. math:: \hat{z} \propto -\vec{p}'_{\text{recoil}}
365/// .. math:: \hat{y} \propto \vec{p}_{\text{beam}} \times (-\vec{p}_{\text{recoil}})
366/// .. math:: \hat{x} = \hat{y} \times \hat{z}
367///
368/// where primed vectors are in the rest frame of the `resonance` and unprimed vectors are in
369/// the center-of-momentum frame.
370///
371/// The Gottfried-Jackson frame differs only in the definition of :math:`\hat{z}`:
372///
373/// .. math:: \hat{z} \propto \vec{p}'_{\text{beam}}
374///
375/// Parameters
376/// ----------
377/// topology : laddu.Topology
378///     Topology describing the 2-to-2 production kinematics in the center-of-momentum frame.
379/// daughter : list of str
380///     Names of particles which are combined to form one of the decay products of the
381///     resonance associated with ``k3`` of the topology.
382/// frame : {'Helicity', 'HX', 'HEL', 'GottfriedJackson', 'Gottfried Jackson', 'GJ', 'Gottfried-Jackson'}
383///     The frame to use in the  calculation
384///
385/// Raises
386/// ------
387/// ValueError
388///     If `frame` is not one of the valid options
389///
390/// See Also
391/// --------
392/// laddu.utils.vectors.Vec3.costheta
393///
394#[pyclass(name = "CosTheta", module = "laddu")]
395#[derive(Clone, Serialize, Deserialize)]
396pub struct PyCosTheta(pub CosTheta);
397
398#[pymethods]
399impl PyCosTheta {
400    #[new]
401    #[pyo3(signature=(topology, daughter, frame="Helicity"))]
402    fn new(topology: PyTopology, daughter: PyP4SelectionInput, frame: &str) -> PyResult<Self> {
403        Ok(Self(CosTheta::new(
404            topology.0.clone(),
405            daughter.into_selection(),
406            frame.parse()?,
407        )))
408    }
409    /// The value of this Variable for the given Event
410    ///
411    /// Parameters
412    /// ----------
413    /// event : Event
414    ///     The Event upon which the Variable is calculated
415    ///
416    /// Returns
417    /// -------
418    /// value : float
419    ///     The value of the Variable for the given `event`
420    ///
421    fn value(&self, event: &PyEvent) -> PyResult<f64> {
422        let metadata = event
423            .metadata_opt()
424            .ok_or_else(|| PyValueError::new_err(
425                "This event is not associated with metadata; supply `p4_names`/`aux_names` when constructing it or evaluate via a Dataset.",
426            ))?;
427        let mut variable = self.0.clone();
428        variable.bind(metadata).map_err(PyErr::from)?;
429        Ok(variable.value(event.event.data()))
430    }
431    /// All values of this Variable on the given Dataset
432    ///
433    /// Parameters
434    /// ----------
435    /// dataset : Dataset
436    ///     The Dataset upon which the Variable is calculated
437    ///
438    /// Returns
439    /// -------
440    /// values : array_like
441    ///     The values of the Variable for each Event in the given `dataset`
442    ///
443    fn value_on<'py>(
444        &self,
445        py: Python<'py>,
446        dataset: &PyDataset,
447    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
448        let values = self.0.value_on(&dataset.0).map_err(PyErr::from)?;
449        Ok(PyArray1::from_vec(py, values))
450    }
451    fn __eq__(&self, value: f64) -> PyVariableExpression {
452        PyVariableExpression(self.0.eq(value))
453    }
454    fn __lt__(&self, value: f64) -> PyVariableExpression {
455        PyVariableExpression(self.0.lt(value))
456    }
457    fn __gt__(&self, value: f64) -> PyVariableExpression {
458        PyVariableExpression(self.0.gt(value))
459    }
460    fn __le__(&self, value: f64) -> PyVariableExpression {
461        PyVariableExpression(self.0.le(value))
462    }
463    fn __ge__(&self, value: f64) -> PyVariableExpression {
464        PyVariableExpression(self.0.ge(value))
465    }
466    fn __repr__(&self) -> String {
467        format!("{:?}", self.0)
468    }
469    fn __str__(&self) -> String {
470        format!("{}", self.0)
471    }
472}
473
474/// The aziumuthal decay angle in the rest frame of the given `resonance`
475///
476/// This Variable is calculated by forming the given frame (helicity or Gottfried-Jackson) and
477/// calculating the spherical angles according to one of the decaying `daughter` particles.
478///
479/// The helicity frame is defined in terms of the following Cartesian axes in the rest frame of
480/// the `resonance`:
481///
482/// .. math:: \hat{z} \propto -\vec{p}'_{\text{recoil}}
483/// .. math:: \hat{y} \propto \vec{p}_{\text{beam}} \times (-\vec{p}_{\text{recoil}})
484/// .. math:: \hat{x} = \hat{y} \times \hat{z}
485///
486/// where primed vectors are in the rest frame of the `resonance` and unprimed vectors are in
487/// the center-of-momentum frame.
488///
489/// The Gottfried-Jackson frame differs only in the definition of :math:`\hat{z}`:
490///
491/// .. math:: \hat{z} \propto \vec{p}'_{\text{beam}}
492///
493/// Parameters
494/// ----------
495/// topology : laddu.Topology
496///     Topology describing the 2-to-2 production kinematics in the center-of-momentum frame.
497/// daughter : list of str
498///     Names of particles which are combined to form one of the decay products of the
499///     resonance associated with ``k3`` of the topology.
500/// frame : {'Helicity', 'HX', 'HEL', 'GottfriedJackson', 'Gottfried Jackson', 'GJ', 'Gottfried-Jackson'}
501///     The frame to use in the  calculation
502///
503/// Raises
504/// ------
505/// ValueError
506///     If `frame` is not one of the valid options
507///
508///
509/// See Also
510/// --------
511/// laddu.utils.vectors.Vec3.phi
512///
513#[pyclass(name = "Phi", module = "laddu")]
514#[derive(Clone, Serialize, Deserialize)]
515pub struct PyPhi(pub Phi);
516
517#[pymethods]
518impl PyPhi {
519    #[new]
520    #[pyo3(signature=(topology, daughter, frame="Helicity"))]
521    fn new(topology: PyTopology, daughter: PyP4SelectionInput, frame: &str) -> PyResult<Self> {
522        Ok(Self(Phi::new(
523            topology.0.clone(),
524            daughter.into_selection(),
525            frame.parse()?,
526        )))
527    }
528    /// The value of this Variable for the given Event
529    ///
530    /// Parameters
531    /// ----------
532    /// event : Event
533    ///     The Event upon which the Variable is calculated
534    ///
535    /// Returns
536    /// -------
537    /// value : float
538    ///     The value of the Variable for the given `event`
539    ///
540    fn value(&self, event: &PyEvent) -> PyResult<f64> {
541        let metadata = event
542            .metadata_opt()
543            .ok_or_else(|| PyValueError::new_err(
544                "This event is not associated with metadata; supply `p4_names`/`aux_names` when constructing it or evaluate via a Dataset.",
545            ))?;
546        let mut variable = self.0.clone();
547        variable.bind(metadata).map_err(PyErr::from)?;
548        Ok(variable.value(event.event.data()))
549    }
550    /// All values of this Variable on the given Dataset
551    ///
552    /// Parameters
553    /// ----------
554    /// dataset : Dataset
555    ///     The Dataset upon which the Variable is calculated
556    ///
557    /// Returns
558    /// -------
559    /// values : array_like
560    ///     The values of the Variable for each Event in the given `dataset`
561    ///
562    fn value_on<'py>(
563        &self,
564        py: Python<'py>,
565        dataset: &PyDataset,
566    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
567        let values = self.0.value_on(&dataset.0).map_err(PyErr::from)?;
568        Ok(PyArray1::from_vec(py, values))
569    }
570    fn __eq__(&self, value: f64) -> PyVariableExpression {
571        PyVariableExpression(self.0.eq(value))
572    }
573    fn __lt__(&self, value: f64) -> PyVariableExpression {
574        PyVariableExpression(self.0.lt(value))
575    }
576    fn __gt__(&self, value: f64) -> PyVariableExpression {
577        PyVariableExpression(self.0.gt(value))
578    }
579    fn __le__(&self, value: f64) -> PyVariableExpression {
580        PyVariableExpression(self.0.le(value))
581    }
582    fn __ge__(&self, value: f64) -> PyVariableExpression {
583        PyVariableExpression(self.0.ge(value))
584    }
585    fn __repr__(&self) -> String {
586        format!("{:?}", self.0)
587    }
588    fn __str__(&self) -> String {
589        format!("{}", self.0)
590    }
591}
592
593/// A Variable used to define both spherical decay angles in the given frame
594///
595/// This class combines ``laddu.CosTheta`` and ``laddu.Phi`` into a single
596/// object
597///
598/// Parameters
599/// ----------
600/// topology : laddu.Topology
601///     Topology describing the 2-to-2 production kinematics in the center-of-momentum frame.
602/// daughter : list of str
603///     Names of particles which are combined to form one of the decay products of the
604///     resonance associated with ``k3`` of the topology.
605/// frame : {'Helicity', 'HX', 'HEL', 'GottfriedJackson', 'Gottfried Jackson', 'GJ', 'Gottfried-Jackson'}
606///     The frame to use in the  calculation
607///
608/// Raises
609/// ------
610/// ValueError
611///     If `frame` is not one of the valid options
612///
613/// See Also
614/// --------
615/// laddu.CosTheta
616/// laddu.Phi
617///
618#[pyclass(name = "Angles", module = "laddu")]
619#[derive(Clone)]
620pub struct PyAngles(pub Angles);
621#[pymethods]
622impl PyAngles {
623    #[new]
624    #[pyo3(signature=(topology, daughter, frame="Helicity"))]
625    fn new(topology: PyTopology, daughter: PyP4SelectionInput, frame: &str) -> PyResult<Self> {
626        Ok(Self(Angles::new(
627            topology.0.clone(),
628            daughter.into_selection(),
629            frame.parse()?,
630        )))
631    }
632    /// The Variable representing the cosine of the polar spherical decay angle
633    ///
634    /// Returns
635    /// -------
636    /// CosTheta
637    ///
638    #[getter]
639    fn costheta(&self) -> PyCosTheta {
640        PyCosTheta(self.0.costheta.clone())
641    }
642    // The Variable representing the polar azimuthal decay angle
643    //
644    // Returns
645    // -------
646    // Phi
647    //
648    #[getter]
649    fn phi(&self) -> PyPhi {
650        PyPhi(self.0.phi.clone())
651    }
652    fn __repr__(&self) -> String {
653        format!("{:?}", self.0)
654    }
655    fn __str__(&self) -> String {
656        format!("{}", self.0)
657    }
658}
659
660/// The polar angle of the given polarization vector with respect to the production plane
661///
662/// The `beam` and `recoil` particles define the plane of production, and this Variable
663/// describes the polar angle of the `beam` relative to this plane
664///
665/// Parameters
666/// ----------
667/// topology : laddu.Topology
668///     Topology describing the 2-to-2 production kinematics in the center-of-momentum frame.
669/// pol_angle : str
670///     Name of the auxiliary scalar column storing the polarization angle in radians
671///
672#[pyclass(name = "PolAngle", module = "laddu")]
673#[derive(Clone, Serialize, Deserialize)]
674pub struct PyPolAngle(pub PolAngle);
675
676#[pymethods]
677impl PyPolAngle {
678    #[new]
679    fn new(topology: PyTopology, pol_angle: String) -> Self {
680        Self(PolAngle::new(topology.0.clone(), pol_angle))
681    }
682    /// The value of this Variable for the given Event
683    ///
684    /// Parameters
685    /// ----------
686    /// event : Event
687    ///     The Event upon which the Variable is calculated
688    ///
689    /// Returns
690    /// -------
691    /// value : float
692    ///     The value of the Variable for the given `event`
693    ///
694    fn value(&self, event: &PyEvent) -> PyResult<f64> {
695        let metadata = event
696            .metadata_opt()
697            .ok_or_else(|| PyValueError::new_err(
698                "This event is not associated with metadata; supply `p4_names`/`aux_names` when constructing it or evaluate via a Dataset.",
699            ))?;
700        let mut variable = self.0.clone();
701        variable.bind(metadata).map_err(PyErr::from)?;
702        Ok(variable.value(event.event.data()))
703    }
704    /// All values of this Variable on the given Dataset
705    ///
706    /// Parameters
707    /// ----------
708    /// dataset : Dataset
709    ///     The Dataset upon which the Variable is calculated
710    ///
711    /// Returns
712    /// -------
713    /// values : array_like
714    ///     The values of the Variable for each Event in the given `dataset`
715    ///
716    fn value_on<'py>(
717        &self,
718        py: Python<'py>,
719        dataset: &PyDataset,
720    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
721        let values = self.0.value_on(&dataset.0).map_err(PyErr::from)?;
722        Ok(PyArray1::from_vec(py, values))
723    }
724    fn __eq__(&self, value: f64) -> PyVariableExpression {
725        PyVariableExpression(self.0.eq(value))
726    }
727    fn __lt__(&self, value: f64) -> PyVariableExpression {
728        PyVariableExpression(self.0.lt(value))
729    }
730    fn __gt__(&self, value: f64) -> PyVariableExpression {
731        PyVariableExpression(self.0.gt(value))
732    }
733    fn __le__(&self, value: f64) -> PyVariableExpression {
734        PyVariableExpression(self.0.le(value))
735    }
736    fn __ge__(&self, value: f64) -> PyVariableExpression {
737        PyVariableExpression(self.0.ge(value))
738    }
739    fn __repr__(&self) -> String {
740        format!("{:?}", self.0)
741    }
742    fn __str__(&self) -> String {
743        format!("{}", self.0)
744    }
745}
746
747/// The magnitude of the given particle's polarization vector
748///
749/// This Variable simply represents the magnitude of the polarization vector of the particle
750/// with the index `beam`
751///
752/// Parameters
753/// ----------
754/// pol_magnitude : str
755///     Name of the auxiliary scalar column storing the magnitude of the polarization vector
756///
757/// See Also
758/// --------
759/// laddu.utils.vectors.Vec3.mag
760///
761#[pyclass(name = "PolMagnitude", module = "laddu")]
762#[derive(Clone, Serialize, Deserialize)]
763pub struct PyPolMagnitude(pub PolMagnitude);
764
765#[pymethods]
766impl PyPolMagnitude {
767    #[new]
768    fn new(pol_magnitude: String) -> Self {
769        Self(PolMagnitude::new(pol_magnitude))
770    }
771    /// The value of this Variable for the given Event
772    ///
773    /// Parameters
774    /// ----------
775    /// event : Event
776    ///     The Event upon which the Variable is calculated
777    ///
778    /// Returns
779    /// -------
780    /// value : float
781    ///     The value of the Variable for the given `event`
782    ///
783    fn value(&self, event: &PyEvent) -> PyResult<f64> {
784        let metadata = event
785            .metadata_opt()
786            .ok_or_else(|| PyValueError::new_err(
787                "This event is not associated with metadata; supply `p4_names`/`aux_names` when constructing it or evaluate via a Dataset.",
788            ))?;
789        let mut variable = self.0.clone();
790        variable.bind(metadata).map_err(PyErr::from)?;
791        Ok(variable.value(event.event.data()))
792    }
793    /// All values of this Variable on the given Dataset
794    ///
795    /// Parameters
796    /// ----------
797    /// dataset : Dataset
798    ///     The Dataset upon which the Variable is calculated
799    ///
800    /// Returns
801    /// -------
802    /// values : array_like
803    ///     The values of the Variable for each Event in the given `dataset`
804    ///
805    fn value_on<'py>(
806        &self,
807        py: Python<'py>,
808        dataset: &PyDataset,
809    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
810        let values = self.0.value_on(&dataset.0).map_err(PyErr::from)?;
811        Ok(PyArray1::from_vec(py, values))
812    }
813    fn __eq__(&self, value: f64) -> PyVariableExpression {
814        PyVariableExpression(self.0.eq(value))
815    }
816    fn __lt__(&self, value: f64) -> PyVariableExpression {
817        PyVariableExpression(self.0.lt(value))
818    }
819    fn __gt__(&self, value: f64) -> PyVariableExpression {
820        PyVariableExpression(self.0.gt(value))
821    }
822    fn __le__(&self, value: f64) -> PyVariableExpression {
823        PyVariableExpression(self.0.le(value))
824    }
825    fn __ge__(&self, value: f64) -> PyVariableExpression {
826        PyVariableExpression(self.0.ge(value))
827    }
828    fn __repr__(&self) -> String {
829        format!("{:?}", self.0)
830    }
831    fn __str__(&self) -> String {
832        format!("{}", self.0)
833    }
834}
835
836/// A Variable used to define both the polarization angle and magnitude of the given particle``
837///
838/// This class combines ``laddu.PolAngle`` and ``laddu.PolMagnitude`` into a single
839/// object
840///
841/// Parameters
842/// ----------
843/// topology : laddu.Topology
844///     Topology describing the 2-to-2 production kinematics in the center-of-momentum frame.
845/// pol_magnitude : str
846///     Name of the auxiliary scalar storing the polarization magnitude
847/// pol_angle : str
848///     Name of the auxiliary scalar storing the polarization angle in radians
849///
850/// See Also
851/// --------
852/// laddu.PolAngle
853/// laddu.PolMagnitude
854///
855#[pyclass(name = "Polarization", module = "laddu")]
856#[derive(Clone)]
857pub struct PyPolarization(pub Polarization);
858#[pymethods]
859impl PyPolarization {
860    #[new]
861    #[pyo3(signature=(topology, *, pol_magnitude, pol_angle))]
862    fn new(topology: PyTopology, pol_magnitude: String, pol_angle: String) -> PyResult<Self> {
863        if pol_magnitude == pol_angle {
864            return Err(PyValueError::new_err(
865                "`pol_magnitude` and `pol_angle` must reference distinct auxiliary columns",
866            ));
867        }
868        let polarization = Polarization::new(topology.0.clone(), pol_magnitude, pol_angle);
869        Ok(PyPolarization(polarization))
870    }
871    /// The Variable representing the magnitude of the polarization vector
872    ///
873    /// Returns
874    /// -------
875    /// PolMagnitude
876    ///
877    #[getter]
878    fn pol_magnitude(&self) -> PyPolMagnitude {
879        PyPolMagnitude(self.0.pol_magnitude.clone())
880    }
881    /// The Variable representing the polar angle of the polarization vector
882    ///
883    /// Returns
884    /// -------
885    /// PolAngle
886    ///
887    #[getter]
888    fn pol_angle(&self) -> PyPolAngle {
889        PyPolAngle(self.0.pol_angle.clone())
890    }
891    fn __repr__(&self) -> String {
892        format!("{:?}", self.0)
893    }
894    fn __str__(&self) -> String {
895        format!("{}", self.0)
896    }
897}
898
899/// Mandelstam variables s, t, and u
900///
901/// By convention, the metric is chosen to be :math:`(+---)` and the variables are defined as follows
902/// (ignoring factors of :math:`c`):
903///
904/// .. math:: s = (p_1 + p_2)^2 = (p_3 + p_4)^2
905///
906/// .. math:: t = (p_1 - p_3)^2 = (p_4 - p_2)^2
907///
908/// .. math:: u = (p_1 - p_4)^2 = (p_3 - p_2)^2
909///
910/// Parameters
911/// ----------
912/// topology : laddu.Topology
913///     Topology describing the 2-to-2 kinematics whose Mandelstam channels should be evaluated.
914/// channel: {'s', 't', 'u', 'S', 'T', 'U'}
915///     The Mandelstam channel to calculate
916///
917/// Raises
918/// ------
919/// Exception
920///     If more than one particle list is empty
921/// ValueError
922///     If `channel` is not one of the valid options
923///
924/// Notes
925/// -----
926/// At most one of the input particles may be omitted by using an empty list. This will cause
927/// the calculation to use whichever equality listed above does not contain that particle.
928///
929/// By default, the first equality is used if no particle lists are empty.
930///
931#[pyclass(name = "Mandelstam", module = "laddu")]
932#[derive(Clone, Serialize, Deserialize)]
933pub struct PyMandelstam(pub Mandelstam);
934
935#[pymethods]
936impl PyMandelstam {
937    #[new]
938    fn new(topology: PyTopology, channel: &str) -> PyResult<Self> {
939        Ok(Self(Mandelstam::new(topology.0.clone(), channel.parse()?)))
940    }
941    /// The value of this Variable for the given Event
942    ///
943    /// Parameters
944    /// ----------
945    /// event : Event
946    ///     The Event upon which the Variable is calculated
947    ///
948    /// Returns
949    /// -------
950    /// value : float
951    ///     The value of the Variable for the given `event`
952    ///
953    fn value(&self, event: &PyEvent) -> PyResult<f64> {
954        let metadata = event
955            .metadata_opt()
956            .ok_or_else(|| PyValueError::new_err(
957                "This event is not associated with metadata; supply `p4_names`/`aux_names` when constructing it or evaluate via a Dataset.",
958            ))?;
959        let mut variable = self.0.clone();
960        variable.bind(metadata).map_err(PyErr::from)?;
961        Ok(variable.value(event.event.data()))
962    }
963    /// All values of this Variable on the given Dataset
964    ///
965    /// Parameters
966    /// ----------
967    /// dataset : Dataset
968    ///     The Dataset upon which the Variable is calculated
969    ///
970    /// Returns
971    /// -------
972    /// values : array_like
973    ///     The values of the Variable for each Event in the given `dataset`
974    ///
975    fn value_on<'py>(
976        &self,
977        py: Python<'py>,
978        dataset: &PyDataset,
979    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
980        let values = self.0.value_on(&dataset.0).map_err(PyErr::from)?;
981        Ok(PyArray1::from_vec(py, values))
982    }
983    fn __eq__(&self, value: f64) -> PyVariableExpression {
984        PyVariableExpression(self.0.eq(value))
985    }
986    fn __lt__(&self, value: f64) -> PyVariableExpression {
987        PyVariableExpression(self.0.lt(value))
988    }
989    fn __gt__(&self, value: f64) -> PyVariableExpression {
990        PyVariableExpression(self.0.gt(value))
991    }
992    fn __le__(&self, value: f64) -> PyVariableExpression {
993        PyVariableExpression(self.0.le(value))
994    }
995    fn __ge__(&self, value: f64) -> PyVariableExpression {
996        PyVariableExpression(self.0.ge(value))
997    }
998    fn __repr__(&self) -> String {
999        format!("{:?}", self.0)
1000    }
1001    fn __str__(&self) -> String {
1002        format!("{}", self.0)
1003    }
1004}
1005
1006#[typetag::serde]
1007impl Variable for PyVariable {
1008    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
1009        match self {
1010            PyVariable::Mass(mass) => mass.0.bind(metadata),
1011            PyVariable::CosTheta(cos_theta) => cos_theta.0.bind(metadata),
1012            PyVariable::Phi(phi) => phi.0.bind(metadata),
1013            PyVariable::PolAngle(pol_angle) => pol_angle.0.bind(metadata),
1014            PyVariable::PolMagnitude(pol_magnitude) => pol_magnitude.0.bind(metadata),
1015            PyVariable::Mandelstam(mandelstam) => mandelstam.0.bind(metadata),
1016        }
1017    }
1018
1019    fn value_on(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
1020        match self {
1021            PyVariable::Mass(mass) => mass.0.value_on(dataset),
1022            PyVariable::CosTheta(cos_theta) => cos_theta.0.value_on(dataset),
1023            PyVariable::Phi(phi) => phi.0.value_on(dataset),
1024            PyVariable::PolAngle(pol_angle) => pol_angle.0.value_on(dataset),
1025            PyVariable::PolMagnitude(pol_magnitude) => pol_magnitude.0.value_on(dataset),
1026            PyVariable::Mandelstam(mandelstam) => mandelstam.0.value_on(dataset),
1027        }
1028    }
1029
1030    fn value(&self, event: &EventData) -> f64 {
1031        match self {
1032            PyVariable::Mass(mass) => mass.0.value(event),
1033            PyVariable::CosTheta(cos_theta) => cos_theta.0.value(event),
1034            PyVariable::Phi(phi) => phi.0.value(event),
1035            PyVariable::PolAngle(pol_angle) => pol_angle.0.value(event),
1036            PyVariable::PolMagnitude(pol_magnitude) => pol_magnitude.0.value(event),
1037            PyVariable::Mandelstam(mandelstam) => mandelstam.0.value(event),
1038        }
1039    }
1040}