laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23    utils::{
24        variables::Variable,
25        vectors::{Vec3, Vec4},
26    },
27    Float, LadduError,
28};
29
30/// An event that can be used to test the implementation of an
31/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
32/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
33pub fn test_event() -> Event {
34    use crate::utils::vectors::*;
35    Event {
36        p4s: vec![
37            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
38            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
39            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
40            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
41        ],
42        eps: vec![Vec3::new(0.385, 0.022, 0.000)],
43        weight: 0.48,
44    }
45}
46
47/// An dataset that can be used to test the implementation of an
48/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a singular
49/// [`Event`] generated from [`test_event`].
50pub fn test_dataset() -> Dataset {
51    Dataset::new(vec![Arc::new(test_event())])
52}
53
54/// A single event in a [`Dataset`] containing all the relevant particle information.
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct Event {
57    /// A list of four-momenta for each particle.
58    pub p4s: Vec<Vec4>,
59    /// A list of polarization vectors for each particle.
60    pub eps: Vec<Vec3>,
61    /// The weight given to the event.
62    pub weight: Float,
63}
64
65impl Display for Event {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        writeln!(f, "Event:")?;
68        writeln!(f, "  p4s:")?;
69        for p4 in &self.p4s {
70            writeln!(f, "    {}", p4.to_p4_string())?;
71        }
72        writeln!(f, "  eps:")?;
73        for eps_vec in &self.eps {
74            writeln!(f, "    [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
75        }
76        writeln!(f, "  weight:")?;
77        writeln!(f, "    {}", self.weight)?;
78        Ok(())
79    }
80}
81
82impl Event {
83    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`Event`].
84    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
85        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
86    }
87}
88
89/// A collection of [`Event`]s.
90#[derive(Debug, Clone, Default)]
91pub struct Dataset {
92    /// The [`Event`]s contained in the [`Dataset`]
93    pub events: Vec<Arc<Event>>,
94}
95
96impl Dataset {
97    /// Get a reference to the [`Event`] at the given index in the [`Dataset`] (non-MPI
98    /// version).
99    ///
100    /// # Notes
101    ///
102    /// This method is not intended to be called in analyses but rather in writing methods
103    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
104    /// as if it were any other [`Vec`]:
105    ///
106    /// ```ignore
107    /// let ds: Dataset = Dataset::new(events);
108    /// let event_0 = ds[0];
109    /// ```
110    pub fn index_local(&self, index: usize) -> &Event {
111        &self.events[index]
112    }
113
114    #[cfg(feature = "mpi")]
115    fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
116        for (i, &displ) in displs.iter().enumerate() {
117            if displ as usize > index {
118                return (i as i32 - 1, index - displs[i - 1] as usize);
119            }
120        }
121        (
122            world.size() - 1,
123            index - displs[world.size() as usize - 1] as usize,
124        )
125    }
126
127    #[cfg(feature = "mpi")]
128    fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
129        let (counts, displs) = world.get_counts_displs(events.len());
130        counts
131            .iter()
132            .zip(displs.iter())
133            .map(|(&count, &displ)| {
134                events
135                    .iter()
136                    .skip(displ as usize)
137                    .take(count as usize)
138                    .cloned()
139                    .collect()
140            })
141            .collect()
142    }
143
144    /// Get a reference to the [`Event`] at the given index in the [`Dataset`]
145    /// (MPI-compatible version).
146    ///
147    /// # Notes
148    ///
149    /// This method is not intended to be called in analyses but rather in writing methods
150    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
151    /// as if it were any other [`Vec`]:
152    ///
153    /// ```ignore
154    /// let ds: Dataset = Dataset::new(events);
155    /// let event_0 = ds[0];
156    /// ```
157    #[cfg(feature = "mpi")]
158    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
159        let (_, displs) = world.get_counts_displs(self.n_events());
160        let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
161        let mut serialized_event_buffer_len: usize = 0;
162        let mut serialized_event_buffer: Vec<u8> = Vec::default();
163        if world.rank() == owning_rank {
164            let event = self.index_local(local_index);
165            serialized_event_buffer = bincode::serialize(event).unwrap();
166            serialized_event_buffer_len = serialized_event_buffer.len();
167        }
168        world
169            .process_at_rank(owning_rank)
170            .broadcast_into(&mut serialized_event_buffer_len);
171        if world.rank() != owning_rank {
172            serialized_event_buffer = vec![0; serialized_event_buffer_len];
173        }
174        world
175            .process_at_rank(owning_rank)
176            .broadcast_into(&mut serialized_event_buffer);
177        let event: Event = bincode::deserialize(&serialized_event_buffer).unwrap();
178        Box::leak(Box::new(event))
179    }
180}
181
182impl Index<usize> for Dataset {
183    type Output = Event;
184
185    fn index(&self, index: usize) -> &Self::Output {
186        #[cfg(feature = "mpi")]
187        {
188            if let Some(world) = crate::mpi::get_world() {
189                return self.index_mpi(index, &world);
190            }
191        }
192        self.index_local(index)
193    }
194}
195
196impl Dataset {
197    /// Create a new [`Dataset`] from a list of [`Event`]s (non-MPI version).
198    ///
199    /// # Notes
200    ///
201    /// This method is not intended to be called in analyses but rather in writing methods
202    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
203    pub fn new_local(events: Vec<Arc<Event>>) -> Self {
204        Dataset { events }
205    }
206
207    /// Create a new [`Dataset`] from a list of [`Event`]s (MPI-compatible version).
208    ///
209    /// # Notes
210    ///
211    /// This method is not intended to be called in analyses but rather in writing methods
212    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
213    #[cfg(feature = "mpi")]
214    pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
215        Dataset {
216            events: Dataset::partition(events, world)[world.rank() as usize].clone(),
217        }
218    }
219
220    /// Create a new [`Dataset`] from a list of [`Event`]s.
221    ///
222    /// This method is prefered for external use because it contains proper MPI construction
223    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
224    /// interfacing with MPI and should be avoided unless you know what you are doing.
225    pub fn new(events: Vec<Arc<Event>>) -> Self {
226        #[cfg(feature = "mpi")]
227        {
228            if let Some(world) = crate::mpi::get_world() {
229                return Dataset::new_mpi(events, &world);
230            }
231        }
232        Dataset::new_local(events)
233    }
234
235    /// The number of [`Event`]s in the [`Dataset`] (non-MPI version).
236    ///
237    /// # Notes
238    ///
239    /// This method is not intended to be called in analyses but rather in writing methods
240    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
241    pub fn n_events_local(&self) -> usize {
242        self.events.len()
243    }
244
245    /// The number of [`Event`]s in the [`Dataset`] (MPI-compatible version).
246    ///
247    /// # Notes
248    ///
249    /// This method is not intended to be called in analyses but rather in writing methods
250    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
251    #[cfg(feature = "mpi")]
252    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
253        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
254        let n_events_local = self.n_events_local();
255        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
256        n_events_partitioned.iter().sum()
257    }
258
259    /// The number of [`Event`]s in the [`Dataset`].
260    pub fn n_events(&self) -> usize {
261        #[cfg(feature = "mpi")]
262        {
263            if let Some(world) = crate::mpi::get_world() {
264                return self.n_events_mpi(&world);
265            }
266        }
267        self.n_events_local()
268    }
269}
270
271impl Dataset {
272    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (non-MPI version).
273    ///
274    /// # Notes
275    ///
276    /// This method is not intended to be called in analyses but rather in writing methods
277    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
278    pub fn weights_local(&self) -> Vec<Float> {
279        #[cfg(feature = "rayon")]
280        return self.events.par_iter().map(|e| e.weight).collect();
281        #[cfg(not(feature = "rayon"))]
282        return self.events.iter().map(|e| e.weight).collect();
283    }
284
285    /// Extract a list of weights over each [`Event`] in the [`Dataset`] (MPI-compatible version).
286    ///
287    /// # Notes
288    ///
289    /// This method is not intended to be called in analyses but rather in writing methods
290    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
291    #[cfg(feature = "mpi")]
292    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
293        let local_weights = self.weights_local();
294        let n_events = self.n_events();
295        let mut buffer: Vec<Float> = vec![0.0; n_events];
296        let (counts, displs) = world.get_counts_displs(n_events);
297        {
298            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
299            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
300        }
301        buffer
302    }
303
304    /// Extract a list of weights over each [`Event`] in the [`Dataset`].
305    pub fn weights(&self) -> Vec<Float> {
306        #[cfg(feature = "mpi")]
307        {
308            if let Some(world) = crate::mpi::get_world() {
309                return self.weights_mpi(&world);
310            }
311        }
312        self.weights_local()
313    }
314
315    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (non-MPI version).
316    ///
317    /// # Notes
318    ///
319    /// This method is not intended to be called in analyses but rather in writing methods
320    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
321    pub fn n_events_weighted_local(&self) -> Float {
322        #[cfg(feature = "rayon")]
323        return self
324            .events
325            .par_iter()
326            .map(|e| e.weight)
327            .parallel_sum_with_accumulator::<Klein<Float>>();
328        #[cfg(not(feature = "rayon"))]
329        return self.events.iter().map(|e| e.weight).sum();
330    }
331    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`] (MPI-compatible version).
332    ///
333    /// # Notes
334    ///
335    /// This method is not intended to be called in analyses but rather in writing methods
336    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
337    #[cfg(feature = "mpi")]
338    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
339        let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
340        let n_events_weighted_local = self.n_events_weighted_local();
341        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
342        #[cfg(feature = "rayon")]
343        return n_events_weighted_partitioned
344            .into_par_iter()
345            .parallel_sum_with_accumulator::<Klein<Float>>();
346        #[cfg(not(feature = "rayon"))]
347        return n_events_weighted_partitioned.iter().sum();
348    }
349
350    /// Returns the sum of the weights for each [`Event`] in the [`Dataset`].
351    pub fn n_events_weighted(&self) -> Float {
352        #[cfg(feature = "mpi")]
353        {
354            if let Some(world) = crate::mpi::get_world() {
355                return self.n_events_weighted_mpi(&world);
356            }
357        }
358        self.n_events_weighted_local()
359    }
360
361    /// Generate a new dataset with the same length by resampling the events in the original datset
362    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
363    ///
364    /// # Notes
365    ///
366    /// This method is not intended to be called in analyses but rather in writing methods
367    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
368    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
369        let mut rng = fastrand::Rng::with_seed(seed as u64);
370        let mut indices: Vec<usize> = (0..self.n_events())
371            .map(|_| rng.usize(0..self.n_events()))
372            .collect::<Vec<usize>>();
373        indices.sort();
374        #[cfg(feature = "rayon")]
375        let bootstrapped_events: Vec<Arc<Event>> = indices
376            .into_par_iter()
377            .map(|idx| self.events[idx].clone())
378            .collect();
379        #[cfg(not(feature = "rayon"))]
380        let bootstrapped_events: Vec<Arc<Event>> = indices
381            .into_iter()
382            .map(|idx| self.events[idx].clone())
383            .collect();
384        Arc::new(Dataset {
385            events: bootstrapped_events,
386        })
387    }
388
389    /// Generate a new dataset with the same length by resampling the events in the original datset
390    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
391    ///
392    /// # Notes
393    ///
394    /// This method is not intended to be called in analyses but rather in writing methods
395    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
396    #[cfg(feature = "mpi")]
397    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
398        let n_events = self.n_events();
399        let mut indices: Vec<usize> = vec![0; n_events];
400        if world.is_root() {
401            let mut rng = fastrand::Rng::with_seed(seed as u64);
402            indices = (0..n_events)
403                .map(|_| rng.usize(0..n_events))
404                .collect::<Vec<usize>>();
405            indices.sort();
406        }
407        world.process_at_root().broadcast_into(&mut indices);
408        #[cfg(feature = "rayon")]
409        let bootstrapped_events: Vec<Arc<Event>> = indices
410            .into_par_iter()
411            .map(|idx| self.events[idx].clone())
412            .collect();
413        #[cfg(not(feature = "rayon"))]
414        let bootstrapped_events: Vec<Arc<Event>> = indices
415            .into_iter()
416            .map(|idx| self.events[idx].clone())
417            .collect();
418        Arc::new(Dataset {
419            events: bootstrapped_events,
420        })
421    }
422
423    /// Generate a new dataset with the same length by resampling the events in the original datset
424    /// with replacement. This can be used to perform error analysis via the bootstrap method.
425    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
426        #[cfg(feature = "mpi")]
427        {
428            if let Some(world) = crate::mpi::get_world() {
429                return self.bootstrap_mpi(seed, &world);
430            }
431        }
432        self.bootstrap_local(seed)
433    }
434
435    /// Filter the [`Dataset`] by a given `predicate`, selecting events for which the predicate
436    /// returns `true`.
437    pub fn filter<P>(&self, predicate: P) -> Arc<Dataset>
438    where
439        P: Fn(&Event) -> bool + Send + Sync,
440    {
441        #[cfg(feature = "rayon")]
442        let filtered_events = self
443            .events
444            .par_iter()
445            .filter(|e| predicate(e))
446            .cloned()
447            .collect();
448        #[cfg(not(feature = "rayon"))]
449        let filtered_events = self
450            .events
451            .iter()
452            .filter(|e| predicate(e))
453            .cloned()
454            .collect();
455        Arc::new(Dataset {
456            events: filtered_events,
457        })
458    }
459
460    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
461    /// given `range`.
462    pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
463    where
464        V: Variable,
465    {
466        let bin_width = (range.1 - range.0) / bins as Float;
467        let bin_edges = get_bin_edges(bins, range);
468        #[cfg(feature = "rayon")]
469        let evaluated: Vec<(usize, &Arc<Event>)> = self
470            .events
471            .par_iter()
472            .filter_map(|event| {
473                let value = variable.value(event.as_ref());
474                if value >= range.0 && value < range.1 {
475                    let bin_index = ((value - range.0) / bin_width) as usize;
476                    let bin_index = bin_index.min(bins - 1);
477                    Some((bin_index, event))
478                } else {
479                    None
480                }
481            })
482            .collect();
483        #[cfg(not(feature = "rayon"))]
484        let evaluated: Vec<(usize, &Arc<Event>)> = self
485            .events
486            .iter()
487            .filter_map(|event| {
488                let value = variable.value(event.as_ref());
489                if value >= range.0 && value < range.1 {
490                    let bin_index = ((value - range.0) / bin_width) as usize;
491                    let bin_index = bin_index.min(bins - 1);
492                    Some((bin_index, event))
493                } else {
494                    None
495                }
496            })
497            .collect();
498        let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
499        for (bin_index, event) in evaluated {
500            binned_events[bin_index].push(event.clone());
501        }
502        BinnedDataset {
503            #[cfg(feature = "rayon")]
504            datasets: binned_events
505                .into_par_iter()
506                .map(|events| Arc::new(Dataset { events }))
507                .collect(),
508            #[cfg(not(feature = "rayon"))]
509            datasets: binned_events
510                .into_iter()
511                .map(|events| Arc::new(Dataset { events }))
512                .collect(),
513            edges: bin_edges,
514        }
515    }
516}
517
518impl_op_ex!(+ |a: &Dataset, b: &Dataset| ->  Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
519
520fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
521    let mut p4s = Vec::new();
522    let mut eps = Vec::new();
523
524    let p4_count = batch
525        .schema()
526        .fields()
527        .iter()
528        .filter(|field| field.name().starts_with("p4_"))
529        .count()
530        / 4;
531    let eps_count = batch
532        .schema()
533        .fields()
534        .iter()
535        .filter(|field| field.name().starts_with("eps_"))
536        .count()
537        / 3;
538
539    for i in 0..p4_count {
540        let e = batch
541            .column_by_name(&format!("p4_{}_E", i))
542            .unwrap()
543            .as_any()
544            .downcast_ref::<Float32Array>()
545            .unwrap()
546            .value(row) as Float;
547        let px = batch
548            .column_by_name(&format!("p4_{}_Px", i))
549            .unwrap()
550            .as_any()
551            .downcast_ref::<Float32Array>()
552            .unwrap()
553            .value(row) as Float;
554        let py = batch
555            .column_by_name(&format!("p4_{}_Py", i))
556            .unwrap()
557            .as_any()
558            .downcast_ref::<Float32Array>()
559            .unwrap()
560            .value(row) as Float;
561        let pz = batch
562            .column_by_name(&format!("p4_{}_Pz", i))
563            .unwrap()
564            .as_any()
565            .downcast_ref::<Float32Array>()
566            .unwrap()
567            .value(row) as Float;
568        p4s.push(Vec4::new(px, py, pz, e));
569    }
570
571    // TODO: insert empty vectors if not provided
572    for i in 0..eps_count {
573        let x = batch
574            .column_by_name(&format!("eps_{}_x", i))
575            .unwrap()
576            .as_any()
577            .downcast_ref::<Float32Array>()
578            .unwrap()
579            .value(row) as Float;
580        let y = batch
581            .column_by_name(&format!("eps_{}_y", i))
582            .unwrap()
583            .as_any()
584            .downcast_ref::<Float32Array>()
585            .unwrap()
586            .value(row) as Float;
587        let z = batch
588            .column_by_name(&format!("eps_{}_z", i))
589            .unwrap()
590            .as_any()
591            .downcast_ref::<Float32Array>()
592            .unwrap()
593            .value(row) as Float;
594        eps.push(Vec3::new(x, y, z));
595    }
596
597    let weight = batch
598        .column(19)
599        .as_any()
600        .downcast_ref::<Float32Array>()
601        .unwrap()
602        .value(row) as Float;
603
604    Event { p4s, eps, weight }
605}
606
607/// Open a Parquet file and read the data into a [`Dataset`].
608pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
609    // TODO: make this read in directly to MPI ranks
610    let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
611    let file = File::open(file_path)?;
612    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
613    let reader = builder.build()?;
614    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
615
616    #[cfg(feature = "rayon")]
617    let events: Vec<Arc<Event>> = batches
618        .into_par_iter()
619        .flat_map(|batch| {
620            let num_rows = batch.num_rows();
621            let mut local_events = Vec::with_capacity(num_rows);
622
623            // Process each row in the batch
624            for row in 0..num_rows {
625                let event = batch_to_event(&batch, row);
626                local_events.push(Arc::new(event));
627            }
628            local_events
629        })
630        .collect();
631    #[cfg(not(feature = "rayon"))]
632    let events: Vec<Arc<Event>> = batches
633        .into_iter()
634        .flat_map(|batch| {
635            let num_rows = batch.num_rows();
636            let mut local_events = Vec::with_capacity(num_rows);
637
638            // Process each row in the batch
639            for row in 0..num_rows {
640                let event = batch_to_event(&batch, row);
641                local_events.push(Arc::new(event));
642            }
643            local_events
644        })
645        .collect();
646    Ok(Arc::new(Dataset::new(events)))
647}
648
649/// A list of [`Dataset`]s formed by binning [`Event`]s by some [`Variable`].
650pub struct BinnedDataset {
651    datasets: Vec<Arc<Dataset>>,
652    edges: Vec<Float>,
653}
654
655impl Index<usize> for BinnedDataset {
656    type Output = Arc<Dataset>;
657
658    fn index(&self, index: usize) -> &Self::Output {
659        &self.datasets[index]
660    }
661}
662
663impl IndexMut<usize> for BinnedDataset {
664    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
665        &mut self.datasets[index]
666    }
667}
668
669impl Deref for BinnedDataset {
670    type Target = Vec<Arc<Dataset>>;
671
672    fn deref(&self) -> &Self::Target {
673        &self.datasets
674    }
675}
676
677impl DerefMut for BinnedDataset {
678    fn deref_mut(&mut self) -> &mut Self::Target {
679        &mut self.datasets
680    }
681}
682
683impl BinnedDataset {
684    /// The number of bins in the [`BinnedDataset`].
685    pub fn n_bins(&self) -> usize {
686        self.datasets.len()
687    }
688
689    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
690    pub fn edges(&self) -> Vec<Float> {
691        self.edges.clone()
692    }
693
694    /// Returns the range that was used to form the [`BinnedDataset`].
695    pub fn range(&self) -> (Float, Float) {
696        (self.edges[0], self.edges[self.n_bins()])
697    }
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use approx::{assert_relative_eq, assert_relative_ne};
704    use serde::{Deserialize, Serialize};
705    #[test]
706    fn test_event_creation() {
707        let event = test_event();
708        assert_eq!(event.p4s.len(), 4);
709        assert_eq!(event.eps.len(), 1);
710        assert_relative_eq!(event.weight, 0.48)
711    }
712
713    #[test]
714    fn test_event_p4_sum() {
715        let event = test_event();
716        let sum = event.get_p4_sum([2, 3]);
717        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
718        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
719        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
720        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
721    }
722
723    #[test]
724    fn test_dataset_size_check() {
725        let mut dataset = Dataset::default();
726        assert_eq!(dataset.n_events(), 0);
727        dataset.events.push(Arc::new(test_event()));
728        assert_eq!(dataset.n_events(), 1);
729    }
730
731    #[test]
732    fn test_dataset_sum() {
733        let dataset = test_dataset();
734        let dataset2 = Dataset::new(vec![Arc::new(Event {
735            p4s: test_event().p4s,
736            eps: test_event().eps,
737            weight: 0.52,
738        })]);
739        let dataset_sum = &dataset + &dataset2;
740        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
741        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
742    }
743
744    #[test]
745    fn test_dataset_weights() {
746        let mut dataset = Dataset::default();
747        dataset.events.push(Arc::new(test_event()));
748        dataset.events.push(Arc::new(Event {
749            p4s: test_event().p4s,
750            eps: test_event().eps,
751            weight: 0.52,
752        }));
753        let weights = dataset.weights();
754        assert_eq!(weights.len(), 2);
755        assert_relative_eq!(weights[0], 0.48);
756        assert_relative_eq!(weights[1], 0.52);
757        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
758    }
759
760    #[test]
761    fn test_dataset_filtering() {
762        let mut dataset = test_dataset();
763        dataset.events.push(Arc::new(Event {
764            p4s: vec![
765                Vec3::new(0.0, 0.0, 5.0).with_mass(0.0),
766                Vec3::new(0.0, 0.0, 1.0).with_mass(1.0),
767            ],
768            eps: vec![],
769            weight: 1.0,
770        }));
771
772        let filtered = dataset.filter(|event| event.p4s.len() == 2);
773        assert_eq!(filtered.n_events(), 1);
774        assert_eq!(filtered[0].p4s.len(), 2);
775    }
776
777    #[test]
778    fn test_binned_dataset() {
779        let dataset = Dataset::new(vec![
780            Arc::new(Event {
781                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
782                eps: vec![],
783                weight: 1.0,
784            }),
785            Arc::new(Event {
786                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
787                eps: vec![],
788                weight: 2.0,
789            }),
790        ]);
791
792        #[derive(Clone, Serialize, Deserialize)]
793        struct BeamEnergy;
794        #[typetag::serde]
795        impl Variable for BeamEnergy {
796            fn value(&self, event: &Event) -> Float {
797                event.p4s[0].e()
798            }
799        }
800
801        // Test binning by first particle energy
802        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
803
804        assert_eq!(binned.n_bins(), 2);
805        assert_eq!(binned.edges().len(), 3);
806        assert_relative_eq!(binned.edges()[0], 0.0);
807        assert_relative_eq!(binned.edges()[2], 3.0);
808        assert_eq!(binned[0].n_events(), 1);
809        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
810        assert_eq!(binned[1].n_events(), 1);
811        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
812    }
813
814    #[test]
815    fn test_dataset_bootstrap() {
816        let mut dataset = test_dataset();
817        dataset.events.push(Arc::new(Event {
818            p4s: test_event().p4s.clone(),
819            eps: test_event().eps.clone(),
820            weight: 1.0,
821        }));
822        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
823
824        let bootstrapped = dataset.bootstrap(43);
825        assert_eq!(bootstrapped.n_events(), dataset.n_events());
826        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
827
828        // Test empty dataset bootstrap
829        let empty_dataset = Dataset::default();
830        let empty_bootstrap = empty_dataset.bootstrap(43);
831        assert_eq!(empty_bootstrap.n_events(), 0);
832    }
833
834    #[test]
835    fn test_event_display() {
836        let event = test_event();
837        let display_string = format!("{}", event);
838        assert_eq!(
839            display_string,
840            "Event:\n  p4s:\n    [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n    [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n    [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n    [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n  eps:\n    [0.385, 0.022, 0]\n  weight:\n    0.48\n"
841        );
842    }
843}