laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23    utils::{
24        variables::Variable,
25        vectors::{Vec3, Vec4},
26    },
27    Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33/// An event that can be used to test the implementation of an
34/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
35/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
36pub fn test_event() -> Event {
37    use crate::utils::vectors::*;
38    Event {
39        p4s: vec![
40            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
41            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
42            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
43            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
44        ],
45        aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46        weight: 0.48,
47    }
48}
49
50/// An dataset that can be used to test the implementation of an
51/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a singular
52/// [`Event`] generated from [`test_event`].
53pub fn test_dataset() -> Dataset {
54    Dataset::new(vec![Arc::new(test_event())])
55}
56
57/// A single event in a [`Dataset`] containing all the relevant particle information.
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60    /// A list of four-momenta for each particle.
61    pub p4s: Vec<Vec4>,
62    /// A list of auxiliary vectors which can be used to store data like particle polarization.
63    pub aux: Vec<Vec3>,
64    /// The weight given to the event.
65    pub weight: Float,
66}
67
68impl Display for Event {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        writeln!(f, "Event:")?;
71        writeln!(f, "  p4s:")?;
72        for p4 in &self.p4s {
73            writeln!(f, "    {}", p4.to_p4_string())?;
74        }
75        writeln!(f, "  eps:")?;
76        for eps_vec in &self.aux {
77            writeln!(f, "    [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78        }
79        writeln!(f, "  weight:")?;
80        writeln!(f, "    {}", self.weight)?;
81        Ok(())
82    }
83}
84
85impl Event {
86    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`Event`].
87    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89    }
90}
91
92/// A collection of [`Event`]s.
93#[derive(Debug, Clone, Default)]
94pub struct Dataset {
95    /// The [`Event`]s contained in the [`Dataset`]
96    pub events: Vec<Arc<Event>>,
97}
98
99impl Dataset {
100    /// Get a reference to the [`Event`] at the given index in the [`Dataset`] (non-MPI
101    /// version).
102    ///
103    /// # Notes
104    ///
105    /// This method is not intended to be called in analyses but rather in writing methods
106    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
107    /// as if it were any other [`Vec`]:
108    ///
109    /// ```ignore
110    /// let ds: Dataset = Dataset::new(events);
111    /// let event_0 = ds[0];
112    /// ```
113    pub fn index_local(&self, index: usize) -> &Event {
114        &self.events[index]
115    }
116
117    #[cfg(feature = "mpi")]
118    fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
119        for (i, &displ) in displs.iter().enumerate() {
120            if displ as usize > index {
121                return (i as i32 - 1, index - displs[i - 1] as usize);
122            }
123        }
124        (
125            world.size() - 1,
126            index - displs[world.size() as usize - 1] as usize,
127        )
128    }
129
130    #[cfg(feature = "mpi")]
131    fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
132        let (counts, displs) = world.get_counts_displs(events.len());
133        counts
134            .iter()
135            .zip(displs.iter())
136            .map(|(&count, &displ)| {
137                events
138                    .iter()
139                    .skip(displ as usize)
140                    .take(count as usize)
141                    .cloned()
142                    .collect()
143            })
144            .collect()
145    }
146
147    /// Get a reference to the [`Event`] at the given index in the [`Dataset`]
148    /// (MPI-compatible version).
149    ///
150    /// # Notes
151    ///
152    /// This method is not intended to be called in analyses but rather in writing methods
153    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
154    /// as if it were any other [`Vec`]:
155    ///
156    /// ```ignore
157    /// let ds: Dataset = Dataset::new(events);
158    /// let event_0 = ds[0];
159    /// ```
160    #[cfg(feature = "mpi")]
161    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
162        let (_, displs) = world.get_counts_displs(self.n_events());
163        let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
164        let mut serialized_event_buffer_len: usize = 0;
165        let mut serialized_event_buffer: Vec<u8> = Vec::default();
166        if world.rank() == owning_rank {
167            let event = self.index_local(local_index);
168            serialized_event_buffer = bincode::serialize(event).unwrap();
169            serialized_event_buffer_len = serialized_event_buffer.len();
170        }
171        world
172            .process_at_rank(owning_rank)
173            .broadcast_into(&mut serialized_event_buffer_len);
174        if world.rank() != owning_rank {
175            serialized_event_buffer = vec![0; serialized_event_buffer_len];
176        }
177        world
178            .process_at_rank(owning_rank)
179            .broadcast_into(&mut serialized_event_buffer);
180        let event: Event = bincode::deserialize(&serialized_event_buffer).unwrap();
181        Box::leak(Box::new(event))
182    }
183}
184
185impl Index<usize> for Dataset {
186    type Output = Event;
187
188    fn index(&self, index: usize) -> &Self::Output {
189        #[cfg(feature = "mpi")]
190        {
191            if let Some(world) = crate::mpi::get_world() {
192                return self.index_mpi(index, &world);
193            }
194        }
195        self.index_local(index)
196    }
197}
198
199impl Dataset {
200    /// Create a new [`Dataset`] from a list of [`Event`]s (non-MPI version).
201    ///
202    /// # Notes
203    ///
204    /// This method is not intended to be called in analyses but rather in writing methods
205    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
206    pub fn new_local(events: Vec<Arc<Event>>) -> Self {
207        Dataset { events }
208    }
209
210    /// Create a new [`Dataset`] from a list of [`Event`]s (MPI-compatible version).
211    ///
212    /// # Notes
213    ///
214    /// This method is not intended to be called in analyses but rather in writing methods
215    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
216    #[cfg(feature = "mpi")]
217    pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
218        Dataset {
219            events: Dataset::partition(events, world)[world.rank() as usize].clone(),
220        }
221    }
222
223    /// Create a new [`Dataset`] from a list of [`Event`]s.
224    ///
225    /// This method is prefered for external use because it contains proper MPI construction
226    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
227    /// interfacing with MPI and should be avoided unless you know what you are doing.
228    pub fn new(events: Vec<Arc<Event>>) -> Self {
229        #[cfg(feature = "mpi")]
230        {
231            if let Some(world) = crate::mpi::get_world() {
232                return Dataset::new_mpi(events, &world);
233            }
234        }
235        Dataset::new_local(events)
236    }
237
238    /// The number of [`Event`]s in the [`Dataset`] (non-MPI version).
239    ///
240    /// # Notes
241    ///
242    /// This method is not intended to be called in analyses but rather in writing methods
243    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
244    pub fn n_events_local(&self) -> usize {
245        self.events.len()
246    }
247
248    /// The number of [`Event`]s in the [`Dataset`] (MPI-compatible version).
249    ///
250    /// # Notes
251    ///
252    /// This method is not intended to be called in analyses but rather in writing methods
253    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
254    #[cfg(feature = "mpi")]
255    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
256        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
257        let n_events_local = self.n_events_local();
258        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
259        n_events_partitioned.iter().sum()
260    }
261
262    /// The number of [`Event`]s in the [`Dataset`].
263    pub fn n_events(&self) -> usize {
264        #[cfg(feature = "mpi")]
265        {
266            if let Some(world) = crate::mpi::get_world() {
267                return self.n_events_mpi(&world);
268            }
269        }
270        self.n_events_local()
271    }
272}
273
274impl Dataset {
275    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (non-MPI version).
276    ///
277    /// # Notes
278    ///
279    /// This method is not intended to be called in analyses but rather in writing methods
280    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
281    pub fn weights_local(&self) -> Vec<Float> {
282        #[cfg(feature = "rayon")]
283        return self.events.par_iter().map(|e| e.weight).collect();
284        #[cfg(not(feature = "rayon"))]
285        return self.events.iter().map(|e| e.weight).collect();
286    }
287
288    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (MPI-compatible version).
289    ///
290    /// # Notes
291    ///
292    /// This method is not intended to be called in analyses but rather in writing methods
293    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
294    #[cfg(feature = "mpi")]
295    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
296        let local_weights = self.weights_local();
297        let n_events = self.n_events();
298        let mut buffer: Vec<Float> = vec![0.0; n_events];
299        let (counts, displs) = world.get_counts_displs(n_events);
300        {
301            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
302            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
303        }
304        buffer
305    }
306
307    /// Extract a list of weights over each [`Event`] in the [`Dataset`].
308    pub fn weights(&self) -> Vec<Float> {
309        #[cfg(feature = "mpi")]
310        {
311            if let Some(world) = crate::mpi::get_world() {
312                return self.weights_mpi(&world);
313            }
314        }
315        self.weights_local()
316    }
317
318    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (non-MPI version).
319    ///
320    /// # Notes
321    ///
322    /// This method is not intended to be called in analyses but rather in writing methods
323    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
324    pub fn n_events_weighted_local(&self) -> Float {
325        #[cfg(feature = "rayon")]
326        return self
327            .events
328            .par_iter()
329            .map(|e| e.weight)
330            .parallel_sum_with_accumulator::<Klein<Float>>();
331        #[cfg(not(feature = "rayon"))]
332        return self.events.iter().map(|e| e.weight).sum();
333    }
334    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (MPI-compatible version).
335    ///
336    /// # Notes
337    ///
338    /// This method is not intended to be called in analyses but rather in writing methods
339    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
340    #[cfg(feature = "mpi")]
341    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
342        let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
343        let n_events_weighted_local = self.n_events_weighted_local();
344        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
345        #[cfg(feature = "rayon")]
346        return n_events_weighted_partitioned
347            .into_par_iter()
348            .parallel_sum_with_accumulator::<Klein<Float>>();
349        #[cfg(not(feature = "rayon"))]
350        return n_events_weighted_partitioned.iter().sum();
351    }
352
353    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`].
354    pub fn n_events_weighted(&self) -> Float {
355        #[cfg(feature = "mpi")]
356        {
357            if let Some(world) = crate::mpi::get_world() {
358                return self.n_events_weighted_mpi(&world);
359            }
360        }
361        self.n_events_weighted_local()
362    }
363
364    /// Generate a new dataset with the same length by resampling the events in the original datset
365    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
366    ///
367    /// # Notes
368    ///
369    /// This method is not intended to be called in analyses but rather in writing methods
370    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
371    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
372        let mut rng = fastrand::Rng::with_seed(seed as u64);
373        let mut indices: Vec<usize> = (0..self.n_events())
374            .map(|_| rng.usize(0..self.n_events()))
375            .collect::<Vec<usize>>();
376        indices.sort();
377        #[cfg(feature = "rayon")]
378        let bootstrapped_events: Vec<Arc<Event>> = indices
379            .into_par_iter()
380            .map(|idx| self.events[idx].clone())
381            .collect();
382        #[cfg(not(feature = "rayon"))]
383        let bootstrapped_events: Vec<Arc<Event>> = indices
384            .into_iter()
385            .map(|idx| self.events[idx].clone())
386            .collect();
387        Arc::new(Dataset {
388            events: bootstrapped_events,
389        })
390    }
391
392    /// Generate a new dataset with the same length by resampling the events in the original datset
393    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
394    ///
395    /// # Notes
396    ///
397    /// This method is not intended to be called in analyses but rather in writing methods
398    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
399    #[cfg(feature = "mpi")]
400    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
401        let n_events = self.n_events();
402        let mut indices: Vec<usize> = vec![0; n_events];
403        if world.is_root() {
404            let mut rng = fastrand::Rng::with_seed(seed as u64);
405            indices = (0..n_events)
406                .map(|_| rng.usize(0..n_events))
407                .collect::<Vec<usize>>();
408            indices.sort();
409        }
410        world.process_at_root().broadcast_into(&mut indices);
411        #[cfg(feature = "rayon")]
412        let bootstrapped_events: Vec<Arc<Event>> = indices
413            .into_par_iter()
414            .map(|idx| self.events[idx].clone())
415            .collect();
416        #[cfg(not(feature = "rayon"))]
417        let bootstrapped_events: Vec<Arc<Event>> = indices
418            .into_iter()
419            .map(|idx| self.events[idx].clone())
420            .collect();
421        Arc::new(Dataset {
422            events: bootstrapped_events,
423        })
424    }
425
426    /// Generate a new dataset with the same length by resampling the events in the original datset
427    /// with replacement. This can be used to perform error analysis via the bootstrap method.
428    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
429        #[cfg(feature = "mpi")]
430        {
431            if let Some(world) = crate::mpi::get_world() {
432                return self.bootstrap_mpi(seed, &world);
433            }
434        }
435        self.bootstrap_local(seed)
436    }
437
438    /// Filter the [`Dataset`] by a given `predicate`, selecting events for which the predicate
439    /// returns `true`.
440    pub fn filter<P>(&self, predicate: P) -> Arc<Dataset>
441    where
442        P: Fn(&Event) -> bool + Send + Sync,
443    {
444        #[cfg(feature = "rayon")]
445        let filtered_events = self
446            .events
447            .par_iter()
448            .filter(|e| predicate(e))
449            .cloned()
450            .collect();
451        #[cfg(not(feature = "rayon"))]
452        let filtered_events = self
453            .events
454            .iter()
455            .filter(|e| predicate(e))
456            .cloned()
457            .collect();
458        Arc::new(Dataset {
459            events: filtered_events,
460        })
461    }
462
463    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
464    /// given `range`.
465    pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
466    where
467        V: Variable,
468    {
469        let bin_width = (range.1 - range.0) / bins as Float;
470        let bin_edges = get_bin_edges(bins, range);
471        #[cfg(feature = "rayon")]
472        let evaluated: Vec<(usize, &Arc<Event>)> = self
473            .events
474            .par_iter()
475            .filter_map(|event| {
476                let value = variable.value(event.as_ref());
477                if value >= range.0 && value < range.1 {
478                    let bin_index = ((value - range.0) / bin_width) as usize;
479                    let bin_index = bin_index.min(bins - 1);
480                    Some((bin_index, event))
481                } else {
482                    None
483                }
484            })
485            .collect();
486        #[cfg(not(feature = "rayon"))]
487        let evaluated: Vec<(usize, &Arc<Event>)> = self
488            .events
489            .iter()
490            .filter_map(|event| {
491                let value = variable.value(event.as_ref());
492                if value >= range.0 && value < range.1 {
493                    let bin_index = ((value - range.0) / bin_width) as usize;
494                    let bin_index = bin_index.min(bins - 1);
495                    Some((bin_index, event))
496                } else {
497                    None
498                }
499            })
500            .collect();
501        let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
502        for (bin_index, event) in evaluated {
503            binned_events[bin_index].push(event.clone());
504        }
505        BinnedDataset {
506            #[cfg(feature = "rayon")]
507            datasets: binned_events
508                .into_par_iter()
509                .map(|events| Arc::new(Dataset { events }))
510                .collect(),
511            #[cfg(not(feature = "rayon"))]
512            datasets: binned_events
513                .into_iter()
514                .map(|events| Arc::new(Dataset { events }))
515                .collect(),
516            edges: bin_edges,
517        }
518    }
519}
520
521impl_op_ex!(+ |a: &Dataset, b: &Dataset| ->  Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
522
523fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
524    let mut p4s = Vec::new();
525    let mut aux = Vec::new();
526
527    let p4_count = batch
528        .schema()
529        .fields()
530        .iter()
531        .filter(|field| field.name().starts_with(P4_PREFIX))
532        .count()
533        / 4;
534    let aux_count = batch
535        .schema()
536        .fields()
537        .iter()
538        .filter(|field| field.name().starts_with(AUX_PREFIX))
539        .count()
540        / 3;
541
542    for i in 0..p4_count {
543        let e = batch
544            .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
545            .unwrap()
546            .as_any()
547            .downcast_ref::<Float32Array>()
548            .unwrap()
549            .value(row) as Float;
550        let px = batch
551            .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
552            .unwrap()
553            .as_any()
554            .downcast_ref::<Float32Array>()
555            .unwrap()
556            .value(row) as Float;
557        let py = batch
558            .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
559            .unwrap()
560            .as_any()
561            .downcast_ref::<Float32Array>()
562            .unwrap()
563            .value(row) as Float;
564        let pz = batch
565            .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
566            .unwrap()
567            .as_any()
568            .downcast_ref::<Float32Array>()
569            .unwrap()
570            .value(row) as Float;
571        p4s.push(Vec4::new(px, py, pz, e));
572    }
573
574    // TODO: insert empty vectors if not provided
575    for i in 0..aux_count {
576        let x = batch
577            .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
578            .unwrap()
579            .as_any()
580            .downcast_ref::<Float32Array>()
581            .unwrap()
582            .value(row) as Float;
583        let y = batch
584            .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
585            .unwrap()
586            .as_any()
587            .downcast_ref::<Float32Array>()
588            .unwrap()
589            .value(row) as Float;
590        let z = batch
591            .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
592            .unwrap()
593            .as_any()
594            .downcast_ref::<Float32Array>()
595            .unwrap()
596            .value(row) as Float;
597        aux.push(Vec3::new(x, y, z));
598    }
599
600    let weight = batch
601        .column(19)
602        .as_any()
603        .downcast_ref::<Float32Array>()
604        .unwrap()
605        .value(row) as Float;
606
607    Event { p4s, aux, weight }
608}
609
610/// Open a Parquet file and read the data into a [`Dataset`].
611pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
612    // TODO: make this read in directly to MPI ranks
613    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
614    let file = File::open(file_path)?;
615    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
616    let reader = builder.build()?;
617    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
618
619    #[cfg(feature = "rayon")]
620    let events: Vec<Arc<Event>> = batches
621        .into_par_iter()
622        .flat_map(|batch| {
623            let num_rows = batch.num_rows();
624            let mut local_events = Vec::with_capacity(num_rows);
625
626            // Process each row in the batch
627            for row in 0..num_rows {
628                let event = batch_to_event(&batch, row);
629                local_events.push(Arc::new(event));
630            }
631            local_events
632        })
633        .collect();
634    #[cfg(not(feature = "rayon"))]
635    let events: Vec<Arc<Event>> = batches
636        .into_iter()
637        .flat_map(|batch| {
638            let num_rows = batch.num_rows();
639            let mut local_events = Vec::with_capacity(num_rows);
640
641            // Process each row in the batch
642            for row in 0..num_rows {
643                let event = batch_to_event(&batch, row);
644                local_events.push(Arc::new(event));
645            }
646            local_events
647        })
648        .collect();
649    Ok(Arc::new(Dataset::new(events)))
650}
651
652/// A list of [`Dataset`]s formed by binning [`Event`]s by some [`Variable`].
653pub struct BinnedDataset {
654    datasets: Vec<Arc<Dataset>>,
655    edges: Vec<Float>,
656}
657
658impl Index<usize> for BinnedDataset {
659    type Output = Arc<Dataset>;
660
661    fn index(&self, index: usize) -> &Self::Output {
662        &self.datasets[index]
663    }
664}
665
666impl IndexMut<usize> for BinnedDataset {
667    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
668        &mut self.datasets[index]
669    }
670}
671
672impl Deref for BinnedDataset {
673    type Target = Vec<Arc<Dataset>>;
674
675    fn deref(&self) -> &Self::Target {
676        &self.datasets
677    }
678}
679
680impl DerefMut for BinnedDataset {
681    fn deref_mut(&mut self) -> &mut Self::Target {
682        &mut self.datasets
683    }
684}
685
686impl BinnedDataset {
687    /// The number of bins in the [`BinnedDataset`].
688    pub fn n_bins(&self) -> usize {
689        self.datasets.len()
690    }
691
692    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
693    pub fn edges(&self) -> Vec<Float> {
694        self.edges.clone()
695    }
696
697    /// Returns the range that was used to form the [`BinnedDataset`].
698    pub fn range(&self) -> (Float, Float) {
699        (self.edges[0], self.edges[self.n_bins()])
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use approx::{assert_relative_eq, assert_relative_ne};
707    use serde::{Deserialize, Serialize};
708    #[test]
709    fn test_event_creation() {
710        let event = test_event();
711        assert_eq!(event.p4s.len(), 4);
712        assert_eq!(event.aux.len(), 1);
713        assert_relative_eq!(event.weight, 0.48)
714    }
715
716    #[test]
717    fn test_event_p4_sum() {
718        let event = test_event();
719        let sum = event.get_p4_sum([2, 3]);
720        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
721        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
722        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
723        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
724    }
725
726    #[test]
727    fn test_dataset_size_check() {
728        let mut dataset = Dataset::default();
729        assert_eq!(dataset.n_events(), 0);
730        dataset.events.push(Arc::new(test_event()));
731        assert_eq!(dataset.n_events(), 1);
732    }
733
734    #[test]
735    fn test_dataset_sum() {
736        let dataset = test_dataset();
737        let dataset2 = Dataset::new(vec![Arc::new(Event {
738            p4s: test_event().p4s,
739            aux: test_event().aux,
740            weight: 0.52,
741        })]);
742        let dataset_sum = &dataset + &dataset2;
743        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
744        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
745    }
746
747    #[test]
748    fn test_dataset_weights() {
749        let mut dataset = Dataset::default();
750        dataset.events.push(Arc::new(test_event()));
751        dataset.events.push(Arc::new(Event {
752            p4s: test_event().p4s,
753            aux: test_event().aux,
754            weight: 0.52,
755        }));
756        let weights = dataset.weights();
757        assert_eq!(weights.len(), 2);
758        assert_relative_eq!(weights[0], 0.48);
759        assert_relative_eq!(weights[1], 0.52);
760        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
761    }
762
763    #[test]
764    fn test_dataset_filtering() {
765        let mut dataset = test_dataset();
766        dataset.events.push(Arc::new(Event {
767            p4s: vec![
768                Vec3::new(0.0, 0.0, 5.0).with_mass(0.0),
769                Vec3::new(0.0, 0.0, 1.0).with_mass(1.0),
770            ],
771            aux: vec![],
772            weight: 1.0,
773        }));
774
775        let filtered = dataset.filter(|event| event.p4s.len() == 2);
776        assert_eq!(filtered.n_events(), 1);
777        assert_eq!(filtered[0].p4s.len(), 2);
778    }
779
780    #[test]
781    fn test_binned_dataset() {
782        let dataset = Dataset::new(vec![
783            Arc::new(Event {
784                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
785                aux: vec![],
786                weight: 1.0,
787            }),
788            Arc::new(Event {
789                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
790                aux: vec![],
791                weight: 2.0,
792            }),
793        ]);
794
795        #[derive(Clone, Serialize, Deserialize, Debug)]
796        struct BeamEnergy;
797        impl Display for BeamEnergy {
798            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
799                write!(f, "BeamEnergy")
800            }
801        }
802        #[typetag::serde]
803        impl Variable for BeamEnergy {
804            fn value(&self, event: &Event) -> Float {
805                event.p4s[0].e()
806            }
807        }
808        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
809
810        // Test binning by first particle energy
811        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
812
813        assert_eq!(binned.n_bins(), 2);
814        assert_eq!(binned.edges().len(), 3);
815        assert_relative_eq!(binned.edges()[0], 0.0);
816        assert_relative_eq!(binned.edges()[2], 3.0);
817        assert_eq!(binned[0].n_events(), 1);
818        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
819        assert_eq!(binned[1].n_events(), 1);
820        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
821    }
822
823    #[test]
824    fn test_dataset_bootstrap() {
825        let mut dataset = test_dataset();
826        dataset.events.push(Arc::new(Event {
827            p4s: test_event().p4s.clone(),
828            aux: test_event().aux.clone(),
829            weight: 1.0,
830        }));
831        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
832
833        let bootstrapped = dataset.bootstrap(43);
834        assert_eq!(bootstrapped.n_events(), dataset.n_events());
835        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
836
837        // Test empty dataset bootstrap
838        let empty_dataset = Dataset::default();
839        let empty_bootstrap = empty_dataset.bootstrap(43);
840        assert_eq!(empty_bootstrap.n_events(), 0);
841    }
842
843    #[test]
844    fn test_event_display() {
845        let event = test_event();
846        let display_string = format!("{}", event);
847        assert_eq!(
848            display_string,
849            "Event:\n  p4s:\n    [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n    [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n    [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n    [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n  eps:\n    [0.385, 0.022, 0]\n  weight:\n    0.48\n"
850        );
851    }
852}