laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23    utils::{
24        variables::{Variable, VariableExpression},
25        vectors::{Vec3, Vec4},
26    },
27    Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33/// An event that can be used to test the implementation of an
34/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
35/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
36pub fn test_event() -> Event {
37    use crate::utils::vectors::*;
38    Event {
39        p4s: vec![
40            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
41            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
42            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
43            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
44        ],
45        aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46        weight: 0.48,
47    }
48}
49
50/// An dataset that can be used to test the implementation of an
51/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a singular
52/// [`Event`] generated from [`test_event`].
53pub fn test_dataset() -> Dataset {
54    Dataset::new(vec![Arc::new(test_event())])
55}
56
57/// A single event in a [`Dataset`] containing all the relevant particle information.
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60    /// A list of four-momenta for each particle.
61    pub p4s: Vec<Vec4>,
62    /// A list of auxiliary vectors which can be used to store data like particle polarization.
63    pub aux: Vec<Vec3>,
64    /// The weight given to the event.
65    pub weight: Float,
66}
67
68impl Display for Event {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        writeln!(f, "Event:")?;
71        writeln!(f, "  p4s:")?;
72        for p4 in &self.p4s {
73            writeln!(f, "    {}", p4.to_p4_string())?;
74        }
75        writeln!(f, "  eps:")?;
76        for eps_vec in &self.aux {
77            writeln!(f, "    [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78        }
79        writeln!(f, "  weight:")?;
80        writeln!(f, "    {}", self.weight)?;
81        Ok(())
82    }
83}
84
85impl Event {
86    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`Event`].
87    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89    }
90    /// Boost all the four-momenta in the [`Event`] to the rest frame of the given set of
91    /// four-momenta by indices.
92    pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
93        let frame = self.get_p4_sum(indices);
94        Event {
95            p4s: self
96                .p4s
97                .iter()
98                .map(|p4| p4.boost(&(-frame.beta())))
99                .collect(),
100            aux: self.aux.clone(),
101            weight: self.weight,
102        }
103    }
104    /// Evaluate a [`Variable`] on an [`Event`].
105    pub fn evaluate<V: Variable>(&self, variable: &V) -> Float {
106        variable.value(self)
107    }
108}
109
110/// A collection of [`Event`]s.
111#[derive(Debug, Clone, Default)]
112pub struct Dataset {
113    /// The [`Event`]s contained in the [`Dataset`]
114    pub events: Vec<Arc<Event>>,
115}
116
117impl Dataset {
118    /// Get a reference to the [`Event`] at the given index in the [`Dataset`] (non-MPI
119    /// version).
120    ///
121    /// # Notes
122    ///
123    /// This method is not intended to be called in analyses but rather in writing methods
124    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
125    /// as if it were any other [`Vec`]:
126    ///
127    /// ```ignore
128    /// let ds: Dataset = Dataset::new(events);
129    /// let event_0 = ds[0];
130    /// ```
131    pub fn index_local(&self, index: usize) -> &Event {
132        &self.events[index]
133    }
134
135    #[cfg(feature = "mpi")]
136    fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
137        for (i, &displ) in displs.iter().enumerate() {
138            if displ as usize > index {
139                return (i as i32 - 1, index - displs[i - 1] as usize);
140            }
141        }
142        (
143            world.size() - 1,
144            index - displs[world.size() as usize - 1] as usize,
145        )
146    }
147
148    #[cfg(feature = "mpi")]
149    fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
150        let (counts, displs) = world.get_counts_displs(events.len());
151        counts
152            .iter()
153            .zip(displs.iter())
154            .map(|(&count, &displ)| {
155                events
156                    .iter()
157                    .skip(displ as usize)
158                    .take(count as usize)
159                    .cloned()
160                    .collect()
161            })
162            .collect()
163    }
164
165    /// Get a reference to the [`Event`] at the given index in the [`Dataset`]
166    /// (MPI-compatible version).
167    ///
168    /// # Notes
169    ///
170    /// This method is not intended to be called in analyses but rather in writing methods
171    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
172    /// as if it were any other [`Vec`]:
173    ///
174    /// ```ignore
175    /// let ds: Dataset = Dataset::new(events);
176    /// let event_0 = ds[0];
177    /// ```
178    #[cfg(feature = "mpi")]
179    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
180        let (_, displs) = world.get_counts_displs(self.n_events());
181        let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
182        let mut serialized_event_buffer_len: usize = 0;
183        let mut serialized_event_buffer: Vec<u8> = Vec::default();
184        let config = bincode::config::standard();
185        if world.rank() == owning_rank {
186            let event = self.index_local(local_index);
187            serialized_event_buffer = bincode::serde::encode_to_vec(event, config).unwrap();
188            serialized_event_buffer_len = serialized_event_buffer.len();
189        }
190        world
191            .process_at_rank(owning_rank)
192            .broadcast_into(&mut serialized_event_buffer_len);
193        if world.rank() != owning_rank {
194            serialized_event_buffer = vec![0; serialized_event_buffer_len];
195        }
196        world
197            .process_at_rank(owning_rank)
198            .broadcast_into(&mut serialized_event_buffer);
199        let (event, _): (Event, usize) =
200            bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
201        Box::leak(Box::new(event))
202    }
203}
204
205impl Index<usize> for Dataset {
206    type Output = Event;
207
208    fn index(&self, index: usize) -> &Self::Output {
209        #[cfg(feature = "mpi")]
210        {
211            if let Some(world) = crate::mpi::get_world() {
212                return self.index_mpi(index, &world);
213            }
214        }
215        self.index_local(index)
216    }
217}
218
219impl Dataset {
220    /// Create a new [`Dataset`] from a list of [`Event`]s (non-MPI version).
221    ///
222    /// # Notes
223    ///
224    /// This method is not intended to be called in analyses but rather in writing methods
225    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
226    pub fn new_local(events: Vec<Arc<Event>>) -> Self {
227        Dataset { events }
228    }
229
230    /// Create a new [`Dataset`] from a list of [`Event`]s (MPI-compatible version).
231    ///
232    /// # Notes
233    ///
234    /// This method is not intended to be called in analyses but rather in writing methods
235    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
236    #[cfg(feature = "mpi")]
237    pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
238        Dataset {
239            events: Dataset::partition(events, world)[world.rank() as usize].clone(),
240        }
241    }
242
243    /// Create a new [`Dataset`] from a list of [`Event`]s.
244    ///
245    /// This method is prefered for external use because it contains proper MPI construction
246    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
247    /// interfacing with MPI and should be avoided unless you know what you are doing.
248    pub fn new(events: Vec<Arc<Event>>) -> Self {
249        #[cfg(feature = "mpi")]
250        {
251            if let Some(world) = crate::mpi::get_world() {
252                return Dataset::new_mpi(events, &world);
253            }
254        }
255        Dataset::new_local(events)
256    }
257
258    /// The number of [`Event`]s in the [`Dataset`] (non-MPI version).
259    ///
260    /// # Notes
261    ///
262    /// This method is not intended to be called in analyses but rather in writing methods
263    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
264    pub fn n_events_local(&self) -> usize {
265        self.events.len()
266    }
267
268    /// The number of [`Event`]s in the [`Dataset`] (MPI-compatible version).
269    ///
270    /// # Notes
271    ///
272    /// This method is not intended to be called in analyses but rather in writing methods
273    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
274    #[cfg(feature = "mpi")]
275    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
276        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
277        let n_events_local = self.n_events_local();
278        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
279        n_events_partitioned.iter().sum()
280    }
281
282    /// The number of [`Event`]s in the [`Dataset`].
283    pub fn n_events(&self) -> usize {
284        #[cfg(feature = "mpi")]
285        {
286            if let Some(world) = crate::mpi::get_world() {
287                return self.n_events_mpi(&world);
288            }
289        }
290        self.n_events_local()
291    }
292}
293
294impl Dataset {
295    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (non-MPI version).
296    ///
297    /// # Notes
298    ///
299    /// This method is not intended to be called in analyses but rather in writing methods
300    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
301    pub fn weights_local(&self) -> Vec<Float> {
302        #[cfg(feature = "rayon")]
303        return self.events.par_iter().map(|e| e.weight).collect();
304        #[cfg(not(feature = "rayon"))]
305        return self.events.iter().map(|e| e.weight).collect();
306    }
307
308    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (MPI-compatible version).
309    ///
310    /// # Notes
311    ///
312    /// This method is not intended to be called in analyses but rather in writing methods
313    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
314    #[cfg(feature = "mpi")]
315    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
316        let local_weights = self.weights_local();
317        let n_events = self.n_events();
318        let mut buffer: Vec<Float> = vec![0.0; n_events];
319        let (counts, displs) = world.get_counts_displs(n_events);
320        {
321            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
322            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
323        }
324        buffer
325    }
326
327    /// Extract a list of weights over each [`Event`] in the [`Dataset`].
328    pub fn weights(&self) -> Vec<Float> {
329        #[cfg(feature = "mpi")]
330        {
331            if let Some(world) = crate::mpi::get_world() {
332                return self.weights_mpi(&world);
333            }
334        }
335        self.weights_local()
336    }
337
338    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (non-MPI version).
339    ///
340    /// # Notes
341    ///
342    /// This method is not intended to be called in analyses but rather in writing methods
343    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
344    pub fn n_events_weighted_local(&self) -> Float {
345        #[cfg(feature = "rayon")]
346        return self
347            .events
348            .par_iter()
349            .map(|e| e.weight)
350            .parallel_sum_with_accumulator::<Klein<Float>>();
351        #[cfg(not(feature = "rayon"))]
352        return self.events.iter().map(|e| e.weight).sum();
353    }
354    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (MPI-compatible version).
355    ///
356    /// # Notes
357    ///
358    /// This method is not intended to be called in analyses but rather in writing methods
359    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
360    #[cfg(feature = "mpi")]
361    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
362        let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
363        let n_events_weighted_local = self.n_events_weighted_local();
364        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
365        #[cfg(feature = "rayon")]
366        return n_events_weighted_partitioned
367            .into_par_iter()
368            .parallel_sum_with_accumulator::<Klein<Float>>();
369        #[cfg(not(feature = "rayon"))]
370        return n_events_weighted_partitioned.iter().sum();
371    }
372
373    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`].
374    pub fn n_events_weighted(&self) -> Float {
375        #[cfg(feature = "mpi")]
376        {
377            if let Some(world) = crate::mpi::get_world() {
378                return self.n_events_weighted_mpi(&world);
379            }
380        }
381        self.n_events_weighted_local()
382    }
383
384    /// Generate a new dataset with the same length by resampling the events in the original datset
385    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
386    ///
387    /// # Notes
388    ///
389    /// This method is not intended to be called in analyses but rather in writing methods
390    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
391    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
392        let mut rng = fastrand::Rng::with_seed(seed as u64);
393        let mut indices: Vec<usize> = (0..self.n_events())
394            .map(|_| rng.usize(0..self.n_events()))
395            .collect::<Vec<usize>>();
396        indices.sort();
397        #[cfg(feature = "rayon")]
398        let bootstrapped_events: Vec<Arc<Event>> = indices
399            .into_par_iter()
400            .map(|idx| self.events[idx].clone())
401            .collect();
402        #[cfg(not(feature = "rayon"))]
403        let bootstrapped_events: Vec<Arc<Event>> = indices
404            .into_iter()
405            .map(|idx| self.events[idx].clone())
406            .collect();
407        Arc::new(Dataset {
408            events: bootstrapped_events,
409        })
410    }
411
412    /// Generate a new dataset with the same length by resampling the events in the original datset
413    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
414    ///
415    /// # Notes
416    ///
417    /// This method is not intended to be called in analyses but rather in writing methods
418    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
419    #[cfg(feature = "mpi")]
420    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
421        let n_events = self.n_events();
422        let mut indices: Vec<usize> = vec![0; n_events];
423        if world.is_root() {
424            let mut rng = fastrand::Rng::with_seed(seed as u64);
425            indices = (0..n_events)
426                .map(|_| rng.usize(0..n_events))
427                .collect::<Vec<usize>>();
428            indices.sort();
429        }
430        world.process_at_root().broadcast_into(&mut indices);
431        let (_, displs) = world.get_counts_displs(self.n_events());
432        let local_indices: Vec<usize> = indices
433            .into_iter()
434            .filter_map(|idx| {
435                let (owning_rank, local_index) = Dataset::get_rank_index(idx, &displs, world);
436                if world.rank() == owning_rank {
437                    Some(local_index)
438                } else {
439                    None
440                }
441            })
442            .collect();
443        // `local_indices` only contains indices owned by the current rank, translating them into
444        // local indices on the events vector.
445        #[cfg(feature = "rayon")]
446        let bootstrapped_events: Vec<Arc<Event>> = local_indices
447            .into_par_iter()
448            .map(|idx| self.events[idx].clone())
449            .collect();
450        #[cfg(not(feature = "rayon"))]
451        let bootstrapped_events: Vec<Arc<Event>> = local_indices
452            .into_iter()
453            .map(|idx| self.events[idx].clone())
454            .collect();
455        Arc::new(Dataset {
456            events: bootstrapped_events,
457        })
458    }
459
460    /// Generate a new dataset with the same length by resampling the events in the original datset
461    /// with replacement. This can be used to perform error analysis via the bootstrap method.
462    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
463        #[cfg(feature = "mpi")]
464        {
465            if let Some(world) = crate::mpi::get_world() {
466                return self.bootstrap_mpi(seed, &world);
467            }
468        }
469        self.bootstrap_local(seed)
470    }
471
472    /// Filter the [`Dataset`] by a given [`VariableExpression`], selecting events for which
473    /// the expression returns `true`.
474    pub fn filter(&self, expression: &VariableExpression) -> Arc<Dataset> {
475        let compiled = expression.compile();
476        #[cfg(feature = "rayon")]
477        let filtered_events = self
478            .events
479            .par_iter()
480            .filter(|e| compiled.evaluate(e))
481            .cloned()
482            .collect();
483        #[cfg(not(feature = "rayon"))]
484        let filtered_events = self
485            .events
486            .iter()
487            .filter(|e| compiled.evaluate(e))
488            .cloned()
489            .collect();
490        Arc::new(Dataset {
491            events: filtered_events,
492        })
493    }
494
495    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
496    /// given `range`.
497    pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
498    where
499        V: Variable,
500    {
501        let bin_width = (range.1 - range.0) / bins as Float;
502        let bin_edges = get_bin_edges(bins, range);
503        #[cfg(feature = "rayon")]
504        let evaluated: Vec<(usize, &Arc<Event>)> = self
505            .events
506            .par_iter()
507            .filter_map(|event| {
508                let value = variable.value(event.as_ref());
509                if value >= range.0 && value < range.1 {
510                    let bin_index = ((value - range.0) / bin_width) as usize;
511                    let bin_index = bin_index.min(bins - 1);
512                    Some((bin_index, event))
513                } else {
514                    None
515                }
516            })
517            .collect();
518        #[cfg(not(feature = "rayon"))]
519        let evaluated: Vec<(usize, &Arc<Event>)> = self
520            .events
521            .iter()
522            .filter_map(|event| {
523                let value = variable.value(event.as_ref());
524                if value >= range.0 && value < range.1 {
525                    let bin_index = ((value - range.0) / bin_width) as usize;
526                    let bin_index = bin_index.min(bins - 1);
527                    Some((bin_index, event))
528                } else {
529                    None
530                }
531            })
532            .collect();
533        let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
534        for (bin_index, event) in evaluated {
535            binned_events[bin_index].push(event.clone());
536        }
537        BinnedDataset {
538            #[cfg(feature = "rayon")]
539            datasets: binned_events
540                .into_par_iter()
541                .map(|events| Arc::new(Dataset { events }))
542                .collect(),
543            #[cfg(not(feature = "rayon"))]
544            datasets: binned_events
545                .into_iter()
546                .map(|events| Arc::new(Dataset { events }))
547                .collect(),
548            edges: bin_edges,
549        }
550    }
551
552    /// Boost all the four-momenta in all [`Event`]s to the rest frame of the given set of
553    /// four-momenta by indices.
554    pub fn boost_to_rest_frame_of<T: AsRef<[usize]> + Sync>(&self, indices: T) -> Arc<Dataset> {
555        #[cfg(feature = "rayon")]
556        {
557            Arc::new(Dataset {
558                events: self
559                    .events
560                    .par_iter()
561                    .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
562                    .collect(),
563            })
564        }
565        #[cfg(not(feature = "rayon"))]
566        {
567            Arc::new(Dataset {
568                events: self
569                    .events
570                    .iter()
571                    .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
572                    .collect(),
573            })
574        }
575    }
576    /// Evaluate a [`Variable`] on every event in the [`Dataset`].
577    pub fn evaluate<V: Variable>(&self, variable: &V) -> Vec<Float> {
578        variable.value_on(self)
579    }
580}
581
582impl_op_ex!(+ |a: &Dataset, b: &Dataset| ->  Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
583
584fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
585    let mut p4s = Vec::new();
586    let mut aux = Vec::new();
587
588    let p4_count = batch
589        .schema()
590        .fields()
591        .iter()
592        .filter(|field| field.name().starts_with(P4_PREFIX))
593        .count()
594        / 4;
595    let aux_count = batch
596        .schema()
597        .fields()
598        .iter()
599        .filter(|field| field.name().starts_with(AUX_PREFIX))
600        .count()
601        / 3;
602
603    for i in 0..p4_count {
604        let e = batch
605            .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
606            .unwrap()
607            .as_any()
608            .downcast_ref::<Float32Array>()
609            .unwrap()
610            .value(row) as Float;
611        let px = batch
612            .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
613            .unwrap()
614            .as_any()
615            .downcast_ref::<Float32Array>()
616            .unwrap()
617            .value(row) as Float;
618        let py = batch
619            .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
620            .unwrap()
621            .as_any()
622            .downcast_ref::<Float32Array>()
623            .unwrap()
624            .value(row) as Float;
625        let pz = batch
626            .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
627            .unwrap()
628            .as_any()
629            .downcast_ref::<Float32Array>()
630            .unwrap()
631            .value(row) as Float;
632        p4s.push(Vec4::new(px, py, pz, e));
633    }
634
635    // TODO: insert empty vectors if not provided
636    for i in 0..aux_count {
637        let x = batch
638            .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
639            .unwrap()
640            .as_any()
641            .downcast_ref::<Float32Array>()
642            .unwrap()
643            .value(row) as Float;
644        let y = batch
645            .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
646            .unwrap()
647            .as_any()
648            .downcast_ref::<Float32Array>()
649            .unwrap()
650            .value(row) as Float;
651        let z = batch
652            .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
653            .unwrap()
654            .as_any()
655            .downcast_ref::<Float32Array>()
656            .unwrap()
657            .value(row) as Float;
658        aux.push(Vec3::new(x, y, z));
659    }
660
661    let weight = batch
662        .column(19)
663        .as_any()
664        .downcast_ref::<Float32Array>()
665        .unwrap()
666        .value(row) as Float;
667
668    Event { p4s, aux, weight }
669}
670
671/// Open a Parquet file and read the data into a [`Dataset`].
672pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
673    // TODO: make this read in directly to MPI ranks
674    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
675    let file = File::open(file_path)?;
676    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
677    let reader = builder.build()?;
678    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
679
680    #[cfg(feature = "rayon")]
681    let events: Vec<Arc<Event>> = batches
682        .into_par_iter()
683        .flat_map(|batch| {
684            let num_rows = batch.num_rows();
685            let mut local_events = Vec::with_capacity(num_rows);
686
687            // Process each row in the batch
688            for row in 0..num_rows {
689                let event = batch_to_event(&batch, row);
690                local_events.push(Arc::new(event));
691            }
692            local_events
693        })
694        .collect();
695    #[cfg(not(feature = "rayon"))]
696    let events: Vec<Arc<Event>> = batches
697        .into_iter()
698        .flat_map(|batch| {
699            let num_rows = batch.num_rows();
700            let mut local_events = Vec::with_capacity(num_rows);
701
702            // Process each row in the batch
703            for row in 0..num_rows {
704                let event = batch_to_event(&batch, row);
705                local_events.push(Arc::new(event));
706            }
707            local_events
708        })
709        .collect();
710    Ok(Arc::new(Dataset::new(events)))
711}
712
713/// Open a Parquet file and read the data into a [`Dataset`]. This method boosts each event to the
714/// rest frame of the four-momenta at the given indices.
715pub fn open_boosted_to_rest_frame_of<T: AsRef<str>, I: AsRef<[usize]> + Sync>(
716    file_path: T,
717    indices: I,
718) -> Result<Arc<Dataset>, LadduError> {
719    // TODO: make this read in directly to MPI ranks
720    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
721    let file = File::open(file_path)?;
722    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
723    let reader = builder.build()?;
724    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
725
726    #[cfg(feature = "rayon")]
727    let events: Vec<Arc<Event>> = batches
728        .into_par_iter()
729        .flat_map(|batch| {
730            let num_rows = batch.num_rows();
731            let mut local_events = Vec::with_capacity(num_rows);
732
733            // Process each row in the batch
734            for row in 0..num_rows {
735                let mut event = batch_to_event(&batch, row);
736                event = event.boost_to_rest_frame_of(indices.as_ref());
737                local_events.push(Arc::new(event));
738            }
739            local_events
740        })
741        .collect();
742    #[cfg(not(feature = "rayon"))]
743    let events: Vec<Arc<Event>> = batches
744        .into_iter()
745        .flat_map(|batch| {
746            let num_rows = batch.num_rows();
747            let mut local_events = Vec::with_capacity(num_rows);
748
749            // Process each row in the batch
750            for row in 0..num_rows {
751                let mut event = batch_to_event(&batch, row);
752                event = event.boost_to_rest_frame_of(indices.as_ref());
753                local_events.push(Arc::new(event));
754            }
755            local_events
756        })
757        .collect();
758    Ok(Arc::new(Dataset::new(events)))
759}
760
761/// A list of [`Dataset`]s formed by binning [`Event`]s by some [`Variable`].
762pub struct BinnedDataset {
763    datasets: Vec<Arc<Dataset>>,
764    edges: Vec<Float>,
765}
766
767impl Index<usize> for BinnedDataset {
768    type Output = Arc<Dataset>;
769
770    fn index(&self, index: usize) -> &Self::Output {
771        &self.datasets[index]
772    }
773}
774
775impl IndexMut<usize> for BinnedDataset {
776    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
777        &mut self.datasets[index]
778    }
779}
780
781impl Deref for BinnedDataset {
782    type Target = Vec<Arc<Dataset>>;
783
784    fn deref(&self) -> &Self::Target {
785        &self.datasets
786    }
787}
788
789impl DerefMut for BinnedDataset {
790    fn deref_mut(&mut self) -> &mut Self::Target {
791        &mut self.datasets
792    }
793}
794
795impl BinnedDataset {
796    /// The number of bins in the [`BinnedDataset`].
797    pub fn n_bins(&self) -> usize {
798        self.datasets.len()
799    }
800
801    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
802    pub fn edges(&self) -> Vec<Float> {
803        self.edges.clone()
804    }
805
806    /// Returns the range that was used to form the [`BinnedDataset`].
807    pub fn range(&self) -> (Float, Float) {
808        (self.edges[0], self.edges[self.n_bins()])
809    }
810}
811
812#[cfg(test)]
813mod tests {
814    use crate::Mass;
815
816    use super::*;
817    use approx::{assert_relative_eq, assert_relative_ne};
818    use serde::{Deserialize, Serialize};
819    #[test]
820    fn test_event_creation() {
821        let event = test_event();
822        assert_eq!(event.p4s.len(), 4);
823        assert_eq!(event.aux.len(), 1);
824        assert_relative_eq!(event.weight, 0.48)
825    }
826
827    #[test]
828    fn test_event_p4_sum() {
829        let event = test_event();
830        let sum = event.get_p4_sum([2, 3]);
831        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
832        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
833        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
834        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
835    }
836
837    #[test]
838    fn test_event_boost() {
839        let event = test_event();
840        let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
841        let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
842        assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
843        assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
844        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
845    }
846
847    #[test]
848    fn test_event_evaluate() {
849        let event = test_event();
850        let mass = Mass::new([1]);
851        assert_relative_eq!(event.evaluate(&mass), 1.007);
852    }
853
854    #[test]
855    fn test_dataset_size_check() {
856        let mut dataset = Dataset::default();
857        assert_eq!(dataset.n_events(), 0);
858        dataset.events.push(Arc::new(test_event()));
859        assert_eq!(dataset.n_events(), 1);
860    }
861
862    #[test]
863    fn test_dataset_sum() {
864        let dataset = test_dataset();
865        let dataset2 = Dataset::new(vec![Arc::new(Event {
866            p4s: test_event().p4s,
867            aux: test_event().aux,
868            weight: 0.52,
869        })]);
870        let dataset_sum = &dataset + &dataset2;
871        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
872        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
873    }
874
875    #[test]
876    fn test_dataset_weights() {
877        let mut dataset = Dataset::default();
878        dataset.events.push(Arc::new(test_event()));
879        dataset.events.push(Arc::new(Event {
880            p4s: test_event().p4s,
881            aux: test_event().aux,
882            weight: 0.52,
883        }));
884        let weights = dataset.weights();
885        assert_eq!(weights.len(), 2);
886        assert_relative_eq!(weights[0], 0.48);
887        assert_relative_eq!(weights[1], 0.52);
888        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
889    }
890
891    #[test]
892    fn test_dataset_filtering() {
893        let mut dataset = Dataset::default();
894        dataset.events.push(Arc::new(Event {
895            p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
896            aux: vec![],
897            weight: 1.0,
898        }));
899        dataset.events.push(Arc::new(Event {
900            p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
901            aux: vec![],
902            weight: 1.0,
903        }));
904        dataset.events.push(Arc::new(Event {
905            p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
906            // HACK: using 1.0 messes with this test because the eventual computation gives a mass
907            // slightly less than 1.0
908            aux: vec![],
909            weight: 1.0,
910        }));
911
912        let mass = Mass::new([0]);
913        let expression = mass.gt(0.0).and(&mass.lt(1.0));
914
915        let filtered = dataset.filter(&expression);
916        assert_eq!(filtered.n_events(), 1);
917        assert_relative_eq!(
918            mass.value(&filtered[0]),
919            0.5,
920            epsilon = Float::EPSILON.sqrt()
921        );
922    }
923
924    #[test]
925    fn test_dataset_boost() {
926        let dataset = test_dataset();
927        let dataset_boosted = dataset.boost_to_rest_frame_of([1, 2, 3]);
928        let p4_sum = dataset_boosted[0].get_p4_sum([1, 2, 3]);
929        assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
930        assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
931        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
932    }
933
934    #[test]
935    fn test_dataset_evaluate() {
936        let dataset = test_dataset();
937        let mass = Mass::new([1]);
938        assert_relative_eq!(dataset.evaluate(&mass)[0], 1.007);
939    }
940
941    #[test]
942    fn test_binned_dataset() {
943        let dataset = Dataset::new(vec![
944            Arc::new(Event {
945                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
946                aux: vec![],
947                weight: 1.0,
948            }),
949            Arc::new(Event {
950                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
951                aux: vec![],
952                weight: 2.0,
953            }),
954        ]);
955
956        #[derive(Clone, Serialize, Deserialize, Debug)]
957        struct BeamEnergy;
958        impl Display for BeamEnergy {
959            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
960                write!(f, "BeamEnergy")
961            }
962        }
963        #[typetag::serde]
964        impl Variable for BeamEnergy {
965            fn value(&self, event: &Event) -> Float {
966                event.p4s[0].e()
967            }
968        }
969        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
970
971        // Test binning by first particle energy
972        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
973
974        assert_eq!(binned.n_bins(), 2);
975        assert_eq!(binned.edges().len(), 3);
976        assert_relative_eq!(binned.edges()[0], 0.0);
977        assert_relative_eq!(binned.edges()[2], 3.0);
978        assert_eq!(binned[0].n_events(), 1);
979        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
980        assert_eq!(binned[1].n_events(), 1);
981        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
982    }
983
984    #[test]
985    fn test_dataset_bootstrap() {
986        let mut dataset = test_dataset();
987        dataset.events.push(Arc::new(Event {
988            p4s: test_event().p4s.clone(),
989            aux: test_event().aux.clone(),
990            weight: 1.0,
991        }));
992        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
993
994        let bootstrapped = dataset.bootstrap(43);
995        assert_eq!(bootstrapped.n_events(), dataset.n_events());
996        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
997
998        // Test empty dataset bootstrap
999        let empty_dataset = Dataset::default();
1000        let empty_bootstrap = empty_dataset.bootstrap(43);
1001        assert_eq!(empty_bootstrap.n_events(), 0);
1002    }
1003    #[test]
1004    fn test_event_display() {
1005        let event = test_event();
1006        let display_string = format!("{}", event);
1007        assert_eq!(
1008            display_string,
1009            "Event:\n  p4s:\n    [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n    [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n    [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n    [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n  eps:\n    [0.385, 0.022, 0]\n  weight:\n    0.48\n"
1010        );
1011    }
1012}