laddu_core/
data.rs

1use accurate::{sum::Klein, traits::*};
2use arrow::{
3    array::{Float32Array, Float64Array},
4    datatypes::{DataType, Field, Schema},
5    record_batch::RecordBatch,
6};
7use auto_ops::impl_op_ex;
8use parking_lot::Mutex;
9use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter};
10#[cfg(feature = "mpi")]
11use parquet::file::metadata::ParquetMetaData;
12use serde::{Deserialize, Serialize};
13use std::ops::{Deref, DerefMut, Index, IndexMut};
14use std::path::Path;
15use std::{fmt::Display, fs::File};
16use std::{path::PathBuf, sync::Arc};
17
18use oxyroot::{Branch, Named, ReaderTree, RootFile, WriterTree};
19
20#[cfg(feature = "rayon")]
21use rayon::prelude::*;
22
23#[cfg(feature = "mpi")]
24use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
25
26#[cfg(feature = "mpi")]
27use crate::mpi::LadduMPI;
28
29#[cfg(feature = "mpi")]
30type WorldHandle = SimpleCommunicator;
31#[cfg(not(feature = "mpi"))]
32type WorldHandle = ();
33
34use crate::utils::get_bin_edges;
35use crate::{
36    utils::{
37        variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
38        vectors::Vec4,
39    },
40    LadduError, LadduResult,
41};
42use indexmap::{IndexMap, IndexSet};
43
44/// An event that can be used to test the implementation of an
45/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
46/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
47pub fn test_event() -> EventData {
48    use crate::utils::vectors::*;
49    let pol_magnitude = 0.38562805;
50    let pol_angle = 0.05708078;
51    EventData {
52        p4s: vec![
53            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
54            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
55            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
56            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
57        ],
58        aux: vec![pol_magnitude, pol_angle],
59        weight: 0.48,
60    }
61}
62
63/// Particle names used by [`test_dataset`].
64pub const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
65/// Auxiliary scalar names used by [`test_dataset`].
66pub const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
67
68/// A dataset that can be used to test the implementation of an
69/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a single
70/// [`EventData`] generated from [`test_event`].
71pub fn test_dataset() -> Dataset {
72    let metadata = Arc::new(
73        DatasetMetadata::new(
74            TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
75            TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
76        )
77        .expect("Test metadata should be valid"),
78    );
79    Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
80}
81
82/// Raw event data in a [`Dataset`] containing all particle and auxiliary information.
83///
84/// An [`EventData`] instance owns the list of four-momenta (`p4s`), auxiliary scalars (`aux`),
85/// and weight recorded for a particular collision event. Use [`Event`] when you need a
86/// metadata-aware view with name-based helpers.
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88pub struct EventData {
89    /// A list of four-momenta for each particle.
90    pub p4s: Vec<Vec4>,
91    /// A list of auxiliary scalar values associated with the event.
92    pub aux: Vec<f64>,
93    /// The weight given to the event.
94    pub weight: f64,
95}
96
97impl Display for EventData {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        writeln!(f, "Event:")?;
100        writeln!(f, "  p4s:")?;
101        for p4 in &self.p4s {
102            writeln!(f, "    {}", p4.to_p4_string())?;
103        }
104        writeln!(f, "  aux:")?;
105        for (idx, value) in self.aux.iter().enumerate() {
106            writeln!(f, "    aux[{idx}]: {value}")?;
107        }
108        writeln!(f, "  weight:")?;
109        writeln!(f, "    {}", self.weight)?;
110        Ok(())
111    }
112}
113
114impl EventData {
115    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`EventData`].
116    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
117        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
118    }
119    /// Boost all the four-momenta in the [`EventData`] to the rest frame of the given set of
120    /// four-momenta by indices.
121    pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
122        let frame = self.get_p4_sum(indices);
123        EventData {
124            p4s: self
125                .p4s
126                .iter()
127                .map(|p4| p4.boost(&(-frame.beta())))
128                .collect(),
129            aux: self.aux.clone(),
130            weight: self.weight,
131        }
132    }
133    /// Evaluate a [`Variable`] on an [`EventData`].
134    pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
135        variable.value(self)
136    }
137}
138
139/// A collection of [`EventData`].
140#[derive(Debug, Clone)]
141pub struct DatasetMetadata {
142    pub(crate) p4_names: Vec<String>,
143    pub(crate) aux_names: Vec<String>,
144    pub(crate) p4_lookup: IndexMap<String, usize>,
145    pub(crate) aux_lookup: IndexMap<String, usize>,
146    pub(crate) p4_selections: IndexMap<String, P4Selection>,
147}
148
149impl DatasetMetadata {
150    /// Construct metadata from explicit particle and auxiliary names.
151    pub fn new<P: Into<String>, A: Into<String>>(
152        p4_names: Vec<P>,
153        aux_names: Vec<A>,
154    ) -> LadduResult<Self> {
155        let mut p4_lookup = IndexMap::with_capacity(p4_names.len());
156        let mut aux_lookup = IndexMap::with_capacity(aux_names.len());
157        let mut p4_selections = IndexMap::with_capacity(p4_names.len());
158        let p4_names: Vec<String> = p4_names
159            .into_iter()
160            .enumerate()
161            .map(|(idx, name)| {
162                let name = name.into();
163                if p4_lookup.contains_key(&name) {
164                    return Err(LadduError::DuplicateName {
165                        category: "p4",
166                        name,
167                    });
168                }
169                p4_lookup.insert(name.clone(), idx);
170                p4_selections.insert(
171                    name.clone(),
172                    P4Selection::with_indices(vec![name.clone()], vec![idx]),
173                );
174                Ok(name)
175            })
176            .collect::<Result<_, _>>()?;
177        let aux_names: Vec<String> = aux_names
178            .into_iter()
179            .enumerate()
180            .map(|(idx, name)| {
181                let name = name.into();
182                if aux_lookup.contains_key(&name) {
183                    return Err(LadduError::DuplicateName {
184                        category: "aux",
185                        name,
186                    });
187                }
188                aux_lookup.insert(name.clone(), idx);
189                Ok(name)
190            })
191            .collect::<Result<_, _>>()?;
192        Ok(Self {
193            p4_names,
194            aux_names,
195            p4_lookup,
196            aux_lookup,
197            p4_selections,
198        })
199    }
200
201    /// Create metadata with no registered names.
202    pub fn empty() -> Self {
203        Self {
204            p4_names: Vec::new(),
205            aux_names: Vec::new(),
206            p4_lookup: IndexMap::new(),
207            aux_lookup: IndexMap::new(),
208            p4_selections: IndexMap::new(),
209        }
210    }
211
212    /// Resolve the index of a four-momentum by name.
213    pub fn p4_index(&self, name: &str) -> Option<usize> {
214        self.p4_lookup.get(name).copied()
215    }
216
217    /// Registered four-momentum names in declaration order.
218    pub fn p4_names(&self) -> &[String] {
219        &self.p4_names
220    }
221
222    /// Resolve the index of an auxiliary scalar by name.
223    pub fn aux_index(&self, name: &str) -> Option<usize> {
224        self.aux_lookup.get(name).copied()
225    }
226
227    /// Registered auxiliary scalar names in declaration order.
228    pub fn aux_names(&self) -> &[String] {
229        &self.aux_names
230    }
231
232    /// Look up a resolved four-momentum selection by name (canonical or alias).
233    pub fn p4_selection(&self, name: &str) -> Option<&P4Selection> {
234        self.p4_selections.get(name)
235    }
236
237    /// Register an alias mapping to one or more existing four-momenta.
238    pub fn add_p4_alias<N>(&mut self, alias: N, mut selection: P4Selection) -> LadduResult<()>
239    where
240        N: Into<String>,
241    {
242        let alias = alias.into();
243        if self.p4_selections.contains_key(&alias) {
244            return Err(LadduError::DuplicateName {
245                category: "alias",
246                name: alias,
247            });
248        }
249        selection.bind(self)?;
250        self.p4_selections.insert(alias, selection);
251        Ok(())
252    }
253
254    /// Register multiple aliases at once.
255    pub fn add_p4_aliases<I, N>(&mut self, entries: I) -> LadduResult<()>
256    where
257        I: IntoIterator<Item = (N, P4Selection)>,
258        N: Into<String>,
259    {
260        for (alias, selection) in entries {
261            self.add_p4_alias(alias, selection)?;
262        }
263        Ok(())
264    }
265
266    pub(crate) fn append_indices_for_name(
267        &self,
268        name: &str,
269        target: &mut Vec<usize>,
270    ) -> LadduResult<()> {
271        if let Some(selection) = self.p4_selections.get(name) {
272            target.extend_from_slice(selection.indices());
273            return Ok(());
274        }
275        Err(LadduError::UnknownName {
276            category: "p4",
277            name: name.to_string(),
278        })
279    }
280}
281
282impl Default for DatasetMetadata {
283    fn default() -> Self {
284        Self::empty()
285    }
286}
287
288/// A collection of [`EventData`] with optional metadata for name-based lookups.
289#[derive(Debug, Clone)]
290pub struct Dataset {
291    /// The [`EventData`] contained in the [`Dataset`]
292    pub events: Vec<Event>,
293    pub(crate) metadata: Arc<DatasetMetadata>,
294}
295
296/// Metadata-aware view of an [`EventData`] with name-based helpers.
297#[derive(Clone, Debug)]
298pub struct Event {
299    event: Arc<EventData>,
300    metadata: Arc<DatasetMetadata>,
301}
302
303impl Event {
304    /// Create a new metadata-aware event from raw data and dataset metadata.
305    pub fn new(event: Arc<EventData>, metadata: Arc<DatasetMetadata>) -> Self {
306        Self { event, metadata }
307    }
308
309    /// Borrow the raw [`EventData`].
310    pub fn data(&self) -> &EventData {
311        &self.event
312    }
313
314    /// Obtain a clone of the underlying [`EventData`] handle.
315    pub fn data_arc(&self) -> Arc<EventData> {
316        self.event.clone()
317    }
318
319    /// Return the four-momenta stored in this event keyed by their registered names.
320    pub fn p4s(&self) -> IndexMap<&str, Vec4> {
321        let mut map = IndexMap::with_capacity(self.metadata.p4_names.len());
322        for (idx, name) in self.metadata.p4_names.iter().enumerate() {
323            if let Some(p4) = self.event.p4s.get(idx) {
324                map.insert(name.as_str(), *p4);
325            }
326        }
327        map
328    }
329
330    /// Return the auxiliary scalars stored in this event keyed by their registered names.
331    pub fn aux(&self) -> IndexMap<&str, f64> {
332        let mut map = IndexMap::with_capacity(self.metadata.aux_names.len());
333        for (idx, name) in self.metadata.aux_names.iter().enumerate() {
334            if let Some(value) = self.event.aux.get(idx) {
335                map.insert(name.as_str(), *value);
336            }
337        }
338        map
339    }
340
341    /// Return the event weight.
342    pub fn weight(&self) -> f64 {
343        self.event.weight
344    }
345
346    /// Retrieve the dataset metadata attached to this event.
347    pub fn metadata(&self) -> &DatasetMetadata {
348        &self.metadata
349    }
350
351    /// Clone the metadata handle associated with this event.
352    pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
353        self.metadata.clone()
354    }
355
356    /// Retrieve a four-momentum (or aliased sum) by name.
357    pub fn p4(&self, name: &str) -> Option<Vec4> {
358        self.metadata
359            .p4_selection(name)
360            .map(|selection| selection.momentum(&self.event))
361    }
362
363    fn resolve_p4_indices<N>(&self, names: N) -> Vec<usize>
364    where
365        N: IntoIterator,
366        N::Item: AsRef<str>,
367    {
368        let mut indices = Vec::new();
369        for name in names {
370            let name_ref = name.as_ref();
371            if let Some(selection) = self.metadata.p4_selection(name_ref) {
372                indices.extend_from_slice(selection.indices());
373            } else {
374                panic!("Unknown particle name '{name}'", name = name_ref);
375            }
376        }
377        indices
378    }
379
380    /// Return a four-momentum formed by summing four-momenta with the specified names.
381    pub fn get_p4_sum<N>(&self, names: N) -> Vec4
382    where
383        N: IntoIterator,
384        N::Item: AsRef<str>,
385    {
386        let indices = self.resolve_p4_indices(names);
387        self.event.get_p4_sum(&indices)
388    }
389
390    /// Boost all four-momenta into the rest frame defined by the specified particle names.
391    pub fn boost_to_rest_frame_of<N>(&self, names: N) -> EventData
392    where
393        N: IntoIterator,
394        N::Item: AsRef<str>,
395    {
396        let indices = self.resolve_p4_indices(names);
397        self.event.boost_to_rest_frame_of(&indices)
398    }
399
400    /// Evaluate a [`Variable`] over this event.
401    pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
402        self.event.evaluate(variable)
403    }
404}
405
406impl Deref for Event {
407    type Target = EventData;
408
409    fn deref(&self) -> &Self::Target {
410        &self.event
411    }
412}
413
414impl AsRef<EventData> for Event {
415    fn as_ref(&self) -> &EventData {
416        self.data()
417    }
418}
419
420impl IntoIterator for Dataset {
421    type Item = Event;
422
423    type IntoIter = DatasetIntoIter;
424
425    fn into_iter(self) -> Self::IntoIter {
426        #[cfg(feature = "mpi")]
427        {
428            if let Some(world) = crate::mpi::get_world() {
429                // Cache total before moving fields out of self for MPI iteration.
430                let total = self.n_events();
431                return DatasetIntoIter::Mpi(DatasetMpiIntoIter {
432                    events: self.events,
433                    metadata: self.metadata,
434                    world,
435                    index: 0,
436                    total,
437                });
438            }
439        }
440        DatasetIntoIter::Local(self.events.into_iter())
441    }
442}
443
444impl Dataset {
445    /// Iterate over all events in the dataset. When MPI is enabled, this will visit
446    /// every event across all ranks, fetching remote events on demand.
447    pub fn iter(&self) -> DatasetIter<'_> {
448        #[cfg(feature = "mpi")]
449        {
450            if let Some(world) = crate::mpi::get_world() {
451                return DatasetIter::Mpi(DatasetMpiIter {
452                    dataset: self,
453                    world,
454                    index: 0,
455                    total: self.n_events(),
456                });
457            }
458        }
459        DatasetIter::Local(self.events.iter())
460    }
461
462    /// Borrow the dataset metadata used for name lookups.
463    pub fn metadata(&self) -> &DatasetMetadata {
464        &self.metadata
465    }
466
467    /// Clone the internal metadata handle for external consumers (e.g., language bindings).
468    pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
469        self.metadata.clone()
470    }
471
472    /// Names corresponding to stored four-momenta.
473    pub fn p4_names(&self) -> &[String] {
474        &self.metadata.p4_names
475    }
476
477    /// Names corresponding to stored auxiliary scalars.
478    pub fn aux_names(&self) -> &[String] {
479        &self.metadata.aux_names
480    }
481
482    /// Resolve the index of a four-momentum by name.
483    pub fn p4_index(&self, name: &str) -> Option<usize> {
484        self.metadata.p4_index(name)
485    }
486
487    /// Resolve the index of an auxiliary scalar by name.
488    pub fn aux_index(&self, name: &str) -> Option<usize> {
489        self.metadata.aux_index(name)
490    }
491
492    /// Borrow event data together with metadata-based helpers as an [`Event`] view.
493    pub fn named_event(&self, index: usize) -> Event {
494        self.events[index].clone()
495    }
496
497    /// Retrieve a four-momentum by name for the event at `event_index`.
498    pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
499        self.events
500            .get(event_index)
501            .and_then(|event| event.p4(name))
502    }
503
504    /// Retrieve an auxiliary scalar by name for the event at `event_index`.
505    pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
506        let idx = self.aux_index(name)?;
507        self.events
508            .get(event_index)
509            .and_then(|event| event.aux.get(idx))
510            .copied()
511    }
512
513    /// Get a reference to the [`EventData`] at the given index in the [`Dataset`] (non-MPI
514    /// version).
515    ///
516    /// # Notes
517    ///
518    /// This method is not intended to be called in analyses but rather in writing methods
519    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
520    /// as if it were any other [`Vec`]:
521    ///
522    /// ```ignore
523    /// let ds: Dataset = Dataset::new(events);
524    /// let event_0 = ds[0];
525    /// ```
526    pub fn index_local(&self, index: usize) -> &Event {
527        &self.events[index]
528    }
529
530    #[cfg(feature = "mpi")]
531    fn partition(
532        events: Vec<Arc<EventData>>,
533        world: &SimpleCommunicator,
534    ) -> Vec<Vec<Arc<EventData>>> {
535        let partition = world.partition(events.len());
536        (0..partition.n_ranks())
537            .map(|rank| {
538                let range = partition.range_for_rank(rank);
539                events[range.clone()].iter().cloned().collect()
540            })
541            .collect()
542    }
543
544    /// Get a reference to the [`EventData`] at the given index in the [`Dataset`]
545    /// (MPI-compatible version).
546    ///
547    /// # Notes
548    ///
549    /// This method is not intended to be called in analyses but rather in writing methods
550    /// that have `mpi`-feature-gated versions. Most users should just index into a [`Dataset`]
551    /// as if it were any other [`Vec`]:
552    ///
553    /// ```ignore
554    /// let ds: Dataset = Dataset::new(events);
555    /// let event_0 = ds[0];
556    /// ```
557    #[cfg(feature = "mpi")]
558    pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
559        let total = self.n_events();
560        let event = fetch_event_mpi(self, index, world, total);
561        Box::leak(Box::new(event))
562    }
563}
564
565impl Index<usize> for Dataset {
566    type Output = Event;
567
568    fn index(&self, index: usize) -> &Self::Output {
569        #[cfg(feature = "mpi")]
570        {
571            if let Some(world) = crate::mpi::get_world() {
572                return self.index_mpi(index, &world);
573            }
574        }
575        self.index_local(index)
576    }
577}
578
579/// Iterator over a [`Dataset`].
580pub enum DatasetIter<'a> {
581    /// Iterator over locally available events.
582    Local(std::slice::Iter<'a, Event>),
583    #[cfg(feature = "mpi")]
584    /// Iterator that fetches events across MPI ranks.
585    Mpi(DatasetMpiIter<'a>),
586}
587
588impl<'a> Iterator for DatasetIter<'a> {
589    type Item = Event;
590
591    fn next(&mut self) -> Option<Self::Item> {
592        match self {
593            DatasetIter::Local(iter) => iter.next().cloned(),
594            #[cfg(feature = "mpi")]
595            DatasetIter::Mpi(iter) => iter.next(),
596        }
597    }
598}
599
600/// Owning iterator over a [`Dataset`].
601pub enum DatasetIntoIter {
602    /// Iterator over locally available events, consuming the dataset.
603    Local(std::vec::IntoIter<Event>),
604    #[cfg(feature = "mpi")]
605    /// Iterator that fetches events across MPI ranks, consuming the dataset.
606    Mpi(DatasetMpiIntoIter),
607}
608
609impl Iterator for DatasetIntoIter {
610    type Item = Event;
611
612    fn next(&mut self) -> Option<Self::Item> {
613        match self {
614            DatasetIntoIter::Local(iter) => iter.next(),
615            #[cfg(feature = "mpi")]
616            DatasetIntoIter::Mpi(iter) => iter.next(),
617        }
618    }
619}
620
621#[cfg(feature = "mpi")]
622/// Iterator over a [`Dataset`] that fetches events across MPI ranks.
623pub struct DatasetMpiIter<'a> {
624    dataset: &'a Dataset,
625    world: SimpleCommunicator,
626    index: usize,
627    total: usize,
628}
629
630#[cfg(feature = "mpi")]
631impl<'a> Iterator for DatasetMpiIter<'a> {
632    type Item = Event;
633
634    fn next(&mut self) -> Option<Self::Item> {
635        if self.index >= self.total {
636            return None;
637        }
638        let event = fetch_event_mpi(self.dataset, self.index, &self.world, self.total);
639        self.index += 1;
640        Some(event)
641    }
642}
643
644#[cfg(feature = "mpi")]
645/// Owning iterator over a [`Dataset`] that fetches events across MPI ranks.
646pub struct DatasetMpiIntoIter {
647    events: Vec<Event>,
648    metadata: Arc<DatasetMetadata>,
649    world: SimpleCommunicator,
650    index: usize,
651    total: usize,
652}
653
654#[cfg(feature = "mpi")]
655impl Iterator for DatasetMpiIntoIter {
656    type Item = Event;
657
658    fn next(&mut self) -> Option<Self::Item> {
659        if self.index >= self.total {
660            return None;
661        }
662        let event = fetch_event_mpi_from_events(
663            &self.events,
664            &self.metadata,
665            self.index,
666            &self.world,
667            self.total,
668        );
669        self.index += 1;
670        Some(event)
671    }
672}
673
674#[cfg(feature = "mpi")]
675fn fetch_event_mpi(
676    dataset: &Dataset,
677    global_index: usize,
678    world: &SimpleCommunicator,
679    total: usize,
680) -> Event {
681    fetch_event_mpi_generic(
682        global_index,
683        total,
684        world,
685        &dataset.metadata,
686        |local_index| dataset.index_local(local_index),
687    )
688}
689
690#[cfg(feature = "mpi")]
691fn fetch_event_mpi_from_events(
692    events: &[Event],
693    metadata: &Arc<DatasetMetadata>,
694    global_index: usize,
695    world: &SimpleCommunicator,
696    total: usize,
697) -> Event {
698    fetch_event_mpi_generic(global_index, total, world, metadata, |local_index| {
699        &events[local_index]
700    })
701}
702
703#[cfg(feature = "mpi")]
704fn fetch_event_mpi_generic<'a, F>(
705    global_index: usize,
706    total: usize,
707    world: &SimpleCommunicator,
708    metadata: &Arc<DatasetMetadata>,
709    local_event: F,
710) -> Event
711where
712    F: Fn(usize) -> &'a Event,
713{
714    let (owning_rank, local_index) = world.owner_of_global_index(global_index, total);
715    let mut serialized_event_buffer_len: usize = 0;
716    let mut serialized_event_buffer: Vec<u8> = Vec::default();
717    let config = bincode::config::standard();
718    if world.rank() == owning_rank {
719        let event = local_event(local_index);
720        serialized_event_buffer = bincode::serde::encode_to_vec(event.data(), config).unwrap();
721        serialized_event_buffer_len = serialized_event_buffer.len();
722    }
723    world
724        .process_at_rank(owning_rank)
725        .broadcast_into(&mut serialized_event_buffer_len);
726    if world.rank() != owning_rank {
727        serialized_event_buffer = vec![0; serialized_event_buffer_len];
728    }
729    world
730        .process_at_rank(owning_rank)
731        .broadcast_into(&mut serialized_event_buffer);
732
733    if world.rank() == owning_rank {
734        local_event(local_index).clone()
735    } else {
736        let (event, _): (EventData, usize) =
737            bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
738        Event::new(Arc::new(event), metadata.clone())
739    }
740}
741
742impl Dataset {
743    /// Create a new [`Dataset`] from a list of [`EventData`] (non-MPI version).
744    ///
745    /// # Notes
746    ///
747    /// This method is not intended to be called in analyses but rather in writing methods
748    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
749    pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
750        let wrapped_events = events
751            .into_iter()
752            .map(|event| Event::new(event, metadata.clone()))
753            .collect();
754        Dataset {
755            events: wrapped_events,
756            metadata,
757        }
758    }
759
760    /// Create a new [`Dataset`] from a list of [`EventData`] (MPI-compatible version).
761    ///
762    /// # Notes
763    ///
764    /// This method is not intended to be called in analyses but rather in writing methods
765    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
766    #[cfg(feature = "mpi")]
767    pub fn new_mpi(
768        events: Vec<Arc<EventData>>,
769        metadata: Arc<DatasetMetadata>,
770        world: &SimpleCommunicator,
771    ) -> Self {
772        let partitions = Dataset::partition(events, world);
773        let local = partitions[world.rank() as usize]
774            .iter()
775            .cloned()
776            .map(|event| Event::new(event, metadata.clone()))
777            .collect();
778        Dataset {
779            events: local,
780            metadata,
781        }
782    }
783
784    /// Create a new [`Dataset`] from a list of [`EventData`].
785    ///
786    /// This method is prefered for external use because it contains proper MPI construction
787    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
788    /// interfacing with MPI and should be avoided unless you know what you are doing.
789    pub fn new(events: Vec<Arc<EventData>>) -> Self {
790        Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
791    }
792
793    /// Create a dataset with explicit metadata for name-based lookups.
794    /// Create a dataset with explicit metadata for name-based lookups.
795    pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
796        #[cfg(feature = "mpi")]
797        {
798            if let Some(world) = crate::mpi::get_world() {
799                return Dataset::new_mpi(events, metadata, &world);
800            }
801        }
802        Dataset::new_local(events, metadata)
803    }
804
805    /// The number of [`EventData`]s in the [`Dataset`] (non-MPI version).
806    ///
807    /// # Notes
808    ///
809    /// This method is not intended to be called in analyses but rather in writing methods
810    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
811    pub fn n_events_local(&self) -> usize {
812        self.events.len()
813    }
814
815    /// The number of [`EventData`]s in the [`Dataset`] (MPI-compatible version).
816    ///
817    /// # Notes
818    ///
819    /// This method is not intended to be called in analyses but rather in writing methods
820    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
821    #[cfg(feature = "mpi")]
822    pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
823        let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
824        let n_events_local = self.n_events_local();
825        world.all_gather_into(&n_events_local, &mut n_events_partitioned);
826        n_events_partitioned.iter().sum()
827    }
828
829    /// The number of [`EventData`]s in the [`Dataset`].
830    pub fn n_events(&self) -> usize {
831        #[cfg(feature = "mpi")]
832        {
833            if let Some(world) = crate::mpi::get_world() {
834                return self.n_events_mpi(&world);
835            }
836        }
837        self.n_events_local()
838    }
839}
840
841impl Dataset {
842    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (non-MPI version).
843    ///
844    /// # Notes
845    ///
846    /// This method is not intended to be called in analyses but rather in writing methods
847    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
848    pub fn weights_local(&self) -> Vec<f64> {
849        #[cfg(feature = "rayon")]
850        return self.events.par_iter().map(|e| e.weight).collect();
851        #[cfg(not(feature = "rayon"))]
852        return self.events.iter().map(|e| e.weight).collect();
853    }
854
855    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (MPI-compatible version).
856    ///
857    /// # Notes
858    ///
859    /// This method is not intended to be called in analyses but rather in writing methods
860    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
861    #[cfg(feature = "mpi")]
862    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
863        let local_weights = self.weights_local();
864        let n_events = self.n_events();
865        let mut buffer: Vec<f64> = vec![0.0; n_events];
866        let (counts, displs) = world.get_counts_displs(n_events);
867        {
868            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
869            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
870        }
871        buffer
872    }
873
874    /// Extract a list of weights over each [`EventData`] in the [`Dataset`].
875    pub fn weights(&self) -> Vec<f64> {
876        #[cfg(feature = "mpi")]
877        {
878            if let Some(world) = crate::mpi::get_world() {
879                return self.weights_mpi(&world);
880            }
881        }
882        self.weights_local()
883    }
884
885    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (non-MPI version).
886    ///
887    /// # Notes
888    ///
889    /// This method is not intended to be called in analyses but rather in writing methods
890    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
891    pub fn n_events_weighted_local(&self) -> f64 {
892        #[cfg(feature = "rayon")]
893        return self
894            .events
895            .par_iter()
896            .map(|e| e.weight)
897            .parallel_sum_with_accumulator::<Klein<f64>>();
898        #[cfg(not(feature = "rayon"))]
899        return self.events.iter().map(|e| e.weight).sum();
900    }
901    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (MPI-compatible version).
902    ///
903    /// # Notes
904    ///
905    /// This method is not intended to be called in analyses but rather in writing methods
906    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
907    #[cfg(feature = "mpi")]
908    pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> f64 {
909        let mut n_events_weighted_partitioned: Vec<f64> = vec![0.0; world.size() as usize];
910        let n_events_weighted_local = self.n_events_weighted_local();
911        world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
912        #[cfg(feature = "rayon")]
913        return n_events_weighted_partitioned
914            .into_par_iter()
915            .parallel_sum_with_accumulator::<Klein<f64>>();
916        #[cfg(not(feature = "rayon"))]
917        return n_events_weighted_partitioned.iter().sum();
918    }
919
920    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`].
921    pub fn n_events_weighted(&self) -> f64 {
922        #[cfg(feature = "mpi")]
923        {
924            if let Some(world) = crate::mpi::get_world() {
925                return self.n_events_weighted_mpi(&world);
926            }
927        }
928        self.n_events_weighted_local()
929    }
930
931    /// Generate a new dataset with the same length by resampling the events in the original datset
932    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
933    ///
934    /// # Notes
935    ///
936    /// This method is not intended to be called in analyses but rather in writing methods
937    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
938    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
939        let mut rng = fastrand::Rng::with_seed(seed as u64);
940        let mut indices: Vec<usize> = (0..self.n_events())
941            .map(|_| rng.usize(0..self.n_events()))
942            .collect::<Vec<usize>>();
943        indices.sort();
944        #[cfg(feature = "rayon")]
945        let bootstrapped_events: Vec<Arc<EventData>> = indices
946            .into_par_iter()
947            .map(|idx| self.events[idx].data_arc())
948            .collect();
949        #[cfg(not(feature = "rayon"))]
950        let bootstrapped_events: Vec<Arc<EventData>> = indices
951            .into_iter()
952            .map(|idx| self.events[idx].data_arc())
953            .collect();
954        Arc::new(Dataset::new_with_metadata(
955            bootstrapped_events,
956            self.metadata.clone(),
957        ))
958    }
959
960    /// Generate a new dataset with the same length by resampling the events in the original datset
961    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
962    ///
963    /// # Notes
964    ///
965    /// This method is not intended to be called in analyses but rather in writing methods
966    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
967    #[cfg(feature = "mpi")]
968    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
969        let n_events = self.n_events();
970        let mut indices: Vec<usize> = vec![0; n_events];
971        if world.is_root() {
972            let mut rng = fastrand::Rng::with_seed(seed as u64);
973            indices = (0..n_events)
974                .map(|_| rng.usize(0..n_events))
975                .collect::<Vec<usize>>();
976            indices.sort();
977        }
978        world.process_at_root().broadcast_into(&mut indices);
979        let local_indices: Vec<usize> = indices
980            .into_iter()
981            .filter_map(|idx| {
982                let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
983                if world.rank() == owning_rank {
984                    Some(local_index)
985                } else {
986                    None
987                }
988            })
989            .collect();
990        // `local_indices` only contains indices owned by the current rank, translating them into
991        // local indices on the events vector.
992        #[cfg(feature = "rayon")]
993        let bootstrapped_events: Vec<Arc<EventData>> = local_indices
994            .into_par_iter()
995            .map(|idx| self.events[idx].data_arc())
996            .collect();
997        #[cfg(not(feature = "rayon"))]
998        let bootstrapped_events: Vec<Arc<EventData>> = local_indices
999            .into_iter()
1000            .map(|idx| self.events[idx].data_arc())
1001            .collect();
1002        Arc::new(Dataset::new_with_metadata(
1003            bootstrapped_events,
1004            self.metadata.clone(),
1005        ))
1006    }
1007
1008    /// Generate a new dataset with the same length by resampling the events in the original datset
1009    /// with replacement. This can be used to perform error analysis via the bootstrap method.
1010    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1011        #[cfg(feature = "mpi")]
1012        {
1013            if let Some(world) = crate::mpi::get_world() {
1014                return self.bootstrap_mpi(seed, &world);
1015            }
1016        }
1017        self.bootstrap_local(seed)
1018    }
1019
1020    /// Filter the [`Dataset`] by a given [`VariableExpression`], selecting events for which
1021    /// the expression returns `true`.
1022    pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1023        let compiled = expression.compile(&self.metadata)?;
1024        #[cfg(feature = "rayon")]
1025        let filtered_events: Vec<Arc<EventData>> = self
1026            .events
1027            .par_iter()
1028            .filter(|event| compiled.evaluate(event.as_ref()))
1029            .map(|event| event.data_arc())
1030            .collect();
1031        #[cfg(not(feature = "rayon"))]
1032        let filtered_events: Vec<Arc<EventData>> = self
1033            .events
1034            .iter()
1035            .filter(|event| compiled.evaluate(event.as_ref()))
1036            .map(|event| event.data_arc())
1037            .collect();
1038        Ok(Arc::new(Dataset::new_with_metadata(
1039            filtered_events,
1040            self.metadata.clone(),
1041        )))
1042    }
1043
1044    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
1045    /// given `range`.
1046    pub fn bin_by<V>(
1047        &self,
1048        mut variable: V,
1049        bins: usize,
1050        range: (f64, f64),
1051    ) -> LadduResult<BinnedDataset>
1052    where
1053        V: Variable,
1054    {
1055        variable.bind(self.metadata())?;
1056        let bin_width = (range.1 - range.0) / bins as f64;
1057        let bin_edges = get_bin_edges(bins, range);
1058        let variable = variable;
1059        #[cfg(feature = "rayon")]
1060        let evaluated: Vec<(usize, Arc<EventData>)> = self
1061            .events
1062            .par_iter()
1063            .filter_map(|event| {
1064                let value = variable.value(event.as_ref());
1065                if value >= range.0 && value < range.1 {
1066                    let bin_index = ((value - range.0) / bin_width) as usize;
1067                    let bin_index = bin_index.min(bins - 1);
1068                    Some((bin_index, event.data_arc()))
1069                } else {
1070                    None
1071                }
1072            })
1073            .collect();
1074        #[cfg(not(feature = "rayon"))]
1075        let evaluated: Vec<(usize, Arc<EventData>)> = self
1076            .events
1077            .iter()
1078            .filter_map(|event| {
1079                let value = variable.value(event.as_ref());
1080                if value >= range.0 && value < range.1 {
1081                    let bin_index = ((value - range.0) / bin_width) as usize;
1082                    let bin_index = bin_index.min(bins - 1);
1083                    Some((bin_index, event.data_arc()))
1084                } else {
1085                    None
1086                }
1087            })
1088            .collect();
1089        let mut binned_events: Vec<Vec<Arc<EventData>>> = vec![Vec::default(); bins];
1090        for (bin_index, event) in evaluated {
1091            binned_events[bin_index].push(event.clone());
1092        }
1093        #[cfg(feature = "rayon")]
1094        let datasets: Vec<Arc<Dataset>> = binned_events
1095            .into_par_iter()
1096            .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1097            .collect();
1098        #[cfg(not(feature = "rayon"))]
1099        let datasets: Vec<Arc<Dataset>> = binned_events
1100            .into_iter()
1101            .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1102            .collect();
1103        Ok(BinnedDataset {
1104            datasets,
1105            edges: bin_edges,
1106        })
1107    }
1108
1109    /// Boost all the four-momenta in all [`EventData`]s to the rest frame of the given set of
1110    /// four-momenta identified by name.
1111    pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1112    where
1113        S: AsRef<str>,
1114    {
1115        let mut indices: Vec<usize> = Vec::new();
1116        for name in names {
1117            let name_ref = name.as_ref();
1118            if let Some(selection) = self.metadata.p4_selection(name_ref) {
1119                indices.extend_from_slice(selection.indices());
1120            } else {
1121                panic!("Unknown particle name '{name}'", name = name_ref);
1122            }
1123        }
1124        #[cfg(feature = "rayon")]
1125        let boosted_events: Vec<Arc<EventData>> = self
1126            .events
1127            .par_iter()
1128            .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1129            .collect();
1130        #[cfg(not(feature = "rayon"))]
1131        let boosted_events: Vec<Arc<EventData>> = self
1132            .events
1133            .iter()
1134            .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1135            .collect();
1136        Arc::new(Dataset::new_with_metadata(
1137            boosted_events,
1138            self.metadata.clone(),
1139        ))
1140    }
1141    /// Evaluate a [`Variable`] on every event in the [`Dataset`].
1142    pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
1143        variable.value_on(self)
1144    }
1145
1146    fn write_parquet_impl(
1147        &self,
1148        file_path: PathBuf,
1149        options: &DatasetWriteOptions,
1150    ) -> LadduResult<()> {
1151        let batch_size = options.batch_size.max(1);
1152        let precision = options.precision;
1153        let schema = Arc::new(build_parquet_schema(&self.metadata, precision));
1154
1155        #[cfg(feature = "mpi")]
1156        let is_root = crate::mpi::get_world()
1157            .as_ref()
1158            .map_or(true, |world| world.rank() == 0);
1159        #[cfg(not(feature = "mpi"))]
1160        let is_root = true;
1161
1162        let mut writer: Option<ArrowWriter<File>> = None;
1163        if is_root {
1164            let file = File::create(&file_path)?;
1165            writer = Some(
1166                ArrowWriter::try_new(file, schema.clone(), None).map_err(|err| {
1167                    LadduError::Custom(format!("Failed to create Parquet writer: {err}"))
1168                })?,
1169            );
1170        }
1171
1172        let mut iter = self.iter();
1173        loop {
1174            let mut buffers =
1175                ColumnBuffers::new(self.metadata.p4_names.len(), self.metadata.aux_names.len());
1176            let mut rows = 0usize;
1177
1178            while rows < batch_size {
1179                match iter.next() {
1180                    Some(event) => {
1181                        if is_root {
1182                            buffers.push_event(&event);
1183                        }
1184                        rows += 1;
1185                    }
1186                    None => break,
1187                }
1188            }
1189
1190            if rows == 0 {
1191                break;
1192            }
1193
1194            if let Some(writer) = writer.as_mut() {
1195                let batch = buffers
1196                    .into_record_batch(schema.clone(), precision)
1197                    .map_err(|err| {
1198                        LadduError::Custom(format!("Failed to build Parquet batch: {err}"))
1199                    })?;
1200                writer.write(&batch).map_err(|err| {
1201                    LadduError::Custom(format!("Failed to write Parquet batch: {err}"))
1202                })?;
1203            }
1204        }
1205
1206        if let Some(writer) = writer {
1207            writer.close().map_err(|err| {
1208                LadduError::Custom(format!("Failed to finalise Parquet file: {err}"))
1209            })?;
1210        }
1211
1212        Ok(())
1213    }
1214
1215    fn write_root_impl(
1216        &self,
1217        file_path: PathBuf,
1218        options: &DatasetWriteOptions,
1219    ) -> LadduResult<()> {
1220        let tree_name = options.tree.clone().unwrap_or_else(|| "events".to_string());
1221        let branch_count = self.metadata.p4_names.len() * 4 + self.metadata.aux_names.len() + 1; // +weight
1222
1223        #[cfg(feature = "mpi")]
1224        let mut world_opt = crate::mpi::get_world();
1225        #[cfg(feature = "mpi")]
1226        let is_root = world_opt.as_ref().map_or(true, |world| world.rank() == 0);
1227        #[cfg(not(feature = "mpi"))]
1228        let is_root = true;
1229
1230        #[cfg(feature = "mpi")]
1231        let world: Option<WorldHandle> = world_opt.take();
1232        #[cfg(not(feature = "mpi"))]
1233        let world: Option<WorldHandle> = None;
1234
1235        let total_events = self.n_events();
1236        let dataset_arc = Arc::new(self.clone());
1237
1238        match options.precision {
1239            FloatPrecision::F64 => self.write_root_with_type::<f64>(
1240                dataset_arc,
1241                world,
1242                is_root,
1243                &file_path,
1244                &tree_name,
1245                branch_count,
1246                total_events,
1247            ),
1248            FloatPrecision::F32 => self.write_root_with_type::<f32>(
1249                dataset_arc,
1250                world,
1251                is_root,
1252                &file_path,
1253                &tree_name,
1254                branch_count,
1255                total_events,
1256            ),
1257        }
1258    }
1259}
1260
1261fn canonicalize_dataset_path(file_path: &str) -> LadduResult<PathBuf> {
1262    Ok(Path::new(&*shellexpand::full(file_path)?).canonicalize()?)
1263}
1264
1265fn expand_output_path(file_path: &str) -> LadduResult<PathBuf> {
1266    Ok(PathBuf::from(&*shellexpand::full(file_path)?))
1267}
1268
1269/// Load a [`Dataset`] from a Parquet file.
1270pub fn read_parquet(file_path: &str, options: &DatasetReadOptions) -> LadduResult<Arc<Dataset>> {
1271    let path = canonicalize_dataset_path(file_path)?;
1272    let (detected_p4_names, detected_aux_names) = detect_columns(&path)?;
1273    let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1274    let file = File::open(path)?;
1275    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1276
1277    #[cfg(feature = "mpi")]
1278    {
1279        if let Some(world) = crate::mpi::get_world() {
1280            return read_parquet_mpi(builder, metadata, &world);
1281        }
1282    }
1283
1284    read_parquet_local(builder, metadata)
1285}
1286
1287fn read_parquet_local(
1288    builder: ParquetRecordBatchReaderBuilder<File>,
1289    metadata: Arc<DatasetMetadata>,
1290) -> LadduResult<Arc<Dataset>> {
1291    let reader = builder.build()?;
1292    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
1293    let events = batches_to_events(batches, metadata.as_ref())?;
1294    Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1295}
1296
1297#[cfg(feature = "mpi")]
1298fn read_parquet_mpi(
1299    mut builder: ParquetRecordBatchReaderBuilder<File>,
1300    metadata: Arc<DatasetMetadata>,
1301    world: &SimpleCommunicator,
1302) -> LadduResult<Arc<Dataset>> {
1303    let parquet_metadata = builder.metadata().clone();
1304    let total_rows = parquet_metadata.file_metadata().num_rows() as usize;
1305    if total_rows == 0 {
1306        return Ok(Arc::new(Dataset::new_local(Vec::new(), metadata)));
1307    }
1308
1309    let partition = world.partition(total_rows);
1310    let rank = world.rank() as usize;
1311    let local_range = partition.range_for_rank(rank);
1312    let local_start = local_range.start;
1313    let local_end = local_range.end;
1314    if local_start == local_end {
1315        return Ok(Arc::new(Dataset::new_local(Vec::new(), metadata)));
1316    }
1317
1318    let (row_groups, first_row_start) =
1319        row_groups_for_range(&parquet_metadata, local_start, local_end);
1320    if !row_groups.is_empty() {
1321        builder = builder.with_row_groups(row_groups);
1322    }
1323
1324    let reader = builder.build()?;
1325    let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
1326    let mut events = batches_to_events(batches, metadata.as_ref())?;
1327
1328    let drop_front = local_start.saturating_sub(first_row_start);
1329    if drop_front > 0 {
1330        events.drain(0..drop_front);
1331    }
1332    let expected_local = local_end - local_start;
1333    if events.len() > expected_local {
1334        events.truncate(expected_local);
1335    }
1336    if events.len() != expected_local {
1337        return Err(LadduError::Custom(format!(
1338            "Loaded {} rows on rank {} but expected {}",
1339            events.len(),
1340            rank,
1341            expected_local
1342        )));
1343    }
1344
1345    Ok(Arc::new(Dataset::new_local(events, metadata)))
1346}
1347
1348#[cfg(feature = "mpi")]
1349fn row_groups_for_range(
1350    metadata: &Arc<ParquetMetaData>,
1351    start: usize,
1352    end: usize,
1353) -> (Vec<usize>, usize) {
1354    let mut selected = Vec::new();
1355    let mut first_row_start = start;
1356    let mut offset = 0usize;
1357
1358    for (idx, row_group) in metadata.row_groups().iter().enumerate() {
1359        let group_start = offset;
1360        let rows = row_group.num_rows() as usize;
1361        let group_end = group_start + rows;
1362        offset = group_end;
1363
1364        if group_end <= start {
1365            continue;
1366        }
1367        if group_start >= end {
1368            break;
1369        }
1370        if selected.is_empty() {
1371            first_row_start = group_start;
1372        }
1373        selected.push(idx);
1374        if group_end >= end {
1375            break;
1376        }
1377    }
1378
1379    (selected, first_row_start)
1380}
1381
1382fn batches_to_events(
1383    batches: Vec<RecordBatch>,
1384    metadata: &DatasetMetadata,
1385) -> LadduResult<Vec<Arc<EventData>>> {
1386    #[cfg(feature = "rayon")]
1387    {
1388        let batch_events: Vec<LadduResult<Vec<Arc<EventData>>>> = batches
1389            .into_par_iter()
1390            .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1391            .collect();
1392        let mut events = Vec::new();
1393        for batch in batch_events {
1394            let mut batch = batch?;
1395            events.append(&mut batch);
1396        }
1397        Ok(events)
1398    }
1399
1400    #[cfg(not(feature = "rayon"))]
1401    {
1402        Ok(batches
1403            .into_iter()
1404            .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1405            .collect::<LadduResult<Vec<_>>>()?
1406            .into_iter()
1407            .flatten()
1408            .collect())
1409    }
1410}
1411
1412/// Load a [`Dataset`] from a ROOT TTree using the oxyroot backend.
1413pub fn read_root(file_path: &str, options: &DatasetReadOptions) -> LadduResult<Arc<Dataset>> {
1414    let path = canonicalize_dataset_path(file_path)?;
1415    let mut file = RootFile::open(&path).map_err(|err| {
1416        LadduError::Custom(format!(
1417            "Failed to open ROOT file '{}': {err}",
1418            path.display()
1419        ))
1420    })?;
1421
1422    let (tree, tree_name) = resolve_root_tree(&mut file, options.tree.as_deref())?;
1423
1424    let branches: Vec<&Branch> = tree.branches().collect();
1425    let mut lookup: BranchLookup<'_> = IndexMap::new();
1426    for &branch in &branches {
1427        if let Some(kind) = branch_scalar_kind(branch) {
1428            lookup.insert(branch.name(), (kind, branch));
1429        }
1430    }
1431
1432    if lookup.is_empty() {
1433        return Err(LadduError::Custom(format!(
1434            "No float or double branches found in ROOT tree '{tree_name}'"
1435        )));
1436    }
1437
1438    let column_names: Vec<&str> = lookup.keys().copied().collect();
1439    let (detected_p4_names, detected_aux_names) = infer_p4_and_aux_names(&column_names);
1440    let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1441
1442    struct RootP4Columns {
1443        px: Vec<f64>,
1444        py: Vec<f64>,
1445        pz: Vec<f64>,
1446        e: Vec<f64>,
1447    }
1448
1449    // TODO: do all reads in parallel if possible to match parquet impl
1450    let mut p4_columns = Vec::with_capacity(metadata.p4_names.len());
1451    for name in &metadata.p4_names {
1452        let logical = format!("{name}_px");
1453        let px = read_branch_values_from_candidates(
1454            &lookup,
1455            &component_candidates(name, "px"),
1456            &logical,
1457        )?;
1458
1459        let logical = format!("{name}_py");
1460        let py = read_branch_values_from_candidates(
1461            &lookup,
1462            &component_candidates(name, "py"),
1463            &logical,
1464        )?;
1465
1466        let logical = format!("{name}_pz");
1467        let pz = read_branch_values_from_candidates(
1468            &lookup,
1469            &component_candidates(name, "pz"),
1470            &logical,
1471        )?;
1472
1473        let logical = format!("{name}_e");
1474        let e = read_branch_values_from_candidates(
1475            &lookup,
1476            &component_candidates(name, "e"),
1477            &logical,
1478        )?;
1479
1480        p4_columns.push(RootP4Columns { px, py, pz, e });
1481    }
1482
1483    let mut aux_columns = Vec::with_capacity(metadata.aux_names.len());
1484    for name in &metadata.aux_names {
1485        let values = read_branch_values(&lookup, name)?;
1486        aux_columns.push(values);
1487    }
1488
1489    let n_events = if let Some(first) = p4_columns.first() {
1490        first.px.len()
1491    } else if let Some(first) = aux_columns.first() {
1492        first.len()
1493    } else {
1494        return Err(LadduError::Custom(
1495            "Unable to determine event count; dataset has no four-momentum or auxiliary columns"
1496                .to_string(),
1497        ));
1498    };
1499
1500    let weight_values = match read_branch_values_optional(&lookup, "weight")? {
1501        Some(values) => {
1502            if values.len() != n_events {
1503                return Err(LadduError::Custom(format!(
1504                    "Column 'weight' has {} entries but expected {}",
1505                    values.len(),
1506                    n_events
1507                )));
1508            }
1509            values
1510        }
1511        None => vec![1.0; n_events],
1512    };
1513
1514    let mut events = Vec::with_capacity(n_events);
1515    for row in 0..n_events {
1516        let mut p4s = Vec::with_capacity(p4_columns.len());
1517        for columns in &p4_columns {
1518            p4s.push(Vec4::new(
1519                columns.px[row],
1520                columns.py[row],
1521                columns.pz[row],
1522                columns.e[row],
1523            ));
1524        }
1525
1526        let mut aux = Vec::with_capacity(aux_columns.len());
1527        for column in &aux_columns {
1528            aux.push(column[row]);
1529        }
1530
1531        let event = EventData {
1532            p4s,
1533            aux,
1534            weight: weight_values[row],
1535        };
1536        events.push(Arc::new(event));
1537    }
1538
1539    Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1540}
1541
1542/// Persist a [`Dataset`] to a Parquet file.
1543pub fn write_parquet(
1544    dataset: &Dataset,
1545    file_path: &str,
1546    options: &DatasetWriteOptions,
1547) -> LadduResult<()> {
1548    let path = expand_output_path(file_path)?;
1549    dataset.write_parquet_impl(path, options)
1550}
1551
1552/// Persist a [`Dataset`] to a ROOT file using the oxyroot backend.
1553pub fn write_root(
1554    dataset: &Dataset,
1555    file_path: &str,
1556    options: &DatasetWriteOptions,
1557) -> LadduResult<()> {
1558    let path = expand_output_path(file_path)?;
1559    dataset.write_root_impl(path, options)
1560}
1561
1562impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
1563    debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
1564    debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
1565    Dataset {
1566        events: a
1567            .events
1568            .iter()
1569            .chain(b.events.iter())
1570            .cloned()
1571            .collect(),
1572        metadata: a.metadata.clone(),
1573    }
1574});
1575
1576/// Options for reading a [`Dataset`] from a file.
1577///
1578/// # See Also
1579/// [`read_parquet`], [`read_root`]
1580#[derive(Default, Clone)]
1581pub struct DatasetReadOptions {
1582    /// Particle names to read from the data file.
1583    pub p4_names: Option<Vec<String>>,
1584    /// Auxiliary scalar names to read from the data file.
1585    pub aux_names: Option<Vec<String>>,
1586    /// Name of the tree to read when loading ROOT files. When absent and the file contains a
1587    /// single tree, it will be selected automatically.
1588    pub tree: Option<String>,
1589    /// Optional aliases mapping logical names to selections of four-momenta.
1590    pub aliases: IndexMap<String, P4Selection>,
1591}
1592
1593/// Precision for writing floating-point columns.
1594#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
1595pub enum FloatPrecision {
1596    /// 32-bit floats.
1597    F32,
1598    /// 64-bit floats.
1599    #[default]
1600    F64,
1601}
1602
1603/// Options for writing a [`Dataset`] to disk.
1604#[derive(Clone, Debug)]
1605pub struct DatasetWriteOptions {
1606    /// Number of events to include in each batch when writing.
1607    pub batch_size: usize,
1608    /// Floating-point precision to use for persisted columns.
1609    pub precision: FloatPrecision,
1610    /// Tree name to use when writing ROOT files.
1611    pub tree: Option<String>,
1612}
1613
1614impl Default for DatasetWriteOptions {
1615    fn default() -> Self {
1616        Self {
1617            batch_size: DEFAULT_WRITE_BATCH_SIZE,
1618            precision: FloatPrecision::default(),
1619            tree: None,
1620        }
1621    }
1622}
1623
1624impl DatasetWriteOptions {
1625    /// Override the batch size used for writing; defaults to 10_000.
1626    pub fn batch_size(mut self, batch_size: usize) -> Self {
1627        self.batch_size = batch_size;
1628        self
1629    }
1630
1631    /// Select the floating-point precision for persisted columns.
1632    pub fn precision(mut self, precision: FloatPrecision) -> Self {
1633        self.precision = precision;
1634        self
1635    }
1636
1637    /// Set the ROOT tree name (defaults to \"events\").
1638    pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
1639        self.tree = Some(name.into());
1640        self
1641    }
1642}
1643impl DatasetReadOptions {
1644    /// Create a new [`Default`] set of [`DatasetReadOptions`].
1645    pub fn new() -> Self {
1646        Self::default()
1647    }
1648
1649    /// If provided, the specified particles will be read from the data file (assuming columns with
1650    /// required suffixes are present, i.e. `<particle>_px`, `<particle>_py`, `<particle>_pz`, and `<particle>_e`). Otherwise, all valid columns with these suffixes will be read.
1651    pub fn p4_names<I, S>(mut self, names: I) -> Self
1652    where
1653        I: IntoIterator<Item = S>,
1654        S: AsRef<str>,
1655    {
1656        self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1657        self
1658    }
1659
1660    /// If provided, the specified columns will be read as auxiliary scalars. Otherwise, all valid
1661    /// columns which do not satisfy the conditions required to be read as four-momenta will be
1662    /// used.
1663    pub fn aux_names<I, S>(mut self, names: I) -> Self
1664    where
1665        I: IntoIterator<Item = S>,
1666        S: AsRef<str>,
1667    {
1668        self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1669        self
1670    }
1671
1672    /// Select the tree to read when opening ROOT files.
1673    pub fn tree<S>(mut self, name: S) -> Self
1674    where
1675        S: AsRef<str>,
1676    {
1677        self.tree = Some(name.as_ref().to_string());
1678        self
1679    }
1680
1681    /// Register an alias for one or more existing four-momenta.
1682    pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
1683    where
1684        N: Into<String>,
1685        S: IntoP4Selection,
1686    {
1687        self.aliases.insert(name.into(), selection.into_selection());
1688        self
1689    }
1690
1691    /// Register multiple aliases for four-momenta selections.
1692    pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
1693    where
1694        I: IntoIterator<Item = (N, S)>,
1695        N: Into<String>,
1696        S: IntoP4Selection,
1697    {
1698        for (name, selection) in aliases {
1699            self = self.alias(name, selection);
1700        }
1701        self
1702    }
1703
1704    fn resolve_metadata(
1705        &self,
1706        detected_p4_names: Vec<String>,
1707        detected_aux_names: Vec<String>,
1708    ) -> LadduResult<Arc<DatasetMetadata>> {
1709        let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
1710        let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
1711
1712        let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
1713        if !self.aliases.is_empty() {
1714            metadata.add_p4_aliases(self.aliases.clone())?;
1715        }
1716        Ok(Arc::new(metadata))
1717    }
1718}
1719
1720const P4_COMPONENT_SUFFIXES: [&str; 4] = ["_px", "_py", "_pz", "_e"];
1721const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
1722
1723fn detect_columns(file_path: &PathBuf) -> LadduResult<(Vec<String>, Vec<String>)> {
1724    let file = File::open(file_path)?;
1725    let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1726    let schema = builder.schema();
1727    let float_cols: Vec<&str> = schema
1728        .fields()
1729        .iter()
1730        .filter(|f| matches!(f.data_type(), DataType::Float32 | DataType::Float64))
1731        .map(|f| f.name().as_str())
1732        .collect();
1733    Ok(infer_p4_and_aux_names(&float_cols))
1734}
1735
1736fn infer_p4_and_aux_names(float_cols: &[&str]) -> (Vec<String>, Vec<String>) {
1737    let suffix_set: IndexSet<&str> = P4_COMPONENT_SUFFIXES.iter().copied().collect();
1738    let mut groups: IndexMap<&str, IndexSet<&str>> = IndexMap::new();
1739    for col in float_cols {
1740        for suffix in &suffix_set {
1741            if let Some(prefix) = col.strip_suffix(suffix) {
1742                groups.entry(prefix).or_default().insert(*suffix);
1743            }
1744        }
1745    }
1746
1747    let mut p4_names: Vec<String> = Vec::new();
1748    let mut p4_columns: IndexSet<String> = IndexSet::new();
1749    for (prefix, suffixes) in &groups {
1750        if suffixes.len() == suffix_set.len() {
1751            p4_names.push((*prefix).to_string());
1752            for suffix in &suffix_set {
1753                p4_columns.insert(format!("{prefix}{suffix}"));
1754            }
1755        }
1756    }
1757
1758    let mut aux_names: Vec<String> = Vec::new();
1759    for col in float_cols {
1760        if p4_columns.contains(*col) {
1761            continue;
1762        }
1763        if col.eq_ignore_ascii_case("weight") {
1764            continue;
1765        }
1766        aux_names.push((*col).to_string());
1767    }
1768
1769    (p4_names, aux_names)
1770}
1771
1772type BranchLookup<'a> = IndexMap<&'a str, (RootScalarKind, &'a Branch)>;
1773
1774#[derive(Clone, Copy)]
1775enum RootScalarKind {
1776    F32,
1777    F64,
1778}
1779
1780fn branch_scalar_kind(branch: &Branch) -> Option<RootScalarKind> {
1781    let type_name = branch.item_type_name();
1782    let lower = type_name.to_ascii_lowercase();
1783    if lower.contains("vector") {
1784        return None;
1785    }
1786    match lower.as_str() {
1787        "float" | "float_t" | "float32_t" => Some(RootScalarKind::F32),
1788        "double" | "double_t" | "double32_t" => Some(RootScalarKind::F64),
1789        _ => None,
1790    }
1791}
1792
1793fn read_branch_values<'a>(lookup: &BranchLookup<'a>, column_name: &str) -> LadduResult<Vec<f64>> {
1794    let (kind, branch) =
1795        lookup
1796            .get(column_name)
1797            .copied()
1798            .ok_or_else(|| LadduError::MissingColumn {
1799                name: column_name.to_string(),
1800            })?;
1801
1802    let values = match kind {
1803        RootScalarKind::F32 => branch
1804            .as_iter::<f32>()
1805            .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1806            .map(|value| value as f64)
1807            .collect(),
1808        RootScalarKind::F64 => branch
1809            .as_iter::<f64>()
1810            .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1811            .collect(),
1812    };
1813    Ok(values)
1814}
1815
1816fn read_branch_values_optional<'a>(
1817    lookup: &BranchLookup<'a>,
1818    column_name: &str,
1819) -> LadduResult<Option<Vec<f64>>> {
1820    if lookup.contains_key(column_name) {
1821        read_branch_values(lookup, column_name).map(Some)
1822    } else {
1823        Ok(None)
1824    }
1825}
1826
1827fn read_branch_values_from_candidates<'a>(
1828    lookup: &BranchLookup<'a>,
1829    candidates: &[String],
1830    logical_name: &str,
1831) -> LadduResult<Vec<f64>> {
1832    for candidate in candidates {
1833        if lookup.contains_key(candidate.as_str()) {
1834            return read_branch_values(lookup, candidate);
1835        }
1836    }
1837    Err(LadduError::MissingColumn {
1838        name: logical_name.to_string(),
1839    })
1840}
1841
1842fn resolve_root_tree(
1843    file: &mut RootFile,
1844    requested: Option<&str>,
1845) -> LadduResult<(ReaderTree, String)> {
1846    if let Some(name) = requested {
1847        let tree = file
1848            .get_tree(name)
1849            .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{name}'"), err))?;
1850        return Ok((tree, name.to_string()));
1851    }
1852
1853    let tree_names: Vec<String> = file
1854        .keys()
1855        .into_iter()
1856        .filter(|key| key.class_name() == "TTree")
1857        .map(|key| key.name().to_string())
1858        .collect();
1859
1860    if tree_names.is_empty() {
1861        return Err(LadduError::Custom(
1862            "ROOT file does not contain any TTrees".to_string(),
1863        ));
1864    }
1865
1866    if tree_names.len() > 1 {
1867        return Err(LadduError::Custom(format!(
1868            "Multiple TTrees found ({:?}); specify DatasetReadOptions::tree to disambiguate",
1869            tree_names
1870        )));
1871    }
1872
1873    let selected = &tree_names[0];
1874    let tree = file
1875        .get_tree(selected)
1876        .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{selected}'"), err))?;
1877    Ok((tree, selected.clone()))
1878}
1879
1880fn map_root_error<E: std::fmt::Display>(context: &str, err: E) -> LadduError {
1881    LadduError::Custom(format!("{context}: {err}")) // NOTE: the oxyroot error type is not public
1882}
1883
1884#[derive(Clone, Copy)]
1885enum FloatColumn<'a> {
1886    F32(&'a Float32Array),
1887    F64(&'a Float64Array),
1888}
1889
1890impl<'a> FloatColumn<'a> {
1891    fn value(&self, row: usize) -> f64 {
1892        match self {
1893            Self::F32(array) => array.value(row) as f64,
1894            Self::F64(array) => array.value(row),
1895        }
1896    }
1897}
1898
1899struct P4Columns<'a> {
1900    px: FloatColumn<'a>,
1901    py: FloatColumn<'a>,
1902    pz: FloatColumn<'a>,
1903    e: FloatColumn<'a>,
1904}
1905
1906fn prepare_float_column<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<FloatColumn<'a>> {
1907    prepare_float_column_from_candidates(batch, &[name.to_string()], name)
1908}
1909
1910fn prepare_p4_columns<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<P4Columns<'a>> {
1911    Ok(P4Columns {
1912        px: prepare_float_column_from_candidates(
1913            batch,
1914            &component_candidates(name, "px"),
1915            &format!("{name}_px"),
1916        )?,
1917        py: prepare_float_column_from_candidates(
1918            batch,
1919            &component_candidates(name, "py"),
1920            &format!("{name}_py"),
1921        )?,
1922        pz: prepare_float_column_from_candidates(
1923            batch,
1924            &component_candidates(name, "pz"),
1925            &format!("{name}_pz"),
1926        )?,
1927        e: prepare_float_column_from_candidates(
1928            batch,
1929            &component_candidates(name, "e"),
1930            &format!("{name}_e"),
1931        )?,
1932    })
1933}
1934
1935fn component_candidates(name: &str, suffix: &str) -> Vec<String> {
1936    let mut candidates = Vec::with_capacity(3);
1937    let base = format!("{name}_{suffix}");
1938    candidates.push(base.clone());
1939
1940    let mut capitalized = suffix.to_string();
1941    if let Some(first) = capitalized.get_mut(0..1) {
1942        first.make_ascii_uppercase();
1943    }
1944    if capitalized != suffix {
1945        candidates.push(format!("{name}_{capitalized}"));
1946    }
1947
1948    let upper = suffix.to_ascii_uppercase();
1949    if upper != suffix && upper != capitalized {
1950        candidates.push(format!("{name}_{upper}"));
1951    }
1952
1953    candidates
1954}
1955
1956fn find_float_column_from_candidates<'a>(
1957    batch: &'a RecordBatch,
1958    candidates: &[String],
1959) -> LadduResult<Option<FloatColumn<'a>>> {
1960    use arrow::datatypes::DataType;
1961
1962    for candidate in candidates {
1963        if let Some(column) = batch.column_by_name(candidate) {
1964            return match column.data_type() {
1965                DataType::Float32 => Ok(Some(FloatColumn::F32(
1966                    column
1967                        .as_any()
1968                        .downcast_ref::<Float32Array>()
1969                        .expect("Column advertised as Float32 but could not be downcast"),
1970                ))),
1971                DataType::Float64 => Ok(Some(FloatColumn::F64(
1972                    column
1973                        .as_any()
1974                        .downcast_ref::<Float64Array>()
1975                        .expect("Column advertised as Float64 but could not be downcast"),
1976                ))),
1977                other => {
1978                    return Err(LadduError::InvalidColumnType {
1979                        name: candidate.clone(),
1980                        datatype: other.to_string(),
1981                    })
1982                }
1983            };
1984        }
1985    }
1986    Ok(None)
1987}
1988
1989fn prepare_float_column_from_candidates<'a>(
1990    batch: &'a RecordBatch,
1991    candidates: &[String],
1992    logical_name: &str,
1993) -> LadduResult<FloatColumn<'a>> {
1994    find_float_column_from_candidates(batch, candidates)?.ok_or_else(|| LadduError::MissingColumn {
1995        name: logical_name.to_string(),
1996    })
1997}
1998
1999fn record_batch_to_events(
2000    batch: RecordBatch,
2001    p4_names: &[String],
2002    aux_names: &[String],
2003) -> LadduResult<Vec<Arc<EventData>>> {
2004    let batch_ref = &batch;
2005    let p4_columns: Vec<P4Columns<'_>> = p4_names
2006        .iter()
2007        .map(|name| prepare_p4_columns(batch_ref, name))
2008        .collect::<Result<_, _>>()?;
2009
2010    let aux_columns: Vec<FloatColumn<'_>> = aux_names
2011        .iter()
2012        .map(|name| prepare_float_column(batch_ref, name))
2013        .collect::<Result<_, _>>()?;
2014
2015    let weight_column = find_float_column_from_candidates(batch_ref, &["weight".to_string()])?;
2016
2017    let mut events = Vec::with_capacity(batch_ref.num_rows());
2018    for row in 0..batch_ref.num_rows() {
2019        let mut p4s = Vec::with_capacity(p4_columns.len());
2020        for columns in &p4_columns {
2021            let px = columns.px.value(row);
2022            let py = columns.py.value(row);
2023            let pz = columns.pz.value(row);
2024            let e = columns.e.value(row);
2025            p4s.push(Vec4::new(px, py, pz, e));
2026        }
2027
2028        let mut aux = Vec::with_capacity(aux_columns.len());
2029        for column in &aux_columns {
2030            aux.push(column.value(row));
2031        }
2032
2033        let event = EventData {
2034            p4s,
2035            aux,
2036            weight: weight_column
2037                .as_ref()
2038                .map(|column| column.value(row))
2039                .unwrap_or(1.0),
2040        };
2041        events.push(Arc::new(event));
2042    }
2043    Ok(events)
2044}
2045
2046struct ColumnBuffers {
2047    p4: Vec<P4Buffer>,
2048    aux: Vec<Vec<f64>>,
2049    weight: Vec<f64>,
2050}
2051
2052impl ColumnBuffers {
2053    fn new(n_p4: usize, n_aux: usize) -> Self {
2054        let p4 = (0..n_p4).map(|_| P4Buffer::default()).collect();
2055        let aux = vec![Vec::new(); n_aux];
2056        Self {
2057            p4,
2058            aux,
2059            weight: Vec::new(),
2060        }
2061    }
2062
2063    fn push_event(&mut self, event: &Event) {
2064        for (buffer, p4) in self.p4.iter_mut().zip(event.p4s.iter()) {
2065            buffer.px.push(p4.x);
2066            buffer.py.push(p4.y);
2067            buffer.pz.push(p4.z);
2068            buffer.e.push(p4.t);
2069        }
2070
2071        for (buffer, value) in self.aux.iter_mut().zip(event.aux.iter()) {
2072            buffer.push(*value);
2073        }
2074
2075        self.weight.push(event.weight);
2076    }
2077
2078    fn into_record_batch(
2079        self,
2080        schema: Arc<Schema>,
2081        precision: FloatPrecision,
2082    ) -> arrow::error::Result<RecordBatch> {
2083        let mut columns: Vec<arrow::array::ArrayRef> = Vec::new();
2084
2085        match precision {
2086            FloatPrecision::F64 => {
2087                for buffer in &self.p4 {
2088                    columns.push(Arc::new(Float64Array::from(buffer.px.clone())));
2089                    columns.push(Arc::new(Float64Array::from(buffer.py.clone())));
2090                    columns.push(Arc::new(Float64Array::from(buffer.pz.clone())));
2091                    columns.push(Arc::new(Float64Array::from(buffer.e.clone())));
2092                }
2093
2094                for buffer in &self.aux {
2095                    columns.push(Arc::new(Float64Array::from(buffer.clone())));
2096                }
2097
2098                columns.push(Arc::new(Float64Array::from(self.weight)));
2099            }
2100            FloatPrecision::F32 => {
2101                for buffer in &self.p4 {
2102                    columns.push(Arc::new(Float32Array::from(
2103                        buffer.px.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2104                    )));
2105                    columns.push(Arc::new(Float32Array::from(
2106                        buffer.py.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2107                    )));
2108                    columns.push(Arc::new(Float32Array::from(
2109                        buffer.pz.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2110                    )));
2111                    columns.push(Arc::new(Float32Array::from(
2112                        buffer.e.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2113                    )));
2114                }
2115
2116                for buffer in &self.aux {
2117                    columns.push(Arc::new(Float32Array::from(
2118                        buffer.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2119                    )));
2120                }
2121
2122                columns.push(Arc::new(Float32Array::from(
2123                    self.weight.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2124                )));
2125            }
2126        }
2127
2128        RecordBatch::try_new(schema, columns)
2129    }
2130}
2131
2132#[derive(Default)]
2133struct P4Buffer {
2134    px: Vec<f64>,
2135    py: Vec<f64>,
2136    pz: Vec<f64>,
2137    e: Vec<f64>,
2138}
2139
2140fn build_parquet_schema(metadata: &DatasetMetadata, precision: FloatPrecision) -> Schema {
2141    let dtype = match precision {
2142        FloatPrecision::F64 => DataType::Float64,
2143        FloatPrecision::F32 => DataType::Float32,
2144    };
2145
2146    let mut fields = Vec::new();
2147    for name in &metadata.p4_names {
2148        for suffix in P4_COMPONENT_SUFFIXES {
2149            fields.push(Field::new(format!("{name}{suffix}"), dtype.clone(), false));
2150        }
2151    }
2152
2153    for name in &metadata.aux_names {
2154        fields.push(Field::new(name.clone(), dtype.clone(), false));
2155    }
2156
2157    fields.push(Field::new("weight", dtype, false));
2158    Schema::new(fields)
2159}
2160
2161trait FromF64 {
2162    fn from_f64(value: f64) -> Self;
2163}
2164
2165impl FromF64 for f64 {
2166    fn from_f64(value: f64) -> Self {
2167        value
2168    }
2169}
2170
2171impl FromF64 for f32 {
2172    fn from_f64(value: f64) -> Self {
2173        value as f32
2174    }
2175}
2176
2177struct SharedEventFetcher {
2178    dataset: Arc<Dataset>,
2179    world: Option<WorldHandle>,
2180    total: usize,
2181    branch_count: usize,
2182    current_index: Option<usize>,
2183    current_event: Option<Event>,
2184    remaining: usize,
2185}
2186
2187impl SharedEventFetcher {
2188    fn new(
2189        dataset: Arc<Dataset>,
2190        world: Option<WorldHandle>,
2191        total: usize,
2192        branch_count: usize,
2193    ) -> Self {
2194        Self {
2195            dataset,
2196            world,
2197            total,
2198            branch_count,
2199            current_index: None,
2200            current_event: None,
2201            remaining: 0,
2202        }
2203    }
2204
2205    fn event_for_index(&mut self, index: usize) -> Option<Event> {
2206        if index >= self.total {
2207            return None;
2208        }
2209
2210        let refresh_needed = match self.current_index {
2211            None => true,
2212            Some(current) => current != index || self.remaining == 0,
2213        };
2214
2215        if refresh_needed {
2216            let event =
2217                fetch_event_for_index(&self.dataset, index, self.total, self.world.as_ref());
2218            self.current_index = Some(index);
2219            self.remaining = self.branch_count;
2220            self.current_event = Some(event);
2221        }
2222
2223        let event = self.current_event.as_ref().cloned();
2224        if self.remaining > 0 {
2225            self.remaining -= 1;
2226        }
2227        if self.remaining == 0 {
2228            // Drop the cached event so the next request fetches the next index.
2229            self.current_event = None;
2230        }
2231        event
2232    }
2233}
2234
2235enum ColumnKind {
2236    Px(usize),
2237    Py(usize),
2238    Pz(usize),
2239    E(usize),
2240    Aux(usize),
2241    Weight,
2242}
2243
2244struct ColumnIterator<T> {
2245    fetcher: Arc<Mutex<SharedEventFetcher>>,
2246    index: usize,
2247    kind: ColumnKind,
2248    _marker: std::marker::PhantomData<T>,
2249}
2250
2251impl<T> ColumnIterator<T> {
2252    fn new(fetcher: Arc<Mutex<SharedEventFetcher>>, kind: ColumnKind) -> Self {
2253        Self {
2254            fetcher,
2255            index: 0,
2256            kind,
2257            _marker: std::marker::PhantomData,
2258        }
2259    }
2260}
2261
2262impl<T> Iterator for ColumnIterator<T>
2263where
2264    T: FromF64,
2265{
2266    type Item = T;
2267
2268    fn next(&mut self) -> Option<Self::Item> {
2269        let mut fetcher = self.fetcher.lock();
2270        let event = fetcher.event_for_index(self.index)?;
2271        self.index += 1;
2272
2273        match self.kind {
2274            ColumnKind::Px(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.x)),
2275            ColumnKind::Py(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.y)),
2276            ColumnKind::Pz(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.z)),
2277            ColumnKind::E(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.t)),
2278            ColumnKind::Aux(idx) => event.aux.get(idx).map(|value| T::from_f64(*value)),
2279            ColumnKind::Weight => Some(T::from_f64(event.weight)),
2280        }
2281    }
2282}
2283
2284fn build_root_column_iterators<T>(
2285    dataset: Arc<Dataset>,
2286    world: Option<WorldHandle>,
2287    branch_count: usize,
2288    total: usize,
2289) -> Vec<(String, ColumnIterator<T>)>
2290where
2291    T: FromF64,
2292{
2293    let fetcher = Arc::new(Mutex::new(SharedEventFetcher::new(
2294        dataset,
2295        world,
2296        total,
2297        branch_count,
2298    )));
2299
2300    let p4_names: Vec<String> = fetcher.lock().dataset.metadata.p4_names.clone();
2301    let aux_names: Vec<String> = fetcher.lock().dataset.metadata.aux_names.clone();
2302
2303    let mut iterators = Vec::new();
2304
2305    for (idx, name) in p4_names.iter().enumerate() {
2306        iterators.push((
2307            format!("{name}_px"),
2308            ColumnIterator::new(fetcher.clone(), ColumnKind::Px(idx)),
2309        ));
2310        iterators.push((
2311            format!("{name}_py"),
2312            ColumnIterator::new(fetcher.clone(), ColumnKind::Py(idx)),
2313        ));
2314        iterators.push((
2315            format!("{name}_pz"),
2316            ColumnIterator::new(fetcher.clone(), ColumnKind::Pz(idx)),
2317        ));
2318        iterators.push((
2319            format!("{name}_e"),
2320            ColumnIterator::new(fetcher.clone(), ColumnKind::E(idx)),
2321        ));
2322    }
2323
2324    for (idx, name) in aux_names.iter().enumerate() {
2325        iterators.push((
2326            name.clone(),
2327            ColumnIterator::new(fetcher.clone(), ColumnKind::Aux(idx)),
2328        ));
2329    }
2330
2331    iterators.push((
2332        "weight".to_string(),
2333        ColumnIterator::new(fetcher, ColumnKind::Weight),
2334    ));
2335
2336    iterators
2337}
2338
2339fn drain_column_iterators<T>(iterators: &mut [(String, ColumnIterator<T>)], n_events: usize)
2340where
2341    T: FromF64,
2342{
2343    for _ in 0..n_events {
2344        for (_name, iterator) in iterators.iter_mut() {
2345            let _ = iterator.next();
2346        }
2347    }
2348}
2349
2350fn fetch_event_for_index(
2351    dataset: &Dataset,
2352    index: usize,
2353    total: usize,
2354    world: Option<&WorldHandle>,
2355) -> Event {
2356    let _ = total;
2357    let _ = world;
2358    #[cfg(feature = "mpi")]
2359    {
2360        if let Some(world) = world {
2361            return fetch_event_mpi(dataset, index, world, total);
2362        }
2363    }
2364
2365    dataset.index_local(index).clone()
2366}
2367
2368impl Dataset {
2369    #[allow(clippy::too_many_arguments)]
2370    fn write_root_with_type<T>(
2371        &self,
2372        dataset: Arc<Dataset>,
2373        world: Option<WorldHandle>,
2374        is_root: bool,
2375        file_path: &Path,
2376        tree_name: &str,
2377        branch_count: usize,
2378        total_events: usize,
2379    ) -> LadduResult<()>
2380    where
2381        T: FromF64 + oxyroot::Marshaler + 'static,
2382    {
2383        let mut iterators =
2384            build_root_column_iterators::<T>(dataset, world, branch_count, total_events);
2385
2386        if is_root {
2387            let mut file = RootFile::create(file_path).map_err(|err| {
2388                LadduError::Custom(format!(
2389                    "Failed to create ROOT file '{}': {err}",
2390                    file_path.display()
2391                ))
2392            })?;
2393
2394            let mut tree = WriterTree::new(tree_name);
2395            for (name, iterator) in iterators {
2396                tree.new_branch(name, iterator);
2397            }
2398
2399            tree.write(&mut file).map_err(|err| {
2400                LadduError::Custom(format!(
2401                    "Failed to write ROOT tree '{tree_name}' to '{}': {err}",
2402                    file_path.display()
2403                ))
2404            })?;
2405
2406            file.close().map_err(|err| {
2407                LadduError::Custom(format!(
2408                    "Failed to close ROOT file '{}': {err}",
2409                    file_path.display()
2410                ))
2411            })?;
2412        } else {
2413            drain_column_iterators(&mut iterators, total_events);
2414        }
2415
2416        Ok(())
2417    }
2418}
2419
2420/// A list of [`Dataset`]s formed by binning [`EventData`] by some [`Variable`].
2421pub struct BinnedDataset {
2422    datasets: Vec<Arc<Dataset>>,
2423    edges: Vec<f64>,
2424}
2425
2426impl Index<usize> for BinnedDataset {
2427    type Output = Arc<Dataset>;
2428
2429    fn index(&self, index: usize) -> &Self::Output {
2430        &self.datasets[index]
2431    }
2432}
2433
2434impl IndexMut<usize> for BinnedDataset {
2435    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2436        &mut self.datasets[index]
2437    }
2438}
2439
2440impl Deref for BinnedDataset {
2441    type Target = Vec<Arc<Dataset>>;
2442
2443    fn deref(&self) -> &Self::Target {
2444        &self.datasets
2445    }
2446}
2447
2448impl DerefMut for BinnedDataset {
2449    fn deref_mut(&mut self) -> &mut Self::Target {
2450        &mut self.datasets
2451    }
2452}
2453
2454impl BinnedDataset {
2455    /// The number of bins in the [`BinnedDataset`].
2456    pub fn n_bins(&self) -> usize {
2457        self.datasets.len()
2458    }
2459
2460    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
2461    pub fn edges(&self) -> Vec<f64> {
2462        self.edges.clone()
2463    }
2464
2465    /// Returns the range that was used to form the [`BinnedDataset`].
2466    pub fn range(&self) -> (f64, f64) {
2467        (self.edges[0], self.edges[self.n_bins()])
2468    }
2469}
2470
2471#[cfg(test)]
2472mod tests {
2473    use crate::Mass;
2474
2475    use super::*;
2476    use crate::utils::vectors::Vec3;
2477    use approx::{assert_relative_eq, assert_relative_ne};
2478    use fastrand;
2479    use serde::{Deserialize, Serialize};
2480    use std::{
2481        env, fs,
2482        path::{Path, PathBuf},
2483    };
2484
2485    fn test_data_path(file: &str) -> PathBuf {
2486        Path::new(env!("CARGO_MANIFEST_DIR"))
2487            .join("test_data")
2488            .join(file)
2489    }
2490
2491    fn open_test_dataset(file: &str, options: DatasetReadOptions) -> Arc<Dataset> {
2492        let path = test_data_path(file);
2493        let path_str = path.to_str().expect("test data path should be valid UTF-8");
2494        let ext = path
2495            .extension()
2496            .and_then(|ext| ext.to_str())
2497            .unwrap_or_default()
2498            .to_ascii_lowercase();
2499        match ext.as_str() {
2500            "parquet" => read_parquet(path_str, &options),
2501            "root" => read_root(path_str, &options),
2502            other => panic!("Unsupported extension in test data: {other}"),
2503        }
2504        .expect("dataset should open")
2505    }
2506
2507    fn make_temp_dir() -> PathBuf {
2508        let dir = env::temp_dir().join(format!("laddu_test_{}", fastrand::u64(..)));
2509        fs::create_dir(&dir).expect("temp dir should be created");
2510        dir
2511    }
2512
2513    fn assert_events_close(left: &Event, right: &Event, p4_names: &[&str], aux_names: &[&str]) {
2514        for name in p4_names {
2515            let lp4 = left
2516                .p4(name)
2517                .unwrap_or_else(|| panic!("missing p4 '{name}' in left dataset"));
2518            let rp4 = right
2519                .p4(name)
2520                .unwrap_or_else(|| panic!("missing p4 '{name}' in right dataset"));
2521            assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-9);
2522            assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-9);
2523            assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-9);
2524            assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-9);
2525        }
2526        let left_aux = left.aux();
2527        let right_aux = right.aux();
2528        for name in aux_names {
2529            let laux = left_aux
2530                .get(name)
2531                .copied()
2532                .unwrap_or_else(|| panic!("missing aux '{name}' in left dataset"));
2533            let raux = right_aux
2534                .get(name)
2535                .copied()
2536                .unwrap_or_else(|| panic!("missing aux '{name}' in right dataset"));
2537            assert_relative_eq!(laux, raux, epsilon = 1e-9);
2538        }
2539        assert_relative_eq!(left.weight(), right.weight(), epsilon = 1e-9);
2540    }
2541
2542    fn assert_datasets_close(
2543        left: &Arc<Dataset>,
2544        right: &Arc<Dataset>,
2545        p4_names: &[&str],
2546        aux_names: &[&str],
2547    ) {
2548        assert_eq!(left.n_events(), right.n_events());
2549        for idx in 0..left.n_events() {
2550            let levent = &left[idx];
2551            let revent = &right[idx];
2552            assert_events_close(levent, revent, p4_names, aux_names);
2553        }
2554    }
2555
2556    #[test]
2557    fn test_from_parquet_auto_matches_explicit_names() {
2558        let auto = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2559        let explicit_options = DatasetReadOptions::new()
2560            .p4_names(TEST_P4_NAMES)
2561            .aux_names(TEST_AUX_NAMES);
2562        let explicit = open_test_dataset("data_f32.parquet", explicit_options);
2563
2564        let mut detected_p4: Vec<&str> = auto.p4_names().iter().map(String::as_str).collect();
2565        detected_p4.sort_unstable();
2566        let mut expected_p4 = TEST_P4_NAMES.to_vec();
2567        expected_p4.sort_unstable();
2568        assert_eq!(detected_p4, expected_p4);
2569        let mut detected_aux: Vec<&str> = auto.aux_names().iter().map(String::as_str).collect();
2570        detected_aux.sort_unstable();
2571        let mut expected_aux = TEST_AUX_NAMES.to_vec();
2572        expected_aux.sort_unstable();
2573        assert_eq!(detected_aux, expected_aux);
2574        assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2575    }
2576
2577    #[test]
2578    fn test_from_parquet_with_aliases() {
2579        let dataset = open_test_dataset(
2580            "data_f32.parquet",
2581            DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2582        );
2583        let event = dataset.named_event(0);
2584        let alias_vec = event.p4("resonance").expect("alias vector");
2585        let expected = event.get_p4_sum(["kshort1", "kshort2"]);
2586        assert_relative_eq!(alias_vec.px(), expected.px(), epsilon = 1e-9);
2587        assert_relative_eq!(alias_vec.py(), expected.py(), epsilon = 1e-9);
2588        assert_relative_eq!(alias_vec.pz(), expected.pz(), epsilon = 1e-9);
2589        assert_relative_eq!(alias_vec.e(), expected.e(), epsilon = 1e-9);
2590    }
2591
2592    #[test]
2593    fn test_from_parquet_f64_matches_f32() {
2594        let f32_ds = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2595        let f64_ds = open_test_dataset("data_f64.parquet", DatasetReadOptions::new());
2596        assert_datasets_close(&f64_ds, &f32_ds, TEST_P4_NAMES, TEST_AUX_NAMES);
2597    }
2598
2599    #[test]
2600    fn test_from_root_detects_columns_and_matches_parquet() {
2601        let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2602        let root_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2603        let mut detected_p4: Vec<&str> = root_auto.p4_names().iter().map(String::as_str).collect();
2604        detected_p4.sort_unstable();
2605        let mut expected_p4 = TEST_P4_NAMES.to_vec();
2606        expected_p4.sort_unstable();
2607        assert_eq!(detected_p4, expected_p4);
2608        let mut detected_aux: Vec<&str> =
2609            root_auto.aux_names().iter().map(String::as_str).collect();
2610        detected_aux.sort_unstable();
2611        let mut expected_aux = TEST_AUX_NAMES.to_vec();
2612        expected_aux.sort_unstable();
2613        assert_eq!(detected_aux, expected_aux);
2614        let root_named_options = DatasetReadOptions::new()
2615            .p4_names(TEST_P4_NAMES)
2616            .aux_names(TEST_AUX_NAMES);
2617        let root_named = open_test_dataset("data_f32.root", root_named_options);
2618        assert_datasets_close(&root_auto, &root_named, TEST_P4_NAMES, TEST_AUX_NAMES);
2619        assert_datasets_close(&root_auto, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2620    }
2621
2622    #[test]
2623    fn test_from_root_f64_matches_parquet() {
2624        let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2625        let root_f64 = open_test_dataset("data_f64.root", DatasetReadOptions::new());
2626        assert_datasets_close(&root_f64, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2627    }
2628    #[test]
2629    fn test_event_creation() {
2630        let event = test_event();
2631        assert_eq!(event.p4s.len(), 4);
2632        assert_eq!(event.aux.len(), 2);
2633        assert_relative_eq!(event.weight, 0.48)
2634    }
2635
2636    #[test]
2637    fn test_event_p4_sum() {
2638        let event = test_event();
2639        let sum = event.get_p4_sum([2, 3]);
2640        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
2641        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
2642        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
2643        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
2644    }
2645
2646    #[test]
2647    fn test_event_boost() {
2648        let event = test_event();
2649        let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
2650        let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
2651        assert_relative_eq!(p4_sum.px(), 0.0);
2652        assert_relative_eq!(p4_sum.py(), 0.0);
2653        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2654    }
2655
2656    #[test]
2657    fn test_event_evaluate() {
2658        let event = test_event();
2659        let mut mass = Mass::new(["proton"]);
2660        mass.bind(
2661            &DatasetMetadata::new(
2662                TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
2663                TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
2664            )
2665            .expect("metadata"),
2666        )
2667        .unwrap();
2668        assert_relative_eq!(event.evaluate(&mass), 1.007);
2669    }
2670
2671    #[test]
2672    fn test_dataset_size_check() {
2673        let dataset = Dataset::new(Vec::new());
2674        assert_eq!(dataset.n_events(), 0);
2675        let dataset = Dataset::new(vec![Arc::new(test_event())]);
2676        assert_eq!(dataset.n_events(), 1);
2677    }
2678
2679    #[test]
2680    fn test_dataset_sum() {
2681        let dataset = test_dataset();
2682        let metadata = dataset.metadata_arc();
2683        let dataset2 = Dataset::new_with_metadata(
2684            vec![Arc::new(EventData {
2685                p4s: test_event().p4s,
2686                aux: test_event().aux,
2687                weight: 0.52,
2688            })],
2689            metadata.clone(),
2690        );
2691        let dataset_sum = &dataset + &dataset2;
2692        assert_eq!(dataset_sum[0].weight, dataset[0].weight);
2693        assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
2694    }
2695
2696    #[test]
2697    fn test_dataset_weights() {
2698        let dataset = Dataset::new(vec![
2699            Arc::new(test_event()),
2700            Arc::new(EventData {
2701                p4s: test_event().p4s,
2702                aux: test_event().aux,
2703                weight: 0.52,
2704            }),
2705        ]);
2706        let weights = dataset.weights();
2707        assert_eq!(weights.len(), 2);
2708        assert_relative_eq!(weights[0], 0.48);
2709        assert_relative_eq!(weights[1], 0.52);
2710        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
2711    }
2712
2713    #[test]
2714    fn test_dataset_filtering() {
2715        let metadata = Arc::new(
2716            DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2717                .expect("metadata should be valid"),
2718        );
2719        let events = vec![
2720            Arc::new(EventData {
2721                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
2722                aux: vec![],
2723                weight: 1.0,
2724            }),
2725            Arc::new(EventData {
2726                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
2727                aux: vec![],
2728                weight: 1.0,
2729            }),
2730            Arc::new(EventData {
2731                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
2732                // HACK: using 1.0 messes with this test because the eventual computation gives a mass
2733                // slightly less than 1.0
2734                aux: vec![],
2735                weight: 1.0,
2736            }),
2737        ];
2738        let dataset = Dataset::new_with_metadata(events, metadata);
2739
2740        let metadata = dataset.metadata_arc();
2741        let mut mass = Mass::new(["beam"]);
2742        mass.bind(metadata.as_ref()).unwrap();
2743        let expression = mass.gt(0.0).and(&mass.lt(1.0));
2744
2745        let filtered = dataset.filter(&expression).unwrap();
2746        assert_eq!(filtered.n_events(), 1);
2747        assert_relative_eq!(mass.value(&filtered[0]), 0.5);
2748    }
2749
2750    #[test]
2751    fn test_dataset_boost() {
2752        let dataset = test_dataset();
2753        let dataset_boosted = dataset.boost_to_rest_frame_of(&["proton", "kshort1", "kshort2"]);
2754        let p4_sum = dataset_boosted[0].get_p4_sum(["proton", "kshort1", "kshort2"]);
2755        assert_relative_eq!(p4_sum.px(), 0.0);
2756        assert_relative_eq!(p4_sum.py(), 0.0);
2757        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2758    }
2759
2760    #[test]
2761    fn test_named_event_view() {
2762        let dataset = test_dataset();
2763        let view = dataset.named_event(0);
2764
2765        assert_relative_eq!(view.weight(), dataset[0].weight);
2766        let beam = view.p4("beam").expect("beam p4");
2767        assert_relative_eq!(beam.px(), dataset[0].p4s[0].px());
2768        assert_relative_eq!(beam.e(), dataset[0].p4s[0].e());
2769
2770        let summed = view.get_p4_sum(["kshort1", "kshort2"]);
2771        assert_relative_eq!(summed.e(), dataset[0].p4s[2].e() + dataset[0].p4s[3].e());
2772
2773        let aux_angle = view.aux().get("pol_angle").copied().expect("pol angle");
2774        assert_relative_eq!(aux_angle, dataset[0].aux[1]);
2775
2776        let metadata = dataset.metadata_arc();
2777        let boosted = view.boost_to_rest_frame_of(["proton", "kshort1", "kshort2"]);
2778        let boosted_event = Event::new(Arc::new(boosted), metadata);
2779        let boosted_sum = boosted_event.get_p4_sum(["proton", "kshort1", "kshort2"]);
2780        assert_relative_eq!(boosted_sum.px(), 0.0);
2781    }
2782
2783    #[test]
2784    fn test_dataset_evaluate() {
2785        let dataset = test_dataset();
2786        let mass = Mass::new(["proton"]);
2787        assert_relative_eq!(dataset.evaluate(&mass).unwrap()[0], 1.007);
2788    }
2789
2790    #[test]
2791    fn test_dataset_metadata_rejects_duplicate_names() {
2792        let err = DatasetMetadata::new(vec!["beam", "beam"], Vec::<String>::new());
2793        assert!(matches!(
2794            err,
2795            Err(LadduError::DuplicateName { category, .. }) if category == "p4"
2796        ));
2797        let err = DatasetMetadata::new(
2798            vec!["beam"],
2799            vec!["pol_angle".to_string(), "pol_angle".to_string()],
2800        );
2801        assert!(matches!(
2802            err,
2803            Err(LadduError::DuplicateName { category, .. }) if category == "aux"
2804        ));
2805    }
2806
2807    #[test]
2808    fn test_dataset_lookup_by_name() {
2809        let dataset = test_dataset();
2810        let proton = dataset.p4_by_name(0, "proton").expect("proton p4");
2811        let proton_idx = dataset.metadata().p4_index("proton").unwrap();
2812        assert_relative_eq!(proton.e(), dataset[0].p4s[proton_idx].e());
2813        assert!(dataset.p4_by_name(0, "unknown").is_none());
2814        let angle = dataset.aux_by_name(0, "pol_angle").expect("pol_angle");
2815        assert_relative_eq!(angle, dataset[0].aux[1]);
2816        assert!(dataset.aux_by_name(0, "missing").is_none());
2817    }
2818
2819    #[test]
2820    fn test_binned_dataset() {
2821        let dataset = Dataset::new(vec![
2822            Arc::new(EventData {
2823                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
2824                aux: vec![],
2825                weight: 1.0,
2826            }),
2827            Arc::new(EventData {
2828                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
2829                aux: vec![],
2830                weight: 2.0,
2831            }),
2832        ]);
2833
2834        #[derive(Clone, Serialize, Deserialize, Debug)]
2835        struct BeamEnergy;
2836        impl Display for BeamEnergy {
2837            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2838                write!(f, "BeamEnergy")
2839            }
2840        }
2841        #[typetag::serde]
2842        impl Variable for BeamEnergy {
2843            fn value(&self, event: &EventData) -> f64 {
2844                event.p4s[0].e()
2845            }
2846        }
2847        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
2848
2849        // Test binning by first particle energy
2850        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0)).unwrap();
2851
2852        assert_eq!(binned.n_bins(), 2);
2853        assert_eq!(binned.edges().len(), 3);
2854        assert_relative_eq!(binned.edges()[0], 0.0);
2855        assert_relative_eq!(binned.edges()[2], 3.0);
2856        assert_eq!(binned[0].n_events(), 1);
2857        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
2858        assert_eq!(binned[1].n_events(), 1);
2859        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
2860    }
2861
2862    #[test]
2863    fn test_dataset_bootstrap() {
2864        let metadata = test_dataset().metadata_arc();
2865        let dataset = Dataset::new_with_metadata(
2866            vec![
2867                Arc::new(test_event()),
2868                Arc::new(EventData {
2869                    p4s: test_event().p4s.clone(),
2870                    aux: test_event().aux.clone(),
2871                    weight: 1.0,
2872                }),
2873            ],
2874            metadata,
2875        );
2876        assert_relative_ne!(dataset[0].weight, dataset[1].weight);
2877
2878        let bootstrapped = dataset.bootstrap(43);
2879        assert_eq!(bootstrapped.n_events(), dataset.n_events());
2880        assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
2881
2882        // Test empty dataset bootstrap
2883        let empty_dataset = Dataset::new(Vec::new());
2884        let empty_bootstrap = empty_dataset.bootstrap(43);
2885        assert_eq!(empty_bootstrap.n_events(), 0);
2886    }
2887
2888    #[test]
2889    fn test_dataset_iteration_returns_events() {
2890        let dataset = test_dataset();
2891        let mut weights = Vec::new();
2892        for event in dataset.iter() {
2893            weights.push(event.weight());
2894        }
2895        assert_eq!(weights.len(), dataset.n_events());
2896        assert_relative_eq!(weights[0], dataset[0].weight);
2897    }
2898
2899    #[test]
2900    fn test_dataset_into_iter_returns_events() {
2901        let dataset = test_dataset();
2902        let weights: Vec<f64> = dataset.into_iter().map(|event| event.weight()).collect();
2903        assert_eq!(weights.len(), 1);
2904        assert_relative_eq!(weights[0], test_event().weight);
2905    }
2906    #[test]
2907    fn test_event_display() {
2908        let event = test_event();
2909        let display_string = format!("{}", event);
2910        assert!(display_string.contains("Event:"));
2911        assert!(display_string.contains("p4s:"));
2912        assert!(display_string.contains("aux:"));
2913        assert!(display_string.contains("aux[0]: 0.38562805"));
2914        assert!(display_string.contains("aux[1]: 0.05708078"));
2915        assert!(display_string.contains("weight:"));
2916    }
2917
2918    #[test]
2919    fn test_name_based_access() {
2920        let metadata =
2921            Arc::new(DatasetMetadata::new(vec!["beam", "target"], vec!["pol_angle"]).unwrap());
2922        let event = Arc::new(EventData {
2923            p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(0.1, 0.2, 0.3, 0.5)],
2924            aux: vec![0.42],
2925            weight: 1.0,
2926        });
2927        let dataset = Dataset::new_with_metadata(vec![event], metadata);
2928        let beam = dataset.p4_by_name(0, "beam").unwrap();
2929        assert_relative_eq!(beam.px(), 0.0);
2930        assert_relative_eq!(beam.py(), 0.0);
2931        assert_relative_eq!(beam.pz(), 1.0);
2932        assert_relative_eq!(beam.e(), 1.0);
2933        assert_relative_eq!(dataset.aux_by_name(0, "pol_angle").unwrap(), 0.42);
2934        assert!(dataset.p4_by_name(0, "missing").is_none());
2935        assert!(dataset.aux_by_name(0, "missing").is_none());
2936    }
2937
2938    #[test]
2939    fn test_parquet_roundtrip_to_tempfile() {
2940        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2941        let dir = make_temp_dir();
2942        let path = dir.join("roundtrip.parquet");
2943        let path_str = path.to_str().expect("path should be valid UTF-8");
2944
2945        write_parquet(&dataset, path_str, &DatasetWriteOptions::default())
2946            .expect("writing parquet should succeed");
2947        let reopened = read_parquet(path_str, &DatasetReadOptions::new())
2948            .expect("parquet roundtrip should reopen");
2949
2950        assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2951        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2952    }
2953
2954    #[test]
2955    fn test_root_roundtrip_to_tempfile() {
2956        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2957        let dir = make_temp_dir();
2958        let path = dir.join("roundtrip.root");
2959        let path_str = path.to_str().expect("path should be valid UTF-8");
2960
2961        write_root(&dataset, path_str, &DatasetWriteOptions::default())
2962            .expect("writing root should succeed");
2963        let reopened =
2964            read_root(path_str, &DatasetReadOptions::new()).expect("root roundtrip should reopen");
2965
2966        assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2967        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2968    }
2969}