Skip to main content

laddu_core/utils/
variables.rs

1use dyn_clone::DynClone;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Debug, Display};
4
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8#[cfg(feature = "mpi")]
9use crate::mpi::LadduMPI;
10use crate::{
11    data::{Dataset, DatasetMetadata, EventData, NamedEventView},
12    utils::{
13        enums::{Channel, Frame},
14        reaction::{Particle, Reaction},
15        vectors::{Vec3, Vec4},
16    },
17    LadduError, LadduResult,
18};
19
20use auto_ops::impl_op_ex;
21
22#[cfg(feature = "mpi")]
23use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
24
25/// Standard methods for extracting some value from an event view.
26#[typetag::serde(tag = "type")]
27pub trait Variable: DynClone + Send + Sync + Debug + Display {
28    /// Bind the variable to dataset metadata so that any referenced names can be resolved to
29    /// concrete indices. Implementations that do not require metadata may keep the default
30    /// no-op.
31    fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
32        Ok(())
33    }
34
35    /// This method extracts a single value (like a mass) from an event access view.
36    fn value(&self, event: &NamedEventView<'_>) -> f64;
37
38    /// This method distributes [`Variable::value`] over each event in a [`Dataset`] (non-MPI version).
39    ///
40    /// # Notes
41    ///
42    /// This method is not intended to be called in analyses but rather in writing methods
43    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
44    fn value_on_local(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
45        let mut variable = dyn_clone::clone_box(self);
46        variable.bind(dataset.metadata())?;
47        #[cfg(feature = "rayon")]
48        let local_values: Vec<f64> = (0..dataset.n_events_local())
49            .into_par_iter()
50            .map(|event_index| {
51                let event = dataset.event_view(event_index);
52                variable.value(&event)
53            })
54            .collect();
55        #[cfg(not(feature = "rayon"))]
56        let local_values: Vec<f64> = (0..dataset.n_events_local())
57            .map(|event_index| {
58                let event = dataset.event_view(event_index);
59                variable.value(&event)
60            })
61            .collect();
62        Ok(local_values)
63    }
64
65    /// This method distributes the [`Variable::value`] method over each [`EventData`] in a
66    /// [`Dataset`] (MPI-compatible version).
67    ///
68    /// # Notes
69    ///
70    /// This method is not intended to be called in analyses but rather in writing methods
71    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
72    #[cfg(feature = "mpi")]
73    fn value_on_mpi(&self, dataset: &Dataset, world: &SimpleCommunicator) -> LadduResult<Vec<f64>> {
74        let local_weights = self.value_on_local(dataset)?;
75        let n_events = dataset.n_events();
76        let mut buffer: Vec<f64> = vec![0.0; n_events];
77        let (counts, displs) = world.get_counts_displs(n_events);
78        {
79            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
80            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
81        }
82        Ok(buffer)
83    }
84
85    /// This method distributes the [`Variable::value`] method over each [`EventData`] in a
86    /// [`Dataset`].
87    fn value_on(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
88        #[cfg(feature = "mpi")]
89        {
90            if let Some(world) = crate::mpi::get_world() {
91                return self.value_on_mpi(dataset, &world);
92            }
93        }
94        self.value_on_local(dataset)
95    }
96
97    /// Create an [`VariableExpression`] that evaluates to `self == val`
98    fn eq(&self, val: f64) -> VariableExpression
99    where
100        Self: std::marker::Sized + 'static,
101    {
102        VariableExpression::Eq(dyn_clone::clone_box(self), val)
103    }
104
105    /// Create an [`VariableExpression`] that evaluates to `self < val`
106    fn lt(&self, val: f64) -> VariableExpression
107    where
108        Self: std::marker::Sized + 'static,
109    {
110        VariableExpression::Lt(dyn_clone::clone_box(self), val)
111    }
112
113    /// Create an [`VariableExpression`] that evaluates to `self > val`
114    fn gt(&self, val: f64) -> VariableExpression
115    where
116        Self: std::marker::Sized + 'static,
117    {
118        VariableExpression::Gt(dyn_clone::clone_box(self), val)
119    }
120
121    /// Create an [`VariableExpression`] that evaluates to `self >= val`
122    fn ge(&self, val: f64) -> VariableExpression
123    where
124        Self: std::marker::Sized + 'static,
125    {
126        self.gt(val).or(&self.eq(val))
127    }
128
129    /// Create an [`VariableExpression`] that evaluates to `self <= val`
130    fn le(&self, val: f64) -> VariableExpression
131    where
132        Self: std::marker::Sized + 'static,
133    {
134        self.lt(val).or(&self.eq(val))
135    }
136}
137dyn_clone::clone_trait_object!(Variable);
138
139/// Expressions which can be used to compare [`Variable`]s to [`f64`]s.
140#[derive(Clone, Debug)]
141pub enum VariableExpression {
142    /// Expression which is true when the variable is equal to the float.
143    Eq(Box<dyn Variable>, f64),
144    /// Expression which is true when the variable is less than the float.
145    Lt(Box<dyn Variable>, f64),
146    /// Expression which is true when the variable is greater than the float.
147    Gt(Box<dyn Variable>, f64),
148    /// Expression which is true when both inner expressions are true.
149    And(Box<VariableExpression>, Box<VariableExpression>),
150    /// Expression which is true when either inner expression is true.
151    Or(Box<VariableExpression>, Box<VariableExpression>),
152    /// Expression which is true when the inner expression is false.
153    Not(Box<VariableExpression>),
154}
155
156impl VariableExpression {
157    /// Construct an [`VariableExpression::And`] from the current expression and another.
158    pub fn and(&self, rhs: &VariableExpression) -> VariableExpression {
159        VariableExpression::And(Box::new(self.clone()), Box::new(rhs.clone()))
160    }
161
162    /// Construct an [`VariableExpression::Or`] from the current expression and another.
163    pub fn or(&self, rhs: &VariableExpression) -> VariableExpression {
164        VariableExpression::Or(Box::new(self.clone()), Box::new(rhs.clone()))
165    }
166
167    /// Comple the [`VariableExpression`] into a [`CompiledExpression`] bound to the supplied
168    /// metadata so that all variable references are resolved.
169    pub(crate) fn compile(
170        &self,
171        metadata: &DatasetMetadata,
172    ) -> LadduResult<CompiledVariableExpression> {
173        let mut compiled = compile_expression(self.clone());
174        compiled.bind(metadata)?;
175        Ok(compiled)
176    }
177}
178impl Display for VariableExpression {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match self {
181            VariableExpression::Eq(var, val) => {
182                write!(f, "({} == {})", var, val)
183            }
184            VariableExpression::Lt(var, val) => {
185                write!(f, "({} < {})", var, val)
186            }
187            VariableExpression::Gt(var, val) => {
188                write!(f, "({} > {})", var, val)
189            }
190            VariableExpression::And(lhs, rhs) => {
191                write!(f, "({} & {})", lhs, rhs)
192            }
193            VariableExpression::Or(lhs, rhs) => {
194                write!(f, "({} | {})", lhs, rhs)
195            }
196            VariableExpression::Not(inner) => {
197                write!(f, "!({})", inner)
198            }
199        }
200    }
201}
202
203/// A method which negates the given expression.
204pub fn not(expr: &VariableExpression) -> VariableExpression {
205    VariableExpression::Not(Box::new(expr.clone()))
206}
207
208#[rustfmt::skip]
209impl_op_ex!(& |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.and(rhs) });
210#[rustfmt::skip]
211impl_op_ex!(| |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.or(rhs) });
212#[rustfmt::skip]
213impl_op_ex!(! |exp: &VariableExpression| -> VariableExpression{ not(exp) });
214
215#[derive(Debug)]
216enum Opcode {
217    PushEq(usize, f64),
218    PushLt(usize, f64),
219    PushGt(usize, f64),
220    And,
221    Or,
222    Not,
223}
224
225pub(crate) struct CompiledVariableExpression {
226    bytecode: Vec<Opcode>,
227    variables: Vec<Box<dyn Variable>>,
228}
229
230impl CompiledVariableExpression {
231    pub fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
232        for variable in &mut self.variables {
233            variable.bind(metadata)?;
234        }
235        Ok(())
236    }
237
238    /// Evaluate the [`CompiledExpression`] on a given named event view.
239    pub fn evaluate(&self, event: &NamedEventView<'_>) -> bool {
240        let mut stack = Vec::with_capacity(self.bytecode.len());
241
242        for op in &self.bytecode {
243            match op {
244                Opcode::PushEq(i, val) => stack.push(self.variables[*i].value(event) == *val),
245                Opcode::PushLt(i, val) => stack.push(self.variables[*i].value(event) < *val),
246                Opcode::PushGt(i, val) => stack.push(self.variables[*i].value(event) > *val),
247                Opcode::Not => {
248                    let a = stack.pop().unwrap();
249                    stack.push(!a);
250                }
251                Opcode::And => {
252                    let b = stack.pop().unwrap();
253                    let a = stack.pop().unwrap();
254                    stack.push(a && b);
255                }
256                Opcode::Or => {
257                    let b = stack.pop().unwrap();
258                    let a = stack.pop().unwrap();
259                    stack.push(a || b);
260                }
261            }
262        }
263
264        stack.pop().unwrap()
265    }
266}
267
268pub(crate) fn compile_expression(expr: VariableExpression) -> CompiledVariableExpression {
269    let mut bytecode = Vec::new();
270    let mut variables: Vec<Box<dyn Variable>> = Vec::new();
271
272    fn compile(
273        expr: VariableExpression,
274        bytecode: &mut Vec<Opcode>,
275        variables: &mut Vec<Box<dyn Variable>>,
276    ) {
277        match expr {
278            VariableExpression::Eq(var, val) => {
279                variables.push(var);
280                bytecode.push(Opcode::PushEq(variables.len() - 1, val));
281            }
282            VariableExpression::Lt(var, val) => {
283                variables.push(var);
284                bytecode.push(Opcode::PushLt(variables.len() - 1, val));
285            }
286            VariableExpression::Gt(var, val) => {
287                variables.push(var);
288                bytecode.push(Opcode::PushGt(variables.len() - 1, val));
289            }
290            VariableExpression::And(lhs, rhs) => {
291                compile(*lhs, bytecode, variables);
292                compile(*rhs, bytecode, variables);
293                bytecode.push(Opcode::And);
294            }
295            VariableExpression::Or(lhs, rhs) => {
296                compile(*lhs, bytecode, variables);
297                compile(*rhs, bytecode, variables);
298                bytecode.push(Opcode::Or);
299            }
300            VariableExpression::Not(inner) => {
301                compile(*inner, bytecode, variables);
302                bytecode.push(Opcode::Not);
303            }
304        }
305    }
306
307    compile(expr, &mut bytecode, &mut variables);
308
309    CompiledVariableExpression {
310        bytecode,
311        variables,
312    }
313}
314
315fn names_to_string(names: &[String]) -> String {
316    names.join(", ")
317}
318
319/// A reusable selection that may span one or more four-momentum names.
320///
321/// Instances are constructed from metadata-facing identifiers and later bound to
322/// column indices so that variable evaluators can resolve aliases or grouped
323/// particles efficiently.
324#[derive(Clone, Debug, Serialize, Deserialize)]
325pub struct P4Selection {
326    names: Vec<String>,
327    #[serde(skip, default)]
328    indices: Vec<usize>,
329}
330
331impl P4Selection {
332    fn new_many<I, S>(names: I) -> Self
333    where
334        I: IntoIterator<Item = S>,
335        S: Into<String>,
336    {
337        Self {
338            names: names.into_iter().map(Into::into).collect(),
339            indices: Vec::new(),
340        }
341    }
342
343    pub(crate) fn with_indices<I, S>(names: I, indices: Vec<usize>) -> Self
344    where
345        I: IntoIterator<Item = S>,
346        S: Into<String>,
347    {
348        Self {
349            names: names.into_iter().map(Into::into).collect(),
350            indices,
351        }
352    }
353
354    /// Returns the metadata names contributing to this selection.
355    pub fn names(&self) -> &[String] {
356        &self.names
357    }
358
359    pub(crate) fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
360        let mut resolved = Vec::with_capacity(self.names.len());
361        for name in &self.names {
362            metadata.append_indices_for_name(name, &mut resolved)?;
363        }
364        self.indices = resolved;
365        Ok(())
366    }
367
368    /// The resolved column indices backing this selection.
369    pub fn indices(&self) -> &[usize] {
370        &self.indices
371    }
372
373    pub(crate) fn momentum(&self, event: &EventData) -> Vec4 {
374        event.get_p4_sum(self.indices())
375    }
376}
377
378/// Helper trait to convert common particle specifications into [`P4Selection`] instances.
379pub trait IntoP4Selection {
380    /// Convert the input into a [`P4Selection`].
381    fn into_selection(self) -> P4Selection;
382}
383
384impl IntoP4Selection for P4Selection {
385    fn into_selection(self) -> P4Selection {
386        self
387    }
388}
389
390impl IntoP4Selection for &P4Selection {
391    fn into_selection(self) -> P4Selection {
392        self.clone()
393    }
394}
395
396impl IntoP4Selection for String {
397    fn into_selection(self) -> P4Selection {
398        P4Selection::new_many(vec![self])
399    }
400}
401
402impl IntoP4Selection for &String {
403    fn into_selection(self) -> P4Selection {
404        P4Selection::new_many(vec![self.clone()])
405    }
406}
407
408impl IntoP4Selection for &str {
409    fn into_selection(self) -> P4Selection {
410        P4Selection::new_many(vec![self.to_string()])
411    }
412}
413
414impl<S> IntoP4Selection for Vec<S>
415where
416    S: Into<String>,
417{
418    fn into_selection(self) -> P4Selection {
419        P4Selection::new_many(self.into_iter().map(Into::into).collect::<Vec<_>>())
420    }
421}
422
423impl<S> IntoP4Selection for &[S]
424where
425    S: Clone + Into<String>,
426{
427    fn into_selection(self) -> P4Selection {
428        P4Selection::new_many(self.iter().cloned().map(Into::into).collect::<Vec<_>>())
429    }
430}
431
432impl<S, const N: usize> IntoP4Selection for [S; N]
433where
434    S: Into<String>,
435{
436    fn into_selection(self) -> P4Selection {
437        P4Selection::new_many(self.into_iter().map(Into::into).collect::<Vec<_>>())
438    }
439}
440
441impl<S, const N: usize> IntoP4Selection for &[S; N]
442where
443    S: Clone + Into<String>,
444{
445    fn into_selection(self) -> P4Selection {
446        P4Selection::new_many(self.iter().cloned().map(Into::into).collect::<Vec<_>>())
447    }
448}
449#[derive(Clone, Debug, Serialize, Deserialize)]
450struct AuxSelection {
451    name: String,
452    #[serde(skip, default)]
453    index: Option<usize>,
454}
455
456impl AuxSelection {
457    fn new<S: Into<String>>(name: S) -> Self {
458        Self {
459            name: name.into(),
460            index: None,
461        }
462    }
463
464    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
465        let idx = metadata
466            .aux_index(&self.name)
467            .ok_or_else(|| LadduError::UnknownName {
468                category: "aux",
469                name: self.name.clone(),
470            })?;
471        self.index = Some(idx);
472        Ok(())
473    }
474
475    fn index(&self) -> usize {
476        self.index.expect("AuxSelection must be bound before use")
477    }
478
479    fn name(&self) -> &str {
480        &self.name
481    }
482}
483
484/// Source for a mass variable.
485#[derive(Clone, Debug, Serialize, Deserialize)]
486enum MassSource {
487    Selection(P4Selection),
488    Reaction {
489        reaction: Box<Reaction>,
490        particle: Particle,
491    },
492}
493
494/// A struct for obtaining the invariant mass of a selected or reaction-defined particle.
495#[derive(Clone, Debug, Serialize, Deserialize)]
496pub struct Mass {
497    source: MassSource,
498}
499impl Mass {
500    /// Create a new [`Mass`] from the sum of the four-momenta identified by `constituents` in the
501    /// [`EventData`]'s `p4s` field.
502    pub fn new<C>(constituents: C) -> Self
503    where
504        C: IntoP4Selection,
505    {
506        Self {
507            source: MassSource::Selection(constituents.into_selection()),
508        }
509    }
510
511    /// Create a new [`Mass`] for a particle resolved through a [`Reaction`].
512    pub fn from_reaction(reaction: Reaction, particle: Particle) -> Self {
513        Self {
514            source: MassSource::Reaction {
515                reaction: Box::new(reaction),
516                particle,
517            },
518        }
519    }
520}
521impl Display for Mass {
522    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        match &self.source {
524            MassSource::Selection(constituents) => {
525                write!(
526                    f,
527                    "Mass(constituents=[{}])",
528                    names_to_string(constituents.names())
529                )
530            }
531            MassSource::Reaction { particle, .. } => write!(f, "Mass(particle={})", particle),
532        }
533    }
534}
535#[typetag::serde]
536impl Variable for Mass {
537    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
538        match &mut self.source {
539            MassSource::Selection(constituents) => constituents.bind(metadata),
540            MassSource::Reaction { .. } => Ok(()),
541        }
542    }
543    fn value(&self, event: &NamedEventView<'_>) -> f64 {
544        match &self.source {
545            MassSource::Selection(constituents) => constituents
546                .indices()
547                .iter()
548                .map(|index| event.p4_at(*index))
549                .sum::<Vec4>()
550                .m(),
551            MassSource::Reaction { reaction, particle } => reaction
552                .p4(event, particle)
553                .unwrap_or_else(|err| panic!("failed to evaluate reaction mass: {err}"))
554                .m(),
555        }
556    }
557}
558
559#[derive(Clone, Debug, Serialize, Deserialize)]
560struct AngleSource {
561    reaction: Box<Reaction>,
562    parent: Particle,
563    daughter: Particle,
564}
565
566impl AngleSource {
567    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
568        let _ = metadata;
569        Ok(())
570    }
571
572    fn costheta(&self, event: &NamedEventView<'_>, frame: Frame) -> f64 {
573        self.reaction
574            .angles_value(event, &self.parent, &self.daughter, frame)
575            .unwrap_or_else(|err| panic!("failed to evaluate reaction costheta: {err}"))
576            .costheta()
577    }
578
579    fn phi(&self, event: &NamedEventView<'_>, frame: Frame) -> f64 {
580        self.reaction
581            .angles_value(event, &self.parent, &self.daughter, frame)
582            .unwrap_or_else(|err| panic!("failed to evaluate reaction phi: {err}"))
583            .phi()
584    }
585}
586
587/// A struct for obtaining the $`\cos\theta`$ (cosine of the polar angle) of a decay product in
588/// a given reference frame of its parent resonance.
589#[derive(Clone, Debug, Serialize, Deserialize)]
590pub struct CosTheta {
591    source: AngleSource,
592    frame: Frame,
593}
594impl Display for CosTheta {
595    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        write!(
597            f,
598            "CosTheta(parent={}, daughter={}, frame={})",
599            self.source.parent, self.source.daughter, self.frame
600        )
601    }
602}
603impl CosTheta {
604    /// Construct an angle for a reaction daughter in the specified parent frame.
605    pub fn from_reaction(
606        reaction: Reaction,
607        parent: Particle,
608        daughter: Particle,
609        frame: Frame,
610    ) -> Self {
611        Self {
612            source: AngleSource {
613                reaction: Box::new(reaction),
614                parent,
615                daughter,
616            },
617            frame,
618        }
619    }
620}
621
622#[typetag::serde]
623impl Variable for CosTheta {
624    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
625        self.source.bind(metadata)
626    }
627    fn value(&self, event: &NamedEventView<'_>) -> f64 {
628        self.source.costheta(event, self.frame)
629    }
630}
631
632/// A struct for obtaining the $`\phi`$ angle (azimuthal angle) of a decay product in a given
633/// reference frame of its parent resonance.
634#[derive(Clone, Debug, Serialize, Deserialize)]
635pub struct Phi {
636    source: AngleSource,
637    frame: Frame,
638}
639impl Display for Phi {
640    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
641        write!(
642            f,
643            "Phi(parent={}, daughter={}, frame={})",
644            self.source.parent, self.source.daughter, self.frame
645        )
646    }
647}
648impl Phi {
649    /// Construct an angle for a reaction daughter in the specified parent frame.
650    pub fn from_reaction(
651        reaction: Reaction,
652        parent: Particle,
653        daughter: Particle,
654        frame: Frame,
655    ) -> Self {
656        Self {
657            source: AngleSource {
658                reaction: Box::new(reaction),
659                parent,
660                daughter,
661            },
662            frame,
663        }
664    }
665}
666#[typetag::serde]
667impl Variable for Phi {
668    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
669        self.source.bind(metadata)
670    }
671    fn value(&self, event: &NamedEventView<'_>) -> f64 {
672        self.source.phi(event, self.frame)
673    }
674}
675
676/// A struct for obtaining both spherical angles at the same time.
677#[derive(Clone, Debug, Serialize, Deserialize)]
678pub struct Angles {
679    /// See [`CosTheta`].
680    pub costheta: CosTheta,
681    /// See [`Phi`].
682    pub phi: Phi,
683}
684
685/// A pair of variables that define spherical decay angles.
686pub trait AngleVariables: Debug + Display + Send + Sync {
687    /// Return the variable used for `cos(theta)`.
688    fn costheta_variable(&self) -> Box<dyn Variable>;
689
690    /// Return the variable used for `phi`.
691    fn phi_variable(&self) -> Box<dyn Variable>;
692}
693
694impl AngleVariables for Angles {
695    fn costheta_variable(&self) -> Box<dyn Variable> {
696        Box::new(self.costheta.clone())
697    }
698
699    fn phi_variable(&self) -> Box<dyn Variable> {
700        Box::new(self.phi.clone())
701    }
702}
703
704impl Display for Angles {
705    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
706        write!(
707            f,
708            "Angles(parent={}, daughter={}, frame={})",
709            self.costheta.source.parent, self.costheta.source.daughter, self.costheta.frame
710        )
711    }
712}
713impl Angles {
714    /// Construct reaction-derived angle variables for a daughter in its parent frame.
715    pub fn from_reaction(
716        reaction: Reaction,
717        parent: Particle,
718        daughter: Particle,
719        frame: Frame,
720    ) -> Self {
721        let costheta =
722            CosTheta::from_reaction(reaction.clone(), parent.clone(), daughter.clone(), frame);
723        let phi = Phi::from_reaction(reaction, parent, daughter, frame);
724        Self { costheta, phi }
725    }
726}
727
728/// A struct defining the polarization angle for a beam relative to the production plane.
729#[derive(Clone, Debug, Serialize, Deserialize)]
730pub struct PolAngle {
731    reaction: Reaction,
732    angle_aux: AuxSelection,
733}
734impl Display for PolAngle {
735    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
736        write!(
737            f,
738            "PolAngle(reaction={:?}, angle_aux={})",
739            self.reaction.topology(),
740            self.angle_aux.name(),
741        )
742    }
743}
744impl PolAngle {
745    /// Constructs the polarization angle given a [`Reaction`] describing the production plane and
746    /// the auxiliary column storing the precomputed angle.
747    pub fn new<A>(reaction: Reaction, angle_aux: A) -> Self
748    where
749        A: Into<String>,
750    {
751        Self {
752            reaction,
753            angle_aux: AuxSelection::new(angle_aux.into()),
754        }
755    }
756}
757#[typetag::serde]
758impl Variable for PolAngle {
759    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
760        let _ = metadata;
761        self.angle_aux.bind(metadata)?;
762        Ok(())
763    }
764    fn value(&self, event: &NamedEventView<'_>) -> f64 {
765        let resolved = self
766            .reaction
767            .resolve_two_to_two(event)
768            .unwrap_or_else(|err| panic!("failed to evaluate polarization angle: {err}"));
769        let beam = resolved.p1();
770        let recoil = resolved.p4();
771        let pol_angle = event.aux_at(self.angle_aux.index());
772        let polarization = Vec3::new(pol_angle.cos(), pol_angle.sin(), 0.0);
773        let y = beam.vec3().cross(&-recoil.vec3()).unit();
774        let numerator = y.dot(&polarization);
775        let denominator = beam.vec3().unit().dot(&polarization.cross(&y));
776        f64::atan2(numerator, denominator)
777    }
778}
779
780/// A struct defining the polarization magnitude for a beam relative to the production plane.
781#[derive(Clone, Debug, Serialize, Deserialize)]
782pub struct PolMagnitude {
783    magnitude_aux: AuxSelection,
784}
785impl Display for PolMagnitude {
786    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
787        write!(
788            f,
789            "PolMagnitude(magnitude_aux={})",
790            self.magnitude_aux.name(),
791        )
792    }
793}
794impl PolMagnitude {
795    /// Constructs the polarization magnitude given the named auxiliary column containing the
796    /// magnitude value.
797    pub fn new<S: Into<String>>(magnitude_aux: S) -> Self {
798        Self {
799            magnitude_aux: AuxSelection::new(magnitude_aux.into()),
800        }
801    }
802}
803#[typetag::serde]
804impl Variable for PolMagnitude {
805    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
806        self.magnitude_aux.bind(metadata)
807    }
808    fn value(&self, event: &NamedEventView<'_>) -> f64 {
809        event.aux_at(self.magnitude_aux.index())
810    }
811}
812
813/// A struct for obtaining both the polarization angle and magnitude at the same time.
814#[derive(Clone, Debug, Serialize, Deserialize)]
815pub struct Polarization {
816    /// See [`PolMagnitude`].
817    pub pol_magnitude: PolMagnitude,
818    /// See [`PolAngle`].
819    pub pol_angle: PolAngle,
820}
821impl Display for Polarization {
822    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
823        write!(
824            f,
825            "Polarization(reaction={:?}, magnitude_aux={}, angle_aux={})",
826            self.pol_angle.reaction.topology(),
827            self.pol_magnitude.magnitude_aux.name(),
828            self.pol_angle.angle_aux.name(),
829        )
830    }
831}
832impl Polarization {
833    /// Constructs the polarization angle and magnitude given a [`Reaction`] and distinct
834    /// auxiliary columns for magnitude and angle.
835    ///
836    /// # Panics
837    ///
838    /// Panics if `magnitude_aux` and `angle_aux` refer to the same auxiliary column name.
839    pub fn new<M, A>(reaction: Reaction, magnitude_aux: M, angle_aux: A) -> Self
840    where
841        M: Into<String>,
842        A: Into<String>,
843    {
844        let magnitude_aux = magnitude_aux.into();
845        let angle_aux = angle_aux.into();
846        assert!(
847            magnitude_aux != angle_aux,
848            "Polarization magnitude and angle must reference distinct auxiliary columns"
849        );
850        Self {
851            pol_magnitude: PolMagnitude::new(magnitude_aux),
852            pol_angle: PolAngle::new(reaction, angle_aux),
853        }
854    }
855}
856
857/// A struct used to calculate Mandelstam variables ($`s`$, $`t`$, or $`u`$).
858///
859/// By convention, the metric is chosen to be $`(+---)`$ and the variables are defined as follows
860/// (ignoring factors of $`c`$):
861///
862/// $`s = (p_1 + p_2)^2 = (p_3 + p_4)^2`$
863///
864/// $`t = (p_1 - p_3)^2 = (p_4 - p_2)^2`$
865///
866/// $`u = (p_1 - p_4)^2 = (p_3 - p_2)^2`$
867#[derive(Clone, Debug, Serialize, Deserialize)]
868pub struct Mandelstam {
869    reaction: Reaction,
870    channel: Channel,
871}
872impl Display for Mandelstam {
873    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
874        write!(f, "Mandelstam(channel={})", self.channel)
875    }
876}
877impl Mandelstam {
878    /// Constructs the Mandelstam variable for the given `channel` using the supplied [`Reaction`].
879    pub fn new(reaction: Reaction, channel: Channel) -> Self {
880        Self { reaction, channel }
881    }
882}
883
884#[typetag::serde]
885impl Variable for Mandelstam {
886    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
887        let _ = metadata;
888        Ok(())
889    }
890    fn value(&self, event: &NamedEventView<'_>) -> f64 {
891        let resolved = self
892            .reaction
893            .resolve_two_to_two(event)
894            .unwrap_or_else(|err| panic!("failed to evaluate reaction Mandelstam: {err}"));
895        match self.channel {
896            Channel::S => resolved.s(),
897            Channel::T => resolved.t(),
898            Channel::U => resolved.u(),
899        }
900    }
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906    use crate::data::test_dataset;
907    use approx::assert_relative_eq;
908
909    fn reaction() -> (Reaction, Particle, Particle, Particle) {
910        let beam = Particle::measured("beam", "beam");
911        let target = Particle::missing("target");
912        let kshort1 = Particle::measured("K_S1", "kshort1");
913        let kshort2 = Particle::measured("K_S2", "kshort2");
914        let kk = Particle::composite("KK", [&kshort1, &kshort2]).unwrap();
915        let proton = Particle::measured("proton", "proton");
916        (
917            Reaction::two_to_two(&beam, &target, &kk, &proton).unwrap(),
918            kk,
919            kshort1,
920            kshort2,
921        )
922    }
923
924    #[test]
925    fn test_mass_single_particle() {
926        let dataset = test_dataset();
927        let mut mass = Mass::new("proton");
928        mass.bind(dataset.metadata()).unwrap();
929        let event = dataset.event_view(0);
930        assert_relative_eq!(mass.value(&event), 1.007);
931    }
932
933    #[test]
934    fn test_mass_multiple_particles() {
935        let dataset = test_dataset();
936        let mut mass = Mass::new(["kshort1", "kshort2"]);
937        mass.bind(dataset.metadata()).unwrap();
938        let event = dataset.event_view(0);
939        assert_relative_eq!(mass.value(&event), 1.3743786309153077);
940    }
941
942    #[test]
943    fn test_mass_display() {
944        let mass = Mass::new(["kshort1", "kshort2"]);
945        assert_eq!(mass.to_string(), "Mass(constituents=[kshort1, kshort2])");
946    }
947
948    #[test]
949    fn test_costheta_helicity() {
950        let dataset = test_dataset();
951        let (reaction, kk, kshort1, _) = reaction();
952        let decay = reaction.decay(&kk).unwrap();
953        let mut costheta = decay.costheta(&kshort1, Frame::Helicity).unwrap();
954        costheta.bind(dataset.metadata()).unwrap();
955        let event = dataset.event_view(0);
956        assert_relative_eq!(costheta.value(&event), -0.4611175068834202);
957    }
958
959    #[test]
960    fn test_costheta_display() {
961        let (reaction, kk, kshort1, _) = reaction();
962        let decay = reaction.decay(&kk).unwrap();
963        let costheta = decay.costheta(&kshort1, Frame::Helicity).unwrap();
964        assert_eq!(
965            costheta.to_string(),
966            "CosTheta(parent=KK, daughter=K_S1, frame=Helicity)"
967        );
968    }
969
970    #[test]
971    fn test_phi_helicity() {
972        let dataset = test_dataset();
973        let (reaction, kk, kshort1, _) = reaction();
974        let decay = reaction.decay(&kk).unwrap();
975        let mut phi = decay.phi(&kshort1, Frame::Helicity).unwrap();
976        phi.bind(dataset.metadata()).unwrap();
977        let event = dataset.event_view(0);
978        assert_relative_eq!(phi.value(&event), -2.657462587335066);
979    }
980
981    #[test]
982    fn test_phi_display() {
983        let (reaction, kk, kshort1, _) = reaction();
984        let decay = reaction.decay(&kk).unwrap();
985        let phi = decay.phi(&kshort1, Frame::Helicity).unwrap();
986        assert_eq!(
987            phi.to_string(),
988            "Phi(parent=KK, daughter=K_S1, frame=Helicity)"
989        );
990    }
991
992    #[test]
993    fn test_costheta_gottfried_jackson() {
994        let dataset = test_dataset();
995        let (reaction, kk, kshort1, _) = reaction();
996        let decay = reaction.decay(&kk).unwrap();
997        let mut costheta = decay.costheta(&kshort1, Frame::GottfriedJackson).unwrap();
998        costheta.bind(dataset.metadata()).unwrap();
999        let event = dataset.event_view(0);
1000        assert_relative_eq!(costheta.value(&event), 0.09198832278032032);
1001    }
1002
1003    #[test]
1004    fn test_phi_gottfried_jackson() {
1005        let dataset = test_dataset();
1006        let (reaction, kk, kshort1, _) = reaction();
1007        let decay = reaction.decay(&kk).unwrap();
1008        let mut phi = decay.phi(&kshort1, Frame::GottfriedJackson).unwrap();
1009        phi.bind(dataset.metadata()).unwrap();
1010        let event = dataset.event_view(0);
1011        assert_relative_eq!(phi.value(&event), -2.7139131991339056);
1012    }
1013
1014    #[test]
1015    fn test_angles() {
1016        let dataset = test_dataset();
1017        let (reaction, kk, kshort1, _) = reaction();
1018        let decay = reaction.decay(&kk).unwrap();
1019        let mut angles = decay.angles(&kshort1, Frame::Helicity).unwrap();
1020        angles.costheta.bind(dataset.metadata()).unwrap();
1021        angles.phi.bind(dataset.metadata()).unwrap();
1022        let event = dataset.event_view(0);
1023        assert_relative_eq!(angles.costheta.value(&event), -0.4611175068834202);
1024        assert_relative_eq!(angles.phi.value(&event), -2.657462587335066);
1025    }
1026
1027    #[test]
1028    fn test_angles_display() {
1029        let (reaction, kk, kshort1, _) = reaction();
1030        let decay = reaction.decay(&kk).unwrap();
1031        let angles = decay.angles(&kshort1, Frame::Helicity).unwrap();
1032        assert_eq!(
1033            angles.to_string(),
1034            "Angles(parent=KK, daughter=K_S1, frame=Helicity)"
1035        );
1036    }
1037
1038    #[test]
1039    fn test_pol_angle() {
1040        let dataset = test_dataset();
1041        let (reaction, _, _, _) = reaction();
1042        let mut pol_angle = reaction.pol_angle("pol_angle");
1043        pol_angle.bind(dataset.metadata()).unwrap();
1044        let event = dataset.event_view(0);
1045        assert_relative_eq!(pol_angle.value(&event), 1.935929887818673);
1046    }
1047
1048    #[test]
1049    fn test_pol_magnitude() {
1050        let dataset = test_dataset();
1051        let mut pol_magnitude = PolMagnitude::new("pol_magnitude");
1052        pol_magnitude.bind(dataset.metadata()).unwrap();
1053        let event = dataset.event_view(0);
1054        assert_relative_eq!(pol_magnitude.value(&event), 0.38562805);
1055    }
1056
1057    #[test]
1058    fn test_pol_magnitude_display() {
1059        let pol_magnitude = PolMagnitude::new("pol_magnitude");
1060        assert_eq!(
1061            pol_magnitude.to_string(),
1062            "PolMagnitude(magnitude_aux=pol_magnitude)"
1063        );
1064    }
1065
1066    #[test]
1067    fn test_polarization() {
1068        let dataset = test_dataset();
1069        let (reaction, _, _, _) = reaction();
1070        let mut polarization = reaction.polarization("pol_magnitude", "pol_angle");
1071        polarization.pol_angle.bind(dataset.metadata()).unwrap();
1072        polarization.pol_magnitude.bind(dataset.metadata()).unwrap();
1073        let event = dataset.event_view(0);
1074        assert_relative_eq!(polarization.pol_angle.value(&event), 1.935929887818673);
1075        assert_relative_eq!(polarization.pol_magnitude.value(&event), 0.38562805);
1076    }
1077
1078    #[test]
1079    fn test_mandelstam() {
1080        let dataset = test_dataset();
1081        let metadata = dataset.metadata();
1082        let (reaction, _, _, _) = reaction();
1083        let mut s = reaction.mandelstam(Channel::S);
1084        let mut t = reaction.mandelstam(Channel::T);
1085        let mut u = reaction.mandelstam(Channel::U);
1086        for variable in [&mut s, &mut t, &mut u] {
1087            variable.bind(metadata).unwrap();
1088        }
1089        let event = dataset.event_view(0);
1090        let resolved = reaction.resolve_two_to_two(&event).unwrap();
1091        assert_relative_eq!(s.value(&event), resolved.s());
1092        assert_relative_eq!(t.value(&event), resolved.t());
1093        assert_relative_eq!(u.value(&event), resolved.u());
1094    }
1095
1096    #[test]
1097    fn test_mandelstam_display() {
1098        let (reaction, _, _, _) = reaction();
1099        let s = reaction.mandelstam(Channel::S);
1100        assert_eq!(s.to_string(), "Mandelstam(channel=s)");
1101    }
1102
1103    #[test]
1104    fn test_variable_value_on() {
1105        let dataset = test_dataset();
1106        let mass = Mass::new(["kshort1", "kshort2"]);
1107
1108        let values = mass.value_on(&dataset).unwrap();
1109        assert_eq!(values.len(), 1);
1110        assert_relative_eq!(values[0], 1.3743786309153077);
1111    }
1112}