laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23    utils::{
24        variables::Variable,
25        vectors::{Vec3, Vec4},
26    },
27    Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33/// An event that can be used to test the implementation of an
34/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
35/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
36pub fn test_event() -> Event {
37    use crate::utils::vectors::*;
38    Event {
39        p4s: vec![
40            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
41            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
42            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
43            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
44        ],
45        aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46        weight: 0.48,
47    }
48}
49
50/// An dataset that can be used to test the implementation of an
51/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a singular
52/// [`Event`] generated from [`test_event`].
53pub fn test_dataset() -> Dataset {
54    Dataset::new(vec![Arc::new(test_event())])
55}
56
57/// A single event in a [`Dataset`] containing all the relevant particle information.
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60    /// A list of four-momenta for each particle.
61    pub p4s: Vec<Vec4>,
62    /// A list of auxiliary vectors which can be used to store data like particle polarization.
63    pub aux: Vec<Vec3>,
64    /// The weight given to the event.
65    pub weight: Float,
66}
67
68impl Display for Event {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        writeln!(f, "Event:")?;
71        writeln!(f, "  p4s:")?;
72        for p4 in &self.p4s {
73            writeln!(f, "    {}", p4.to_p4_string())?;
74        }
75        writeln!(f, "  eps:")?;
76        for eps_vec in &self.aux {
77            writeln!(f, "    [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78        }
79        writeln!(f, "  weight:")?;
80        writeln!(f, "    {}", self.weight)?;
81        Ok(())
82    }
83}
84
85impl Event {
86    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`Event`].
87    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89    }
90}
91
92/// A collection of [`Event`]s.
93#[derive(Debug, Clone, Default)]
94pub struct Dataset {
95    /// The [`Event`]s contained in the [`Dataset`]
96    pub events: Vec<Arc<Event>>,
97}
98
99impl Dataset {
100    /// Get a reference to the [`Event`] at the given index in the [`Dataset`] (non-MPI
101    /// version).
102    ///
103    /// # Notes
104    ///
105    /// This method is not intended to be called in analyses but rather in writing methods
106    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
107    /// as if it were any other [`Vec`]:
108    ///
109    /// ```ignore
110    /// let ds: Dataset = Dataset::new(events);
111    /// let event_0 = ds[0];
112    /// ```
113    pub fn index_local(&self, index: usize) -> &Event {
114        &self.events[index]
115    }
116
117    #[cfg(feature = "mpi")]
118    fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
119        for (i, &displ) in displs.iter().enumerate() {
120            if displ as usize > index {
121                return (i as i32 - 1, index - displs[i - 1] as usize);
122            }
123        }
124        (
125            world.size() - 1,
126            index - displs[world.size() as usize - 1] as usize,
127        )
128    }
129
130    #[cfg(feature = "mpi")]
131    fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
132        let (counts, displs) = world.get_counts_displs(events.len());
133        counts
134            .iter()
135            .zip(displs.iter())
136            .map(|(&count, &displ)| {
137                events
138                    .iter()
139                    .skip(displ as usize)
140                    .take(count as usize)
141                    .cloned()
142                    .collect()
143            })
144            .collect()
145    }
146
147    /// Get a reference to the [`Event`] at the given index in the [`Dataset`]
148    /// (MPI-compatible version).
149    ///
150    /// # Notes
151    ///
152    /// This method is not intended to be called in analyses but rather in writing methods
153    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
154    /// as if it were any other [`Vec`]:
155    ///
156    /// ```ignore
157    /// let ds: Dataset = Dataset::new(events);
158    /// let event_0 = ds[0];
159    /// ```
160    #[cfg(feature = "mpi")]
161    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
162        let (_, displs) = world.get_counts_displs(self.n_events());
163        let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
164        let mut serialized_event_buffer_len: usize = 0;
165        let mut serialized_event_buffer: Vec<u8> = Vec::default();
166        let config = bincode::config::standard();
167        if world.rank() == owning_rank {
168            let event = self.index_local(local_index);
169            serialized_event_buffer = bincode::serde::encode_to_vec(event, config).unwrap();
170            serialized_event_buffer_len = serialized_event_buffer.len();
171        }
172        world
173            .process_at_rank(owning_rank)
174            .broadcast_into(&mut serialized_event_buffer_len);
175        if world.rank() != owning_rank {
176            serialized_event_buffer = vec![0; serialized_event_buffer_len];
177        }
178        world
179            .process_at_rank(owning_rank)
180            .broadcast_into(&mut serialized_event_buffer);
181        let (event, _): (Event, usize) =
182            bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
183        Box::leak(Box::new(event))
184    }
185}
186
187impl Index<usize> for Dataset {
188    type Output = Event;
189
190    fn index(&self, index: usize) -> &Self::Output {
191        #[cfg(feature = "mpi")]
192        {
193            if let Some(world) = crate::mpi::get_world() {
194                return self.index_mpi(index, &world);
195            }
196        }
197        self.index_local(index)
198    }
199}
200
201impl Dataset {
202    /// Create a new [`Dataset`] from a list of [`Event`]s (non-MPI version).
203    ///
204    /// # Notes
205    ///
206    /// This method is not intended to be called in analyses but rather in writing methods
207    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
208    pub fn new_local(events: Vec<Arc<Event>>) -> Self {
209        Dataset { events }
210    }
211
212    /// Create a new [`Dataset`] from a list of [`Event`]s (MPI-compatible version).
213    ///
214    /// # Notes
215    ///
216    /// This method is not intended to be called in analyses but rather in writing methods
217    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
218    #[cfg(feature = "mpi")]
219    pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
220        Dataset {
221            events: Dataset::partition(events, world)[world.rank() as usize].clone(),
222        }
223    }
224
225    /// Create a new [`Dataset`] from a list of [`Event`]s.
226    ///
227    /// This method is prefered for external use because it contains proper MPI construction
228    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
229    /// interfacing with MPI and should be avoided unless you know what you are doing.
230    pub fn new(events: Vec<Arc<Event>>) -> Self {
231        #[cfg(feature = "mpi")]
232        {
233            if let Some(world) = crate::mpi::get_world() {
234                return Dataset::new_mpi(events, &world);
235            }
236        }
237        Dataset::new_local(events)
238    }
239
240    /// The number of [`Event`]s in the [`Dataset`] (non-MPI version).
241    ///
242    /// # Notes
243    ///
244    /// This method is not intended to be called in analyses but rather in writing methods
245    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
246    pub fn n_events_local(&self) -> usize {
247        self.events.len()
248    }
249
250    /// The number of [`Event`]s in the [`Dataset`] (MPI-compatible version).
251    ///
252    /// # Notes
253    ///
254    /// This method is not intended to be called in analyses but rather in writing methods
255    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
256    #[cfg(feature = "mpi")]
257    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
258        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
259        let n_events_local = self.n_events_local();
260        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
261        n_events_partitioned.iter().sum()
262    }
263
264    /// The number of [`Event`]s in the [`Dataset`].
265    pub fn n_events(&self) -> usize {
266        #[cfg(feature = "mpi")]
267        {
268            if let Some(world) = crate::mpi::get_world() {
269                return self.n_events_mpi(&world);
270            }
271        }
272        self.n_events_local()
273    }
274}
275
276impl Dataset {
277    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (non-MPI version).
278    ///
279    /// # Notes
280    ///
281    /// This method is not intended to be called in analyses but rather in writing methods
282    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
283    pub fn weights_local(&self) -> Vec<Float> {
284        #[cfg(feature = "rayon")]
285        return self.events.par_iter().map(|e| e.weight).collect();
286        #[cfg(not(feature = "rayon"))]
287        return self.events.iter().map(|e| e.weight).collect();
288    }
289
290    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (MPI-compatible version).
291    ///
292    /// # Notes
293    ///
294    /// This method is not intended to be called in analyses but rather in writing methods
295    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
296    #[cfg(feature = "mpi")]
297    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
298        let local_weights = self.weights_local();
299        let n_events = self.n_events();
300        let mut buffer: Vec<Float> = vec![0.0; n_events];
301        let (counts, displs) = world.get_counts_displs(n_events);
302        {
303            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
304            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
305        }
306        buffer
307    }
308
309    /// Extract a list of weights over each [`Event`] in the [`Dataset`].
310    pub fn weights(&self) -> Vec<Float> {
311        #[cfg(feature = "mpi")]
312        {
313            if let Some(world) = crate::mpi::get_world() {
314                return self.weights_mpi(&world);
315            }
316        }
317        self.weights_local()
318    }
319
320    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (non-MPI version).
321    ///
322    /// # Notes
323    ///
324    /// This method is not intended to be called in analyses but rather in writing methods
325    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
326    pub fn n_events_weighted_local(&self) -> Float {
327        #[cfg(feature = "rayon")]
328        return self
329            .events
330            .par_iter()
331            .map(|e| e.weight)
332            .parallel_sum_with_accumulator::<Klein<Float>>();
333        #[cfg(not(feature = "rayon"))]
334        return self.events.iter().map(|e| e.weight).sum();
335    }
336    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (MPI-compatible version).
337    ///
338    /// # Notes
339    ///
340    /// This method is not intended to be called in analyses but rather in writing methods
341    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
342    #[cfg(feature = "mpi")]
343    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
344        let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
345        let n_events_weighted_local = self.n_events_weighted_local();
346        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
347        #[cfg(feature = "rayon")]
348        return n_events_weighted_partitioned
349            .into_par_iter()
350            .parallel_sum_with_accumulator::<Klein<Float>>();
351        #[cfg(not(feature = "rayon"))]
352        return n_events_weighted_partitioned.iter().sum();
353    }
354
355    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`].
356    pub fn n_events_weighted(&self) -> Float {
357        #[cfg(feature = "mpi")]
358        {
359            if let Some(world) = crate::mpi::get_world() {
360                return self.n_events_weighted_mpi(&world);
361            }
362        }
363        self.n_events_weighted_local()
364    }
365
366    /// Generate a new dataset with the same length by resampling the events in the original datset
367    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
368    ///
369    /// # Notes
370    ///
371    /// This method is not intended to be called in analyses but rather in writing methods
372    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
373    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
374        let mut rng = fastrand::Rng::with_seed(seed as u64);
375        let mut indices: Vec<usize> = (0..self.n_events())
376            .map(|_| rng.usize(0..self.n_events()))
377            .collect::<Vec<usize>>();
378        indices.sort();
379        #[cfg(feature = "rayon")]
380        let bootstrapped_events: Vec<Arc<Event>> = indices
381            .into_par_iter()
382            .map(|idx| self.events[idx].clone())
383            .collect();
384        #[cfg(not(feature = "rayon"))]
385        let bootstrapped_events: Vec<Arc<Event>> = indices
386            .into_iter()
387            .map(|idx| self.events[idx].clone())
388            .collect();
389        Arc::new(Dataset {
390            events: bootstrapped_events,
391        })
392    }
393
394    /// Generate a new dataset with the same length by resampling the events in the original datset
395    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
396    ///
397    /// # Notes
398    ///
399    /// This method is not intended to be called in analyses but rather in writing methods
400    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
401    #[cfg(feature = "mpi")]
402    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
403        let n_events = self.n_events();
404        let mut indices: Vec<usize> = vec![0; n_events];
405        if world.is_root() {
406            let mut rng = fastrand::Rng::with_seed(seed as u64);
407            indices = (0..n_events)
408                .map(|_| rng.usize(0..n_events))
409                .collect::<Vec<usize>>();
410            indices.sort();
411        }
412        world.process_at_root().broadcast_into(&mut indices);
413        #[cfg(feature = "rayon")]
414        let bootstrapped_events: Vec<Arc<Event>> = indices
415            .into_par_iter()
416            .map(|idx| self.events[idx].clone())
417            .collect();
418        #[cfg(not(feature = "rayon"))]
419        let bootstrapped_events: Vec<Arc<Event>> = indices
420            .into_iter()
421            .map(|idx| self.events[idx].clone())
422            .collect();
423        Arc::new(Dataset {
424            events: bootstrapped_events,
425        })
426    }
427
428    /// Generate a new dataset with the same length by resampling the events in the original datset
429    /// with replacement. This can be used to perform error analysis via the bootstrap method.
430    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
431        #[cfg(feature = "mpi")]
432        {
433            if let Some(world) = crate::mpi::get_world() {
434                return self.bootstrap_mpi(seed, &world);
435            }
436        }
437        self.bootstrap_local(seed)
438    }
439
440    /// Filter the [`Dataset`] by a given `predicate`, selecting events for which the predicate
441    /// returns `true`.
442    pub fn filter<P>(&self, predicate: P) -> Arc<Dataset>
443    where
444        P: Fn(&Event) -> bool + Send + Sync,
445    {
446        #[cfg(feature = "rayon")]
447        let filtered_events = self
448            .events
449            .par_iter()
450            .filter(|e| predicate(e))
451            .cloned()
452            .collect();
453        #[cfg(not(feature = "rayon"))]
454        let filtered_events = self
455            .events
456            .iter()
457            .filter(|e| predicate(e))
458            .cloned()
459            .collect();
460        Arc::new(Dataset {
461            events: filtered_events,
462        })
463    }
464
465    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
466    /// given `range`.
467    pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
468    where
469        V: Variable,
470    {
471        let bin_width = (range.1 - range.0) / bins as Float;
472        let bin_edges = get_bin_edges(bins, range);
473        #[cfg(feature = "rayon")]
474        let evaluated: Vec<(usize, &Arc<Event>)> = self
475            .events
476            .par_iter()
477            .filter_map(|event| {
478                let value = variable.value(event.as_ref());
479                if value >= range.0 && value < range.1 {
480                    let bin_index = ((value - range.0) / bin_width) as usize;
481                    let bin_index = bin_index.min(bins - 1);
482                    Some((bin_index, event))
483                } else {
484                    None
485                }
486            })
487            .collect();
488        #[cfg(not(feature = "rayon"))]
489        let evaluated: Vec<(usize, &Arc<Event>)> = self
490            .events
491            .iter()
492            .filter_map(|event| {
493                let value = variable.value(event.as_ref());
494                if value >= range.0 && value < range.1 {
495                    let bin_index = ((value - range.0) / bin_width) as usize;
496                    let bin_index = bin_index.min(bins - 1);
497                    Some((bin_index, event))
498                } else {
499                    None
500                }
501            })
502            .collect();
503        let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
504        for (bin_index, event) in evaluated {
505            binned_events[bin_index].push(event.clone());
506        }
507        BinnedDataset {
508            #[cfg(feature = "rayon")]
509            datasets: binned_events
510                .into_par_iter()
511                .map(|events| Arc::new(Dataset { events }))
512                .collect(),
513            #[cfg(not(feature = "rayon"))]
514            datasets: binned_events
515                .into_iter()
516                .map(|events| Arc::new(Dataset { events }))
517                .collect(),
518            edges: bin_edges,
519        }
520    }
521}
522
523impl_op_ex!(+ |a: &Dataset, b: &Dataset| ->  Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
524
525fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
526    let mut p4s = Vec::new();
527    let mut aux = Vec::new();
528
529    let p4_count = batch
530        .schema()
531        .fields()
532        .iter()
533        .filter(|field| field.name().starts_with(P4_PREFIX))
534        .count()
535        / 4;
536    let aux_count = batch
537        .schema()
538        .fields()
539        .iter()
540        .filter(|field| field.name().starts_with(AUX_PREFIX))
541        .count()
542        / 3;
543
544    for i in 0..p4_count {
545        let e = batch
546            .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
547            .unwrap()
548            .as_any()
549            .downcast_ref::<Float32Array>()
550            .unwrap()
551            .value(row) as Float;
552        let px = batch
553            .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
554            .unwrap()
555            .as_any()
556            .downcast_ref::<Float32Array>()
557            .unwrap()
558            .value(row) as Float;
559        let py = batch
560            .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
561            .unwrap()
562            .as_any()
563            .downcast_ref::<Float32Array>()
564            .unwrap()
565            .value(row) as Float;
566        let pz = batch
567            .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
568            .unwrap()
569            .as_any()
570            .downcast_ref::<Float32Array>()
571            .unwrap()
572            .value(row) as Float;
573        p4s.push(Vec4::new(px, py, pz, e));
574    }
575
576    // TODO: insert empty vectors if not provided
577    for i in 0..aux_count {
578        let x = batch
579            .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
580            .unwrap()
581            .as_any()
582            .downcast_ref::<Float32Array>()
583            .unwrap()
584            .value(row) as Float;
585        let y = batch
586            .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
587            .unwrap()
588            .as_any()
589            .downcast_ref::<Float32Array>()
590            .unwrap()
591            .value(row) as Float;
592        let z = batch
593            .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
594            .unwrap()
595            .as_any()
596            .downcast_ref::<Float32Array>()
597            .unwrap()
598            .value(row) as Float;
599        aux.push(Vec3::new(x, y, z));
600    }
601
602    let weight = batch
603        .column(19)
604        .as_any()
605        .downcast_ref::<Float32Array>()
606        .unwrap()
607        .value(row) as Float;
608
609    Event { p4s, aux, weight }
610}
611
612/// Open a Parquet file and read the data into a [`Dataset`].
613pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
614    // TODO: make this read in directly to MPI ranks
615    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
616    let file = File::open(file_path)?;
617    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
618    let reader = builder.build()?;
619    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
620
621    #[cfg(feature = "rayon")]
622    let events: Vec<Arc<Event>> = batches
623        .into_par_iter()
624        .flat_map(|batch| {
625            let num_rows = batch.num_rows();
626            let mut local_events = Vec::with_capacity(num_rows);
627
628            // Process each row in the batch
629            for row in 0..num_rows {
630                let event = batch_to_event(&batch, row);
631                local_events.push(Arc::new(event));
632            }
633            local_events
634        })
635        .collect();
636    #[cfg(not(feature = "rayon"))]
637    let events: Vec<Arc<Event>> = batches
638        .into_iter()
639        .flat_map(|batch| {
640            let num_rows = batch.num_rows();
641            let mut local_events = Vec::with_capacity(num_rows);
642
643            // Process each row in the batch
644            for row in 0..num_rows {
645                let event = batch_to_event(&batch, row);
646                local_events.push(Arc::new(event));
647            }
648            local_events
649        })
650        .collect();
651    Ok(Arc::new(Dataset::new(events)))
652}
653
654/// A list of [`Dataset`]s formed by binning [`Event`]s by some [`Variable`].
655pub struct BinnedDataset {
656    datasets: Vec<Arc<Dataset>>,
657    edges: Vec<Float>,
658}
659
660impl Index<usize> for BinnedDataset {
661    type Output = Arc<Dataset>;
662
663    fn index(&self, index: usize) -> &Self::Output {
664        &self.datasets[index]
665    }
666}
667
668impl IndexMut<usize> for BinnedDataset {
669    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
670        &mut self.datasets[index]
671    }
672}
673
674impl Deref for BinnedDataset {
675    type Target = Vec<Arc<Dataset>>;
676
677    fn deref(&self) -> &Self::Target {
678        &self.datasets
679    }
680}
681
682impl DerefMut for BinnedDataset {
683    fn deref_mut(&mut self) -> &mut Self::Target {
684        &mut self.datasets
685    }
686}
687
688impl BinnedDataset {
689    /// The number of bins in the [`BinnedDataset`].
690    pub fn n_bins(&self) -> usize {
691        self.datasets.len()
692    }
693
694    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
695    pub fn edges(&self) -> Vec<Float> {
696        self.edges.clone()
697    }
698
699    /// Returns the range that was used to form the [`BinnedDataset`].
700    pub fn range(&self) -> (Float, Float) {
701        (self.edges[0], self.edges[self.n_bins()])
702    }
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use approx::{assert_relative_eq, assert_relative_ne};
709    use serde::{Deserialize, Serialize};
710    #[test]
711    fn test_event_creation() {
712        let event = test_event();
713        assert_eq!(event.p4s.len(), 4);
714        assert_eq!(event.aux.len(), 1);
715        assert_relative_eq!(event.weight, 0.48)
716    }
717
718    #[test]
719    fn test_event_p4_sum() {
720        let event = test_event();
721        let sum = event.get_p4_sum([2, 3]);
722        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
723        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
724        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
725        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
726    }
727
728    #[test]
729    fn test_dataset_size_check() {
730        let mut dataset = Dataset::default();
731        assert_eq!(dataset.n_events(), 0);
732        dataset.events.push(Arc::new(test_event()));
733        assert_eq!(dataset.n_events(), 1);
734    }
735
736    #[test]
737    fn test_dataset_sum() {
738        let dataset = test_dataset();
739        let dataset2 = Dataset::new(vec![Arc::new(Event {
740            p4s: test_event().p4s,
741            aux: test_event().aux,
742            weight: 0.52,
743        })]);
744        let dataset_sum = &dataset + &dataset2;
745        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
746        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
747    }
748
749    #[test]
750    fn test_dataset_weights() {
751        let mut dataset = Dataset::default();
752        dataset.events.push(Arc::new(test_event()));
753        dataset.events.push(Arc::new(Event {
754            p4s: test_event().p4s,
755            aux: test_event().aux,
756            weight: 0.52,
757        }));
758        let weights = dataset.weights();
759        assert_eq!(weights.len(), 2);
760        assert_relative_eq!(weights[0], 0.48);
761        assert_relative_eq!(weights[1], 0.52);
762        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
763    }
764
765    #[test]
766    fn test_dataset_filtering() {
767        let mut dataset = test_dataset();
768        dataset.events.push(Arc::new(Event {
769            p4s: vec![
770                Vec3::new(0.0, 0.0, 5.0).with_mass(0.0),
771                Vec3::new(0.0, 0.0, 1.0).with_mass(1.0),
772            ],
773            aux: vec![],
774            weight: 1.0,
775        }));
776
777        let filtered = dataset.filter(|event| event.p4s.len() == 2);
778        assert_eq!(filtered.n_events(), 1);
779        assert_eq!(filtered[0].p4s.len(), 2);
780    }
781
782    #[test]
783    fn test_binned_dataset() {
784        let dataset = Dataset::new(vec![
785            Arc::new(Event {
786                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
787                aux: vec![],
788                weight: 1.0,
789            }),
790            Arc::new(Event {
791                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
792                aux: vec![],
793                weight: 2.0,
794            }),
795        ]);
796
797        #[derive(Clone, Serialize, Deserialize, Debug)]
798        struct BeamEnergy;
799        impl Display for BeamEnergy {
800            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
801                write!(f, "BeamEnergy")
802            }
803        }
804        #[typetag::serde]
805        impl Variable for BeamEnergy {
806            fn value(&self, event: &Event) -> Float {
807                event.p4s[0].e()
808            }
809        }
810        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
811
812        // Test binning by first particle energy
813        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
814
815        assert_eq!(binned.n_bins(), 2);
816        assert_eq!(binned.edges().len(), 3);
817        assert_relative_eq!(binned.edges()[0], 0.0);
818        assert_relative_eq!(binned.edges()[2], 3.0);
819        assert_eq!(binned[0].n_events(), 1);
820        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
821        assert_eq!(binned[1].n_events(), 1);
822        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
823    }
824
825    #[test]
826    fn test_dataset_bootstrap() {
827        let mut dataset = test_dataset();
828        dataset.events.push(Arc::new(Event {
829            p4s: test_event().p4s.clone(),
830            aux: test_event().aux.clone(),
831            weight: 1.0,
832        }));
833        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
834
835        let bootstrapped = dataset.bootstrap(43);
836        assert_eq!(bootstrapped.n_events(), dataset.n_events());
837        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
838
839        // Test empty dataset bootstrap
840        let empty_dataset = Dataset::default();
841        let empty_bootstrap = empty_dataset.bootstrap(43);
842        assert_eq!(empty_bootstrap.n_events(), 0);
843    }
844
845    #[test]
846    fn test_event_display() {
847        let event = test_event();
848        let display_string = format!("{}", event);
849        assert_eq!(
850            display_string,
851            "Event:\n  p4s:\n    [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n    [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n    [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n    [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n  eps:\n    [0.385, 0.022, 0]\n  weight:\n    0.48\n"
852        );
853    }
854}