Skip to main content

laddu_core/utils/
variables.rs

1use dyn_clone::DynClone;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Debug, Display};
4
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8#[cfg(feature = "mpi")]
9use crate::mpi::LadduMPI;
10use crate::{
11    data::{Dataset, DatasetMetadata, EventData, NamedEventView},
12    utils::{
13        enums::{Channel, Frame},
14        vectors::{Vec3, Vec4},
15    },
16    LadduError, LadduResult,
17};
18
19use auto_ops::impl_op_ex;
20
21#[cfg(feature = "mpi")]
22use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
23
24/// Standard methods for extracting some value from an event view.
25#[typetag::serde(tag = "type")]
26pub trait Variable: DynClone + Send + Sync + Debug + Display {
27    /// Bind the variable to dataset metadata so that any referenced names can be resolved to
28    /// concrete indices. Implementations that do not require metadata may keep the default
29    /// no-op.
30    fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
31        Ok(())
32    }
33
34    /// This method extracts a single value (like a mass) from an event access view.
35    fn value(&self, event: &NamedEventView<'_>) -> f64;
36
37    /// This method distributes [`Variable::value`] over each event in a [`Dataset`] (non-MPI version).
38    ///
39    /// # Notes
40    ///
41    /// This method is not intended to be called in analyses but rather in writing methods
42    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
43    fn value_on_local(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
44        let mut variable = dyn_clone::clone_box(self);
45        variable.bind(dataset.metadata())?;
46        #[cfg(feature = "rayon")]
47        let local_values: Vec<f64> = (0..dataset.n_events_local())
48            .into_par_iter()
49            .map(|event_index| {
50                let event = dataset.event_view(event_index);
51                variable.value(&event)
52            })
53            .collect();
54        #[cfg(not(feature = "rayon"))]
55        let local_values: Vec<f64> = (0..dataset.n_events_local())
56            .map(|event_index| {
57                let event = dataset.event_view(event_index);
58                variable.value(&event)
59            })
60            .collect();
61        Ok(local_values)
62    }
63
64    /// This method distributes the [`Variable::value`] method over each [`EventData`] in a
65    /// [`Dataset`] (MPI-compatible version).
66    ///
67    /// # Notes
68    ///
69    /// This method is not intended to be called in analyses but rather in writing methods
70    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
71    #[cfg(feature = "mpi")]
72    fn value_on_mpi(&self, dataset: &Dataset, world: &SimpleCommunicator) -> LadduResult<Vec<f64>> {
73        let local_weights = self.value_on_local(dataset)?;
74        let n_events = dataset.n_events();
75        let mut buffer: Vec<f64> = vec![0.0; n_events];
76        let (counts, displs) = world.get_counts_displs(n_events);
77        {
78            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
79            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
80        }
81        Ok(buffer)
82    }
83
84    /// This method distributes the [`Variable::value`] method over each [`EventData`] in a
85    /// [`Dataset`].
86    fn value_on(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
87        #[cfg(feature = "mpi")]
88        {
89            if let Some(world) = crate::mpi::get_world() {
90                return self.value_on_mpi(dataset, &world);
91            }
92        }
93        self.value_on_local(dataset)
94    }
95
96    /// Create an [`VariableExpression`] that evaluates to `self == val`
97    fn eq(&self, val: f64) -> VariableExpression
98    where
99        Self: std::marker::Sized + 'static,
100    {
101        VariableExpression::Eq(dyn_clone::clone_box(self), val)
102    }
103
104    /// Create an [`VariableExpression`] that evaluates to `self < val`
105    fn lt(&self, val: f64) -> VariableExpression
106    where
107        Self: std::marker::Sized + 'static,
108    {
109        VariableExpression::Lt(dyn_clone::clone_box(self), val)
110    }
111
112    /// Create an [`VariableExpression`] that evaluates to `self > val`
113    fn gt(&self, val: f64) -> VariableExpression
114    where
115        Self: std::marker::Sized + 'static,
116    {
117        VariableExpression::Gt(dyn_clone::clone_box(self), val)
118    }
119
120    /// Create an [`VariableExpression`] that evaluates to `self >= val`
121    fn ge(&self, val: f64) -> VariableExpression
122    where
123        Self: std::marker::Sized + 'static,
124    {
125        self.gt(val).or(&self.eq(val))
126    }
127
128    /// Create an [`VariableExpression`] that evaluates to `self <= val`
129    fn le(&self, val: f64) -> VariableExpression
130    where
131        Self: std::marker::Sized + 'static,
132    {
133        self.lt(val).or(&self.eq(val))
134    }
135}
136dyn_clone::clone_trait_object!(Variable);
137
138/// Expressions which can be used to compare [`Variable`]s to [`f64`]s.
139#[derive(Clone, Debug)]
140pub enum VariableExpression {
141    /// Expression which is true when the variable is equal to the float.
142    Eq(Box<dyn Variable>, f64),
143    /// Expression which is true when the variable is less than the float.
144    Lt(Box<dyn Variable>, f64),
145    /// Expression which is true when the variable is greater than the float.
146    Gt(Box<dyn Variable>, f64),
147    /// Expression which is true when both inner expressions are true.
148    And(Box<VariableExpression>, Box<VariableExpression>),
149    /// Expression which is true when either inner expression is true.
150    Or(Box<VariableExpression>, Box<VariableExpression>),
151    /// Expression which is true when the inner expression is false.
152    Not(Box<VariableExpression>),
153}
154
155impl VariableExpression {
156    /// Construct an [`VariableExpression::And`] from the current expression and another.
157    pub fn and(&self, rhs: &VariableExpression) -> VariableExpression {
158        VariableExpression::And(Box::new(self.clone()), Box::new(rhs.clone()))
159    }
160
161    /// Construct an [`VariableExpression::Or`] from the current expression and another.
162    pub fn or(&self, rhs: &VariableExpression) -> VariableExpression {
163        VariableExpression::Or(Box::new(self.clone()), Box::new(rhs.clone()))
164    }
165
166    /// Comple the [`VariableExpression`] into a [`CompiledExpression`] bound to the supplied
167    /// metadata so that all variable references are resolved.
168    pub(crate) fn compile(&self, metadata: &DatasetMetadata) -> LadduResult<CompiledExpression> {
169        let mut compiled = compile_expression(self.clone());
170        compiled.bind(metadata)?;
171        Ok(compiled)
172    }
173}
174impl Display for VariableExpression {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        match self {
177            VariableExpression::Eq(var, val) => {
178                write!(f, "({} == {})", var, val)
179            }
180            VariableExpression::Lt(var, val) => {
181                write!(f, "({} < {})", var, val)
182            }
183            VariableExpression::Gt(var, val) => {
184                write!(f, "({} > {})", var, val)
185            }
186            VariableExpression::And(lhs, rhs) => {
187                write!(f, "({} & {})", lhs, rhs)
188            }
189            VariableExpression::Or(lhs, rhs) => {
190                write!(f, "({} | {})", lhs, rhs)
191            }
192            VariableExpression::Not(inner) => {
193                write!(f, "!({})", inner)
194            }
195        }
196    }
197}
198
199/// A method which negates the given expression.
200pub fn not(expr: &VariableExpression) -> VariableExpression {
201    VariableExpression::Not(Box::new(expr.clone()))
202}
203
204#[rustfmt::skip]
205impl_op_ex!(& |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.and(rhs) });
206#[rustfmt::skip]
207impl_op_ex!(| |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.or(rhs) });
208#[rustfmt::skip]
209impl_op_ex!(! |exp: &VariableExpression| -> VariableExpression{ not(exp) });
210
211#[derive(Debug)]
212enum Opcode {
213    PushEq(usize, f64),
214    PushLt(usize, f64),
215    PushGt(usize, f64),
216    And,
217    Or,
218    Not,
219}
220
221pub(crate) struct CompiledExpression {
222    bytecode: Vec<Opcode>,
223    variables: Vec<Box<dyn Variable>>,
224}
225
226impl CompiledExpression {
227    pub fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
228        for variable in &mut self.variables {
229            variable.bind(metadata)?;
230        }
231        Ok(())
232    }
233
234    /// Evaluate the [`CompiledExpression`] on a given named event view.
235    pub fn evaluate(&self, event: &NamedEventView<'_>) -> bool {
236        let mut stack = Vec::with_capacity(self.bytecode.len());
237
238        for op in &self.bytecode {
239            match op {
240                Opcode::PushEq(i, val) => stack.push(self.variables[*i].value(event) == *val),
241                Opcode::PushLt(i, val) => stack.push(self.variables[*i].value(event) < *val),
242                Opcode::PushGt(i, val) => stack.push(self.variables[*i].value(event) > *val),
243                Opcode::Not => {
244                    let a = stack.pop().unwrap();
245                    stack.push(!a);
246                }
247                Opcode::And => {
248                    let b = stack.pop().unwrap();
249                    let a = stack.pop().unwrap();
250                    stack.push(a && b);
251                }
252                Opcode::Or => {
253                    let b = stack.pop().unwrap();
254                    let a = stack.pop().unwrap();
255                    stack.push(a || b);
256                }
257            }
258        }
259
260        stack.pop().unwrap()
261    }
262}
263
264pub(crate) fn compile_expression(expr: VariableExpression) -> CompiledExpression {
265    let mut bytecode = Vec::new();
266    let mut variables: Vec<Box<dyn Variable>> = Vec::new();
267
268    fn compile(
269        expr: VariableExpression,
270        bytecode: &mut Vec<Opcode>,
271        variables: &mut Vec<Box<dyn Variable>>,
272    ) {
273        match expr {
274            VariableExpression::Eq(var, val) => {
275                variables.push(var);
276                bytecode.push(Opcode::PushEq(variables.len() - 1, val));
277            }
278            VariableExpression::Lt(var, val) => {
279                variables.push(var);
280                bytecode.push(Opcode::PushLt(variables.len() - 1, val));
281            }
282            VariableExpression::Gt(var, val) => {
283                variables.push(var);
284                bytecode.push(Opcode::PushGt(variables.len() - 1, val));
285            }
286            VariableExpression::And(lhs, rhs) => {
287                compile(*lhs, bytecode, variables);
288                compile(*rhs, bytecode, variables);
289                bytecode.push(Opcode::And);
290            }
291            VariableExpression::Or(lhs, rhs) => {
292                compile(*lhs, bytecode, variables);
293                compile(*rhs, bytecode, variables);
294                bytecode.push(Opcode::Or);
295            }
296            VariableExpression::Not(inner) => {
297                compile(*inner, bytecode, variables);
298                bytecode.push(Opcode::Not);
299            }
300        }
301    }
302
303    compile(expr, &mut bytecode, &mut variables);
304
305    CompiledExpression {
306        bytecode,
307        variables,
308    }
309}
310
311fn names_to_string(names: &[String]) -> String {
312    names.join(", ")
313}
314
315/// A reusable selection that may span one or more four-momentum names.
316///
317/// Instances are constructed from metadata-facing identifiers and later bound to
318/// column indices so that variable evaluators can resolve aliases or grouped
319/// particles efficiently.
320#[derive(Clone, Debug, Serialize, Deserialize)]
321pub struct P4Selection {
322    names: Vec<String>,
323    #[serde(skip, default)]
324    indices: Vec<usize>,
325}
326
327impl P4Selection {
328    fn new_many<I, S>(names: I) -> Self
329    where
330        I: IntoIterator<Item = S>,
331        S: Into<String>,
332    {
333        Self {
334            names: names.into_iter().map(Into::into).collect(),
335            indices: Vec::new(),
336        }
337    }
338
339    pub(crate) fn with_indices<I, S>(names: I, indices: Vec<usize>) -> Self
340    where
341        I: IntoIterator<Item = S>,
342        S: Into<String>,
343    {
344        Self {
345            names: names.into_iter().map(Into::into).collect(),
346            indices,
347        }
348    }
349
350    /// Returns the metadata names contributing to this selection.
351    pub fn names(&self) -> &[String] {
352        &self.names
353    }
354
355    pub(crate) fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
356        let mut resolved = Vec::with_capacity(self.names.len());
357        for name in &self.names {
358            metadata.append_indices_for_name(name, &mut resolved)?;
359        }
360        self.indices = resolved;
361        Ok(())
362    }
363
364    /// The resolved column indices backing this selection.
365    pub fn indices(&self) -> &[usize] {
366        &self.indices
367    }
368
369    pub(crate) fn momentum(&self, event: &EventData) -> Vec4 {
370        event.get_p4_sum(self.indices())
371    }
372}
373
374/// Helper trait to convert common particle specifications into [`P4Selection`] instances.
375pub trait IntoP4Selection {
376    /// Convert the input into a [`P4Selection`].
377    fn into_selection(self) -> P4Selection;
378}
379
380impl IntoP4Selection for P4Selection {
381    fn into_selection(self) -> P4Selection {
382        self
383    }
384}
385
386impl IntoP4Selection for &P4Selection {
387    fn into_selection(self) -> P4Selection {
388        self.clone()
389    }
390}
391
392impl IntoP4Selection for String {
393    fn into_selection(self) -> P4Selection {
394        P4Selection::new_many(vec![self])
395    }
396}
397
398impl IntoP4Selection for &String {
399    fn into_selection(self) -> P4Selection {
400        P4Selection::new_many(vec![self.clone()])
401    }
402}
403
404impl IntoP4Selection for &str {
405    fn into_selection(self) -> P4Selection {
406        P4Selection::new_many(vec![self.to_string()])
407    }
408}
409
410impl<S> IntoP4Selection for Vec<S>
411where
412    S: Into<String>,
413{
414    fn into_selection(self) -> P4Selection {
415        P4Selection::new_many(self.into_iter().map(Into::into).collect::<Vec<_>>())
416    }
417}
418
419impl<S> IntoP4Selection for &[S]
420where
421    S: Clone + Into<String>,
422{
423    fn into_selection(self) -> P4Selection {
424        P4Selection::new_many(self.iter().cloned().map(Into::into).collect::<Vec<_>>())
425    }
426}
427
428impl<S, const N: usize> IntoP4Selection for [S; N]
429where
430    S: Into<String>,
431{
432    fn into_selection(self) -> P4Selection {
433        P4Selection::new_many(self.into_iter().map(Into::into).collect::<Vec<_>>())
434    }
435}
436
437impl<S, const N: usize> IntoP4Selection for &[S; N]
438where
439    S: Clone + Into<String>,
440{
441    fn into_selection(self) -> P4Selection {
442        P4Selection::new_many(self.iter().cloned().map(Into::into).collect::<Vec<_>>())
443    }
444}
445
446/// A reusable 2-to-2 reaction description shared by several kinematic variables.
447///
448/// A topology records the four canonical vertices $`k_1 + k_2 \to k_3 + k_4`$.
449/// When one vertex is omitted, it is reconstructed by enforcing four-momentum
450/// conservation, which is unambiguous in that frame. Use [`Topology::com_boost_vector`]
451/// and the `*_com` helpers to access particles in the center-of-momentum frame.
452///
453/// ```text
454/// k1  k3
455///  ╲  ╱
456///   â•­â•®
457///   ╰╯
458///  ╱  ╲
459/// k2  k4
460/// ```
461///
462/// Note that variables are typically designed to use $`k_1`$ as the incoming beam, $`k_2`$ as a
463/// target, $`k_3`$ as some resonance, and $`k_4`$ as the recoiling target particle, but this
464/// notation should be extensible to any 2-to-2 reaction.
465#[derive(Clone, Debug, Serialize, Deserialize)]
466pub enum Topology {
467    /// All four vertices are explicitly provided.
468    Full {
469        /// First incoming vertex.
470        k1: P4Selection,
471        /// Second incoming vertex.
472        k2: P4Selection,
473        /// First outgoing vertex.
474        k3: P4Selection,
475        /// Second outgoing vertex.
476        k4: P4Selection,
477    },
478    /// The first incoming vertex (`k1`) is reconstructed.
479    MissingK1 {
480        /// Second incoming vertex.
481        k2: P4Selection,
482        /// First outgoing vertex.
483        k3: P4Selection,
484        /// Second outgoing vertex.
485        k4: P4Selection,
486    },
487    /// The second incoming vertex (`k2`) is reconstructed.
488    MissingK2 {
489        /// First incoming vertex.
490        k1: P4Selection,
491        /// First outgoing vertex.
492        k3: P4Selection,
493        /// Second outgoing vertex.
494        k4: P4Selection,
495    },
496    /// The first outgoing vertex (`k3`) is reconstructed.
497    MissingK3 {
498        /// First incoming vertex.
499        k1: P4Selection,
500        /// Second incoming vertex.
501        k2: P4Selection,
502        /// Second outgoing vertex.
503        k4: P4Selection,
504    },
505    /// The second outgoing vertex (`k4`) is reconstructed.
506    MissingK4 {
507        /// First incoming vertex.
508        k1: P4Selection,
509        /// Second incoming vertex.
510        k2: P4Selection,
511        /// First outgoing vertex.
512        k3: P4Selection,
513    },
514}
515
516impl Topology {
517    /// Construct a topology with all four vertices explicitly defined.
518    pub fn new<K1, K2, K3, K4>(k1: K1, k2: K2, k3: K3, k4: K4) -> Self
519    where
520        K1: IntoP4Selection,
521        K2: IntoP4Selection,
522        K3: IntoP4Selection,
523        K4: IntoP4Selection,
524    {
525        Self::Full {
526            k1: k1.into_selection(),
527            k2: k2.into_selection(),
528            k3: k3.into_selection(),
529            k4: k4.into_selection(),
530        }
531    }
532
533    /// Construct a topology when the first incoming vertex (`k1`) is omitted.
534    pub fn missing_k1<K2, K3, K4>(k2: K2, k3: K3, k4: K4) -> Self
535    where
536        K2: IntoP4Selection,
537        K3: IntoP4Selection,
538        K4: IntoP4Selection,
539    {
540        Self::MissingK1 {
541            k2: k2.into_selection(),
542            k3: k3.into_selection(),
543            k4: k4.into_selection(),
544        }
545    }
546
547    /// Construct a topology when the second incoming vertex (`k2`) is omitted.
548    pub fn missing_k2<K1, K3, K4>(k1: K1, k3: K3, k4: K4) -> Self
549    where
550        K1: IntoP4Selection,
551        K3: IntoP4Selection,
552        K4: IntoP4Selection,
553    {
554        Self::MissingK2 {
555            k1: k1.into_selection(),
556            k3: k3.into_selection(),
557            k4: k4.into_selection(),
558        }
559    }
560
561    /// Construct a topology when the first outgoing vertex (`k3`) is omitted.
562    pub fn missing_k3<K1, K2, K4>(k1: K1, k2: K2, k4: K4) -> Self
563    where
564        K1: IntoP4Selection,
565        K2: IntoP4Selection,
566        K4: IntoP4Selection,
567    {
568        Self::MissingK3 {
569            k1: k1.into_selection(),
570            k2: k2.into_selection(),
571            k4: k4.into_selection(),
572        }
573    }
574
575    /// Construct a topology when the second outgoing vertex (`k4`) is omitted.
576    pub fn missing_k4<K1, K2, K3>(k1: K1, k2: K2, k3: K3) -> Self
577    where
578        K1: IntoP4Selection,
579        K2: IntoP4Selection,
580        K3: IntoP4Selection,
581    {
582        Self::MissingK4 {
583            k1: k1.into_selection(),
584            k2: k2.into_selection(),
585            k3: k3.into_selection(),
586        }
587    }
588
589    /// Bind every vertex to dataset metadata so the particle names resolve to indices.
590    pub fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
591        match self {
592            Topology::Full { k1, k2, k3, k4 } => {
593                k1.bind(metadata)?;
594                k2.bind(metadata)?;
595                k3.bind(metadata)?;
596                k4.bind(metadata)?;
597            }
598            Topology::MissingK1 { k2, k3, k4 } => {
599                k2.bind(metadata)?;
600                k3.bind(metadata)?;
601                k4.bind(metadata)?;
602            }
603            Topology::MissingK2 { k1, k3, k4 } => {
604                k1.bind(metadata)?;
605                k3.bind(metadata)?;
606                k4.bind(metadata)?;
607            }
608            Topology::MissingK3 { k1, k2, k4 } => {
609                k1.bind(metadata)?;
610                k2.bind(metadata)?;
611                k4.bind(metadata)?;
612            }
613            Topology::MissingK4 { k1, k2, k3 } => {
614                k1.bind(metadata)?;
615                k2.bind(metadata)?;
616                k3.bind(metadata)?;
617            }
618        }
619        Ok(())
620    }
621
622    /// Return the velocity vector that boosts lab-frame momenta into the diagram's
623    /// center-of-momentum frame.
624    pub fn com_boost_vector(&self, event: &EventData) -> Vec3 {
625        match self {
626            Topology::Full { k3, k4, .. }
627            | Topology::MissingK1 { k3, k4, .. }
628            | Topology::MissingK2 { k3, k4, .. } => {
629                -(k3.momentum(event) + k4.momentum(event)).beta()
630            }
631            Topology::MissingK3 { k1, k2, .. } | Topology::MissingK4 { k1, k2, .. } => {
632                -(k1.momentum(event) + k2.momentum(event)).beta()
633            }
634        }
635    }
636
637    /// Convenience helper returning the beam four-momentum (`k1`).
638    pub fn k1(&self, event: &EventData) -> Vec4 {
639        match self {
640            Topology::Full { k1, .. }
641            | Topology::MissingK2 { k1, .. }
642            | Topology::MissingK3 { k1, .. }
643            | Topology::MissingK4 { k1, .. } => k1.momentum(event),
644            Topology::MissingK1 { k2, k3, k4 } => {
645                k3.momentum(event) + k4.momentum(event) - k2.momentum(event)
646            }
647        }
648    }
649
650    /// Convenience helper returning the target four-momentum (`k2`).
651    pub fn k2(&self, event: &EventData) -> Vec4 {
652        match self {
653            Topology::Full { k2, .. }
654            | Topology::MissingK1 { k2, .. }
655            | Topology::MissingK3 { k2, .. }
656            | Topology::MissingK4 { k2, .. } => k2.momentum(event),
657            Topology::MissingK2 { k1, k3, k4 } => {
658                k3.momentum(event) + k4.momentum(event) - k1.momentum(event)
659            }
660        }
661    }
662
663    /// Convenience helper returning the resonance four-momentum (`k3`).
664    pub fn k3(&self, event: &EventData) -> Vec4 {
665        match self {
666            Topology::Full { k3, .. }
667            | Topology::MissingK1 { k3, .. }
668            | Topology::MissingK2 { k3, .. }
669            | Topology::MissingK4 { k3, .. } => k3.momentum(event),
670            Topology::MissingK3 { k1, k2, k4 } => {
671                k1.momentum(event) + k2.momentum(event) - k4.momentum(event)
672            }
673        }
674    }
675
676    /// Convenience helper returning the recoil four-momentum (`k4`).
677    pub fn k4(&self, event: &EventData) -> Vec4 {
678        match self {
679            Topology::Full { k4, .. }
680            | Topology::MissingK1 { k4, .. }
681            | Topology::MissingK2 { k4, .. }
682            | Topology::MissingK3 { k4, .. } => k4.momentum(event),
683            Topology::MissingK4 { k1, k2, k3 } => {
684                k1.momentum(event) + k2.momentum(event) - k3.momentum(event)
685            }
686        }
687    }
688
689    /// Beam four-momentum (`k1`) expressed in the center-of-momentum frame.
690    pub fn k1_com(&self, event: &EventData) -> Vec4 {
691        self.k1(event).boost(&self.com_boost_vector(event))
692    }
693
694    /// Target four-momentum (`k2`) expressed in the center-of-momentum frame.
695    pub fn k2_com(&self, event: &EventData) -> Vec4 {
696        self.k2(event).boost(&self.com_boost_vector(event))
697    }
698
699    /// Resonance four-momentum (`k3`) expressed in the center-of-momentum frame.
700    pub fn k3_com(&self, event: &EventData) -> Vec4 {
701        self.k3(event).boost(&self.com_boost_vector(event))
702    }
703
704    /// Recoil four-momentum (`k4`) expressed in the center-of-momentum frame.
705    pub fn k4_com(&self, event: &EventData) -> Vec4 {
706        self.k4(event).boost(&self.com_boost_vector(event))
707    }
708
709    /// Returns the resolved names for `k1` if it was explicitly provided.
710    pub fn k1_names(&self) -> Option<&[String]> {
711        match self {
712            Topology::Full { k1, .. }
713            | Topology::MissingK2 { k1, .. }
714            | Topology::MissingK3 { k1, .. }
715            | Topology::MissingK4 { k1, .. } => Some(k1.names()),
716            Topology::MissingK1 { .. } => None,
717        }
718    }
719
720    /// Returns the resolved names for `k2` if it was explicitly provided.
721    pub fn k2_names(&self) -> Option<&[String]> {
722        match self {
723            Topology::Full { k2, .. }
724            | Topology::MissingK1 { k2, .. }
725            | Topology::MissingK3 { k2, .. }
726            | Topology::MissingK4 { k2, .. } => Some(k2.names()),
727            Topology::MissingK2 { .. } => None,
728        }
729    }
730
731    /// Returns the resolved names for `k3` if it was explicitly provided.
732    pub fn k3_names(&self) -> Option<&[String]> {
733        match self {
734            Topology::Full { k3, .. }
735            | Topology::MissingK1 { k3, .. }
736            | Topology::MissingK2 { k3, .. }
737            | Topology::MissingK4 { k3, .. } => Some(k3.names()),
738            Topology::MissingK3 { .. } => None,
739        }
740    }
741
742    /// Returns the resolved names for `k4` if it was explicitly provided.
743    pub fn k4_names(&self) -> Option<&[String]> {
744        match self {
745            Topology::Full { k4, .. }
746            | Topology::MissingK1 { k4, .. }
747            | Topology::MissingK2 { k4, .. }
748            | Topology::MissingK3 { k4, .. } => Some(k4.names()),
749            Topology::MissingK4 { .. } => None,
750        }
751    }
752}
753
754impl Display for Topology {
755    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
756        write!(
757            f,
758            "Topology(k1=[{}], k2=[{}], k3=[{}], k4=[{}])",
759            format_topology_names(self.k1_names()),
760            format_topology_names(self.k2_names()),
761            format_topology_names(self.k3_names()),
762            format_topology_names(self.k4_names())
763        )
764    }
765}
766
767fn format_topology_names(names: Option<&[String]>) -> String {
768    match names {
769        Some(names) if !names.is_empty() => names_to_string(names),
770        Some(_) => String::new(),
771        None => "<reconstructed>".to_string(),
772    }
773}
774
775#[derive(Clone, Debug, Serialize, Deserialize)]
776struct AuxSelection {
777    name: String,
778    #[serde(skip, default)]
779    index: Option<usize>,
780}
781
782impl AuxSelection {
783    fn new<S: Into<String>>(name: S) -> Self {
784        Self {
785            name: name.into(),
786            index: None,
787        }
788    }
789
790    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
791        let idx = metadata
792            .aux_index(&self.name)
793            .ok_or_else(|| LadduError::UnknownName {
794                category: "aux",
795                name: self.name.clone(),
796            })?;
797        self.index = Some(idx);
798        Ok(())
799    }
800
801    fn index(&self) -> usize {
802        self.index.expect("AuxSelection must be bound before use")
803    }
804
805    fn name(&self) -> &str {
806        &self.name
807    }
808}
809
810/// A struct for obtaining the mass of a particle by indexing the four-momenta of an event, adding
811/// together multiple four-momenta if more than one entry is given.
812#[derive(Clone, Debug, Serialize, Deserialize)]
813pub struct Mass {
814    constituents: P4Selection,
815}
816impl Mass {
817    /// Create a new [`Mass`] from the sum of the four-momenta identified by `constituents` in the
818    /// [`EventData`]'s `p4s` field.
819    pub fn new<C>(constituents: C) -> Self
820    where
821        C: IntoP4Selection,
822    {
823        Self {
824            constituents: constituents.into_selection(),
825        }
826    }
827}
828impl Display for Mass {
829    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
830        write!(
831            f,
832            "Mass(constituents=[{}])",
833            names_to_string(self.constituents.names())
834        )
835    }
836}
837#[typetag::serde]
838impl Variable for Mass {
839    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
840        self.constituents.bind(metadata)
841    }
842    fn value(&self, event: &NamedEventView<'_>) -> f64 {
843        self.constituents
844            .indices()
845            .iter()
846            .map(|index| event.p4_at(*index))
847            .sum::<Vec4>()
848            .m()
849    }
850}
851
852/// A struct for obtaining the $`\cos\theta`$ (cosine of the polar angle) of a decay product in
853/// a given reference frame of its parent resonance.
854#[derive(Clone, Debug, Serialize, Deserialize)]
855pub struct CosTheta {
856    topology: Topology,
857    daughter: P4Selection,
858    frame: Frame,
859}
860impl Display for CosTheta {
861    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
862        write!(
863            f,
864            "CosTheta(topology={}, daughter=[{}], frame={})",
865            self.topology,
866            names_to_string(self.daughter.names()),
867            self.frame
868        )
869    }
870}
871impl CosTheta {
872    /// Construct the angle given a [`Topology`] describing the production kinematics along with a
873    /// decay daughter of the `k3` resonance. See [`Frame`] for options regarding the reference
874    /// frame.
875    pub fn new<D>(topology: Topology, daughter: D, frame: Frame) -> Self
876    where
877        D: IntoP4Selection,
878    {
879        Self {
880            topology,
881            daughter: daughter.into_selection(),
882            frame,
883        }
884    }
885}
886
887#[typetag::serde]
888impl Variable for CosTheta {
889    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
890        self.topology.bind(metadata)?;
891        self.daughter.bind(metadata)?;
892        Ok(())
893    }
894    fn value(&self, event: &NamedEventView<'_>) -> f64 {
895        let p4_sum = |indices: &[usize]| {
896            indices
897                .iter()
898                .map(|index| event.p4_at(*index))
899                .sum::<Vec4>()
900        };
901        let k1 = match &self.topology {
902            Topology::Full { k1, .. }
903            | Topology::MissingK2 { k1, .. }
904            | Topology::MissingK3 { k1, .. }
905            | Topology::MissingK4 { k1, .. } => p4_sum(k1.indices()),
906            Topology::MissingK1 { k2, k3, k4 } => {
907                p4_sum(k3.indices()) + p4_sum(k4.indices()) - p4_sum(k2.indices())
908            }
909        };
910        let k3 = match &self.topology {
911            Topology::Full { k3, .. }
912            | Topology::MissingK1 { k3, .. }
913            | Topology::MissingK2 { k3, .. }
914            | Topology::MissingK4 { k3, .. } => p4_sum(k3.indices()),
915            Topology::MissingK3 { k1, k2, k4 } => {
916                p4_sum(k1.indices()) + p4_sum(k2.indices()) - p4_sum(k4.indices())
917            }
918        };
919        let k4 = match &self.topology {
920            Topology::Full { k4, .. }
921            | Topology::MissingK1 { k4, .. }
922            | Topology::MissingK2 { k4, .. }
923            | Topology::MissingK3 { k4, .. } => p4_sum(k4.indices()),
924            Topology::MissingK4 { k1, k2, k3 } => {
925                p4_sum(k1.indices()) + p4_sum(k2.indices()) - p4_sum(k3.indices())
926            }
927        };
928        let com_boost = match &self.topology {
929            Topology::Full { k3, k4, .. }
930            | Topology::MissingK1 { k3, k4, .. }
931            | Topology::MissingK2 { k3, k4, .. } => {
932                -(p4_sum(k3.indices()) + p4_sum(k4.indices())).beta()
933            }
934            Topology::MissingK3 { k1, k2, .. } | Topology::MissingK4 { k1, k2, .. } => {
935                -(p4_sum(k1.indices()) + p4_sum(k2.indices())).beta()
936            }
937        };
938        let beam = k1.boost(&com_boost);
939        let resonance = k3.boost(&com_boost);
940        let daughter = p4_sum(self.daughter.indices()).boost(&com_boost);
941        let daughter_res = daughter.boost(&-resonance.beta());
942        let plane_normal = beam.vec3().cross(&resonance.vec3()).unit();
943        let z = match self.frame {
944            Frame::Helicity => {
945                let recoil_res = k4.boost(&com_boost).boost(&-resonance.beta());
946                (-recoil_res.vec3()).unit()
947            }
948            Frame::GottfriedJackson => beam.boost(&-resonance.beta()).vec3().unit(),
949        };
950        let x = plane_normal.cross(&z).unit();
951        let y = z.cross(&x).unit();
952        Vec3::new(
953            daughter_res.vec3().dot(&x),
954            daughter_res.vec3().dot(&y),
955            daughter_res.vec3().dot(&z),
956        )
957        .costheta()
958    }
959}
960
961/// A struct for obtaining the $`\phi`$ angle (azimuthal angle) of a decay product in a given
962/// reference frame of its parent resonance.
963#[derive(Clone, Debug, Serialize, Deserialize)]
964pub struct Phi {
965    topology: Topology,
966    daughter: P4Selection,
967    frame: Frame,
968}
969impl Display for Phi {
970    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
971        write!(
972            f,
973            "Phi(topology={}, daughter=[{}], frame={})",
974            self.topology,
975            names_to_string(self.daughter.names()),
976            self.frame
977        )
978    }
979}
980impl Phi {
981    /// Construct the angle given a [`Topology`] describing the production kinematics along with a
982    /// daughter of the resonance defined by `k3`. See [`Frame`] for options regarding the
983    /// reference frame.
984    pub fn new<D>(topology: Topology, daughter: D, frame: Frame) -> Self
985    where
986        D: IntoP4Selection,
987    {
988        Self {
989            topology,
990            daughter: daughter.into_selection(),
991            frame,
992        }
993    }
994}
995#[typetag::serde]
996impl Variable for Phi {
997    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
998        self.topology.bind(metadata)?;
999        self.daughter.bind(metadata)?;
1000        Ok(())
1001    }
1002    fn value(&self, event: &NamedEventView<'_>) -> f64 {
1003        let p4_sum = |indices: &[usize]| {
1004            indices
1005                .iter()
1006                .map(|index| event.p4_at(*index))
1007                .sum::<Vec4>()
1008        };
1009        let k1 = match &self.topology {
1010            Topology::Full { k1, .. }
1011            | Topology::MissingK2 { k1, .. }
1012            | Topology::MissingK3 { k1, .. }
1013            | Topology::MissingK4 { k1, .. } => p4_sum(k1.indices()),
1014            Topology::MissingK1 { k2, k3, k4 } => {
1015                p4_sum(k3.indices()) + p4_sum(k4.indices()) - p4_sum(k2.indices())
1016            }
1017        };
1018        let k3 = match &self.topology {
1019            Topology::Full { k3, .. }
1020            | Topology::MissingK1 { k3, .. }
1021            | Topology::MissingK2 { k3, .. }
1022            | Topology::MissingK4 { k3, .. } => p4_sum(k3.indices()),
1023            Topology::MissingK3 { k1, k2, k4 } => {
1024                p4_sum(k1.indices()) + p4_sum(k2.indices()) - p4_sum(k4.indices())
1025            }
1026        };
1027        let k4 = match &self.topology {
1028            Topology::Full { k4, .. }
1029            | Topology::MissingK1 { k4, .. }
1030            | Topology::MissingK2 { k4, .. }
1031            | Topology::MissingK3 { k4, .. } => p4_sum(k4.indices()),
1032            Topology::MissingK4 { k1, k2, k3 } => {
1033                p4_sum(k1.indices()) + p4_sum(k2.indices()) - p4_sum(k3.indices())
1034            }
1035        };
1036        let com_boost = match &self.topology {
1037            Topology::Full { k3, k4, .. }
1038            | Topology::MissingK1 { k3, k4, .. }
1039            | Topology::MissingK2 { k3, k4, .. } => {
1040                -(p4_sum(k3.indices()) + p4_sum(k4.indices())).beta()
1041            }
1042            Topology::MissingK3 { k1, k2, .. } | Topology::MissingK4 { k1, k2, .. } => {
1043                -(p4_sum(k1.indices()) + p4_sum(k2.indices())).beta()
1044            }
1045        };
1046        let beam = k1.boost(&com_boost);
1047        let resonance = k3.boost(&com_boost);
1048        let daughter = p4_sum(self.daughter.indices()).boost(&com_boost);
1049        let daughter_res = daughter.boost(&-resonance.beta());
1050        let plane_normal = beam.vec3().cross(&resonance.vec3()).unit();
1051        let z = match self.frame {
1052            Frame::Helicity => {
1053                let recoil_res = k4.boost(&com_boost).boost(&-resonance.beta());
1054                (-recoil_res.vec3()).unit()
1055            }
1056            Frame::GottfriedJackson => beam.boost(&-resonance.beta()).vec3().unit(),
1057        };
1058        let x = plane_normal.cross(&z).unit();
1059        let y = z.cross(&x).unit();
1060        Vec3::new(
1061            daughter_res.vec3().dot(&x),
1062            daughter_res.vec3().dot(&y),
1063            daughter_res.vec3().dot(&z),
1064        )
1065        .phi()
1066    }
1067}
1068
1069/// A struct for obtaining both spherical angles at the same time.
1070#[derive(Clone, Debug, Serialize, Deserialize)]
1071pub struct Angles {
1072    /// See [`CosTheta`].
1073    pub costheta: CosTheta,
1074    /// See [`Phi`].
1075    pub phi: Phi,
1076}
1077
1078impl Display for Angles {
1079    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1080        write!(
1081            f,
1082            "Angles(topology={}, daughter=[{}], frame={})",
1083            self.costheta.topology,
1084            names_to_string(self.costheta.daughter.names()),
1085            self.costheta.frame
1086        )
1087    }
1088}
1089impl Angles {
1090    /// Construct the angles given a [`Topology`] along with the daughter selection.
1091    /// See [`Frame`] for options regarding the reference frame.
1092    pub fn new<D>(topology: Topology, daughter: D, frame: Frame) -> Self
1093    where
1094        D: IntoP4Selection,
1095    {
1096        let daughter_vertex = daughter.into_selection();
1097        let costheta = CosTheta::new(topology.clone(), daughter_vertex.clone(), frame);
1098        let phi = Phi::new(topology, daughter_vertex, frame);
1099        Self { costheta, phi }
1100    }
1101}
1102
1103/// A struct defining the polarization angle for a beam relative to the production plane.
1104#[derive(Clone, Debug, Serialize, Deserialize)]
1105pub struct PolAngle {
1106    topology: Topology,
1107    angle_aux: AuxSelection,
1108}
1109impl Display for PolAngle {
1110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1111        write!(
1112            f,
1113            "PolAngle(topology={}, angle_aux={})",
1114            self.topology,
1115            self.angle_aux.name(),
1116        )
1117    }
1118}
1119impl PolAngle {
1120    /// Constructs the polarization angle given a [`Topology`] describing the production plane and
1121    /// the auxiliary column storing the precomputed angle.
1122    pub fn new<A>(topology: Topology, angle_aux: A) -> Self
1123    where
1124        A: Into<String>,
1125    {
1126        Self {
1127            topology,
1128            angle_aux: AuxSelection::new(angle_aux.into()),
1129        }
1130    }
1131}
1132#[typetag::serde]
1133impl Variable for PolAngle {
1134    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
1135        self.topology.bind(metadata)?;
1136        self.angle_aux.bind(metadata)?;
1137        Ok(())
1138    }
1139    fn value(&self, event: &NamedEventView<'_>) -> f64 {
1140        let p4_sum = |indices: &[usize]| {
1141            indices
1142                .iter()
1143                .map(|index| event.p4_at(*index))
1144                .sum::<Vec4>()
1145        };
1146        let beam = match &self.topology {
1147            Topology::Full { k1, .. }
1148            | Topology::MissingK2 { k1, .. }
1149            | Topology::MissingK3 { k1, .. }
1150            | Topology::MissingK4 { k1, .. } => p4_sum(k1.indices()),
1151            Topology::MissingK1 { k2, k3, k4 } => {
1152                p4_sum(k3.indices()) + p4_sum(k4.indices()) - p4_sum(k2.indices())
1153            }
1154        };
1155        let recoil = match &self.topology {
1156            Topology::Full { k4, .. }
1157            | Topology::MissingK1 { k4, .. }
1158            | Topology::MissingK2 { k4, .. }
1159            | Topology::MissingK3 { k4, .. } => p4_sum(k4.indices()),
1160            Topology::MissingK4 { k1, k2, k3 } => {
1161                p4_sum(k1.indices()) + p4_sum(k2.indices()) - p4_sum(k3.indices())
1162            }
1163        };
1164        let pol_angle = event.aux_at(self.angle_aux.index());
1165        let polarization = Vec3::new(pol_angle.cos(), pol_angle.sin(), 0.0);
1166        let y = beam.vec3().cross(&-recoil.vec3()).unit();
1167        let numerator = y.dot(&polarization);
1168        let denominator = beam.vec3().unit().dot(&polarization.cross(&y));
1169        f64::atan2(numerator, denominator)
1170    }
1171}
1172
1173/// A struct defining the polarization magnitude for a beam relative to the production plane.
1174#[derive(Clone, Debug, Serialize, Deserialize)]
1175pub struct PolMagnitude {
1176    magnitude_aux: AuxSelection,
1177}
1178impl Display for PolMagnitude {
1179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1180        write!(
1181            f,
1182            "PolMagnitude(magnitude_aux={})",
1183            self.magnitude_aux.name(),
1184        )
1185    }
1186}
1187impl PolMagnitude {
1188    /// Constructs the polarization magnitude given the named auxiliary column containing the
1189    /// magnitude value.
1190    pub fn new<S: Into<String>>(magnitude_aux: S) -> Self {
1191        Self {
1192            magnitude_aux: AuxSelection::new(magnitude_aux.into()),
1193        }
1194    }
1195}
1196#[typetag::serde]
1197impl Variable for PolMagnitude {
1198    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
1199        self.magnitude_aux.bind(metadata)
1200    }
1201    fn value(&self, event: &NamedEventView<'_>) -> f64 {
1202        event.aux_at(self.magnitude_aux.index())
1203    }
1204}
1205
1206/// A struct for obtaining both the polarization angle and magnitude at the same time.
1207#[derive(Clone, Debug, Serialize, Deserialize)]
1208pub struct Polarization {
1209    /// See [`PolMagnitude`].
1210    pub pol_magnitude: PolMagnitude,
1211    /// See [`PolAngle`].
1212    pub pol_angle: PolAngle,
1213}
1214impl Display for Polarization {
1215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1216        write!(
1217            f,
1218            "Polarization(topology={}, magnitude_aux={}, angle_aux={})",
1219            self.pol_angle.topology,
1220            self.pol_magnitude.magnitude_aux.name(),
1221            self.pol_angle.angle_aux.name(),
1222        )
1223    }
1224}
1225impl Polarization {
1226    /// Constructs the polarization angle and magnitude given a [`Topology`] and distinct
1227    /// auxiliary columns for magnitude and angle.
1228    ///
1229    /// # Panics
1230    ///
1231    /// Panics if `magnitude_aux` and `angle_aux` refer to the same auxiliary column name.
1232    pub fn new<M, A>(topology: Topology, magnitude_aux: M, angle_aux: A) -> Self
1233    where
1234        M: Into<String>,
1235        A: Into<String>,
1236    {
1237        let magnitude_aux = magnitude_aux.into();
1238        let angle_aux = angle_aux.into();
1239        assert!(
1240            magnitude_aux != angle_aux,
1241            "Polarization magnitude and angle must reference distinct auxiliary columns"
1242        );
1243        Self {
1244            pol_magnitude: PolMagnitude::new(magnitude_aux),
1245            pol_angle: PolAngle::new(topology, angle_aux),
1246        }
1247    }
1248}
1249
1250/// A struct used to calculate Mandelstam variables ($`s`$, $`t`$, or $`u`$).
1251///
1252/// By convention, the metric is chosen to be $`(+---)`$ and the variables are defined as follows
1253/// (ignoring factors of $`c`$):
1254///
1255/// $`s = (p_1 + p_2)^2 = (p_3 + p_4)^2`$
1256///
1257/// $`t = (p_1 - p_3)^2 = (p_4 - p_2)^2`$
1258///
1259/// $`u = (p_1 - p_4)^2 = (p_3 - p_2)^2`$
1260#[derive(Clone, Debug, Serialize, Deserialize)]
1261pub struct Mandelstam {
1262    topology: Topology,
1263    channel: Channel,
1264}
1265impl Display for Mandelstam {
1266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1267        write!(
1268            f,
1269            "Mandelstam(topology={}, channel={})",
1270            self.topology, self.channel,
1271        )
1272    }
1273}
1274impl Mandelstam {
1275    /// Constructs the Mandelstam variable for the given `channel` using the supplied [`Topology`].
1276    pub fn new(topology: Topology, channel: Channel) -> Self {
1277        Self { topology, channel }
1278    }
1279}
1280
1281#[typetag::serde]
1282impl Variable for Mandelstam {
1283    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
1284        self.topology.bind(metadata)
1285    }
1286    fn value(&self, event: &NamedEventView<'_>) -> f64 {
1287        let p4_sum = |indices: &[usize]| {
1288            indices
1289                .iter()
1290                .map(|index| event.p4_at(*index))
1291                .sum::<Vec4>()
1292        };
1293        let k1 = match &self.topology {
1294            Topology::Full { k1, .. }
1295            | Topology::MissingK2 { k1, .. }
1296            | Topology::MissingK3 { k1, .. }
1297            | Topology::MissingK4 { k1, .. } => p4_sum(k1.indices()),
1298            Topology::MissingK1 { k2, k3, k4 } => {
1299                p4_sum(k3.indices()) + p4_sum(k4.indices()) - p4_sum(k2.indices())
1300            }
1301        };
1302        match self.channel {
1303            Channel::S => {
1304                let k2 = match &self.topology {
1305                    Topology::Full { k2, .. }
1306                    | Topology::MissingK1 { k2, .. }
1307                    | Topology::MissingK3 { k2, .. }
1308                    | Topology::MissingK4 { k2, .. } => p4_sum(k2.indices()),
1309                    Topology::MissingK2 { k1, k3, k4 } => {
1310                        p4_sum(k3.indices()) + p4_sum(k4.indices()) - p4_sum(k1.indices())
1311                    }
1312                };
1313                (k1 + k2).mag2()
1314            }
1315            Channel::T => {
1316                let k3 = match &self.topology {
1317                    Topology::Full { k3, .. }
1318                    | Topology::MissingK1 { k3, .. }
1319                    | Topology::MissingK2 { k3, .. }
1320                    | Topology::MissingK4 { k3, .. } => p4_sum(k3.indices()),
1321                    Topology::MissingK3 { k1, k2, k4 } => {
1322                        p4_sum(k1.indices()) + p4_sum(k2.indices()) - p4_sum(k4.indices())
1323                    }
1324                };
1325                (k1 - k3).mag2()
1326            }
1327            Channel::U => {
1328                let k4 = match &self.topology {
1329                    Topology::Full { k4, .. }
1330                    | Topology::MissingK1 { k4, .. }
1331                    | Topology::MissingK2 { k4, .. }
1332                    | Topology::MissingK3 { k4, .. } => p4_sum(k4.indices()),
1333                    Topology::MissingK4 { k1, k2, k3 } => {
1334                        p4_sum(k1.indices()) + p4_sum(k2.indices()) - p4_sum(k3.indices())
1335                    }
1336                };
1337                (k1 - k4).mag2()
1338            }
1339        }
1340    }
1341}
1342
1343#[cfg(test)]
1344mod tests {
1345    use super::*;
1346    use crate::data::{test_dataset, test_event};
1347    use approx::assert_relative_eq;
1348
1349    fn topology_test_metadata() -> DatasetMetadata {
1350        DatasetMetadata::new(
1351            vec!["beam", "target", "resonance", "recoil"],
1352            Vec::<String>::new(),
1353        )
1354        .expect("topology metadata should be valid")
1355    }
1356
1357    fn topology_test_event() -> EventData {
1358        let p1 = Vec4::new(0.0, 0.0, 3.0, 3.5);
1359        let p2 = Vec4::new(0.0, 0.0, -3.0, 3.5);
1360        let p3 = Vec4::new(0.5, -0.25, 1.0, 1.9);
1361        let p4 = p1 + p2 - p3;
1362        EventData {
1363            p4s: vec![p1, p2, p3, p4],
1364            aux: vec![],
1365            weight: 1.0,
1366        }
1367    }
1368
1369    fn reaction_topology() -> Topology {
1370        Topology::missing_k2("beam", ["kshort1", "kshort2"], "proton")
1371    }
1372
1373    #[test]
1374    #[allow(clippy::needless_borrows_for_generic_args)]
1375    fn test_topology_accepts_varied_inputs() {
1376        let topo = Topology::new(
1377            "particle1",
1378            ["particle2a", "particle2b"],
1379            &["particle3"],
1380            "particle4".to_string(),
1381        );
1382        assert_eq!(
1383            topo.k1_names()
1384                .unwrap()
1385                .iter()
1386                .map(String::as_str)
1387                .collect::<Vec<_>>(),
1388            vec!["particle1"]
1389        );
1390        assert_eq!(
1391            topo.k2_names()
1392                .unwrap()
1393                .iter()
1394                .map(String::as_str)
1395                .collect::<Vec<_>>(),
1396            vec!["particle2a", "particle2b"]
1397        );
1398        let missing = Topology::missing_k2("particle1", vec!["particle3"], "particle4");
1399        assert!(missing.k2_names().is_none());
1400        assert!(missing.to_string().contains("<reconstructed>"));
1401    }
1402
1403    #[test]
1404    fn test_topology_reconstructs_missing_vertices() {
1405        let metadata = topology_test_metadata();
1406        let event = topology_test_event();
1407
1408        let mut full = Topology::new("beam", "target", "resonance", "recoil");
1409        full.bind(&metadata).unwrap();
1410        assert_relative_eq!(full.k3(&event), event.p4s[2], epsilon = 1e-12);
1411
1412        let mut missing_k1 = Topology::missing_k1("target", "resonance", "recoil");
1413        missing_k1.bind(&metadata).unwrap();
1414        assert!(missing_k1.k1_names().is_none());
1415        assert_relative_eq!(missing_k1.k1(&event), event.p4s[0], epsilon = 1e-12);
1416
1417        let mut missing_k2 = Topology::missing_k2("beam", "resonance", "recoil");
1418        missing_k2.bind(&metadata).unwrap();
1419        assert!(missing_k2.k2_names().is_none());
1420        assert_relative_eq!(missing_k2.k2(&event), event.p4s[1], epsilon = 1e-12);
1421
1422        let mut missing_k3 = Topology::missing_k3("beam", "target", "recoil");
1423        missing_k3.bind(&metadata).unwrap();
1424        assert!(missing_k3.k3_names().is_none());
1425        assert_relative_eq!(missing_k3.k3(&event), event.p4s[2], epsilon = 1e-12);
1426
1427        let mut missing_k4 = Topology::missing_k4("beam", "target", "resonance");
1428        missing_k4.bind(&metadata).unwrap();
1429        assert!(missing_k4.k4_names().is_none());
1430        assert_relative_eq!(missing_k4.k4(&event), event.p4s[3], epsilon = 1e-12);
1431    }
1432
1433    #[test]
1434    fn test_topology_com_helpers_match_manual_boost() {
1435        let metadata = topology_test_metadata();
1436        let event = topology_test_event();
1437        let mut topo = Topology::new("beam", "target", "resonance", "recoil");
1438        topo.bind(&metadata).unwrap();
1439        let beta = topo.com_boost_vector(&event);
1440        assert_relative_eq!(topo.k1_com(&event), topo.k1(&event).boost(&beta));
1441        assert_relative_eq!(topo.k2_com(&event), topo.k2(&event).boost(&beta));
1442        assert_relative_eq!(topo.k3_com(&event), topo.k3(&event).boost(&beta));
1443        assert_relative_eq!(topo.k4_com(&event), topo.k4(&event).boost(&beta));
1444    }
1445
1446    #[test]
1447    fn test_mass_single_particle() {
1448        let dataset = test_dataset();
1449        let mut mass = Mass::new("proton");
1450        mass.bind(dataset.metadata()).unwrap();
1451        let event = dataset.event_view(0);
1452        assert_relative_eq!(mass.value(&event), 1.007);
1453    }
1454
1455    #[test]
1456    fn test_mass_multiple_particles() {
1457        let dataset = test_dataset();
1458        let mut mass = Mass::new(["kshort1", "kshort2"]);
1459        mass.bind(dataset.metadata()).unwrap();
1460        let event = dataset.event_view(0);
1461        assert_relative_eq!(mass.value(&event), 1.3743786309153077);
1462    }
1463
1464    #[test]
1465    fn test_mass_display() {
1466        let mass = Mass::new(["kshort1", "kshort2"]);
1467        assert_eq!(mass.to_string(), "Mass(constituents=[kshort1, kshort2])");
1468    }
1469
1470    #[test]
1471    fn test_costheta_helicity() {
1472        let dataset = test_dataset();
1473        let mut costheta = CosTheta::new(reaction_topology(), "kshort1", Frame::Helicity);
1474        costheta.bind(dataset.metadata()).unwrap();
1475        let event = dataset.event_view(0);
1476        assert_relative_eq!(costheta.value(&event), -0.4611175068834238, epsilon = 1e-12);
1477    }
1478
1479    #[test]
1480    fn test_costheta_display() {
1481        let costheta = CosTheta::new(reaction_topology(), "kshort1", Frame::Helicity);
1482        assert_eq!(
1483            costheta.to_string(),
1484            "CosTheta(topology=Topology(k1=[beam], k2=[<reconstructed>], k3=[kshort1, kshort2], k4=[proton]), daughter=[kshort1], frame=Helicity)"
1485        );
1486    }
1487
1488    #[test]
1489    fn test_phi_helicity() {
1490        let dataset = test_dataset();
1491        let mut phi = Phi::new(reaction_topology(), "kshort1", Frame::Helicity);
1492        phi.bind(dataset.metadata()).unwrap();
1493        let event = dataset.event_view(0);
1494        assert_relative_eq!(phi.value(&event), -2.657462587335066, epsilon = 1e-12);
1495    }
1496
1497    #[test]
1498    fn test_phi_display() {
1499        let phi = Phi::new(reaction_topology(), "kshort1", Frame::Helicity);
1500        assert_eq!(
1501            phi.to_string(),
1502            "Phi(topology=Topology(k1=[beam], k2=[<reconstructed>], k3=[kshort1, kshort2], k4=[proton]), daughter=[kshort1], frame=Helicity)"
1503        );
1504    }
1505
1506    #[test]
1507    fn test_costheta_gottfried_jackson() {
1508        let dataset = test_dataset();
1509        let mut costheta = CosTheta::new(reaction_topology(), "kshort1", Frame::GottfriedJackson);
1510        costheta.bind(dataset.metadata()).unwrap();
1511        let event = dataset.event_view(0);
1512        assert_relative_eq!(costheta.value(&event), 0.09198832278031577, epsilon = 1e-12);
1513    }
1514
1515    #[test]
1516    fn test_phi_gottfried_jackson() {
1517        let dataset = test_dataset();
1518        let mut phi = Phi::new(reaction_topology(), "kshort1", Frame::GottfriedJackson);
1519        phi.bind(dataset.metadata()).unwrap();
1520        let event = dataset.event_view(0);
1521        assert_relative_eq!(phi.value(&event), -2.713913199133907, epsilon = 1e-12);
1522    }
1523
1524    #[test]
1525    fn test_angles() {
1526        let dataset = test_dataset();
1527        let mut angles = Angles::new(reaction_topology(), "kshort1", Frame::Helicity);
1528        angles.costheta.bind(dataset.metadata()).unwrap();
1529        angles.phi.bind(dataset.metadata()).unwrap();
1530        let event = dataset.event_view(0);
1531        assert_relative_eq!(
1532            angles.costheta.value(&event),
1533            -0.4611175068834238,
1534            epsilon = 1e-12
1535        );
1536        assert_relative_eq!(
1537            angles.phi.value(&event),
1538            -2.657462587335066,
1539            epsilon = 1e-12
1540        );
1541    }
1542
1543    #[test]
1544    fn test_angles_display() {
1545        let angles = Angles::new(reaction_topology(), "kshort1", Frame::Helicity);
1546        assert_eq!(
1547            angles.to_string(),
1548            "Angles(topology=Topology(k1=[beam], k2=[<reconstructed>], k3=[kshort1, kshort2], k4=[proton]), daughter=[kshort1], frame=Helicity)"
1549        );
1550    }
1551
1552    #[test]
1553    fn test_pol_angle() {
1554        let dataset = test_dataset();
1555        let mut pol_angle = PolAngle::new(reaction_topology(), "pol_angle");
1556        pol_angle.bind(dataset.metadata()).unwrap();
1557        let event = dataset.event_view(0);
1558        assert_relative_eq!(pol_angle.value(&event), 1.935929887818673);
1559    }
1560
1561    #[test]
1562    fn test_pol_angle_display() {
1563        let pol_angle = PolAngle::new(reaction_topology(), "pol_angle");
1564        assert_eq!(
1565            pol_angle.to_string(),
1566            "PolAngle(topology=Topology(k1=[beam], k2=[<reconstructed>], k3=[kshort1, kshort2], k4=[proton]), angle_aux=pol_angle)"
1567        );
1568    }
1569
1570    #[test]
1571    fn test_pol_magnitude() {
1572        let dataset = test_dataset();
1573        let mut pol_magnitude = PolMagnitude::new("pol_magnitude");
1574        pol_magnitude.bind(dataset.metadata()).unwrap();
1575        let event = dataset.event_view(0);
1576        assert_relative_eq!(pol_magnitude.value(&event), 0.38562805);
1577    }
1578
1579    #[test]
1580    fn test_pol_magnitude_display() {
1581        let pol_magnitude = PolMagnitude::new("pol_magnitude");
1582        assert_eq!(
1583            pol_magnitude.to_string(),
1584            "PolMagnitude(magnitude_aux=pol_magnitude)"
1585        );
1586    }
1587
1588    #[test]
1589    fn test_polarization() {
1590        let dataset = test_dataset();
1591        let mut polarization = Polarization::new(reaction_topology(), "pol_magnitude", "pol_angle");
1592        polarization.pol_angle.bind(dataset.metadata()).unwrap();
1593        polarization.pol_magnitude.bind(dataset.metadata()).unwrap();
1594        let event = dataset.event_view(0);
1595        assert_relative_eq!(polarization.pol_angle.value(&event), 1.935929887818673);
1596        assert_relative_eq!(polarization.pol_magnitude.value(&event), 0.38562805);
1597    }
1598
1599    #[test]
1600    fn test_polarization_display() {
1601        let polarization = Polarization::new(reaction_topology(), "pol_magnitude", "pol_angle");
1602        assert_eq!(
1603            polarization.to_string(),
1604            "Polarization(topology=Topology(k1=[beam], k2=[<reconstructed>], k3=[kshort1, kshort2], k4=[proton]), magnitude_aux=pol_magnitude, angle_aux=pol_angle)"
1605        );
1606    }
1607
1608    #[test]
1609    fn test_mandelstam() {
1610        let dataset = test_dataset();
1611        let metadata = dataset.metadata();
1612        let mut s = Mandelstam::new(reaction_topology(), Channel::S);
1613        let mut t = Mandelstam::new(reaction_topology(), Channel::T);
1614        let mut u = Mandelstam::new(reaction_topology(), Channel::U);
1615        for variable in [&mut s, &mut t, &mut u] {
1616            variable.bind(metadata).unwrap();
1617        }
1618        let event = dataset.event_view(0);
1619        assert_relative_eq!(s.value(&event), 18.504011052120063);
1620        assert_relative_eq!(t.value(&event), -0.19222859969898076);
1621        assert_relative_eq!(u.value(&event), -14.404198931464428);
1622        let mut direct_topology = reaction_topology();
1623        direct_topology.bind(metadata).unwrap();
1624        let event_data = test_event();
1625        let k2 = direct_topology.k2(&event_data);
1626        let k3 = direct_topology.k3(&event_data);
1627        let k4 = direct_topology.k4(&event_data);
1628        assert_relative_eq!(s.value(&event), (k3 + k4).mag2());
1629        assert_relative_eq!(t.value(&event), (k2 - k4).mag2());
1630        assert_relative_eq!(u.value(&event), (k3 - k2).mag2());
1631        let m2_beam = test_event().get_p4_sum([0]).m2();
1632        let m2_recoil = test_event().get_p4_sum([1]).m2();
1633        let m2_res = test_event().get_p4_sum([2, 3]).m2();
1634        assert_relative_eq!(
1635            s.value(&event) + t.value(&event) + u.value(&event) - m2_beam - m2_recoil - m2_res,
1636            1.00,
1637            epsilon = 1e-2
1638        );
1639        // Note: not very accurate, but considering the values in test_event only go to about 3
1640        // decimal places, this is probably okay
1641    }
1642
1643    #[test]
1644    fn test_mandelstam_display() {
1645        let s = Mandelstam::new(reaction_topology(), Channel::S);
1646        assert_eq!(
1647            s.to_string(),
1648            "Mandelstam(topology=Topology(k1=[beam], k2=[<reconstructed>], k3=[kshort1, kshort2], k4=[proton]), channel=s)"
1649        );
1650    }
1651
1652    #[test]
1653    fn test_variable_value_on() {
1654        let dataset = test_dataset();
1655        let mass = Mass::new(["kshort1", "kshort2"]);
1656
1657        let values = mass.value_on(&dataset).unwrap();
1658        assert_eq!(values.len(), 1);
1659        assert_relative_eq!(values[0], 1.3743786309153077);
1660    }
1661}