laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::{
3    array::{Float32Array, Float64Array},
4    datatypes::{DataType, Field, Schema},
5    record_batch::RecordBatch,
6};
7use auto_ops::impl_op_ex;
8use parking_lot::Mutex;
9use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter};
10use serde::{Deserialize, Serialize};
11use std::ops::{Deref, DerefMut, Index, IndexMut};
12use std::path::Path;
13use std::{fmt::Display, fs::File};
14use std::{path::PathBuf, sync::Arc};
15
16use oxyroot::{Branch, Named, ReaderTree, RootFile, WriterTree};
17
18#[cfg(feature = "rayon")]
19use rayon::prelude::*;
20
21#[cfg(feature = "mpi")]
22use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
23
24#[cfg(feature = "mpi")]
25use crate::mpi::LadduMPI;
26
27#[cfg(feature = "mpi")]
28type WorldHandle = SimpleCommunicator;
29#[cfg(not(feature = "mpi"))]
30type WorldHandle = ();
31
32use crate::utils::get_bin_edges;
33use crate::{
34    utils::{
35        variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
36        vectors::Vec4,
37    },
38    LadduError, LadduResult,
39};
40use indexmap::{IndexMap, IndexSet};
41
42/// An event that can be used to test the implementation of an
43/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
44/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
45pub fn test_event() -> EventData {
46    use crate::utils::vectors::*;
47    let pol_magnitude = 0.38562805;
48    let pol_angle = 0.05708078;
49    EventData {
50        p4s: vec![
51            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
52            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
53            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
54            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
55        ],
56        aux: vec![pol_magnitude, pol_angle],
57        weight: 0.48,
58    }
59}
60
61/// Particle names used by [`test_dataset`].
62pub const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
63/// Auxiliary scalar names used by [`test_dataset`].
64pub const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
65
66/// A dataset that can be used to test the implementation of an
67/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a single
68/// [`EventData`] generated from [`test_event`].
69pub fn test_dataset() -> Dataset {
70    let metadata = Arc::new(
71        DatasetMetadata::new(
72            TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
73            TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
74        )
75        .expect("Test metadata should be valid"),
76    );
77    Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
78}
79
80/// Raw event data in a [`Dataset`] containing all particle and auxiliary information.
81///
82/// An [`EventData`] instance owns the list of four-momenta (`p4s`), auxiliary scalars (`aux`),
83/// and weight recorded for a particular collision event. Use [`Event`] when you need a
84/// metadata-aware view with name-based helpers.
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct EventData {
87    /// A list of four-momenta for each particle.
88    pub p4s: Vec<Vec4>,
89    /// A list of auxiliary scalar values associated with the event.
90    pub aux: Vec<f64>,
91    /// The weight given to the event.
92    pub weight: f64,
93}
94
95impl Display for EventData {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        writeln!(f, "Event:")?;
98        writeln!(f, "  p4s:")?;
99        for p4 in &self.p4s {
100            writeln!(f, "    {}", p4.to_p4_string())?;
101        }
102        writeln!(f, "  aux:")?;
103        for (idx, value) in self.aux.iter().enumerate() {
104            writeln!(f, "    aux[{idx}]: {value}")?;
105        }
106        writeln!(f, "  weight:")?;
107        writeln!(f, "    {}", self.weight)?;
108        Ok(())
109    }
110}
111
112impl EventData {
113    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`EventData`].
114    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
115        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
116    }
117    /// Boost all the four-momenta in the [`EventData`] to the rest frame of the given set of
118    /// four-momenta by indices.
119    pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
120        let frame = self.get_p4_sum(indices);
121        EventData {
122            p4s: self
123                .p4s
124                .iter()
125                .map(|p4| p4.boost(&(-frame.beta())))
126                .collect(),
127            aux: self.aux.clone(),
128            weight: self.weight,
129        }
130    }
131    /// Evaluate a [`Variable`] on an [`EventData`].
132    pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
133        variable.value(self)
134    }
135}
136
137/// A collection of [`EventData`].
138#[derive(Debug, Clone)]
139pub struct DatasetMetadata {
140    pub(crate) p4_names: Vec<String>,
141    pub(crate) aux_names: Vec<String>,
142    pub(crate) p4_lookup: IndexMap<String, usize>,
143    pub(crate) aux_lookup: IndexMap<String, usize>,
144    pub(crate) p4_selections: IndexMap<String, P4Selection>,
145}
146
147impl DatasetMetadata {
148    /// Construct metadata from explicit particle and auxiliary names.
149    pub fn new<P: Into<String>, A: Into<String>>(
150        p4_names: Vec<P>,
151        aux_names: Vec<A>,
152    ) -> LadduResult<Self> {
153        let mut p4_lookup = IndexMap::with_capacity(p4_names.len());
154        let mut aux_lookup = IndexMap::with_capacity(aux_names.len());
155        let mut p4_selections = IndexMap::with_capacity(p4_names.len());
156        let p4_names: Vec<String> = p4_names
157            .into_iter()
158            .enumerate()
159            .map(|(idx, name)| {
160                let name = name.into();
161                if p4_lookup.contains_key(&name) {
162                    return Err(LadduError::DuplicateName {
163                        category: "p4",
164                        name,
165                    });
166                }
167                p4_lookup.insert(name.clone(), idx);
168                p4_selections.insert(
169                    name.clone(),
170                    P4Selection::with_indices(vec![name.clone()], vec![idx]),
171                );
172                Ok(name)
173            })
174            .collect::<Result<_, _>>()?;
175        let aux_names: Vec<String> = aux_names
176            .into_iter()
177            .enumerate()
178            .map(|(idx, name)| {
179                let name = name.into();
180                if aux_lookup.contains_key(&name) {
181                    return Err(LadduError::DuplicateName {
182                        category: "aux",
183                        name,
184                    });
185                }
186                aux_lookup.insert(name.clone(), idx);
187                Ok(name)
188            })
189            .collect::<Result<_, _>>()?;
190        Ok(Self {
191            p4_names,
192            aux_names,
193            p4_lookup,
194            aux_lookup,
195            p4_selections,
196        })
197    }
198
199    /// Create metadata with no registered names.
200    pub fn empty() -> Self {
201        Self {
202            p4_names: Vec::new(),
203            aux_names: Vec::new(),
204            p4_lookup: IndexMap::new(),
205            aux_lookup: IndexMap::new(),
206            p4_selections: IndexMap::new(),
207        }
208    }
209
210    /// Resolve the index of a four-momentum by name.
211    pub fn p4_index(&self, name: &str) -> Option<usize> {
212        self.p4_lookup.get(name).copied()
213    }
214
215    /// Registered four-momentum names in declaration order.
216    pub fn p4_names(&self) -> &[String] {
217        &self.p4_names
218    }
219
220    /// Resolve the index of an auxiliary scalar by name.
221    pub fn aux_index(&self, name: &str) -> Option<usize> {
222        self.aux_lookup.get(name).copied()
223    }
224
225    /// Registered auxiliary scalar names in declaration order.
226    pub fn aux_names(&self) -> &[String] {
227        &self.aux_names
228    }
229
230    /// Look up a resolved four-momentum selection by name (canonical or alias).
231    pub fn p4_selection(&self, name: &str) -> Option<&P4Selection> {
232        self.p4_selections.get(name)
233    }
234
235    /// Register an alias mapping to one or more existing four-momenta.
236    pub fn add_p4_alias<N>(&mut self, alias: N, mut selection: P4Selection) -> LadduResult<()>
237    where
238        N: Into<String>,
239    {
240        let alias = alias.into();
241        if self.p4_selections.contains_key(&alias) {
242            return Err(LadduError::DuplicateName {
243                category: "alias",
244                name: alias,
245            });
246        }
247        selection.bind(self)?;
248        self.p4_selections.insert(alias, selection);
249        Ok(())
250    }
251
252    /// Register multiple aliases at once.
253    pub fn add_p4_aliases<I, N>(&mut self, entries: I) -> LadduResult<()>
254    where
255        I: IntoIterator<Item = (N, P4Selection)>,
256        N: Into<String>,
257    {
258        for (alias, selection) in entries {
259            self.add_p4_alias(alias, selection)?;
260        }
261        Ok(())
262    }
263
264    pub(crate) fn append_indices_for_name(
265        &self,
266        name: &str,
267        target: &mut Vec<usize>,
268    ) -> LadduResult<()> {
269        if let Some(selection) = self.p4_selections.get(name) {
270            target.extend_from_slice(selection.indices());
271            return Ok(());
272        }
273        Err(LadduError::UnknownName {
274            category: "p4",
275            name: name.to_string(),
276        })
277    }
278}
279
280impl Default for DatasetMetadata {
281    fn default() -> Self {
282        Self::empty()
283    }
284}
285
286/// A collection of [`EventData`] with optional metadata for name-based lookups.
287#[derive(Debug, Clone)]
288pub struct Dataset {
289    /// The [`EventData`] contained in the [`Dataset`]
290    pub events: Vec<Event>,
291    pub(crate) metadata: Arc<DatasetMetadata>,
292}
293
294/// Metadata-aware view of an [`EventData`] with name-based helpers.
295#[derive(Clone, Debug)]
296pub struct Event {
297    event: Arc<EventData>,
298    metadata: Arc<DatasetMetadata>,
299}
300
301impl Event {
302    /// Create a new metadata-aware event from raw data and dataset metadata.
303    pub fn new(event: Arc<EventData>, metadata: Arc<DatasetMetadata>) -> Self {
304        Self { event, metadata }
305    }
306
307    /// Borrow the raw [`EventData`].
308    pub fn data(&self) -> &EventData {
309        &self.event
310    }
311
312    /// Obtain a clone of the underlying [`EventData`] handle.
313    pub fn data_arc(&self) -> Arc<EventData> {
314        self.event.clone()
315    }
316
317    /// Return the four-momenta stored in this event keyed by their registered names.
318    pub fn p4s(&self) -> IndexMap<&str, Vec4> {
319        let mut map = IndexMap::with_capacity(self.metadata.p4_names.len());
320        for (idx, name) in self.metadata.p4_names.iter().enumerate() {
321            if let Some(p4) = self.event.p4s.get(idx) {
322                map.insert(name.as_str(), *p4);
323            }
324        }
325        map
326    }
327
328    /// Return the auxiliary scalars stored in this event keyed by their registered names.
329    pub fn aux(&self) -> IndexMap<&str, f64> {
330        let mut map = IndexMap::with_capacity(self.metadata.aux_names.len());
331        for (idx, name) in self.metadata.aux_names.iter().enumerate() {
332            if let Some(value) = self.event.aux.get(idx) {
333                map.insert(name.as_str(), *value);
334            }
335        }
336        map
337    }
338
339    /// Return the event weight.
340    pub fn weight(&self) -> f64 {
341        self.event.weight
342    }
343
344    /// Retrieve the dataset metadata attached to this event.
345    pub fn metadata(&self) -> &DatasetMetadata {
346        &self.metadata
347    }
348
349    /// Clone the metadata handle associated with this event.
350    pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
351        self.metadata.clone()
352    }
353
354    /// Retrieve a four-momentum (or aliased sum) by name.
355    pub fn p4(&self, name: &str) -> Option<Vec4> {
356        self.metadata
357            .p4_selection(name)
358            .map(|selection| selection.momentum(&self.event))
359    }
360
361    fn resolve_p4_indices<N>(&self, names: N) -> Vec<usize>
362    where
363        N: IntoIterator,
364        N::Item: AsRef<str>,
365    {
366        let mut indices = Vec::new();
367        for name in names {
368            let name_ref = name.as_ref();
369            if let Some(selection) = self.metadata.p4_selection(name_ref) {
370                indices.extend_from_slice(selection.indices());
371            } else {
372                panic!("Unknown particle name '{name}'", name = name_ref);
373            }
374        }
375        indices
376    }
377
378    /// Return a four-momentum formed by summing four-momenta with the specified names.
379    pub fn get_p4_sum<N>(&self, names: N) -> Vec4
380    where
381        N: IntoIterator,
382        N::Item: AsRef<str>,
383    {
384        let indices = self.resolve_p4_indices(names);
385        self.event.get_p4_sum(&indices)
386    }
387
388    /// Boost all four-momenta into the rest frame defined by the specified particle names.
389    pub fn boost_to_rest_frame_of<N>(&self, names: N) -> EventData
390    where
391        N: IntoIterator,
392        N::Item: AsRef<str>,
393    {
394        let indices = self.resolve_p4_indices(names);
395        self.event.boost_to_rest_frame_of(&indices)
396    }
397
398    /// Evaluate a [`Variable`] over this event.
399    pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
400        self.event.evaluate(variable)
401    }
402}
403
404impl Deref for Event {
405    type Target = EventData;
406
407    fn deref(&self) -> &Self::Target {
408        &self.event
409    }
410}
411
412impl AsRef<EventData> for Event {
413    fn as_ref(&self) -> &EventData {
414        self.data()
415    }
416}
417
418impl IntoIterator for Dataset {
419    type Item = Event;
420
421    type IntoIter = DatasetIntoIter;
422
423    fn into_iter(self) -> Self::IntoIter {
424        #[cfg(feature = "mpi")]
425        {
426            if let Some(world) = crate::mpi::get_world() {
427                // Cache total before moving fields out of self for MPI iteration.
428                let total = self.n_events();
429                return DatasetIntoIter::Mpi(DatasetMpiIntoIter {
430                    events: self.events,
431                    metadata: self.metadata,
432                    world,
433                    index: 0,
434                    total,
435                });
436            }
437        }
438        DatasetIntoIter::Local(self.events.into_iter())
439    }
440}
441
442impl Dataset {
443    /// Iterate over all events in the dataset. When MPI is enabled, this will visit
444    /// every event across all ranks, fetching remote events on demand.
445    pub fn iter(&self) -> DatasetIter<'_> {
446        #[cfg(feature = "mpi")]
447        {
448            if let Some(world) = crate::mpi::get_world() {
449                return DatasetIter::Mpi(DatasetMpiIter {
450                    dataset: self,
451                    world,
452                    index: 0,
453                    total: self.n_events(),
454                });
455            }
456        }
457        DatasetIter::Local(self.events.iter())
458    }
459
460    /// Borrow the dataset metadata used for name lookups.
461    pub fn metadata(&self) -> &DatasetMetadata {
462        &self.metadata
463    }
464
465    /// Clone the internal metadata handle for external consumers (e.g., language bindings).
466    pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
467        self.metadata.clone()
468    }
469
470    /// Names corresponding to stored four-momenta.
471    pub fn p4_names(&self) -> &[String] {
472        &self.metadata.p4_names
473    }
474
475    /// Names corresponding to stored auxiliary scalars.
476    pub fn aux_names(&self) -> &[String] {
477        &self.metadata.aux_names
478    }
479
480    /// Resolve the index of a four-momentum by name.
481    pub fn p4_index(&self, name: &str) -> Option<usize> {
482        self.metadata.p4_index(name)
483    }
484
485    /// Resolve the index of an auxiliary scalar by name.
486    pub fn aux_index(&self, name: &str) -> Option<usize> {
487        self.metadata.aux_index(name)
488    }
489
490    /// Borrow event data together with metadata-based helpers as an [`Event`] view.
491    pub fn named_event(&self, index: usize) -> Event {
492        self.events[index].clone()
493    }
494
495    /// Retrieve a four-momentum by name for the event at `event_index`.
496    pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
497        self.events
498            .get(event_index)
499            .and_then(|event| event.p4(name))
500    }
501
502    /// Retrieve an auxiliary scalar by name for the event at `event_index`.
503    pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
504        let idx = self.aux_index(name)?;
505        self.events
506            .get(event_index)
507            .and_then(|event| event.aux.get(idx))
508            .copied()
509    }
510
511    /// Get a reference to the [`EventData`] at the given index in the [`Dataset`] (non-MPI
512    /// version).
513    ///
514    /// # Notes
515    ///
516    /// This method is not intended to be called in analyses but rather in writing methods
517    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
518    /// as if it were any other [`Vec`]:
519    ///
520    /// ```ignore
521    /// let ds: Dataset = Dataset::new(events);
522    /// let event_0 = ds[0];
523    /// ```
524    pub fn index_local(&self, index: usize) -> &Event {
525        &self.events[index]
526    }
527
528    #[cfg(feature = "mpi")]
529    fn partition(
530        events: Vec<Arc<EventData>>,
531        world: &SimpleCommunicator,
532    ) -> Vec<Vec<Arc<EventData>>> {
533        let (counts, displs) = world.get_counts_displs(events.len());
534        counts
535            .iter()
536            .zip(displs.iter())
537            .map(|(&count, &displ)| {
538                events
539                    .iter()
540                    .skip(displ as usize)
541                    .take(count as usize)
542                    .cloned()
543                    .collect()
544            })
545            .collect()
546    }
547
548    /// Get a reference to the [`EventData`] at the given index in the [`Dataset`]
549    /// (MPI-compatible version).
550    ///
551    /// # Notes
552    ///
553    /// This method is not intended to be called in analyses but rather in writing methods
554    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
555    /// as if it were any other [`Vec`]:
556    ///
557    /// ```ignore
558    /// let ds: Dataset = Dataset::new(events);
559    /// let event_0 = ds[0];
560    /// ```
561    #[cfg(feature = "mpi")]
562    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
563        let total = self.n_events();
564        let event = fetch_event_mpi(self, index, world, total);
565        Box::leak(Box::new(event))
566    }
567}
568
569impl Index<usize> for Dataset {
570    type Output = Event;
571
572    fn index(&self, index: usize) -> &Self::Output {
573        #[cfg(feature = "mpi")]
574        {
575            if let Some(world) = crate::mpi::get_world() {
576                return self.index_mpi(index, &world);
577            }
578        }
579        self.index_local(index)
580    }
581}
582
583/// Iterator over a [`Dataset`].
584pub enum DatasetIter<'a> {
585    /// Iterator over locally available events.
586    Local(std::slice::Iter<'a, Event>),
587    #[cfg(feature = "mpi")]
588    /// Iterator that fetches events across MPI ranks.
589    Mpi(DatasetMpiIter<'a>),
590}
591
592impl<'a> Iterator for DatasetIter<'a> {
593    type Item = Event;
594
595    fn next(&mut self) -> Option<Self::Item> {
596        match self {
597            DatasetIter::Local(iter) => iter.next().cloned(),
598            #[cfg(feature = "mpi")]
599            DatasetIter::Mpi(iter) => iter.next(),
600        }
601    }
602}
603
604/// Owning iterator over a [`Dataset`].
605pub enum DatasetIntoIter {
606    /// Iterator over locally available events, consuming the dataset.
607    Local(std::vec::IntoIter<Event>),
608    #[cfg(feature = "mpi")]
609    /// Iterator that fetches events across MPI ranks, consuming the dataset.
610    Mpi(DatasetMpiIntoIter),
611}
612
613impl Iterator for DatasetIntoIter {
614    type Item = Event;
615
616    fn next(&mut self) -> Option<Self::Item> {
617        match self {
618            DatasetIntoIter::Local(iter) => iter.next(),
619            #[cfg(feature = "mpi")]
620            DatasetIntoIter::Mpi(iter) => iter.next(),
621        }
622    }
623}
624
625#[cfg(feature = "mpi")]
626/// Iterator over a [`Dataset`] that fetches events across MPI ranks.
627pub struct DatasetMpiIter<'a> {
628    dataset: &'a Dataset,
629    world: SimpleCommunicator,
630    index: usize,
631    total: usize,
632}
633
634#[cfg(feature = "mpi")]
635impl<'a> Iterator for DatasetMpiIter<'a> {
636    type Item = Event;
637
638    fn next(&mut self) -> Option<Self::Item> {
639        if self.index >= self.total {
640            return None;
641        }
642        let event = fetch_event_mpi(self.dataset, self.index, &self.world, self.total);
643        self.index += 1;
644        Some(event)
645    }
646}
647
648#[cfg(feature = "mpi")]
649/// Owning iterator over a [`Dataset`] that fetches events across MPI ranks.
650pub struct DatasetMpiIntoIter {
651    events: Vec<Event>,
652    metadata: Arc<DatasetMetadata>,
653    world: SimpleCommunicator,
654    index: usize,
655    total: usize,
656}
657
658#[cfg(feature = "mpi")]
659impl Iterator for DatasetMpiIntoIter {
660    type Item = Event;
661
662    fn next(&mut self) -> Option<Self::Item> {
663        if self.index >= self.total {
664            return None;
665        }
666        let event = fetch_event_mpi_from_events(
667            &self.events,
668            &self.metadata,
669            self.index,
670            &self.world,
671            self.total,
672        );
673        self.index += 1;
674        Some(event)
675    }
676}
677
678#[cfg(feature = "mpi")]
679fn fetch_event_mpi(
680    dataset: &Dataset,
681    global_index: usize,
682    world: &SimpleCommunicator,
683    total: usize,
684) -> Event {
685    fetch_event_mpi_generic(
686        global_index,
687        total,
688        world,
689        &dataset.metadata,
690        |local_index| dataset.index_local(local_index),
691    )
692}
693
694#[cfg(feature = "mpi")]
695fn fetch_event_mpi_from_events(
696    events: &[Event],
697    metadata: &Arc<DatasetMetadata>,
698    global_index: usize,
699    world: &SimpleCommunicator,
700    total: usize,
701) -> Event {
702    fetch_event_mpi_generic(global_index, total, world, metadata, |local_index| {
703        &events[local_index]
704    })
705}
706
707#[cfg(feature = "mpi")]
708fn fetch_event_mpi_generic<'a, F>(
709    global_index: usize,
710    total: usize,
711    world: &SimpleCommunicator,
712    metadata: &Arc<DatasetMetadata>,
713    local_event: F,
714) -> Event
715where
716    F: Fn(usize) -> &'a Event,
717{
718    let (owning_rank, local_index) = world.owner_of_global_index(global_index, total);
719    let mut serialized_event_buffer_len: usize = 0;
720    let mut serialized_event_buffer: Vec<u8> = Vec::default();
721    let config = bincode::config::standard();
722    if world.rank() == owning_rank {
723        let event = local_event(local_index);
724        serialized_event_buffer = bincode::serde::encode_to_vec(event.data(), config).unwrap();
725        serialized_event_buffer_len = serialized_event_buffer.len();
726    }
727    world
728        .process_at_rank(owning_rank)
729        .broadcast_into(&mut serialized_event_buffer_len);
730    if world.rank() != owning_rank {
731        serialized_event_buffer = vec![0; serialized_event_buffer_len];
732    }
733    world
734        .process_at_rank(owning_rank)
735        .broadcast_into(&mut serialized_event_buffer);
736
737    if world.rank() == owning_rank {
738        local_event(local_index).clone()
739    } else {
740        let (event, _): (EventData, usize) =
741            bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
742        Event::new(Arc::new(event), metadata.clone())
743    }
744}
745
746impl Dataset {
747    /// Create a new [`Dataset`] from a list of [`EventData`] (non-MPI version).
748    ///
749    /// # Notes
750    ///
751    /// This method is not intended to be called in analyses but rather in writing methods
752    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
753    pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
754        let wrapped_events = events
755            .into_iter()
756            .map(|event| Event::new(event, metadata.clone()))
757            .collect();
758        Dataset {
759            events: wrapped_events,
760            metadata,
761        }
762    }
763
764    /// Create a new [`Dataset`] from a list of [`EventData`] (MPI-compatible version).
765    ///
766    /// # Notes
767    ///
768    /// This method is not intended to be called in analyses but rather in writing methods
769    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
770    #[cfg(feature = "mpi")]
771    pub fn new_mpi(
772        events: Vec<Arc<EventData>>,
773        metadata: Arc<DatasetMetadata>,
774        world: &SimpleCommunicator,
775    ) -> Self {
776        let partitions = Dataset::partition(events, world);
777        let local = partitions[world.rank() as usize]
778            .iter()
779            .cloned()
780            .map(|event| Event::new(event, metadata.clone()))
781            .collect();
782        Dataset {
783            events: local,
784            metadata,
785        }
786    }
787
788    /// Create a new [`Dataset`] from a list of [`EventData`].
789    ///
790    /// This method is prefered for external use because it contains proper MPI construction
791    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
792    /// interfacing with MPI and should be avoided unless you know what you are doing.
793    pub fn new(events: Vec<Arc<EventData>>) -> Self {
794        Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
795    }
796
797    /// Create a dataset with explicit metadata for name-based lookups.
798    /// Create a dataset with explicit metadata for name-based lookups.
799    pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
800        #[cfg(feature = "mpi")]
801        {
802            if let Some(world) = crate::mpi::get_world() {
803                return Dataset::new_mpi(events, metadata, &world);
804            }
805        }
806        Dataset::new_local(events, metadata)
807    }
808
809    /// The number of [`EventData`]s in the [`Dataset`] (non-MPI version).
810    ///
811    /// # Notes
812    ///
813    /// This method is not intended to be called in analyses but rather in writing methods
814    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
815    pub fn n_events_local(&self) -> usize {
816        self.events.len()
817    }
818
819    /// The number of [`EventData`]s in the [`Dataset`] (MPI-compatible version).
820    ///
821    /// # Notes
822    ///
823    /// This method is not intended to be called in analyses but rather in writing methods
824    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
825    #[cfg(feature = "mpi")]
826    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
827        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
828        let n_events_local = self.n_events_local();
829        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
830        n_events_partitioned.iter().sum()
831    }
832
833    /// The number of [`EventData`]s in the [`Dataset`].
834    pub fn n_events(&self) -> usize {
835        #[cfg(feature = "mpi")]
836        {
837            if let Some(world) = crate::mpi::get_world() {
838                return self.n_events_mpi(&world);
839            }
840        }
841        self.n_events_local()
842    }
843}
844
845impl Dataset {
846    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (non-MPI version).
847    ///
848    /// # Notes
849    ///
850    /// This method is not intended to be called in analyses but rather in writing methods
851    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
852    pub fn weights_local(&self) -> Vec<f64> {
853        #[cfg(feature = "rayon")]
854        return self.events.par_iter().map(|e| e.weight).collect();
855        #[cfg(not(feature = "rayon"))]
856        return self.events.iter().map(|e| e.weight).collect();
857    }
858
859    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (MPI-compatible version).
860    ///
861    /// # Notes
862    ///
863    /// This method is not intended to be called in analyses but rather in writing methods
864    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
865    #[cfg(feature = "mpi")]
866    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
867        let local_weights = self.weights_local();
868        let n_events = self.n_events();
869        let mut buffer: Vec<f64> = vec![0.0; n_events];
870        let (counts, displs) = world.get_counts_displs(n_events);
871        {
872            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
873            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
874        }
875        buffer
876    }
877
878    /// Extract a list of weights over each [`EventData`] in the [`Dataset`].
879    pub fn weights(&self) -> Vec<f64> {
880        #[cfg(feature = "mpi")]
881        {
882            if let Some(world) = crate::mpi::get_world() {
883                return self.weights_mpi(&world);
884            }
885        }
886        self.weights_local()
887    }
888
889    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (non-MPI version).
890    ///
891    /// # Notes
892    ///
893    /// This method is not intended to be called in analyses but rather in writing methods
894    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
895    pub fn n_events_weighted_local(&self) -> f64 {
896        #[cfg(feature = "rayon")]
897        return self
898            .events
899            .par_iter()
900            .map(|e| e.weight)
901            .parallel_sum_with_accumulator::<Klein<f64>>();
902        #[cfg(not(feature = "rayon"))]
903        return self.events.iter().map(|e| e.weight).sum();
904    }
905    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (MPI-compatible version).
906    ///
907    /// # Notes
908    ///
909    /// This method is not intended to be called in analyses but rather in writing methods
910    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
911    #[cfg(feature = "mpi")]
912    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> f64 {
913        let mut n_events_weighted_partitioned: Vec<f64> = vec![0.0; world.size() as usize];
914        let n_events_weighted_local = self.n_events_weighted_local();
915        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
916        #[cfg(feature = "rayon")]
917        return n_events_weighted_partitioned
918            .into_par_iter()
919            .parallel_sum_with_accumulator::<Klein<f64>>();
920        #[cfg(not(feature = "rayon"))]
921        return n_events_weighted_partitioned.iter().sum();
922    }
923
924    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`].
925    pub fn n_events_weighted(&self) -> f64 {
926        #[cfg(feature = "mpi")]
927        {
928            if let Some(world) = crate::mpi::get_world() {
929                return self.n_events_weighted_mpi(&world);
930            }
931        }
932        self.n_events_weighted_local()
933    }
934
935    /// Generate a new dataset with the same length by resampling the events in the original datset
936    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
937    ///
938    /// # Notes
939    ///
940    /// This method is not intended to be called in analyses but rather in writing methods
941    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
942    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
943        let mut rng = fastrand::Rng::with_seed(seed as u64);
944        let mut indices: Vec<usize> = (0..self.n_events())
945            .map(|_| rng.usize(0..self.n_events()))
946            .collect::<Vec<usize>>();
947        indices.sort();
948        #[cfg(feature = "rayon")]
949        let bootstrapped_events: Vec<Arc<EventData>> = indices
950            .into_par_iter()
951            .map(|idx| self.events[idx].data_arc())
952            .collect();
953        #[cfg(not(feature = "rayon"))]
954        let bootstrapped_events: Vec<Arc<EventData>> = indices
955            .into_iter()
956            .map(|idx| self.events[idx].data_arc())
957            .collect();
958        Arc::new(Dataset::new_with_metadata(
959            bootstrapped_events,
960            self.metadata.clone(),
961        ))
962    }
963
964    /// Generate a new dataset with the same length by resampling the events in the original datset
965    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
966    ///
967    /// # Notes
968    ///
969    /// This method is not intended to be called in analyses but rather in writing methods
970    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
971    #[cfg(feature = "mpi")]
972    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
973        let n_events = self.n_events();
974        let mut indices: Vec<usize> = vec![0; n_events];
975        if world.is_root() {
976            let mut rng = fastrand::Rng::with_seed(seed as u64);
977            indices = (0..n_events)
978                .map(|_| rng.usize(0..n_events))
979                .collect::<Vec<usize>>();
980            indices.sort();
981        }
982        world.process_at_root().broadcast_into(&mut indices);
983        let local_indices: Vec<usize> = indices
984            .into_iter()
985            .filter_map(|idx| {
986                let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
987                if world.rank() == owning_rank {
988                    Some(local_index)
989                } else {
990                    None
991                }
992            })
993            .collect();
994        // `local_indices` only contains indices owned by the current rank, translating them into
995        // local indices on the events vector.
996        #[cfg(feature = "rayon")]
997        let bootstrapped_events: Vec<Arc<EventData>> = local_indices
998            .into_par_iter()
999            .map(|idx| self.events[idx].data_arc())
1000            .collect();
1001        #[cfg(not(feature = "rayon"))]
1002        let bootstrapped_events: Vec<Arc<EventData>> = local_indices
1003            .into_iter()
1004            .map(|idx| self.events[idx].data_arc())
1005            .collect();
1006        Arc::new(Dataset::new_with_metadata(
1007            bootstrapped_events,
1008            self.metadata.clone(),
1009        ))
1010    }
1011
1012    /// Generate a new dataset with the same length by resampling the events in the original datset
1013    /// with replacement. This can be used to perform error analysis via the bootstrap method.
1014    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1015        #[cfg(feature = "mpi")]
1016        {
1017            if let Some(world) = crate::mpi::get_world() {
1018                return self.bootstrap_mpi(seed, &world);
1019            }
1020        }
1021        self.bootstrap_local(seed)
1022    }
1023
1024    /// Filter the [`Dataset`] by a given [`VariableExpression`], selecting events for which
1025    /// the expression returns `true`.
1026    pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1027        let compiled = expression.compile(&self.metadata)?;
1028        #[cfg(feature = "rayon")]
1029        let filtered_events: Vec<Arc<EventData>> = self
1030            .events
1031            .par_iter()
1032            .filter(|event| compiled.evaluate(event.as_ref()))
1033            .map(|event| event.data_arc())
1034            .collect();
1035        #[cfg(not(feature = "rayon"))]
1036        let filtered_events: Vec<Arc<EventData>> = self
1037            .events
1038            .iter()
1039            .filter(|event| compiled.evaluate(event.as_ref()))
1040            .map(|event| event.data_arc())
1041            .collect();
1042        Ok(Arc::new(Dataset::new_with_metadata(
1043            filtered_events,
1044            self.metadata.clone(),
1045        )))
1046    }
1047
1048    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
1049    /// given `range`.
1050    pub fn bin_by<V>(
1051        &self,
1052        mut variable: V,
1053        bins: usize,
1054        range: (f64, f64),
1055    ) -> LadduResult<BinnedDataset>
1056    where
1057        V: Variable,
1058    {
1059        variable.bind(self.metadata())?;
1060        let bin_width = (range.1 - range.0) / bins as f64;
1061        let bin_edges = get_bin_edges(bins, range);
1062        let variable = variable;
1063        #[cfg(feature = "rayon")]
1064        let evaluated: Vec<(usize, Arc<EventData>)> = self
1065            .events
1066            .par_iter()
1067            .filter_map(|event| {
1068                let value = variable.value(event.as_ref());
1069                if value >= range.0 && value < range.1 {
1070                    let bin_index = ((value - range.0) / bin_width) as usize;
1071                    let bin_index = bin_index.min(bins - 1);
1072                    Some((bin_index, event.data_arc()))
1073                } else {
1074                    None
1075                }
1076            })
1077            .collect();
1078        #[cfg(not(feature = "rayon"))]
1079        let evaluated: Vec<(usize, Arc<EventData>)> = self
1080            .events
1081            .iter()
1082            .filter_map(|event| {
1083                let value = variable.value(event.as_ref());
1084                if value >= range.0 && value < range.1 {
1085                    let bin_index = ((value - range.0) / bin_width) as usize;
1086                    let bin_index = bin_index.min(bins - 1);
1087                    Some((bin_index, event.data_arc()))
1088                } else {
1089                    None
1090                }
1091            })
1092            .collect();
1093        let mut binned_events: Vec<Vec<Arc<EventData>>> = vec![Vec::default(); bins];
1094        for (bin_index, event) in evaluated {
1095            binned_events[bin_index].push(event.clone());
1096        }
1097        #[cfg(feature = "rayon")]
1098        let datasets: Vec<Arc<Dataset>> = binned_events
1099            .into_par_iter()
1100            .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1101            .collect();
1102        #[cfg(not(feature = "rayon"))]
1103        let datasets: Vec<Arc<Dataset>> = binned_events
1104            .into_iter()
1105            .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1106            .collect();
1107        Ok(BinnedDataset {
1108            datasets,
1109            edges: bin_edges,
1110        })
1111    }
1112
1113    /// Boost all the four-momenta in all [`EventData`]s to the rest frame of the given set of
1114    /// four-momenta identified by name.
1115    pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1116    where
1117        S: AsRef<str>,
1118    {
1119        let mut indices: Vec<usize> = Vec::new();
1120        for name in names {
1121            let name_ref = name.as_ref();
1122            if let Some(selection) = self.metadata.p4_selection(name_ref) {
1123                indices.extend_from_slice(selection.indices());
1124            } else {
1125                panic!("Unknown particle name '{name}'", name = name_ref);
1126            }
1127        }
1128        #[cfg(feature = "rayon")]
1129        let boosted_events: Vec<Arc<EventData>> = self
1130            .events
1131            .par_iter()
1132            .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1133            .collect();
1134        #[cfg(not(feature = "rayon"))]
1135        let boosted_events: Vec<Arc<EventData>> = self
1136            .events
1137            .iter()
1138            .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1139            .collect();
1140        Arc::new(Dataset::new_with_metadata(
1141            boosted_events,
1142            self.metadata.clone(),
1143        ))
1144    }
1145    /// Evaluate a [`Variable`] on every event in the [`Dataset`].
1146    pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
1147        variable.value_on(self)
1148    }
1149
1150    /// Open a file and construct a [`Dataset`] using named columns.
1151    ///
1152    /// Each entry in `p4_names` identifies a particle; the loader expects columns named
1153    /// ``{identifier}_px``, ``{identifier}_py``, ``{identifier}_pz``, and ``{identifier}_e`` for
1154    /// each identifier (component suffixes are matched case-insensitively). Auxiliary scalars are
1155    /// listed explicitly via `aux_names` and stored in the same order. Columns may be encoded as
1156    /// `Float32` or `Float64` and are promoted to `f64` on load.
1157    ///
1158    /// Load a dataset from a Parquet file.
1159    pub fn from_parquet(
1160        file_path: &str,
1161        options: &DatasetReadOptions,
1162    ) -> LadduResult<Arc<Dataset>> {
1163        let path = Self::canonicalize_dataset_path(file_path)?;
1164        Self::open_parquet_impl(path, options)
1165    }
1166
1167    /// Load a dataset from a ROOT TTree using the oxyroot backend.
1168    ///
1169    /// When reading ROOT files containing multiple trees, set [`DatasetReadOptions::tree`]
1170    /// to select one.
1171    pub fn from_root(file_path: &str, options: &DatasetReadOptions) -> LadduResult<Arc<Dataset>> {
1172        let path = Self::canonicalize_dataset_path(file_path)?;
1173        Self::open_root_impl(path, options)
1174    }
1175
1176    /// Write the dataset to a Parquet file.
1177    pub fn to_parquet(&self, file_path: &str, options: &DatasetWriteOptions) -> LadduResult<()> {
1178        let path = Self::expand_output_path(file_path)?;
1179        self.write_parquet_impl(path, options)
1180    }
1181
1182    /// Write the dataset to a ROOT TTree using the oxyroot backend.
1183    pub fn to_root(&self, file_path: &str, options: &DatasetWriteOptions) -> LadduResult<()> {
1184        let path = Self::expand_output_path(file_path)?;
1185        self.write_root_impl(path, options)
1186    }
1187    fn canonicalize_dataset_path(file_path: &str) -> LadduResult<PathBuf> {
1188        Ok(Path::new(&*shellexpand::full(file_path)?).canonicalize()?)
1189    }
1190
1191    fn expand_output_path(file_path: &str) -> LadduResult<PathBuf> {
1192        Ok(PathBuf::from(&*shellexpand::full(file_path)?))
1193    }
1194
1195    fn open_parquet_impl(
1196        file_path: PathBuf,
1197        options: &DatasetReadOptions,
1198    ) -> LadduResult<Arc<Dataset>> {
1199        let (detected_p4_names, detected_aux_names) = detect_columns(&file_path)?;
1200        let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1201        let file = File::open(file_path)?;
1202        let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1203        let reader = builder.build()?;
1204        let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
1205
1206        #[cfg(feature = "rayon")]
1207        let events: Vec<Arc<EventData>> = {
1208            let batch_events: Vec<LadduResult<Vec<Arc<EventData>>>> = batches
1209                .into_par_iter()
1210                .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1211                .collect();
1212            let mut events = Vec::new();
1213            for batch in batch_events {
1214                let mut batch = batch?;
1215                events.append(&mut batch);
1216            }
1217            events
1218        };
1219
1220        #[cfg(not(feature = "rayon"))]
1221        let events: Vec<Arc<EventData>> = {
1222            batches
1223                .into_iter()
1224                .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1225                .collect::<LadduResult<Vec<_>>>()?
1226                .into_iter()
1227                .flatten()
1228                .collect()
1229        };
1230
1231        Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1232    }
1233
1234    fn open_root_impl(
1235        file_path: PathBuf,
1236        options: &DatasetReadOptions,
1237    ) -> LadduResult<Arc<Dataset>> {
1238        let mut file = RootFile::open(&file_path).map_err(|err| {
1239            LadduError::Custom(format!(
1240                "Failed to open ROOT file '{}': {err}",
1241                file_path.display()
1242            ))
1243        })?;
1244
1245        let (tree, tree_name) = resolve_root_tree(&mut file, options.tree.as_deref())?;
1246
1247        let branches: Vec<&Branch> = tree.branches().collect();
1248        let mut lookup: BranchLookup<'_> = IndexMap::new();
1249        for &branch in &branches {
1250            if let Some(kind) = branch_scalar_kind(branch) {
1251                lookup.insert(branch.name(), (kind, branch));
1252            }
1253        }
1254
1255        if lookup.is_empty() {
1256            return Err(LadduError::Custom(format!(
1257                "No float or double branches found in ROOT tree '{tree_name}'"
1258            )));
1259        }
1260
1261        let column_names: Vec<&str> = lookup.keys().copied().collect();
1262        let (detected_p4_names, detected_aux_names) = infer_p4_and_aux_names(&column_names);
1263        let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1264
1265        struct RootP4Columns {
1266            px: Vec<f64>,
1267            py: Vec<f64>,
1268            pz: Vec<f64>,
1269            e: Vec<f64>,
1270        }
1271
1272        // TODO: do all reads in parallel if possible to match parquet impl
1273        let mut p4_columns = Vec::with_capacity(metadata.p4_names.len());
1274        for name in &metadata.p4_names {
1275            let logical = format!("{name}_px");
1276            let px = read_branch_values_from_candidates(
1277                &lookup,
1278                &component_candidates(name, "px"),
1279                &logical,
1280            )?;
1281
1282            let logical = format!("{name}_py");
1283            let py = read_branch_values_from_candidates(
1284                &lookup,
1285                &component_candidates(name, "py"),
1286                &logical,
1287            )?;
1288
1289            let logical = format!("{name}_pz");
1290            let pz = read_branch_values_from_candidates(
1291                &lookup,
1292                &component_candidates(name, "pz"),
1293                &logical,
1294            )?;
1295
1296            let logical = format!("{name}_e");
1297            let e = read_branch_values_from_candidates(
1298                &lookup,
1299                &component_candidates(name, "e"),
1300                &logical,
1301            )?;
1302
1303            p4_columns.push(RootP4Columns { px, py, pz, e });
1304        }
1305
1306        let mut aux_columns = Vec::with_capacity(metadata.aux_names.len());
1307        for name in &metadata.aux_names {
1308            let values = read_branch_values(&lookup, name)?;
1309            aux_columns.push(values);
1310        }
1311
1312        let n_events = if let Some(first) = p4_columns.first() {
1313            first.px.len()
1314        } else if let Some(first) = aux_columns.first() {
1315            first.len()
1316        } else {
1317            return Err(LadduError::Custom(
1318                "Unable to determine event count; dataset has no four-momentum or auxiliary columns".to_string(),
1319            ));
1320        };
1321
1322        let weight_values = match read_branch_values_optional(&lookup, "weight")? {
1323            Some(values) => {
1324                if values.len() != n_events {
1325                    return Err(LadduError::Custom(format!(
1326                        "Column 'weight' has {} entries but expected {}",
1327                        values.len(),
1328                        n_events
1329                    )));
1330                }
1331                values
1332            }
1333            None => vec![1.0; n_events],
1334        };
1335
1336        let mut events = Vec::with_capacity(n_events);
1337        for row in 0..n_events {
1338            let mut p4s = Vec::with_capacity(p4_columns.len());
1339            for columns in &p4_columns {
1340                p4s.push(Vec4::new(
1341                    columns.px[row],
1342                    columns.py[row],
1343                    columns.pz[row],
1344                    columns.e[row],
1345                ));
1346            }
1347
1348            let mut aux = Vec::with_capacity(aux_columns.len());
1349            for column in &aux_columns {
1350                aux.push(column[row]);
1351            }
1352
1353            let event = EventData {
1354                p4s,
1355                aux,
1356                weight: weight_values[row],
1357            };
1358            events.push(Arc::new(event));
1359        }
1360
1361        Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1362    }
1363
1364    fn write_parquet_impl(
1365        &self,
1366        file_path: PathBuf,
1367        options: &DatasetWriteOptions,
1368    ) -> LadduResult<()> {
1369        let batch_size = options.batch_size.max(1);
1370        let precision = options.precision;
1371        let schema = Arc::new(build_parquet_schema(&self.metadata, precision));
1372
1373        #[cfg(feature = "mpi")]
1374        let is_root = crate::mpi::get_world()
1375            .as_ref()
1376            .map_or(true, |world| world.rank() == 0);
1377        #[cfg(not(feature = "mpi"))]
1378        let is_root = true;
1379
1380        let mut writer: Option<ArrowWriter<File>> = None;
1381        if is_root {
1382            let file = File::create(&file_path)?;
1383            writer = Some(
1384                ArrowWriter::try_new(file, schema.clone(), None).map_err(|err| {
1385                    LadduError::Custom(format!("Failed to create Parquet writer: {err}"))
1386                })?,
1387            );
1388        }
1389
1390        let mut iter = self.iter();
1391        loop {
1392            let mut buffers =
1393                ColumnBuffers::new(self.metadata.p4_names.len(), self.metadata.aux_names.len());
1394            let mut rows = 0usize;
1395
1396            while rows < batch_size {
1397                match iter.next() {
1398                    Some(event) => {
1399                        if is_root {
1400                            buffers.push_event(&event);
1401                        }
1402                        rows += 1;
1403                    }
1404                    None => break,
1405                }
1406            }
1407
1408            if rows == 0 {
1409                break;
1410            }
1411
1412            if let Some(writer) = writer.as_mut() {
1413                let batch = buffers
1414                    .into_record_batch(schema.clone(), precision)
1415                    .map_err(|err| {
1416                        LadduError::Custom(format!("Failed to build Parquet batch: {err}"))
1417                    })?;
1418                writer.write(&batch).map_err(|err| {
1419                    LadduError::Custom(format!("Failed to write Parquet batch: {err}"))
1420                })?;
1421            }
1422        }
1423
1424        if let Some(writer) = writer {
1425            writer.close().map_err(|err| {
1426                LadduError::Custom(format!("Failed to finalise Parquet file: {err}"))
1427            })?;
1428        }
1429
1430        Ok(())
1431    }
1432
1433    fn write_root_impl(
1434        &self,
1435        file_path: PathBuf,
1436        options: &DatasetWriteOptions,
1437    ) -> LadduResult<()> {
1438        let tree_name = options.tree.clone().unwrap_or_else(|| "events".to_string());
1439        let branch_count = self.metadata.p4_names.len() * 4 + self.metadata.aux_names.len() + 1; // +weight
1440
1441        #[cfg(feature = "mpi")]
1442        let mut world_opt = crate::mpi::get_world();
1443        #[cfg(feature = "mpi")]
1444        let is_root = world_opt.as_ref().map_or(true, |world| world.rank() == 0);
1445        #[cfg(not(feature = "mpi"))]
1446        let is_root = true;
1447
1448        #[cfg(feature = "mpi")]
1449        let world: Option<WorldHandle> = world_opt.take();
1450        #[cfg(not(feature = "mpi"))]
1451        let world: Option<WorldHandle> = None;
1452
1453        let total_events = self.n_events();
1454        let dataset_arc = Arc::new(self.clone());
1455
1456        match options.precision {
1457            FloatPrecision::F64 => self.write_root_with_type::<f64>(
1458                dataset_arc,
1459                world,
1460                is_root,
1461                &file_path,
1462                &tree_name,
1463                branch_count,
1464                total_events,
1465            ),
1466            FloatPrecision::F32 => self.write_root_with_type::<f32>(
1467                dataset_arc,
1468                world,
1469                is_root,
1470                &file_path,
1471                &tree_name,
1472                branch_count,
1473                total_events,
1474            ),
1475        }
1476    }
1477}
1478
1479impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
1480    debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
1481    debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
1482    Dataset {
1483        events: a
1484            .events
1485            .iter()
1486            .chain(b.events.iter())
1487            .cloned()
1488            .collect(),
1489        metadata: a.metadata.clone(),
1490    }
1491});
1492
1493/// Options for reading a [`Dataset`] from a file.
1494///
1495/// # See Also
1496/// [`Dataset::from_parquet`], [`Dataset::from_root`]
1497#[derive(Default, Clone)]
1498pub struct DatasetReadOptions {
1499    /// Particle names to read from the data file.
1500    pub p4_names: Option<Vec<String>>,
1501    /// Auxiliary scalar names to read from the data file.
1502    pub aux_names: Option<Vec<String>>,
1503    /// Name of the tree to read when loading ROOT files. When absent and the file contains a
1504    /// single tree, it will be selected automatically.
1505    pub tree: Option<String>,
1506    /// Optional aliases mapping logical names to selections of four-momenta.
1507    pub aliases: IndexMap<String, P4Selection>,
1508}
1509
1510/// Precision for writing floating-point columns.
1511#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
1512pub enum FloatPrecision {
1513    /// 32-bit floats.
1514    F32,
1515    /// 64-bit floats.
1516    #[default]
1517    F64,
1518}
1519
1520/// Options for writing a [`Dataset`] to disk.
1521#[derive(Clone, Debug)]
1522pub struct DatasetWriteOptions {
1523    /// Number of events to include in each batch when writing.
1524    pub batch_size: usize,
1525    /// Floating-point precision to use for persisted columns.
1526    pub precision: FloatPrecision,
1527    /// Tree name to use when writing ROOT files.
1528    pub tree: Option<String>,
1529}
1530
1531impl Default for DatasetWriteOptions {
1532    fn default() -> Self {
1533        Self {
1534            batch_size: DEFAULT_WRITE_BATCH_SIZE,
1535            precision: FloatPrecision::default(),
1536            tree: None,
1537        }
1538    }
1539}
1540
1541impl DatasetWriteOptions {
1542    /// Override the batch size used for writing; defaults to 10_000.
1543    pub fn batch_size(mut self, batch_size: usize) -> Self {
1544        self.batch_size = batch_size;
1545        self
1546    }
1547
1548    /// Select the floating-point precision for persisted columns.
1549    pub fn precision(mut self, precision: FloatPrecision) -> Self {
1550        self.precision = precision;
1551        self
1552    }
1553
1554    /// Set the ROOT tree name (defaults to \"events\").
1555    pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
1556        self.tree = Some(name.into());
1557        self
1558    }
1559}
1560impl DatasetReadOptions {
1561    /// Create a new [`Default`] set of [`DatasetReadOptions`].
1562    pub fn new() -> Self {
1563        Self::default()
1564    }
1565
1566    /// If provided, the specified particles will be read from the data file (assuming columns with
1567    /// required suffixes are present, i.e. `<particle>_px`, `<particle>_py`, `<particle>_pz`, and `<particle>_e`). Otherwise, all valid columns with these suffixes will be read.
1568    pub fn p4_names<I, S>(mut self, names: I) -> Self
1569    where
1570        I: IntoIterator<Item = S>,
1571        S: AsRef<str>,
1572    {
1573        self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1574        self
1575    }
1576
1577    /// If provided, the specified columns will be read as auxiliary scalars. Otherwise, all valid
1578    /// columns which do not satisfy the conditions required to be read as four-momenta will be
1579    /// used.
1580    pub fn aux_names<I, S>(mut self, names: I) -> Self
1581    where
1582        I: IntoIterator<Item = S>,
1583        S: AsRef<str>,
1584    {
1585        self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1586        self
1587    }
1588
1589    /// Select the tree to read when opening ROOT files.
1590    pub fn tree<S>(mut self, name: S) -> Self
1591    where
1592        S: AsRef<str>,
1593    {
1594        self.tree = Some(name.as_ref().to_string());
1595        self
1596    }
1597
1598    /// Register an alias for one or more existing four-momenta.
1599    pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
1600    where
1601        N: Into<String>,
1602        S: IntoP4Selection,
1603    {
1604        self.aliases.insert(name.into(), selection.into_selection());
1605        self
1606    }
1607
1608    /// Register multiple aliases for four-momenta selections.
1609    pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
1610    where
1611        I: IntoIterator<Item = (N, S)>,
1612        N: Into<String>,
1613        S: IntoP4Selection,
1614    {
1615        for (name, selection) in aliases {
1616            self = self.alias(name, selection);
1617        }
1618        self
1619    }
1620
1621    fn resolve_metadata(
1622        &self,
1623        detected_p4_names: Vec<String>,
1624        detected_aux_names: Vec<String>,
1625    ) -> LadduResult<Arc<DatasetMetadata>> {
1626        let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
1627        let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
1628
1629        let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
1630        if !self.aliases.is_empty() {
1631            metadata.add_p4_aliases(self.aliases.clone())?;
1632        }
1633        Ok(Arc::new(metadata))
1634    }
1635}
1636
1637const P4_COMPONENT_SUFFIXES: [&str; 4] = ["_px", "_py", "_pz", "_e"];
1638const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
1639
1640fn detect_columns(file_path: &PathBuf) -> LadduResult<(Vec<String>, Vec<String>)> {
1641    let file = File::open(file_path)?;
1642    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1643    let schema = builder.schema();
1644    let float_cols: Vec<&str> = schema
1645        .fields()
1646        .iter()
1647        .filter(|f| matches!(f.data_type(), DataType::Float32 | DataType::Float64))
1648        .map(|f| f.name().as_str())
1649        .collect();
1650    Ok(infer_p4_and_aux_names(&float_cols))
1651}
1652
1653fn infer_p4_and_aux_names(float_cols: &[&str]) -> (Vec<String>, Vec<String>) {
1654    let suffix_set: IndexSet<&str> = P4_COMPONENT_SUFFIXES.iter().copied().collect();
1655    let mut groups: IndexMap<&str, IndexSet<&str>> = IndexMap::new();
1656    for col in float_cols {
1657        for suffix in &suffix_set {
1658            if let Some(prefix) = col.strip_suffix(suffix) {
1659                groups.entry(prefix).or_default().insert(*suffix);
1660            }
1661        }
1662    }
1663
1664    let mut p4_names: Vec<String> = Vec::new();
1665    let mut p4_columns: IndexSet<String> = IndexSet::new();
1666    for (prefix, suffixes) in &groups {
1667        if suffixes.len() == suffix_set.len() {
1668            p4_names.push((*prefix).to_string());
1669            for suffix in &suffix_set {
1670                p4_columns.insert(format!("{prefix}{suffix}"));
1671            }
1672        }
1673    }
1674
1675    let mut aux_names: Vec<String> = Vec::new();
1676    for col in float_cols {
1677        if p4_columns.contains(*col) {
1678            continue;
1679        }
1680        if col.eq_ignore_ascii_case("weight") {
1681            continue;
1682        }
1683        aux_names.push((*col).to_string());
1684    }
1685
1686    (p4_names, aux_names)
1687}
1688
1689type BranchLookup<'a> = IndexMap<&'a str, (RootScalarKind, &'a Branch)>;
1690
1691#[derive(Clone, Copy)]
1692enum RootScalarKind {
1693    F32,
1694    F64,
1695}
1696
1697fn branch_scalar_kind(branch: &Branch) -> Option<RootScalarKind> {
1698    let type_name = branch.item_type_name();
1699    let lower = type_name.to_ascii_lowercase();
1700    if lower.contains("vector") {
1701        return None;
1702    }
1703    match lower.as_str() {
1704        "float" | "float_t" | "float32_t" => Some(RootScalarKind::F32),
1705        "double" | "double_t" | "double32_t" => Some(RootScalarKind::F64),
1706        _ => None,
1707    }
1708}
1709
1710fn read_branch_values<'a>(lookup: &BranchLookup<'a>, column_name: &str) -> LadduResult<Vec<f64>> {
1711    let (kind, branch) =
1712        lookup
1713            .get(column_name)
1714            .copied()
1715            .ok_or_else(|| LadduError::MissingColumn {
1716                name: column_name.to_string(),
1717            })?;
1718
1719    let values = match kind {
1720        RootScalarKind::F32 => branch
1721            .as_iter::<f32>()
1722            .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1723            .map(|value| value as f64)
1724            .collect(),
1725        RootScalarKind::F64 => branch
1726            .as_iter::<f64>()
1727            .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1728            .collect(),
1729    };
1730    Ok(values)
1731}
1732
1733fn read_branch_values_optional<'a>(
1734    lookup: &BranchLookup<'a>,
1735    column_name: &str,
1736) -> LadduResult<Option<Vec<f64>>> {
1737    if lookup.contains_key(column_name) {
1738        read_branch_values(lookup, column_name).map(Some)
1739    } else {
1740        Ok(None)
1741    }
1742}
1743
1744fn read_branch_values_from_candidates<'a>(
1745    lookup: &BranchLookup<'a>,
1746    candidates: &[String],
1747    logical_name: &str,
1748) -> LadduResult<Vec<f64>> {
1749    for candidate in candidates {
1750        if lookup.contains_key(candidate.as_str()) {
1751            return read_branch_values(lookup, candidate);
1752        }
1753    }
1754    Err(LadduError::MissingColumn {
1755        name: logical_name.to_string(),
1756    })
1757}
1758
1759fn resolve_root_tree(
1760    file: &mut RootFile,
1761    requested: Option<&str>,
1762) -> LadduResult<(ReaderTree, String)> {
1763    if let Some(name) = requested {
1764        let tree = file
1765            .get_tree(name)
1766            .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{name}'"), err))?;
1767        return Ok((tree, name.to_string()));
1768    }
1769
1770    let tree_names: Vec<String> = file
1771        .keys()
1772        .into_iter()
1773        .filter(|key| key.class_name() == "TTree")
1774        .map(|key| key.name().to_string())
1775        .collect();
1776
1777    if tree_names.is_empty() {
1778        return Err(LadduError::Custom(
1779            "ROOT file does not contain any TTrees".to_string(),
1780        ));
1781    }
1782
1783    if tree_names.len() > 1 {
1784        return Err(LadduError::Custom(format!(
1785            "Multiple TTrees found ({:?}); specify DatasetReadOptions::tree to disambiguate",
1786            tree_names
1787        )));
1788    }
1789
1790    let selected = &tree_names[0];
1791    let tree = file
1792        .get_tree(selected)
1793        .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{selected}'"), err))?;
1794    Ok((tree, selected.clone()))
1795}
1796
1797fn map_root_error<E: std::fmt::Display>(context: &str, err: E) -> LadduError {
1798    LadduError::Custom(format!("{context}: {err}")) // NOTE: the oxyroot error type is not public
1799}
1800
1801#[derive(Clone, Copy)]
1802enum FloatColumn<'a> {
1803    F32(&'a Float32Array),
1804    F64(&'a Float64Array),
1805}
1806
1807impl<'a> FloatColumn<'a> {
1808    fn value(&self, row: usize) -> f64 {
1809        match self {
1810            Self::F32(array) => array.value(row) as f64,
1811            Self::F64(array) => array.value(row),
1812        }
1813    }
1814}
1815
1816struct P4Columns<'a> {
1817    px: FloatColumn<'a>,
1818    py: FloatColumn<'a>,
1819    pz: FloatColumn<'a>,
1820    e: FloatColumn<'a>,
1821}
1822
1823fn prepare_float_column<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<FloatColumn<'a>> {
1824    prepare_float_column_from_candidates(batch, &[name.to_string()], name)
1825}
1826
1827fn prepare_p4_columns<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<P4Columns<'a>> {
1828    Ok(P4Columns {
1829        px: prepare_float_column_from_candidates(
1830            batch,
1831            &component_candidates(name, "px"),
1832            &format!("{name}_px"),
1833        )?,
1834        py: prepare_float_column_from_candidates(
1835            batch,
1836            &component_candidates(name, "py"),
1837            &format!("{name}_py"),
1838        )?,
1839        pz: prepare_float_column_from_candidates(
1840            batch,
1841            &component_candidates(name, "pz"),
1842            &format!("{name}_pz"),
1843        )?,
1844        e: prepare_float_column_from_candidates(
1845            batch,
1846            &component_candidates(name, "e"),
1847            &format!("{name}_e"),
1848        )?,
1849    })
1850}
1851
1852fn component_candidates(name: &str, suffix: &str) -> Vec<String> {
1853    let mut candidates = Vec::with_capacity(3);
1854    let base = format!("{name}_{suffix}");
1855    candidates.push(base.clone());
1856
1857    let mut capitalized = suffix.to_string();
1858    if let Some(first) = capitalized.get_mut(0..1) {
1859        first.make_ascii_uppercase();
1860    }
1861    if capitalized != suffix {
1862        candidates.push(format!("{name}_{capitalized}"));
1863    }
1864
1865    let upper = suffix.to_ascii_uppercase();
1866    if upper != suffix && upper != capitalized {
1867        candidates.push(format!("{name}_{upper}"));
1868    }
1869
1870    candidates
1871}
1872
1873fn find_float_column_from_candidates<'a>(
1874    batch: &'a RecordBatch,
1875    candidates: &[String],
1876) -> LadduResult<Option<FloatColumn<'a>>> {
1877    use arrow::datatypes::DataType;
1878
1879    for candidate in candidates {
1880        if let Some(column) = batch.column_by_name(candidate) {
1881            return match column.data_type() {
1882                DataType::Float32 => Ok(Some(FloatColumn::F32(
1883                    column
1884                        .as_any()
1885                        .downcast_ref::<Float32Array>()
1886                        .expect("Column advertised as Float32 but could not be downcast"),
1887                ))),
1888                DataType::Float64 => Ok(Some(FloatColumn::F64(
1889                    column
1890                        .as_any()
1891                        .downcast_ref::<Float64Array>()
1892                        .expect("Column advertised as Float64 but could not be downcast"),
1893                ))),
1894                other => {
1895                    return Err(LadduError::InvalidColumnType {
1896                        name: candidate.clone(),
1897                        datatype: other.to_string(),
1898                    })
1899                }
1900            };
1901        }
1902    }
1903    Ok(None)
1904}
1905
1906fn prepare_float_column_from_candidates<'a>(
1907    batch: &'a RecordBatch,
1908    candidates: &[String],
1909    logical_name: &str,
1910) -> LadduResult<FloatColumn<'a>> {
1911    find_float_column_from_candidates(batch, candidates)?.ok_or_else(|| LadduError::MissingColumn {
1912        name: logical_name.to_string(),
1913    })
1914}
1915
1916fn record_batch_to_events(
1917    batch: RecordBatch,
1918    p4_names: &[String],
1919    aux_names: &[String],
1920) -> LadduResult<Vec<Arc<EventData>>> {
1921    let batch_ref = &batch;
1922    let p4_columns: Vec<P4Columns<'_>> = p4_names
1923        .iter()
1924        .map(|name| prepare_p4_columns(batch_ref, name))
1925        .collect::<Result<_, _>>()?;
1926
1927    let aux_columns: Vec<FloatColumn<'_>> = aux_names
1928        .iter()
1929        .map(|name| prepare_float_column(batch_ref, name))
1930        .collect::<Result<_, _>>()?;
1931
1932    let weight_column = find_float_column_from_candidates(batch_ref, &["weight".to_string()])?;
1933
1934    let mut events = Vec::with_capacity(batch_ref.num_rows());
1935    for row in 0..batch_ref.num_rows() {
1936        let mut p4s = Vec::with_capacity(p4_columns.len());
1937        for columns in &p4_columns {
1938            let px = columns.px.value(row);
1939            let py = columns.py.value(row);
1940            let pz = columns.pz.value(row);
1941            let e = columns.e.value(row);
1942            p4s.push(Vec4::new(px, py, pz, e));
1943        }
1944
1945        let mut aux = Vec::with_capacity(aux_columns.len());
1946        for column in &aux_columns {
1947            aux.push(column.value(row));
1948        }
1949
1950        let event = EventData {
1951            p4s,
1952            aux,
1953            weight: weight_column
1954                .as_ref()
1955                .map(|column| column.value(row))
1956                .unwrap_or(1.0),
1957        };
1958        events.push(Arc::new(event));
1959    }
1960    Ok(events)
1961}
1962
1963struct ColumnBuffers {
1964    p4: Vec<P4Buffer>,
1965    aux: Vec<Vec<f64>>,
1966    weight: Vec<f64>,
1967}
1968
1969impl ColumnBuffers {
1970    fn new(n_p4: usize, n_aux: usize) -> Self {
1971        let p4 = (0..n_p4).map(|_| P4Buffer::default()).collect();
1972        let aux = vec![Vec::new(); n_aux];
1973        Self {
1974            p4,
1975            aux,
1976            weight: Vec::new(),
1977        }
1978    }
1979
1980    fn push_event(&mut self, event: &Event) {
1981        for (buffer, p4) in self.p4.iter_mut().zip(event.p4s.iter()) {
1982            buffer.px.push(p4.x);
1983            buffer.py.push(p4.y);
1984            buffer.pz.push(p4.z);
1985            buffer.e.push(p4.t);
1986        }
1987
1988        for (buffer, value) in self.aux.iter_mut().zip(event.aux.iter()) {
1989            buffer.push(*value);
1990        }
1991
1992        self.weight.push(event.weight);
1993    }
1994
1995    fn into_record_batch(
1996        self,
1997        schema: Arc<Schema>,
1998        precision: FloatPrecision,
1999    ) -> arrow::error::Result<RecordBatch> {
2000        let mut columns: Vec<arrow::array::ArrayRef> = Vec::new();
2001
2002        match precision {
2003            FloatPrecision::F64 => {
2004                for buffer in &self.p4 {
2005                    columns.push(Arc::new(Float64Array::from(buffer.px.clone())));
2006                    columns.push(Arc::new(Float64Array::from(buffer.py.clone())));
2007                    columns.push(Arc::new(Float64Array::from(buffer.pz.clone())));
2008                    columns.push(Arc::new(Float64Array::from(buffer.e.clone())));
2009                }
2010
2011                for buffer in &self.aux {
2012                    columns.push(Arc::new(Float64Array::from(buffer.clone())));
2013                }
2014
2015                columns.push(Arc::new(Float64Array::from(self.weight)));
2016            }
2017            FloatPrecision::F32 => {
2018                for buffer in &self.p4 {
2019                    columns.push(Arc::new(Float32Array::from(
2020                        buffer.px.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2021                    )));
2022                    columns.push(Arc::new(Float32Array::from(
2023                        buffer.py.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2024                    )));
2025                    columns.push(Arc::new(Float32Array::from(
2026                        buffer.pz.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2027                    )));
2028                    columns.push(Arc::new(Float32Array::from(
2029                        buffer.e.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2030                    )));
2031                }
2032
2033                for buffer in &self.aux {
2034                    columns.push(Arc::new(Float32Array::from(
2035                        buffer.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2036                    )));
2037                }
2038
2039                columns.push(Arc::new(Float32Array::from(
2040                    self.weight.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2041                )));
2042            }
2043        }
2044
2045        RecordBatch::try_new(schema, columns)
2046    }
2047}
2048
2049#[derive(Default)]
2050struct P4Buffer {
2051    px: Vec<f64>,
2052    py: Vec<f64>,
2053    pz: Vec<f64>,
2054    e: Vec<f64>,
2055}
2056
2057fn build_parquet_schema(metadata: &DatasetMetadata, precision: FloatPrecision) -> Schema {
2058    let dtype = match precision {
2059        FloatPrecision::F64 => DataType::Float64,
2060        FloatPrecision::F32 => DataType::Float32,
2061    };
2062
2063    let mut fields = Vec::new();
2064    for name in &metadata.p4_names {
2065        for suffix in P4_COMPONENT_SUFFIXES {
2066            fields.push(Field::new(format!("{name}{suffix}"), dtype.clone(), false));
2067        }
2068    }
2069
2070    for name in &metadata.aux_names {
2071        fields.push(Field::new(name.clone(), dtype.clone(), false));
2072    }
2073
2074    fields.push(Field::new("weight", dtype, false));
2075    Schema::new(fields)
2076}
2077
2078trait FromF64 {
2079    fn from_f64(value: f64) -> Self;
2080}
2081
2082impl FromF64 for f64 {
2083    fn from_f64(value: f64) -> Self {
2084        value
2085    }
2086}
2087
2088impl FromF64 for f32 {
2089    fn from_f64(value: f64) -> Self {
2090        value as f32
2091    }
2092}
2093
2094struct SharedEventFetcher {
2095    dataset: Arc<Dataset>,
2096    world: Option<WorldHandle>,
2097    total: usize,
2098    branch_count: usize,
2099    current_index: Option<usize>,
2100    current_event: Option<Event>,
2101    remaining: usize,
2102}
2103
2104impl SharedEventFetcher {
2105    fn new(
2106        dataset: Arc<Dataset>,
2107        world: Option<WorldHandle>,
2108        total: usize,
2109        branch_count: usize,
2110    ) -> Self {
2111        Self {
2112            dataset,
2113            world,
2114            total,
2115            branch_count,
2116            current_index: None,
2117            current_event: None,
2118            remaining: 0,
2119        }
2120    }
2121
2122    fn event_for_index(&mut self, index: usize) -> Option<Event> {
2123        if index >= self.total {
2124            return None;
2125        }
2126
2127        let refresh_needed = match self.current_index {
2128            None => true,
2129            Some(current) => current != index || self.remaining == 0,
2130        };
2131
2132        if refresh_needed {
2133            let event =
2134                fetch_event_for_index(&self.dataset, index, self.total, self.world.as_ref());
2135            self.current_index = Some(index);
2136            self.remaining = self.branch_count;
2137            self.current_event = Some(event);
2138        }
2139
2140        let event = self.current_event.as_ref().cloned();
2141        if self.remaining > 0 {
2142            self.remaining -= 1;
2143        }
2144        if self.remaining == 0 {
2145            // Drop the cached event so the next request fetches the next index.
2146            self.current_event = None;
2147        }
2148        event
2149    }
2150}
2151
2152enum ColumnKind {
2153    Px(usize),
2154    Py(usize),
2155    Pz(usize),
2156    E(usize),
2157    Aux(usize),
2158    Weight,
2159}
2160
2161struct ColumnIterator<T> {
2162    fetcher: Arc<Mutex<SharedEventFetcher>>,
2163    index: usize,
2164    kind: ColumnKind,
2165    _marker: std::marker::PhantomData<T>,
2166}
2167
2168impl<T> ColumnIterator<T> {
2169    fn new(fetcher: Arc<Mutex<SharedEventFetcher>>, kind: ColumnKind) -> Self {
2170        Self {
2171            fetcher,
2172            index: 0,
2173            kind,
2174            _marker: std::marker::PhantomData,
2175        }
2176    }
2177}
2178
2179impl<T> Iterator for ColumnIterator<T>
2180where
2181    T: FromF64,
2182{
2183    type Item = T;
2184
2185    fn next(&mut self) -> Option<Self::Item> {
2186        let mut fetcher = self.fetcher.lock();
2187        let event = fetcher.event_for_index(self.index)?;
2188        self.index += 1;
2189
2190        match self.kind {
2191            ColumnKind::Px(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.x)),
2192            ColumnKind::Py(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.y)),
2193            ColumnKind::Pz(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.z)),
2194            ColumnKind::E(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.t)),
2195            ColumnKind::Aux(idx) => event.aux.get(idx).map(|value| T::from_f64(*value)),
2196            ColumnKind::Weight => Some(T::from_f64(event.weight)),
2197        }
2198    }
2199}
2200
2201fn build_root_column_iterators<T>(
2202    dataset: Arc<Dataset>,
2203    world: Option<WorldHandle>,
2204    branch_count: usize,
2205    total: usize,
2206) -> Vec<(String, ColumnIterator<T>)>
2207where
2208    T: FromF64,
2209{
2210    let fetcher = Arc::new(Mutex::new(SharedEventFetcher::new(
2211        dataset,
2212        world,
2213        total,
2214        branch_count,
2215    )));
2216
2217    let p4_names: Vec<String> = fetcher.lock().dataset.metadata.p4_names.clone();
2218    let aux_names: Vec<String> = fetcher.lock().dataset.metadata.aux_names.clone();
2219
2220    let mut iterators = Vec::new();
2221
2222    for (idx, name) in p4_names.iter().enumerate() {
2223        iterators.push((
2224            format!("{name}_px"),
2225            ColumnIterator::new(fetcher.clone(), ColumnKind::Px(idx)),
2226        ));
2227        iterators.push((
2228            format!("{name}_py"),
2229            ColumnIterator::new(fetcher.clone(), ColumnKind::Py(idx)),
2230        ));
2231        iterators.push((
2232            format!("{name}_pz"),
2233            ColumnIterator::new(fetcher.clone(), ColumnKind::Pz(idx)),
2234        ));
2235        iterators.push((
2236            format!("{name}_e"),
2237            ColumnIterator::new(fetcher.clone(), ColumnKind::E(idx)),
2238        ));
2239    }
2240
2241    for (idx, name) in aux_names.iter().enumerate() {
2242        iterators.push((
2243            name.clone(),
2244            ColumnIterator::new(fetcher.clone(), ColumnKind::Aux(idx)),
2245        ));
2246    }
2247
2248    iterators.push((
2249        "weight".to_string(),
2250        ColumnIterator::new(fetcher, ColumnKind::Weight),
2251    ));
2252
2253    iterators
2254}
2255
2256fn drain_column_iterators<T>(iterators: &mut [(String, ColumnIterator<T>)], n_events: usize)
2257where
2258    T: FromF64,
2259{
2260    for _ in 0..n_events {
2261        for (_name, iterator) in iterators.iter_mut() {
2262            let _ = iterator.next();
2263        }
2264    }
2265}
2266
2267fn fetch_event_for_index(
2268    dataset: &Dataset,
2269    index: usize,
2270    total: usize,
2271    world: Option<&WorldHandle>,
2272) -> Event {
2273    let _ = total;
2274    let _ = world;
2275    #[cfg(feature = "mpi")]
2276    {
2277        if let Some(world) = world {
2278            return fetch_event_mpi(dataset, index, world, total);
2279        }
2280    }
2281
2282    dataset.index_local(index).clone()
2283}
2284
2285impl Dataset {
2286    #[allow(clippy::too_many_arguments)]
2287    fn write_root_with_type<T>(
2288        &self,
2289        dataset: Arc<Dataset>,
2290        world: Option<WorldHandle>,
2291        is_root: bool,
2292        file_path: &Path,
2293        tree_name: &str,
2294        branch_count: usize,
2295        total_events: usize,
2296    ) -> LadduResult<()>
2297    where
2298        T: FromF64 + oxyroot::Marshaler + 'static,
2299    {
2300        let mut iterators =
2301            build_root_column_iterators::<T>(dataset, world, branch_count, total_events);
2302
2303        if is_root {
2304            let mut file = RootFile::create(file_path).map_err(|err| {
2305                LadduError::Custom(format!(
2306                    "Failed to create ROOT file '{}': {err}",
2307                    file_path.display()
2308                ))
2309            })?;
2310
2311            let mut tree = WriterTree::new(tree_name);
2312            for (name, iterator) in iterators {
2313                tree.new_branch(name, iterator);
2314            }
2315
2316            tree.write(&mut file).map_err(|err| {
2317                LadduError::Custom(format!(
2318                    "Failed to write ROOT tree '{tree_name}' to '{}': {err}",
2319                    file_path.display()
2320                ))
2321            })?;
2322
2323            file.close().map_err(|err| {
2324                LadduError::Custom(format!(
2325                    "Failed to close ROOT file '{}': {err}",
2326                    file_path.display()
2327                ))
2328            })?;
2329        } else {
2330            drain_column_iterators(&mut iterators, total_events);
2331        }
2332
2333        Ok(())
2334    }
2335}
2336
2337/// A list of [`Dataset`]s formed by binning [`EventData`] by some [`Variable`].
2338pub struct BinnedDataset {
2339    datasets: Vec<Arc<Dataset>>,
2340    edges: Vec<f64>,
2341}
2342
2343impl Index<usize> for BinnedDataset {
2344    type Output = Arc<Dataset>;
2345
2346    fn index(&self, index: usize) -> &Self::Output {
2347        &self.datasets[index]
2348    }
2349}
2350
2351impl IndexMut<usize> for BinnedDataset {
2352    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2353        &mut self.datasets[index]
2354    }
2355}
2356
2357impl Deref for BinnedDataset {
2358    type Target = Vec<Arc<Dataset>>;
2359
2360    fn deref(&self) -> &Self::Target {
2361        &self.datasets
2362    }
2363}
2364
2365impl DerefMut for BinnedDataset {
2366    fn deref_mut(&mut self) -> &mut Self::Target {
2367        &mut self.datasets
2368    }
2369}
2370
2371impl BinnedDataset {
2372    /// The number of bins in the [`BinnedDataset`].
2373    pub fn n_bins(&self) -> usize {
2374        self.datasets.len()
2375    }
2376
2377    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
2378    pub fn edges(&self) -> Vec<f64> {
2379        self.edges.clone()
2380    }
2381
2382    /// Returns the range that was used to form the [`BinnedDataset`].
2383    pub fn range(&self) -> (f64, f64) {
2384        (self.edges[0], self.edges[self.n_bins()])
2385    }
2386}
2387
2388#[cfg(test)]
2389mod tests {
2390    use crate::Mass;
2391
2392    use super::*;
2393    use crate::utils::vectors::Vec3;
2394    use approx::{assert_relative_eq, assert_relative_ne};
2395    use fastrand;
2396    use serde::{Deserialize, Serialize};
2397    use std::{
2398        env, fs,
2399        path::{Path, PathBuf},
2400    };
2401
2402    fn test_data_path(file: &str) -> PathBuf {
2403        Path::new(env!("CARGO_MANIFEST_DIR"))
2404            .join("test_data")
2405            .join(file)
2406    }
2407
2408    fn open_test_dataset(file: &str, options: DatasetReadOptions) -> Arc<Dataset> {
2409        let path = test_data_path(file);
2410        let path_str = path.to_str().expect("test data path should be valid UTF-8");
2411        let ext = path
2412            .extension()
2413            .and_then(|ext| ext.to_str())
2414            .unwrap_or_default()
2415            .to_ascii_lowercase();
2416        match ext.as_str() {
2417            "parquet" => Dataset::from_parquet(path_str, &options),
2418            "root" => Dataset::from_root(path_str, &options),
2419            other => panic!("Unsupported extension in test data: {other}"),
2420        }
2421        .expect("dataset should open")
2422    }
2423
2424    fn make_temp_dir() -> PathBuf {
2425        let dir = env::temp_dir().join(format!("laddu_test_{}", fastrand::u64(..)));
2426        fs::create_dir(&dir).expect("temp dir should be created");
2427        dir
2428    }
2429
2430    fn assert_events_close(left: &Event, right: &Event, p4_names: &[&str], aux_names: &[&str]) {
2431        for name in p4_names {
2432            let lp4 = left
2433                .p4(name)
2434                .unwrap_or_else(|| panic!("missing p4 '{name}' in left dataset"));
2435            let rp4 = right
2436                .p4(name)
2437                .unwrap_or_else(|| panic!("missing p4 '{name}' in right dataset"));
2438            assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-9);
2439            assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-9);
2440            assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-9);
2441            assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-9);
2442        }
2443        let left_aux = left.aux();
2444        let right_aux = right.aux();
2445        for name in aux_names {
2446            let laux = left_aux
2447                .get(name)
2448                .copied()
2449                .unwrap_or_else(|| panic!("missing aux '{name}' in left dataset"));
2450            let raux = right_aux
2451                .get(name)
2452                .copied()
2453                .unwrap_or_else(|| panic!("missing aux '{name}' in right dataset"));
2454            assert_relative_eq!(laux, raux, epsilon = 1e-9);
2455        }
2456        assert_relative_eq!(left.weight(), right.weight(), epsilon = 1e-9);
2457    }
2458
2459    fn assert_datasets_close(
2460        left: &Arc<Dataset>,
2461        right: &Arc<Dataset>,
2462        p4_names: &[&str],
2463        aux_names: &[&str],
2464    ) {
2465        assert_eq!(left.n_events(), right.n_events());
2466        for idx in 0..left.n_events() {
2467            let levent = &left[idx];
2468            let revent = &right[idx];
2469            assert_events_close(levent, revent, p4_names, aux_names);
2470        }
2471    }
2472
2473    #[test]
2474    fn test_from_parquet_auto_matches_explicit_names() {
2475        let auto = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2476        let explicit_options = DatasetReadOptions::new()
2477            .p4_names(TEST_P4_NAMES)
2478            .aux_names(TEST_AUX_NAMES);
2479        let explicit = open_test_dataset("data_f32.parquet", explicit_options);
2480
2481        let mut detected_p4: Vec<&str> = auto.p4_names().iter().map(String::as_str).collect();
2482        detected_p4.sort_unstable();
2483        let mut expected_p4 = TEST_P4_NAMES.to_vec();
2484        expected_p4.sort_unstable();
2485        assert_eq!(detected_p4, expected_p4);
2486        let mut detected_aux: Vec<&str> = auto.aux_names().iter().map(String::as_str).collect();
2487        detected_aux.sort_unstable();
2488        let mut expected_aux = TEST_AUX_NAMES.to_vec();
2489        expected_aux.sort_unstable();
2490        assert_eq!(detected_aux, expected_aux);
2491        assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2492    }
2493
2494    #[test]
2495    fn test_from_parquet_with_aliases() {
2496        let dataset = open_test_dataset(
2497            "data_f32.parquet",
2498            DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2499        );
2500        let event = dataset.named_event(0);
2501        let alias_vec = event.p4("resonance").expect("alias vector");
2502        let expected = event.get_p4_sum(["kshort1", "kshort2"]);
2503        assert_relative_eq!(alias_vec.px(), expected.px(), epsilon = 1e-9);
2504        assert_relative_eq!(alias_vec.py(), expected.py(), epsilon = 1e-9);
2505        assert_relative_eq!(alias_vec.pz(), expected.pz(), epsilon = 1e-9);
2506        assert_relative_eq!(alias_vec.e(), expected.e(), epsilon = 1e-9);
2507    }
2508
2509    #[test]
2510    fn test_from_parquet_f64_matches_f32() {
2511        let f32_ds = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2512        let f64_ds = open_test_dataset("data_f64.parquet", DatasetReadOptions::new());
2513        assert_datasets_close(&f64_ds, &f32_ds, TEST_P4_NAMES, TEST_AUX_NAMES);
2514    }
2515
2516    #[test]
2517    fn test_from_root_detects_columns_and_matches_parquet() {
2518        let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2519        let root_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2520        let mut detected_p4: Vec<&str> = root_auto.p4_names().iter().map(String::as_str).collect();
2521        detected_p4.sort_unstable();
2522        let mut expected_p4 = TEST_P4_NAMES.to_vec();
2523        expected_p4.sort_unstable();
2524        assert_eq!(detected_p4, expected_p4);
2525        let mut detected_aux: Vec<&str> =
2526            root_auto.aux_names().iter().map(String::as_str).collect();
2527        detected_aux.sort_unstable();
2528        let mut expected_aux = TEST_AUX_NAMES.to_vec();
2529        expected_aux.sort_unstable();
2530        assert_eq!(detected_aux, expected_aux);
2531        let root_named_options = DatasetReadOptions::new()
2532            .p4_names(TEST_P4_NAMES)
2533            .aux_names(TEST_AUX_NAMES);
2534        let root_named = open_test_dataset("data_f32.root", root_named_options);
2535        assert_datasets_close(&root_auto, &root_named, TEST_P4_NAMES, TEST_AUX_NAMES);
2536        assert_datasets_close(&root_auto, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2537    }
2538
2539    #[test]
2540    fn test_from_root_f64_matches_parquet() {
2541        let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2542        let root_f64 = open_test_dataset("data_f64.root", DatasetReadOptions::new());
2543        assert_datasets_close(&root_f64, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2544    }
2545    #[test]
2546    fn test_event_creation() {
2547        let event = test_event();
2548        assert_eq!(event.p4s.len(), 4);
2549        assert_eq!(event.aux.len(), 2);
2550        assert_relative_eq!(event.weight, 0.48)
2551    }
2552
2553    #[test]
2554    fn test_event_p4_sum() {
2555        let event = test_event();
2556        let sum = event.get_p4_sum([2, 3]);
2557        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
2558        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
2559        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
2560        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
2561    }
2562
2563    #[test]
2564    fn test_event_boost() {
2565        let event = test_event();
2566        let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
2567        let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
2568        assert_relative_eq!(p4_sum.px(), 0.0);
2569        assert_relative_eq!(p4_sum.py(), 0.0);
2570        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2571    }
2572
2573    #[test]
2574    fn test_event_evaluate() {
2575        let event = test_event();
2576        let mut mass = Mass::new(["proton"]);
2577        mass.bind(
2578            &DatasetMetadata::new(
2579                TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
2580                TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
2581            )
2582            .expect("metadata"),
2583        )
2584        .unwrap();
2585        assert_relative_eq!(event.evaluate(&mass), 1.007);
2586    }
2587
2588    #[test]
2589    fn test_dataset_size_check() {
2590        let dataset = Dataset::new(Vec::new());
2591        assert_eq!(dataset.n_events(), 0);
2592        let dataset = Dataset::new(vec![Arc::new(test_event())]);
2593        assert_eq!(dataset.n_events(), 1);
2594    }
2595
2596    #[test]
2597    fn test_dataset_sum() {
2598        let dataset = test_dataset();
2599        let metadata = dataset.metadata_arc();
2600        let dataset2 = Dataset::new_with_metadata(
2601            vec![Arc::new(EventData {
2602                p4s: test_event().p4s,
2603                aux: test_event().aux,
2604                weight: 0.52,
2605            })],
2606            metadata.clone(),
2607        );
2608        let dataset_sum = &dataset + &dataset2;
2609        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
2610        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
2611    }
2612
2613    #[test]
2614    fn test_dataset_weights() {
2615        let dataset = Dataset::new(vec![
2616            Arc::new(test_event()),
2617            Arc::new(EventData {
2618                p4s: test_event().p4s,
2619                aux: test_event().aux,
2620                weight: 0.52,
2621            }),
2622        ]);
2623        let weights = dataset.weights();
2624        assert_eq!(weights.len(), 2);
2625        assert_relative_eq!(weights[0], 0.48);
2626        assert_relative_eq!(weights[1], 0.52);
2627        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
2628    }
2629
2630    #[test]
2631    fn test_dataset_filtering() {
2632        let metadata = Arc::new(
2633            DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2634                .expect("metadata should be valid"),
2635        );
2636        let events = vec![
2637            Arc::new(EventData {
2638                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
2639                aux: vec![],
2640                weight: 1.0,
2641            }),
2642            Arc::new(EventData {
2643                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
2644                aux: vec![],
2645                weight: 1.0,
2646            }),
2647            Arc::new(EventData {
2648                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
2649                // HACK: using 1.0 messes with this test because the eventual computation gives a mass
2650                // slightly less than 1.0
2651                aux: vec![],
2652                weight: 1.0,
2653            }),
2654        ];
2655        let dataset = Dataset::new_with_metadata(events, metadata);
2656
2657        let metadata = dataset.metadata_arc();
2658        let mut mass = Mass::new(["beam"]);
2659        mass.bind(metadata.as_ref()).unwrap();
2660        let expression = mass.gt(0.0).and(&mass.lt(1.0));
2661
2662        let filtered = dataset.filter(&expression).unwrap();
2663        assert_eq!(filtered.n_events(), 1);
2664        assert_relative_eq!(mass.value(&filtered[0]), 0.5);
2665    }
2666
2667    #[test]
2668    fn test_dataset_boost() {
2669        let dataset = test_dataset();
2670        let dataset_boosted = dataset.boost_to_rest_frame_of(&["proton", "kshort1", "kshort2"]);
2671        let p4_sum = dataset_boosted[0].get_p4_sum(["proton", "kshort1", "kshort2"]);
2672        assert_relative_eq!(p4_sum.px(), 0.0);
2673        assert_relative_eq!(p4_sum.py(), 0.0);
2674        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2675    }
2676
2677    #[test]
2678    fn test_named_event_view() {
2679        let dataset = test_dataset();
2680        let view = dataset.named_event(0);
2681
2682        assert_relative_eq!(view.weight(), dataset[0].weight);
2683        let beam = view.p4("beam").expect("beam p4");
2684        assert_relative_eq!(beam.px(), dataset[0].p4s[0].px());
2685        assert_relative_eq!(beam.e(), dataset[0].p4s[0].e());
2686
2687        let summed = view.get_p4_sum(["kshort1", "kshort2"]);
2688        assert_relative_eq!(summed.e(), dataset[0].p4s[2].e() + dataset[0].p4s[3].e());
2689
2690        let aux_angle = view.aux().get("pol_angle").copied().expect("pol angle");
2691        assert_relative_eq!(aux_angle, dataset[0].aux[1]);
2692
2693        let metadata = dataset.metadata_arc();
2694        let boosted = view.boost_to_rest_frame_of(["proton", "kshort1", "kshort2"]);
2695        let boosted_event = Event::new(Arc::new(boosted), metadata);
2696        let boosted_sum = boosted_event.get_p4_sum(["proton", "kshort1", "kshort2"]);
2697        assert_relative_eq!(boosted_sum.px(), 0.0);
2698    }
2699
2700    #[test]
2701    fn test_dataset_evaluate() {
2702        let dataset = test_dataset();
2703        let mass = Mass::new(["proton"]);
2704        assert_relative_eq!(dataset.evaluate(&mass).unwrap()[0], 1.007);
2705    }
2706
2707    #[test]
2708    fn test_dataset_metadata_rejects_duplicate_names() {
2709        let err = DatasetMetadata::new(vec!["beam", "beam"], Vec::<String>::new());
2710        assert!(matches!(
2711            err,
2712            Err(LadduError::DuplicateName { category, .. }) if category == "p4"
2713        ));
2714        let err = DatasetMetadata::new(
2715            vec!["beam"],
2716            vec!["pol_angle".to_string(), "pol_angle".to_string()],
2717        );
2718        assert!(matches!(
2719            err,
2720            Err(LadduError::DuplicateName { category, .. }) if category == "aux"
2721        ));
2722    }
2723
2724    #[test]
2725    fn test_dataset_lookup_by_name() {
2726        let dataset = test_dataset();
2727        let proton = dataset.p4_by_name(0, "proton").expect("proton p4");
2728        let proton_idx = dataset.metadata().p4_index("proton").unwrap();
2729        assert_relative_eq!(proton.e(), dataset[0].p4s[proton_idx].e());
2730        assert!(dataset.p4_by_name(0, "unknown").is_none());
2731        let angle = dataset.aux_by_name(0, "pol_angle").expect("pol_angle");
2732        assert_relative_eq!(angle, dataset[0].aux[1]);
2733        assert!(dataset.aux_by_name(0, "missing").is_none());
2734    }
2735
2736    #[test]
2737    fn test_binned_dataset() {
2738        let dataset = Dataset::new(vec![
2739            Arc::new(EventData {
2740                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
2741                aux: vec![],
2742                weight: 1.0,
2743            }),
2744            Arc::new(EventData {
2745                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
2746                aux: vec![],
2747                weight: 2.0,
2748            }),
2749        ]);
2750
2751        #[derive(Clone, Serialize, Deserialize, Debug)]
2752        struct BeamEnergy;
2753        impl Display for BeamEnergy {
2754            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2755                write!(f, "BeamEnergy")
2756            }
2757        }
2758        #[typetag::serde]
2759        impl Variable for BeamEnergy {
2760            fn value(&self, event: &EventData) -> f64 {
2761                event.p4s[0].e()
2762            }
2763        }
2764        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
2765
2766        // Test binning by first particle energy
2767        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0)).unwrap();
2768
2769        assert_eq!(binned.n_bins(), 2);
2770        assert_eq!(binned.edges().len(), 3);
2771        assert_relative_eq!(binned.edges()[0], 0.0);
2772        assert_relative_eq!(binned.edges()[2], 3.0);
2773        assert_eq!(binned[0].n_events(), 1);
2774        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
2775        assert_eq!(binned[1].n_events(), 1);
2776        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
2777    }
2778
2779    #[test]
2780    fn test_dataset_bootstrap() {
2781        let metadata = test_dataset().metadata_arc();
2782        let dataset = Dataset::new_with_metadata(
2783            vec![
2784                Arc::new(test_event()),
2785                Arc::new(EventData {
2786                    p4s: test_event().p4s.clone(),
2787                    aux: test_event().aux.clone(),
2788                    weight: 1.0,
2789                }),
2790            ],
2791            metadata,
2792        );
2793        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
2794
2795        let bootstrapped = dataset.bootstrap(43);
2796        assert_eq!(bootstrapped.n_events(), dataset.n_events());
2797        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
2798
2799        // Test empty dataset bootstrap
2800        let empty_dataset = Dataset::new(Vec::new());
2801        let empty_bootstrap = empty_dataset.bootstrap(43);
2802        assert_eq!(empty_bootstrap.n_events(), 0);
2803    }
2804
2805    #[test]
2806    fn test_dataset_iteration_returns_events() {
2807        let dataset = test_dataset();
2808        let mut weights = Vec::new();
2809        for event in dataset.iter() {
2810            weights.push(event.weight());
2811        }
2812        assert_eq!(weights.len(), dataset.n_events());
2813        assert_relative_eq!(weights[0], dataset[0].weight);
2814    }
2815
2816    #[test]
2817    fn test_dataset_into_iter_returns_events() {
2818        let dataset = test_dataset();
2819        let weights: Vec<f64> = dataset.into_iter().map(|event| event.weight()).collect();
2820        assert_eq!(weights.len(), 1);
2821        assert_relative_eq!(weights[0], test_event().weight);
2822    }
2823    #[test]
2824    fn test_event_display() {
2825        let event = test_event();
2826        let display_string = format!("{}", event);
2827        assert!(display_string.contains("Event:"));
2828        assert!(display_string.contains("p4s:"));
2829        assert!(display_string.contains("aux:"));
2830        assert!(display_string.contains("aux[0]: 0.38562805"));
2831        assert!(display_string.contains("aux[1]: 0.05708078"));
2832        assert!(display_string.contains("weight:"));
2833    }
2834
2835    #[test]
2836    fn test_name_based_access() {
2837        let metadata =
2838            Arc::new(DatasetMetadata::new(vec!["beam", "target"], vec!["pol_angle"]).unwrap());
2839        let event = Arc::new(EventData {
2840            p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(0.1, 0.2, 0.3, 0.5)],
2841            aux: vec![0.42],
2842            weight: 1.0,
2843        });
2844        let dataset = Dataset::new_with_metadata(vec![event], metadata);
2845        let beam = dataset.p4_by_name(0, "beam").unwrap();
2846        assert_relative_eq!(beam.px(), 0.0);
2847        assert_relative_eq!(beam.py(), 0.0);
2848        assert_relative_eq!(beam.pz(), 1.0);
2849        assert_relative_eq!(beam.e(), 1.0);
2850        assert_relative_eq!(dataset.aux_by_name(0, "pol_angle").unwrap(), 0.42);
2851        assert!(dataset.p4_by_name(0, "missing").is_none());
2852        assert!(dataset.aux_by_name(0, "missing").is_none());
2853    }
2854
2855    #[test]
2856    fn test_parquet_roundtrip_to_tempfile() {
2857        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2858        let dir = make_temp_dir();
2859        let path = dir.join("roundtrip.parquet");
2860        let path_str = path.to_str().expect("path should be valid UTF-8");
2861
2862        dataset
2863            .to_parquet(path_str, &DatasetWriteOptions::default())
2864            .expect("writing parquet should succeed");
2865        let reopened = Dataset::from_parquet(path_str, &DatasetReadOptions::new())
2866            .expect("parquet roundtrip should reopen");
2867
2868        assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2869        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2870    }
2871
2872    #[test]
2873    fn test_root_roundtrip_to_tempfile() {
2874        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2875        let dir = make_temp_dir();
2876        let path = dir.join("roundtrip.root");
2877        let path_str = path.to_str().expect("path should be valid UTF-8");
2878
2879        dataset
2880            .to_root(path_str, &DatasetWriteOptions::default())
2881            .expect("writing root should succeed");
2882        let reopened = Dataset::from_root(path_str, &DatasetReadOptions::new())
2883            .expect("root roundtrip should reopen");
2884
2885        assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2886        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2887    }
2888}