laddu_core/utils/
variables.rs

1use dyn_clone::DynClone;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8#[cfg(feature = "mpi")]
9use crate::mpi::LadduMPI;
10use crate::{
11    data::{Dataset, Event},
12    utils::{
13        enums::{Channel, Frame},
14        vectors::Vec3,
15    },
16    Float, LadduError,
17};
18#[cfg(feature = "mpi")]
19use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
20
21/// Standard methods for extracting some value out of an [`Event`].
22#[typetag::serde(tag = "type")]
23pub trait Variable: DynClone + Send + Sync {
24    /// This method takes an [`Event`] and extracts a single value (like the mass of a particle).
25    fn value(&self, event: &Event) -> Float;
26
27    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
28    /// [`Dataset`] (non-MPI version).
29    ///
30    /// # Notes
31    ///
32    /// This method is not intended to be called in analyses but rather in writing methods
33    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
34    fn value_on_local(&self, dataset: &Arc<Dataset>) -> Vec<Float> {
35        #[cfg(feature = "rayon")]
36        let local_values: Vec<Float> = dataset.events.par_iter().map(|e| self.value(e)).collect();
37        #[cfg(not(feature = "rayon"))]
38        let local_values: Vec<Float> = dataset.events.iter().map(|e| self.value(e)).collect();
39        local_values
40    }
41
42    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
43    /// [`Dataset`] (MPI-compatible version).
44    ///
45    /// # Notes
46    ///
47    /// This method is not intended to be called in analyses but rather in writing methods
48    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
49    #[cfg(feature = "mpi")]
50    fn value_on_mpi(&self, dataset: &Arc<Dataset>, world: &SimpleCommunicator) -> Vec<Float> {
51        let local_weights = self.value_on_local(dataset);
52        let n_events = dataset.n_events();
53        let mut buffer: Vec<Float> = vec![0.0; n_events];
54        let (counts, displs) = world.get_counts_displs(n_events);
55        {
56            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
57            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
58        }
59        buffer
60    }
61
62    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
63    /// [`Dataset`].
64    fn value_on(&self, dataset: &Arc<Dataset>) -> Vec<Float> {
65        #[cfg(feature = "mpi")]
66        {
67            if let Some(world) = crate::mpi::get_world() {
68                return self.value_on_mpi(dataset, &world);
69            }
70        }
71        self.value_on_local(dataset)
72    }
73}
74dyn_clone::clone_trait_object!(Variable);
75
76/// A struct for obtaining the mass of a particle by indexing the four-momenta of an event, adding
77/// together multiple four-momenta if more than one index is given.
78#[derive(Clone, Debug, Serialize, Deserialize)]
79pub struct Mass(Vec<usize>);
80impl Mass {
81    /// Create a new [`Mass`] from the sum of the four-momenta at the given indices in the
82    /// [`Event`]'s `p4s` field.
83    pub fn new<T: AsRef<[usize]>>(constituents: T) -> Self {
84        Self(constituents.as_ref().into())
85    }
86}
87#[typetag::serde]
88impl Variable for Mass {
89    fn value(&self, event: &Event) -> Float {
90        event.get_p4_sum(&self.0).m()
91    }
92}
93
94/// A struct for obtaining the $`\cos\theta`$ (cosine of the polar angle) of a decay product in
95/// a given reference frame of its parent resonance.
96#[derive(Clone, Debug, Serialize, Deserialize)]
97pub struct CosTheta {
98    beam: usize,
99    recoil: Vec<usize>,
100    daughter: Vec<usize>,
101    resonance: Vec<usize>,
102    frame: Frame,
103}
104impl CosTheta {
105    /// Construct the angle given the four-momentum indices for each specified particle. Fields
106    /// which can take lists of more than one index will add the relevant four-momenta to make a
107    /// new particle from the constituents. See [`Frame`] for options regarding the reference
108    /// frame.
109    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
110        beam: usize,
111        recoil: T,
112        daughter: U,
113        resonance: V,
114        frame: Frame,
115    ) -> Self {
116        Self {
117            beam,
118            recoil: recoil.as_ref().into(),
119            daughter: daughter.as_ref().into(),
120            resonance: resonance.as_ref().into(),
121            frame,
122        }
123    }
124}
125impl Default for CosTheta {
126    fn default() -> Self {
127        Self {
128            beam: 0,
129            recoil: vec![1],
130            daughter: vec![2],
131            resonance: vec![2, 3],
132            frame: Frame::Helicity,
133        }
134    }
135}
136#[typetag::serde]
137impl Variable for CosTheta {
138    fn value(&self, event: &Event) -> Float {
139        let beam = event.p4s[self.beam];
140        let recoil = event.get_p4_sum(&self.recoil);
141        let daughter = event.get_p4_sum(&self.daughter);
142        let resonance = event.get_p4_sum(&self.resonance);
143        let daughter_res = daughter.boost(&-resonance.beta());
144        match self.frame {
145            Frame::Helicity => {
146                let recoil_res = recoil.boost(&-resonance.beta());
147                let z = -recoil_res.vec3().unit();
148                let y = beam.vec3().cross(&-recoil.vec3()).unit();
149                let x = y.cross(&z);
150                let angles = Vec3::new(
151                    daughter_res.vec3().dot(&x),
152                    daughter_res.vec3().dot(&y),
153                    daughter_res.vec3().dot(&z),
154                );
155                angles.costheta()
156            }
157            Frame::GottfriedJackson => {
158                let beam_res = beam.boost(&-resonance.beta());
159                let z = beam_res.vec3().unit();
160                let y = beam.vec3().cross(&-recoil.vec3()).unit();
161                let x = y.cross(&z);
162                let angles = Vec3::new(
163                    daughter_res.vec3().dot(&x),
164                    daughter_res.vec3().dot(&y),
165                    daughter_res.vec3().dot(&z),
166                );
167                angles.costheta()
168            }
169        }
170    }
171}
172
173/// A struct for obtaining the $`\phi`$ angle (azimuthal angle) of a decay product in a given
174/// reference frame of its parent resonance.
175#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct Phi {
177    beam: usize,
178    recoil: Vec<usize>,
179    daughter: Vec<usize>,
180    resonance: Vec<usize>,
181    frame: Frame,
182}
183impl Phi {
184    /// Construct the angle given the four-momentum indices for each specified particle. Fields
185    /// which can take lists of more than one index will add the relevant four-momenta to make a
186    /// new particle from the constituents. See [`Frame`] for options regarding the reference
187    /// frame.
188    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
189        beam: usize,
190        recoil: T,
191        daughter: U,
192        resonance: V,
193        frame: Frame,
194    ) -> Self {
195        Self {
196            beam,
197            recoil: recoil.as_ref().into(),
198            daughter: daughter.as_ref().into(),
199            resonance: resonance.as_ref().into(),
200            frame,
201        }
202    }
203}
204impl Default for Phi {
205    fn default() -> Self {
206        Self {
207            beam: 0,
208            recoil: vec![1],
209            daughter: vec![2],
210            resonance: vec![2, 3],
211            frame: Frame::Helicity,
212        }
213    }
214}
215#[typetag::serde]
216impl Variable for Phi {
217    fn value(&self, event: &Event) -> Float {
218        let beam = event.p4s[self.beam];
219        let recoil = event.get_p4_sum(&self.recoil);
220        let daughter = event.get_p4_sum(&self.daughter);
221        let resonance = event.get_p4_sum(&self.resonance);
222        let daughter_res = daughter.boost(&-resonance.beta());
223        match self.frame {
224            Frame::Helicity => {
225                let recoil_res = recoil.boost(&-resonance.beta());
226                let z = -recoil_res.vec3().unit();
227                let y = beam.vec3().cross(&-recoil.vec3()).unit();
228                let x = y.cross(&z);
229                let angles = Vec3::new(
230                    daughter_res.vec3().dot(&x),
231                    daughter_res.vec3().dot(&y),
232                    daughter_res.vec3().dot(&z),
233                );
234                angles.phi()
235            }
236            Frame::GottfriedJackson => {
237                let beam_res = beam.boost(&-resonance.beta());
238                let z = beam_res.vec3().unit();
239                let y = beam.vec3().cross(&-recoil.vec3()).unit();
240                let x = y.cross(&z);
241                let angles = Vec3::new(
242                    daughter_res.vec3().dot(&x),
243                    daughter_res.vec3().dot(&y),
244                    daughter_res.vec3().dot(&z),
245                );
246                angles.phi()
247            }
248        }
249    }
250}
251
252/// A struct for obtaining both spherical angles at the same time.
253#[derive(Clone, Debug, Serialize, Deserialize)]
254pub struct Angles {
255    /// See [`CosTheta`].
256    pub costheta: CosTheta,
257    /// See [`Phi`].
258    pub phi: Phi,
259}
260
261impl Angles {
262    /// Construct the angles given the four-momentum indices for each specified particle. Fields
263    /// which can take lists of more than one index will add the relevant four-momenta to make a
264    /// new particle from the constituents. See [`Frame`] for options regarding the reference
265    /// frame.
266    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
267        beam: usize,
268        recoil: T,
269        daughter: U,
270        resonance: V,
271        frame: Frame,
272    ) -> Self {
273        Self {
274            costheta: CosTheta::new(beam, &recoil, &daughter, &resonance, frame),
275            phi: Phi {
276                beam,
277                recoil: recoil.as_ref().into(),
278                daughter: daughter.as_ref().into(),
279                resonance: resonance.as_ref().into(),
280                frame,
281            },
282        }
283    }
284}
285
286/// A struct defining the polarization angle for a beam relative to the production plane.
287#[derive(Clone, Debug, Serialize, Deserialize)]
288pub struct PolAngle {
289    beam: usize,
290    recoil: Vec<usize>,
291    beam_polarization: usize,
292}
293impl PolAngle {
294    /// Constructs the polarization angle given the four-momentum indices for each specified
295    /// particle. Fields which can take lists of more than one index will add the relevant
296    /// four-momenta to make a new particle from the constituents.
297    pub fn new<T: AsRef<[usize]>>(beam: usize, recoil: T, beam_polarization: usize) -> Self {
298        Self {
299            beam,
300            recoil: recoil.as_ref().into(),
301            beam_polarization,
302        }
303    }
304}
305#[typetag::serde]
306impl Variable for PolAngle {
307    fn value(&self, event: &Event) -> Float {
308        let beam = event.p4s[self.beam];
309        let recoil = event.get_p4_sum(&self.recoil);
310        let y = beam.vec3().cross(&-recoil.vec3()).unit();
311        Float::atan2(
312            y.dot(&event.aux[self.beam_polarization]),
313            beam.vec3()
314                .unit()
315                .dot(&event.aux[self.beam_polarization].cross(&y)),
316        )
317    }
318}
319
320/// A struct defining the polarization magnitude for a beam relative to the production plane.
321#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
322pub struct PolMagnitude {
323    beam_polarization: usize,
324}
325
326impl PolMagnitude {
327    /// Constructs the polarization magnitude given the four-momentum index for the beam.
328    pub fn new(beam_polarization: usize) -> Self {
329        Self { beam_polarization }
330    }
331}
332#[typetag::serde]
333impl Variable for PolMagnitude {
334    fn value(&self, event: &Event) -> Float {
335        event.aux[self.beam_polarization].mag()
336    }
337}
338
339/// A struct for obtaining both the polarization angle and magnitude at the same time.
340#[derive(Clone, Debug, Serialize, Deserialize)]
341pub struct Polarization {
342    /// See [`PolMagnitude`].
343    pub pol_magnitude: PolMagnitude,
344    /// See [`PolAngle`].
345    pub pol_angle: PolAngle,
346}
347
348impl Polarization {
349    /// Constructs the polarization angle and magnitude given the four-momentum indices for
350    /// the beam and target (recoil) particle. Fields which can take lists of more than one index will add
351    /// the relevant four-momenta to make a new particle from the constituents.
352    pub fn new<T: AsRef<[usize]>>(beam: usize, recoil: T, beam_polarization: usize) -> Self {
353        Self {
354            pol_magnitude: PolMagnitude::new(beam_polarization),
355            pol_angle: PolAngle::new(beam, recoil, beam_polarization),
356        }
357    }
358}
359
360/// A struct used to calculate Mandelstam variables ($`s`$, $`t`$, or $`u`$).
361///
362/// By convention, the metric is chosen to be $`(+---)`$ and the variables are defined as follows
363/// (ignoring factors of $`c`$):
364///
365/// $`s = (p_1 + p_2)^2 = (p_3 + p_4)^2`$
366///
367/// $`t = (p_1 - p_3)^2 = (p_4 - p_2)^2`$
368///
369/// $`u = (p_1 - p_4)^2 = (p_3 - p_2)^2`$
370#[derive(Clone, Debug, Serialize, Deserialize)]
371pub struct Mandelstam {
372    p1: Vec<usize>,
373    p2: Vec<usize>,
374    p3: Vec<usize>,
375    p4: Vec<usize>,
376    missing: Option<u8>,
377    channel: Channel,
378}
379impl Mandelstam {
380    /// Constructs the Mandelstam variable for the given `channel` and particles.
381    /// Fields which can take lists of more than one index will add
382    /// the relevant four-momenta to make a new particle from the constituents.
383    pub fn new<T, U, V, W>(p1: T, p2: U, p3: V, p4: W, channel: Channel) -> Result<Self, LadduError>
384    where
385        T: AsRef<[usize]>,
386        U: AsRef<[usize]>,
387        V: AsRef<[usize]>,
388        W: AsRef<[usize]>,
389    {
390        let mut missing = None;
391        if p1.as_ref().is_empty() {
392            missing = Some(1)
393        }
394        if p2.as_ref().is_empty() {
395            if missing.is_none() {
396                missing = Some(2)
397            } else {
398                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
399            }
400        }
401        if p3.as_ref().is_empty() {
402            if missing.is_none() {
403                missing = Some(3)
404            } else {
405                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
406            }
407        }
408        if p4.as_ref().is_empty() {
409            if missing.is_none() {
410                missing = Some(4)
411            } else {
412                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
413            }
414        }
415        Ok(Self {
416            p1: p1.as_ref().into(),
417            p2: p2.as_ref().into(),
418            p3: p3.as_ref().into(),
419            p4: p4.as_ref().into(),
420            missing,
421            channel,
422        })
423    }
424}
425
426#[typetag::serde]
427impl Variable for Mandelstam {
428    fn value(&self, event: &Event) -> Float {
429        match self.channel {
430            Channel::S => match self.missing {
431                None | Some(3) | Some(4) => {
432                    let p1 = event.get_p4_sum(&self.p1);
433                    let p2 = event.get_p4_sum(&self.p2);
434                    (p1 + p2).mag2()
435                }
436                Some(1) | Some(2) => {
437                    let p3 = event.get_p4_sum(&self.p3);
438                    let p4 = event.get_p4_sum(&self.p4);
439                    (p3 + p4).mag2()
440                }
441                _ => unreachable!(),
442            },
443            Channel::T => match self.missing {
444                None | Some(2) | Some(4) => {
445                    let p1 = event.get_p4_sum(&self.p1);
446                    let p3 = event.get_p4_sum(&self.p3);
447                    (p1 - p3).mag2()
448                }
449                Some(1) | Some(3) => {
450                    let p2 = event.get_p4_sum(&self.p2);
451                    let p4 = event.get_p4_sum(&self.p4);
452                    (p4 - p2).mag2()
453                }
454                _ => unreachable!(),
455            },
456            Channel::U => match self.missing {
457                None | Some(2) | Some(3) => {
458                    let p1 = event.get_p4_sum(&self.p1);
459                    let p4 = event.get_p4_sum(&self.p4);
460                    (p1 - p4).mag2()
461                }
462                Some(1) | Some(4) => {
463                    let p2 = event.get_p4_sum(&self.p2);
464                    let p3 = event.get_p4_sum(&self.p3);
465                    (p3 - p2).mag2()
466                }
467                _ => unreachable!(),
468            },
469        }
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use approx::assert_relative_eq;
476
477    use crate::data::{test_dataset, test_event};
478
479    use super::*;
480    #[test]
481    fn test_mass_single_particle() {
482        let event = test_event();
483        let mass = Mass::new([1]);
484        assert_relative_eq!(mass.value(&event), 1.007);
485    }
486
487    #[test]
488    fn test_mass_multiple_particles() {
489        let event = test_event();
490        let mass = Mass::new([2, 3]);
491        assert_relative_eq!(
492            mass.value(&event),
493            1.37437863,
494            epsilon = Float::EPSILON.sqrt()
495        );
496    }
497
498    #[test]
499    fn test_costheta_helicity() {
500        let event = test_event();
501        let costheta = CosTheta::new(0, [1], [2], [2, 3], Frame::Helicity);
502        assert_relative_eq!(
503            costheta.value(&event),
504            -0.4611175,
505            epsilon = Float::EPSILON.sqrt()
506        );
507    }
508
509    #[test]
510    fn test_phi_helicity() {
511        let event = test_event();
512        let phi = Phi::new(0, [1], [2], [2, 3], Frame::Helicity);
513        assert_relative_eq!(
514            phi.value(&event),
515            -2.65746258,
516            epsilon = Float::EPSILON.sqrt()
517        );
518    }
519
520    #[test]
521    fn test_costheta_gottfried_jackson() {
522        let event = test_event();
523        let costheta = CosTheta::new(0, [1], [2], [2, 3], Frame::GottfriedJackson);
524        assert_relative_eq!(
525            costheta.value(&event),
526            0.09198832,
527            epsilon = Float::EPSILON.sqrt()
528        );
529    }
530
531    #[test]
532    fn test_phi_gottfried_jackson() {
533        let event = test_event();
534        let phi = Phi::new(0, [1], [2], [2, 3], Frame::GottfriedJackson);
535        assert_relative_eq!(
536            phi.value(&event),
537            -2.71391319,
538            epsilon = Float::EPSILON.sqrt()
539        );
540    }
541
542    #[test]
543    fn test_angles() {
544        let event = test_event();
545        let angles = Angles::new(0, [1], [2], [2, 3], Frame::Helicity);
546        assert_relative_eq!(
547            angles.costheta.value(&event),
548            -0.4611175,
549            epsilon = Float::EPSILON.sqrt()
550        );
551        assert_relative_eq!(
552            angles.phi.value(&event),
553            -2.65746258,
554            epsilon = Float::EPSILON.sqrt()
555        );
556    }
557
558    #[test]
559    fn test_pol_angle() {
560        let event = test_event();
561        let pol_angle = PolAngle::new(0, vec![1], 0);
562        assert_relative_eq!(
563            pol_angle.value(&event),
564            1.93592989,
565            epsilon = Float::EPSILON.sqrt()
566        );
567    }
568
569    #[test]
570    fn test_pol_magnitude() {
571        let event = test_event();
572        let pol_magnitude = PolMagnitude::new(0);
573        assert_relative_eq!(
574            pol_magnitude.value(&event),
575            0.38562805,
576            epsilon = Float::EPSILON.sqrt()
577        );
578    }
579
580    #[test]
581    fn test_polarization() {
582        let event = test_event();
583        let polarization = Polarization::new(0, vec![1], 0);
584        assert_relative_eq!(
585            polarization.pol_angle.value(&event),
586            1.93592989,
587            epsilon = Float::EPSILON.sqrt()
588        );
589        assert_relative_eq!(
590            polarization.pol_magnitude.value(&event),
591            0.38562805,
592            epsilon = Float::EPSILON.sqrt()
593        );
594    }
595
596    #[test]
597    fn test_mandelstam() {
598        let event = test_event();
599        let s = Mandelstam::new([0], [], [2, 3], [1], Channel::S).unwrap();
600        let t = Mandelstam::new([0], [], [2, 3], [1], Channel::T).unwrap();
601        let u = Mandelstam::new([0], [], [2, 3], [1], Channel::U).unwrap();
602        let sp = Mandelstam::new([], [0], [1], [2, 3], Channel::S).unwrap();
603        let tp = Mandelstam::new([], [0], [1], [2, 3], Channel::T).unwrap();
604        let up = Mandelstam::new([], [0], [1], [2, 3], Channel::U).unwrap();
605        assert_relative_eq!(
606            s.value(&event),
607            18.50401105,
608            epsilon = Float::EPSILON.sqrt()
609        );
610        assert_relative_eq!(s.value(&event), sp.value(&event),);
611        assert_relative_eq!(
612            t.value(&event),
613            -0.19222859,
614            epsilon = Float::EPSILON.sqrt()
615        );
616        assert_relative_eq!(t.value(&event), tp.value(&event),);
617        assert_relative_eq!(
618            u.value(&event),
619            -14.40419893,
620            epsilon = Float::EPSILON.sqrt()
621        );
622        assert_relative_eq!(u.value(&event), up.value(&event),);
623        let m2_beam = test_event().get_p4_sum([0]).m2();
624        let m2_recoil = test_event().get_p4_sum([1]).m2();
625        let m2_res = test_event().get_p4_sum([2, 3]).m2();
626        assert_relative_eq!(
627            s.value(&event) + t.value(&event) + u.value(&event) - m2_beam - m2_recoil - m2_res,
628            1.00,
629            epsilon = 1e-2
630        );
631        // Note: not very accurate, but considering the values in test_event only go to about 3
632        // decimal places, this is probably okay
633    }
634
635    #[test]
636    fn test_variable_value_on() {
637        let dataset = Arc::new(test_dataset());
638        let mass = Mass::new(vec![2, 3]);
639
640        let values = mass.value_on(&dataset);
641        assert_eq!(values.len(), 1);
642        assert_relative_eq!(values[0], 1.37437863, epsilon = Float::EPSILON.sqrt());
643    }
644}