laddu_core/utils/
variables.rs

1use dyn_clone::DynClone;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Debug, Display};
4
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8#[cfg(feature = "mpi")]
9use crate::mpi::LadduMPI;
10use crate::{
11    data::{Dataset, Event},
12    utils::{
13        enums::{Channel, Frame},
14        vectors::Vec3,
15    },
16    Float, LadduError,
17};
18
19use auto_ops::impl_op_ex;
20
21#[cfg(feature = "mpi")]
22use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
23
24/// Standard methods for extracting some value out of an [`Event`].
25#[typetag::serde(tag = "type")]
26pub trait Variable: DynClone + Send + Sync + Debug + Display {
27    /// This method takes an [`Event`] and extracts a single value (like the mass of a particle).
28    fn value(&self, event: &Event) -> Float;
29
30    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
31    /// [`Dataset`] (non-MPI version).
32    ///
33    /// # Notes
34    ///
35    /// This method is not intended to be called in analyses but rather in writing methods
36    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
37    fn value_on_local(&self, dataset: &Dataset) -> Vec<Float> {
38        #[cfg(feature = "rayon")]
39        let local_values: Vec<Float> = dataset.events.par_iter().map(|e| self.value(e)).collect();
40        #[cfg(not(feature = "rayon"))]
41        let local_values: Vec<Float> = dataset.events.iter().map(|e| self.value(e)).collect();
42        local_values
43    }
44
45    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
46    /// [`Dataset`] (MPI-compatible version).
47    ///
48    /// # Notes
49    ///
50    /// This method is not intended to be called in analyses but rather in writing methods
51    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
52    #[cfg(feature = "mpi")]
53    fn value_on_mpi(&self, dataset: &Dataset, world: &SimpleCommunicator) -> Vec<Float> {
54        let local_weights = self.value_on_local(dataset);
55        let n_events = dataset.n_events();
56        let mut buffer: Vec<Float> = vec![0.0; n_events];
57        let (counts, displs) = world.get_counts_displs(n_events);
58        {
59            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
60            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
61        }
62        buffer
63    }
64
65    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
66    /// [`Dataset`].
67    fn value_on(&self, dataset: &Dataset) -> Vec<Float> {
68        #[cfg(feature = "mpi")]
69        {
70            if let Some(world) = crate::mpi::get_world() {
71                return self.value_on_mpi(dataset, &world);
72            }
73        }
74        self.value_on_local(dataset)
75    }
76
77    /// Create an [`VariableExpression`] that evaluates to `self == val`
78    fn eq(&self, val: Float) -> VariableExpression
79    where
80        Self: std::marker::Sized + 'static,
81    {
82        VariableExpression::Eq(dyn_clone::clone_box(self), val)
83    }
84
85    /// Create an [`VariableExpression`] that evaluates to `self < val`
86    fn lt(&self, val: Float) -> VariableExpression
87    where
88        Self: std::marker::Sized + 'static,
89    {
90        VariableExpression::Lt(dyn_clone::clone_box(self), val)
91    }
92
93    /// Create an [`VariableExpression`] that evaluates to `self > val`
94    fn gt(&self, val: Float) -> VariableExpression
95    where
96        Self: std::marker::Sized + 'static,
97    {
98        VariableExpression::Gt(dyn_clone::clone_box(self), val)
99    }
100
101    /// Create an [`VariableExpression`] that evaluates to `self >= val`
102    fn ge(&self, val: Float) -> VariableExpression
103    where
104        Self: std::marker::Sized + 'static,
105    {
106        self.gt(val).or(&self.eq(val))
107    }
108
109    /// Create an [`VariableExpression`] that evaluates to `self <= val`
110    fn le(&self, val: Float) -> VariableExpression
111    where
112        Self: std::marker::Sized + 'static,
113    {
114        self.lt(val).or(&self.eq(val))
115    }
116}
117dyn_clone::clone_trait_object!(Variable);
118
119/// Expressions which can be used to compare [`Variable`]s to [`Float`]s.
120#[derive(Clone, Debug)]
121pub enum VariableExpression {
122    /// Expression which is true when the variable is equal to the float.
123    Eq(Box<dyn Variable>, Float),
124    /// Expression which is true when the variable is less than the float.
125    Lt(Box<dyn Variable>, Float),
126    /// Expression which is true when the variable is greater than the float.
127    Gt(Box<dyn Variable>, Float),
128    /// Expression which is true when both inner expressions are true.
129    And(Box<VariableExpression>, Box<VariableExpression>),
130    /// Expression which is true when either inner expression is true.
131    Or(Box<VariableExpression>, Box<VariableExpression>),
132    /// Expression which is true when the inner expression is false.
133    Not(Box<VariableExpression>),
134}
135
136impl VariableExpression {
137    /// Construct an [`VariableExpression::And`] from the current expression and another.
138    pub fn and(&self, rhs: &VariableExpression) -> VariableExpression {
139        VariableExpression::And(Box::new(self.clone()), Box::new(rhs.clone()))
140    }
141
142    /// Construct an [`VariableExpression::Or`] from the current expression and another.
143    pub fn or(&self, rhs: &VariableExpression) -> VariableExpression {
144        VariableExpression::Or(Box::new(self.clone()), Box::new(rhs.clone()))
145    }
146
147    /// Comple the [`VariableExpression`] into a [`CompiledExpression`].
148    pub(crate) fn compile(&self) -> CompiledExpression {
149        compile_expression(self.clone())
150    }
151}
152impl Display for VariableExpression {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        match self {
155            VariableExpression::Eq(var, val) => {
156                write!(f, "({} == {})", var, val)
157            }
158            VariableExpression::Lt(var, val) => {
159                write!(f, "({} < {})", var, val)
160            }
161            VariableExpression::Gt(var, val) => {
162                write!(f, "({} > {})", var, val)
163            }
164            VariableExpression::And(lhs, rhs) => {
165                write!(f, "({} & {})", lhs, rhs)
166            }
167            VariableExpression::Or(lhs, rhs) => {
168                write!(f, "({} | {})", lhs, rhs)
169            }
170            VariableExpression::Not(inner) => {
171                write!(f, "!({})", inner)
172            }
173        }
174    }
175}
176
177/// A method which negates the given expression.
178pub fn not(expr: &VariableExpression) -> VariableExpression {
179    VariableExpression::Not(Box::new(expr.clone()))
180}
181
182#[rustfmt::skip]
183impl_op_ex!(& |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.and(rhs) });
184#[rustfmt::skip]
185impl_op_ex!(| |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.or(rhs) });
186#[rustfmt::skip]
187impl_op_ex!(! |exp: &VariableExpression| -> VariableExpression{ not(exp) });
188
189#[derive(Debug)]
190enum Opcode {
191    PushEq(usize, Float),
192    PushLt(usize, Float),
193    PushGt(usize, Float),
194    And,
195    Or,
196    Not,
197}
198
199pub(crate) struct CompiledExpression {
200    bytecode: Vec<Opcode>,
201    variables: Vec<Box<dyn Variable>>,
202}
203
204impl CompiledExpression {
205    /// Evaluate the [`CompiledExpression`] on a given [`Event`].
206    pub fn evaluate(&self, event: &Event) -> bool {
207        let mut stack = Vec::with_capacity(self.bytecode.len());
208
209        for op in &self.bytecode {
210            match op {
211                Opcode::PushEq(i, val) => stack.push(self.variables[*i].value(event) == *val),
212                Opcode::PushLt(i, val) => stack.push(self.variables[*i].value(event) < *val),
213                Opcode::PushGt(i, val) => stack.push(self.variables[*i].value(event) > *val),
214                Opcode::Not => {
215                    let a = stack.pop().unwrap();
216                    stack.push(!a);
217                }
218                Opcode::And => {
219                    let b = stack.pop().unwrap();
220                    let a = stack.pop().unwrap();
221                    stack.push(a && b);
222                }
223                Opcode::Or => {
224                    let b = stack.pop().unwrap();
225                    let a = stack.pop().unwrap();
226                    stack.push(a || b);
227                }
228            }
229        }
230
231        stack.pop().unwrap()
232    }
233}
234
235pub(crate) fn compile_expression(expr: VariableExpression) -> CompiledExpression {
236    let mut bytecode = Vec::new();
237    let mut variables: Vec<Box<dyn Variable>> = Vec::new();
238
239    fn compile(
240        expr: VariableExpression,
241        bytecode: &mut Vec<Opcode>,
242        variables: &mut Vec<Box<dyn Variable>>,
243    ) {
244        match expr {
245            VariableExpression::Eq(var, val) => {
246                variables.push(var);
247                bytecode.push(Opcode::PushEq(variables.len() - 1, val));
248            }
249            VariableExpression::Lt(var, val) => {
250                variables.push(var);
251                bytecode.push(Opcode::PushLt(variables.len() - 1, val));
252            }
253            VariableExpression::Gt(var, val) => {
254                variables.push(var);
255                bytecode.push(Opcode::PushGt(variables.len() - 1, val));
256            }
257            VariableExpression::And(lhs, rhs) => {
258                compile(*lhs, bytecode, variables);
259                compile(*rhs, bytecode, variables);
260                bytecode.push(Opcode::And);
261            }
262            VariableExpression::Or(lhs, rhs) => {
263                compile(*lhs, bytecode, variables);
264                compile(*rhs, bytecode, variables);
265                bytecode.push(Opcode::Or);
266            }
267            VariableExpression::Not(inner) => {
268                compile(*inner, bytecode, variables);
269                bytecode.push(Opcode::Not);
270            }
271        }
272    }
273
274    compile(expr, &mut bytecode, &mut variables);
275
276    CompiledExpression {
277        bytecode,
278        variables,
279    }
280}
281
282fn sort_indices<T: AsRef<[usize]>>(indices: T) -> Vec<usize> {
283    let mut indices = indices.as_ref().to_vec();
284    indices.sort();
285    indices
286}
287
288fn indices_to_string<T: AsRef<[usize]>>(indices: T) -> String {
289    indices
290        .as_ref()
291        .iter()
292        .map(|n| n.to_string())
293        .collect::<Vec<_>>()
294        .join(", ")
295}
296
297/// A struct for obtaining the mass of a particle by indexing the four-momenta of an event, adding
298/// together multiple four-momenta if more than one index is given.
299#[derive(Clone, Debug, Serialize, Deserialize)]
300pub struct Mass(Vec<usize>);
301impl Mass {
302    /// Create a new [`Mass`] from the sum of the four-momenta at the given indices in the
303    /// [`Event`]'s `p4s` field.
304    pub fn new<T: AsRef<[usize]>>(constituents: T) -> Self {
305        Self(sort_indices(constituents))
306    }
307}
308impl Display for Mass {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        write!(f, "Mass(constituents=[{}])", indices_to_string(&self.0))
311    }
312}
313#[typetag::serde]
314impl Variable for Mass {
315    fn value(&self, event: &Event) -> Float {
316        event.get_p4_sum(&self.0).m()
317    }
318}
319
320/// A struct for obtaining the $`\cos\theta`$ (cosine of the polar angle) of a decay product in
321/// a given reference frame of its parent resonance.
322#[derive(Clone, Debug, Serialize, Deserialize)]
323pub struct CosTheta {
324    beam: usize,
325    recoil: Vec<usize>,
326    daughter: Vec<usize>,
327    resonance: Vec<usize>,
328    frame: Frame,
329}
330impl Display for CosTheta {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        write!(
333            f,
334            "CosTheta(beam={}, recoil=[{}], daughter=[{}], resonance=[{}], frame={})",
335            self.beam,
336            indices_to_string(&self.recoil),
337            indices_to_string(&self.daughter),
338            indices_to_string(&self.resonance),
339            self.frame
340        )
341    }
342}
343impl CosTheta {
344    /// Construct the angle given the four-momentum indices for each specified particle. Fields
345    /// which can take lists of more than one index will add the relevant four-momenta to make a
346    /// new particle from the constituents. See [`Frame`] for options regarding the reference
347    /// frame.
348    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
349        beam: usize,
350        recoil: T,
351        daughter: U,
352        resonance: V,
353        frame: Frame,
354    ) -> Self {
355        Self {
356            beam,
357            recoil: recoil.as_ref().into(),
358            daughter: daughter.as_ref().into(),
359            resonance: resonance.as_ref().into(),
360            frame,
361        }
362    }
363}
364impl Default for CosTheta {
365    fn default() -> Self {
366        Self {
367            beam: 0,
368            recoil: vec![1],
369            daughter: vec![2],
370            resonance: vec![2, 3],
371            frame: Frame::Helicity,
372        }
373    }
374}
375#[typetag::serde]
376impl Variable for CosTheta {
377    fn value(&self, event: &Event) -> Float {
378        let beam = event.p4s[self.beam];
379        let recoil = event.get_p4_sum(&self.recoil);
380        let daughter = event.get_p4_sum(&self.daughter);
381        let resonance = event.get_p4_sum(&self.resonance);
382        let daughter_res = daughter.boost(&-resonance.beta());
383        match self.frame {
384            Frame::Helicity => {
385                let recoil_res = recoil.boost(&-resonance.beta());
386                let z = -recoil_res.vec3().unit();
387                let y = beam.vec3().cross(&-recoil.vec3()).unit();
388                let x = y.cross(&z);
389                let angles = Vec3::new(
390                    daughter_res.vec3().dot(&x),
391                    daughter_res.vec3().dot(&y),
392                    daughter_res.vec3().dot(&z),
393                );
394                angles.costheta()
395            }
396            Frame::GottfriedJackson => {
397                let beam_res = beam.boost(&-resonance.beta());
398                let z = beam_res.vec3().unit();
399                let y = beam.vec3().cross(&-recoil.vec3()).unit();
400                let x = y.cross(&z);
401                let angles = Vec3::new(
402                    daughter_res.vec3().dot(&x),
403                    daughter_res.vec3().dot(&y),
404                    daughter_res.vec3().dot(&z),
405                );
406                angles.costheta()
407            }
408        }
409    }
410}
411
412/// A struct for obtaining the $`\phi`$ angle (azimuthal angle) of a decay product in a given
413/// reference frame of its parent resonance.
414#[derive(Clone, Debug, Serialize, Deserialize)]
415pub struct Phi {
416    beam: usize,
417    recoil: Vec<usize>,
418    daughter: Vec<usize>,
419    resonance: Vec<usize>,
420    frame: Frame,
421}
422impl Display for Phi {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        write!(
425            f,
426            "Phi(beam={}, recoil=[{}], daughter=[{}], resonance=[{}], frame={})",
427            self.beam,
428            indices_to_string(&self.recoil),
429            indices_to_string(&self.daughter),
430            indices_to_string(&self.resonance),
431            self.frame
432        )
433    }
434}
435impl Phi {
436    /// Construct the angle given the four-momentum indices for each specified particle. Fields
437    /// which can take lists of more than one index will add the relevant four-momenta to make a
438    /// new particle from the constituents. See [`Frame`] for options regarding the reference
439    /// frame.
440    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
441        beam: usize,
442        recoil: T,
443        daughter: U,
444        resonance: V,
445        frame: Frame,
446    ) -> Self {
447        Self {
448            beam,
449            recoil: recoil.as_ref().into(),
450            daughter: daughter.as_ref().into(),
451            resonance: resonance.as_ref().into(),
452            frame,
453        }
454    }
455}
456impl Default for Phi {
457    fn default() -> Self {
458        Self {
459            beam: 0,
460            recoil: vec![1],
461            daughter: vec![2],
462            resonance: vec![2, 3],
463            frame: Frame::Helicity,
464        }
465    }
466}
467#[typetag::serde]
468impl Variable for Phi {
469    fn value(&self, event: &Event) -> Float {
470        let beam = event.p4s[self.beam];
471        let recoil = event.get_p4_sum(&self.recoil);
472        let daughter = event.get_p4_sum(&self.daughter);
473        let resonance = event.get_p4_sum(&self.resonance);
474        let daughter_res = daughter.boost(&-resonance.beta());
475        match self.frame {
476            Frame::Helicity => {
477                let recoil_res = recoil.boost(&-resonance.beta());
478                let z = -recoil_res.vec3().unit();
479                let y = beam.vec3().cross(&-recoil.vec3()).unit();
480                let x = y.cross(&z);
481                let angles = Vec3::new(
482                    daughter_res.vec3().dot(&x),
483                    daughter_res.vec3().dot(&y),
484                    daughter_res.vec3().dot(&z),
485                );
486                angles.phi()
487            }
488            Frame::GottfriedJackson => {
489                let beam_res = beam.boost(&-resonance.beta());
490                let z = beam_res.vec3().unit();
491                let y = beam.vec3().cross(&-recoil.vec3()).unit();
492                let x = y.cross(&z);
493                let angles = Vec3::new(
494                    daughter_res.vec3().dot(&x),
495                    daughter_res.vec3().dot(&y),
496                    daughter_res.vec3().dot(&z),
497                );
498                angles.phi()
499            }
500        }
501    }
502}
503
504/// A struct for obtaining both spherical angles at the same time.
505#[derive(Clone, Debug, Serialize, Deserialize)]
506pub struct Angles {
507    /// See [`CosTheta`].
508    pub costheta: CosTheta,
509    /// See [`Phi`].
510    pub phi: Phi,
511}
512
513impl Display for Angles {
514    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
515        write!(
516            f,
517            "Angles(beam={}, recoil=[{}], daughter=[{}], resonance=[{}], frame={})",
518            self.costheta.beam,
519            indices_to_string(&self.costheta.recoil),
520            indices_to_string(&self.costheta.daughter),
521            indices_to_string(&self.costheta.resonance),
522            self.costheta.frame
523        )
524    }
525}
526impl Angles {
527    /// Construct the angles given the four-momentum indices for each specified particle. Fields
528    /// which can take lists of more than one index will add the relevant four-momenta to make a
529    /// new particle from the constituents. See [`Frame`] for options regarding the reference
530    /// frame.
531    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
532        beam: usize,
533        recoil: T,
534        daughter: U,
535        resonance: V,
536        frame: Frame,
537    ) -> Self {
538        Self {
539            costheta: CosTheta::new(beam, &recoil, &daughter, &resonance, frame),
540            phi: Phi {
541                beam,
542                recoil: recoil.as_ref().into(),
543                daughter: daughter.as_ref().into(),
544                resonance: resonance.as_ref().into(),
545                frame,
546            },
547        }
548    }
549}
550
551/// A struct defining the polarization angle for a beam relative to the production plane.
552#[derive(Clone, Debug, Serialize, Deserialize)]
553pub struct PolAngle {
554    beam: usize,
555    recoil: Vec<usize>,
556    beam_polarization: usize,
557}
558impl Display for PolAngle {
559    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560        write!(
561            f,
562            "PolAngle(beam={}, recoil=[{}], beam_polarization={})",
563            self.beam,
564            indices_to_string(&self.recoil),
565            self.beam_polarization,
566        )
567    }
568}
569impl PolAngle {
570    /// Constructs the polarization angle given the four-momentum indices for each specified
571    /// particle. Fields which can take lists of more than one index will add the relevant
572    /// four-momenta to make a new particle from the constituents.
573    pub fn new<T: AsRef<[usize]>>(beam: usize, recoil: T, beam_polarization: usize) -> Self {
574        Self {
575            beam,
576            recoil: recoil.as_ref().into(),
577            beam_polarization,
578        }
579    }
580}
581#[typetag::serde]
582impl Variable for PolAngle {
583    fn value(&self, event: &Event) -> Float {
584        let beam = event.p4s[self.beam];
585        let recoil = event.get_p4_sum(&self.recoil);
586        let y = beam.vec3().cross(&-recoil.vec3()).unit();
587        Float::atan2(
588            y.dot(&event.aux[self.beam_polarization]),
589            beam.vec3()
590                .unit()
591                .dot(&event.aux[self.beam_polarization].cross(&y)),
592        )
593    }
594}
595
596/// A struct defining the polarization magnitude for a beam relative to the production plane.
597#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
598pub struct PolMagnitude {
599    beam_polarization: usize,
600}
601impl Display for PolMagnitude {
602    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603        write!(
604            f,
605            "PolMagnitude(beam_polarization={})",
606            self.beam_polarization,
607        )
608    }
609}
610impl PolMagnitude {
611    /// Constructs the polarization magnitude given the four-momentum index for the beam.
612    pub fn new(beam_polarization: usize) -> Self {
613        Self { beam_polarization }
614    }
615}
616#[typetag::serde]
617impl Variable for PolMagnitude {
618    fn value(&self, event: &Event) -> Float {
619        event.aux[self.beam_polarization].mag()
620    }
621}
622
623/// A struct for obtaining both the polarization angle and magnitude at the same time.
624#[derive(Clone, Debug, Serialize, Deserialize)]
625pub struct Polarization {
626    /// See [`PolMagnitude`].
627    pub pol_magnitude: PolMagnitude,
628    /// See [`PolAngle`].
629    pub pol_angle: PolAngle,
630}
631impl Display for Polarization {
632    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
633        write!(
634            f,
635            "Polarization(beam={}, recoil=[{}], beam_polarization={})",
636            self.pol_angle.beam,
637            indices_to_string(&self.pol_angle.recoil),
638            self.pol_angle.beam_polarization,
639        )
640    }
641}
642impl Polarization {
643    /// Constructs the polarization angle and magnitude given the four-momentum indices for
644    /// the beam and target (recoil) particle. Fields which can take lists of more than one index will add
645    /// the relevant four-momenta to make a new particle from the constituents.
646    pub fn new<T: AsRef<[usize]>>(beam: usize, recoil: T, beam_polarization: usize) -> Self {
647        Self {
648            pol_magnitude: PolMagnitude::new(beam_polarization),
649            pol_angle: PolAngle::new(beam, recoil, beam_polarization),
650        }
651    }
652}
653
654/// A struct used to calculate Mandelstam variables ($`s`$, $`t`$, or $`u`$).
655///
656/// By convention, the metric is chosen to be $`(+---)`$ and the variables are defined as follows
657/// (ignoring factors of $`c`$):
658///
659/// $`s = (p_1 + p_2)^2 = (p_3 + p_4)^2`$
660///
661/// $`t = (p_1 - p_3)^2 = (p_4 - p_2)^2`$
662///
663/// $`u = (p_1 - p_4)^2 = (p_3 - p_2)^2`$
664#[derive(Clone, Debug, Serialize, Deserialize)]
665pub struct Mandelstam {
666    p1: Vec<usize>,
667    p2: Vec<usize>,
668    p3: Vec<usize>,
669    p4: Vec<usize>,
670    missing: Option<u8>,
671    channel: Channel,
672}
673impl Display for Mandelstam {
674    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
675        write!(
676            f,
677            "Mandelstam(p1=[{}], p2=[{}], p3=[{}], p4=[{}], channel={})",
678            indices_to_string(&self.p1),
679            indices_to_string(&self.p2),
680            indices_to_string(&self.p3),
681            indices_to_string(&self.p4),
682            self.channel,
683        )
684    }
685}
686impl Mandelstam {
687    /// Constructs the Mandelstam variable for the given `channel` and particles.
688    /// Fields which can take lists of more than one index will add
689    /// the relevant four-momenta to make a new particle from the constituents.
690    pub fn new<T, U, V, W>(p1: T, p2: U, p3: V, p4: W, channel: Channel) -> Result<Self, LadduError>
691    where
692        T: AsRef<[usize]>,
693        U: AsRef<[usize]>,
694        V: AsRef<[usize]>,
695        W: AsRef<[usize]>,
696    {
697        let mut missing = None;
698        if p1.as_ref().is_empty() {
699            missing = Some(1)
700        }
701        if p2.as_ref().is_empty() {
702            if missing.is_none() {
703                missing = Some(2)
704            } else {
705                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
706            }
707        }
708        if p3.as_ref().is_empty() {
709            if missing.is_none() {
710                missing = Some(3)
711            } else {
712                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
713            }
714        }
715        if p4.as_ref().is_empty() {
716            if missing.is_none() {
717                missing = Some(4)
718            } else {
719                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
720            }
721        }
722        Ok(Self {
723            p1: p1.as_ref().into(),
724            p2: p2.as_ref().into(),
725            p3: p3.as_ref().into(),
726            p4: p4.as_ref().into(),
727            missing,
728            channel,
729        })
730    }
731}
732
733#[typetag::serde]
734impl Variable for Mandelstam {
735    fn value(&self, event: &Event) -> Float {
736        match self.channel {
737            Channel::S => match self.missing {
738                None | Some(3) | Some(4) => {
739                    let p1 = event.get_p4_sum(&self.p1);
740                    let p2 = event.get_p4_sum(&self.p2);
741                    (p1 + p2).mag2()
742                }
743                Some(1) | Some(2) => {
744                    let p3 = event.get_p4_sum(&self.p3);
745                    let p4 = event.get_p4_sum(&self.p4);
746                    (p3 + p4).mag2()
747                }
748                _ => unreachable!(),
749            },
750            Channel::T => match self.missing {
751                None | Some(2) | Some(4) => {
752                    let p1 = event.get_p4_sum(&self.p1);
753                    let p3 = event.get_p4_sum(&self.p3);
754                    (p1 - p3).mag2()
755                }
756                Some(1) | Some(3) => {
757                    let p2 = event.get_p4_sum(&self.p2);
758                    let p4 = event.get_p4_sum(&self.p4);
759                    (p4 - p2).mag2()
760                }
761                _ => unreachable!(),
762            },
763            Channel::U => match self.missing {
764                None | Some(2) | Some(3) => {
765                    let p1 = event.get_p4_sum(&self.p1);
766                    let p4 = event.get_p4_sum(&self.p4);
767                    (p1 - p4).mag2()
768                }
769                Some(1) | Some(4) => {
770                    let p2 = event.get_p4_sum(&self.p2);
771                    let p3 = event.get_p4_sum(&self.p3);
772                    (p3 - p2).mag2()
773                }
774                _ => unreachable!(),
775            },
776        }
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783    use crate::data::{test_dataset, test_event};
784    use approx::assert_relative_eq;
785
786    #[test]
787    fn test_mass_single_particle() {
788        let event = test_event();
789        let mass = Mass::new([1]);
790        assert_relative_eq!(mass.value(&event), 1.007);
791    }
792
793    #[test]
794    fn test_mass_multiple_particles() {
795        let event = test_event();
796        let mass = Mass::new([2, 3]);
797        assert_relative_eq!(
798            mass.value(&event),
799            1.37437863,
800            epsilon = Float::EPSILON.sqrt()
801        );
802    }
803
804    #[test]
805    fn test_mass_display() {
806        let mass = Mass::new([2, 3]);
807        assert_eq!(mass.to_string(), "Mass(constituents=[2, 3])");
808    }
809
810    #[test]
811    fn test_costheta_helicity() {
812        let event = test_event();
813        let costheta = CosTheta::new(0, [1], [2], [2, 3], Frame::Helicity);
814        assert_relative_eq!(
815            costheta.value(&event),
816            -0.4611175,
817            epsilon = Float::EPSILON.sqrt()
818        );
819    }
820
821    #[test]
822    fn test_costheta_display() {
823        let costheta = CosTheta::new(0, [1], [2], [2, 3], Frame::Helicity);
824        assert_eq!(
825            costheta.to_string(),
826            "CosTheta(beam=0, recoil=[1], daughter=[2], resonance=[2, 3], frame=Helicity)"
827        );
828    }
829
830    #[test]
831    fn test_phi_helicity() {
832        let event = test_event();
833        let phi = Phi::new(0, [1], [2], [2, 3], Frame::Helicity);
834        assert_relative_eq!(
835            phi.value(&event),
836            -2.65746258,
837            epsilon = Float::EPSILON.sqrt()
838        );
839    }
840
841    #[test]
842    fn test_phi_display() {
843        let phi = Phi::new(0, [1], [2], [2, 3], Frame::Helicity);
844        assert_eq!(
845            phi.to_string(),
846            "Phi(beam=0, recoil=[1], daughter=[2], resonance=[2, 3], frame=Helicity)"
847        );
848    }
849
850    #[test]
851    fn test_costheta_gottfried_jackson() {
852        let event = test_event();
853        let costheta = CosTheta::new(0, [1], [2], [2, 3], Frame::GottfriedJackson);
854        assert_relative_eq!(
855            costheta.value(&event),
856            0.09198832,
857            epsilon = Float::EPSILON.sqrt()
858        );
859    }
860
861    #[test]
862    fn test_phi_gottfried_jackson() {
863        let event = test_event();
864        let phi = Phi::new(0, [1], [2], [2, 3], Frame::GottfriedJackson);
865        assert_relative_eq!(
866            phi.value(&event),
867            -2.71391319,
868            epsilon = Float::EPSILON.sqrt()
869        );
870    }
871
872    #[test]
873    fn test_angles() {
874        let event = test_event();
875        let angles = Angles::new(0, [1], [2], [2, 3], Frame::Helicity);
876        assert_relative_eq!(
877            angles.costheta.value(&event),
878            -0.4611175,
879            epsilon = Float::EPSILON.sqrt()
880        );
881        assert_relative_eq!(
882            angles.phi.value(&event),
883            -2.65746258,
884            epsilon = Float::EPSILON.sqrt()
885        );
886    }
887
888    #[test]
889    fn test_angles_display() {
890        let angles = Angles::new(0, [1], [2], [2, 3], Frame::Helicity);
891        assert_eq!(
892            angles.to_string(),
893            "Angles(beam=0, recoil=[1], daughter=[2], resonance=[2, 3], frame=Helicity)"
894        );
895    }
896
897    #[test]
898    fn test_pol_angle() {
899        let event = test_event();
900        let pol_angle = PolAngle::new(0, vec![1], 0);
901        assert_relative_eq!(
902            pol_angle.value(&event),
903            1.93592989,
904            epsilon = Float::EPSILON.sqrt()
905        );
906    }
907
908    #[test]
909    fn test_pol_angle_display() {
910        let pol_angle = PolAngle::new(0, vec![1], 0);
911        assert_eq!(
912            pol_angle.to_string(),
913            "PolAngle(beam=0, recoil=[1], beam_polarization=0)"
914        );
915    }
916
917    #[test]
918    fn test_pol_magnitude() {
919        let event = test_event();
920        let pol_magnitude = PolMagnitude::new(0);
921        assert_relative_eq!(
922            pol_magnitude.value(&event),
923            0.38562805,
924            epsilon = Float::EPSILON.sqrt()
925        );
926    }
927
928    #[test]
929    fn test_pol_magnitude_display() {
930        let pol_magnitude = PolMagnitude::new(0);
931        assert_eq!(
932            pol_magnitude.to_string(),
933            "PolMagnitude(beam_polarization=0)"
934        );
935    }
936
937    #[test]
938    fn test_polarization() {
939        let event = test_event();
940        let polarization = Polarization::new(0, vec![1], 0);
941        assert_relative_eq!(
942            polarization.pol_angle.value(&event),
943            1.93592989,
944            epsilon = Float::EPSILON.sqrt()
945        );
946        assert_relative_eq!(
947            polarization.pol_magnitude.value(&event),
948            0.38562805,
949            epsilon = Float::EPSILON.sqrt()
950        );
951    }
952
953    #[test]
954    fn test_polarization_display() {
955        let polarization = Polarization::new(0, vec![1], 0);
956        assert_eq!(
957            polarization.to_string(),
958            "Polarization(beam=0, recoil=[1], beam_polarization=0)"
959        );
960    }
961
962    #[test]
963    fn test_mandelstam() {
964        let event = test_event();
965        let s = Mandelstam::new([0], [], [2, 3], [1], Channel::S).unwrap();
966        let t = Mandelstam::new([0], [], [2, 3], [1], Channel::T).unwrap();
967        let u = Mandelstam::new([0], [], [2, 3], [1], Channel::U).unwrap();
968        let sp = Mandelstam::new([], [0], [1], [2, 3], Channel::S).unwrap();
969        let tp = Mandelstam::new([], [0], [1], [2, 3], Channel::T).unwrap();
970        let up = Mandelstam::new([], [0], [1], [2, 3], Channel::U).unwrap();
971        assert_relative_eq!(
972            s.value(&event),
973            18.50401105,
974            epsilon = Float::EPSILON.sqrt()
975        );
976        assert_relative_eq!(s.value(&event), sp.value(&event),);
977        assert_relative_eq!(
978            t.value(&event),
979            -0.19222859,
980            epsilon = Float::EPSILON.sqrt()
981        );
982        assert_relative_eq!(t.value(&event), tp.value(&event),);
983        assert_relative_eq!(
984            u.value(&event),
985            -14.40419893,
986            epsilon = Float::EPSILON.sqrt()
987        );
988        assert_relative_eq!(u.value(&event), up.value(&event),);
989        let m2_beam = test_event().get_p4_sum([0]).m2();
990        let m2_recoil = test_event().get_p4_sum([1]).m2();
991        let m2_res = test_event().get_p4_sum([2, 3]).m2();
992        assert_relative_eq!(
993            s.value(&event) + t.value(&event) + u.value(&event) - m2_beam - m2_recoil - m2_res,
994            1.00,
995            epsilon = 1e-2
996        );
997        // Note: not very accurate, but considering the values in test_event only go to about 3
998        // decimal places, this is probably okay
999    }
1000
1001    #[test]
1002    fn test_mandelstam_display() {
1003        let s = Mandelstam::new([0], [], [2, 3], [1], Channel::S).unwrap();
1004        assert_eq!(
1005            s.to_string(),
1006            "Mandelstam(p1=[0], p2=[], p3=[2, 3], p4=[1], channel=s)"
1007        );
1008    }
1009
1010    #[test]
1011    fn test_variable_value_on() {
1012        let dataset = test_dataset();
1013        let mass = Mass::new(vec![2, 3]);
1014
1015        let values = mass.value_on(&dataset);
1016        assert_eq!(values.len(), 1);
1017        assert_relative_eq!(values[0], 1.37437863, epsilon = Float::EPSILON.sqrt());
1018    }
1019}