laddu_core/utils/
variables.rs

1use dyn_clone::DynClone;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8#[cfg(feature = "mpi")]
9use crate::mpi::LadduMPI;
10use crate::{
11    data::{Dataset, Event},
12    utils::{
13        enums::{Channel, Frame},
14        vectors::Vec3,
15    },
16    Float, LadduError,
17};
18#[cfg(feature = "mpi")]
19use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
20
21/// Standard methods for extracting some value out of an [`Event`].
22#[typetag::serde(tag = "type")]
23pub trait Variable: DynClone + Send + Sync {
24    /// This method takes an [`Event`] and extracts a single value (like the mass of a particle).
25    fn value(&self, event: &Event) -> Float;
26
27    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
28    /// [`Dataset`] (non-MPI version).
29    ///
30    /// # Notes
31    ///
32    /// This method is not intended to be called in analyses but rather in writing methods
33    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
34    fn value_on_local(&self, dataset: &Arc<Dataset>) -> Vec<Float> {
35        #[cfg(feature = "rayon")]
36        let local_values: Vec<Float> = dataset.events.par_iter().map(|e| self.value(e)).collect();
37        #[cfg(not(feature = "rayon"))]
38        let local_values: Vec<Float> = dataset.events.iter().map(|e| self.value(e)).collect();
39        local_values
40    }
41
42    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
43    /// [`Dataset`] (MPI-compatible version).
44    ///
45    /// # Notes
46    ///
47    /// This method is not intended to be called in analyses but rather in writing methods
48    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
49    #[cfg(feature = "mpi")]
50    fn value_on_mpi(&self, dataset: &Arc<Dataset>, world: &SimpleCommunicator) -> Vec<Float> {
51        let local_weights = self.value_on_local(dataset);
52        let n_events = dataset.n_events();
53        let mut buffer: Vec<Float> = vec![0.0; n_events];
54        let (counts, displs) = world.get_counts_displs(n_events);
55        {
56            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
57            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
58        }
59        buffer
60    }
61
62    /// This method distributes the [`Variable::value`] method over each [`Event`] in a
63    /// [`Dataset`].
64    fn value_on(&self, dataset: &Arc<Dataset>) -> Vec<Float> {
65        #[cfg(feature = "mpi")]
66        {
67            if let Some(world) = crate::mpi::get_world() {
68                return self.value_on_mpi(dataset, &world);
69            }
70        }
71        self.value_on_local(dataset)
72    }
73}
74dyn_clone::clone_trait_object!(Variable);
75
76/// A struct for obtaining the mass of a particle by indexing the four-momenta of an event, adding
77/// together multiple four-momenta if more than one index is given.
78#[derive(Clone, Debug, Serialize, Deserialize)]
79pub struct Mass(Vec<usize>);
80impl Mass {
81    /// Create a new [`Mass`] from the sum of the four-momenta at the given indices in the
82    /// [`Event`]'s `p4s` field.
83    pub fn new<T: AsRef<[usize]>>(constituents: T) -> Self {
84        Self(constituents.as_ref().into())
85    }
86}
87#[typetag::serde]
88impl Variable for Mass {
89    fn value(&self, event: &Event) -> Float {
90        event.get_p4_sum(&self.0).m()
91    }
92}
93
94/// A struct for obtaining the $`\cos\theta`$ (cosine of the polar angle) of a decay product in
95/// a given reference frame of its parent resonance.
96#[derive(Clone, Debug, Serialize, Deserialize)]
97pub struct CosTheta {
98    beam: usize,
99    recoil: Vec<usize>,
100    daughter: Vec<usize>,
101    resonance: Vec<usize>,
102    frame: Frame,
103}
104impl CosTheta {
105    /// Construct the angle given the four-momentum indices for each specified particle. Fields
106    /// which can take lists of more than one index will add the relevant four-momenta to make a
107    /// new particle from the constituents. See [`Frame`] for options regarding the reference
108    /// frame.
109    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
110        beam: usize,
111        recoil: T,
112        daughter: U,
113        resonance: V,
114        frame: Frame,
115    ) -> Self {
116        Self {
117            beam,
118            recoil: recoil.as_ref().into(),
119            daughter: daughter.as_ref().into(),
120            resonance: resonance.as_ref().into(),
121            frame,
122        }
123    }
124}
125impl Default for CosTheta {
126    fn default() -> Self {
127        Self {
128            beam: 0,
129            recoil: vec![1],
130            daughter: vec![2],
131            resonance: vec![2, 3],
132            frame: Frame::Helicity,
133        }
134    }
135}
136#[typetag::serde]
137impl Variable for CosTheta {
138    fn value(&self, event: &Event) -> Float {
139        let beam = event.p4s[self.beam];
140        let recoil = event.get_p4_sum(&self.recoil);
141        let daughter = event.get_p4_sum(&self.daughter);
142        let resonance = event.get_p4_sum(&self.resonance);
143        let daughter_res = daughter.boost(&-resonance.beta());
144        match self.frame {
145            Frame::Helicity => {
146                let recoil_res = recoil.boost(&-resonance.beta());
147                let z = -recoil_res.vec3().unit();
148                let y = beam.vec3().cross(&-recoil.vec3()).unit();
149                let x = y.cross(&z);
150                let angles = Vec3::new(
151                    daughter_res.vec3().dot(&x),
152                    daughter_res.vec3().dot(&y),
153                    daughter_res.vec3().dot(&z),
154                );
155                angles.costheta()
156            }
157            Frame::GottfriedJackson => {
158                let beam_res = beam.boost(&-resonance.beta());
159                let z = beam_res.vec3().unit();
160                let y = beam.vec3().cross(&-recoil.vec3()).unit();
161                let x = y.cross(&z);
162                let angles = Vec3::new(
163                    daughter_res.vec3().dot(&x),
164                    daughter_res.vec3().dot(&y),
165                    daughter_res.vec3().dot(&z),
166                );
167                angles.costheta()
168            }
169        }
170    }
171}
172
173/// A struct for obtaining the $`\phi`$ angle (azimuthal angle) of a decay product in a given
174/// reference frame of its parent resonance.
175#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct Phi {
177    beam: usize,
178    recoil: Vec<usize>,
179    daughter: Vec<usize>,
180    resonance: Vec<usize>,
181    frame: Frame,
182}
183impl Phi {
184    /// Construct the angle given the four-momentum indices for each specified particle. Fields
185    /// which can take lists of more than one index will add the relevant four-momenta to make a
186    /// new particle from the constituents. See [`Frame`] for options regarding the reference
187    /// frame.
188    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
189        beam: usize,
190        recoil: T,
191        daughter: U,
192        resonance: V,
193        frame: Frame,
194    ) -> Self {
195        Self {
196            beam,
197            recoil: recoil.as_ref().into(),
198            daughter: daughter.as_ref().into(),
199            resonance: resonance.as_ref().into(),
200            frame,
201        }
202    }
203}
204impl Default for Phi {
205    fn default() -> Self {
206        Self {
207            beam: 0,
208            recoil: vec![1],
209            daughter: vec![2],
210            resonance: vec![2, 3],
211            frame: Frame::Helicity,
212        }
213    }
214}
215#[typetag::serde]
216impl Variable for Phi {
217    fn value(&self, event: &Event) -> Float {
218        let beam = event.p4s[self.beam];
219        let recoil = event.get_p4_sum(&self.recoil);
220        let daughter = event.get_p4_sum(&self.daughter);
221        let resonance = event.get_p4_sum(&self.resonance);
222        let daughter_res = daughter.boost(&-resonance.beta());
223        match self.frame {
224            Frame::Helicity => {
225                let recoil_res = recoil.boost(&-resonance.beta());
226                let z = -recoil_res.vec3().unit();
227                let y = beam.vec3().cross(&-recoil.vec3()).unit();
228                let x = y.cross(&z);
229                let angles = Vec3::new(
230                    daughter_res.vec3().dot(&x),
231                    daughter_res.vec3().dot(&y),
232                    daughter_res.vec3().dot(&z),
233                );
234                angles.phi()
235            }
236            Frame::GottfriedJackson => {
237                let beam_res = beam.boost(&-resonance.beta());
238                let z = beam_res.vec3().unit();
239                let y = beam.vec3().cross(&-recoil.vec3()).unit();
240                let x = y.cross(&z);
241                let angles = Vec3::new(
242                    daughter_res.vec3().dot(&x),
243                    daughter_res.vec3().dot(&y),
244                    daughter_res.vec3().dot(&z),
245                );
246                angles.phi()
247            }
248        }
249    }
250}
251
252/// A struct for obtaining both spherical angles at the same time.
253#[derive(Clone, Debug, Serialize, Deserialize)]
254pub struct Angles {
255    /// See [`CosTheta`].
256    pub costheta: CosTheta,
257    /// See [`Phi`].
258    pub phi: Phi,
259}
260
261impl Angles {
262    /// Construct the angles given the four-momentum indices for each specified particle. Fields
263    /// which can take lists of more than one index will add the relevant four-momenta to make a
264    /// new particle from the constituents. See [`Frame`] for options regarding the reference
265    /// frame.
266    pub fn new<T: AsRef<[usize]>, U: AsRef<[usize]>, V: AsRef<[usize]>>(
267        beam: usize,
268        recoil: T,
269        daughter: U,
270        resonance: V,
271        frame: Frame,
272    ) -> Self {
273        Self {
274            costheta: CosTheta::new(beam, &recoil, &daughter, &resonance, frame),
275            phi: Phi {
276                beam,
277                recoil: recoil.as_ref().into(),
278                daughter: daughter.as_ref().into(),
279                resonance: resonance.as_ref().into(),
280                frame,
281            },
282        }
283    }
284}
285
286/// A struct defining the polarization angle for a beam relative to the production plane.
287#[derive(Clone, Debug, Serialize, Deserialize)]
288pub struct PolAngle {
289    beam: usize,
290    recoil: Vec<usize>,
291}
292impl PolAngle {
293    /// Constructs the polarization angle given the four-momentum indices for each specified
294    /// particle. Fields which can take lists of more than one index will add the relevant
295    /// four-momenta to make a new particle from the constituents.
296    pub fn new<T: AsRef<[usize]>>(beam: usize, recoil: T) -> Self {
297        Self {
298            beam,
299            recoil: recoil.as_ref().into(),
300        }
301    }
302}
303#[typetag::serde]
304impl Variable for PolAngle {
305    fn value(&self, event: &Event) -> Float {
306        let beam = event.p4s[self.beam];
307        let recoil = event.get_p4_sum(&self.recoil);
308        let y = beam.vec3().cross(&-recoil.vec3()).unit();
309        Float::atan2(
310            y.dot(&event.eps[self.beam]),
311            beam.vec3().unit().dot(&event.eps[self.beam].cross(&y)),
312        )
313    }
314}
315
316/// A struct defining the polarization magnitude for a beam relative to the production plane.
317#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
318pub struct PolMagnitude {
319    beam: usize,
320}
321
322impl PolMagnitude {
323    /// Constructs the polarization magnitude given the four-momentum index for the beam.
324    pub fn new(beam: usize) -> Self {
325        Self { beam }
326    }
327}
328#[typetag::serde]
329impl Variable for PolMagnitude {
330    fn value(&self, event: &Event) -> Float {
331        event.eps[self.beam].mag()
332    }
333}
334
335/// A struct for obtaining both the polarization angle and magnitude at the same time.
336#[derive(Clone, Debug, Serialize, Deserialize)]
337pub struct Polarization {
338    /// See [`PolMagnitude`].
339    pub pol_magnitude: PolMagnitude,
340    /// See [`PolAngle`].
341    pub pol_angle: PolAngle,
342}
343
344impl Polarization {
345    /// Constructs the polarization angle and magnitude given the four-momentum indices for
346    /// the beam and target (recoil) particle. Fields which can take lists of more than one index will add
347    /// the relevant four-momenta to make a new particle from the constituents.
348    pub fn new<T: AsRef<[usize]>>(beam: usize, recoil: T) -> Self {
349        Self {
350            pol_magnitude: PolMagnitude::new(beam),
351            pol_angle: PolAngle::new(beam, recoil),
352        }
353    }
354}
355
356/// A struct used to calculate Mandelstam variables ($`s`$, $`t`$, or $`u`$).
357///
358/// By convention, the metric is chosen to be $`(+---)`$ and the variables are defined as follows
359/// (ignoring factors of $`c`$):
360///
361/// $`s = (p_1 + p_2)^2 = (p_3 + p_4)^2`$
362///
363/// $`t = (p_1 - p_3)^2 = (p_4 - p_2)^2`$
364///
365/// $`u = (p_1 - p_4)^2 = (p_3 - p_2)^2`$
366#[derive(Clone, Debug, Serialize, Deserialize)]
367pub struct Mandelstam {
368    p1: Vec<usize>,
369    p2: Vec<usize>,
370    p3: Vec<usize>,
371    p4: Vec<usize>,
372    missing: Option<u8>,
373    channel: Channel,
374}
375impl Mandelstam {
376    /// Constructs the Mandelstam variable for the given `channel` and particles.
377    /// Fields which can take lists of more than one index will add
378    /// the relevant four-momenta to make a new particle from the constituents.
379    pub fn new<T, U, V, W>(p1: T, p2: U, p3: V, p4: W, channel: Channel) -> Result<Self, LadduError>
380    where
381        T: AsRef<[usize]>,
382        U: AsRef<[usize]>,
383        V: AsRef<[usize]>,
384        W: AsRef<[usize]>,
385    {
386        let mut missing = None;
387        if p1.as_ref().is_empty() {
388            missing = Some(1)
389        }
390        if p2.as_ref().is_empty() {
391            if missing.is_none() {
392                missing = Some(2)
393            } else {
394                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
395            }
396        }
397        if p3.as_ref().is_empty() {
398            if missing.is_none() {
399                missing = Some(3)
400            } else {
401                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
402            }
403        }
404        if p4.as_ref().is_empty() {
405            if missing.is_none() {
406                missing = Some(4)
407            } else {
408                return Err(LadduError::Custom("A maximum of one particle may be ommitted while constructing a Mandelstam variable!".to_string()));
409            }
410        }
411        Ok(Self {
412            p1: p1.as_ref().into(),
413            p2: p2.as_ref().into(),
414            p3: p3.as_ref().into(),
415            p4: p4.as_ref().into(),
416            missing,
417            channel,
418        })
419    }
420}
421
422#[typetag::serde]
423impl Variable for Mandelstam {
424    fn value(&self, event: &Event) -> Float {
425        match self.channel {
426            Channel::S => match self.missing {
427                None | Some(3) | Some(4) => {
428                    let p1 = event.get_p4_sum(&self.p1);
429                    let p2 = event.get_p4_sum(&self.p2);
430                    (p1 + p2).mag2()
431                }
432                Some(1) | Some(2) => {
433                    let p3 = event.get_p4_sum(&self.p3);
434                    let p4 = event.get_p4_sum(&self.p4);
435                    (p3 + p4).mag2()
436                }
437                _ => unreachable!(),
438            },
439            Channel::T => match self.missing {
440                None | Some(2) | Some(4) => {
441                    let p1 = event.get_p4_sum(&self.p1);
442                    let p3 = event.get_p4_sum(&self.p3);
443                    (p1 - p3).mag2()
444                }
445                Some(1) | Some(3) => {
446                    let p2 = event.get_p4_sum(&self.p2);
447                    let p4 = event.get_p4_sum(&self.p4);
448                    (p4 - p2).mag2()
449                }
450                _ => unreachable!(),
451            },
452            Channel::U => match self.missing {
453                None | Some(2) | Some(3) => {
454                    let p1 = event.get_p4_sum(&self.p1);
455                    let p4 = event.get_p4_sum(&self.p4);
456                    (p1 - p4).mag2()
457                }
458                Some(1) | Some(4) => {
459                    let p2 = event.get_p4_sum(&self.p2);
460                    let p3 = event.get_p4_sum(&self.p3);
461                    (p3 - p2).mag2()
462                }
463                _ => unreachable!(),
464            },
465        }
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use approx::assert_relative_eq;
472
473    use crate::data::{test_dataset, test_event};
474
475    use super::*;
476    #[test]
477    fn test_mass_single_particle() {
478        let event = test_event();
479        let mass = Mass::new([1]);
480        assert_relative_eq!(mass.value(&event), 1.007);
481    }
482
483    #[test]
484    fn test_mass_multiple_particles() {
485        let event = test_event();
486        let mass = Mass::new([2, 3]);
487        assert_relative_eq!(
488            mass.value(&event),
489            1.37437863,
490            epsilon = Float::EPSILON.sqrt()
491        );
492    }
493
494    #[test]
495    fn test_costheta_helicity() {
496        let event = test_event();
497        let costheta = CosTheta::new(0, [1], [2], [2, 3], Frame::Helicity);
498        assert_relative_eq!(
499            costheta.value(&event),
500            -0.4611175,
501            epsilon = Float::EPSILON.sqrt()
502        );
503    }
504
505    #[test]
506    fn test_phi_helicity() {
507        let event = test_event();
508        let phi = Phi::new(0, [1], [2], [2, 3], Frame::Helicity);
509        assert_relative_eq!(
510            phi.value(&event),
511            -2.65746258,
512            epsilon = Float::EPSILON.sqrt()
513        );
514    }
515
516    #[test]
517    fn test_costheta_gottfried_jackson() {
518        let event = test_event();
519        let costheta = CosTheta::new(0, [1], [2], [2, 3], Frame::GottfriedJackson);
520        assert_relative_eq!(
521            costheta.value(&event),
522            0.09198832,
523            epsilon = Float::EPSILON.sqrt()
524        );
525    }
526
527    #[test]
528    fn test_phi_gottfried_jackson() {
529        let event = test_event();
530        let phi = Phi::new(0, [1], [2], [2, 3], Frame::GottfriedJackson);
531        assert_relative_eq!(
532            phi.value(&event),
533            -2.71391319,
534            epsilon = Float::EPSILON.sqrt()
535        );
536    }
537
538    #[test]
539    fn test_angles() {
540        let event = test_event();
541        let angles = Angles::new(0, [1], [2], [2, 3], Frame::Helicity);
542        assert_relative_eq!(
543            angles.costheta.value(&event),
544            -0.4611175,
545            epsilon = Float::EPSILON.sqrt()
546        );
547        assert_relative_eq!(
548            angles.phi.value(&event),
549            -2.65746258,
550            epsilon = Float::EPSILON.sqrt()
551        );
552    }
553
554    #[test]
555    fn test_pol_angle() {
556        let event = test_event();
557        let pol_angle = PolAngle::new(0, vec![1]);
558        assert_relative_eq!(
559            pol_angle.value(&event),
560            1.93592989,
561            epsilon = Float::EPSILON.sqrt()
562        );
563    }
564
565    #[test]
566    fn test_pol_magnitude() {
567        let event = test_event();
568        let pol_magnitude = PolMagnitude::new(0);
569        assert_relative_eq!(
570            pol_magnitude.value(&event),
571            0.38562805,
572            epsilon = Float::EPSILON.sqrt()
573        );
574    }
575
576    #[test]
577    fn test_polarization() {
578        let event = test_event();
579        let polarization = Polarization::new(0, vec![1]);
580        assert_relative_eq!(
581            polarization.pol_angle.value(&event),
582            1.93592989,
583            epsilon = Float::EPSILON.sqrt()
584        );
585        assert_relative_eq!(
586            polarization.pol_magnitude.value(&event),
587            0.38562805,
588            epsilon = Float::EPSILON.sqrt()
589        );
590    }
591
592    #[test]
593    fn test_mandelstam() {
594        let event = test_event();
595        let s = Mandelstam::new([0], [], [2, 3], [1], Channel::S).unwrap();
596        let t = Mandelstam::new([0], [], [2, 3], [1], Channel::T).unwrap();
597        let u = Mandelstam::new([0], [], [2, 3], [1], Channel::U).unwrap();
598        let sp = Mandelstam::new([], [0], [1], [2, 3], Channel::S).unwrap();
599        let tp = Mandelstam::new([], [0], [1], [2, 3], Channel::T).unwrap();
600        let up = Mandelstam::new([], [0], [1], [2, 3], Channel::U).unwrap();
601        assert_relative_eq!(
602            s.value(&event),
603            18.50401105,
604            epsilon = Float::EPSILON.sqrt()
605        );
606        assert_relative_eq!(s.value(&event), sp.value(&event),);
607        assert_relative_eq!(
608            t.value(&event),
609            -0.19222859,
610            epsilon = Float::EPSILON.sqrt()
611        );
612        assert_relative_eq!(t.value(&event), tp.value(&event),);
613        assert_relative_eq!(
614            u.value(&event),
615            -14.40419893,
616            epsilon = Float::EPSILON.sqrt()
617        );
618        assert_relative_eq!(u.value(&event), up.value(&event),);
619        let m2_beam = test_event().get_p4_sum([0]).m2();
620        let m2_recoil = test_event().get_p4_sum([1]).m2();
621        let m2_res = test_event().get_p4_sum([2, 3]).m2();
622        assert_relative_eq!(
623            s.value(&event) + t.value(&event) + u.value(&event) - m2_beam - m2_recoil - m2_res,
624            1.00,
625            epsilon = 1e-2
626        );
627        // Note: not very accurate, but considering the values in test_event only go to about 3
628        // decimal places, this is probably okay
629    }
630
631    #[test]
632    fn test_variable_value_on() {
633        let dataset = Arc::new(test_dataset());
634        let mass = Mass::new(vec![2, 3]);
635
636        let values = mass.value_on(&dataset);
637        assert_eq!(values.len(), 1);
638        assert_relative_eq!(values[0], 1.37437863, epsilon = Float::EPSILON.sqrt());
639    }
640}