laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23    utils::{
24        variables::Variable,
25        vectors::{Vec3, Vec4},
26    },
27    Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33/// An event that can be used to test the implementation of an
34/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
35/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
36pub fn test_event() -> Event {
37    use crate::utils::vectors::*;
38    Event {
39        p4s: vec![
40            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
41            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
42            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
43            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
44        ],
45        aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46        weight: 0.48,
47    }
48}
49
50/// An dataset that can be used to test the implementation of an
51/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a singular
52/// [`Event`] generated from [`test_event`].
53pub fn test_dataset() -> Dataset {
54    Dataset::new(vec![Arc::new(test_event())])
55}
56
57/// A single event in a [`Dataset`] containing all the relevant particle information.
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60    /// A list of four-momenta for each particle.
61    pub p4s: Vec<Vec4>,
62    /// A list of auxiliary vectors which can be used to store data like particle polarization.
63    pub aux: Vec<Vec3>,
64    /// The weight given to the event.
65    pub weight: Float,
66}
67
68impl Display for Event {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        writeln!(f, "Event:")?;
71        writeln!(f, "  p4s:")?;
72        for p4 in &self.p4s {
73            writeln!(f, "    {}", p4.to_p4_string())?;
74        }
75        writeln!(f, "  eps:")?;
76        for eps_vec in &self.aux {
77            writeln!(f, "    [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78        }
79        writeln!(f, "  weight:")?;
80        writeln!(f, "    {}", self.weight)?;
81        Ok(())
82    }
83}
84
85impl Event {
86    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`Event`].
87    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89    }
90    /// Boost all the four-momenta in the [`Event`] to the rest frame of the given set of
91    /// four-momenta by indices.
92    pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
93        let frame = self.get_p4_sum(indices);
94        Event {
95            p4s: self
96                .p4s
97                .iter()
98                .map(|p4| p4.boost(&(-frame.beta())))
99                .collect(),
100            aux: self.aux.clone(),
101            weight: self.weight,
102        }
103    }
104}
105
106/// A collection of [`Event`]s.
107#[derive(Debug, Clone, Default)]
108pub struct Dataset {
109    /// The [`Event`]s contained in the [`Dataset`]
110    pub events: Vec<Arc<Event>>,
111}
112
113impl Dataset {
114    /// Get a reference to the [`Event`] at the given index in the [`Dataset`] (non-MPI
115    /// version).
116    ///
117    /// # Notes
118    ///
119    /// This method is not intended to be called in analyses but rather in writing methods
120    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
121    /// as if it were any other [`Vec`]:
122    ///
123    /// ```ignore
124    /// let ds: Dataset = Dataset::new(events);
125    /// let event_0 = ds[0];
126    /// ```
127    pub fn index_local(&self, index: usize) -> &Event {
128        &self.events[index]
129    }
130
131    #[cfg(feature = "mpi")]
132    fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
133        for (i, &displ) in displs.iter().enumerate() {
134            if displ as usize > index {
135                return (i as i32 - 1, index - displs[i - 1] as usize);
136            }
137        }
138        (
139            world.size() - 1,
140            index - displs[world.size() as usize - 1] as usize,
141        )
142    }
143
144    #[cfg(feature = "mpi")]
145    fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
146        let (counts, displs) = world.get_counts_displs(events.len());
147        counts
148            .iter()
149            .zip(displs.iter())
150            .map(|(&count, &displ)| {
151                events
152                    .iter()
153                    .skip(displ as usize)
154                    .take(count as usize)
155                    .cloned()
156                    .collect()
157            })
158            .collect()
159    }
160
161    /// Get a reference to the [`Event`] at the given index in the [`Dataset`]
162    /// (MPI-compatible version).
163    ///
164    /// # Notes
165    ///
166    /// This method is not intended to be called in analyses but rather in writing methods
167    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
168    /// as if it were any other [`Vec`]:
169    ///
170    /// ```ignore
171    /// let ds: Dataset = Dataset::new(events);
172    /// let event_0 = ds[0];
173    /// ```
174    #[cfg(feature = "mpi")]
175    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
176        let (_, displs) = world.get_counts_displs(self.n_events());
177        let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
178        let mut serialized_event_buffer_len: usize = 0;
179        let mut serialized_event_buffer: Vec<u8> = Vec::default();
180        let config = bincode::config::standard();
181        if world.rank() == owning_rank {
182            let event = self.index_local(local_index);
183            serialized_event_buffer = bincode::serde::encode_to_vec(event, config).unwrap();
184            serialized_event_buffer_len = serialized_event_buffer.len();
185        }
186        world
187            .process_at_rank(owning_rank)
188            .broadcast_into(&mut serialized_event_buffer_len);
189        if world.rank() != owning_rank {
190            serialized_event_buffer = vec![0; serialized_event_buffer_len];
191        }
192        world
193            .process_at_rank(owning_rank)
194            .broadcast_into(&mut serialized_event_buffer);
195        let (event, _): (Event, usize) =
196            bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
197        Box::leak(Box::new(event))
198    }
199}
200
201impl Index<usize> for Dataset {
202    type Output = Event;
203
204    fn index(&self, index: usize) -> &Self::Output {
205        #[cfg(feature = "mpi")]
206        {
207            if let Some(world) = crate::mpi::get_world() {
208                return self.index_mpi(index, &world);
209            }
210        }
211        self.index_local(index)
212    }
213}
214
215impl Dataset {
216    /// Create a new [`Dataset`] from a list of [`Event`]s (non-MPI version).
217    ///
218    /// # Notes
219    ///
220    /// This method is not intended to be called in analyses but rather in writing methods
221    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
222    pub fn new_local(events: Vec<Arc<Event>>) -> Self {
223        Dataset { events }
224    }
225
226    /// Create a new [`Dataset`] from a list of [`Event`]s (MPI-compatible version).
227    ///
228    /// # Notes
229    ///
230    /// This method is not intended to be called in analyses but rather in writing methods
231    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
232    #[cfg(feature = "mpi")]
233    pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
234        Dataset {
235            events: Dataset::partition(events, world)[world.rank() as usize].clone(),
236        }
237    }
238
239    /// Create a new [`Dataset`] from a list of [`Event`]s.
240    ///
241    /// This method is prefered for external use because it contains proper MPI construction
242    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
243    /// interfacing with MPI and should be avoided unless you know what you are doing.
244    pub fn new(events: Vec<Arc<Event>>) -> Self {
245        #[cfg(feature = "mpi")]
246        {
247            if let Some(world) = crate::mpi::get_world() {
248                return Dataset::new_mpi(events, &world);
249            }
250        }
251        Dataset::new_local(events)
252    }
253
254    /// The number of [`Event`]s in the [`Dataset`] (non-MPI version).
255    ///
256    /// # Notes
257    ///
258    /// This method is not intended to be called in analyses but rather in writing methods
259    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
260    pub fn n_events_local(&self) -> usize {
261        self.events.len()
262    }
263
264    /// The number of [`Event`]s in the [`Dataset`] (MPI-compatible version).
265    ///
266    /// # Notes
267    ///
268    /// This method is not intended to be called in analyses but rather in writing methods
269    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
270    #[cfg(feature = "mpi")]
271    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
272        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
273        let n_events_local = self.n_events_local();
274        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
275        n_events_partitioned.iter().sum()
276    }
277
278    /// The number of [`Event`]s in the [`Dataset`].
279    pub fn n_events(&self) -> usize {
280        #[cfg(feature = "mpi")]
281        {
282            if let Some(world) = crate::mpi::get_world() {
283                return self.n_events_mpi(&world);
284            }
285        }
286        self.n_events_local()
287    }
288}
289
290impl Dataset {
291    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (non-MPI version).
292    ///
293    /// # Notes
294    ///
295    /// This method is not intended to be called in analyses but rather in writing methods
296    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
297    pub fn weights_local(&self) -> Vec<Float> {
298        #[cfg(feature = "rayon")]
299        return self.events.par_iter().map(|e| e.weight).collect();
300        #[cfg(not(feature = "rayon"))]
301        return self.events.iter().map(|e| e.weight).collect();
302    }
303
304    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (MPI-compatible version).
305    ///
306    /// # Notes
307    ///
308    /// This method is not intended to be called in analyses but rather in writing methods
309    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
310    #[cfg(feature = "mpi")]
311    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
312        let local_weights = self.weights_local();
313        let n_events = self.n_events();
314        let mut buffer: Vec<Float> = vec![0.0; n_events];
315        let (counts, displs) = world.get_counts_displs(n_events);
316        {
317            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
318            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
319        }
320        buffer
321    }
322
323    /// Extract a list of weights over each [`Event`] in the [`Dataset`].
324    pub fn weights(&self) -> Vec<Float> {
325        #[cfg(feature = "mpi")]
326        {
327            if let Some(world) = crate::mpi::get_world() {
328                return self.weights_mpi(&world);
329            }
330        }
331        self.weights_local()
332    }
333
334    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (non-MPI version).
335    ///
336    /// # Notes
337    ///
338    /// This method is not intended to be called in analyses but rather in writing methods
339    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
340    pub fn n_events_weighted_local(&self) -> Float {
341        #[cfg(feature = "rayon")]
342        return self
343            .events
344            .par_iter()
345            .map(|e| e.weight)
346            .parallel_sum_with_accumulator::<Klein<Float>>();
347        #[cfg(not(feature = "rayon"))]
348        return self.events.iter().map(|e| e.weight).sum();
349    }
350    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (MPI-compatible version).
351    ///
352    /// # Notes
353    ///
354    /// This method is not intended to be called in analyses but rather in writing methods
355    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
356    #[cfg(feature = "mpi")]
357    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
358        let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
359        let n_events_weighted_local = self.n_events_weighted_local();
360        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
361        #[cfg(feature = "rayon")]
362        return n_events_weighted_partitioned
363            .into_par_iter()
364            .parallel_sum_with_accumulator::<Klein<Float>>();
365        #[cfg(not(feature = "rayon"))]
366        return n_events_weighted_partitioned.iter().sum();
367    }
368
369    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`].
370    pub fn n_events_weighted(&self) -> Float {
371        #[cfg(feature = "mpi")]
372        {
373            if let Some(world) = crate::mpi::get_world() {
374                return self.n_events_weighted_mpi(&world);
375            }
376        }
377        self.n_events_weighted_local()
378    }
379
380    /// Generate a new dataset with the same length by resampling the events in the original datset
381    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
382    ///
383    /// # Notes
384    ///
385    /// This method is not intended to be called in analyses but rather in writing methods
386    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
387    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
388        let mut rng = fastrand::Rng::with_seed(seed as u64);
389        let mut indices: Vec<usize> = (0..self.n_events())
390            .map(|_| rng.usize(0..self.n_events()))
391            .collect::<Vec<usize>>();
392        indices.sort();
393        #[cfg(feature = "rayon")]
394        let bootstrapped_events: Vec<Arc<Event>> = indices
395            .into_par_iter()
396            .map(|idx| self.events[idx].clone())
397            .collect();
398        #[cfg(not(feature = "rayon"))]
399        let bootstrapped_events: Vec<Arc<Event>> = indices
400            .into_iter()
401            .map(|idx| self.events[idx].clone())
402            .collect();
403        Arc::new(Dataset {
404            events: bootstrapped_events,
405        })
406    }
407
408    /// Generate a new dataset with the same length by resampling the events in the original datset
409    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
410    ///
411    /// # Notes
412    ///
413    /// This method is not intended to be called in analyses but rather in writing methods
414    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
415    #[cfg(feature = "mpi")]
416    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
417        let n_events = self.n_events();
418        let mut indices: Vec<usize> = vec![0; n_events];
419        if world.is_root() {
420            let mut rng = fastrand::Rng::with_seed(seed as u64);
421            indices = (0..n_events)
422                .map(|_| rng.usize(0..n_events))
423                .collect::<Vec<usize>>();
424            indices.sort();
425        }
426        world.process_at_root().broadcast_into(&mut indices);
427        let (_, displs) = world.get_counts_displs(self.n_events());
428        let local_indices: Vec<usize> = indices
429            .into_iter()
430            .filter_map(|idx| {
431                let (owning_rank, local_index) = Dataset::get_rank_index(idx, &displs, world);
432                if world.rank() == owning_rank {
433                    Some(local_index)
434                } else {
435                    None
436                }
437            })
438            .collect();
439        // `local_indices` only contains indices owned by the current rank, translating them into
440        // local indices on the events vector.
441        #[cfg(feature = "rayon")]
442        let bootstrapped_events: Vec<Arc<Event>> = local_indices
443            .into_par_iter()
444            .map(|idx| self.events[idx].clone())
445            .collect();
446        #[cfg(not(feature = "rayon"))]
447        let bootstrapped_events: Vec<Arc<Event>> = local_indices
448            .into_iter()
449            .map(|idx| self.events[idx].clone())
450            .collect();
451        Arc::new(Dataset {
452            events: bootstrapped_events,
453        })
454    }
455
456    /// Generate a new dataset with the same length by resampling the events in the original datset
457    /// with replacement. This can be used to perform error analysis via the bootstrap method.
458    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
459        #[cfg(feature = "mpi")]
460        {
461            if let Some(world) = crate::mpi::get_world() {
462                return self.bootstrap_mpi(seed, &world);
463            }
464        }
465        self.bootstrap_local(seed)
466    }
467
468    /// Filter the [`Dataset`] by a given `predicate`, selecting events for which the predicate
469    /// returns `true`.
470    pub fn filter<P>(&self, predicate: P) -> Arc<Dataset>
471    where
472        P: Fn(&Event) -> bool + Send + Sync,
473    {
474        #[cfg(feature = "rayon")]
475        let filtered_events = self
476            .events
477            .par_iter()
478            .filter(|e| predicate(e))
479            .cloned()
480            .collect();
481        #[cfg(not(feature = "rayon"))]
482        let filtered_events = self
483            .events
484            .iter()
485            .filter(|e| predicate(e))
486            .cloned()
487            .collect();
488        Arc::new(Dataset {
489            events: filtered_events,
490        })
491    }
492
493    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
494    /// given `range`.
495    pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
496    where
497        V: Variable,
498    {
499        let bin_width = (range.1 - range.0) / bins as Float;
500        let bin_edges = get_bin_edges(bins, range);
501        #[cfg(feature = "rayon")]
502        let evaluated: Vec<(usize, &Arc<Event>)> = self
503            .events
504            .par_iter()
505            .filter_map(|event| {
506                let value = variable.value(event.as_ref());
507                if value >= range.0 && value < range.1 {
508                    let bin_index = ((value - range.0) / bin_width) as usize;
509                    let bin_index = bin_index.min(bins - 1);
510                    Some((bin_index, event))
511                } else {
512                    None
513                }
514            })
515            .collect();
516        #[cfg(not(feature = "rayon"))]
517        let evaluated: Vec<(usize, &Arc<Event>)> = self
518            .events
519            .iter()
520            .filter_map(|event| {
521                let value = variable.value(event.as_ref());
522                if value >= range.0 && value < range.1 {
523                    let bin_index = ((value - range.0) / bin_width) as usize;
524                    let bin_index = bin_index.min(bins - 1);
525                    Some((bin_index, event))
526                } else {
527                    None
528                }
529            })
530            .collect();
531        let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
532        for (bin_index, event) in evaluated {
533            binned_events[bin_index].push(event.clone());
534        }
535        BinnedDataset {
536            #[cfg(feature = "rayon")]
537            datasets: binned_events
538                .into_par_iter()
539                .map(|events| Arc::new(Dataset { events }))
540                .collect(),
541            #[cfg(not(feature = "rayon"))]
542            datasets: binned_events
543                .into_iter()
544                .map(|events| Arc::new(Dataset { events }))
545                .collect(),
546            edges: bin_edges,
547        }
548    }
549
550    /// Boost all the four-momenta in all [`Event`]s to the rest frame of the given set of
551    /// four-momenta by indices.
552    pub fn boost_to_rest_frame_of<T: AsRef<[usize]> + Sync>(&self, indices: T) -> Arc<Dataset> {
553        #[cfg(feature = "rayon")]
554        {
555            Arc::new(Dataset {
556                events: self
557                    .events
558                    .par_iter()
559                    .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
560                    .collect(),
561            })
562        }
563        #[cfg(not(feature = "rayon"))]
564        {
565            Arc::new(Dataset {
566                events: self
567                    .events
568                    .iter()
569                    .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
570                    .collect(),
571            })
572        }
573    }
574}
575
576impl_op_ex!(+ |a: &Dataset, b: &Dataset| ->  Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
577
578fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
579    let mut p4s = Vec::new();
580    let mut aux = Vec::new();
581
582    let p4_count = batch
583        .schema()
584        .fields()
585        .iter()
586        .filter(|field| field.name().starts_with(P4_PREFIX))
587        .count()
588        / 4;
589    let aux_count = batch
590        .schema()
591        .fields()
592        .iter()
593        .filter(|field| field.name().starts_with(AUX_PREFIX))
594        .count()
595        / 3;
596
597    for i in 0..p4_count {
598        let e = batch
599            .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
600            .unwrap()
601            .as_any()
602            .downcast_ref::<Float32Array>()
603            .unwrap()
604            .value(row) as Float;
605        let px = batch
606            .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
607            .unwrap()
608            .as_any()
609            .downcast_ref::<Float32Array>()
610            .unwrap()
611            .value(row) as Float;
612        let py = batch
613            .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
614            .unwrap()
615            .as_any()
616            .downcast_ref::<Float32Array>()
617            .unwrap()
618            .value(row) as Float;
619        let pz = batch
620            .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
621            .unwrap()
622            .as_any()
623            .downcast_ref::<Float32Array>()
624            .unwrap()
625            .value(row) as Float;
626        p4s.push(Vec4::new(px, py, pz, e));
627    }
628
629    // TODO: insert empty vectors if not provided
630    for i in 0..aux_count {
631        let x = batch
632            .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
633            .unwrap()
634            .as_any()
635            .downcast_ref::<Float32Array>()
636            .unwrap()
637            .value(row) as Float;
638        let y = batch
639            .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
640            .unwrap()
641            .as_any()
642            .downcast_ref::<Float32Array>()
643            .unwrap()
644            .value(row) as Float;
645        let z = batch
646            .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
647            .unwrap()
648            .as_any()
649            .downcast_ref::<Float32Array>()
650            .unwrap()
651            .value(row) as Float;
652        aux.push(Vec3::new(x, y, z));
653    }
654
655    let weight = batch
656        .column(19)
657        .as_any()
658        .downcast_ref::<Float32Array>()
659        .unwrap()
660        .value(row) as Float;
661
662    Event { p4s, aux, weight }
663}
664
665/// Open a Parquet file and read the data into a [`Dataset`].
666pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
667    // TODO: make this read in directly to MPI ranks
668    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
669    let file = File::open(file_path)?;
670    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
671    let reader = builder.build()?;
672    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
673
674    #[cfg(feature = "rayon")]
675    let events: Vec<Arc<Event>> = batches
676        .into_par_iter()
677        .flat_map(|batch| {
678            let num_rows = batch.num_rows();
679            let mut local_events = Vec::with_capacity(num_rows);
680
681            // Process each row in the batch
682            for row in 0..num_rows {
683                let event = batch_to_event(&batch, row);
684                local_events.push(Arc::new(event));
685            }
686            local_events
687        })
688        .collect();
689    #[cfg(not(feature = "rayon"))]
690    let events: Vec<Arc<Event>> = batches
691        .into_iter()
692        .flat_map(|batch| {
693            let num_rows = batch.num_rows();
694            let mut local_events = Vec::with_capacity(num_rows);
695
696            // Process each row in the batch
697            for row in 0..num_rows {
698                let event = batch_to_event(&batch, row);
699                local_events.push(Arc::new(event));
700            }
701            local_events
702        })
703        .collect();
704    Ok(Arc::new(Dataset::new(events)))
705}
706
707/// Open a Parquet file and read the data into a [`Dataset`]. This method boosts each event to the
708/// rest frame of the four-momenta at the given indices.
709pub fn open_boosted_to_rest_frame_of<T: AsRef<str>, I: AsRef<[usize]> + Sync>(
710    file_path: T,
711    indices: I,
712) -> Result<Arc<Dataset>, LadduError> {
713    // TODO: make this read in directly to MPI ranks
714    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
715    let file = File::open(file_path)?;
716    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
717    let reader = builder.build()?;
718    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
719
720    #[cfg(feature = "rayon")]
721    let events: Vec<Arc<Event>> = batches
722        .into_par_iter()
723        .flat_map(|batch| {
724            let num_rows = batch.num_rows();
725            let mut local_events = Vec::with_capacity(num_rows);
726
727            // Process each row in the batch
728            for row in 0..num_rows {
729                let mut event = batch_to_event(&batch, row);
730                event = event.boost_to_rest_frame_of(indices.as_ref());
731                local_events.push(Arc::new(event));
732            }
733            local_events
734        })
735        .collect();
736    #[cfg(not(feature = "rayon"))]
737    let events: Vec<Arc<Event>> = batches
738        .into_iter()
739        .flat_map(|batch| {
740            let num_rows = batch.num_rows();
741            let mut local_events = Vec::with_capacity(num_rows);
742
743            // Process each row in the batch
744            for row in 0..num_rows {
745                let mut event = batch_to_event(&batch, row);
746                event = event.boost_to_rest_frame_of(indices.as_ref());
747                local_events.push(Arc::new(event));
748            }
749            local_events
750        })
751        .collect();
752    Ok(Arc::new(Dataset::new(events)))
753}
754
755/// A list of [`Dataset`]s formed by binning [`Event`]s by some [`Variable`].
756pub struct BinnedDataset {
757    datasets: Vec<Arc<Dataset>>,
758    edges: Vec<Float>,
759}
760
761impl Index<usize> for BinnedDataset {
762    type Output = Arc<Dataset>;
763
764    fn index(&self, index: usize) -> &Self::Output {
765        &self.datasets[index]
766    }
767}
768
769impl IndexMut<usize> for BinnedDataset {
770    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
771        &mut self.datasets[index]
772    }
773}
774
775impl Deref for BinnedDataset {
776    type Target = Vec<Arc<Dataset>>;
777
778    fn deref(&self) -> &Self::Target {
779        &self.datasets
780    }
781}
782
783impl DerefMut for BinnedDataset {
784    fn deref_mut(&mut self) -> &mut Self::Target {
785        &mut self.datasets
786    }
787}
788
789impl BinnedDataset {
790    /// The number of bins in the [`BinnedDataset`].
791    pub fn n_bins(&self) -> usize {
792        self.datasets.len()
793    }
794
795    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
796    pub fn edges(&self) -> Vec<Float> {
797        self.edges.clone()
798    }
799
800    /// Returns the range that was used to form the [`BinnedDataset`].
801    pub fn range(&self) -> (Float, Float) {
802        (self.edges[0], self.edges[self.n_bins()])
803    }
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use approx::{assert_relative_eq, assert_relative_ne};
810    use serde::{Deserialize, Serialize};
811    #[test]
812    fn test_event_creation() {
813        let event = test_event();
814        assert_eq!(event.p4s.len(), 4);
815        assert_eq!(event.aux.len(), 1);
816        assert_relative_eq!(event.weight, 0.48)
817    }
818
819    #[test]
820    fn test_event_p4_sum() {
821        let event = test_event();
822        let sum = event.get_p4_sum([2, 3]);
823        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
824        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
825        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
826        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
827    }
828
829    #[test]
830    fn test_event_boost() {
831        let event = test_event();
832        let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
833        let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
834        assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
835        assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
836        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
837    }
838
839    #[test]
840    fn test_dataset_size_check() {
841        let mut dataset = Dataset::default();
842        assert_eq!(dataset.n_events(), 0);
843        dataset.events.push(Arc::new(test_event()));
844        assert_eq!(dataset.n_events(), 1);
845    }
846
847    #[test]
848    fn test_dataset_sum() {
849        let dataset = test_dataset();
850        let dataset2 = Dataset::new(vec![Arc::new(Event {
851            p4s: test_event().p4s,
852            aux: test_event().aux,
853            weight: 0.52,
854        })]);
855        let dataset_sum = &dataset + &dataset2;
856        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
857        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
858    }
859
860    #[test]
861    fn test_dataset_weights() {
862        let mut dataset = Dataset::default();
863        dataset.events.push(Arc::new(test_event()));
864        dataset.events.push(Arc::new(Event {
865            p4s: test_event().p4s,
866            aux: test_event().aux,
867            weight: 0.52,
868        }));
869        let weights = dataset.weights();
870        assert_eq!(weights.len(), 2);
871        assert_relative_eq!(weights[0], 0.48);
872        assert_relative_eq!(weights[1], 0.52);
873        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
874    }
875
876    #[test]
877    fn test_dataset_filtering() {
878        let mut dataset = test_dataset();
879        dataset.events.push(Arc::new(Event {
880            p4s: vec![
881                Vec3::new(0.0, 0.0, 5.0).with_mass(0.0),
882                Vec3::new(0.0, 0.0, 1.0).with_mass(1.0),
883            ],
884            aux: vec![],
885            weight: 1.0,
886        }));
887
888        let filtered = dataset.filter(|event| event.p4s.len() == 2);
889        assert_eq!(filtered.n_events(), 1);
890        assert_eq!(filtered[0].p4s.len(), 2);
891    }
892
893    #[test]
894    fn test_dataset_boost() {
895        let dataset = test_dataset();
896        let dataset_boosted = dataset.boost_to_rest_frame_of([1, 2, 3]);
897        let p4_sum = dataset_boosted[0].get_p4_sum([1, 2, 3]);
898        assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
899        assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
900        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
901    }
902
903    #[test]
904    fn test_binned_dataset() {
905        let dataset = Dataset::new(vec![
906            Arc::new(Event {
907                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
908                aux: vec![],
909                weight: 1.0,
910            }),
911            Arc::new(Event {
912                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
913                aux: vec![],
914                weight: 2.0,
915            }),
916        ]);
917
918        #[derive(Clone, Serialize, Deserialize, Debug)]
919        struct BeamEnergy;
920        impl Display for BeamEnergy {
921            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
922                write!(f, "BeamEnergy")
923            }
924        }
925        #[typetag::serde]
926        impl Variable for BeamEnergy {
927            fn value(&self, event: &Event) -> Float {
928                event.p4s[0].e()
929            }
930        }
931        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
932
933        // Test binning by first particle energy
934        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
935
936        assert_eq!(binned.n_bins(), 2);
937        assert_eq!(binned.edges().len(), 3);
938        assert_relative_eq!(binned.edges()[0], 0.0);
939        assert_relative_eq!(binned.edges()[2], 3.0);
940        assert_eq!(binned[0].n_events(), 1);
941        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
942        assert_eq!(binned[1].n_events(), 1);
943        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
944    }
945
946    #[test]
947    fn test_dataset_bootstrap() {
948        let mut dataset = test_dataset();
949        dataset.events.push(Arc::new(Event {
950            p4s: test_event().p4s.clone(),
951            aux: test_event().aux.clone(),
952            weight: 1.0,
953        }));
954        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
955
956        let bootstrapped = dataset.bootstrap(43);
957        assert_eq!(bootstrapped.n_events(), dataset.n_events());
958        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
959
960        // Test empty dataset bootstrap
961        let empty_dataset = Dataset::default();
962        let empty_bootstrap = empty_dataset.bootstrap(43);
963        assert_eq!(empty_bootstrap.n_events(), 0);
964    }
965    #[test]
966    fn test_event_display() {
967        let event = test_event();
968        let display_string = format!("{}", event);
969        assert_eq!(
970            display_string,
971            "Event:\n  p4s:\n    [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n    [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n    [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n    [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n  eps:\n    [0.385, 0.022, 0]\n  weight:\n    0.48\n"
972        );
973    }
974}