laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23    utils::{
24        variables::{Variable, VariableExpression},
25        vectors::{Vec3, Vec4},
26    },
27    Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33/// An event that can be used to test the implementation of an
34/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
35/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
36pub fn test_event() -> Event {
37    use crate::utils::vectors::*;
38    Event {
39        p4s: vec![
40            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
41            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
42            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
43            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
44        ],
45        aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46        weight: 0.48,
47    }
48}
49
50/// An dataset that can be used to test the implementation of an
51/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a singular
52/// [`Event`] generated from [`test_event`].
53pub fn test_dataset() -> Dataset {
54    Dataset::new(vec![Arc::new(test_event())])
55}
56
57/// A single event in a [`Dataset`] containing all the relevant particle information.
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60    /// A list of four-momenta for each particle.
61    pub p4s: Vec<Vec4>,
62    /// A list of auxiliary vectors which can be used to store data like particle polarization.
63    pub aux: Vec<Vec3>,
64    /// The weight given to the event.
65    pub weight: Float,
66}
67
68impl Display for Event {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        writeln!(f, "Event:")?;
71        writeln!(f, "  p4s:")?;
72        for p4 in &self.p4s {
73            writeln!(f, "    {}", p4.to_p4_string())?;
74        }
75        writeln!(f, "  eps:")?;
76        for eps_vec in &self.aux {
77            writeln!(f, "    [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78        }
79        writeln!(f, "  weight:")?;
80        writeln!(f, "    {}", self.weight)?;
81        Ok(())
82    }
83}
84
85impl Event {
86    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`Event`].
87    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89    }
90    /// Boost all the four-momenta in the [`Event`] to the rest frame of the given set of
91    /// four-momenta by indices.
92    pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
93        let frame = self.get_p4_sum(indices);
94        Event {
95            p4s: self
96                .p4s
97                .iter()
98                .map(|p4| p4.boost(&(-frame.beta())))
99                .collect(),
100            aux: self.aux.clone(),
101            weight: self.weight,
102        }
103    }
104    /// Evaluate a [`Variable`] on an [`Event`].
105    pub fn evaluate<V: Variable>(&self, variable: &V) -> Float {
106        variable.value(self)
107    }
108}
109
110/// A collection of [`Event`]s.
111#[derive(Debug, Clone, Default)]
112pub struct Dataset {
113    /// The [`Event`]s contained in the [`Dataset`]
114    pub events: Vec<Arc<Event>>,
115}
116
117impl Dataset {
118    /// Get a reference to the [`Event`] at the given index in the [`Dataset`] (non-MPI
119    /// version).
120    ///
121    /// # Notes
122    ///
123    /// This method is not intended to be called in analyses but rather in writing methods
124    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
125    /// as if it were any other [`Vec`]:
126    ///
127    /// ```ignore
128    /// let ds: Dataset = Dataset::new(events);
129    /// let event_0 = ds[0];
130    /// ```
131    pub fn index_local(&self, index: usize) -> &Event {
132        &self.events[index]
133    }
134
135    #[cfg(feature = "mpi")]
136    fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
137        for (i, &displ) in displs.iter().enumerate() {
138            if displ as usize > index {
139                return (i as i32 - 1, index - displs[i - 1] as usize);
140            }
141        }
142        (
143            world.size() - 1,
144            index - displs[world.size() as usize - 1] as usize,
145        )
146    }
147
148    /// Return the counts, displacements, and local indices for the current MPI rank.
149    ///
150    /// This method is useful for processing scalar values over a batch of [`Event`]s rather than
151    /// the entire dataset.
152    #[cfg(feature = "mpi")]
153    pub fn get_counts_displs_locals_from_indices(
154        &self,
155        indices: &[usize],
156        world: &SimpleCommunicator,
157    ) -> (Vec<i32>, Vec<i32>, Vec<usize>) {
158        let mut counts = vec![0i32; world.size() as usize];
159        let mut displs = vec![0i32; world.size() as usize];
160        let (_, global_displs) = world.get_counts_displs(self.n_events());
161        let owning_rank_locals: Vec<(i32, usize)> = indices
162            .iter()
163            .map(|i| Dataset::get_rank_index(*i, &global_displs, world))
164            .collect();
165        let mut locals_by_rank = vec![Vec::new(); world.size() as usize];
166        for &(r, li) in owning_rank_locals.iter() {
167            locals_by_rank[r as usize].push(li);
168        }
169        for rank in 0..world.size() as usize {
170            counts[rank] = locals_by_rank[rank].len() as i32;
171            displs[rank] = if rank == 0 {
172                0
173            } else {
174                displs[rank - 1] + counts[rank - 1]
175            };
176        }
177        (
178            counts,
179            displs,
180            locals_by_rank[world.rank() as usize].clone(),
181        )
182    }
183
184    /// Return the counts, displacements, and local indices for the current MPI rank, flattened to
185    /// account for vectors of a given internal length.
186    ///
187    /// This method is useful for processing vector values over a batch of [`Event`]s rather than
188    /// the entire dataset.
189    #[cfg(feature = "mpi")]
190    pub fn get_flattened_counts_displs_locals_from_indices(
191        &self,
192        indices: &[usize],
193        internal_len: usize,
194        world: &SimpleCommunicator,
195    ) -> (Vec<i32>, Vec<i32>, Vec<usize>) {
196        let mut counts = vec![0i32; world.size() as usize];
197        let mut displs = vec![0i32; world.size() as usize];
198        let (_, global_displs) = world.get_counts_displs(self.n_events());
199        let owning_rank_locals: Vec<(i32, usize)> = indices
200            .iter()
201            .map(|i| Dataset::get_rank_index(*i, &global_displs, world))
202            .collect();
203        let mut locals_by_rank = vec![Vec::new(); world.size() as usize];
204        for &(r, li) in owning_rank_locals.iter() {
205            locals_by_rank[r as usize].push(li);
206        }
207        for rank in 0..world.size() as usize {
208            counts[rank] = (locals_by_rank[rank].len() * internal_len) as i32;
209            displs[rank] = if rank == 0 {
210                0
211            } else {
212                displs[rank - 1] + counts[rank - 1]
213            };
214        }
215        (
216            counts,
217            displs,
218            locals_by_rank[world.rank() as usize].clone(),
219        )
220    }
221
222    #[cfg(feature = "mpi")]
223    fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
224        let (counts, displs) = world.get_counts_displs(events.len());
225        counts
226            .iter()
227            .zip(displs.iter())
228            .map(|(&count, &displ)| {
229                events
230                    .iter()
231                    .skip(displ as usize)
232                    .take(count as usize)
233                    .cloned()
234                    .collect()
235            })
236            .collect()
237    }
238
239    /// Get a reference to the [`Event`] at the given index in the [`Dataset`]
240    /// (MPI-compatible version).
241    ///
242    /// # Notes
243    ///
244    /// This method is not intended to be called in analyses but rather in writing methods
245    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
246    /// as if it were any other [`Vec`]:
247    ///
248    /// ```ignore
249    /// let ds: Dataset = Dataset::new(events);
250    /// let event_0 = ds[0];
251    /// ```
252    #[cfg(feature = "mpi")]
253    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
254        let (_, displs) = world.get_counts_displs(self.n_events());
255        let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
256        let mut serialized_event_buffer_len: usize = 0;
257        let mut serialized_event_buffer: Vec<u8> = Vec::default();
258        let config = bincode::config::standard();
259        if world.rank() == owning_rank {
260            let event = self.index_local(local_index);
261            serialized_event_buffer = bincode::serde::encode_to_vec(event, config).unwrap();
262            serialized_event_buffer_len = serialized_event_buffer.len();
263        }
264        world
265            .process_at_rank(owning_rank)
266            .broadcast_into(&mut serialized_event_buffer_len);
267        if world.rank() != owning_rank {
268            serialized_event_buffer = vec![0; serialized_event_buffer_len];
269        }
270        world
271            .process_at_rank(owning_rank)
272            .broadcast_into(&mut serialized_event_buffer);
273        let (event, _): (Event, usize) =
274            bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
275        Box::leak(Box::new(event))
276    }
277}
278
279impl Index<usize> for Dataset {
280    type Output = Event;
281
282    fn index(&self, index: usize) -> &Self::Output {
283        #[cfg(feature = "mpi")]
284        {
285            if let Some(world) = crate::mpi::get_world() {
286                return self.index_mpi(index, &world);
287            }
288        }
289        self.index_local(index)
290    }
291}
292
293impl Dataset {
294    /// Create a new [`Dataset`] from a list of [`Event`]s (non-MPI version).
295    ///
296    /// # Notes
297    ///
298    /// This method is not intended to be called in analyses but rather in writing methods
299    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
300    pub fn new_local(events: Vec<Arc<Event>>) -> Self {
301        Dataset { events }
302    }
303
304    /// Create a new [`Dataset`] from a list of [`Event`]s (MPI-compatible version).
305    ///
306    /// # Notes
307    ///
308    /// This method is not intended to be called in analyses but rather in writing methods
309    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
310    #[cfg(feature = "mpi")]
311    pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
312        Dataset {
313            events: Dataset::partition(events, world)[world.rank() as usize].clone(),
314        }
315    }
316
317    /// Create a new [`Dataset`] from a list of [`Event`]s.
318    ///
319    /// This method is prefered for external use because it contains proper MPI construction
320    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
321    /// interfacing with MPI and should be avoided unless you know what you are doing.
322    pub fn new(events: Vec<Arc<Event>>) -> Self {
323        #[cfg(feature = "mpi")]
324        {
325            if let Some(world) = crate::mpi::get_world() {
326                return Dataset::new_mpi(events, &world);
327            }
328        }
329        Dataset::new_local(events)
330    }
331
332    /// The number of [`Event`]s in the [`Dataset`] (non-MPI version).
333    ///
334    /// # Notes
335    ///
336    /// This method is not intended to be called in analyses but rather in writing methods
337    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
338    pub fn n_events_local(&self) -> usize {
339        self.events.len()
340    }
341
342    /// The number of [`Event`]s in the [`Dataset`] (MPI-compatible version).
343    ///
344    /// # Notes
345    ///
346    /// This method is not intended to be called in analyses but rather in writing methods
347    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
348    #[cfg(feature = "mpi")]
349    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
350        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
351        let n_events_local = self.n_events_local();
352        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
353        n_events_partitioned.iter().sum()
354    }
355
356    /// The number of [`Event`]s in the [`Dataset`].
357    pub fn n_events(&self) -> usize {
358        #[cfg(feature = "mpi")]
359        {
360            if let Some(world) = crate::mpi::get_world() {
361                return self.n_events_mpi(&world);
362            }
363        }
364        self.n_events_local()
365    }
366}
367
368impl Dataset {
369    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (non-MPI version).
370    ///
371    /// # Notes
372    ///
373    /// This method is not intended to be called in analyses but rather in writing methods
374    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
375    pub fn weights_local(&self) -> Vec<Float> {
376        #[cfg(feature = "rayon")]
377        return self.events.par_iter().map(|e| e.weight).collect();
378        #[cfg(not(feature = "rayon"))]
379        return self.events.iter().map(|e| e.weight).collect();
380    }
381
382    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (MPI-compatible version).
383    ///
384    /// # Notes
385    ///
386    /// This method is not intended to be called in analyses but rather in writing methods
387    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
388    #[cfg(feature = "mpi")]
389    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
390        let local_weights = self.weights_local();
391        let n_events = self.n_events();
392        let mut buffer: Vec<Float> = vec![0.0; n_events];
393        let (counts, displs) = world.get_counts_displs(n_events);
394        {
395            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
396            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
397        }
398        buffer
399    }
400
401    /// Extract a list of weights over each [`Event`] in the [`Dataset`].
402    pub fn weights(&self) -> Vec<Float> {
403        #[cfg(feature = "mpi")]
404        {
405            if let Some(world) = crate::mpi::get_world() {
406                return self.weights_mpi(&world);
407            }
408        }
409        self.weights_local()
410    }
411
412    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (non-MPI version).
413    ///
414    /// # Notes
415    ///
416    /// This method is not intended to be called in analyses but rather in writing methods
417    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
418    pub fn n_events_weighted_local(&self) -> Float {
419        #[cfg(feature = "rayon")]
420        return self
421            .events
422            .par_iter()
423            .map(|e| e.weight)
424            .parallel_sum_with_accumulator::<Klein<Float>>();
425        #[cfg(not(feature = "rayon"))]
426        return self.events.iter().map(|e| e.weight).sum();
427    }
428    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (MPI-compatible version).
429    ///
430    /// # Notes
431    ///
432    /// This method is not intended to be called in analyses but rather in writing methods
433    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
434    #[cfg(feature = "mpi")]
435    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
436        let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
437        let n_events_weighted_local = self.n_events_weighted_local();
438        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
439        #[cfg(feature = "rayon")]
440        return n_events_weighted_partitioned
441            .into_par_iter()
442            .parallel_sum_with_accumulator::<Klein<Float>>();
443        #[cfg(not(feature = "rayon"))]
444        return n_events_weighted_partitioned.iter().sum();
445    }
446
447    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`].
448    pub fn n_events_weighted(&self) -> Float {
449        #[cfg(feature = "mpi")]
450        {
451            if let Some(world) = crate::mpi::get_world() {
452                return self.n_events_weighted_mpi(&world);
453            }
454        }
455        self.n_events_weighted_local()
456    }
457
458    /// Generate a new dataset with the same length by resampling the events in the original datset
459    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
460    ///
461    /// # Notes
462    ///
463    /// This method is not intended to be called in analyses but rather in writing methods
464    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
465    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
466        let mut rng = fastrand::Rng::with_seed(seed as u64);
467        let mut indices: Vec<usize> = (0..self.n_events())
468            .map(|_| rng.usize(0..self.n_events()))
469            .collect::<Vec<usize>>();
470        indices.sort();
471        #[cfg(feature = "rayon")]
472        let bootstrapped_events: Vec<Arc<Event>> = indices
473            .into_par_iter()
474            .map(|idx| self.events[idx].clone())
475            .collect();
476        #[cfg(not(feature = "rayon"))]
477        let bootstrapped_events: Vec<Arc<Event>> = indices
478            .into_iter()
479            .map(|idx| self.events[idx].clone())
480            .collect();
481        Arc::new(Dataset {
482            events: bootstrapped_events,
483        })
484    }
485
486    /// Generate a new dataset with the same length by resampling the events in the original datset
487    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
488    ///
489    /// # Notes
490    ///
491    /// This method is not intended to be called in analyses but rather in writing methods
492    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
493    #[cfg(feature = "mpi")]
494    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
495        let n_events = self.n_events();
496        let mut indices: Vec<usize> = vec![0; n_events];
497        if world.is_root() {
498            let mut rng = fastrand::Rng::with_seed(seed as u64);
499            indices = (0..n_events)
500                .map(|_| rng.usize(0..n_events))
501                .collect::<Vec<usize>>();
502            indices.sort();
503        }
504        world.process_at_root().broadcast_into(&mut indices);
505        let (_, displs) = world.get_counts_displs(self.n_events());
506        let local_indices: Vec<usize> = indices
507            .into_iter()
508            .filter_map(|idx| {
509                let (owning_rank, local_index) = Dataset::get_rank_index(idx, &displs, world);
510                if world.rank() == owning_rank {
511                    Some(local_index)
512                } else {
513                    None
514                }
515            })
516            .collect();
517        // `local_indices` only contains indices owned by the current rank, translating them into
518        // local indices on the events vector.
519        #[cfg(feature = "rayon")]
520        let bootstrapped_events: Vec<Arc<Event>> = local_indices
521            .into_par_iter()
522            .map(|idx| self.events[idx].clone())
523            .collect();
524        #[cfg(not(feature = "rayon"))]
525        let bootstrapped_events: Vec<Arc<Event>> = local_indices
526            .into_iter()
527            .map(|idx| self.events[idx].clone())
528            .collect();
529        Arc::new(Dataset {
530            events: bootstrapped_events,
531        })
532    }
533
534    /// Generate a new dataset with the same length by resampling the events in the original datset
535    /// with replacement. This can be used to perform error analysis via the bootstrap method.
536    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
537        #[cfg(feature = "mpi")]
538        {
539            if let Some(world) = crate::mpi::get_world() {
540                return self.bootstrap_mpi(seed, &world);
541            }
542        }
543        self.bootstrap_local(seed)
544    }
545
546    /// Filter the [`Dataset`] by a given [`VariableExpression`], selecting events for which
547    /// the expression returns `true`.
548    pub fn filter(&self, expression: &VariableExpression) -> Arc<Dataset> {
549        let compiled = expression.compile();
550        #[cfg(feature = "rayon")]
551        let filtered_events = self
552            .events
553            .par_iter()
554            .filter(|e| compiled.evaluate(e))
555            .cloned()
556            .collect();
557        #[cfg(not(feature = "rayon"))]
558        let filtered_events = self
559            .events
560            .iter()
561            .filter(|e| compiled.evaluate(e))
562            .cloned()
563            .collect();
564        Arc::new(Dataset {
565            events: filtered_events,
566        })
567    }
568
569    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
570    /// given `range`.
571    pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
572    where
573        V: Variable,
574    {
575        let bin_width = (range.1 - range.0) / bins as Float;
576        let bin_edges = get_bin_edges(bins, range);
577        #[cfg(feature = "rayon")]
578        let evaluated: Vec<(usize, &Arc<Event>)> = self
579            .events
580            .par_iter()
581            .filter_map(|event| {
582                let value = variable.value(event.as_ref());
583                if value >= range.0 && value < range.1 {
584                    let bin_index = ((value - range.0) / bin_width) as usize;
585                    let bin_index = bin_index.min(bins - 1);
586                    Some((bin_index, event))
587                } else {
588                    None
589                }
590            })
591            .collect();
592        #[cfg(not(feature = "rayon"))]
593        let evaluated: Vec<(usize, &Arc<Event>)> = self
594            .events
595            .iter()
596            .filter_map(|event| {
597                let value = variable.value(event.as_ref());
598                if value >= range.0 && value < range.1 {
599                    let bin_index = ((value - range.0) / bin_width) as usize;
600                    let bin_index = bin_index.min(bins - 1);
601                    Some((bin_index, event))
602                } else {
603                    None
604                }
605            })
606            .collect();
607        let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
608        for (bin_index, event) in evaluated {
609            binned_events[bin_index].push(event.clone());
610        }
611        BinnedDataset {
612            #[cfg(feature = "rayon")]
613            datasets: binned_events
614                .into_par_iter()
615                .map(|events| Arc::new(Dataset { events }))
616                .collect(),
617            #[cfg(not(feature = "rayon"))]
618            datasets: binned_events
619                .into_iter()
620                .map(|events| Arc::new(Dataset { events }))
621                .collect(),
622            edges: bin_edges,
623        }
624    }
625
626    /// Boost all the four-momenta in all [`Event`]s to the rest frame of the given set of
627    /// four-momenta by indices.
628    pub fn boost_to_rest_frame_of<T: AsRef<[usize]> + Sync>(&self, indices: T) -> Arc<Dataset> {
629        #[cfg(feature = "rayon")]
630        {
631            Arc::new(Dataset {
632                events: self
633                    .events
634                    .par_iter()
635                    .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
636                    .collect(),
637            })
638        }
639        #[cfg(not(feature = "rayon"))]
640        {
641            Arc::new(Dataset {
642                events: self
643                    .events
644                    .iter()
645                    .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
646                    .collect(),
647            })
648        }
649    }
650    /// Evaluate a [`Variable`] on every event in the [`Dataset`].
651    pub fn evaluate<V: Variable>(&self, variable: &V) -> Vec<Float> {
652        variable.value_on(self)
653    }
654}
655
656impl_op_ex!(+ |a: &Dataset, b: &Dataset| ->  Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
657
658fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
659    let mut p4s = Vec::new();
660    let mut aux = Vec::new();
661
662    let p4_count = batch
663        .schema()
664        .fields()
665        .iter()
666        .filter(|field| field.name().starts_with(P4_PREFIX))
667        .count()
668        / 4;
669    let aux_count = batch
670        .schema()
671        .fields()
672        .iter()
673        .filter(|field| field.name().starts_with(AUX_PREFIX))
674        .count()
675        / 3;
676
677    for i in 0..p4_count {
678        let e = batch
679            .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
680            .unwrap()
681            .as_any()
682            .downcast_ref::<Float32Array>()
683            .unwrap()
684            .value(row) as Float;
685        let px = batch
686            .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
687            .unwrap()
688            .as_any()
689            .downcast_ref::<Float32Array>()
690            .unwrap()
691            .value(row) as Float;
692        let py = batch
693            .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
694            .unwrap()
695            .as_any()
696            .downcast_ref::<Float32Array>()
697            .unwrap()
698            .value(row) as Float;
699        let pz = batch
700            .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
701            .unwrap()
702            .as_any()
703            .downcast_ref::<Float32Array>()
704            .unwrap()
705            .value(row) as Float;
706        p4s.push(Vec4::new(px, py, pz, e));
707    }
708
709    // TODO: insert empty vectors if not provided
710    for i in 0..aux_count {
711        let x = batch
712            .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
713            .unwrap()
714            .as_any()
715            .downcast_ref::<Float32Array>()
716            .unwrap()
717            .value(row) as Float;
718        let y = batch
719            .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
720            .unwrap()
721            .as_any()
722            .downcast_ref::<Float32Array>()
723            .unwrap()
724            .value(row) as Float;
725        let z = batch
726            .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
727            .unwrap()
728            .as_any()
729            .downcast_ref::<Float32Array>()
730            .unwrap()
731            .value(row) as Float;
732        aux.push(Vec3::new(x, y, z));
733    }
734
735    let weight = batch
736        .column(19)
737        .as_any()
738        .downcast_ref::<Float32Array>()
739        .unwrap()
740        .value(row) as Float;
741
742    Event { p4s, aux, weight }
743}
744
745/// Open a Parquet file and read the data into a [`Dataset`].
746pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
747    // TODO: make this read in directly to MPI ranks
748    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
749    let file = File::open(file_path)?;
750    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
751    let reader = builder.build()?;
752    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
753
754    #[cfg(feature = "rayon")]
755    let events: Vec<Arc<Event>> = batches
756        .into_par_iter()
757        .flat_map(|batch| {
758            let num_rows = batch.num_rows();
759            let mut local_events = Vec::with_capacity(num_rows);
760
761            // Process each row in the batch
762            for row in 0..num_rows {
763                let event = batch_to_event(&batch, row);
764                local_events.push(Arc::new(event));
765            }
766            local_events
767        })
768        .collect();
769    #[cfg(not(feature = "rayon"))]
770    let events: Vec<Arc<Event>> = batches
771        .into_iter()
772        .flat_map(|batch| {
773            let num_rows = batch.num_rows();
774            let mut local_events = Vec::with_capacity(num_rows);
775
776            // Process each row in the batch
777            for row in 0..num_rows {
778                let event = batch_to_event(&batch, row);
779                local_events.push(Arc::new(event));
780            }
781            local_events
782        })
783        .collect();
784    Ok(Arc::new(Dataset::new(events)))
785}
786
787/// Open a Parquet file and read the data into a [`Dataset`]. This method boosts each event to the
788/// rest frame of the four-momenta at the given indices.
789pub fn open_boosted_to_rest_frame_of<T: AsRef<str>, I: AsRef<[usize]> + Sync>(
790    file_path: T,
791    indices: I,
792) -> Result<Arc<Dataset>, LadduError> {
793    // TODO: make this read in directly to MPI ranks
794    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
795    let file = File::open(file_path)?;
796    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
797    let reader = builder.build()?;
798    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
799
800    #[cfg(feature = "rayon")]
801    let events: Vec<Arc<Event>> = batches
802        .into_par_iter()
803        .flat_map(|batch| {
804            let num_rows = batch.num_rows();
805            let mut local_events = Vec::with_capacity(num_rows);
806
807            // Process each row in the batch
808            for row in 0..num_rows {
809                let mut event = batch_to_event(&batch, row);
810                event = event.boost_to_rest_frame_of(indices.as_ref());
811                local_events.push(Arc::new(event));
812            }
813            local_events
814        })
815        .collect();
816    #[cfg(not(feature = "rayon"))]
817    let events: Vec<Arc<Event>> = batches
818        .into_iter()
819        .flat_map(|batch| {
820            let num_rows = batch.num_rows();
821            let mut local_events = Vec::with_capacity(num_rows);
822
823            // Process each row in the batch
824            for row in 0..num_rows {
825                let mut event = batch_to_event(&batch, row);
826                event = event.boost_to_rest_frame_of(indices.as_ref());
827                local_events.push(Arc::new(event));
828            }
829            local_events
830        })
831        .collect();
832    Ok(Arc::new(Dataset::new(events)))
833}
834
835/// A list of [`Dataset`]s formed by binning [`Event`]s by some [`Variable`].
836pub struct BinnedDataset {
837    datasets: Vec<Arc<Dataset>>,
838    edges: Vec<Float>,
839}
840
841impl Index<usize> for BinnedDataset {
842    type Output = Arc<Dataset>;
843
844    fn index(&self, index: usize) -> &Self::Output {
845        &self.datasets[index]
846    }
847}
848
849impl IndexMut<usize> for BinnedDataset {
850    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
851        &mut self.datasets[index]
852    }
853}
854
855impl Deref for BinnedDataset {
856    type Target = Vec<Arc<Dataset>>;
857
858    fn deref(&self) -> &Self::Target {
859        &self.datasets
860    }
861}
862
863impl DerefMut for BinnedDataset {
864    fn deref_mut(&mut self) -> &mut Self::Target {
865        &mut self.datasets
866    }
867}
868
869impl BinnedDataset {
870    /// The number of bins in the [`BinnedDataset`].
871    pub fn n_bins(&self) -> usize {
872        self.datasets.len()
873    }
874
875    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
876    pub fn edges(&self) -> Vec<Float> {
877        self.edges.clone()
878    }
879
880    /// Returns the range that was used to form the [`BinnedDataset`].
881    pub fn range(&self) -> (Float, Float) {
882        (self.edges[0], self.edges[self.n_bins()])
883    }
884}
885
886#[cfg(test)]
887mod tests {
888    use crate::Mass;
889
890    use super::*;
891    use approx::{assert_relative_eq, assert_relative_ne};
892    use serde::{Deserialize, Serialize};
893    #[test]
894    fn test_event_creation() {
895        let event = test_event();
896        assert_eq!(event.p4s.len(), 4);
897        assert_eq!(event.aux.len(), 1);
898        assert_relative_eq!(event.weight, 0.48)
899    }
900
901    #[test]
902    fn test_event_p4_sum() {
903        let event = test_event();
904        let sum = event.get_p4_sum([2, 3]);
905        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
906        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
907        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
908        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
909    }
910
911    #[test]
912    fn test_event_boost() {
913        let event = test_event();
914        let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
915        let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
916        assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
917        assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
918        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
919    }
920
921    #[test]
922    fn test_event_evaluate() {
923        let event = test_event();
924        let mass = Mass::new([1]);
925        assert_relative_eq!(event.evaluate(&mass), 1.007);
926    }
927
928    #[test]
929    fn test_dataset_size_check() {
930        let mut dataset = Dataset::default();
931        assert_eq!(dataset.n_events(), 0);
932        dataset.events.push(Arc::new(test_event()));
933        assert_eq!(dataset.n_events(), 1);
934    }
935
936    #[test]
937    fn test_dataset_sum() {
938        let dataset = test_dataset();
939        let dataset2 = Dataset::new(vec![Arc::new(Event {
940            p4s: test_event().p4s,
941            aux: test_event().aux,
942            weight: 0.52,
943        })]);
944        let dataset_sum = &dataset + &dataset2;
945        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
946        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
947    }
948
949    #[test]
950    fn test_dataset_weights() {
951        let mut dataset = Dataset::default();
952        dataset.events.push(Arc::new(test_event()));
953        dataset.events.push(Arc::new(Event {
954            p4s: test_event().p4s,
955            aux: test_event().aux,
956            weight: 0.52,
957        }));
958        let weights = dataset.weights();
959        assert_eq!(weights.len(), 2);
960        assert_relative_eq!(weights[0], 0.48);
961        assert_relative_eq!(weights[1], 0.52);
962        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
963    }
964
965    #[test]
966    fn test_dataset_filtering() {
967        let mut dataset = Dataset::default();
968        dataset.events.push(Arc::new(Event {
969            p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
970            aux: vec![],
971            weight: 1.0,
972        }));
973        dataset.events.push(Arc::new(Event {
974            p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
975            aux: vec![],
976            weight: 1.0,
977        }));
978        dataset.events.push(Arc::new(Event {
979            p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
980            // HACK: using 1.0 messes with this test because the eventual computation gives a mass
981            // slightly less than 1.0
982            aux: vec![],
983            weight: 1.0,
984        }));
985
986        let mass = Mass::new([0]);
987        let expression = mass.gt(0.0).and(&mass.lt(1.0));
988
989        let filtered = dataset.filter(&expression);
990        assert_eq!(filtered.n_events(), 1);
991        assert_relative_eq!(
992            mass.value(&filtered[0]),
993            0.5,
994            epsilon = Float::EPSILON.sqrt()
995        );
996    }
997
998    #[test]
999    fn test_dataset_boost() {
1000        let dataset = test_dataset();
1001        let dataset_boosted = dataset.boost_to_rest_frame_of([1, 2, 3]);
1002        let p4_sum = dataset_boosted[0].get_p4_sum([1, 2, 3]);
1003        assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
1004        assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
1005        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
1006    }
1007
1008    #[test]
1009    fn test_dataset_evaluate() {
1010        let dataset = test_dataset();
1011        let mass = Mass::new([1]);
1012        assert_relative_eq!(dataset.evaluate(&mass)[0], 1.007);
1013    }
1014
1015    #[test]
1016    fn test_binned_dataset() {
1017        let dataset = Dataset::new(vec![
1018            Arc::new(Event {
1019                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
1020                aux: vec![],
1021                weight: 1.0,
1022            }),
1023            Arc::new(Event {
1024                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
1025                aux: vec![],
1026                weight: 2.0,
1027            }),
1028        ]);
1029
1030        #[derive(Clone, Serialize, Deserialize, Debug)]
1031        struct BeamEnergy;
1032        impl Display for BeamEnergy {
1033            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1034                write!(f, "BeamEnergy")
1035            }
1036        }
1037        #[typetag::serde]
1038        impl Variable for BeamEnergy {
1039            fn value(&self, event: &Event) -> Float {
1040                event.p4s[0].e()
1041            }
1042        }
1043        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
1044
1045        // Test binning by first particle energy
1046        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
1047
1048        assert_eq!(binned.n_bins(), 2);
1049        assert_eq!(binned.edges().len(), 3);
1050        assert_relative_eq!(binned.edges()[0], 0.0);
1051        assert_relative_eq!(binned.edges()[2], 3.0);
1052        assert_eq!(binned[0].n_events(), 1);
1053        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
1054        assert_eq!(binned[1].n_events(), 1);
1055        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
1056    }
1057
1058    #[test]
1059    fn test_dataset_bootstrap() {
1060        let mut dataset = test_dataset();
1061        dataset.events.push(Arc::new(Event {
1062            p4s: test_event().p4s.clone(),
1063            aux: test_event().aux.clone(),
1064            weight: 1.0,
1065        }));
1066        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
1067
1068        let bootstrapped = dataset.bootstrap(43);
1069        assert_eq!(bootstrapped.n_events(), dataset.n_events());
1070        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
1071
1072        // Test empty dataset bootstrap
1073        let empty_dataset = Dataset::default();
1074        let empty_bootstrap = empty_dataset.bootstrap(43);
1075        assert_eq!(empty_bootstrap.n_events(), 0);
1076    }
1077    #[test]
1078    fn test_event_display() {
1079        let event = test_event();
1080        let display_string = format!("{}", event);
1081        assert_eq!(
1082            display_string,
1083            "Event:\n  p4s:\n    [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n    [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n    [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n    [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n  eps:\n    [0.385, 0.022, 0]\n  weight:\n    0.48\n"
1084        );
1085    }
1086}