Skip to main content

laddu_core/data/
dataset.rs

1use std::{
2    borrow::Cow,
3    ops::{Deref, DerefMut, Index, IndexMut},
4    sync::Arc,
5};
6
7use accurate::{sum::Klein, traits::*};
8use auto_ops::impl_op_ex;
9#[cfg(feature = "mpi")]
10use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
11#[cfg(feature = "rayon")]
12use rayon::prelude::*;
13
14use super::{
15    event::{test_event, ColumnarP4Column, DatasetStorage, Event, EventData, OwnedEvent},
16    metadata::DatasetMetadata,
17};
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21#[cfg(feature = "mpi")]
22pub(crate) type WorldHandle = SimpleCommunicator;
23#[cfg(not(feature = "mpi"))]
24pub(crate) type WorldHandle = ();
25
26#[cfg(feature = "mpi")]
27// Chosen from local two-rank probes: 512 matched or beat smaller chunks
28// while keeping the fetched-event cache modest.
29const DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE: usize = 512;
30#[cfg(feature = "mpi")]
31const MPI_EVENT_FETCH_CHUNK_SIZE_ENV: &str = "LADDU_MPI_EVENT_FETCH_CHUNK_SIZE";
32
33use indexmap::IndexMap;
34
35use crate::{
36    math::get_bin_edges,
37    variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
38    vectors::Vec4,
39    LadduError, LadduResult,
40};
41
42const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
43const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
44
45fn local_weighted_sum(weights: &[f64]) -> f64 {
46    #[cfg(feature = "rayon")]
47    {
48        weights
49            .par_iter()
50            .copied()
51            .parallel_sum_with_accumulator::<Klein<f64>>()
52    }
53    #[cfg(not(feature = "rayon"))]
54    {
55        weights.iter().copied().sum_with_accumulator::<Klein<f64>>()
56    }
57}
58
59/// A dataset that can be used to test the implementation of an
60/// [`Amplitude`](crate::amplitude::Amplitude). This particular dataset contains a single
61/// [`EventData`] generated from [`test_event`].
62pub fn test_dataset() -> Dataset {
63    let metadata = Arc::new(
64        DatasetMetadata::new(
65            TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
66            TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
67        )
68        .expect("Test metadata should be valid"),
69    );
70    Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
71}
72
73/// A collection of events with optional metadata for name-based lookups.
74#[derive(Debug, Clone)]
75pub struct Dataset {
76    pub(crate) columnar: Arc<DatasetStorage>,
77    rows: RowSelection,
78    pub(crate) metadata: Arc<DatasetMetadata>,
79    pub(crate) cached_local_weighted_sum: f64,
80    #[cfg(feature = "mpi")]
81    pub(crate) cached_global_event_count: usize,
82    #[cfg(feature = "mpi")]
83    pub(crate) cached_global_weighted_sum: f64,
84    #[cfg(feature = "mpi")]
85    pub(crate) mpi_layout: Option<MpiDatasetLayout>,
86}
87
88#[derive(Debug, Clone)]
89enum RowSelection {
90    Identity,
91    Indices(Arc<[usize]>),
92}
93
94impl RowSelection {
95    fn len(&self, storage_len: usize) -> usize {
96        match self {
97            Self::Identity => storage_len,
98            Self::Indices(indices) => indices.len(),
99        }
100    }
101
102    const fn is_identity(&self) -> bool {
103        matches!(self, Self::Identity)
104    }
105
106    fn physical_index(&self, logical_index: usize) -> usize {
107        match self {
108            Self::Identity => logical_index,
109            Self::Indices(indices) => indices[logical_index],
110        }
111    }
112}
113
114#[cfg(feature = "mpi")]
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub(crate) enum MpiDatasetLayout {
117    Canonical,
118    RoundRobin,
119    Derived,
120}
121
122#[cfg(feature = "mpi")]
123impl MpiDatasetLayout {
124    fn owner_of(
125        self,
126        global_index: usize,
127        total: usize,
128        local_len: usize,
129        world: &SimpleCommunicator,
130    ) -> (i32, usize) {
131        match self {
132            Self::Canonical => world.owner_of_global_index(global_index, total),
133            Self::RoundRobin => {
134                let size = world.size() as usize;
135                ((global_index % size) as i32, global_index / size)
136            }
137            Self::Derived => {
138                let counts = gather_local_event_counts(local_len, world);
139                let mut start = 0usize;
140                for (rank, count) in counts.into_iter().enumerate() {
141                    let end = start + count;
142                    if global_index < end {
143                        return (rank as i32, global_index - start);
144                    }
145                    start = end;
146                }
147                debug_assert!(
148                    global_index < total,
149                    "validated derived global event index should be in range"
150                );
151                (world.rank(), 0)
152            }
153        }
154    }
155
156    fn local_range(
157        self,
158        total: usize,
159        local_len: usize,
160        world: &SimpleCommunicator,
161    ) -> std::ops::Range<usize> {
162        match self {
163            Self::Canonical => world.partition(total).range_for_rank(world.rank() as usize),
164            Self::RoundRobin => 0..local_len_for_round_robin(total, world),
165            Self::Derived => {
166                let counts = gather_local_event_counts(local_len, world);
167                let start = counts
168                    .iter()
169                    .take(world.rank() as usize)
170                    .copied()
171                    .sum::<usize>();
172                start..start + counts[world.rank() as usize]
173            }
174        }
175    }
176
177    fn local_indices_for_range(
178        self,
179        start: usize,
180        end: usize,
181        total: usize,
182        local_len: usize,
183        world: &SimpleCommunicator,
184    ) -> Vec<usize> {
185        match self {
186            Self::Canonical => {
187                let local_range = self.local_range(total, local_len, world);
188                let owned_start = start.max(local_range.start);
189                let owned_end = end.min(local_range.end);
190                if owned_start < owned_end {
191                    (owned_start - local_range.start..owned_end - local_range.start).collect()
192                } else {
193                    Vec::new()
194                }
195            }
196            Self::RoundRobin => {
197                let rank = world.rank() as usize;
198                let size = world.size() as usize;
199                (start..end)
200                    .filter_map(|global_index| {
201                        if global_index % size == rank {
202                            Some(global_index / size)
203                        } else {
204                            None
205                        }
206                    })
207                    .filter(|local_index| *local_index < local_len)
208                    .collect()
209            }
210            Self::Derived => {
211                let counts = gather_local_event_counts(local_len, world);
212                let local_start = counts
213                    .iter()
214                    .take(world.rank() as usize)
215                    .copied()
216                    .sum::<usize>();
217                let local_end = local_start + local_len;
218                let owned_start = start.max(local_start);
219                let owned_end = end.min(local_end);
220                if owned_start < owned_end {
221                    (owned_start - local_start..owned_end - local_start).collect()
222                } else {
223                    Vec::new()
224                }
225            }
226        }
227    }
228}
229
230#[cfg(feature = "mpi")]
231fn gather_local_event_counts(local_len: usize, world: &SimpleCommunicator) -> Vec<usize> {
232    let mut counts = vec![0usize; world.size() as usize];
233    world.all_gather_into(&local_len, &mut counts);
234    counts
235}
236
237#[cfg(feature = "mpi")]
238fn local_len_for_round_robin(total: usize, world: &SimpleCommunicator) -> usize {
239    let rank = world.rank() as usize;
240    let size = world.size() as usize;
241    if total <= rank {
242        0
243    } else {
244        (total - 1 - rank) / size + 1
245    }
246}
247
248fn shared_dataset_iter(dataset: Arc<Dataset>) -> DatasetArcIter {
249    #[cfg(feature = "mpi")]
250    {
251        if let Some(world) = crate::mpi::get_world() {
252            if let Some(layout) = dataset.mpi_layout {
253                let total = dataset.n_events();
254                return DatasetArcIter::Mpi(DatasetArcMpiIter {
255                    dataset,
256                    world,
257                    index: 0,
258                    total,
259                    cursor: MpiEventChunkCursor::for_iteration(total),
260                    layout,
261                });
262            }
263        }
264    }
265    DatasetArcIter::Local { dataset, index: 0 }
266}
267
268/// Extension methods for shared [`Arc<Dataset>`] handles.
269pub trait SharedDatasetIterExt {
270    /// Build an iterator over a shared [`Arc<Dataset>`] without cloning the dataset contents.
271    fn shared_iter(&self) -> DatasetArcIter;
272
273    /// Alias for [`SharedDatasetIterExt::shared_iter`].
274    fn shared_iter_global(&self) -> DatasetArcIter;
275}
276
277impl SharedDatasetIterExt for Arc<Dataset> {
278    fn shared_iter(&self) -> DatasetArcIter {
279        shared_dataset_iter(self.clone())
280    }
281
282    fn shared_iter_global(&self) -> DatasetArcIter {
283        self.shared_iter()
284    }
285}
286
287impl Dataset {
288    fn from_columnar_storage(
289        columnar: DatasetStorage,
290        metadata: Arc<DatasetMetadata>,
291        rows: RowSelection,
292    ) -> Self {
293        #[cfg(feature = "mpi")]
294        let local_count = rows.len(columnar.n_events());
295        let local_weighted_sum = Self::weighted_sum_for_rows(&columnar, &rows);
296        Dataset {
297            columnar: Arc::new(columnar),
298            rows,
299            metadata,
300            cached_local_weighted_sum: local_weighted_sum,
301            #[cfg(feature = "mpi")]
302            cached_global_event_count: local_count,
303            #[cfg(feature = "mpi")]
304            cached_global_weighted_sum: local_weighted_sum,
305            #[cfg(feature = "mpi")]
306            mpi_layout: None,
307        }
308    }
309
310    fn weighted_sum_for_rows(columnar: &DatasetStorage, rows: &RowSelection) -> f64 {
311        match rows {
312            RowSelection::Identity => local_weighted_sum(&columnar.weights),
313            RowSelection::Indices(indices) => {
314                #[cfg(feature = "rayon")]
315                {
316                    indices
317                        .par_iter()
318                        .map(|index| columnar.weight(*index))
319                        .parallel_sum_with_accumulator::<Klein<f64>>()
320                }
321                #[cfg(not(feature = "rayon"))]
322                {
323                    indices
324                        .iter()
325                        .map(|index| columnar.weight(*index))
326                        .sum_with_accumulator::<Klein<f64>>()
327                }
328            }
329        }
330    }
331
332    fn indexed_local_view<I>(&self, indices: I) -> Arc<Dataset>
333    where
334        I: IntoIterator<Item = usize>,
335    {
336        let rows = RowSelection::Indices(indices.into_iter().collect::<Vec<_>>().into());
337        let local_weighted_sum = Self::weighted_sum_for_rows(&self.columnar, &rows);
338        let dataset = Dataset {
339            columnar: self.columnar.clone(),
340            rows,
341            metadata: self.metadata.clone(),
342            cached_local_weighted_sum: local_weighted_sum,
343            #[cfg(feature = "mpi")]
344            cached_global_event_count: 0,
345            #[cfg(feature = "mpi")]
346            cached_global_weighted_sum: local_weighted_sum,
347            #[cfg(feature = "mpi")]
348            mpi_layout: self.mpi_layout,
349        };
350        #[cfg(feature = "mpi")]
351        {
352            let mut dataset = dataset;
353            if dataset.mpi_layout.is_some() {
354                dataset.mpi_layout = Some(MpiDatasetLayout::Derived);
355                if let Some(world) = crate::mpi::get_world() {
356                    dataset.set_cached_global_event_count_from_world(&world);
357                    dataset.set_cached_global_weighted_sum_from_world(&world);
358                }
359            }
360            Arc::new(dataset)
361        }
362        #[cfg(not(feature = "mpi"))]
363        {
364            Arc::new(dataset)
365        }
366    }
367
368    fn ensure_mutable_storage(&self, operation: &str) -> LadduResult<()> {
369        if self.rows.is_identity() {
370            Ok(())
371        } else {
372            Err(LadduError::Custom(format!(
373                "Cannot {operation} on a filtered or bootstrapped dataset view; materialize it first"
374            )))
375        }
376    }
377
378    /// Iterate over locally stored events as borrowed [`Event`] views.
379    pub fn events_local(&self) -> impl Iterator<Item = Event<'_>> {
380        DatasetViewIter {
381            dataset: self,
382            index: 0,
383        }
384    }
385
386    /// Iterate over all events using the default global iteration semantics.
387    ///
388    /// When MPI is enabled, the iterator is ordered like [`Dataset::event_global`] and may
389    /// fetch remotely owned events in chunks.
390    pub fn events_global(&self) -> DatasetGlobalIter<'_> {
391        let total = self.n_events();
392        #[cfg(feature = "mpi")]
393        {
394            if let (Some(world), Some(layout)) = (crate::mpi::get_world(), self.mpi_layout) {
395                return DatasetGlobalIter {
396                    dataset: self,
397                    index: 0,
398                    total,
399                    world: Some(world),
400                    cursor: Some(MpiEventChunkCursor::for_iteration(total)),
401                    layout: Some(layout),
402                };
403            }
404        }
405        DatasetGlobalIter {
406            dataset: self,
407            index: 0,
408            total,
409            #[cfg(feature = "mpi")]
410            world: None,
411            #[cfg(feature = "mpi")]
412            cursor: None,
413            #[cfg(feature = "mpi")]
414            layout: None,
415        }
416    }
417
418    fn refresh_local_weight_cache(&mut self) {
419        self.cached_local_weighted_sum = Self::weighted_sum_for_rows(&self.columnar, &self.rows);
420        #[cfg(feature = "mpi")]
421        {
422            self.cached_global_weighted_sum = self.cached_local_weighted_sum;
423            self.cached_global_event_count = self.n_events_local();
424            if self.mpi_layout.is_some() {
425                if let Some(world) = crate::mpi::get_world() {
426                    self.set_cached_global_event_count_from_world(&world);
427                    self.set_cached_global_weighted_sum_from_world(&world);
428                }
429            }
430        }
431    }
432
433    #[cfg(test)]
434    pub(crate) fn clear_events_local(&mut self) {
435        self.ensure_mutable_storage("clear local events")
436            .expect("test datasets should be materialized");
437        let columnar = Arc::make_mut(&mut self.columnar);
438        for column in &mut columnar.p4 {
439            column.px.clear();
440            column.py.clear();
441            column.pz.clear();
442            column.e.clear();
443        }
444        for column in &mut columnar.aux {
445            column.clear();
446        }
447        columnar.weights.clear();
448        self.refresh_local_weight_cache();
449    }
450
451    /// Borrow the dataset metadata used for name lookups.
452    pub fn metadata(&self) -> &DatasetMetadata {
453        &self.metadata
454    }
455
456    /// Clone the internal metadata handle for external consumers (e.g., language bindings).
457    pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
458        self.metadata.clone()
459    }
460
461    /// Names corresponding to stored four-momenta.
462    pub fn p4_names(&self) -> &[String] {
463        &self.metadata.p4_names
464    }
465
466    /// Names corresponding to stored auxiliary scalars.
467    pub fn aux_names(&self) -> &[String] {
468        &self.metadata.aux_names
469    }
470
471    /// Resolve the index of a four-momentum by name.
472    pub fn p4_index(&self, name: &str) -> Option<usize> {
473        self.metadata.p4_index(name)
474    }
475
476    /// Resolve the index of an auxiliary scalar by name.
477    pub fn aux_index(&self, name: &str) -> Option<usize> {
478        self.metadata.aux_index(name)
479    }
480
481    fn event_global_opt(&self, index: usize) -> LadduResult<Option<OwnedEvent>> {
482        #[cfg(feature = "mpi")]
483        {
484            if let (Some(world), Some(_)) = (crate::mpi::get_world(), self.mpi_layout) {
485                let total = self.n_events();
486                if index >= total {
487                    return Ok(None);
488                }
489                return self.fetch_event_mpi(index, &world, total).map(Some);
490            }
491        }
492
493        Ok((index < self.n_events_local())
494            .then(|| OwnedEvent::new(self.event_data_arc_local(index), self.metadata.clone())))
495    }
496
497    /// Retrieve a single owned event by global index.
498    pub fn event_global(&self, index: usize) -> LadduResult<OwnedEvent> {
499        self.event_global_opt(index)?.ok_or_else(|| {
500            LadduError::Custom(format!(
501                "Dataset index out of bounds: index {index}, length {}",
502                self.n_events()
503            ))
504        })
505    }
506
507    /// Borrow a locally stored event by local rank index.
508    pub fn event_local(&self, event_index: usize) -> LadduResult<Event<'_>> {
509        if event_index >= self.n_events_local() {
510            return Err(LadduError::Custom(format!(
511                "Dataset local index out of bounds: index {event_index}, length {}",
512                self.n_events_local()
513            )));
514        }
515        Ok(self.event_view(event_index))
516    }
517
518    /// Retrieve a four-momentum by name for the event at `event_index`.
519    pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
520        self.event_global_opt(event_index)
521            .ok()
522            .flatten()
523            .and_then(|event| event.p4(name))
524    }
525
526    /// Retrieve an auxiliary scalar by name for the event at `event_index`.
527    pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
528        let idx = self.aux_index(name)?;
529        self.event_global_opt(event_index)
530            .ok()
531            .flatten()
532            .and_then(|event| event.aux.get(idx).copied())
533    }
534
535    pub(crate) fn event_view(&self, event_index: usize) -> Event<'_> {
536        self.columnar
537            .event_view(self.rows.physical_index(event_index))
538    }
539
540    /// Get a reference to the [`EventData`] at the given index in the [`Dataset`] (non-MPI
541    /// version).
542    ///
543    /// # Notes
544    ///
545    /// This method is not intended to be called in analyses but rather in writing methods
546    /// that have `mpi`-feature-gated versions. Most users should use [`Dataset::event`] instead:
547    ///
548    /// ```ignore
549    /// let ds: Dataset = Dataset::new(events);
550    /// let event_0 = ds.event_global(0)?;
551    /// ```
552    pub(crate) fn event_data_arc_local(&self, index: usize) -> Arc<EventData> {
553        Arc::new(self.columnar.event_data(self.rows.physical_index(index)))
554    }
555
556    pub(crate) fn local_event_data_arcs(&self) -> Vec<Arc<EventData>> {
557        (0..self.n_events_local())
558            .map(|index| self.event_data_arc_local(index))
559            .collect()
560    }
561
562    pub(crate) fn local_storage_for_export(&self) -> LadduResult<Cow<'_, DatasetStorage>> {
563        if self.rows.is_identity() {
564            Ok(Cow::Borrowed(self.columnar.as_ref()))
565        } else {
566            Ok(Cow::Owned(Self::columnar_from_events(
567                &self.local_event_data_arcs(),
568                self.metadata.clone(),
569            )?))
570        }
571    }
572
573    pub(crate) fn local_weight_cache_key(&self) -> (usize, usize) {
574        match &self.rows {
575            RowSelection::Identity => (
576                self.columnar.weights.as_ptr() as usize,
577                self.n_events_local(),
578            ),
579            RowSelection::Indices(indices) => (indices.as_ptr() as usize, indices.len()),
580        }
581    }
582
583    #[cfg(feature = "mpi")]
584    fn partition(
585        events: Vec<Arc<EventData>>,
586        world: &SimpleCommunicator,
587    ) -> Vec<Vec<Arc<EventData>>> {
588        let partition = world.partition(events.len());
589        (0..partition.n_ranks())
590            .map(|rank| {
591                let range = partition.range_for_rank(rank);
592                events[range.clone()].to_vec()
593            })
594            .collect()
595    }
596}
597
598/// Iterator over local borrowed event views in a [`Dataset`].
599pub(crate) struct DatasetViewIter<'a> {
600    dataset: &'a Dataset,
601    index: usize,
602}
603
604impl<'a> Iterator for DatasetViewIter<'a> {
605    type Item = Event<'a>;
606
607    fn next(&mut self) -> Option<Self::Item> {
608        if self.index >= self.dataset.n_events_local() {
609            return None;
610        }
611        let event = self.dataset.event_view(self.index);
612        self.index += 1;
613        Some(event)
614    }
615}
616
617/// Iterator over global owned events in a [`Dataset`].
618pub struct DatasetGlobalIter<'a> {
619    dataset: &'a Dataset,
620    index: usize,
621    total: usize,
622    #[cfg(feature = "mpi")]
623    world: Option<SimpleCommunicator>,
624    #[cfg(feature = "mpi")]
625    cursor: Option<MpiEventChunkCursor>,
626    #[cfg(feature = "mpi")]
627    layout: Option<MpiDatasetLayout>,
628}
629
630impl Iterator for DatasetGlobalIter<'_> {
631    type Item = OwnedEvent;
632
633    fn next(&mut self) -> Option<Self::Item> {
634        if self.index >= self.total {
635            return None;
636        }
637        let index = self.index;
638        self.index += 1;
639
640        #[cfg(feature = "mpi")]
641        {
642            if let (Some(world), Some(cursor), Some(layout)) =
643                (&self.world, &mut self.cursor, self.layout)
644            {
645                return cursor
646                    .event_for_dataset(self.dataset, index, world, self.total, layout)
647                    .ok()
648                    .flatten();
649            }
650        }
651
652        self.dataset.event_global_opt(index).ok().flatten()
653    }
654}
655
656/// Iterator over a shared [`Arc<Dataset>`].
657pub enum DatasetArcIter {
658    /// Iterator over locally available events from a shared dataset handle.
659    Local {
660        /// Shared dataset handle.
661        dataset: Arc<Dataset>,
662        /// Next local event index to read.
663        index: usize,
664    },
665    #[cfg(feature = "mpi")]
666    /// Iterator that fetches events across MPI ranks from a shared dataset handle.
667    Mpi(DatasetArcMpiIter),
668}
669
670impl Iterator for DatasetArcIter {
671    type Item = OwnedEvent;
672
673    fn next(&mut self) -> Option<Self::Item> {
674        match self {
675            DatasetArcIter::Local { dataset, index } => {
676                let event = (*index < dataset.n_events_local()).then(|| {
677                    OwnedEvent::new(
678                        dataset.event_data_arc_local(*index),
679                        dataset.metadata.clone(),
680                    )
681                });
682                *index += 1;
683                event
684            }
685            #[cfg(feature = "mpi")]
686            DatasetArcIter::Mpi(iter) => iter.next(),
687        }
688    }
689}
690
691#[cfg(feature = "mpi")]
692#[derive(Debug, Clone)]
693pub(crate) struct MpiEventChunkCursor {
694    chunk_start: usize,
695    chunk_size: usize,
696    cached_events: Vec<OwnedEvent>,
697}
698
699#[cfg(feature = "mpi")]
700pub(crate) fn resolve_mpi_event_fetch_chunk_size(total: usize) -> usize {
701    let clamped_total = total.max(1);
702    if let Some(raw) = std::env::var_os(MPI_EVENT_FETCH_CHUNK_SIZE_ENV) {
703        if let Some(parsed) = raw.to_str().and_then(|value| value.parse::<usize>().ok()) {
704            return parsed.max(1).min(clamped_total);
705        }
706    }
707    DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE.min(clamped_total)
708}
709
710#[cfg(feature = "mpi")]
711#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
712enum ColumnMutationKind {
713    P4,
714    Aux,
715}
716
717#[cfg(feature = "mpi")]
718#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
719struct ColumnMutationStatus {
720    kind: ColumnMutationKind,
721    name: String,
722    len_ok: bool,
723    duplicate: bool,
724}
725
726#[cfg(feature = "mpi")]
727impl MpiEventChunkCursor {
728    pub(crate) fn for_iteration(total: usize) -> Self {
729        Self::new(resolve_mpi_event_fetch_chunk_size(total))
730    }
731}
732
733#[cfg(feature = "mpi")]
734impl MpiEventChunkCursor {
735    pub(crate) fn new(chunk_size: usize) -> Self {
736        Self {
737            chunk_start: 0,
738            chunk_size: chunk_size.max(1),
739            cached_events: Vec::new(),
740        }
741    }
742
743    fn chunk_end(&self) -> usize {
744        self.chunk_start + self.cached_events.len()
745    }
746
747    fn contains(&self, global_index: usize) -> bool {
748        global_index >= self.chunk_start && global_index < self.chunk_end()
749    }
750
751    pub(crate) fn event_for_dataset(
752        &mut self,
753        dataset: &Dataset,
754        global_index: usize,
755        world: &SimpleCommunicator,
756        total: usize,
757        layout: MpiDatasetLayout,
758    ) -> LadduResult<Option<OwnedEvent>> {
759        if global_index >= total {
760            return Ok(None);
761        }
762        if !self.contains(global_index) {
763            self.chunk_start = global_index;
764            self.cached_events = dataset.fetch_event_chunk_mpi(
765                global_index,
766                self.chunk_size,
767                world,
768                total,
769                layout,
770            )?;
771        }
772        Ok(self
773            .cached_events
774            .get(global_index - self.chunk_start)
775            .cloned())
776    }
777}
778
779#[cfg(feature = "mpi")]
780/// Iterator over a shared [`Arc<Dataset>`] that fetches events across MPI ranks.
781pub struct DatasetArcMpiIter {
782    dataset: Arc<Dataset>,
783    world: SimpleCommunicator,
784    index: usize,
785    total: usize,
786    cursor: MpiEventChunkCursor,
787    layout: MpiDatasetLayout,
788}
789
790#[cfg(feature = "mpi")]
791impl Iterator for DatasetArcMpiIter {
792    type Item = OwnedEvent;
793
794    fn next(&mut self) -> Option<Self::Item> {
795        let event = self
796            .cursor
797            .event_for_dataset(
798                &self.dataset,
799                self.index,
800                &self.world,
801                self.total,
802                self.layout,
803            )
804            .ok()
805            .flatten();
806        self.index += 1;
807        event
808    }
809}
810
811impl Dataset {
812    #[cfg(feature = "mpi")]
813    fn validate_global_column_add(
814        &self,
815        kind: ColumnMutationKind,
816        name: &str,
817        len_ok: bool,
818    ) -> LadduResult<()> {
819        let Some(world) = crate::mpi::get_world() else {
820            return Ok(());
821        };
822        let duplicate = match kind {
823            ColumnMutationKind::P4 => self.metadata.ensure_new_p4_name(name).is_err(),
824            ColumnMutationKind::Aux => self.metadata.ensure_new_aux_name(name).is_err(),
825        };
826        let local_status = ColumnMutationStatus {
827            kind,
828            name: name.to_string(),
829            len_ok,
830            duplicate,
831        };
832        let serialized = bitcode::serialize(&local_status)?;
833        let local_byte_count = serialized.len() as i32;
834        let mut byte_counts = vec![0_i32; world.size() as usize];
835        world.all_gather_into(&local_byte_count, &mut byte_counts);
836        let mut byte_displs = vec![0_i32; byte_counts.len()];
837        for index in 1..byte_displs.len() {
838            byte_displs[index] = byte_displs[index - 1] + byte_counts[index - 1];
839        }
840        let gathered_bytes = world.all_gather_with_counts(&serialized, &byte_counts, &byte_displs);
841        let mut statuses = Vec::with_capacity(world.size() as usize);
842        for rank in 0..world.size() as usize {
843            let start = byte_displs[rank] as usize;
844            let end = start + byte_counts[rank] as usize;
845            statuses.push(bitcode::deserialize::<ColumnMutationStatus>(
846                &gathered_bytes[start..end],
847            )?);
848        }
849        for (rank, status) in statuses.iter().enumerate() {
850            if status.kind != kind {
851                return Err(LadduError::Custom(format!(
852                    "Collective dataset column add mismatch: rank {rank} used {:?}, expected {:?}",
853                    status.kind, kind
854                )));
855            }
856            if status.name != name {
857                return Err(LadduError::Custom(format!(
858                    "Collective dataset column add mismatch: rank {rank} used name '{}', expected '{name}'",
859                    status.name
860                )));
861            }
862            if !status.len_ok {
863                return Err(LadduError::Custom(format!(
864                    "Collective dataset column add mismatch: rank {rank} provided a column with the wrong local length"
865                )));
866            }
867            if status.duplicate {
868                let category = match kind {
869                    ColumnMutationKind::P4 => "p4",
870                    ColumnMutationKind::Aux => "aux",
871                };
872                return Err(LadduError::DuplicateName {
873                    category,
874                    name: name.to_string(),
875                });
876            }
877        }
878        Ok(())
879    }
880
881    #[cfg(feature = "mpi")]
882    fn fetch_event_mpi(
883        &self,
884        global_index: usize,
885        world: &SimpleCommunicator,
886        total: usize,
887    ) -> LadduResult<OwnedEvent> {
888        let layout = self.mpi_layout.ok_or_else(|| {
889            LadduError::Custom(
890                "global MPI event fetch requires a global dataset layout".to_string(),
891            )
892        })?;
893        let (owning_rank, local_index) =
894            layout.owner_of(global_index, total, self.n_events_local(), world);
895        let mut serialized_event_buffer_len: usize = 0;
896        let mut serialized_event_buffer: Vec<u8> = Vec::default();
897        if world.rank() == owning_rank {
898            let event = self
899                .columnar
900                .event_data(self.rows.physical_index(local_index));
901            serialized_event_buffer = bitcode::serialize(&event)?;
902            serialized_event_buffer_len = serialized_event_buffer.len();
903        }
904        world
905            .process_at_rank(owning_rank)
906            .broadcast_into(&mut serialized_event_buffer_len);
907        if world.rank() != owning_rank {
908            serialized_event_buffer = vec![0; serialized_event_buffer_len];
909        }
910        world
911            .process_at_rank(owning_rank)
912            .broadcast_into(&mut serialized_event_buffer);
913
914        if world.rank() == owning_rank {
915            Ok(OwnedEvent::new(
916                Arc::new(
917                    self.columnar
918                        .event_data(self.rows.physical_index(local_index)),
919                ),
920                self.metadata.clone(),
921            ))
922        } else {
923            let event: EventData = bitcode::deserialize(&serialized_event_buffer[..])?;
924            Ok(OwnedEvent::new(Arc::new(event), self.metadata.clone()))
925        }
926    }
927
928    #[cfg(feature = "mpi")]
929    pub(crate) fn fetch_event_chunk_mpi(
930        &self,
931        start: usize,
932        len: usize,
933        world: &SimpleCommunicator,
934        total: usize,
935        layout: MpiDatasetLayout,
936    ) -> LadduResult<Vec<OwnedEvent>> {
937        if len == 0 || start >= total {
938            return Ok(Vec::new());
939        }
940
941        let end = (start + len).min(total);
942        let local_indices =
943            layout.local_indices_for_range(start, end, total, self.n_events_local(), world);
944
945        let local_events: Vec<EventData> = local_indices
946            .into_iter()
947            .map(|local_index| {
948                self.columnar
949                    .event_data(self.rows.physical_index(local_index))
950            })
951            .collect();
952        let local_event_count = local_events.len() as i32;
953
954        let serialized_local = if local_events.is_empty() {
955            Vec::new()
956        } else {
957            bitcode::serialize(&local_events)?
958        };
959        let local_byte_count = serialized_local.len() as i32;
960
961        let mut gathered_event_counts = vec![0_i32; world.size() as usize];
962        let mut gathered_byte_counts = vec![0_i32; world.size() as usize];
963        world.all_gather_into(&local_event_count, &mut gathered_event_counts);
964        world.all_gather_into(&local_byte_count, &mut gathered_byte_counts);
965
966        let mut gathered_byte_displs = vec![0_i32; gathered_byte_counts.len()];
967        for index in 1..gathered_byte_displs.len() {
968            gathered_byte_displs[index] =
969                gathered_byte_displs[index - 1] + gathered_byte_counts[index - 1];
970        }
971        let gathered_bytes = world.all_gather_with_counts(
972            &serialized_local,
973            &gathered_byte_counts,
974            &gathered_byte_displs,
975        );
976
977        let mut events_by_rank = vec![Vec::new(); world.size() as usize];
978        for rank in 0..world.size() as usize {
979            if gathered_event_counts[rank] == 0 {
980                continue;
981            }
982            let byte_start = gathered_byte_displs[rank] as usize;
983            let byte_end = byte_start + gathered_byte_counts[rank] as usize;
984            let decoded: Vec<EventData> =
985                bitcode::deserialize(&gathered_bytes[byte_start..byte_end])?;
986            debug_assert_eq!(decoded.len(), gathered_event_counts[rank] as usize);
987            events_by_rank[rank] = decoded
988                .into_iter()
989                .map(|event| OwnedEvent::new(Arc::new(event), self.metadata.clone()))
990                .collect();
991        }
992
993        let mut offsets = vec![0usize; world.size() as usize];
994        let mut events = Vec::with_capacity(end - start);
995        for global_index in start..end {
996            let (owning_rank, _) =
997                layout.owner_of(global_index, total, self.n_events_local(), world);
998            let rank = owning_rank as usize;
999            let offset = offsets[rank];
1000            events.push(events_by_rank[rank][offset].clone());
1001            offsets[rank] += 1;
1002        }
1003        Ok(events)
1004    }
1005
1006    #[cfg(feature = "mpi")]
1007    pub(crate) fn set_cached_global_event_count_from_world(&mut self, world: &SimpleCommunicator) {
1008        let local_count = self.n_events_local();
1009        let mut global_count = 0usize;
1010        world.all_reduce_into(
1011            &local_count,
1012            &mut global_count,
1013            mpi::collective::SystemOperation::sum(),
1014        );
1015        self.cached_global_event_count = global_count;
1016    }
1017
1018    #[cfg(feature = "mpi")]
1019    pub(crate) fn set_cached_global_weighted_sum_from_world(&mut self, world: &SimpleCommunicator) {
1020        let mut weighted_sums = vec![0.0_f64; world.size() as usize];
1021        world.all_gather_into(&self.cached_local_weighted_sum, &mut weighted_sums);
1022        #[cfg(feature = "rayon")]
1023        {
1024            self.cached_global_weighted_sum = weighted_sums
1025                .into_par_iter()
1026                .parallel_sum_with_accumulator::<Klein<f64>>();
1027        }
1028        #[cfg(not(feature = "rayon"))]
1029        {
1030            self.cached_global_weighted_sum = weighted_sums
1031                .into_iter()
1032                .sum_with_accumulator::<Klein<f64>>();
1033        }
1034    }
1035
1036    fn columnar_from_events(
1037        events: &[Arc<EventData>],
1038        metadata: Arc<DatasetMetadata>,
1039    ) -> LadduResult<DatasetStorage> {
1040        let n_events = events.len();
1041        let (n_p4, n_aux) = match events.first() {
1042            Some(first) => (first.p4s.len(), first.aux.len()),
1043            None => (metadata.p4_names.len(), metadata.aux_names.len()),
1044        };
1045        let mut p4 = (0..n_p4)
1046            .map(|_| ColumnarP4Column::with_capacity(n_events))
1047            .collect::<Vec<_>>();
1048        let mut aux = (0..n_aux)
1049            .map(|_| Vec::with_capacity(n_events))
1050            .collect::<Vec<_>>();
1051        let mut weights = Vec::with_capacity(n_events);
1052        for (event_index, event) in events.iter().enumerate() {
1053            if event.p4s.len() != n_p4 || event.aux.len() != n_aux {
1054                return Err(LadduError::Custom(format!(
1055                    "Ragged dataset shape at event {event_index}: expected ({n_p4} p4, {n_aux} aux), got ({} p4, {} aux)",
1056                    event.p4s.len(),
1057                    event.aux.len()
1058                )));
1059            }
1060            for (column, value) in p4.iter_mut().zip(event.p4s.iter()) {
1061                column.push(*value);
1062            }
1063            for (column, value) in aux.iter_mut().zip(event.aux.iter()) {
1064                column.push(*value);
1065            }
1066            weights.push(event.weight);
1067        }
1068        Ok(DatasetStorage {
1069            metadata,
1070            p4,
1071            aux,
1072            weights,
1073        })
1074    }
1075
1076    /// Create a new [`Dataset`] from a list of [`EventData`] (non-MPI version).
1077    ///
1078    /// # Notes
1079    ///
1080    /// This method is not intended to be called in analyses but rather in writing methods
1081    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
1082    pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1083        let columnar = Self::columnar_from_events(&events, metadata.clone())
1084            .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1085        Self::from_columnar_storage(columnar, metadata, RowSelection::Identity)
1086    }
1087
1088    /// Create an empty local dataset with explicit metadata.
1089    ///
1090    /// The returned dataset is valid immediately and can be extended with
1091    /// [`Dataset::push_event_local`] or [`Dataset::push_event_named_local`].
1092    ///
1093    /// Under MPI, pushed rows are stored only on the rank that performs the
1094    /// push. Use [`Dataset::push_event_global`] for collective single-copy
1095    /// appends.
1096    pub fn empty_local(metadata: DatasetMetadata) -> Self {
1097        let metadata = Arc::new(metadata);
1098        #[cfg(feature = "mpi")]
1099        {
1100            if crate::mpi::get_world().is_some() {
1101                let dataset = Dataset {
1102                    columnar: Arc::new(DatasetStorage::empty_with_capacity(metadata.clone(), 0)),
1103                    rows: RowSelection::Identity,
1104                    metadata,
1105                    cached_local_weighted_sum: 0.0,
1106                    cached_global_event_count: 0,
1107                    cached_global_weighted_sum: 0.0,
1108                    mpi_layout: None,
1109                };
1110                return dataset;
1111            }
1112        }
1113        Dataset {
1114            columnar: Arc::new(DatasetStorage::empty_with_capacity(metadata.clone(), 0)),
1115            rows: RowSelection::Identity,
1116            metadata,
1117            cached_local_weighted_sum: 0.0,
1118            #[cfg(feature = "mpi")]
1119            cached_global_event_count: 0,
1120            #[cfg(feature = "mpi")]
1121            cached_global_weighted_sum: 0.0,
1122            #[cfg(feature = "mpi")]
1123            mpi_layout: None,
1124        }
1125    }
1126
1127    /// Create a local dataset from ordered four-momentum columns, auxiliary columns, and weights.
1128    ///
1129    /// `p4_columns` and `aux_columns` must be ordered to match the supplied metadata. Each
1130    /// column must have the same length as `weights`.
1131    pub fn from_columns_local(
1132        metadata: DatasetMetadata,
1133        p4_columns: Vec<Vec<Vec4>>,
1134        aux_columns: Vec<Vec<f64>>,
1135        weights: Vec<f64>,
1136    ) -> LadduResult<Self> {
1137        let n_events = weights.len();
1138        if p4_columns.len() != metadata.p4_names().len() {
1139            return Err(LadduError::Custom(format!(
1140                "Expected {} p4 columns, got {}",
1141                metadata.p4_names().len(),
1142                p4_columns.len()
1143            )));
1144        }
1145        if aux_columns.len() != metadata.aux_names().len() {
1146            return Err(LadduError::Custom(format!(
1147                "Expected {} aux columns, got {}",
1148                metadata.aux_names().len(),
1149                aux_columns.len()
1150            )));
1151        }
1152        for (index, column) in p4_columns.iter().enumerate() {
1153            if column.len() != n_events {
1154                return Err(LadduError::Custom(format!(
1155                    "P4 column {index} length {} does not match weight length {n_events}",
1156                    column.len()
1157                )));
1158            }
1159        }
1160        for (index, column) in aux_columns.iter().enumerate() {
1161            if column.len() != n_events {
1162                return Err(LadduError::Custom(format!(
1163                    "Aux column {index} length {} does not match weight length {n_events}",
1164                    column.len()
1165                )));
1166            }
1167        }
1168
1169        let events = (0..n_events)
1170            .map(|event_index| {
1171                Arc::new(EventData {
1172                    p4s: p4_columns
1173                        .iter()
1174                        .map(|column| column[event_index])
1175                        .collect(),
1176                    aux: aux_columns
1177                        .iter()
1178                        .map(|column| column[event_index])
1179                        .collect(),
1180                    weight: weights[event_index],
1181                })
1182            })
1183            .collect();
1184        Ok(Dataset::new_local(events, Arc::new(metadata)))
1185    }
1186
1187    /// Create a global dataset from ordered columns.
1188    ///
1189    /// Under MPI, every rank must pass the same global columns. The rows are
1190    /// partitioned across ranks using laddu's canonical contiguous partition.
1191    pub fn from_columns_global(
1192        metadata: DatasetMetadata,
1193        p4_columns: Vec<Vec<Vec4>>,
1194        aux_columns: Vec<Vec<f64>>,
1195        weights: Vec<f64>,
1196    ) -> LadduResult<Self> {
1197        let dataset = Self::from_columns_local(metadata, p4_columns, aux_columns, weights)?;
1198        let events = dataset.local_event_data_arcs();
1199        Ok(Dataset::new_with_metadata(events, dataset.metadata))
1200    }
1201
1202    /// Create a new [`Dataset`] from a list of [`EventData`] (MPI-compatible version).
1203    ///
1204    /// # Notes
1205    ///
1206    /// This method is not intended to be called in analyses but rather in writing methods
1207    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::new`] instead.
1208    #[cfg(feature = "mpi")]
1209    pub fn new_mpi(
1210        events: Vec<Arc<EventData>>,
1211        metadata: Arc<DatasetMetadata>,
1212        world: &SimpleCommunicator,
1213    ) -> Self {
1214        let partitions = Dataset::partition(events, world);
1215        let local = &partitions[world.rank() as usize];
1216        let columnar = Self::columnar_from_events(local, metadata.clone())
1217            .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1218        let local_weighted_sum = local_weighted_sum(&columnar.weights);
1219        let mut dataset = Dataset {
1220            columnar: Arc::new(columnar),
1221            rows: RowSelection::Identity,
1222            metadata,
1223            cached_local_weighted_sum: local_weighted_sum,
1224            cached_global_event_count: 0,
1225            cached_global_weighted_sum: local_weighted_sum,
1226            mpi_layout: Some(MpiDatasetLayout::Canonical),
1227        };
1228        dataset.set_cached_global_event_count_from_world(world);
1229        dataset.set_cached_global_weighted_sum_from_world(world);
1230        dataset
1231    }
1232
1233    /// Create a new [`Dataset`] from a list of [`EventData`].
1234    ///
1235    /// This method is prefered for external use because it contains proper MPI construction
1236    /// methods. Constructing a [`Dataset`] manually is possible, but may cause issues when
1237    /// interfacing with MPI and should be avoided unless you know what you are doing.
1238    pub fn new(events: Vec<Arc<EventData>>) -> Self {
1239        Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
1240    }
1241
1242    /// Create a dataset with explicit metadata for name-based lookups.
1243    /// Create a dataset with explicit metadata for name-based lookups.
1244    pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1245        #[cfg(feature = "mpi")]
1246        {
1247            if let Some(world) = crate::mpi::get_world() {
1248                return Dataset::new_mpi(events, metadata, &world);
1249            }
1250        }
1251        Dataset::new_local(events, metadata)
1252    }
1253
1254    fn push_event_data_local(&mut self, event_data: Arc<EventData>) -> LadduResult<()> {
1255        self.ensure_mutable_storage("push events")?;
1256        Arc::make_mut(&mut self.columnar).push_event_data(&event_data);
1257        self.refresh_local_weight_cache();
1258        Ok(())
1259    }
1260
1261    fn replace_metadata(&mut self, metadata: DatasetMetadata) {
1262        let metadata = Arc::new(metadata);
1263        self.metadata = metadata.clone();
1264        Arc::make_mut(&mut self.columnar).set_metadata(metadata);
1265    }
1266
1267    fn validate_p4_column_len(&self, name: &str, len: usize) -> LadduResult<()> {
1268        if len != self.n_events_local() {
1269            return Err(LadduError::LengthMismatch {
1270                context: format!("P4 column '{name}'"),
1271                expected: self.n_events_local(),
1272                actual: len,
1273            });
1274        }
1275        Ok(())
1276    }
1277
1278    fn validate_aux_column_len(&self, name: &str, len: usize) -> LadduResult<()> {
1279        if len != self.n_events_local() {
1280            return Err(LadduError::LengthMismatch {
1281                context: format!("Aux column '{name}'"),
1282                expected: self.n_events_local(),
1283                actual: len,
1284            });
1285        }
1286        Ok(())
1287    }
1288
1289    fn add_p4_column_unchecked(&mut self, name: String, values: Vec<Vec4>) -> LadduResult<()> {
1290        let mut metadata = (*self.metadata).clone();
1291        metadata.add_p4_name(name)?;
1292        Arc::make_mut(&mut self.columnar).push_p4_column(values);
1293        self.replace_metadata(metadata);
1294        Ok(())
1295    }
1296
1297    fn add_aux_column_unchecked(&mut self, name: String, values: Vec<f64>) -> LadduResult<()> {
1298        let mut metadata = (*self.metadata).clone();
1299        metadata.add_aux_name(name)?;
1300        Arc::make_mut(&mut self.columnar).push_aux_column(values);
1301        self.replace_metadata(metadata);
1302        Ok(())
1303    }
1304
1305    /// Add a four-momentum column to the current rank only.
1306    ///
1307    /// This method is non-collective. Under MPI it is only valid for datasets
1308    /// without an MPI layout; use [`Dataset::add_p4_column_global`] for shared
1309    /// MPI datasets.
1310    pub fn add_p4_column_local<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1311    where
1312        N: Into<String>,
1313        V: IntoIterator<Item = Vec4>,
1314    {
1315        self.ensure_mutable_storage("add a p4 column")?;
1316        #[cfg(feature = "mpi")]
1317        {
1318            if self.mpi_layout.is_some() {
1319                return Err(LadduError::Custom(
1320                    "Cannot add a local p4 column to an MPI dataset; use add_p4_column_global"
1321                        .to_string(),
1322                ));
1323            }
1324        }
1325        let name = name.into();
1326        let values = values.into_iter().collect::<Vec<_>>();
1327        self.metadata.ensure_new_p4_name(&name)?;
1328        self.validate_p4_column_len(&name, values.len())?;
1329        self.add_p4_column_unchecked(name, values)
1330    }
1331
1332    /// Add an auxiliary scalar column to the current rank only.
1333    ///
1334    /// This method is non-collective. Under MPI it is only valid for datasets
1335    /// without an MPI layout; use [`Dataset::add_aux_column_global`] for shared
1336    /// MPI datasets.
1337    pub fn add_aux_column_local<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1338    where
1339        N: Into<String>,
1340        V: IntoIterator<Item = f64>,
1341    {
1342        self.ensure_mutable_storage("add an aux column")?;
1343        #[cfg(feature = "mpi")]
1344        {
1345            if self.mpi_layout.is_some() {
1346                return Err(LadduError::Custom(
1347                    "Cannot add a local aux column to an MPI dataset; use add_aux_column_global"
1348                        .to_string(),
1349                ));
1350            }
1351        }
1352        let name = name.into();
1353        let values = values.into_iter().collect::<Vec<_>>();
1354        self.metadata.ensure_new_aux_name(&name)?;
1355        self.validate_aux_column_len(&name, values.len())?;
1356        self.add_aux_column_unchecked(name, values)
1357    }
1358
1359    /// Add a four-momentum column collectively across all MPI ranks.
1360    ///
1361    /// Under MPI, every rank must call this method in the same order with the
1362    /// same column name. Each rank supplies values for its local events only.
1363    pub fn add_p4_column_global<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1364    where
1365        N: Into<String>,
1366        V: IntoIterator<Item = Vec4>,
1367    {
1368        self.ensure_mutable_storage("add a p4 column")?;
1369        let name = name.into();
1370        let values = values.into_iter().collect::<Vec<_>>();
1371        #[cfg(feature = "mpi")]
1372        {
1373            if crate::mpi::get_world().is_some() {
1374                self.validate_global_column_add(
1375                    ColumnMutationKind::P4,
1376                    &name,
1377                    values.len() == self.n_events_local(),
1378                )?;
1379                self.metadata.ensure_new_p4_name(&name)?;
1380                self.validate_p4_column_len(&name, values.len())?;
1381                return self.add_p4_column_unchecked(name, values);
1382            }
1383        }
1384        self.add_p4_column_local(name, values)
1385    }
1386
1387    /// Add an auxiliary scalar column collectively across all MPI ranks.
1388    ///
1389    /// Under MPI, every rank must call this method in the same order with the
1390    /// same column name. Each rank supplies values for its local events only.
1391    pub fn add_aux_column_global<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1392    where
1393        N: Into<String>,
1394        V: IntoIterator<Item = f64>,
1395    {
1396        self.ensure_mutable_storage("add an aux column")?;
1397        let name = name.into();
1398        let values = values.into_iter().collect::<Vec<_>>();
1399        #[cfg(feature = "mpi")]
1400        {
1401            if crate::mpi::get_world().is_some() {
1402                self.validate_global_column_add(
1403                    ColumnMutationKind::Aux,
1404                    &name,
1405                    values.len() == self.n_events_local(),
1406                )?;
1407                self.metadata.ensure_new_aux_name(&name)?;
1408                self.validate_aux_column_len(&name, values.len())?;
1409                return self.add_aux_column_unchecked(name, values);
1410            }
1411        }
1412        self.add_aux_column_local(name, values)
1413    }
1414
1415    /// Append one ordered event row to the current rank.
1416    ///
1417    /// `p4s` and `aux` must be ordered to match [`Dataset::p4_names`] and
1418    /// [`Dataset::aux_names`].
1419    ///
1420    /// Under MPI, this method performs no communication beyond refreshing
1421    /// cached global counts. Calling it on every rank appends one row per rank.
1422    pub fn push_event_local<P, A>(&mut self, p4s: P, aux: A, weight: f64) -> LadduResult<()>
1423    where
1424        P: IntoIterator<Item = Vec4>,
1425        A: IntoIterator<Item = f64>,
1426    {
1427        self.ensure_mutable_storage("push events")?;
1428        #[cfg(feature = "mpi")]
1429        {
1430            if self.mpi_layout == Some(MpiDatasetLayout::RoundRobin) && self.n_events() > 0 {
1431                return Err(LadduError::Custom(
1432                    "Cannot push local events into a round-robin global dataset".to_string(),
1433                ));
1434            }
1435            self.mpi_layout = None;
1436        }
1437        let p4s = p4s.into_iter().collect::<Vec<_>>();
1438        let aux = aux.into_iter().collect::<Vec<_>>();
1439        if p4s.len() != self.metadata.p4_names().len() {
1440            return Err(LadduError::Custom(format!(
1441                "Expected {} p4 values, got {}",
1442                self.metadata.p4_names().len(),
1443                p4s.len()
1444            )));
1445        }
1446        if aux.len() != self.metadata.aux_names().len() {
1447            return Err(LadduError::Custom(format!(
1448                "Expected {} aux values, got {}",
1449                self.metadata.aux_names().len(),
1450                aux.len()
1451            )));
1452        }
1453
1454        let event_data = Arc::new(EventData { p4s, aux, weight });
1455        self.push_event_data_local(event_data)
1456    }
1457
1458    /// Append one ordered event row collectively as a single global event.
1459    ///
1460    /// Under MPI, this method is collective. Exactly one rank stores the event,
1461    /// selected by `next_global_index % n_ranks`; non-owning ranks ignore their
1462    /// supplied row values. All ranks must call this method in the same order.
1463    pub fn push_event_global<P, A>(&mut self, p4s: P, aux: A, weight: f64) -> LadduResult<()>
1464    where
1465        P: IntoIterator<Item = Vec4>,
1466        A: IntoIterator<Item = f64>,
1467    {
1468        self.ensure_mutable_storage("push events")?;
1469        let p4s = p4s.into_iter().collect::<Vec<_>>();
1470        let aux = aux.into_iter().collect::<Vec<_>>();
1471        if p4s.len() != self.metadata.p4_names().len() {
1472            return Err(LadduError::Custom(format!(
1473                "Expected {} p4 values, got {}",
1474                self.metadata.p4_names().len(),
1475                p4s.len()
1476            )));
1477        }
1478        if aux.len() != self.metadata.aux_names().len() {
1479            return Err(LadduError::Custom(format!(
1480                "Expected {} aux values, got {}",
1481                self.metadata.aux_names().len(),
1482                aux.len()
1483            )));
1484        }
1485
1486        #[cfg(feature = "mpi")]
1487        {
1488            if let Some(world) = crate::mpi::get_world() {
1489                if self.mpi_layout != Some(MpiDatasetLayout::RoundRobin) && self.n_events() > 0 {
1490                    return Err(LadduError::Custom(
1491                        "Cannot push round-robin global events into a non-empty local/canonical dataset"
1492                            .to_string(),
1493                    ));
1494                }
1495                self.mpi_layout = Some(MpiDatasetLayout::RoundRobin);
1496                let global_index = self.n_events();
1497                if global_index % world.size() as usize == world.rank() as usize {
1498                    self.push_event_data_local(Arc::new(EventData { p4s, aux, weight }))?;
1499                } else {
1500                    self.refresh_local_weight_cache();
1501                }
1502                return Ok(());
1503            }
1504        }
1505
1506        self.push_event_data_local(Arc::new(EventData { p4s, aux, weight }))
1507    }
1508
1509    /// Append one named event row to the current rank.
1510    ///
1511    /// The supplied p4 and aux names must exactly match this dataset's metadata, regardless of
1512    /// order. Duplicate, missing, and unknown names are rejected.
1513    pub fn push_event_named_local<P, PN, A, AN>(
1514        &mut self,
1515        p4s: P,
1516        aux: A,
1517        weight: f64,
1518    ) -> LadduResult<()>
1519    where
1520        P: IntoIterator<Item = (PN, Vec4)>,
1521        PN: AsRef<str>,
1522        A: IntoIterator<Item = (AN, f64)>,
1523        AN: AsRef<str>,
1524    {
1525        let mut ordered_p4s = vec![None; self.metadata.p4_names().len()];
1526        for (name, p4) in p4s {
1527            let name = name.as_ref();
1528            let index = self
1529                .metadata
1530                .p4_index(name)
1531                .ok_or_else(|| LadduError::UnknownName {
1532                    category: "p4",
1533                    name: name.to_string(),
1534                })?;
1535            if ordered_p4s[index].replace(p4).is_some() {
1536                return Err(LadduError::DuplicateName {
1537                    category: "p4",
1538                    name: name.to_string(),
1539                });
1540            }
1541        }
1542        let mut ordered_aux = vec![None; self.metadata.aux_names().len()];
1543        for (name, value) in aux {
1544            let name = name.as_ref();
1545            let index = self
1546                .metadata
1547                .aux_index(name)
1548                .ok_or_else(|| LadduError::UnknownName {
1549                    category: "aux",
1550                    name: name.to_string(),
1551                })?;
1552            if ordered_aux[index].replace(value).is_some() {
1553                return Err(LadduError::DuplicateName {
1554                    category: "aux",
1555                    name: name.to_string(),
1556                });
1557            }
1558        }
1559
1560        let p4s = ordered_p4s
1561            .into_iter()
1562            .enumerate()
1563            .map(|(index, value)| {
1564                value.ok_or_else(|| {
1565                    LadduError::Custom(format!(
1566                        "Missing p4 value for '{}'",
1567                        self.metadata.p4_names()[index]
1568                    ))
1569                })
1570            })
1571            .collect::<LadduResult<Vec<_>>>()?;
1572        let aux = ordered_aux
1573            .into_iter()
1574            .enumerate()
1575            .map(|(index, value)| {
1576                value.ok_or_else(|| {
1577                    LadduError::Custom(format!(
1578                        "Missing aux value for '{}'",
1579                        self.metadata.aux_names()[index]
1580                    ))
1581                })
1582            })
1583            .collect::<LadduResult<Vec<_>>>()?;
1584
1585        self.push_event_local(p4s, aux, weight)
1586    }
1587
1588    /// Append one named event row collectively as a single global event.
1589    ///
1590    /// Under MPI, this method is collective. Exactly one rank stores the event,
1591    /// selected by `next_global_index % n_ranks`; non-owning ranks ignore their
1592    /// supplied row values. All ranks must call this method in the same order.
1593    pub fn push_event_named_global<P, PN, A, AN>(
1594        &mut self,
1595        p4s: P,
1596        aux: A,
1597        weight: f64,
1598    ) -> LadduResult<()>
1599    where
1600        P: IntoIterator<Item = (PN, Vec4)>,
1601        PN: AsRef<str>,
1602        A: IntoIterator<Item = (AN, f64)>,
1603        AN: AsRef<str>,
1604    {
1605        let mut ordered_p4s = vec![None; self.metadata.p4_names().len()];
1606        for (name, p4) in p4s {
1607            let name = name.as_ref();
1608            let index = self
1609                .metadata
1610                .p4_index(name)
1611                .ok_or_else(|| LadduError::UnknownName {
1612                    category: "p4",
1613                    name: name.to_string(),
1614                })?;
1615            if ordered_p4s[index].replace(p4).is_some() {
1616                return Err(LadduError::DuplicateName {
1617                    category: "p4",
1618                    name: name.to_string(),
1619                });
1620            }
1621        }
1622        let mut ordered_aux = vec![None; self.metadata.aux_names().len()];
1623        for (name, value) in aux {
1624            let name = name.as_ref();
1625            let index = self
1626                .metadata
1627                .aux_index(name)
1628                .ok_or_else(|| LadduError::UnknownName {
1629                    category: "aux",
1630                    name: name.to_string(),
1631                })?;
1632            if ordered_aux[index].replace(value).is_some() {
1633                return Err(LadduError::DuplicateName {
1634                    category: "aux",
1635                    name: name.to_string(),
1636                });
1637            }
1638        }
1639
1640        let p4s = ordered_p4s
1641            .into_iter()
1642            .enumerate()
1643            .map(|(index, value)| {
1644                value.ok_or_else(|| {
1645                    LadduError::Custom(format!(
1646                        "Missing p4 value for '{}'",
1647                        self.metadata.p4_names()[index]
1648                    ))
1649                })
1650            })
1651            .collect::<LadduResult<Vec<_>>>()?;
1652        let aux = ordered_aux
1653            .into_iter()
1654            .enumerate()
1655            .map(|(index, value)| {
1656                value.ok_or_else(|| {
1657                    LadduError::Custom(format!(
1658                        "Missing aux value for '{}'",
1659                        self.metadata.aux_names()[index]
1660                    ))
1661                })
1662            })
1663            .collect::<LadduResult<Vec<_>>>()?;
1664
1665        self.push_event_global(p4s, aux, weight)
1666    }
1667
1668    /// The number of [`EventData`]s in the [`Dataset`] (non-MPI version).
1669    ///
1670    /// # Notes
1671    ///
1672    /// This method is not intended to be called in analyses but rather in writing methods
1673    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
1674    pub fn n_events_local(&self) -> usize {
1675        self.rows.len(self.columnar.n_events())
1676    }
1677
1678    /// The number of [`EventData`]s in the [`Dataset`] (MPI-compatible version).
1679    ///
1680    /// # Notes
1681    ///
1682    /// This method is not intended to be called in analyses but rather in writing methods
1683    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events`] instead.
1684    #[cfg(feature = "mpi")]
1685    pub fn n_events_mpi(&self, _world: &SimpleCommunicator) -> usize {
1686        self.cached_global_event_count
1687    }
1688
1689    /// The number of [`EventData`]s in the [`Dataset`].
1690    pub fn n_events(&self) -> usize {
1691        #[cfg(feature = "mpi")]
1692        {
1693            if self.mpi_layout.is_some() {
1694                if let Some(world) = crate::mpi::get_world() {
1695                    return self.n_events_mpi(&world);
1696                }
1697            }
1698        }
1699        self.n_events_local()
1700    }
1701
1702    /// Alias for [`Dataset::n_events`].
1703    ///
1704    /// This returns the global event count under MPI.
1705    pub fn n_events_global(&self) -> usize {
1706        self.n_events()
1707    }
1708}
1709
1710impl Dataset {
1711    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (non-MPI version).
1712    ///
1713    /// # Notes
1714    ///
1715    /// This method is not intended to be called in analyses but rather in writing methods
1716    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
1717    pub fn weights_local(&self) -> Vec<f64> {
1718        match &self.rows {
1719            RowSelection::Identity => self.columnar.weights.clone(),
1720            RowSelection::Indices(indices) => indices
1721                .iter()
1722                .map(|index| self.columnar.weight(*index))
1723                .collect(),
1724        }
1725    }
1726
1727    /// Extract a list of weights over each [`EventData`] in the [`Dataset`] (MPI-compatible version).
1728    ///
1729    /// # Notes
1730    ///
1731    /// This method is not intended to be called in analyses but rather in writing methods
1732    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::weights`] instead.
1733    #[cfg(feature = "mpi")]
1734    pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
1735        if matches!(
1736            self.mpi_layout,
1737            Some(MpiDatasetLayout::RoundRobin | MpiDatasetLayout::Derived)
1738        ) {
1739            return self.events_global().map(|event| event.weight()).collect();
1740        }
1741        let local_weights = self.weights_local();
1742        let n_events = self.n_events();
1743        let mut buffer: Vec<f64> = vec![0.0; n_events];
1744        let (counts, displs) = world.get_counts_displs(n_events);
1745        {
1746            // NOTE: gather is required because this API returns full global event weights.
1747            // Use all-reduce only for scalar/vector aggregate values.
1748            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1749            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
1750        }
1751        buffer
1752    }
1753
1754    /// Extract a list of weights over each [`EventData`] in the [`Dataset`].
1755    pub fn weights(&self) -> Vec<f64> {
1756        #[cfg(feature = "mpi")]
1757        {
1758            if self.mpi_layout.is_some() {
1759                if let Some(world) = crate::mpi::get_world() {
1760                    return self.weights_mpi(&world);
1761                }
1762            }
1763        }
1764        self.weights_local()
1765    }
1766
1767    /// Alias for [`Dataset::weights`].
1768    ///
1769    /// This returns the global weight vector in dataset order under MPI.
1770    pub fn weights_global(&self) -> Vec<f64> {
1771        self.weights()
1772    }
1773
1774    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (non-MPI version).
1775    ///
1776    /// # Notes
1777    ///
1778    /// This method is not intended to be called in analyses but rather in writing methods
1779    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
1780    pub fn n_events_weighted_local(&self) -> f64 {
1781        self.cached_local_weighted_sum
1782    }
1783    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`] (MPI-compatible version).
1784    ///
1785    /// # Notes
1786    ///
1787    /// This method is not intended to be called in analyses but rather in writing methods
1788    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::n_events_weighted`] instead.
1789    #[cfg(feature = "mpi")]
1790    pub fn n_events_weighted_mpi(&self, _world: &SimpleCommunicator) -> f64 {
1791        self.cached_global_weighted_sum
1792    }
1793
1794    /// Returns the sum of the weights for each [`EventData`] in the [`Dataset`].
1795    pub fn n_events_weighted(&self) -> f64 {
1796        #[cfg(feature = "mpi")]
1797        {
1798            if self.mpi_layout.is_some() {
1799                if let Some(world) = crate::mpi::get_world() {
1800                    return self.n_events_weighted_mpi(&world);
1801                }
1802            }
1803        }
1804        self.n_events_weighted_local()
1805    }
1806
1807    /// Alias for [`Dataset::n_events_weighted`].
1808    ///
1809    /// This returns the global weighted event count under MPI.
1810    pub fn n_events_weighted_global(&self) -> f64 {
1811        self.n_events_weighted()
1812    }
1813
1814    /// Generate a new dataset with the same length by resampling the events in the original datset
1815    /// with replacement. This can be used to perform error analysis via the bootstrap method. (non-MPI version).
1816    ///
1817    /// # Notes
1818    ///
1819    /// This method is not intended to be called in analyses but rather in writing methods
1820    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
1821    pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
1822        let mut rng = fastrand::Rng::with_seed(seed as u64);
1823        let n_events = self.n_events_local();
1824        let mut indices: Vec<usize> = (0..n_events)
1825            .map(|_| rng.usize(0..n_events))
1826            .collect::<Vec<usize>>();
1827        indices.sort();
1828        self.indexed_local_view(
1829            indices
1830                .into_iter()
1831                .map(|index| self.rows.physical_index(index)),
1832        )
1833    }
1834
1835    /// Generate a new dataset with the same length by resampling the events in the original datset
1836    /// with replacement. This can be used to perform error analysis via the bootstrap method. (MPI-compatible version).
1837    ///
1838    /// # Notes
1839    ///
1840    /// This method is not intended to be called in analyses but rather in writing methods
1841    /// that have `mpi`-feature-gated versions. Most users should just call [`Dataset::bootstrap`] instead.
1842    #[cfg(feature = "mpi")]
1843    pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
1844        let n_events = self.n_events();
1845        let mut indices: Vec<usize> = vec![0; n_events];
1846        if world.is_root() {
1847            let mut rng = fastrand::Rng::with_seed(seed as u64);
1848            indices = (0..n_events)
1849                .map(|_| rng.usize(0..n_events))
1850                .collect::<Vec<usize>>();
1851            indices.sort();
1852        }
1853        world.process_at_root().broadcast_into(&mut indices);
1854        let local_indices: Vec<usize> = indices
1855            .into_iter()
1856            .filter_map(|idx| {
1857                let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
1858                if world.rank() == owning_rank {
1859                    Some(local_index)
1860                } else {
1861                    None
1862                }
1863            })
1864            .collect();
1865        self.indexed_local_view(
1866            local_indices
1867                .into_iter()
1868                .map(|index| self.rows.physical_index(index)),
1869        )
1870    }
1871
1872    /// Generate a new dataset with the same length by resampling the events in the original datset
1873    /// with replacement. This can be used to perform error analysis via the bootstrap method.
1874    pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1875        #[cfg(feature = "mpi")]
1876        {
1877            if let Some(world) = crate::mpi::get_world() {
1878                return self.bootstrap_mpi(seed, &world);
1879            }
1880        }
1881        self.bootstrap_local(seed)
1882    }
1883
1884    /// Filter the [`Dataset`] by a given [`VariableExpression`], selecting events for which
1885    /// the expression returns `true`.
1886    pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1887        let compiled = expression.compile(&self.metadata)?;
1888        #[cfg(feature = "rayon")]
1889        let filtered_indices: Vec<usize> = (0..self.n_events_local())
1890            .into_par_iter()
1891            .filter_map(|event_index| {
1892                let event = self.event_view(event_index);
1893                compiled
1894                    .evaluate(&event)
1895                    .then(|| self.rows.physical_index(event_index))
1896            })
1897            .collect();
1898        #[cfg(not(feature = "rayon"))]
1899        let filtered_indices: Vec<usize> = (0..self.n_events_local())
1900            .into_iter()
1901            .filter_map(|event_index| {
1902                let event = self.event_view(event_index);
1903                compiled
1904                    .evaluate(&event)
1905                    .then(|| self.rows.physical_index(event_index))
1906            })
1907            .collect();
1908        Ok(self.indexed_local_view(filtered_indices))
1909    }
1910
1911    /// Bin a [`Dataset`] by the value of the given [`Variable`] into a number of `bins` within the
1912    /// given `range`.
1913    pub fn bin_by<V>(
1914        &self,
1915        mut variable: V,
1916        bins: usize,
1917        range: (f64, f64),
1918    ) -> LadduResult<BinnedDataset>
1919    where
1920        V: Variable,
1921    {
1922        variable.bind(self.metadata())?;
1923        let bin_width = (range.1 - range.0) / bins as f64;
1924        let bin_edges = get_bin_edges(bins, range);
1925        let variable = variable;
1926        #[cfg(feature = "rayon")]
1927        let evaluated: Vec<(usize, usize)> = (0..self.n_events_local())
1928            .into_par_iter()
1929            .filter_map(|event| {
1930                let value = variable.value(&self.event_view(event));
1931                if value >= range.0 && value < range.1 {
1932                    let bin_index = ((value - range.0) / bin_width) as usize;
1933                    let bin_index = bin_index.min(bins - 1);
1934                    Some((bin_index, self.rows.physical_index(event)))
1935                } else {
1936                    None
1937                }
1938            })
1939            .collect();
1940        #[cfg(not(feature = "rayon"))]
1941        let evaluated: Vec<(usize, usize)> = (0..self.n_events_local())
1942            .into_iter()
1943            .filter_map(|event| {
1944                let value = variable.value(&self.event_view(event));
1945                if value >= range.0 && value < range.1 {
1946                    let bin_index = ((value - range.0) / bin_width) as usize;
1947                    let bin_index = bin_index.min(bins - 1);
1948                    Some((bin_index, self.rows.physical_index(event)))
1949                } else {
1950                    None
1951                }
1952            })
1953            .collect();
1954        let mut binned_indices: Vec<Vec<usize>> = vec![Vec::default(); bins];
1955        for (bin_index, index) in evaluated {
1956            binned_indices[bin_index].push(index);
1957        }
1958        #[cfg(feature = "rayon")]
1959        let datasets: Vec<Arc<Dataset>> = binned_indices
1960            .into_par_iter()
1961            .map(|indices| self.indexed_local_view(indices))
1962            .collect();
1963        #[cfg(not(feature = "rayon"))]
1964        let datasets: Vec<Arc<Dataset>> = binned_indices
1965            .into_iter()
1966            .map(|indices| self.indexed_local_view(indices))
1967            .collect();
1968        Ok(BinnedDataset {
1969            datasets,
1970            edges: bin_edges,
1971        })
1972    }
1973
1974    /// Boost all the four-momenta in all [`EventData`]s to the rest frame of the given set of
1975    /// four-momenta identified by name.
1976    pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1977    where
1978        S: AsRef<str>,
1979    {
1980        let mut indices: Vec<usize> = Vec::new();
1981        for name in names {
1982            let name_ref = name.as_ref();
1983            if let Some(selection) = self.metadata.p4_selection(name_ref) {
1984                indices.extend_from_slice(selection.indices());
1985            } else {
1986                panic!("Unknown particle name '{name}'", name = name_ref);
1987            }
1988        }
1989        #[cfg(feature = "rayon")]
1990        let boosted_events: Vec<Arc<EventData>> = self
1991            .local_event_data_arcs()
1992            .into_par_iter()
1993            .map(|event| Arc::new(event.boost_to_rest_frame_of(&indices)))
1994            .collect();
1995        #[cfg(not(feature = "rayon"))]
1996        let boosted_events: Vec<Arc<EventData>> = self
1997            .local_event_data_arcs()
1998            .into_iter()
1999            .map(|event| Arc::new(event.boost_to_rest_frame_of(&indices)))
2000            .collect();
2001        Arc::new(Dataset::new_with_metadata(
2002            boosted_events,
2003            self.metadata.clone(),
2004        ))
2005    }
2006    /// Evaluate a [`Variable`] on every event in the [`Dataset`].
2007    pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
2008        variable.value_on(self)
2009    }
2010}
2011
2012#[cfg(test)]
2013pub(crate) use super::io::write_parquet_storage;
2014pub use super::io::{
2015    read_parquet, read_parquet_chunks, read_parquet_chunks_with_options, read_root, write_parquet,
2016    write_root,
2017};
2018#[cfg(test)]
2019pub(crate) use super::io::{read_parquet_storage, read_root_storage};
2020
2021impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
2022    debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
2023    debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
2024    let events = a
2025        .local_event_data_arcs()
2026        .into_iter()
2027        .chain(b.local_event_data_arcs())
2028        .collect::<Vec<_>>();
2029    Dataset::new_with_metadata(events, a.metadata.clone())
2030});
2031
2032/// Incrementally builds a [`Dataset`] from chunked dataset reads.
2033#[derive(Default)]
2034pub struct DatasetChunkBuilder {
2035    metadata: Option<Arc<DatasetMetadata>>,
2036    events: Vec<Arc<EventData>>,
2037}
2038
2039impl DatasetChunkBuilder {
2040    /// Create an empty chunk builder.
2041    pub fn new() -> Self {
2042        Self::default()
2043    }
2044
2045    /// Append a dataset chunk.
2046    pub fn push_chunk(&mut self, chunk: &Dataset) -> LadduResult<()> {
2047        if let Some(existing) = &self.metadata {
2048            if existing.p4_names != chunk.metadata.p4_names
2049                || existing.aux_names != chunk.metadata.aux_names
2050            {
2051                return Err(LadduError::Custom(
2052                    "Dataset chunk metadata does not match previous chunks".to_string(),
2053                ));
2054            }
2055        } else {
2056            self.metadata = Some(chunk.metadata.clone());
2057        }
2058        self.events.extend(chunk.local_event_data_arcs());
2059        Ok(())
2060    }
2061
2062    /// Finish building a dataset from all received chunks.
2063    pub fn finish(self) -> Arc<Dataset> {
2064        let metadata = self
2065            .metadata
2066            .unwrap_or_else(|| Arc::new(DatasetMetadata::empty()));
2067        Arc::new(Dataset::new_with_metadata(self.events, metadata))
2068    }
2069}
2070
2071/// Fold over chunked datasets without materializing a full dataset.
2072pub fn try_fold_dataset_chunks<I, T, F>(chunks: I, init: T, mut op: F) -> LadduResult<T>
2073where
2074    I: IntoIterator<Item = LadduResult<Arc<Dataset>>>,
2075    F: FnMut(T, &Dataset) -> LadduResult<T>,
2076{
2077    let mut acc = init;
2078    for chunk in chunks {
2079        let chunk = chunk?;
2080        acc = op(acc, &chunk)?;
2081    }
2082    Ok(acc)
2083}
2084
2085/// Options for reading a [`Dataset`] from a file.
2086///
2087/// # See Also
2088/// [`read_parquet`], [`read_root`]
2089#[derive(Default, Clone)]
2090pub struct DatasetReadOptions {
2091    /// Particle names to read from the data file.
2092    pub p4_names: Option<Vec<String>>,
2093    /// Auxiliary scalar names to read from the data file.
2094    pub aux_names: Option<Vec<String>>,
2095    /// Name of the tree to read when loading ROOT files. When absent and the file contains a
2096    /// single tree, it will be selected automatically.
2097    pub tree: Option<String>,
2098    /// Optional aliases mapping logical names to selections of four-momenta.
2099    pub aliases: IndexMap<String, P4Selection>,
2100    /// Preferred chunk size for chunked read APIs.
2101    pub chunk_size: Option<usize>,
2102}
2103
2104/// Precision for writing floating-point columns.
2105#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
2106pub enum FloatPrecision {
2107    /// 32-bit floats.
2108    F32,
2109    /// 64-bit floats.
2110    #[default]
2111    F64,
2112}
2113
2114/// Options for writing a [`Dataset`] to disk.
2115#[derive(Clone, Debug)]
2116pub struct DatasetWriteOptions {
2117    /// Number of events to include in each batch when writing.
2118    pub batch_size: usize,
2119    /// Floating-point precision to use for persisted columns.
2120    pub precision: FloatPrecision,
2121    /// Tree name to use when writing ROOT files.
2122    pub tree: Option<String>,
2123}
2124
2125impl Default for DatasetWriteOptions {
2126    fn default() -> Self {
2127        Self {
2128            batch_size: DEFAULT_WRITE_BATCH_SIZE,
2129            precision: FloatPrecision::default(),
2130            tree: None,
2131        }
2132    }
2133}
2134
2135impl DatasetWriteOptions {
2136    /// Override the batch size used for writing; defaults to 10_000.
2137    pub fn batch_size(mut self, batch_size: usize) -> Self {
2138        self.batch_size = batch_size;
2139        self
2140    }
2141
2142    /// Select the floating-point precision for persisted columns.
2143    pub fn precision(mut self, precision: FloatPrecision) -> Self {
2144        self.precision = precision;
2145        self
2146    }
2147
2148    /// Set the ROOT tree name (defaults to \"events\").
2149    pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
2150        self.tree = Some(name.into());
2151        self
2152    }
2153}
2154impl DatasetReadOptions {
2155    /// Create a new [`Default`] set of [`DatasetReadOptions`].
2156    pub fn new() -> Self {
2157        Self::default()
2158    }
2159
2160    /// If provided, the specified particles will be read from the data file (assuming columns with
2161    /// required suffixes are present, i.e. `<particle>_px`, `<particle>_py`, `<particle>_pz`, and `<particle>_e`). Otherwise, all valid columns with these suffixes will be read.
2162    pub fn p4_names<I, S>(mut self, names: I) -> Self
2163    where
2164        I: IntoIterator<Item = S>,
2165        S: AsRef<str>,
2166    {
2167        self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
2168        self
2169    }
2170
2171    /// If provided, the specified columns will be read as auxiliary scalars. Otherwise, all valid
2172    /// columns which do not satisfy the conditions required to be read as four-momenta will be
2173    /// used.
2174    pub fn aux_names<I, S>(mut self, names: I) -> Self
2175    where
2176        I: IntoIterator<Item = S>,
2177        S: AsRef<str>,
2178    {
2179        self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
2180        self
2181    }
2182
2183    /// Select the tree to read when opening ROOT files.
2184    pub fn tree<S>(mut self, name: S) -> Self
2185    where
2186        S: AsRef<str>,
2187    {
2188        self.tree = Some(name.as_ref().to_string());
2189        self
2190    }
2191
2192    /// Register an alias for one or more existing four-momenta.
2193    pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
2194    where
2195        N: Into<String>,
2196        S: IntoP4Selection,
2197    {
2198        self.aliases.insert(name.into(), selection.into_selection());
2199        self
2200    }
2201
2202    /// Register multiple aliases for four-momenta selections.
2203    pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
2204    where
2205        I: IntoIterator<Item = (N, S)>,
2206        N: Into<String>,
2207        S: IntoP4Selection,
2208    {
2209        for (name, selection) in aliases {
2210            self = self.alias(name, selection);
2211        }
2212        self
2213    }
2214
2215    /// Set the chunk size used by chunked read APIs; values below 1 are clamped to 1.
2216    pub fn chunk_size(mut self, chunk_size: usize) -> Self {
2217        self.chunk_size = Some(chunk_size.max(1));
2218        self
2219    }
2220
2221    pub(crate) fn resolve_metadata(
2222        &self,
2223        detected_p4_names: Vec<String>,
2224        detected_aux_names: Vec<String>,
2225    ) -> LadduResult<Arc<DatasetMetadata>> {
2226        let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
2227        let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
2228
2229        let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
2230        if !self.aliases.is_empty() {
2231            metadata.add_p4_aliases(self.aliases.clone())?;
2232        }
2233        Ok(Arc::new(metadata))
2234    }
2235}
2236
2237const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
2238pub(crate) const DEFAULT_READ_CHUNK_SIZE: usize = 10_000;
2239
2240/// A list of [`Dataset`]s formed by binning [`EventData`] by some [`Variable`].
2241pub struct BinnedDataset {
2242    datasets: Vec<Arc<Dataset>>,
2243    edges: Vec<f64>,
2244}
2245
2246impl Index<usize> for BinnedDataset {
2247    type Output = Arc<Dataset>;
2248
2249    fn index(&self, index: usize) -> &Self::Output {
2250        &self.datasets[index]
2251    }
2252}
2253
2254impl IndexMut<usize> for BinnedDataset {
2255    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2256        &mut self.datasets[index]
2257    }
2258}
2259
2260impl Deref for BinnedDataset {
2261    type Target = Vec<Arc<Dataset>>;
2262
2263    fn deref(&self) -> &Self::Target {
2264        &self.datasets
2265    }
2266}
2267
2268impl DerefMut for BinnedDataset {
2269    fn deref_mut(&mut self) -> &mut Self::Target {
2270        &mut self.datasets
2271    }
2272}
2273
2274impl BinnedDataset {
2275    /// The number of bins in the [`BinnedDataset`].
2276    pub fn n_bins(&self) -> usize {
2277        self.datasets.len()
2278    }
2279
2280    /// Returns a list of the bin edges that were used to form the [`BinnedDataset`].
2281    pub fn edges(&self) -> Vec<f64> {
2282        self.edges.clone()
2283    }
2284
2285    /// Returns the range that was used to form the [`BinnedDataset`].
2286    pub fn range(&self) -> (f64, f64) {
2287        (self.edges[0], self.edges[self.n_bins()])
2288    }
2289}