Skip to main content

laddu_core/
data.rs

1#[cfg(feature = "mpi")]
2use crate::mpi::LadduMPI;
3use accurate::{sum::Klein, traits::*};
4use auto_ops::impl_op_ex;
5#[cfg(feature = "mpi")]
6use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
7#[cfg(feature = "rayon")]
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::{
11    fmt::Display,
12    ops::{Deref, DerefMut, Index, IndexMut},
13    sync::Arc,
14};
15
16#[cfg(feature = "mpi")]
17type WorldHandle = SimpleCommunicator;
18#[cfg(not(feature = "mpi"))]
19type WorldHandle = ();
20
21#[cfg(feature = "mpi")]
22// Chosen from local two-rank probes: 512 matched or beat smaller chunks
23// while keeping the fetched-event cache modest.
24const DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE: usize = 512;
25#[cfg(feature = "mpi")]
26const MPI_EVENT_FETCH_CHUNK_SIZE_ENV: &str = "LADDU_MPI_EVENT_FETCH_CHUNK_SIZE";
27
28use crate::utils::get_bin_edges;
29use crate::{
30    utils::{
31        variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
32        vectors::Vec4,
33    },
34    LadduError, LadduResult,
35};
36use indexmap::{IndexMap, IndexSet};
37
38/// Dataset I/O implementations and shared ingestion helpers.
39pub mod io;
40
41/// An event that can be used to test the implementation of an
42/// [`Amplitude`](crate::amplitudes::Amplitude). This particular event contains the reaction
43/// $`\gamma p \to K_S^0 K_S^0 p`$ with a polarized photon beam.
44pub fn test_event() -> EventData {
45    use crate::utils::vectors::*;
46    let pol_magnitude = 0.38562805;
47    let pol_angle = 0.05708078;
48    EventData {
49        p4s: vec![
50            Vec3::new(0.0, 0.0, 8.747).with_mass(0.0),         // beam
51            Vec3::new(0.119, 0.374, 0.222).with_mass(1.007),   // "proton"
52            Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498),  // "kaon"
53            Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), // "kaon"
54        ],
55        aux: vec![pol_magnitude, pol_angle],
56        weight: 0.48,
57    }
58}
59
60/// Particle names used by [`test_dataset`].
61pub const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
62/// Auxiliary scalar names used by [`test_dataset`].
63pub const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
64
65/// A dataset that can be used to test the implementation of an
66/// [`Amplitude`](crate::amplitudes::Amplitude). This particular dataset contains a single
67/// [`EventData`] generated from [`test_event`].
68pub fn test_dataset() -> Dataset {
69    let metadata = Arc::new(
70        DatasetMetadata::new(
71            TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
72            TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
73        )
74        .expect("Test metadata should be valid"),
75    );
76    Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
77}
78
79/// Raw event data in a [`Dataset`] containing all particle and auxiliary information.
80///
81/// An [`EventData`] instance owns the list of four-momenta (`p4s`), auxiliary scalars (`aux`),
82/// and weight recorded for a particular collision event. Use [`Event`] when you need a
83/// metadata-aware view with name-based helpers.
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct EventData {
86    /// A list of four-momenta for each particle.
87    pub p4s: Vec<Vec4>,
88    /// A list of auxiliary scalar values associated with the event.
89    pub aux: Vec<f64>,
90    /// The weight given to the event.
91    pub weight: f64,
92}
93
94impl Display for EventData {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        writeln!(f, "Event:")?;
97        writeln!(f, "  p4s:")?;
98        for p4 in &self.p4s {
99            writeln!(f, "    {}", p4.to_p4_string())?;
100        }
101        writeln!(f, "  aux:")?;
102        for (idx, value) in self.aux.iter().enumerate() {
103            writeln!(f, "    aux[{idx}]: {value}")?;
104        }
105        writeln!(f, "  weight:")?;
106        writeln!(f, "    {}", self.weight)?;
107        Ok(())
108    }
109}
110
111impl EventData {
112    /// Return a four-momentum from the sum of four-momenta at the given indices in the [`EventData`].
113    pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
114        indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
115    }
116    /// Boost all the four-momenta in the [`EventData`] to the rest frame of the given set of
117    /// four-momenta by indices.
118    pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
119        let frame = self.get_p4_sum(indices);
120        EventData {
121            p4s: self
122                .p4s
123                .iter()
124                .map(|p4| p4.boost(&(-frame.beta())))
125                .collect(),
126            aux: self.aux.clone(),
127            weight: self.weight,
128        }
129    }
130}
131
132#[allow(dead_code)]
133#[derive(Debug, Clone, Default)]
134struct ColumnarP4Column {
135    px: Vec<f64>,
136    py: Vec<f64>,
137    pz: Vec<f64>,
138    e: Vec<f64>,
139}
140
141#[allow(dead_code)]
142impl ColumnarP4Column {
143    fn with_capacity(capacity: usize) -> Self {
144        Self {
145            px: Vec::with_capacity(capacity),
146            py: Vec::with_capacity(capacity),
147            pz: Vec::with_capacity(capacity),
148            e: Vec::with_capacity(capacity),
149        }
150    }
151
152    fn push(&mut self, p4: Vec4) {
153        self.px.push(p4.x);
154        self.py.push(p4.y);
155        self.pz.push(p4.z);
156        self.e.push(p4.t);
157    }
158
159    fn get(&self, event_index: usize) -> Vec4 {
160        Vec4::new(
161            self.px[event_index],
162            self.py[event_index],
163            self.pz[event_index],
164            self.e[event_index],
165        )
166    }
167}
168
169/// Columnar dataset storage used by [`Dataset`].
170#[derive(Debug, Default)]
171pub(crate) struct DatasetStorage {
172    metadata: Arc<DatasetMetadata>,
173    p4: Vec<ColumnarP4Column>,
174    aux: Vec<Vec<f64>>,
175    weights: Vec<f64>,
176}
177
178impl Clone for DatasetStorage {
179    fn clone(&self) -> Self {
180        Self {
181            metadata: self.metadata.clone(),
182            p4: self.p4.clone(),
183            aux: self.aux.clone(),
184            weights: self.weights.clone(),
185        }
186    }
187}
188
189impl DatasetStorage {
190    /// Convert this columnar dataset back to a row-event dataset.
191    pub(crate) fn to_dataset(&self) -> Dataset {
192        let events = (0..self.n_events())
193            .map(|event_index| Arc::new(self.event_data(event_index)))
194            .collect::<Vec<_>>();
195        #[cfg(not(feature = "mpi"))]
196        let dataset = Dataset::new_local(events, self.metadata.clone());
197        #[cfg(feature = "mpi")]
198        let mut dataset = Dataset::new_local(events, self.metadata.clone());
199        #[cfg(feature = "mpi")]
200        {
201            if let Some(world) = crate::mpi::get_world() {
202                dataset.set_cached_global_event_count_from_world(&world);
203                dataset.set_cached_global_weighted_sum_from_world(&world);
204            }
205        }
206        dataset
207    }
208
209    /// Access metadata.
210    pub(crate) fn metadata(&self) -> &DatasetMetadata {
211        &self.metadata
212    }
213
214    /// Number of local events.
215    pub(crate) fn n_events(&self) -> usize {
216        self.weights.len()
217    }
218
219    /// Retrieve a p4 value by row and p4 index.
220    pub(crate) fn p4(&self, event_index: usize, p4_index: usize) -> Vec4 {
221        self.p4[p4_index].get(event_index)
222    }
223
224    /// Retrieve an aux value by row and aux index.
225    pub(crate) fn aux(&self, event_index: usize, aux_index: usize) -> f64 {
226        self.aux[aux_index][event_index]
227    }
228
229    /// Retrieve event weight by row index.
230    pub(crate) fn weight(&self, event_index: usize) -> f64 {
231        self.weights[event_index]
232    }
233
234    pub(crate) fn event_data(&self, event_index: usize) -> EventData {
235        let mut p4s = Vec::with_capacity(self.p4.len());
236        for p4_index in 0..self.p4.len() {
237            p4s.push(self.p4(event_index, p4_index));
238        }
239        let mut aux = Vec::with_capacity(self.aux.len());
240        for aux_index in 0..self.aux.len() {
241            aux.push(self.aux(event_index, aux_index));
242        }
243        EventData {
244            p4s,
245            aux,
246            weight: self.weight(event_index),
247        }
248    }
249
250    fn row_view(&self, event_index: usize) -> ColumnarEventView<'_> {
251        ColumnarEventView {
252            storage: self,
253            event_index,
254        }
255    }
256
257    #[allow(dead_code)]
258    pub(crate) fn for_each_named_event_local<F>(&self, mut op: F)
259    where
260        F: FnMut(usize, NamedEventView<'_>),
261    {
262        for event_index in 0..self.n_events() {
263            let row = self.row_view(event_index);
264            let view = NamedEventView {
265                row,
266                metadata: &self.metadata,
267            };
268            op(event_index, view);
269        }
270    }
271
272    pub(crate) fn event_view(&self, event_index: usize) -> NamedEventView<'_> {
273        let row = self.row_view(event_index);
274        NamedEventView {
275            row,
276            metadata: self.metadata(),
277        }
278    }
279}
280
281#[allow(dead_code)]
282#[derive(Debug)]
283struct ColumnarEventView<'a> {
284    storage: &'a DatasetStorage,
285    event_index: usize,
286}
287
288#[allow(dead_code)]
289impl ColumnarEventView<'_> {
290    fn p4(&self, p4_index: usize) -> Vec4 {
291        self.storage.p4(self.event_index, p4_index)
292    }
293
294    fn aux(&self, aux_index: usize) -> f64 {
295        self.storage.aux(self.event_index, aux_index)
296    }
297
298    fn weight(&self) -> f64 {
299        self.storage.weight(self.event_index)
300    }
301
302    fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
303        indices.as_ref().iter().map(|index| self.p4(*index)).sum()
304    }
305}
306
307/// A name-aware columnar event view over a single row in a dataset.
308#[derive(Debug)]
309pub struct NamedEventView<'a> {
310    row: ColumnarEventView<'a>,
311    metadata: &'a DatasetMetadata,
312}
313
314impl NamedEventView<'_> {
315    /// Retrieve a four-momentum by positional index.
316    pub fn p4_at(&self, p4_index: usize) -> Vec4 {
317        self.row.p4(p4_index)
318    }
319
320    /// Retrieve an auxiliary scalar by positional index.
321    pub fn aux_at(&self, aux_index: usize) -> f64 {
322        self.row.aux(aux_index)
323    }
324
325    /// Number of four-momenta in this event.
326    pub fn n_p4(&self) -> usize {
327        self.row.storage.p4.len()
328    }
329
330    /// Number of auxiliary values in this event.
331    pub fn n_aux(&self) -> usize {
332        self.row.storage.aux.len()
333    }
334
335    /// Retrieve a four-momentum by metadata name.
336    pub fn p4(&self, name: &str) -> Option<Vec4> {
337        let selection = self.metadata.p4_selection(name)?;
338        Some(
339            selection
340                .indices()
341                .iter()
342                .map(|index| self.row.p4(*index))
343                .sum(),
344        )
345    }
346
347    /// Retrieve an auxiliary scalar by metadata name.
348    pub fn aux(&self, name: &str) -> Option<f64> {
349        let index = self.metadata.aux_index(name)?;
350        Some(self.row.aux(index))
351    }
352
353    /// Retrieve event weight.
354    pub fn weight(&self) -> f64 {
355        self.row.weight()
356    }
357
358    /// Retrieve the sum of multiple four-momenta selected by name.
359    pub fn get_p4_sum<N>(&self, names: N) -> Option<Vec4>
360    where
361        N: IntoIterator,
362        N::Item: AsRef<str>,
363    {
364        names
365            .into_iter()
366            .map(|name| self.p4(name.as_ref()))
367            .collect::<Option<Vec<_>>>()
368            .map(|momenta| momenta.into_iter().sum())
369    }
370
371    /// Evaluate a [`Variable`] against this event.
372    pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
373        variable.value(self)
374    }
375}
376
377/// A collection of [`EventData`].
378#[derive(Debug, Clone)]
379pub struct DatasetMetadata {
380    pub(crate) p4_names: Vec<String>,
381    pub(crate) aux_names: Vec<String>,
382    pub(crate) p4_lookup: IndexMap<String, usize>,
383    pub(crate) aux_lookup: IndexMap<String, usize>,
384    pub(crate) p4_selections: IndexMap<String, P4Selection>,
385}
386
387impl DatasetMetadata {
388    /// Construct metadata from explicit particle and auxiliary names.
389    pub fn new<P: Into<String>, A: Into<String>>(
390        p4_names: Vec<P>,
391        aux_names: Vec<A>,
392    ) -> LadduResult<Self> {
393        let mut p4_lookup = IndexMap::with_capacity(p4_names.len());
394        let mut aux_lookup = IndexMap::with_capacity(aux_names.len());
395        let mut p4_selections = IndexMap::with_capacity(p4_names.len());
396        let p4_names: Vec<String> = p4_names
397            .into_iter()
398            .enumerate()
399            .map(|(idx, name)| {
400                let name = name.into();
401                if p4_lookup.contains_key(&name) {
402                    return Err(LadduError::DuplicateName {
403                        category: "p4",
404                        name,
405                    });
406                }
407                p4_lookup.insert(name.clone(), idx);
408                p4_selections.insert(
409                    name.clone(),
410                    P4Selection::with_indices(vec![name.clone()], vec![idx]),
411                );
412                Ok(name)
413            })
414            .collect::<Result<_, _>>()?;
415        let aux_names: Vec<String> = aux_names
416            .into_iter()
417            .enumerate()
418            .map(|(idx, name)| {
419                let name = name.into();
420                if aux_lookup.contains_key(&name) {
421                    return Err(LadduError::DuplicateName {
422                        category: "aux",
423                        name,
424                    });
425                }
426                aux_lookup.insert(name.clone(), idx);
427                Ok(name)
428            })
429            .collect::<Result<_, _>>()?;
430        Ok(Self {
431            p4_names,
432            aux_names,
433            p4_lookup,
434            aux_lookup,
435            p4_selections,
436        })
437    }
438
439    /// Create metadata with no registered names.
440    pub fn empty() -> Self {
441        Self {
442            p4_names: Vec::new(),
443            aux_names: Vec::new(),
444            p4_lookup: IndexMap::new(),
445            aux_lookup: IndexMap::new(),
446            p4_selections: IndexMap::new(),
447        }
448    }
449
450    /// Resolve the index of a four-momentum by name.
451    pub fn p4_index(&self, name: &str) -> Option<usize> {
452        self.p4_lookup.get(name).copied()
453    }
454
455    /// Registered four-momentum names in declaration order.
456    pub fn p4_names(&self) -> &[String] {
457        &self.p4_names
458    }
459
460    /// Resolve the index of an auxiliary scalar by name.
461    pub fn aux_index(&self, name: &str) -> Option<usize> {
462        self.aux_lookup.get(name).copied()
463    }
464
465    /// Registered auxiliary scalar names in declaration order.
466    pub fn aux_names(&self) -> &[String] {
467        &self.aux_names
468    }
469
470    /// Look up a resolved four-momentum selection by name (canonical or alias).
471    pub fn p4_selection(&self, name: &str) -> Option<&P4Selection> {
472        self.p4_selections.get(name)
473    }
474
475    /// Register an alias mapping to one or more existing four-momenta.
476    pub fn add_p4_alias<N>(&mut self, alias: N, mut selection: P4Selection) -> LadduResult<()>
477    where
478        N: Into<String>,
479    {
480        let alias = alias.into();
481        if self.p4_selections.contains_key(&alias) {
482            return Err(LadduError::DuplicateName {
483                category: "alias",
484                name: alias,
485            });
486        }
487        selection.bind(self)?;
488        self.p4_selections.insert(alias, selection);
489        Ok(())
490    }
491
492    /// Register multiple aliases at once.
493    pub fn add_p4_aliases<I, N>(&mut self, entries: I) -> LadduResult<()>
494    where
495        I: IntoIterator<Item = (N, P4Selection)>,
496        N: Into<String>,
497    {
498        for (alias, selection) in entries {
499            self.add_p4_alias(alias, selection)?;
500        }
501        Ok(())
502    }
503
504    pub(crate) fn append_indices_for_name(
505        &self,
506        name: &str,
507        target: &mut Vec<usize>,
508    ) -> LadduResult<()> {
509        if let Some(selection) = self.p4_selections.get(name) {
510            target.extend_from_slice(selection.indices());
511            return Ok(());
512        }
513        Err(LadduError::UnknownName {
514            category: "p4",
515            name: name.to_string(),
516        })
517    }
518}
519
520impl Default for DatasetMetadata {
521    fn default() -> Self {
522        Self::empty()
523    }
524}
525
526/// A collection of events with optional metadata for name-based lookups.
527#[derive(Debug, Clone)]
528pub struct Dataset {
529    /// The [`EventData`] contained in the [`Dataset`]
530    events: Vec<Event>,
531    pub(crate) columnar: DatasetStorage,
532    pub(crate) metadata: Arc<DatasetMetadata>,
533    pub(crate) cached_local_weighted_sum: f64,
534    #[cfg(feature = "mpi")]
535    pub(crate) cached_global_event_count: usize,
536    #[cfg(feature = "mpi")]
537    pub(crate) cached_global_weighted_sum: f64,
538}
539
540/// Metadata-aware view of an [`EventData`] with name-based helpers.
541#[derive(Clone, Debug)]
542pub struct Event {
543    event: Arc<EventData>,
544    metadata: Arc<DatasetMetadata>,
545}
546
547impl Event {
548    /// Create a new metadata-aware event from raw data and dataset metadata.
549    pub fn new(event: Arc<EventData>, metadata: Arc<DatasetMetadata>) -> Self {
550        Self { event, metadata }
551    }
552
553    /// Borrow the raw [`EventData`].
554    pub fn data(&self) -> &EventData {
555        &self.event
556    }
557
558    /// Obtain a clone of the underlying [`EventData`] handle.
559    pub fn data_arc(&self) -> Arc<EventData> {
560        self.event.clone()
561    }
562
563    /// Return the four-momenta stored in this event keyed by their registered names.
564    pub fn p4s(&self) -> IndexMap<&str, Vec4> {
565        let mut map = IndexMap::with_capacity(self.metadata.p4_names.len());
566        for (idx, name) in self.metadata.p4_names.iter().enumerate() {
567            if let Some(p4) = self.event.p4s.get(idx) {
568                map.insert(name.as_str(), *p4);
569            }
570        }
571        map
572    }
573
574    /// Return the auxiliary scalars stored in this event keyed by their registered names.
575    pub fn aux(&self) -> IndexMap<&str, f64> {
576        let mut map = IndexMap::with_capacity(self.metadata.aux_names.len());
577        for (idx, name) in self.metadata.aux_names.iter().enumerate() {
578            if let Some(value) = self.event.aux.get(idx) {
579                map.insert(name.as_str(), *value);
580            }
581        }
582        map
583    }
584
585    /// Return the event weight.
586    pub fn weight(&self) -> f64 {
587        self.event.weight
588    }
589
590    /// Retrieve the dataset metadata attached to this event.
591    pub fn metadata(&self) -> &DatasetMetadata {
592        &self.metadata
593    }
594
595    /// Clone the metadata handle associated with this event.
596    pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
597        self.metadata.clone()
598    }
599
600    /// Retrieve a four-momentum (or aliased sum) by name.
601    pub fn p4(&self, name: &str) -> Option<Vec4> {
602        self.metadata
603            .p4_selection(name)
604            .map(|selection| selection.momentum(&self.event))
605    }
606
607    fn resolve_p4_indices<N>(&self, names: N) -> Vec<usize>
608    where
609        N: IntoIterator,
610        N::Item: AsRef<str>,
611    {
612        let mut indices = Vec::new();
613        for name in names {
614            let name_ref = name.as_ref();
615            if let Some(selection) = self.metadata.p4_selection(name_ref) {
616                indices.extend_from_slice(selection.indices());
617            } else {
618                panic!("Unknown particle name '{name}'", name = name_ref);
619            }
620        }
621        indices
622    }
623
624    /// Return a four-momentum formed by summing four-momenta with the specified names.
625    pub fn get_p4_sum<N>(&self, names: N) -> Vec4
626    where
627        N: IntoIterator,
628        N::Item: AsRef<str>,
629    {
630        let indices = self.resolve_p4_indices(names);
631        self.event.get_p4_sum(&indices)
632    }
633
634    /// Boost all four-momenta into the rest frame defined by the specified particle names.
635    pub fn boost_to_rest_frame_of<N>(&self, names: N) -> EventData
636    where
637        N: IntoIterator,
638        N::Item: AsRef<str>,
639    {
640        let indices = self.resolve_p4_indices(names);
641        self.event.boost_to_rest_frame_of(&indices)
642    }
643}
644
645impl Deref for Event {
646    type Target = EventData;
647
648    fn deref(&self) -> &Self::Target {
649        &self.event
650    }
651}
652
653impl AsRef<EventData> for Event {
654    fn as_ref(&self) -> &EventData {
655        self.data()
656    }
657}
658
659impl IntoIterator for Dataset {
660    type Item = Event;
661
662    type IntoIter = DatasetIntoIter;
663
664    fn into_iter(self) -> Self::IntoIter {
665        #[cfg(feature = "mpi")]
666        {
667            if let Some(world) = crate::mpi::get_world() {
668                // Cache total before moving fields out of self for MPI iteration.
669                let total = self.n_events();
670                return DatasetIntoIter::Mpi(DatasetMpiIntoIter {
671                    events: self.events,
672                    metadata: self.metadata,
673                    world,
674                    index: 0,
675                    total,
676                    cursor: MpiEventChunkCursor::for_iteration(total),
677                });
678            }
679        }
680        DatasetIntoIter::Local(self.events.into_iter())
681    }
682}
683
684fn shared_dataset_iter(dataset: Arc<Dataset>) -> DatasetArcIter {
685    #[cfg(feature = "mpi")]
686    {
687        if let Some(world) = crate::mpi::get_world() {
688            let total = dataset.n_events();
689            return DatasetArcIter::Mpi(DatasetArcMpiIter {
690                dataset,
691                world,
692                index: 0,
693                total,
694                cursor: MpiEventChunkCursor::for_iteration(total),
695            });
696        }
697    }
698    DatasetArcIter::Local { dataset, index: 0 }
699}
700
701/// Extension methods for shared [`Arc<Dataset>`] handles.
702pub trait SharedDatasetIterExt {
703    /// Build an iterator over a shared [`Arc<Dataset>`] without cloning the dataset contents.
704    fn shared_iter(&self) -> DatasetArcIter;
705
706    /// Alias for [`SharedDatasetIterExt::shared_iter`].
707    fn shared_iter_global(&self) -> DatasetArcIter;
708}
709
710impl SharedDatasetIterExt for Arc<Dataset> {
711    fn shared_iter(&self) -> DatasetArcIter {
712        shared_dataset_iter(self.clone())
713    }
714
715    fn shared_iter_global(&self) -> DatasetArcIter {
716        self.shared_iter()
717    }
718}
719
720impl Dataset {
721    /// Borrow locally stored events.
722    ///
723    /// When MPI is enabled, this slice contains only the current rank's event ownership.
724    pub fn events_local(&self) -> &[Event] {
725        &self.events
726    }
727
728    /// Collect all events into a [`Vec`] using the default global iteration semantics.
729    ///
730    /// When MPI is enabled, the returned vector is ordered like [`Dataset::iter`] and
731    /// may include remotely owned events fetched on demand.
732    pub fn events_global(&self) -> Vec<Event> {
733        self.iter_global().collect()
734    }
735
736    #[cfg(test)]
737    pub(crate) fn clear_events_local(&mut self) {
738        self.events.clear();
739    }
740
741    /// Iterate over all events in the dataset. When MPI is enabled, this will visit
742    /// every event across all ranks, fetching remote events on demand.
743    pub fn iter(&self) -> DatasetIter<'_> {
744        #[cfg(feature = "mpi")]
745        {
746            if let Some(world) = crate::mpi::get_world() {
747                let total = self.n_events();
748                return DatasetIter::Mpi(DatasetMpiIter {
749                    dataset: self,
750                    world,
751                    index: 0,
752                    total,
753                    cursor: MpiEventChunkCursor::for_iteration(total),
754                });
755            }
756        }
757        DatasetIter::Local(self.events.iter())
758    }
759
760    /// Alias for [`Dataset::iter`].
761    ///
762    /// This preserves dataset-wide ordering under MPI.
763    pub fn iter_global(&self) -> DatasetIter<'_> {
764        self.iter()
765    }
766
767    /// Borrow the dataset metadata used for name lookups.
768    pub fn metadata(&self) -> &DatasetMetadata {
769        &self.metadata
770    }
771
772    /// Clone the internal metadata handle for external consumers (e.g., language bindings).
773    pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
774        self.metadata.clone()
775    }
776
777    /// Names corresponding to stored four-momenta.
778    pub fn p4_names(&self) -> &[String] {
779        &self.metadata.p4_names
780    }
781
782    /// Names corresponding to stored auxiliary scalars.
783    pub fn aux_names(&self) -> &[String] {
784        &self.metadata.aux_names
785    }
786
787    /// Resolve the index of a four-momentum by name.
788    pub fn p4_index(&self, name: &str) -> Option<usize> {
789        self.metadata.p4_index(name)
790    }
791
792    /// Resolve the index of an auxiliary scalar by name.
793    pub fn aux_index(&self, name: &str) -> Option<usize> {
794        self.metadata.aux_index(name)
795    }
796
797    /// Borrow event data together with metadata-based helpers as an [`Event`] view.
798    pub fn named_event(&self, index: usize) -> LadduResult<Event> {
799        self.event(index)
800    }
801
802    /// Alias for [`Dataset::named_event`].
803    pub fn named_event_global(&self, index: usize) -> LadduResult<Event> {
804        self.named_event(index)
805    }
806
807    /// Retrieve a single event by index, returning `None` when out of range.
808    pub fn get_event(&self, index: usize) -> Option<Event> {
809        #[cfg(feature = "mpi")]
810        {
811            if let Some(world) = crate::mpi::get_world() {
812                let total = self.n_events();
813                if index >= total {
814                    return None;
815                }
816                return Some(fetch_event_mpi(self, index, &world, total));
817            }
818        }
819
820        self.events.get(index).cloned()
821    }
822
823    /// Alias for [`Dataset::get_event`].
824    ///
825    /// This preserves the default global indexing semantics under MPI.
826    pub fn get_event_global(&self, index: usize) -> Option<Event> {
827        self.get_event(index)
828    }
829
830    /// Retrieve a single event by index.
831    pub fn event(&self, index: usize) -> LadduResult<Event> {
832        self.get_event(index).ok_or_else(|| {
833            LadduError::Custom(format!(
834                "Dataset index out of bounds: index {index}, length {}",
835                self.n_events()
836            ))
837        })
838    }
839
840    /// Alias for [`Dataset::event`].
841    ///
842    /// This preserves the default global indexing semantics under MPI.
843    pub fn event_global(&self, index: usize) -> LadduResult<Event> {
844        self.event(index)
845    }
846
847    /// Retrieve a four-momentum by name for the event at `event_index`.
848    pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
849        self.get_event(event_index).and_then(|event| event.p4(name))
850    }
851
852    /// Retrieve an auxiliary scalar by name for the event at `event_index`.
853    pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
854        let idx = self.aux_index(name)?;
855        self.get_event(event_index)
856            .and_then(|event| event.aux.get(idx).copied())
857    }
858
859    /// Iterate over all local events as metadata-aware columnar views.
860    pub fn for_each_named_event_local<F>(&self, op: F)
861    where
862        F: FnMut(usize, NamedEventView<'_>),
863    {
864        self.columnar.for_each_named_event_local(op);
865    }
866
867    /// Retrieve a metadata-aware columnar event view by local index.
868    pub fn event_view(&self, event_index: usize) -> NamedEventView<'_> {
869        self.columnar.event_view(event_index)
870    }
871
872    /// Get a reference to the [`EventData`] at the given index in the [`Dataset`] (non-MPI
873    /// version).
874    ///
875    /// # Notes
876    ///
877    /// This method is not intended to be called in analyses but rather in writing methods
878    /// that have `mpi`-feature-gated versions. Most users should use [`Dataset::event`] instead:
879    ///
880    /// ```ignore
881    /// let ds: Dataset = Dataset::new(events);
882    /// let event_0 = ds.event(0)?;
883    /// ```
884    pub fn index_local(&self, index: usize) -> &Event {
885        &self.events[index]
886    }
887
888    #[cfg(feature = "mpi")]
889    fn partition(
890        events: Vec<Arc<EventData>>,
891        world: &SimpleCommunicator,
892    ) -> Vec<Vec<Arc<EventData>>> {
893        let partition = world.partition(events.len());
894        (0..partition.n_ranks())
895            .map(|rank| {
896                let range = partition.range_for_rank(rank);
897                events[range.clone()].to_vec()
898            })
899            .collect()
900    }
901}
902
903/// Iterator over a [`Dataset`].
904pub enum DatasetIter<'a> {
905    /// Iterator over locally available events.
906    Local(std::slice::Iter<'a, Event>),
907    #[cfg(feature = "mpi")]
908    /// Iterator that fetches events across MPI ranks.
909    Mpi(DatasetMpiIter<'a>),
910}
911
912impl<'a> Iterator for DatasetIter<'a> {
913    type Item = Event;
914
915    fn next(&mut self) -> Option<Self::Item> {
916        match self {
917            DatasetIter::Local(iter) => iter.next().cloned(),
918            #[cfg(feature = "mpi")]
919            DatasetIter::Mpi(iter) => iter.next(),
920        }
921    }
922}
923
924/// Owning iterator over a [`Dataset`].
925pub enum DatasetIntoIter {
926    /// Iterator over locally available events, consuming the dataset.
927    Local(std::vec::IntoIter<Event>),
928    #[cfg(feature = "mpi")]
929    /// Iterator that fetches events across MPI ranks, consuming the dataset.
930    Mpi(DatasetMpiIntoIter),
931}
932
933impl Iterator for DatasetIntoIter {
934    type Item = Event;
935
936    fn next(&mut self) -> Option<Self::Item> {
937        match self {
938            DatasetIntoIter::Local(iter) => iter.next(),
939            #[cfg(feature = "mpi")]
940            DatasetIntoIter::Mpi(iter) => iter.next(),
941        }
942    }
943}
944
945/// Iterator over a shared [`Arc<Dataset>`].
946pub enum DatasetArcIter {
947    /// Iterator over locally available events from a shared dataset handle.
948    Local {
949        /// Shared dataset handle.
950        dataset: Arc<Dataset>,
951        /// Next local event index to read.
952        index: usize,
953    },
954    #[cfg(feature = "mpi")]
955    /// Iterator that fetches events across MPI ranks from a shared dataset handle.
956    Mpi(DatasetArcMpiIter),
957}
958
959impl Iterator for DatasetArcIter {
960    type Item = Event;
961
962    fn next(&mut self) -> Option<Self::Item> {
963        match self {
964            DatasetArcIter::Local { dataset, index } => {
965                let event = dataset.events.get(*index).cloned();
966                *index += 1;
967                event
968            }
969            #[cfg(feature = "mpi")]
970            DatasetArcIter::Mpi(iter) => iter.next(),
971        }
972    }
973}
974
975#[cfg(feature = "mpi")]
976/// Iterator over a [`Dataset`] that fetches events across MPI ranks.
977pub struct DatasetMpiIter<'a> {
978    dataset: &'a Dataset,
979    world: SimpleCommunicator,
980    index: usize,
981    total: usize,
982    cursor: MpiEventChunkCursor,
983}
984
985#[cfg(feature = "mpi")]
986#[derive(Debug, Clone)]
987pub(crate) struct MpiEventChunkCursor {
988    chunk_start: usize,
989    chunk_size: usize,
990    events: Vec<Event>,
991}
992
993#[cfg(feature = "mpi")]
994fn resolve_mpi_event_fetch_chunk_size(total: usize) -> usize {
995    let clamped_total = total.max(1);
996    if let Some(raw) = std::env::var_os(MPI_EVENT_FETCH_CHUNK_SIZE_ENV) {
997        if let Some(parsed) = raw.to_str().and_then(|value| value.parse::<usize>().ok()) {
998            return parsed.max(1).min(clamped_total);
999        }
1000    }
1001    DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE.min(clamped_total)
1002}
1003
1004#[cfg(feature = "mpi")]
1005impl MpiEventChunkCursor {
1006    pub(crate) fn for_iteration(total: usize) -> Self {
1007        Self::new(resolve_mpi_event_fetch_chunk_size(total))
1008    }
1009}
1010
1011#[cfg(feature = "mpi")]
1012impl MpiEventChunkCursor {
1013    pub(crate) fn new(chunk_size: usize) -> Self {
1014        Self {
1015            chunk_start: 0,
1016            chunk_size: chunk_size.max(1),
1017            events: Vec::new(),
1018        }
1019    }
1020
1021    fn chunk_end(&self) -> usize {
1022        self.chunk_start + self.events.len()
1023    }
1024
1025    fn contains(&self, global_index: usize) -> bool {
1026        global_index >= self.chunk_start && global_index < self.chunk_end()
1027    }
1028
1029    pub(crate) fn event_for_dataset(
1030        &mut self,
1031        dataset: &Dataset,
1032        global_index: usize,
1033        world: &SimpleCommunicator,
1034        total: usize,
1035    ) -> Option<Event> {
1036        if global_index >= total {
1037            return None;
1038        }
1039        if !self.contains(global_index) {
1040            self.chunk_start = global_index;
1041            self.events =
1042                fetch_event_chunk_mpi(dataset, global_index, self.chunk_size, world, total);
1043        }
1044        self.events.get(global_index - self.chunk_start).cloned()
1045    }
1046
1047    pub(crate) fn event_for_events(
1048        &mut self,
1049        events: &[Event],
1050        metadata: &Arc<DatasetMetadata>,
1051        global_index: usize,
1052        world: &SimpleCommunicator,
1053        total: usize,
1054    ) -> Option<Event> {
1055        if global_index >= total {
1056            return None;
1057        }
1058        if !self.contains(global_index) {
1059            self.chunk_start = global_index;
1060            self.events = fetch_event_chunk_mpi_from_events(
1061                events,
1062                metadata,
1063                global_index,
1064                self.chunk_size,
1065                world,
1066                total,
1067            );
1068        }
1069        self.events.get(global_index - self.chunk_start).cloned()
1070    }
1071}
1072
1073#[cfg(feature = "mpi")]
1074impl<'a> Iterator for DatasetMpiIter<'a> {
1075    type Item = Event;
1076
1077    fn next(&mut self) -> Option<Self::Item> {
1078        let event =
1079            self.cursor
1080                .event_for_dataset(self.dataset, self.index, &self.world, self.total);
1081        self.index += 1;
1082        event
1083    }
1084}
1085
1086#[cfg(feature = "mpi")]
1087/// Iterator over a shared [`Arc<Dataset>`] that fetches events across MPI ranks.
1088pub struct DatasetArcMpiIter {
1089    dataset: Arc<Dataset>,
1090    world: SimpleCommunicator,
1091    index: usize,
1092    total: usize,
1093    cursor: MpiEventChunkCursor,
1094}
1095
1096#[cfg(feature = "mpi")]
1097impl Iterator for DatasetArcMpiIter {
1098    type Item = Event;
1099
1100    fn next(&mut self) -> Option<Self::Item> {
1101        let event =
1102            self.cursor
1103                .event_for_dataset(&self.dataset, self.index, &self.world, self.total);
1104        self.index += 1;
1105        event
1106    }
1107}
1108
1109#[cfg(feature = "mpi")]
1110/// Owning iterator over a [`Dataset`] that fetches events across MPI ranks.
1111pub struct DatasetMpiIntoIter {
1112    events: Vec<Event>,
1113    metadata: Arc<DatasetMetadata>,
1114    world: SimpleCommunicator,
1115    index: usize,
1116    total: usize,
1117    cursor: MpiEventChunkCursor,
1118}
1119
1120#[cfg(feature = "mpi")]
1121impl Iterator for DatasetMpiIntoIter {
1122    type Item = Event;
1123
1124    fn next(&mut self) -> Option<Self::Item> {
1125        let event = self.cursor.event_for_events(
1126            &self.events,
1127            &self.metadata,
1128            self.index,
1129            &self.world,
1130            self.total,
1131        );
1132        self.index += 1;
1133        event
1134    }
1135}
1136
1137#[cfg(feature = "mpi")]
1138fn fetch_event_mpi(
1139    dataset: &Dataset,
1140    global_index: usize,
1141    world: &SimpleCommunicator,
1142    total: usize,
1143) -> Event {
1144    fetch_event_mpi_generic(
1145        global_index,
1146        total,
1147        world,
1148        &dataset.metadata,
1149        |local_index| dataset.index_local(local_index),
1150    )
1151}
1152
1153#[cfg(feature = "mpi")]
1154fn fetch_event_chunk_mpi(
1155    dataset: &Dataset,
1156    start: usize,
1157    len: usize,
1158    world: &SimpleCommunicator,
1159    total: usize,
1160) -> Vec<Event> {
1161    fetch_event_chunk_mpi_generic(start, len, total, world, &dataset.metadata, |local_index| {
1162        dataset.index_local(local_index)
1163    })
1164}
1165
1166#[cfg(feature = "mpi")]
1167fn fetch_event_chunk_mpi_from_events(
1168    events: &[Event],
1169    metadata: &Arc<DatasetMetadata>,
1170    start: usize,
1171    len: usize,
1172    world: &SimpleCommunicator,
1173    total: usize,
1174) -> Vec<Event> {
1175    fetch_event_chunk_mpi_generic(start, len, total, world, metadata, |local_index| {
1176        &events[local_index]
1177    })
1178}
1179
1180#[cfg(feature = "mpi")]
1181fn fetch_event_mpi_generic<'a, F>(
1182    global_index: usize,
1183    total: usize,
1184    world: &SimpleCommunicator,
1185    metadata: &Arc<DatasetMetadata>,
1186    local_event: F,
1187) -> Event
1188where
1189    F: Fn(usize) -> &'a Event,
1190{
1191    let (owning_rank, local_index) = world.owner_of_global_index(global_index, total);
1192    let mut serialized_event_buffer_len: usize = 0;
1193    let mut serialized_event_buffer: Vec<u8> = Vec::default();
1194    if world.rank() == owning_rank {
1195        let event = local_event(local_index);
1196        serialized_event_buffer = bitcode::serialize(event.data()).unwrap();
1197        serialized_event_buffer_len = serialized_event_buffer.len();
1198    }
1199    world
1200        .process_at_rank(owning_rank)
1201        .broadcast_into(&mut serialized_event_buffer_len);
1202    if world.rank() != owning_rank {
1203        serialized_event_buffer = vec![0; serialized_event_buffer_len];
1204    }
1205    world
1206        .process_at_rank(owning_rank)
1207        .broadcast_into(&mut serialized_event_buffer);
1208
1209    if world.rank() == owning_rank {
1210        local_event(local_index).clone()
1211    } else {
1212        let event: EventData = bitcode::deserialize(&serialized_event_buffer[..]).unwrap();
1213        Event::new(Arc::new(event), metadata.clone())
1214    }
1215}
1216
1217#[cfg(feature = "mpi")]
1218#[allow(dead_code)]
1219fn fetch_event_chunk_mpi_generic<'a, F>(
1220    start: usize,
1221    len: usize,
1222    total: usize,
1223    world: &SimpleCommunicator,
1224    metadata: &Arc<DatasetMetadata>,
1225    local_event: F,
1226) -> Vec<Event>
1227where
1228    F: Fn(usize) -> &'a Event,
1229{
1230    if len == 0 || start >= total {
1231        return Vec::new();
1232    }
1233
1234    let end = (start + len).min(total);
1235    let partition = world.partition(total);
1236    let local_range = partition.range_for_rank(world.rank() as usize);
1237    let owned_start = start.max(local_range.start);
1238    let owned_end = end.min(local_range.end);
1239    let local_indices = if owned_start < owned_end {
1240        (owned_start - local_range.start)..(owned_end - local_range.start)
1241    } else {
1242        0..0
1243    };
1244
1245    let local_events: Vec<EventData> = local_indices
1246        .map(|local_index| local_event(local_index).data().clone())
1247        .collect();
1248    let local_event_count = local_events.len() as i32;
1249
1250    let serialized_local = if local_events.is_empty() {
1251        Vec::new()
1252    } else {
1253        bitcode::serialize(&local_events).unwrap()
1254    };
1255    let local_byte_count = serialized_local.len() as i32;
1256
1257    let mut gathered_event_counts = vec![0_i32; world.size() as usize];
1258    let mut gathered_byte_counts = vec![0_i32; world.size() as usize];
1259    world.all_gather_into(&local_event_count, &mut gathered_event_counts);
1260    world.all_gather_into(&local_byte_count, &mut gathered_byte_counts);
1261
1262    let mut gathered_byte_displs = vec![0_i32; gathered_byte_counts.len()];
1263    for index in 1..gathered_byte_displs.len() {
1264        gathered_byte_displs[index] =
1265            gathered_byte_displs[index - 1] + gathered_byte_counts[index - 1];
1266    }
1267    let gathered_bytes = world.all_gather_with_counts(
1268        &serialized_local,
1269        &gathered_byte_counts,
1270        &gathered_byte_displs,
1271    );
1272
1273    let mut events = Vec::with_capacity(end - start);
1274    for rank in 0..world.size() as usize {
1275        if gathered_event_counts[rank] == 0 {
1276            continue;
1277        }
1278        let byte_start = gathered_byte_displs[rank] as usize;
1279        let byte_end = byte_start + gathered_byte_counts[rank] as usize;
1280        let decoded: Vec<EventData> =
1281            bitcode::deserialize(&gathered_bytes[byte_start..byte_end]).unwrap();
1282        debug_assert_eq!(decoded.len(), gathered_event_counts[rank] as usize);
1283        events.extend(
1284            decoded
1285                .into_iter()
1286                .map(|event| Event::new(Arc::new(event), metadata.clone())),
1287        );
1288    }
1289
1290    events
1291}
1292
1293impl Dataset {
1294    #[cfg(feature = "mpi")]
1295    pub(crate) fn set_cached_global_event_count_from_world(&mut self, world: &SimpleCommunicator) {
1296        let local_count = self.n_events_local();
1297        let mut global_count = 0usize;
1298        world.all_reduce_into(
1299            &local_count,
1300            &mut global_count,
1301            mpi::collective::SystemOperation::sum(),
1302        );
1303        self.cached_global_event_count = global_count;
1304    }
1305
1306    #[cfg(feature = "mpi")]
1307    pub(crate) fn set_cached_global_weighted_sum_from_world(&mut self, world: &SimpleCommunicator) {
1308        let mut weighted_sums = vec![0.0_f64; world.size() as usize];
1309        world.all_gather_into(&self.cached_local_weighted_sum, &mut weighted_sums);
1310        #[cfg(feature = "rayon")]
1311        {
1312            self.cached_global_weighted_sum = weighted_sums
1313                .into_par_iter()
1314                .parallel_sum_with_accumulator::<Klein<f64>>();
1315        }
1316        #[cfg(not(feature = "rayon"))]
1317        {
1318            self.cached_global_weighted_sum = weighted_sums
1319                .into_iter()
1320                .sum_with_accumulator::<Klein<f64>>();
1321        }
1322    }
1323
1324    fn columnar_from_wrapped_events(
1325        events: &[Event],
1326        metadata: Arc<DatasetMetadata>,
1327    ) -> LadduResult<DatasetStorage> {
1328        let n_events = events.len();
1329        let (n_p4, n_aux) = match events.first() {
1330            Some(first) => (first.p4s.len(), first.aux.len()),
1331            None => (metadata.p4_names.len(), metadata.aux_names.len()),
1332        };
1333        let mut p4 = (0..n_p4)
1334            .map(|_| ColumnarP4Column::with_capacity(n_events))
1335            .collect::<Vec<_>>();
1336        let mut aux = (0..n_aux)
1337            .map(|_| Vec::with_capacity(n_events))
1338            .collect::<Vec<_>>();
1339        let mut weights = Vec::with_capacity(n_events);
1340        for (event_index, event) in events.iter().enumerate() {
1341            if event.p4s.len() != n_p4 || event.aux.len() != n_aux {
1342                return Err(LadduError::Custom(format!(
1343                    "Ragged dataset shape at event {event_index}: expected ({n_p4} p4, {n_aux} aux), got ({} p4, {} aux)",
1344                    event.p4s.len(),
1345                    event.aux.len()
1346                )));
1347            }
1348            for (column, value) in p4.iter_mut().zip(event.p4s.iter()) {
1349                column.push(*value);
1350            }
1351            for (column, value) in aux.iter_mut().zip(event.aux.iter()) {
1352                column.push(*value);
1353            }
1354            weights.push(event.weight);
1355        }
1356        Ok(DatasetStorage {
1357            metadata,
1358            p4,
1359            aux,
1360            weights,
1361        })
1362    }
1363
1364    /// Create a new [`Dataset`] from a list of [`EventData`] (non-MPI version).
1365    ///
1366    /// # Notes
1367    ///
1368    /// This method is not intended to be called in analyses but rather in writing methods
1369    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
1370    pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1371        let wrapped_events = events
1372            .into_iter()
1373            .map(|event| Event::new(event, metadata.clone()))
1374            .collect::<Vec<_>>();
1375        #[cfg(feature = "mpi")]
1376        let local_count = wrapped_events.len();
1377        let columnar = Self::columnar_from_wrapped_events(&wrapped_events, metadata.clone())
1378            .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1379        #[cfg(feature = "rayon")]
1380        let local_weighted_sum = columnar
1381            .weights
1382            .par_iter()
1383            .copied()
1384            .parallel_sum_with_accumulator::<Klein<f64>>();
1385        #[cfg(not(feature = "rayon"))]
1386        let local_weighted_sum = columnar
1387            .weights
1388            .iter()
1389            .copied()
1390            .sum_with_accumulator::<Klein<f64>>();
1391        Dataset {
1392            events: wrapped_events,
1393            columnar,
1394            metadata,
1395            cached_local_weighted_sum: local_weighted_sum,
1396            #[cfg(feature = "mpi")]
1397            cached_global_event_count: local_count,
1398            #[cfg(feature = "mpi")]
1399            cached_global_weighted_sum: local_weighted_sum,
1400        }
1401    }
1402
1403    /// Create a new [`Dataset`] from a list of [`EventData`] (MPI-compatible version).
1404    ///
1405    /// # Notes
1406    ///
1407    /// This method is not intended to be called in analyses but rather in writing methods
1408    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
1409    #[cfg(feature = "mpi")]
1410    pub fn new_mpi(
1411        events: Vec<Arc<EventData>>,
1412        metadata: Arc<DatasetMetadata>,
1413        world: &SimpleCommunicator,
1414    ) -> Self {
1415        let partitions = Dataset::partition(events, world);
1416        let local: Vec<Event> = partitions[world.rank() as usize]
1417            .iter()
1418            .cloned()
1419            .map(|event| Event::new(event, metadata.clone()))
1420            .collect();
1421        let columnar = Self::columnar_from_wrapped_events(&local, metadata.clone())
1422            .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1423        #[cfg(feature = "rayon")]
1424        let local_weighted_sum = columnar
1425            .weights
1426            .par_iter()
1427            .copied()
1428            .parallel_sum_with_accumulator::<Klein<f64>>();
1429        #[cfg(not(feature = "rayon"))]
1430        let local_weighted_sum = columnar
1431            .weights
1432            .iter()
1433            .copied()
1434            .sum_with_accumulator::<Klein<f64>>();
1435        let mut dataset = Dataset {
1436            events: local,
1437            columnar,
1438            metadata,
1439            cached_local_weighted_sum: local_weighted_sum,
1440            cached_global_event_count: 0,
1441            cached_global_weighted_sum: local_weighted_sum,
1442        };
1443        dataset.set_cached_global_event_count_from_world(world);
1444        dataset.set_cached_global_weighted_sum_from_world(world);
1445        dataset
1446    }
1447
1448    /// Create a new [`Dataset`] from a list of [`EventData`].
1449    ///
1450    /// This method is prefered for external use because it contains proper MPI construction
1451    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
1452    /// interfacing with MPI and should be avoided unless you know what you are doing.
1453    pub fn new(events: Vec<Arc<EventData>>) -> Self {
1454        Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
1455    }
1456
1457    /// Create a dataset with explicit metadata for name-based lookups.
1458    /// Create a dataset with explicit metadata for name-based lookups.
1459    pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1460        #[cfg(feature = "mpi")]
1461        {
1462            if let Some(world) = crate::mpi::get_world() {
1463                return Dataset::new_mpi(events, metadata, &world);
1464            }
1465        }
1466        Dataset::new_local(events, metadata)
1467    }
1468
1469    /// The number of [`EventData`]s in the [`Dataset`] (non-MPI version).
1470    ///
1471    /// # Notes
1472    ///
1473    /// This method is not intended to be called in analyses but rather in writing methods
1474    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
1475    pub fn n_events_local(&self) -> usize {
1476        self.columnar.n_events()
1477    }
1478
1479    /// The number of [`EventData`]s in the [`Dataset`] (MPI-compatible version).
1480    ///
1481    /// # Notes
1482    ///
1483    /// This method is not intended to be called in analyses but rather in writing methods
1484    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
1485    #[cfg(feature = "mpi")]
1486    pub fn n_events_mpi(&self, _world: &SimpleCommunicator) -> usize {
1487        self.cached_global_event_count
1488    }
1489
1490    /// The number of [`EventData`]s in the [`Dataset`].
1491    pub fn n_events(&self) -> usize {
1492        #[cfg(feature = "mpi")]
1493        {
1494            if let Some(world) = crate::mpi::get_world() {
1495                return self.n_events_mpi(&world);
1496            }
1497        }
1498        self.n_events_local()
1499    }
1500
1501    /// Alias for [`Dataset::n_events`].
1502    ///
1503    /// This returns the global event count under MPI.
1504    pub fn n_events_global(&self) -> usize {
1505        self.n_events()
1506    }
1507}
1508
1509impl Dataset {
1510    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (non-MPI version).
1511    ///
1512    /// # Notes
1513    ///
1514    /// This method is not intended to be called in analyses but rather in writing methods
1515    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
1516    pub fn weights_local(&self) -> Vec<f64> {
1517        self.columnar.weights.clone()
1518    }
1519
1520    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (MPI-compatible version).
1521    ///
1522    /// # Notes
1523    ///
1524    /// This method is not intended to be called in analyses but rather in writing methods
1525    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
1526    #[cfg(feature = "mpi")]
1527    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
1528        let local_weights = self.weights_local();
1529        let n_events = self.n_events();
1530        let mut buffer: Vec<f64> = vec![0.0; n_events];
1531        let (counts, displs) = world.get_counts_displs(n_events);
1532        {
1533            // NOTE: gather is required because this API returns full global event weights.
1534            // Use all-reduce only for scalar/vector aggregate values.
1535            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1536            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
1537        }
1538        buffer
1539    }
1540
1541    /// Extract a list of weights over each [`EventData`] in the [`Dataset`].
1542    pub fn weights(&self) -> Vec<f64> {
1543        #[cfg(feature = "mpi")]
1544        {
1545            if let Some(world) = crate::mpi::get_world() {
1546                return self.weights_mpi(&world);
1547            }
1548        }
1549        self.weights_local()
1550    }
1551
1552    /// Alias for [`Dataset::weights`].
1553    ///
1554    /// This returns the global weight vector in dataset order under MPI.
1555    pub fn weights_global(&self) -> Vec<f64> {
1556        self.weights()
1557    }
1558
1559    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (non-MPI version).
1560    ///
1561    /// # Notes
1562    ///
1563    /// This method is not intended to be called in analyses but rather in writing methods
1564    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
1565    pub fn n_events_weighted_local(&self) -> f64 {
1566        self.cached_local_weighted_sum
1567    }
1568    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (MPI-compatible version).
1569    ///
1570    /// # Notes
1571    ///
1572    /// This method is not intended to be called in analyses but rather in writing methods
1573    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
1574    #[cfg(feature = "mpi")]
1575    pub fn n_events_weighted_mpi(&self, _world: &SimpleCommunicator) -> f64 {
1576        self.cached_global_weighted_sum
1577    }
1578
1579    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`].
1580    pub fn n_events_weighted(&self) -> f64 {
1581        #[cfg(feature = "mpi")]
1582        {
1583            if let Some(world) = crate::mpi::get_world() {
1584                return self.n_events_weighted_mpi(&world);
1585            }
1586        }
1587        self.n_events_weighted_local()
1588    }
1589
1590    /// Alias for [`Dataset::n_events_weighted`].
1591    ///
1592    /// This returns the global weighted event count under MPI.
1593    pub fn n_events_weighted_global(&self) -> f64 {
1594        self.n_events_weighted()
1595    }
1596
1597    /// Generate a new dataset with the same length by resampling the events in the original datset
1598    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
1599    ///
1600    /// # Notes
1601    ///
1602    /// This method is not intended to be called in analyses but rather in writing methods
1603    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
1604    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
1605        let mut rng = fastrand::Rng::with_seed(seed as u64);
1606        let mut indices: Vec<usize> = (0..self.n_events())
1607            .map(|_| rng.usize(0..self.n_events()))
1608            .collect::<Vec<usize>>();
1609        indices.sort();
1610        #[cfg(feature = "rayon")]
1611        let bootstrapped_events: Vec<Arc<EventData>> = indices
1612            .into_par_iter()
1613            .map(|idx| self.events[idx].data_arc())
1614            .collect();
1615        #[cfg(not(feature = "rayon"))]
1616        let bootstrapped_events: Vec<Arc<EventData>> = indices
1617            .into_iter()
1618            .map(|idx| self.events[idx].data_arc())
1619            .collect();
1620        Arc::new(Dataset::new_with_metadata(
1621            bootstrapped_events,
1622            self.metadata.clone(),
1623        ))
1624    }
1625
1626    /// Generate a new dataset with the same length by resampling the events in the original datset
1627    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
1628    ///
1629    /// # Notes
1630    ///
1631    /// This method is not intended to be called in analyses but rather in writing methods
1632    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
1633    #[cfg(feature = "mpi")]
1634    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
1635        let n_events = self.n_events();
1636        let mut indices: Vec<usize> = vec![0; n_events];
1637        if world.is_root() {
1638            let mut rng = fastrand::Rng::with_seed(seed as u64);
1639            indices = (0..n_events)
1640                .map(|_| rng.usize(0..n_events))
1641                .collect::<Vec<usize>>();
1642            indices.sort();
1643        }
1644        world.process_at_root().broadcast_into(&mut indices);
1645        let local_indices: Vec<usize> = indices
1646            .into_iter()
1647            .filter_map(|idx| {
1648                let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
1649                if world.rank() == owning_rank {
1650                    Some(local_index)
1651                } else {
1652                    None
1653                }
1654            })
1655            .collect();
1656        // `local_indices` only contains indices owned by the current rank, translating them into
1657        // local indices on the events vector.
1658        #[cfg(feature = "rayon")]
1659        let bootstrapped_events: Vec<Arc<EventData>> = local_indices
1660            .into_par_iter()
1661            .map(|idx| self.events[idx].data_arc())
1662            .collect();
1663        #[cfg(not(feature = "rayon"))]
1664        let bootstrapped_events: Vec<Arc<EventData>> = local_indices
1665            .into_iter()
1666            .map(|idx| self.events[idx].data_arc())
1667            .collect();
1668        Arc::new(Dataset::new_with_metadata(
1669            bootstrapped_events,
1670            self.metadata.clone(),
1671        ))
1672    }
1673
1674    /// Generate a new dataset with the same length by resampling the events in the original datset
1675    /// with replacement. This can be used to perform error analysis via the bootstrap method.
1676    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1677        #[cfg(feature = "mpi")]
1678        {
1679            if let Some(world) = crate::mpi::get_world() {
1680                return self.bootstrap_mpi(seed, &world);
1681            }
1682        }
1683        self.bootstrap_local(seed)
1684    }
1685
1686    /// Filter the [`Dataset`] by a given [`VariableExpression`], selecting events for which
1687    /// the expression returns `true`.
1688    pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1689        let compiled = expression.compile(&self.metadata)?;
1690        #[cfg(feature = "rayon")]
1691        let filtered_events: Vec<Arc<EventData>> = (0..self.n_events_local())
1692            .into_par_iter()
1693            .filter_map(|event_index| {
1694                let event = self.event_view(event_index);
1695                compiled
1696                    .evaluate(&event)
1697                    .then(|| self.events[event_index].data_arc())
1698            })
1699            .collect();
1700        #[cfg(not(feature = "rayon"))]
1701        let filtered_events: Vec<Arc<EventData>> = (0..self.n_events_local())
1702            .into_iter()
1703            .filter_map(|event_index| {
1704                let event = self.event_view(event_index);
1705                compiled
1706                    .evaluate(&event)
1707                    .then(|| self.events[event_index].data_arc())
1708            })
1709            .collect();
1710        Ok(Arc::new(Dataset::new_with_metadata(
1711            filtered_events,
1712            self.metadata.clone(),
1713        )))
1714    }
1715
1716    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
1717    /// given `range`.
1718    pub fn bin_by<V>(
1719        &self,
1720        mut variable: V,
1721        bins: usize,
1722        range: (f64, f64),
1723    ) -> LadduResult<BinnedDataset>
1724    where
1725        V: Variable,
1726    {
1727        variable.bind(self.metadata())?;
1728        let bin_width = (range.1 - range.0) / bins as f64;
1729        let bin_edges = get_bin_edges(bins, range);
1730        let variable = variable;
1731        #[cfg(feature = "rayon")]
1732        let evaluated: Vec<(usize, Arc<EventData>)> = (0..self.n_events_local())
1733            .into_par_iter()
1734            .filter_map(|event| {
1735                let value = variable.value(&self.event_view(event));
1736                if value >= range.0 && value < range.1 {
1737                    let bin_index = ((value - range.0) / bin_width) as usize;
1738                    let bin_index = bin_index.min(bins - 1);
1739                    Some((bin_index, self.events[event].data_arc()))
1740                } else {
1741                    None
1742                }
1743            })
1744            .collect();
1745        #[cfg(not(feature = "rayon"))]
1746        let evaluated: Vec<(usize, Arc<EventData>)> = (0..self.n_events_local())
1747            .into_iter()
1748            .filter_map(|event| {
1749                let value = variable.value(&self.event_view(event));
1750                if value >= range.0 && value < range.1 {
1751                    let bin_index = ((value - range.0) / bin_width) as usize;
1752                    let bin_index = bin_index.min(bins - 1);
1753                    Some((bin_index, self.events[event].data_arc()))
1754                } else {
1755                    None
1756                }
1757            })
1758            .collect();
1759        let mut binned_events: Vec<Vec<Arc<EventData>>> = vec![Vec::default(); bins];
1760        for (bin_index, event) in evaluated {
1761            binned_events[bin_index].push(event.clone());
1762        }
1763        #[cfg(feature = "rayon")]
1764        let datasets: Vec<Arc<Dataset>> = binned_events
1765            .into_par_iter()
1766            .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1767            .collect();
1768        #[cfg(not(feature = "rayon"))]
1769        let datasets: Vec<Arc<Dataset>> = binned_events
1770            .into_iter()
1771            .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1772            .collect();
1773        Ok(BinnedDataset {
1774            datasets,
1775            edges: bin_edges,
1776        })
1777    }
1778
1779    /// Boost all the four-momenta in all [`EventData`]s to the rest frame of the given set of
1780    /// four-momenta identified by name.
1781    pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1782    where
1783        S: AsRef<str>,
1784    {
1785        let mut indices: Vec<usize> = Vec::new();
1786        for name in names {
1787            let name_ref = name.as_ref();
1788            if let Some(selection) = self.metadata.p4_selection(name_ref) {
1789                indices.extend_from_slice(selection.indices());
1790            } else {
1791                panic!("Unknown particle name '{name}'", name = name_ref);
1792            }
1793        }
1794        #[cfg(feature = "rayon")]
1795        let boosted_events: Vec<Arc<EventData>> = self
1796            .events
1797            .par_iter()
1798            .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1799            .collect();
1800        #[cfg(not(feature = "rayon"))]
1801        let boosted_events: Vec<Arc<EventData>> = self
1802            .events
1803            .iter()
1804            .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1805            .collect();
1806        Arc::new(Dataset::new_with_metadata(
1807            boosted_events,
1808            self.metadata.clone(),
1809        ))
1810    }
1811    /// Evaluate a [`Variable`] on every event in the [`Dataset`].
1812    pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
1813        variable.value_on(self)
1814    }
1815}
1816
1817#[cfg(test)]
1818pub(crate) use io::write_parquet_storage;
1819pub use io::{
1820    read_parquet, read_parquet_chunks, read_parquet_chunks_with_options, read_root, write_parquet,
1821    write_root,
1822};
1823#[cfg(test)]
1824pub(crate) use io::{read_parquet_storage, read_root_storage};
1825
1826impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
1827    debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
1828    debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
1829    let events = a
1830        .events
1831        .iter()
1832        .chain(b.events.iter())
1833        .map(Event::data_arc)
1834        .collect::<Vec<_>>();
1835    Dataset::new_with_metadata(events, a.metadata.clone())
1836});
1837
1838/// Incrementally builds a [`Dataset`] from chunked dataset reads.
1839#[derive(Default)]
1840pub struct DatasetChunkBuilder {
1841    metadata: Option<Arc<DatasetMetadata>>,
1842    events: Vec<Arc<EventData>>,
1843}
1844
1845impl DatasetChunkBuilder {
1846    /// Create an empty chunk builder.
1847    pub fn new() -> Self {
1848        Self::default()
1849    }
1850
1851    /// Append a dataset chunk.
1852    pub fn push_chunk(&mut self, chunk: &Dataset) -> LadduResult<()> {
1853        if let Some(existing) = &self.metadata {
1854            if existing.p4_names != chunk.metadata.p4_names
1855                || existing.aux_names != chunk.metadata.aux_names
1856            {
1857                return Err(LadduError::Custom(
1858                    "Dataset chunk metadata does not match previous chunks".to_string(),
1859                ));
1860            }
1861        } else {
1862            self.metadata = Some(chunk.metadata.clone());
1863        }
1864        self.events
1865            .extend(chunk.events_local().iter().map(Event::data_arc));
1866        Ok(())
1867    }
1868
1869    /// Finish building a dataset from all received chunks.
1870    pub fn finish(self) -> Arc<Dataset> {
1871        let metadata = self
1872            .metadata
1873            .unwrap_or_else(|| Arc::new(DatasetMetadata::empty()));
1874        Arc::new(Dataset::new_with_metadata(self.events, metadata))
1875    }
1876}
1877
1878/// Fold over chunked datasets without materializing a full dataset.
1879pub fn try_fold_dataset_chunks<I, T, F>(chunks: I, init: T, mut op: F) -> LadduResult<T>
1880where
1881    I: IntoIterator<Item = LadduResult<Arc<Dataset>>>,
1882    F: FnMut(T, &Dataset) -> LadduResult<T>,
1883{
1884    let mut acc = init;
1885    for chunk in chunks {
1886        let chunk = chunk?;
1887        acc = op(acc, &chunk)?;
1888    }
1889    Ok(acc)
1890}
1891
1892/// Options for reading a [`Dataset`] from a file.
1893///
1894/// # See Also
1895/// [`read_parquet`], [`read_root`]
1896#[derive(Default, Clone)]
1897pub struct DatasetReadOptions {
1898    /// Particle names to read from the data file.
1899    pub p4_names: Option<Vec<String>>,
1900    /// Auxiliary scalar names to read from the data file.
1901    pub aux_names: Option<Vec<String>>,
1902    /// Name of the tree to read when loading ROOT files. When absent and the file contains a
1903    /// single tree, it will be selected automatically.
1904    pub tree: Option<String>,
1905    /// Optional aliases mapping logical names to selections of four-momenta.
1906    pub aliases: IndexMap<String, P4Selection>,
1907    /// Preferred chunk size for chunked read APIs.
1908    pub chunk_size: Option<usize>,
1909}
1910
1911/// Precision for writing floating-point columns.
1912#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
1913pub enum FloatPrecision {
1914    /// 32-bit floats.
1915    F32,
1916    /// 64-bit floats.
1917    #[default]
1918    F64,
1919}
1920
1921/// Options for writing a [`Dataset`] to disk.
1922#[derive(Clone, Debug)]
1923pub struct DatasetWriteOptions {
1924    /// Number of events to include in each batch when writing.
1925    pub batch_size: usize,
1926    /// Floating-point precision to use for persisted columns.
1927    pub precision: FloatPrecision,
1928    /// Tree name to use when writing ROOT files.
1929    pub tree: Option<String>,
1930}
1931
1932impl Default for DatasetWriteOptions {
1933    fn default() -> Self {
1934        Self {
1935            batch_size: DEFAULT_WRITE_BATCH_SIZE,
1936            precision: FloatPrecision::default(),
1937            tree: None,
1938        }
1939    }
1940}
1941
1942impl DatasetWriteOptions {
1943    /// Override the batch size used for writing; defaults to 10_000.
1944    pub fn batch_size(mut self, batch_size: usize) -> Self {
1945        self.batch_size = batch_size;
1946        self
1947    }
1948
1949    /// Select the floating-point precision for persisted columns.
1950    pub fn precision(mut self, precision: FloatPrecision) -> Self {
1951        self.precision = precision;
1952        self
1953    }
1954
1955    /// Set the ROOT tree name (defaults to \"events\").
1956    pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
1957        self.tree = Some(name.into());
1958        self
1959    }
1960}
1961impl DatasetReadOptions {
1962    /// Create a new [`Default`] set of [`DatasetReadOptions`].
1963    pub fn new() -> Self {
1964        Self::default()
1965    }
1966
1967    /// If provided, the specified particles will be read from the data file (assuming columns with
1968    /// required suffixes are present, i.e. `<particle>_px`, `<particle>_py`, `<particle>_pz`, and `<particle>_e`). Otherwise, all valid columns with these suffixes will be read.
1969    pub fn p4_names<I, S>(mut self, names: I) -> Self
1970    where
1971        I: IntoIterator<Item = S>,
1972        S: AsRef<str>,
1973    {
1974        self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1975        self
1976    }
1977
1978    /// If provided, the specified columns will be read as auxiliary scalars. Otherwise, all valid
1979    /// columns which do not satisfy the conditions required to be read as four-momenta will be
1980    /// used.
1981    pub fn aux_names<I, S>(mut self, names: I) -> Self
1982    where
1983        I: IntoIterator<Item = S>,
1984        S: AsRef<str>,
1985    {
1986        self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1987        self
1988    }
1989
1990    /// Select the tree to read when opening ROOT files.
1991    pub fn tree<S>(mut self, name: S) -> Self
1992    where
1993        S: AsRef<str>,
1994    {
1995        self.tree = Some(name.as_ref().to_string());
1996        self
1997    }
1998
1999    /// Register an alias for one or more existing four-momenta.
2000    pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
2001    where
2002        N: Into<String>,
2003        S: IntoP4Selection,
2004    {
2005        self.aliases.insert(name.into(), selection.into_selection());
2006        self
2007    }
2008
2009    /// Register multiple aliases for four-momenta selections.
2010    pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
2011    where
2012        I: IntoIterator<Item = (N, S)>,
2013        N: Into<String>,
2014        S: IntoP4Selection,
2015    {
2016        for (name, selection) in aliases {
2017            self = self.alias(name, selection);
2018        }
2019        self
2020    }
2021
2022    /// Set the chunk size used by chunked read APIs; values below 1 are clamped to 1.
2023    pub fn chunk_size(mut self, chunk_size: usize) -> Self {
2024        self.chunk_size = Some(chunk_size.max(1));
2025        self
2026    }
2027
2028    fn resolve_metadata(
2029        &self,
2030        detected_p4_names: Vec<String>,
2031        detected_aux_names: Vec<String>,
2032    ) -> LadduResult<Arc<DatasetMetadata>> {
2033        let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
2034        let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
2035
2036        let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
2037        if !self.aliases.is_empty() {
2038            metadata.add_p4_aliases(self.aliases.clone())?;
2039        }
2040        Ok(Arc::new(metadata))
2041    }
2042}
2043
2044const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
2045pub(crate) const DEFAULT_READ_CHUNK_SIZE: usize = 10_000;
2046
2047/// A list of [`Dataset`]s formed by binning [`EventData`] by some [`Variable`].
2048pub struct BinnedDataset {
2049    datasets: Vec<Arc<Dataset>>,
2050    edges: Vec<f64>,
2051}
2052
2053impl Index<usize> for BinnedDataset {
2054    type Output = Arc<Dataset>;
2055
2056    fn index(&self, index: usize) -> &Self::Output {
2057        &self.datasets[index]
2058    }
2059}
2060
2061impl IndexMut<usize> for BinnedDataset {
2062    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2063        &mut self.datasets[index]
2064    }
2065}
2066
2067impl Deref for BinnedDataset {
2068    type Target = Vec<Arc<Dataset>>;
2069
2070    fn deref(&self) -> &Self::Target {
2071        &self.datasets
2072    }
2073}
2074
2075impl DerefMut for BinnedDataset {
2076    fn deref_mut(&mut self) -> &mut Self::Target {
2077        &mut self.datasets
2078    }
2079}
2080
2081impl BinnedDataset {
2082    /// The number of bins in the [`BinnedDataset`].
2083    pub fn n_bins(&self) -> usize {
2084        self.datasets.len()
2085    }
2086
2087    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
2088    pub fn edges(&self) -> Vec<f64> {
2089        self.edges.clone()
2090    }
2091
2092    /// Returns the range that was used to form the [`BinnedDataset`].
2093    pub fn range(&self) -> (f64, f64) {
2094        (self.edges[0], self.edges[self.n_bins()])
2095    }
2096}
2097
2098#[cfg(test)]
2099mod tests {
2100    use crate::Mass;
2101
2102    use super::*;
2103    #[cfg(feature = "mpi")]
2104    use crate::mpi::{finalize_mpi, get_world, use_mpi};
2105    use crate::utils::vectors::Vec3;
2106    use approx::{assert_relative_eq, assert_relative_ne};
2107    use fastrand;
2108    #[cfg(feature = "mpi")]
2109    use mpi_test::mpi_test;
2110    use serde::{Deserialize, Serialize};
2111    use std::{
2112        env, fs,
2113        path::{Path, PathBuf},
2114    };
2115
2116    fn test_data_path(file: &str) -> PathBuf {
2117        Path::new(env!("CARGO_MANIFEST_DIR"))
2118            .join("test_data")
2119            .join(file)
2120    }
2121
2122    fn open_test_dataset(file: &str, options: DatasetReadOptions) -> Arc<Dataset> {
2123        let path = test_data_path(file);
2124        let path_str = path.to_str().expect("test data path should be valid UTF-8");
2125        let ext = path
2126            .extension()
2127            .and_then(|ext| ext.to_str())
2128            .unwrap_or_default()
2129            .to_ascii_lowercase();
2130        match ext.as_str() {
2131            "parquet" => read_parquet(path_str, &options),
2132            "root" => read_root(path_str, &options),
2133            other => panic!("Unsupported extension in test data: {other}"),
2134        }
2135        .expect("dataset should open")
2136    }
2137
2138    fn make_temp_dir() -> PathBuf {
2139        let dir = env::temp_dir().join(format!("laddu_test_{}", fastrand::u64(..)));
2140        fs::create_dir(&dir).expect("temp dir should be created");
2141        dir
2142    }
2143
2144    #[cfg(feature = "mpi")]
2145    fn mpi_chunk_test_dataset(n_events: usize) -> Dataset {
2146        let metadata = test_dataset().metadata_arc();
2147        let base = test_event();
2148        let events = (0..n_events)
2149            .map(|index| {
2150                let mut event = base.clone();
2151                event.p4s[0] =
2152                    Vec3::new(index as f64 * 0.1, 0.0, 8.747 + index as f64 * 0.01).with_mass(0.0);
2153                event.aux[0] += index as f64;
2154                event.aux[1] += index as f64 * 0.5;
2155                event.weight = 1.0 + index as f64;
2156                Arc::new(event)
2157            })
2158            .collect();
2159        Dataset::new_with_metadata(events, metadata)
2160    }
2161
2162    fn assert_events_close(left: &Event, right: &Event, p4_names: &[&str], aux_names: &[&str]) {
2163        for name in p4_names {
2164            let lp4 = left
2165                .p4(name)
2166                .unwrap_or_else(|| panic!("missing p4 '{name}' in left dataset"));
2167            let rp4 = right
2168                .p4(name)
2169                .unwrap_or_else(|| panic!("missing p4 '{name}' in right dataset"));
2170            assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-9);
2171            assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-9);
2172            assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-9);
2173            assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-9);
2174        }
2175        let left_aux = left.aux();
2176        let right_aux = right.aux();
2177        for name in aux_names {
2178            let laux = left_aux
2179                .get(name)
2180                .copied()
2181                .unwrap_or_else(|| panic!("missing aux '{name}' in left dataset"));
2182            let raux = right_aux
2183                .get(name)
2184                .copied()
2185                .unwrap_or_else(|| panic!("missing aux '{name}' in right dataset"));
2186            assert_relative_eq!(laux, raux, epsilon = 1e-9);
2187        }
2188        assert_relative_eq!(left.weight(), right.weight(), epsilon = 1e-9);
2189    }
2190
2191    fn assert_datasets_close(
2192        left: &Arc<Dataset>,
2193        right: &Arc<Dataset>,
2194        p4_names: &[&str],
2195        aux_names: &[&str],
2196    ) {
2197        assert_eq!(left.n_events(), right.n_events());
2198        for idx in 0..left.n_events() {
2199            let Ok(levent) = left.event(idx) else {
2200                panic!("left dataset missing event at index {idx}");
2201            };
2202            let Ok(revent) = right.event(idx) else {
2203                panic!("right dataset missing event at index {idx}");
2204            };
2205            assert_events_close(&levent, &revent, p4_names, aux_names);
2206        }
2207    }
2208
2209    fn assert_dataset_columnar_close(left: &DatasetStorage, right: &DatasetStorage) {
2210        assert_eq!(left.n_events(), right.n_events());
2211        assert_eq!(left.metadata().p4_names(), right.metadata().p4_names());
2212        assert_eq!(left.metadata().aux_names(), right.metadata().aux_names());
2213        for event_index in 0..left.n_events() {
2214            for p4_index in 0..left.metadata().p4_names().len() {
2215                let lp4 = left.p4(event_index, p4_index);
2216                let rp4 = right.p4(event_index, p4_index);
2217                assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-12);
2218                assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-12);
2219                assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-12);
2220                assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-12);
2221            }
2222            for aux_index in 0..left.metadata().aux_names().len() {
2223                let l = left.aux(event_index, aux_index);
2224                let r = right.aux(event_index, aux_index);
2225                assert_relative_eq!(l, r, epsilon = 1e-12);
2226            }
2227            let lw = left.weight(event_index);
2228            let rw = right.weight(event_index);
2229            assert_relative_eq!(lw, rw, epsilon = 1e-12);
2230        }
2231    }
2232
2233    #[test]
2234    fn test_from_parquet_auto_matches_explicit_names() {
2235        let auto = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2236        let explicit_options = DatasetReadOptions::new()
2237            .p4_names(TEST_P4_NAMES)
2238            .aux_names(TEST_AUX_NAMES);
2239        let explicit = open_test_dataset("data_f32.parquet", explicit_options);
2240
2241        let mut detected_p4: Vec<&str> = auto.p4_names().iter().map(String::as_str).collect();
2242        detected_p4.sort_unstable();
2243        let mut expected_p4 = TEST_P4_NAMES.to_vec();
2244        expected_p4.sort_unstable();
2245        assert_eq!(detected_p4, expected_p4);
2246        let mut detected_aux: Vec<&str> = auto.aux_names().iter().map(String::as_str).collect();
2247        detected_aux.sort_unstable();
2248        let mut expected_aux = TEST_AUX_NAMES.to_vec();
2249        expected_aux.sort_unstable();
2250        assert_eq!(detected_aux, expected_aux);
2251        assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2252    }
2253
2254    #[test]
2255    fn test_from_parquet_with_aliases() {
2256        let dataset = open_test_dataset(
2257            "data_f32.parquet",
2258            DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2259        );
2260        let event = dataset.named_event(0).expect("event should exist");
2261        let alias_vec = event.p4("resonance").expect("alias vector");
2262        let expected = event.get_p4_sum(["kshort1", "kshort2"]);
2263        assert_relative_eq!(alias_vec.px(), expected.px(), epsilon = 1e-9);
2264        assert_relative_eq!(alias_vec.py(), expected.py(), epsilon = 1e-9);
2265        assert_relative_eq!(alias_vec.pz(), expected.pz(), epsilon = 1e-9);
2266        assert_relative_eq!(alias_vec.e(), expected.e(), epsilon = 1e-9);
2267    }
2268
2269    #[test]
2270    fn test_from_parquet_alias_resolution_parity_auto_vs_explicit() {
2271        let auto = open_test_dataset(
2272            "data_f32.parquet",
2273            DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2274        );
2275        let explicit = open_test_dataset(
2276            "data_f32.parquet",
2277            DatasetReadOptions::new()
2278                .p4_names(TEST_P4_NAMES)
2279                .aux_names(TEST_AUX_NAMES)
2280                .alias("resonance", ["kshort1", "kshort2"]),
2281        );
2282
2283        assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2284        for event_index in 0..auto.n_events() {
2285            let auto_event = auto
2286                .named_event(event_index)
2287                .expect("auto parquet event should exist");
2288            let explicit_event = explicit
2289                .named_event(event_index)
2290                .expect("explicit parquet event should exist");
2291
2292            let auto_alias = auto_event
2293                .p4("resonance")
2294                .expect("auto alias should resolve");
2295            let explicit_alias = explicit_event
2296                .p4("resonance")
2297                .expect("explicit alias should resolve");
2298            let auto_expected = auto_event.get_p4_sum(["kshort1", "kshort2"]);
2299            let explicit_expected = explicit_event.get_p4_sum(["kshort1", "kshort2"]);
2300
2301            assert_relative_eq!(auto_alias.px(), auto_expected.px(), epsilon = 1e-9);
2302            assert_relative_eq!(auto_alias.py(), auto_expected.py(), epsilon = 1e-9);
2303            assert_relative_eq!(auto_alias.pz(), auto_expected.pz(), epsilon = 1e-9);
2304            assert_relative_eq!(auto_alias.e(), auto_expected.e(), epsilon = 1e-9);
2305
2306            assert_relative_eq!(explicit_alias.px(), explicit_expected.px(), epsilon = 1e-9);
2307            assert_relative_eq!(explicit_alias.py(), explicit_expected.py(), epsilon = 1e-9);
2308            assert_relative_eq!(explicit_alias.pz(), explicit_expected.pz(), epsilon = 1e-9);
2309            assert_relative_eq!(explicit_alias.e(), explicit_expected.e(), epsilon = 1e-9);
2310        }
2311    }
2312
2313    #[test]
2314    fn test_from_parquet_f64_matches_f32() {
2315        let f32_ds = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2316        let f64_ds = open_test_dataset("data_f64.parquet", DatasetReadOptions::new());
2317        assert_datasets_close(&f64_ds, &f32_ds, TEST_P4_NAMES, TEST_AUX_NAMES);
2318    }
2319
2320    #[test]
2321    fn test_from_root_detects_columns_and_matches_parquet() {
2322        let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2323        let root_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2324        let mut detected_p4: Vec<&str> = root_auto.p4_names().iter().map(String::as_str).collect();
2325        detected_p4.sort_unstable();
2326        let mut expected_p4 = TEST_P4_NAMES.to_vec();
2327        expected_p4.sort_unstable();
2328        assert_eq!(detected_p4, expected_p4);
2329        let mut detected_aux: Vec<&str> =
2330            root_auto.aux_names().iter().map(String::as_str).collect();
2331        detected_aux.sort_unstable();
2332        let mut expected_aux = TEST_AUX_NAMES.to_vec();
2333        expected_aux.sort_unstable();
2334        assert_eq!(detected_aux, expected_aux);
2335        let root_named_options = DatasetReadOptions::new()
2336            .p4_names(TEST_P4_NAMES)
2337            .aux_names(TEST_AUX_NAMES);
2338        let root_named = open_test_dataset("data_f32.root", root_named_options);
2339        assert_datasets_close(&root_auto, &root_named, TEST_P4_NAMES, TEST_AUX_NAMES);
2340        assert_datasets_close(&root_auto, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2341    }
2342
2343    #[cfg(feature = "mpi")]
2344    #[mpi_test(np = [2])]
2345    fn test_from_root_metadata_matches_non_mpi_under_mpi() {
2346        let reference_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2347        let explicit_options = DatasetReadOptions::new()
2348            .p4_names(TEST_P4_NAMES)
2349            .aux_names(TEST_AUX_NAMES);
2350        let reference_explicit = open_test_dataset("data_f32.root", explicit_options.clone());
2351
2352        use_mpi(true);
2353        let local_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2354        let local_explicit = open_test_dataset("data_f32.root", explicit_options);
2355
2356        assert_eq!(local_auto.p4_names(), reference_auto.p4_names());
2357        assert_eq!(local_auto.aux_names(), reference_auto.aux_names());
2358        assert_eq!(local_explicit.p4_names(), reference_explicit.p4_names());
2359        assert_eq!(local_explicit.aux_names(), reference_explicit.aux_names());
2360        assert_eq!(local_auto.p4_names(), local_explicit.p4_names());
2361        assert_eq!(local_auto.aux_names(), local_explicit.aux_names());
2362
2363        for name in local_auto.p4_names() {
2364            let local_auto_selection = local_auto
2365                .metadata()
2366                .p4_selection(name)
2367                .expect("local auto canonical p4 selection should exist");
2368            let reference_auto_selection = reference_auto
2369                .metadata()
2370                .p4_selection(name)
2371                .expect("reference auto canonical p4 selection should exist");
2372            let local_explicit_selection = local_explicit
2373                .metadata()
2374                .p4_selection(name)
2375                .expect("local explicit canonical p4 selection should exist");
2376            assert_eq!(
2377                local_auto_selection.names(),
2378                reference_auto_selection.names()
2379            );
2380            assert_eq!(
2381                local_auto_selection.indices(),
2382                reference_auto_selection.indices()
2383            );
2384            assert_eq!(
2385                local_explicit_selection.names(),
2386                reference_auto_selection.names()
2387            );
2388            assert_eq!(
2389                local_explicit_selection.indices(),
2390                reference_auto_selection.indices()
2391            );
2392        }
2393
2394        finalize_mpi();
2395    }
2396
2397    #[test]
2398    fn test_from_root_f64_matches_parquet() {
2399        let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2400        let root_f64 = open_test_dataset("data_f64.root", DatasetReadOptions::new());
2401        assert_datasets_close(&root_f64, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2402    }
2403
2404    #[cfg(feature = "mpi")]
2405    #[mpi_test(np = [2])]
2406    fn test_from_root_alias_resolution_matches_non_mpi_under_mpi() {
2407        let alias_options = DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]);
2408        let explicit_alias_options = DatasetReadOptions::new()
2409            .p4_names(TEST_P4_NAMES)
2410            .aux_names(TEST_AUX_NAMES)
2411            .alias("resonance", ["kshort1", "kshort2"]);
2412        let reference_auto = open_test_dataset("data_f32.root", alias_options.clone());
2413        let reference_explicit = open_test_dataset("data_f32.root", explicit_alias_options.clone());
2414
2415        use_mpi(true);
2416        let world = get_world().expect("MPI world should be initialized");
2417        let local_auto = open_test_dataset("data_f32.root", alias_options);
2418        let local_explicit = open_test_dataset("data_f32.root", explicit_alias_options);
2419
2420        let local_auto_alias = local_auto
2421            .metadata()
2422            .p4_selection("resonance")
2423            .expect("local auto alias should exist");
2424        let local_explicit_alias = local_explicit
2425            .metadata()
2426            .p4_selection("resonance")
2427            .expect("local explicit alias should exist");
2428        let reference_alias = reference_auto
2429            .metadata()
2430            .p4_selection("resonance")
2431            .expect("reference alias should exist");
2432        let reference_explicit_alias = reference_explicit
2433            .metadata()
2434            .p4_selection("resonance")
2435            .expect("reference explicit alias should exist");
2436        assert_eq!(local_auto_alias.names(), reference_alias.names());
2437        assert_eq!(local_auto_alias.indices(), reference_alias.indices());
2438        assert_eq!(
2439            local_explicit_alias.names(),
2440            reference_explicit_alias.names()
2441        );
2442        assert_eq!(
2443            local_explicit_alias.indices(),
2444            reference_explicit_alias.indices()
2445        );
2446        assert_eq!(local_auto_alias.names(), local_explicit_alias.names());
2447        assert_eq!(local_auto_alias.indices(), local_explicit_alias.indices());
2448
2449        let partition = world.partition(reference_auto.n_events());
2450        let local_range = partition.range_for_rank(world.rank() as usize);
2451        assert_eq!(local_auto.n_events_local(), local_range.len());
2452        assert_eq!(local_explicit.n_events_local(), local_range.len());
2453
2454        for (local_index, global_index) in local_range.enumerate() {
2455            let local_auto_event = local_auto.event_view(local_index);
2456            let local_explicit_event = local_explicit.event_view(local_index);
2457            let reference_event = reference_auto.event_view(global_index);
2458            let reference_explicit_event = reference_explicit.event_view(global_index);
2459
2460            let local_auto_value = local_auto_event
2461                .p4("resonance")
2462                .expect("local auto alias should resolve");
2463            let local_explicit_value = local_explicit_event
2464                .p4("resonance")
2465                .expect("local explicit alias should resolve");
2466            let reference_value = reference_event
2467                .p4("resonance")
2468                .expect("reference alias should resolve");
2469            let reference_explicit_value = reference_explicit_event
2470                .p4("resonance")
2471                .expect("reference explicit alias should resolve");
2472
2473            assert_relative_eq!(local_auto_value.px(), reference_value.px(), epsilon = 1e-9);
2474            assert_relative_eq!(local_auto_value.py(), reference_value.py(), epsilon = 1e-9);
2475            assert_relative_eq!(local_auto_value.pz(), reference_value.pz(), epsilon = 1e-9);
2476            assert_relative_eq!(local_auto_value.e(), reference_value.e(), epsilon = 1e-9);
2477
2478            assert_relative_eq!(
2479                local_explicit_value.px(),
2480                reference_explicit_value.px(),
2481                epsilon = 1e-9
2482            );
2483            assert_relative_eq!(
2484                local_explicit_value.py(),
2485                reference_explicit_value.py(),
2486                epsilon = 1e-9
2487            );
2488            assert_relative_eq!(
2489                local_explicit_value.pz(),
2490                reference_explicit_value.pz(),
2491                epsilon = 1e-9
2492            );
2493            assert_relative_eq!(
2494                local_explicit_value.e(),
2495                reference_explicit_value.e(),
2496                epsilon = 1e-9
2497            );
2498        }
2499
2500        finalize_mpi();
2501    }
2502
2503    #[test]
2504    fn test_event_creation() {
2505        let event = test_event();
2506        assert_eq!(event.p4s.len(), 4);
2507        assert_eq!(event.aux.len(), 2);
2508        assert_relative_eq!(event.weight, 0.48)
2509    }
2510
2511    #[test]
2512    fn test_event_p4_sum() {
2513        let event = test_event();
2514        let sum = event.get_p4_sum([2, 3]);
2515        assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
2516        assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
2517        assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
2518        assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
2519    }
2520
2521    #[test]
2522    fn test_event_boost() {
2523        let event = test_event();
2524        let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
2525        let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
2526        assert_relative_eq!(p4_sum.px(), 0.0);
2527        assert_relative_eq!(p4_sum.py(), 0.0);
2528        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2529    }
2530
2531    #[test]
2532    fn test_named_event_view_evaluate() {
2533        let dataset = test_dataset();
2534        let event = dataset.event_view(0);
2535        let mut mass = Mass::new(["proton"]);
2536        mass.bind(dataset.metadata()).unwrap();
2537        assert_relative_eq!(event.evaluate(&mass), 1.007);
2538    }
2539
2540    #[test]
2541    fn test_dataset_size_check() {
2542        let dataset = Dataset::new(Vec::new());
2543        assert_eq!(dataset.n_events(), 0);
2544        let dataset = Dataset::new(vec![Arc::new(test_event())]);
2545        assert_eq!(dataset.n_events(), 1);
2546    }
2547
2548    #[test]
2549    fn test_dataset_sum() {
2550        let dataset = test_dataset();
2551        let metadata = dataset.metadata_arc();
2552        let dataset2 = Dataset::new_with_metadata(
2553            vec![Arc::new(EventData {
2554                p4s: test_event().p4s,
2555                aux: test_event().aux,
2556                weight: 0.52,
2557            })],
2558            metadata.clone(),
2559        );
2560        let dataset_sum = &dataset + &dataset2;
2561        assert_eq!(
2562            dataset_sum.event(0).expect("event should exist").weight,
2563            dataset.event(0).expect("event should exist").weight
2564        );
2565        assert_eq!(
2566            dataset_sum.event(1).expect("event should exist").weight,
2567            dataset2.event(0).expect("event should exist").weight
2568        );
2569    }
2570
2571    #[test]
2572    fn test_dataset_weights() {
2573        let dataset = Dataset::new(vec![
2574            Arc::new(test_event()),
2575            Arc::new(EventData {
2576                p4s: test_event().p4s,
2577                aux: test_event().aux,
2578                weight: 0.52,
2579            }),
2580        ]);
2581        let weights = dataset.weights();
2582        assert_eq!(weights.len(), 2);
2583        assert_relative_eq!(weights[0], 0.48);
2584        assert_relative_eq!(weights[1], 0.52);
2585        assert_relative_eq!(dataset.n_events_weighted(), 1.0);
2586    }
2587
2588    #[test]
2589    #[should_panic(
2590        expected = "Dataset requires rectangular p4/aux columns for canonical columnar storage"
2591    )]
2592    fn test_dataset_rejects_ragged_rows_at_construction() {
2593        let _ = Dataset::new(vec![
2594            Arc::new(EventData {
2595                p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0)],
2596                aux: vec![0.1],
2597                weight: 1.0,
2598            }),
2599            Arc::new(EventData {
2600                p4s: vec![],
2601                aux: vec![0.2, 0.3],
2602                weight: 2.0,
2603            }),
2604        ]);
2605    }
2606
2607    #[test]
2608    fn test_dataset_filtering() {
2609        let metadata = Arc::new(
2610            DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2611                .expect("metadata should be valid"),
2612        );
2613        let events = vec![
2614            Arc::new(EventData {
2615                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
2616                aux: vec![],
2617                weight: 1.0,
2618            }),
2619            Arc::new(EventData {
2620                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
2621                aux: vec![],
2622                weight: 1.0,
2623            }),
2624            Arc::new(EventData {
2625                p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
2626                // HACK: using 1.0 messes with this test because the eventual computation gives a mass
2627                // slightly less than 1.0
2628                aux: vec![],
2629                weight: 1.0,
2630            }),
2631        ];
2632        let dataset = Dataset::new_with_metadata(events, metadata);
2633
2634        let metadata = dataset.metadata_arc();
2635        let mut mass = Mass::new(["beam"]);
2636        mass.bind(metadata.as_ref()).unwrap();
2637        let expression = mass.gt(0.0).and(&mass.lt(1.0));
2638
2639        let filtered = dataset.filter(&expression).unwrap();
2640        assert_eq!(filtered.n_events(), 1);
2641        assert_relative_eq!(mass.value(&filtered.event_view(0)), 0.5);
2642    }
2643
2644    #[test]
2645    fn test_dataset_boost() {
2646        let dataset = test_dataset();
2647        let dataset_boosted = dataset.boost_to_rest_frame_of(&["proton", "kshort1", "kshort2"]);
2648        let p4_sum = dataset_boosted
2649            .event(0)
2650            .expect("event should exist")
2651            .get_p4_sum(["proton", "kshort1", "kshort2"]);
2652        assert_relative_eq!(p4_sum.px(), 0.0);
2653        assert_relative_eq!(p4_sum.py(), 0.0);
2654        assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2655    }
2656
2657    #[test]
2658    fn test_named_event_view() {
2659        let dataset = test_dataset();
2660        let view = dataset.named_event(0).expect("event should exist");
2661        let dataset_event = dataset.event(0).expect("event should exist");
2662        assert_relative_eq!(view.weight(), dataset_event.weight);
2663        let beam = view.p4("beam").expect("beam p4");
2664        assert_relative_eq!(beam.px(), dataset_event.p4s[0].px());
2665        assert_relative_eq!(beam.e(), dataset_event.p4s[0].e());
2666
2667        let summed = view.get_p4_sum(["kshort1", "kshort2"]);
2668        assert_relative_eq!(
2669            summed.e(),
2670            dataset_event.p4s[2].e() + dataset_event.p4s[3].e()
2671        );
2672
2673        let aux_angle = view.aux().get("pol_angle").copied().expect("pol angle");
2674        assert_relative_eq!(aux_angle, dataset_event.aux[1]);
2675
2676        let metadata = dataset.metadata_arc();
2677        let boosted = view.boost_to_rest_frame_of(["proton", "kshort1", "kshort2"]);
2678        let boosted_event = Event::new(Arc::new(boosted), metadata);
2679        let boosted_sum = boosted_event.get_p4_sum(["proton", "kshort1", "kshort2"]);
2680        assert_relative_eq!(boosted_sum.px(), 0.0);
2681    }
2682
2683    #[test]
2684    fn test_dataset_evaluate() {
2685        let dataset = test_dataset();
2686        let mass = Mass::new(["proton"]);
2687        assert_relative_eq!(dataset.evaluate(&mass).unwrap()[0], 1.007);
2688    }
2689
2690    #[test]
2691    fn test_dataset_metadata_rejects_duplicate_names() {
2692        let err = DatasetMetadata::new(vec!["beam", "beam"], Vec::<String>::new());
2693        assert!(matches!(
2694            err,
2695            Err(LadduError::DuplicateName { category, .. }) if category == "p4"
2696        ));
2697        let err = DatasetMetadata::new(
2698            vec!["beam"],
2699            vec!["pol_angle".to_string(), "pol_angle".to_string()],
2700        );
2701        assert!(matches!(
2702            err,
2703            Err(LadduError::DuplicateName { category, .. }) if category == "aux"
2704        ));
2705    }
2706
2707    #[test]
2708    fn test_dataset_lookup_by_name() {
2709        let dataset = test_dataset();
2710        let proton = dataset.p4_by_name(0, "proton").expect("proton p4");
2711        let proton_idx = dataset.metadata().p4_index("proton").unwrap();
2712        assert_relative_eq!(
2713            proton.e(),
2714            dataset.event(0).expect("event should exist").p4s[proton_idx].e()
2715        );
2716        assert!(dataset.p4_by_name(0, "unknown").is_none());
2717        let angle = dataset.aux_by_name(0, "pol_angle").expect("pol_angle");
2718        assert_relative_eq!(angle, dataset.event(0).expect("event should exist").aux[1]);
2719        assert!(dataset.aux_by_name(0, "missing").is_none());
2720    }
2721
2722    #[test]
2723    fn test_binned_dataset() {
2724        let dataset = Dataset::new(vec![
2725            Arc::new(EventData {
2726                p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
2727                aux: vec![],
2728                weight: 1.0,
2729            }),
2730            Arc::new(EventData {
2731                p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
2732                aux: vec![],
2733                weight: 2.0,
2734            }),
2735        ]);
2736
2737        #[derive(Clone, Serialize, Deserialize, Debug)]
2738        struct BeamEnergy;
2739        impl Display for BeamEnergy {
2740            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2741                write!(f, "BeamEnergy")
2742            }
2743        }
2744        #[typetag::serde]
2745        impl Variable for BeamEnergy {
2746            fn value(&self, event: &NamedEventView<'_>) -> f64 {
2747                event.p4_at(0).e()
2748            }
2749        }
2750        assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
2751
2752        // Test binning by first particle energy
2753        let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0)).unwrap();
2754
2755        assert_eq!(binned.n_bins(), 2);
2756        assert_eq!(binned.edges().len(), 3);
2757        assert_relative_eq!(binned.edges()[0], 0.0);
2758        assert_relative_eq!(binned.edges()[2], 3.0);
2759        assert_eq!(binned[0].n_events(), 1);
2760        assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
2761        assert_eq!(binned[1].n_events(), 1);
2762        assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
2763    }
2764
2765    #[test]
2766    fn test_dataset_bootstrap() {
2767        let metadata = test_dataset().metadata_arc();
2768        let dataset = Dataset::new_with_metadata(
2769            vec![
2770                Arc::new(test_event()),
2771                Arc::new(EventData {
2772                    p4s: test_event().p4s.clone(),
2773                    aux: test_event().aux.clone(),
2774                    weight: 1.0,
2775                }),
2776            ],
2777            metadata,
2778        );
2779        assert_relative_ne!(
2780            dataset.event(0).expect("event should exist").weight,
2781            dataset.event(1).expect("event should exist").weight
2782        );
2783
2784        let bootstrapped = dataset.bootstrap(43);
2785        assert_eq!(bootstrapped.n_events(), dataset.n_events());
2786        assert_relative_eq!(
2787            bootstrapped.event(0).expect("event should exist").weight,
2788            bootstrapped.event(1).expect("event should exist").weight
2789        );
2790
2791        // Test empty dataset bootstrap
2792        let empty_dataset = Dataset::new(Vec::new());
2793        let empty_bootstrap = empty_dataset.bootstrap(43);
2794        assert_eq!(empty_bootstrap.n_events(), 0);
2795    }
2796
2797    fn assert_weight_cache_matches_local_events(dataset: &Dataset) {
2798        #[cfg(feature = "rayon")]
2799        let expected = dataset
2800            .events_local()
2801            .par_iter()
2802            .map(|event| event.weight)
2803            .parallel_sum_with_accumulator::<Klein<f64>>();
2804        #[cfg(not(feature = "rayon"))]
2805        let expected = dataset
2806            .events_local()
2807            .iter()
2808            .map(|event| event.weight)
2809            .sum_with_accumulator::<Klein<f64>>();
2810        assert_relative_eq!(dataset.cached_local_weighted_sum, expected);
2811        assert_relative_eq!(dataset.n_events_weighted_local(), expected);
2812    }
2813
2814    #[test]
2815    fn test_weight_cache_recomputed_for_dataset_transforms() {
2816        let metadata = Arc::new(
2817            DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2818                .expect("metadata should be valid"),
2819        );
2820        let dataset = Dataset::new_with_metadata(
2821            vec![
2822                Arc::new(EventData {
2823                    p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(0.0)],
2824                    aux: vec![],
2825                    weight: 1.0,
2826                }),
2827                Arc::new(EventData {
2828                    p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(0.0)],
2829                    aux: vec![],
2830                    weight: 2.0,
2831                }),
2832                Arc::new(EventData {
2833                    p4s: vec![Vec3::new(0.0, 0.0, 3.0).with_mass(0.0)],
2834                    aux: vec![],
2835                    weight: 3.0,
2836                }),
2837            ],
2838            metadata,
2839        );
2840        assert_weight_cache_matches_local_events(&dataset);
2841
2842        let filtered = dataset.filter(&Mass::new(["beam"]).gt(0.0)).unwrap();
2843        assert_weight_cache_matches_local_events(&filtered);
2844
2845        let bootstrapped = dataset.bootstrap(7);
2846        assert_weight_cache_matches_local_events(&bootstrapped);
2847
2848        let boosted = dataset.boost_to_rest_frame_of(&["beam"]);
2849        assert_weight_cache_matches_local_events(&boosted);
2850
2851        let combined = &dataset + &dataset;
2852        assert_weight_cache_matches_local_events(&combined);
2853    }
2854
2855    #[test]
2856    fn test_dataset_iteration_returns_events() {
2857        let dataset = test_dataset();
2858        let mut weights = Vec::new();
2859        for event in dataset.iter() {
2860            weights.push(event.weight());
2861        }
2862        assert_eq!(weights.len(), dataset.n_events());
2863        assert_relative_eq!(
2864            weights[0],
2865            dataset.event(0).expect("event should exist").weight
2866        );
2867    }
2868
2869    #[test]
2870    fn test_dataset_into_iter_returns_events() {
2871        let dataset = test_dataset();
2872        let weights: Vec<f64> = dataset.into_iter().map(|event| event.weight()).collect();
2873        assert_eq!(weights.len(), 1);
2874        assert_relative_eq!(weights[0], test_event().weight);
2875    }
2876
2877    #[test]
2878    fn test_dataset_arc_into_iter_returns_events() {
2879        let dataset = Arc::new(test_dataset());
2880        let weights: Vec<f64> = dataset.shared_iter().map(|event| event.weight()).collect();
2881        assert_eq!(weights.len(), 1);
2882        assert_relative_eq!(weights[0], test_event().weight);
2883    }
2884
2885    #[test]
2886    fn test_dataset_get_event_local_reuses_underlying_data() {
2887        let dataset = test_dataset();
2888        let first = dataset.get_event(0).expect("event should exist");
2889        let second = dataset.get_event(0).expect("event should exist");
2890        assert!(Arc::ptr_eq(&first.data_arc(), &second.data_arc()));
2891    }
2892
2893    #[test]
2894    fn test_dataset_event_out_of_bounds_is_error() {
2895        let dataset = test_dataset();
2896        assert!(dataset.event(99).is_err());
2897        assert!(dataset.get_event(99).is_none());
2898    }
2899
2900    #[cfg(feature = "mpi")]
2901    fn event_iteration_signature<I>(iter: I) -> (usize, f64, f64, f64)
2902    where
2903        I: IntoIterator<Item = Event>,
2904    {
2905        let mut count = 0usize;
2906        let mut weight_signature = 0.0;
2907        let mut beam_signature = 0.0;
2908        let mut aux_signature = 0.0;
2909
2910        for (index, event) in iter.into_iter().enumerate() {
2911            let position = (index + 1) as f64;
2912            count += 1;
2913            weight_signature += position * event.weight();
2914            beam_signature += position * event.p4("beam").expect("beam should exist").e();
2915            aux_signature += position
2916                * event
2917                    .aux()
2918                    .get("pol_angle")
2919                    .copied()
2920                    .expect("pol_angle should exist");
2921        }
2922
2923        (count, weight_signature, beam_signature, aux_signature)
2924    }
2925
2926    #[cfg(feature = "mpi")]
2927    fn read_resident_rss_kb() -> Option<u64> {
2928        #[cfg(target_os = "linux")]
2929        {
2930            let status = fs::read_to_string("/proc/self/status").ok()?;
2931            let vm_rss = status
2932                .lines()
2933                .find(|line| line.starts_with("VmRSS:"))?
2934                .split_whitespace()
2935                .nth(1)?;
2936            vm_rss.parse::<u64>().ok()
2937        }
2938
2939        #[cfg(not(target_os = "linux"))]
2940        {
2941            None
2942        }
2943    }
2944
2945    #[test]
2946    fn test_dataset_event_stress_local_repeated_access() {
2947        let metadata = test_dataset().metadata_arc();
2948        let base = test_event();
2949        let mut events = Vec::new();
2950        for idx in 0..8 {
2951            events.push(Arc::new(EventData {
2952                p4s: base.p4s.clone(),
2953                aux: base.aux.clone(),
2954                weight: 1.0 + idx as f64,
2955            }));
2956        }
2957        let dataset = Dataset::new_with_metadata(events, metadata);
2958        let baseline: Vec<f64> = (0..dataset.n_events())
2959            .map(|index| dataset.event(index).expect("event should exist").weight())
2960            .collect();
2961
2962        for _ in 0..250 {
2963            for (index, expected_weight) in baseline.iter().enumerate() {
2964                let event = dataset.event(index).expect("event should exist");
2965                assert_relative_eq!(event.weight(), *expected_weight);
2966            }
2967        }
2968    }
2969
2970    #[cfg(feature = "mpi")]
2971    #[mpi_test(np = [2])]
2972    fn test_dataset_event_mpi_repeated_access_is_stable() {
2973        use_mpi(true);
2974        assert!(get_world().is_some(), "MPI world should be initialized");
2975
2976        let dataset = test_dataset();
2977        for _ in 0..32 {
2978            let first = dataset.event(0).expect("event should exist");
2979            let second = dataset.event(0).expect("event should exist");
2980            assert_relative_eq!(first.weight(), second.weight());
2981        }
2982        finalize_mpi();
2983    }
2984
2985    #[cfg(feature = "mpi")]
2986    #[mpi_test(np = [2])]
2987    fn test_dataset_event_stress_mpi_repeated_access() {
2988        use_mpi(true);
2989        assert!(get_world().is_some(), "MPI world should be initialized");
2990
2991        let metadata = test_dataset().metadata_arc();
2992        let base = test_event();
2993        let mut events = Vec::new();
2994        for idx in 0..8 {
2995            events.push(Arc::new(EventData {
2996                p4s: base.p4s.clone(),
2997                aux: base.aux.clone(),
2998                weight: 1.0 + idx as f64,
2999            }));
3000        }
3001        let dataset = Dataset::new_with_metadata(events, metadata);
3002
3003        let baseline: Vec<f64> = (0..dataset.n_events())
3004            .map(|index| dataset.event(index).expect("event should exist").weight())
3005            .collect();
3006
3007        for _ in 0..120 {
3008            for (index, expected_weight) in baseline.iter().enumerate() {
3009                let event = dataset.event(index).expect("event should exist");
3010                assert_relative_eq!(event.weight(), *expected_weight);
3011            }
3012        }
3013        finalize_mpi();
3014    }
3015
3016    #[cfg(feature = "mpi")]
3017    #[mpi_test(np = [2])]
3018    fn test_dataset_iter_stress_mpi_repeated_passes() {
3019        use_mpi(true);
3020        assert!(get_world().is_some(), "MPI world should be initialized");
3021
3022        let metadata = test_dataset().metadata_arc();
3023        let base = test_event();
3024        let mut events = Vec::new();
3025        for idx in 0..8 {
3026            events.push(Arc::new(EventData {
3027                p4s: base.p4s.clone(),
3028                aux: base.aux.clone(),
3029                weight: 1.0 + idx as f64,
3030            }));
3031        }
3032        let dataset = Dataset::new_with_metadata(events, metadata);
3033        let baseline: Vec<f64> = dataset.iter().map(|event| event.weight()).collect();
3034
3035        for _ in 0..80 {
3036            let current: Vec<f64> = dataset.iter().map(|event| event.weight()).collect();
3037            assert_eq!(current.len(), baseline.len());
3038            for (current_weight, expected_weight) in current.iter().zip(baseline.iter()) {
3039                assert_relative_eq!(*current_weight, *expected_weight);
3040            }
3041        }
3042        finalize_mpi();
3043    }
3044
3045    #[cfg(feature = "mpi")]
3046    #[mpi_test(np = [2])]
3047    fn test_dataset_arc_into_iter_stress_mpi_repeated_passes() {
3048        use_mpi(true);
3049        assert!(get_world().is_some(), "MPI world should be initialized");
3050
3051        let metadata = test_dataset().metadata_arc();
3052        let base = test_event();
3053        let mut events = Vec::new();
3054        for idx in 0..8 {
3055            events.push(Arc::new(EventData {
3056                p4s: base.p4s.clone(),
3057                aux: base.aux.clone(),
3058                weight: 1.0 + idx as f64,
3059            }));
3060        }
3061        let dataset = Arc::new(Dataset::new_with_metadata(events, metadata));
3062        let baseline: Vec<f64> = dataset.shared_iter().map(|event| event.weight()).collect();
3063
3064        for _ in 0..80 {
3065            let current: Vec<f64> = dataset.shared_iter().map(|event| event.weight()).collect();
3066            assert_eq!(current.len(), baseline.len());
3067            for (current_weight, expected_weight) in current.iter().zip(baseline.iter()) {
3068                assert_relative_eq!(*current_weight, *expected_weight);
3069            }
3070        }
3071        finalize_mpi();
3072    }
3073
3074    #[cfg(feature = "mpi")]
3075    #[mpi_test(np = [2])]
3076    fn test_dataset_iteration_long_running_mpi_repeated_passes() {
3077        use_mpi(true);
3078        assert!(get_world().is_some(), "MPI world should be initialized");
3079
3080        let dataset = Arc::new(mpi_chunk_test_dataset(8_192));
3081        let baseline_iter = event_iteration_signature(dataset.iter());
3082        let baseline_shared = event_iteration_signature(dataset.shared_iter());
3083        assert_eq!(baseline_iter, baseline_shared);
3084        let mut post_warmup_rss_kb = Vec::new();
3085
3086        for pass_index in 0..48 {
3087            let current_iter = event_iteration_signature(dataset.iter());
3088            let current_shared = event_iteration_signature(dataset.shared_iter());
3089            assert_eq!(current_iter, baseline_iter);
3090            assert_eq!(current_shared, baseline_shared);
3091            if pass_index >= 7 {
3092                if let Some(rss_kb) = read_resident_rss_kb() {
3093                    post_warmup_rss_kb.push(rss_kb);
3094                }
3095            }
3096        }
3097
3098        if let Some((&first_rss_kb, rest_rss_kb)) = post_warmup_rss_kb.split_first() {
3099            let last_rss_kb = *rest_rss_kb.last().unwrap_or(&first_rss_kb);
3100            let min_rss_kb = post_warmup_rss_kb
3101                .iter()
3102                .copied()
3103                .min()
3104                .expect("post-warmup RSS sample should exist");
3105            let max_rss_kb = post_warmup_rss_kb
3106                .iter()
3107                .copied()
3108                .max()
3109                .expect("post-warmup RSS sample should exist");
3110            const MAX_POST_WARMUP_RSS_GROWTH_KB: u64 = 32 * 1024;
3111            const MAX_POST_WARMUP_RSS_SPREAD_KB: u64 = 32 * 1024;
3112            assert!(
3113                last_rss_kb.saturating_sub(first_rss_kb) <= MAX_POST_WARMUP_RSS_GROWTH_KB,
3114                "post-warmup RSS grew by {} KiB (first={} KiB, last={} KiB)",
3115                last_rss_kb.saturating_sub(first_rss_kb),
3116                first_rss_kb,
3117                last_rss_kb
3118            );
3119            assert!(
3120                max_rss_kb.saturating_sub(min_rss_kb) <= MAX_POST_WARMUP_RSS_SPREAD_KB,
3121                "post-warmup RSS spread was {} KiB (min={} KiB, max={} KiB)",
3122                max_rss_kb.saturating_sub(min_rss_kb),
3123                min_rss_kb,
3124                max_rss_kb
3125            );
3126        }
3127
3128        finalize_mpi();
3129    }
3130
3131    #[cfg(feature = "mpi")]
3132    #[mpi_test(np = [2])]
3133    fn test_fetch_event_chunk_mpi_matches_single_event_fetches() {
3134        use_mpi(true);
3135        let world = get_world().expect("MPI world should be initialized");
3136
3137        let dataset = mpi_chunk_test_dataset(8);
3138        let chunk = fetch_event_chunk_mpi(&dataset, 1, 5, &world, dataset.n_events());
3139
3140        assert_eq!(chunk.len(), 5);
3141        for (offset, event) in chunk.iter().enumerate() {
3142            let baseline = dataset
3143                .event(1 + offset)
3144                .expect("chunk baseline event should exist");
3145            assert_events_close(event, &baseline, TEST_P4_NAMES, TEST_AUX_NAMES);
3146        }
3147
3148        assert!(
3149            fetch_event_chunk_mpi(&dataset, dataset.n_events(), 4, &world, dataset.n_events())
3150                .is_empty()
3151        );
3152        finalize_mpi();
3153    }
3154
3155    #[cfg(feature = "mpi")]
3156    #[mpi_test(np = [2])]
3157    fn test_fetch_event_chunk_mpi_truncates_at_dataset_end() {
3158        use_mpi(true);
3159        let world = get_world().expect("MPI world should be initialized");
3160
3161        let dataset = mpi_chunk_test_dataset(8);
3162        let chunk = fetch_event_chunk_mpi(&dataset, 6, 10, &world, dataset.n_events());
3163
3164        assert_eq!(chunk.len(), 2);
3165        for (offset, event) in chunk.iter().enumerate() {
3166            let baseline = dataset
3167                .event(6 + offset)
3168                .expect("truncated chunk baseline event should exist");
3169            assert_events_close(event, &baseline, TEST_P4_NAMES, TEST_AUX_NAMES);
3170        }
3171        finalize_mpi();
3172    }
3173
3174    #[cfg(feature = "mpi")]
3175    #[mpi_test(np = [2])]
3176    fn test_mpi_event_chunk_cursor_reuses_cached_chunk_for_dataset_and_events() {
3177        use_mpi(true);
3178        let world = get_world().expect("MPI world should be initialized");
3179
3180        let dataset = mpi_chunk_test_dataset(9);
3181        let total = dataset.n_events();
3182        let metadata = dataset.metadata_arc();
3183
3184        let mut dataset_cursor = MpiEventChunkCursor::new(3);
3185        for index in 0..total {
3186            let actual = dataset_cursor
3187                .event_for_dataset(&dataset, index, &world, total)
3188                .expect("dataset cursor event should exist");
3189            let expected = dataset.event(index).expect("baseline event should exist");
3190            assert_events_close(&actual, &expected, TEST_P4_NAMES, TEST_AUX_NAMES);
3191        }
3192        assert!(dataset_cursor
3193            .event_for_dataset(&dataset, total, &world, total)
3194            .is_none());
3195
3196        let mut events_cursor = MpiEventChunkCursor::new(4);
3197        for index in 0..total {
3198            let actual = events_cursor
3199                .event_for_events(dataset.events_local(), &metadata, index, &world, total)
3200                .expect("events cursor event should exist");
3201            let expected = dataset.event(index).expect("baseline event should exist");
3202            assert_events_close(&actual, &expected, TEST_P4_NAMES, TEST_AUX_NAMES);
3203        }
3204        finalize_mpi();
3205    }
3206
3207    #[cfg(feature = "mpi")]
3208    #[test]
3209    #[ignore = "developer probe for MPI iteration chunk-size tuning"]
3210    fn probe_mpi_iteration_chunk_size() {
3211        use std::time::Instant;
3212
3213        use_mpi(true);
3214        let Some(world) = get_world() else {
3215            finalize_mpi();
3216            return;
3217        };
3218
3219        let dataset = mpi_chunk_test_dataset(32_768);
3220        let total = dataset.n_events();
3221        let chunk_sizes = [64_usize, 128, 256, 512, 1024];
3222        if world.rank() == 0 {
3223            println!("probe=iteration");
3224        }
3225        for chunk_size in chunk_sizes {
3226            let started = Instant::now();
3227            let mut checksum = 0.0;
3228            for _ in 0..8 {
3229                let mut cursor = MpiEventChunkCursor::new(chunk_size);
3230                for index in 0..total {
3231                    let event = cursor
3232                        .event_for_dataset(&dataset, index, &world, total)
3233                        .expect("cursor event should exist");
3234                    checksum += event.weight() + event.p4("beam").expect("beam should exist").e();
3235                }
3236            }
3237            if world.rank() == 0 {
3238                println!(
3239                    "probe=iteration chunk_size={} elapsed_sec={:.6} checksum={:.6}",
3240                    chunk_size,
3241                    started.elapsed().as_secs_f64(),
3242                    checksum,
3243                );
3244            }
3245        }
3246        finalize_mpi();
3247    }
3248
3249    #[cfg(feature = "mpi")]
3250    #[test]
3251    #[ignore = "developer probe for MPI ROOT write chunk-size tuning"]
3252    fn probe_mpi_root_write_chunk_size() {
3253        use std::time::Instant;
3254
3255        use_mpi(true);
3256        let Some(world) = get_world() else {
3257            finalize_mpi();
3258            return;
3259        };
3260
3261        let dataset = Arc::new(mpi_chunk_test_dataset(32_768));
3262        let chunk_sizes = [64_usize, 128, 256, 512, 1024];
3263        if world.rank() == 0 {
3264            println!("probe=root_write");
3265        }
3266        for chunk_size in chunk_sizes {
3267            let dir = make_temp_dir();
3268            let path = dir.join(format!("mpi_chunk_probe_{chunk_size}.root"));
3269            let path_str = path.to_str().expect("probe path should be valid UTF-8");
3270            let started = Instant::now();
3271            for _ in 0..4 {
3272                io::write_root_with_chunk_size_for_test(
3273                    &dataset,
3274                    path_str,
3275                    &DatasetWriteOptions::default(),
3276                    chunk_size,
3277                )
3278                .expect("probe root write should succeed");
3279            }
3280
3281            if world.rank() == 0 {
3282                println!(
3283                    "probe=root_write chunk_size={} elapsed_sec={:.6}",
3284                    chunk_size,
3285                    started.elapsed().as_secs_f64(),
3286                );
3287                fs::remove_dir_all(&dir).expect("probe temp dir cleanup should succeed");
3288            }
3289        }
3290        finalize_mpi();
3291    }
3292
3293    #[test]
3294    fn test_event_display() {
3295        let event = test_event();
3296        let display_string = format!("{}", event);
3297        assert!(display_string.contains("Event:"));
3298        assert!(display_string.contains("p4s:"));
3299        assert!(display_string.contains("aux:"));
3300        assert!(display_string.contains("aux[0]: 0.38562805"));
3301        assert!(display_string.contains("aux[1]: 0.05708078"));
3302        assert!(display_string.contains("weight:"));
3303    }
3304
3305    #[test]
3306    fn test_name_based_access() {
3307        let metadata =
3308            Arc::new(DatasetMetadata::new(vec!["beam", "target"], vec!["pol_angle"]).unwrap());
3309        let event = Arc::new(EventData {
3310            p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(0.1, 0.2, 0.3, 0.5)],
3311            aux: vec![0.42],
3312            weight: 1.0,
3313        });
3314        let dataset = Dataset::new_with_metadata(vec![event], metadata);
3315        let beam = dataset.p4_by_name(0, "beam").unwrap();
3316        assert_relative_eq!(beam.px(), 0.0);
3317        assert_relative_eq!(beam.py(), 0.0);
3318        assert_relative_eq!(beam.pz(), 1.0);
3319        assert_relative_eq!(beam.e(), 1.0);
3320        assert_relative_eq!(dataset.aux_by_name(0, "pol_angle").unwrap(), 0.42);
3321        assert!(dataset.p4_by_name(0, "missing").is_none());
3322        assert!(dataset.aux_by_name(0, "missing").is_none());
3323    }
3324
3325    #[test]
3326    fn test_parquet_roundtrip_to_tempfile() {
3327        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3328        let dir = make_temp_dir();
3329        let path = dir.join("roundtrip.parquet");
3330        let path_str = path.to_str().expect("path should be valid UTF-8");
3331
3332        write_parquet(&dataset, path_str, &DatasetWriteOptions::default())
3333            .expect("writing parquet should succeed");
3334        let reopened = read_parquet(path_str, &DatasetReadOptions::new())
3335            .expect("parquet roundtrip should reopen");
3336
3337        assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3338        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3339    }
3340
3341    #[test]
3342    fn test_parquet_roundtrip_incremental_small_batches() {
3343        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3344        let dir = make_temp_dir();
3345        let path = dir.join("roundtrip_small_batches.parquet");
3346        let path_str = path.to_str().expect("path should be valid UTF-8");
3347
3348        let write_options = DatasetWriteOptions::default().batch_size(1);
3349        write_parquet(&dataset, path_str, &write_options)
3350            .expect("writing parquet in small batches should succeed");
3351        let reopened = read_parquet(path_str, &DatasetReadOptions::new())
3352            .expect("parquet roundtrip should reopen");
3353
3354        assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3355        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3356    }
3357
3358    #[test]
3359    fn test_parquet_read_order_is_deterministic_across_repeated_reads() {
3360        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3361        let dir = make_temp_dir();
3362        let path = dir.join("deterministic_order.parquet");
3363        let path_str = path.to_str().expect("path should be valid UTF-8");
3364
3365        // Force many parquet batches so order stability is verified under incremental reads.
3366        let write_options = DatasetWriteOptions::default().batch_size(1);
3367        write_parquet(&dataset, path_str, &write_options)
3368            .expect("writing parquet in small batches should succeed");
3369
3370        let first = read_parquet(path_str, &DatasetReadOptions::new())
3371            .expect("first parquet read should succeed");
3372        let second = read_parquet(path_str, &DatasetReadOptions::new())
3373            .expect("second parquet read should succeed");
3374
3375        assert_eq!(first.n_events(), second.n_events());
3376        assert_eq!(first.n_events(), dataset.n_events());
3377        for event_index in 0..dataset.n_events() {
3378            let source = dataset
3379                .event(event_index)
3380                .expect("source event should exist");
3381            let first_event = first
3382                .event(event_index)
3383                .expect("first read event should exist");
3384            let second_event = second
3385                .event(event_index)
3386                .expect("second read event should exist");
3387            assert_events_close(&source, &first_event, TEST_P4_NAMES, TEST_AUX_NAMES);
3388            assert_events_close(&source, &second_event, TEST_P4_NAMES, TEST_AUX_NAMES);
3389        }
3390
3391        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3392    }
3393
3394    #[test]
3395    fn test_parquet_storage_roundtrip_to_tempfile() {
3396        let source_path = test_data_path("data_f32.parquet");
3397        let source_path_str = source_path.to_str().expect("path should be valid UTF-8");
3398        let dataset_columnar = read_parquet_storage(source_path_str, &DatasetReadOptions::new())
3399            .expect("columnar load");
3400        let dir = make_temp_dir();
3401        let path = dir.join("roundtrip_columnar.parquet");
3402        let path_str = path.to_str().expect("path should be valid UTF-8");
3403
3404        write_parquet_storage(&dataset_columnar, path_str, &DatasetWriteOptions::default())
3405            .expect("writing columnar parquet should succeed");
3406        let reopened = read_parquet_storage(path_str, &DatasetReadOptions::new())
3407            .expect("columnar roundtrip reopen");
3408
3409        assert_dataset_columnar_close(&dataset_columnar, &reopened);
3410        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3411    }
3412
3413    #[test]
3414    fn test_root_storage_matches_parquet_storage() {
3415        let root_path = test_data_path("data_f32.root");
3416        let root_path_str = root_path.to_str().expect("path should be valid UTF-8");
3417        let parquet_path = test_data_path("data_f32.parquet");
3418        let parquet_path_str = parquet_path.to_str().expect("path should be valid UTF-8");
3419
3420        let from_root = read_root_storage(root_path_str, &DatasetReadOptions::new())
3421            .expect("root columnar load should work");
3422        let from_parquet = read_parquet_storage(parquet_path_str, &DatasetReadOptions::new())
3423            .expect("parquet columnar load should work");
3424        assert_dataset_columnar_close(&from_root, &from_parquet);
3425    }
3426
3427    #[test]
3428    fn test_root_storage_repeated_reads_are_stable() {
3429        let root_path = test_data_path("data_f32.root");
3430        let root_path_str = root_path.to_str().expect("path should be valid UTF-8");
3431        let first = read_root_storage(root_path_str, &DatasetReadOptions::new())
3432            .expect("first root columnar load should work");
3433        let second = read_root_storage(root_path_str, &DatasetReadOptions::new())
3434            .expect("second root columnar load should work");
3435        assert_dataset_columnar_close(&first, &second);
3436    }
3437
3438    #[cfg(feature = "mpi")]
3439    #[mpi_test(np = [2])]
3440    fn test_root_storage_reads_rank_local_entry_ranges_under_mpi() {
3441        let root_path = test_data_path("data_f32.root");
3442        let root_path_str = root_path.to_str().expect("path should be valid UTF-8");
3443        let full = read_root_storage(root_path_str, &DatasetReadOptions::new())
3444            .expect("full root columnar load should work");
3445        let total = full.n_events();
3446
3447        use_mpi(true);
3448        let world = get_world().expect("MPI world should be initialized");
3449        let partition = world.partition(total);
3450        let local_range = partition.range_for_rank(world.rank() as usize);
3451
3452        let local = read_root_storage(root_path_str, &DatasetReadOptions::new())
3453            .expect("rank-local root columnar load should work");
3454        assert_eq!(local.n_events(), local_range.len());
3455
3456        for (local_index, global_index) in local_range.clone().enumerate() {
3457            for p4_index in 0..full.metadata().p4_names().len() {
3458                let expected = full.p4(global_index, p4_index);
3459                let actual = local.p4(local_index, p4_index);
3460                assert_relative_eq!(actual.px(), expected.px(), epsilon = 1e-12);
3461                assert_relative_eq!(actual.py(), expected.py(), epsilon = 1e-12);
3462                assert_relative_eq!(actual.pz(), expected.pz(), epsilon = 1e-12);
3463                assert_relative_eq!(actual.e(), expected.e(), epsilon = 1e-12);
3464            }
3465            for aux_index in 0..full.metadata().aux_names().len() {
3466                assert_relative_eq!(
3467                    local.aux(local_index, aux_index),
3468                    full.aux(global_index, aux_index),
3469                    epsilon = 1e-12
3470                );
3471            }
3472            assert_relative_eq!(
3473                local.weight(local_index),
3474                full.weight(global_index),
3475                epsilon = 1e-12
3476            );
3477        }
3478
3479        let local_dataset = local.to_dataset();
3480        assert_eq!(local_dataset.n_events_local(), local_range.len());
3481        assert_eq!(local_dataset.n_events(), total);
3482        finalize_mpi();
3483    }
3484
3485    #[test]
3486    fn test_root_roundtrip_to_tempfile() {
3487        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3488        let dir = make_temp_dir();
3489        let path = dir.join("roundtrip.root");
3490        let path_str = path.to_str().expect("path should be valid UTF-8");
3491
3492        write_root(&dataset, path_str, &DatasetWriteOptions::default())
3493            .expect("writing root should succeed");
3494        let reopened =
3495            read_root(path_str, &DatasetReadOptions::new()).expect("root roundtrip should reopen");
3496
3497        assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3498        fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3499    }
3500
3501    #[cfg(feature = "mpi")]
3502    #[mpi_test(np = [2])]
3503    fn test_root_roundtrip_to_tempfile_mpi() {
3504        let reference = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3505        use_mpi(true);
3506        let world = get_world().expect("MPI world should be initialized");
3507        let is_root = world.is_root();
3508
3509        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3510        let path = env::temp_dir().join("laddu_mpi_root_roundtrip.root");
3511        let path_str = path.to_str().expect("path should be valid UTF-8");
3512
3513        if world.is_root() && path.exists() {
3514            fs::remove_file(&path).expect("stale mpi root file cleanup should succeed");
3515        }
3516        world.barrier();
3517
3518        write_root(&dataset, path_str, &DatasetWriteOptions::default())
3519            .expect("writing root with mpi should succeed");
3520        world.barrier();
3521        world.barrier();
3522        finalize_mpi();
3523
3524        if is_root {
3525            let reopened = read_root(path_str, &DatasetReadOptions::new())
3526                .expect("root roundtrip should reopen");
3527            assert_datasets_close(&reference, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3528            if path.exists() {
3529                fs::remove_file(&path).expect("mpi root roundtrip cleanup should succeed");
3530            }
3531        }
3532    }
3533
3534    #[cfg(feature = "mpi")]
3535    #[mpi_test(np = [2])]
3536    fn test_root_output_is_deterministic_under_mpi() {
3537        use_mpi(true);
3538        let world = get_world().expect("MPI world should be initialized");
3539
3540        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3541        let first_path = env::temp_dir().join("laddu_mpi_root_determinism_first.root");
3542        let second_path = env::temp_dir().join("laddu_mpi_root_determinism_second.root");
3543        let first_path_str = first_path.to_str().expect("path should be valid UTF-8");
3544        let second_path_str = second_path.to_str().expect("path should be valid UTF-8");
3545
3546        if world.is_root() {
3547            for path in [&first_path, &second_path] {
3548                if path.exists() {
3549                    fs::remove_file(path).expect("stale mpi root file cleanup should succeed");
3550                }
3551            }
3552        }
3553        world.barrier();
3554
3555        write_root(&dataset, first_path_str, &DatasetWriteOptions::default())
3556            .expect("first mpi root write should succeed");
3557        world.barrier();
3558        write_root(&dataset, second_path_str, &DatasetWriteOptions::default())
3559            .expect("second mpi root write should succeed");
3560        world.barrier();
3561
3562        let first = read_root_storage(first_path_str, &DatasetReadOptions::new())
3563            .expect("first mpi root output should reopen");
3564        let second = read_root_storage(second_path_str, &DatasetReadOptions::new())
3565            .expect("second mpi root output should reopen");
3566        assert_dataset_columnar_close(&first, &second);
3567
3568        world.barrier();
3569        if world.is_root() {
3570            for path in [&first_path, &second_path] {
3571                if path.exists() {
3572                    fs::remove_file(path).expect("mpi root determinism cleanup should succeed");
3573                }
3574            }
3575        }
3576        finalize_mpi();
3577    }
3578
3579    #[cfg(feature = "mpi")]
3580    #[mpi_test(np = [2])]
3581    fn test_root_output_matches_between_mpi_and_non_mpi_writes() {
3582        let cpu_dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3583        let mpi_path = env::temp_dir().join("laddu_root_mpi_reference.root");
3584        let mpi_path_str = mpi_path.to_str().expect("path should be valid UTF-8");
3585
3586        use_mpi(true);
3587        let world = get_world().expect("MPI world should be initialized");
3588        let is_root = world.is_root();
3589        let mpi_dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3590
3591        if is_root && mpi_path.exists() {
3592            fs::remove_file(&mpi_path).expect("stale root file cleanup should succeed");
3593        }
3594        world.barrier();
3595        write_root(&mpi_dataset, mpi_path_str, &DatasetWriteOptions::default())
3596            .expect("mpi root write should succeed");
3597        world.barrier();
3598        world.barrier();
3599        finalize_mpi();
3600
3601        if is_root {
3602            let cpu_dir = make_temp_dir();
3603            let cpu_path = cpu_dir.join("laddu_root_cpu_reference.root");
3604            let cpu_path_str = cpu_path.to_str().expect("path should be valid UTF-8");
3605            write_root(&cpu_dataset, cpu_path_str, &DatasetWriteOptions::default())
3606                .expect("non-mpi root write should succeed");
3607
3608            let cpu_output = read_root_storage(cpu_path_str, &DatasetReadOptions::new())
3609                .expect("non-mpi root output should reopen");
3610            let mpi_output = read_root_storage(mpi_path_str, &DatasetReadOptions::new())
3611                .expect("mpi root output should reopen");
3612            assert_dataset_columnar_close(&cpu_output, &mpi_output);
3613
3614            fs::remove_dir_all(&cpu_dir).expect("root comparison temp dir cleanup should succeed");
3615            if mpi_path.exists() {
3616                fs::remove_file(&mpi_path).expect("root comparison cleanup should succeed");
3617            }
3618        }
3619    }
3620
3621    #[test]
3622    fn test_root_local_column_buffers_match_columnar_storage() {
3623        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3624        let buffers = io::build_root_local_column_buffers::<f64>(&dataset.columnar);
3625        let expected_names = dataset
3626            .p4_names()
3627            .iter()
3628            .flat_map(|name| {
3629                io::P4_COMPONENT_SUFFIXES
3630                    .iter()
3631                    .map(move |suffix| format!("{name}{suffix}"))
3632            })
3633            .chain(dataset.aux_names().iter().cloned())
3634            .chain(std::iter::once("weight".to_string()))
3635            .collect::<Vec<_>>();
3636        let expected_values = dataset
3637            .columnar
3638            .p4
3639            .iter()
3640            .flat_map(|p4| [p4.px.clone(), p4.py.clone(), p4.pz.clone(), p4.e.clone()])
3641            .chain(dataset.columnar.aux.clone())
3642            .chain(std::iter::once(dataset.columnar.weights.clone()))
3643            .collect::<Vec<_>>();
3644        assert_eq!(
3645            buffers
3646                .iter()
3647                .map(|(name, _)| name.as_str())
3648                .collect::<Vec<_>>(),
3649            expected_names
3650        );
3651        assert_eq!(
3652            buffers
3653                .into_iter()
3654                .map(|(_, values)| values)
3655                .collect::<Vec<_>>(),
3656            expected_values
3657        );
3658    }
3659
3660    #[test]
3661    fn test_root_local_column_buffers_convert_precision() {
3662        let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3663        let buffers = io::build_root_local_column_buffers::<f32>(&dataset.columnar);
3664        let expected_values = dataset
3665            .columnar
3666            .p4
3667            .iter()
3668            .flat_map(|p4| {
3669                [
3670                    p4.px.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3671                    p4.py.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3672                    p4.pz.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3673                    p4.e.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3674                ]
3675            })
3676            .chain(
3677                dataset
3678                    .columnar
3679                    .aux
3680                    .iter()
3681                    .map(|aux| aux.iter().map(|value| *value as f32).collect::<Vec<_>>()),
3682            )
3683            .chain(std::iter::once(
3684                dataset
3685                    .columnar
3686                    .weights
3687                    .iter()
3688                    .map(|value| *value as f32)
3689                    .collect::<Vec<_>>(),
3690            ))
3691            .collect::<Vec<_>>();
3692
3693        assert_eq!(
3694            buffers
3695                .into_iter()
3696                .map(|(_, values)| values)
3697                .collect::<Vec<_>>(),
3698            expected_values
3699        );
3700    }
3701
3702    #[test]
3703    fn test_parquet_chunk_iterator_matches_full_read() {
3704        let path = test_data_path("data_f32.parquet");
3705        let path_str = path.to_str().expect("path should be valid UTF-8");
3706        let options = DatasetReadOptions::new();
3707        let full = read_parquet(path_str, &options).expect("full parquet read should work");
3708        let chunks =
3709            read_parquet_chunks(path_str, &options, 17).expect("chunk iterator should open");
3710
3711        let mut global_idx = 0usize;
3712        for chunk in chunks {
3713            let chunk = chunk.expect("chunk read should succeed");
3714            for local_idx in 0..chunk.n_events_local() {
3715                let left = full
3716                    .event(global_idx)
3717                    .expect("full dataset event should exist");
3718                let right = chunk
3719                    .event(local_idx)
3720                    .expect("chunk dataset event should exist");
3721                assert_events_close(&left, &right, TEST_P4_NAMES, TEST_AUX_NAMES);
3722                global_idx += 1;
3723            }
3724        }
3725
3726        assert_eq!(global_idx, full.n_events());
3727    }
3728
3729    #[cfg(feature = "mpi")]
3730    #[mpi_test(np = [2])]
3731    fn test_parquet_chunk_iterator_respects_mpi_partition() {
3732        let path = test_data_path("data_f32.parquet");
3733        let path_str = path.to_str().expect("path should be valid UTF-8");
3734        let options = DatasetReadOptions::new();
3735        let reference =
3736            read_parquet(path_str, &options).expect("reference parquet read should work");
3737
3738        use_mpi(true);
3739        let world = get_world().expect("MPI world should be initialized");
3740        let partition = world.partition(reference.n_events());
3741        let local_range = partition.range_for_rank(world.rank() as usize);
3742        let chunks =
3743            read_parquet_chunks(path_str, &options, 17).expect("chunk iterator should open");
3744
3745        let mut local_idx = 0usize;
3746        for chunk in chunks {
3747            let chunk = chunk.expect("chunk read should succeed");
3748            assert!(chunk.n_events_local() <= 17);
3749            for chunk_idx in 0..chunk.n_events_local() {
3750                let expected = reference
3751                    .event(local_range.start + local_idx)
3752                    .expect("reference event should exist");
3753                let actual = chunk.event(chunk_idx).expect("chunk event should exist");
3754                assert_events_close(&expected, &actual, TEST_P4_NAMES, TEST_AUX_NAMES);
3755                local_idx += 1;
3756            }
3757        }
3758
3759        assert_eq!(local_idx, local_range.len());
3760        let mut gathered_counts = vec![0usize; world.size() as usize];
3761        world.all_gather_into(&local_idx, &mut gathered_counts);
3762        assert_eq!(
3763            gathered_counts.into_iter().sum::<usize>(),
3764            reference.n_events()
3765        );
3766        finalize_mpi();
3767    }
3768
3769    #[test]
3770    fn test_parquet_chunk_iterator_with_options_chunk_size_one() {
3771        let path = test_data_path("data_f32.parquet");
3772        let path_str = path.to_str().expect("path should be valid UTF-8");
3773        let options = DatasetReadOptions::new().chunk_size(1);
3774        let full = read_parquet(path_str, &DatasetReadOptions::new())
3775            .expect("full parquet read should work");
3776        let chunks = read_parquet_chunks_with_options(path_str, &options)
3777            .expect("chunk iterator should open");
3778        let mut event_count = 0usize;
3779        let mut chunk_count = 0usize;
3780
3781        for chunk in chunks {
3782            let chunk = chunk.expect("chunk read should succeed");
3783            chunk_count += 1;
3784            assert_eq!(chunk.n_events_local(), 1);
3785            event_count += chunk.n_events_local();
3786        }
3787
3788        assert_eq!(event_count, full.n_events());
3789        assert_eq!(chunk_count, full.n_events());
3790    }
3791
3792    #[test]
3793    fn test_parquet_chunk_iterator_with_options_large_chunk_size() {
3794        let path = test_data_path("data_f32.parquet");
3795        let path_str = path.to_str().expect("path should be valid UTF-8");
3796        let full = read_parquet(path_str, &DatasetReadOptions::new())
3797            .expect("full parquet read should work");
3798        let options = DatasetReadOptions::new().chunk_size(full.n_events() + 100);
3799        let chunks = read_parquet_chunks_with_options(path_str, &options)
3800            .expect("chunk iterator should open");
3801        let chunk_vec = chunks
3802            .collect::<LadduResult<Vec<_>>>()
3803            .expect("all chunk reads should succeed");
3804
3805        assert_eq!(chunk_vec.len(), 1);
3806        assert_eq!(chunk_vec[0].n_events_local(), full.n_events());
3807    }
3808
3809    #[test]
3810    fn test_dataset_chunk_builder_matches_full_parquet_read() {
3811        let path = test_data_path("data_f32.parquet");
3812        let path_str = path.to_str().expect("path should be valid UTF-8");
3813        let options = DatasetReadOptions::new().chunk_size(13);
3814        let full = read_parquet(path_str, &DatasetReadOptions::new())
3815            .expect("full parquet read should work");
3816        let chunks = read_parquet_chunks_with_options(path_str, &options)
3817            .expect("chunk iterator should open");
3818
3819        let mut builder = DatasetChunkBuilder::new();
3820        for chunk in chunks {
3821            let chunk = chunk.expect("chunk read should succeed");
3822            builder.push_chunk(&chunk).expect("chunk should append");
3823        }
3824        let rebuilt = builder.finish();
3825
3826        assert_datasets_close(&full, &rebuilt, TEST_P4_NAMES, TEST_AUX_NAMES);
3827    }
3828
3829    #[test]
3830    fn test_try_fold_dataset_chunks_matches_full_weight_sum() {
3831        let path = test_data_path("data_f32.parquet");
3832        let path_str = path.to_str().expect("path should be valid UTF-8");
3833        let full = read_parquet(path_str, &DatasetReadOptions::new())
3834            .expect("full parquet read should work");
3835        let chunks = read_parquet_chunks(path_str, &DatasetReadOptions::new(), 11)
3836            .expect("chunk iterator should open");
3837
3838        let folded = try_fold_dataset_chunks(chunks, 0.0_f64, |acc, chunk| {
3839            Ok(acc + chunk.n_events_weighted_local())
3840        })
3841        .expect("chunk fold should succeed");
3842
3843        assert_relative_eq!(folded, full.n_events_weighted_local(), epsilon = 1e-9);
3844    }
3845}