1use std::{
2 borrow::Cow,
3 ops::{Deref, DerefMut, Index, IndexMut},
4 sync::Arc,
5};
6
7use accurate::{sum::Klein, traits::*};
8use auto_ops::impl_op_ex;
9#[cfg(feature = "mpi")]
10use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
11#[cfg(feature = "rayon")]
12use rayon::prelude::*;
13
14use super::{
15 event::{test_event, ColumnarP4Column, DatasetStorage, Event, EventData, OwnedEvent},
16 metadata::DatasetMetadata,
17};
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21#[cfg(feature = "mpi")]
22pub(crate) type WorldHandle = SimpleCommunicator;
23#[cfg(not(feature = "mpi"))]
24pub(crate) type WorldHandle = ();
25
26#[cfg(feature = "mpi")]
27const DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE: usize = 512;
30#[cfg(feature = "mpi")]
31const MPI_EVENT_FETCH_CHUNK_SIZE_ENV: &str = "LADDU_MPI_EVENT_FETCH_CHUNK_SIZE";
32
33use indexmap::IndexMap;
34
35use crate::{
36 math::get_bin_edges,
37 variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
38 vectors::Vec4,
39 LadduError, LadduResult,
40};
41
42const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
43const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
44
45fn local_weighted_sum(weights: &[f64]) -> f64 {
46 #[cfg(feature = "rayon")]
47 {
48 weights
49 .par_iter()
50 .copied()
51 .parallel_sum_with_accumulator::<Klein<f64>>()
52 }
53 #[cfg(not(feature = "rayon"))]
54 {
55 weights.iter().copied().sum_with_accumulator::<Klein<f64>>()
56 }
57}
58
59pub fn test_dataset() -> Dataset {
63 let metadata = Arc::new(
64 DatasetMetadata::new(
65 TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
66 TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
67 )
68 .expect("Test metadata should be valid"),
69 );
70 Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
71}
72
73#[derive(Debug, Clone)]
75pub struct Dataset {
76 pub(crate) columnar: Arc<DatasetStorage>,
77 rows: RowSelection,
78 pub(crate) metadata: Arc<DatasetMetadata>,
79 pub(crate) cached_local_weighted_sum: f64,
80 #[cfg(feature = "mpi")]
81 pub(crate) cached_global_event_count: usize,
82 #[cfg(feature = "mpi")]
83 pub(crate) cached_global_weighted_sum: f64,
84 #[cfg(feature = "mpi")]
85 pub(crate) mpi_layout: Option<MpiDatasetLayout>,
86}
87
88#[derive(Debug, Clone)]
89enum RowSelection {
90 Identity,
91 Indices(Arc<[usize]>),
92}
93
94impl RowSelection {
95 fn len(&self, storage_len: usize) -> usize {
96 match self {
97 Self::Identity => storage_len,
98 Self::Indices(indices) => indices.len(),
99 }
100 }
101
102 const fn is_identity(&self) -> bool {
103 matches!(self, Self::Identity)
104 }
105
106 fn physical_index(&self, logical_index: usize) -> usize {
107 match self {
108 Self::Identity => logical_index,
109 Self::Indices(indices) => indices[logical_index],
110 }
111 }
112}
113
114#[cfg(feature = "mpi")]
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub(crate) enum MpiDatasetLayout {
117 Canonical,
118 RoundRobin,
119 Derived,
120}
121
122#[cfg(feature = "mpi")]
123impl MpiDatasetLayout {
124 fn owner_of(
125 self,
126 global_index: usize,
127 total: usize,
128 local_len: usize,
129 world: &SimpleCommunicator,
130 ) -> (i32, usize) {
131 match self {
132 Self::Canonical => world.owner_of_global_index(global_index, total),
133 Self::RoundRobin => {
134 let size = world.size() as usize;
135 ((global_index % size) as i32, global_index / size)
136 }
137 Self::Derived => {
138 let counts = gather_local_event_counts(local_len, world);
139 let mut start = 0usize;
140 for (rank, count) in counts.into_iter().enumerate() {
141 let end = start + count;
142 if global_index < end {
143 return (rank as i32, global_index - start);
144 }
145 start = end;
146 }
147 debug_assert!(
148 global_index < total,
149 "validated derived global event index should be in range"
150 );
151 (world.rank(), 0)
152 }
153 }
154 }
155
156 fn local_range(
157 self,
158 total: usize,
159 local_len: usize,
160 world: &SimpleCommunicator,
161 ) -> std::ops::Range<usize> {
162 match self {
163 Self::Canonical => world.partition(total).range_for_rank(world.rank() as usize),
164 Self::RoundRobin => 0..local_len_for_round_robin(total, world),
165 Self::Derived => {
166 let counts = gather_local_event_counts(local_len, world);
167 let start = counts
168 .iter()
169 .take(world.rank() as usize)
170 .copied()
171 .sum::<usize>();
172 start..start + counts[world.rank() as usize]
173 }
174 }
175 }
176
177 fn local_indices_for_range(
178 self,
179 start: usize,
180 end: usize,
181 total: usize,
182 local_len: usize,
183 world: &SimpleCommunicator,
184 ) -> Vec<usize> {
185 match self {
186 Self::Canonical => {
187 let local_range = self.local_range(total, local_len, world);
188 let owned_start = start.max(local_range.start);
189 let owned_end = end.min(local_range.end);
190 if owned_start < owned_end {
191 (owned_start - local_range.start..owned_end - local_range.start).collect()
192 } else {
193 Vec::new()
194 }
195 }
196 Self::RoundRobin => {
197 let rank = world.rank() as usize;
198 let size = world.size() as usize;
199 (start..end)
200 .filter_map(|global_index| {
201 if global_index % size == rank {
202 Some(global_index / size)
203 } else {
204 None
205 }
206 })
207 .filter(|local_index| *local_index < local_len)
208 .collect()
209 }
210 Self::Derived => {
211 let counts = gather_local_event_counts(local_len, world);
212 let local_start = counts
213 .iter()
214 .take(world.rank() as usize)
215 .copied()
216 .sum::<usize>();
217 let local_end = local_start + local_len;
218 let owned_start = start.max(local_start);
219 let owned_end = end.min(local_end);
220 if owned_start < owned_end {
221 (owned_start - local_start..owned_end - local_start).collect()
222 } else {
223 Vec::new()
224 }
225 }
226 }
227 }
228}
229
230#[cfg(feature = "mpi")]
231fn gather_local_event_counts(local_len: usize, world: &SimpleCommunicator) -> Vec<usize> {
232 let mut counts = vec![0usize; world.size() as usize];
233 world.all_gather_into(&local_len, &mut counts);
234 counts
235}
236
237#[cfg(feature = "mpi")]
238fn local_len_for_round_robin(total: usize, world: &SimpleCommunicator) -> usize {
239 let rank = world.rank() as usize;
240 let size = world.size() as usize;
241 if total <= rank {
242 0
243 } else {
244 (total - 1 - rank) / size + 1
245 }
246}
247
248fn shared_dataset_iter(dataset: Arc<Dataset>) -> DatasetArcIter {
249 #[cfg(feature = "mpi")]
250 {
251 if let Some(world) = crate::mpi::get_world() {
252 if let Some(layout) = dataset.mpi_layout {
253 let total = dataset.n_events();
254 return DatasetArcIter::Mpi(DatasetArcMpiIter {
255 dataset,
256 world,
257 index: 0,
258 total,
259 cursor: MpiEventChunkCursor::for_iteration(total),
260 layout,
261 });
262 }
263 }
264 }
265 DatasetArcIter::Local { dataset, index: 0 }
266}
267
268pub trait SharedDatasetIterExt {
270 fn shared_iter(&self) -> DatasetArcIter;
272
273 fn shared_iter_global(&self) -> DatasetArcIter;
275}
276
277impl SharedDatasetIterExt for Arc<Dataset> {
278 fn shared_iter(&self) -> DatasetArcIter {
279 shared_dataset_iter(self.clone())
280 }
281
282 fn shared_iter_global(&self) -> DatasetArcIter {
283 self.shared_iter()
284 }
285}
286
287impl Dataset {
288 fn from_columnar_storage(
289 columnar: DatasetStorage,
290 metadata: Arc<DatasetMetadata>,
291 rows: RowSelection,
292 ) -> Self {
293 #[cfg(feature = "mpi")]
294 let local_count = rows.len(columnar.n_events());
295 let local_weighted_sum = Self::weighted_sum_for_rows(&columnar, &rows);
296 Dataset {
297 columnar: Arc::new(columnar),
298 rows,
299 metadata,
300 cached_local_weighted_sum: local_weighted_sum,
301 #[cfg(feature = "mpi")]
302 cached_global_event_count: local_count,
303 #[cfg(feature = "mpi")]
304 cached_global_weighted_sum: local_weighted_sum,
305 #[cfg(feature = "mpi")]
306 mpi_layout: None,
307 }
308 }
309
310 fn weighted_sum_for_rows(columnar: &DatasetStorage, rows: &RowSelection) -> f64 {
311 match rows {
312 RowSelection::Identity => local_weighted_sum(&columnar.weights),
313 RowSelection::Indices(indices) => {
314 #[cfg(feature = "rayon")]
315 {
316 indices
317 .par_iter()
318 .map(|index| columnar.weight(*index))
319 .parallel_sum_with_accumulator::<Klein<f64>>()
320 }
321 #[cfg(not(feature = "rayon"))]
322 {
323 indices
324 .iter()
325 .map(|index| columnar.weight(*index))
326 .sum_with_accumulator::<Klein<f64>>()
327 }
328 }
329 }
330 }
331
332 fn indexed_local_view<I>(&self, indices: I) -> Arc<Dataset>
333 where
334 I: IntoIterator<Item = usize>,
335 {
336 let rows = RowSelection::Indices(indices.into_iter().collect::<Vec<_>>().into());
337 let local_weighted_sum = Self::weighted_sum_for_rows(&self.columnar, &rows);
338 let dataset = Dataset {
339 columnar: self.columnar.clone(),
340 rows,
341 metadata: self.metadata.clone(),
342 cached_local_weighted_sum: local_weighted_sum,
343 #[cfg(feature = "mpi")]
344 cached_global_event_count: 0,
345 #[cfg(feature = "mpi")]
346 cached_global_weighted_sum: local_weighted_sum,
347 #[cfg(feature = "mpi")]
348 mpi_layout: self.mpi_layout,
349 };
350 #[cfg(feature = "mpi")]
351 {
352 let mut dataset = dataset;
353 if dataset.mpi_layout.is_some() {
354 dataset.mpi_layout = Some(MpiDatasetLayout::Derived);
355 if let Some(world) = crate::mpi::get_world() {
356 dataset.set_cached_global_event_count_from_world(&world);
357 dataset.set_cached_global_weighted_sum_from_world(&world);
358 }
359 }
360 Arc::new(dataset)
361 }
362 #[cfg(not(feature = "mpi"))]
363 {
364 Arc::new(dataset)
365 }
366 }
367
368 fn ensure_mutable_storage(&self, operation: &str) -> LadduResult<()> {
369 if self.rows.is_identity() {
370 Ok(())
371 } else {
372 Err(LadduError::Custom(format!(
373 "Cannot {operation} on a filtered or bootstrapped dataset view; materialize it first"
374 )))
375 }
376 }
377
378 pub fn events_local(&self) -> impl Iterator<Item = Event<'_>> {
380 DatasetViewIter {
381 dataset: self,
382 index: 0,
383 }
384 }
385
386 pub fn events_global(&self) -> DatasetGlobalIter<'_> {
391 let total = self.n_events();
392 #[cfg(feature = "mpi")]
393 {
394 if let (Some(world), Some(layout)) = (crate::mpi::get_world(), self.mpi_layout) {
395 return DatasetGlobalIter {
396 dataset: self,
397 index: 0,
398 total,
399 world: Some(world),
400 cursor: Some(MpiEventChunkCursor::for_iteration(total)),
401 layout: Some(layout),
402 };
403 }
404 }
405 DatasetGlobalIter {
406 dataset: self,
407 index: 0,
408 total,
409 #[cfg(feature = "mpi")]
410 world: None,
411 #[cfg(feature = "mpi")]
412 cursor: None,
413 #[cfg(feature = "mpi")]
414 layout: None,
415 }
416 }
417
418 fn refresh_local_weight_cache(&mut self) {
419 self.cached_local_weighted_sum = Self::weighted_sum_for_rows(&self.columnar, &self.rows);
420 #[cfg(feature = "mpi")]
421 {
422 self.cached_global_weighted_sum = self.cached_local_weighted_sum;
423 self.cached_global_event_count = self.n_events_local();
424 if self.mpi_layout.is_some() {
425 if let Some(world) = crate::mpi::get_world() {
426 self.set_cached_global_event_count_from_world(&world);
427 self.set_cached_global_weighted_sum_from_world(&world);
428 }
429 }
430 }
431 }
432
433 #[cfg(test)]
434 pub(crate) fn clear_events_local(&mut self) {
435 self.ensure_mutable_storage("clear local events")
436 .expect("test datasets should be materialized");
437 let columnar = Arc::make_mut(&mut self.columnar);
438 for column in &mut columnar.p4 {
439 column.px.clear();
440 column.py.clear();
441 column.pz.clear();
442 column.e.clear();
443 }
444 for column in &mut columnar.aux {
445 column.clear();
446 }
447 columnar.weights.clear();
448 self.refresh_local_weight_cache();
449 }
450
451 pub fn metadata(&self) -> &DatasetMetadata {
453 &self.metadata
454 }
455
456 pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
458 self.metadata.clone()
459 }
460
461 pub fn p4_names(&self) -> &[String] {
463 &self.metadata.p4_names
464 }
465
466 pub fn aux_names(&self) -> &[String] {
468 &self.metadata.aux_names
469 }
470
471 pub fn p4_index(&self, name: &str) -> Option<usize> {
473 self.metadata.p4_index(name)
474 }
475
476 pub fn aux_index(&self, name: &str) -> Option<usize> {
478 self.metadata.aux_index(name)
479 }
480
481 fn event_global_opt(&self, index: usize) -> LadduResult<Option<OwnedEvent>> {
482 #[cfg(feature = "mpi")]
483 {
484 if let (Some(world), Some(_)) = (crate::mpi::get_world(), self.mpi_layout) {
485 let total = self.n_events();
486 if index >= total {
487 return Ok(None);
488 }
489 return self.fetch_event_mpi(index, &world, total).map(Some);
490 }
491 }
492
493 Ok((index < self.n_events_local())
494 .then(|| OwnedEvent::new(self.event_data_arc_local(index), self.metadata.clone())))
495 }
496
497 pub fn event_global(&self, index: usize) -> LadduResult<OwnedEvent> {
499 self.event_global_opt(index)?.ok_or_else(|| {
500 LadduError::Custom(format!(
501 "Dataset index out of bounds: index {index}, length {}",
502 self.n_events()
503 ))
504 })
505 }
506
507 pub fn event_local(&self, event_index: usize) -> LadduResult<Event<'_>> {
509 if event_index >= self.n_events_local() {
510 return Err(LadduError::Custom(format!(
511 "Dataset local index out of bounds: index {event_index}, length {}",
512 self.n_events_local()
513 )));
514 }
515 Ok(self.event_view(event_index))
516 }
517
518 pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
520 self.event_global_opt(event_index)
521 .ok()
522 .flatten()
523 .and_then(|event| event.p4(name))
524 }
525
526 pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
528 let idx = self.aux_index(name)?;
529 self.event_global_opt(event_index)
530 .ok()
531 .flatten()
532 .and_then(|event| event.aux.get(idx).copied())
533 }
534
535 pub(crate) fn event_view(&self, event_index: usize) -> Event<'_> {
536 self.columnar
537 .event_view(self.rows.physical_index(event_index))
538 }
539
540 pub(crate) fn event_data_arc_local(&self, index: usize) -> Arc<EventData> {
553 Arc::new(self.columnar.event_data(self.rows.physical_index(index)))
554 }
555
556 pub(crate) fn local_event_data_arcs(&self) -> Vec<Arc<EventData>> {
557 (0..self.n_events_local())
558 .map(|index| self.event_data_arc_local(index))
559 .collect()
560 }
561
562 pub(crate) fn local_storage_for_export(&self) -> LadduResult<Cow<'_, DatasetStorage>> {
563 if self.rows.is_identity() {
564 Ok(Cow::Borrowed(self.columnar.as_ref()))
565 } else {
566 Ok(Cow::Owned(Self::columnar_from_events(
567 &self.local_event_data_arcs(),
568 self.metadata.clone(),
569 )?))
570 }
571 }
572
573 pub(crate) fn local_weight_cache_key(&self) -> (usize, usize) {
574 match &self.rows {
575 RowSelection::Identity => (
576 self.columnar.weights.as_ptr() as usize,
577 self.n_events_local(),
578 ),
579 RowSelection::Indices(indices) => (indices.as_ptr() as usize, indices.len()),
580 }
581 }
582
583 #[cfg(feature = "mpi")]
584 fn partition(
585 events: Vec<Arc<EventData>>,
586 world: &SimpleCommunicator,
587 ) -> Vec<Vec<Arc<EventData>>> {
588 let partition = world.partition(events.len());
589 (0..partition.n_ranks())
590 .map(|rank| {
591 let range = partition.range_for_rank(rank);
592 events[range.clone()].to_vec()
593 })
594 .collect()
595 }
596}
597
598pub(crate) struct DatasetViewIter<'a> {
600 dataset: &'a Dataset,
601 index: usize,
602}
603
604impl<'a> Iterator for DatasetViewIter<'a> {
605 type Item = Event<'a>;
606
607 fn next(&mut self) -> Option<Self::Item> {
608 if self.index >= self.dataset.n_events_local() {
609 return None;
610 }
611 let event = self.dataset.event_view(self.index);
612 self.index += 1;
613 Some(event)
614 }
615}
616
617pub struct DatasetGlobalIter<'a> {
619 dataset: &'a Dataset,
620 index: usize,
621 total: usize,
622 #[cfg(feature = "mpi")]
623 world: Option<SimpleCommunicator>,
624 #[cfg(feature = "mpi")]
625 cursor: Option<MpiEventChunkCursor>,
626 #[cfg(feature = "mpi")]
627 layout: Option<MpiDatasetLayout>,
628}
629
630impl Iterator for DatasetGlobalIter<'_> {
631 type Item = OwnedEvent;
632
633 fn next(&mut self) -> Option<Self::Item> {
634 if self.index >= self.total {
635 return None;
636 }
637 let index = self.index;
638 self.index += 1;
639
640 #[cfg(feature = "mpi")]
641 {
642 if let (Some(world), Some(cursor), Some(layout)) =
643 (&self.world, &mut self.cursor, self.layout)
644 {
645 return cursor
646 .event_for_dataset(self.dataset, index, world, self.total, layout)
647 .ok()
648 .flatten();
649 }
650 }
651
652 self.dataset.event_global_opt(index).ok().flatten()
653 }
654}
655
656pub enum DatasetArcIter {
658 Local {
660 dataset: Arc<Dataset>,
662 index: usize,
664 },
665 #[cfg(feature = "mpi")]
666 Mpi(DatasetArcMpiIter),
668}
669
670impl Iterator for DatasetArcIter {
671 type Item = OwnedEvent;
672
673 fn next(&mut self) -> Option<Self::Item> {
674 match self {
675 DatasetArcIter::Local { dataset, index } => {
676 let event = (*index < dataset.n_events_local()).then(|| {
677 OwnedEvent::new(
678 dataset.event_data_arc_local(*index),
679 dataset.metadata.clone(),
680 )
681 });
682 *index += 1;
683 event
684 }
685 #[cfg(feature = "mpi")]
686 DatasetArcIter::Mpi(iter) => iter.next(),
687 }
688 }
689}
690
691#[cfg(feature = "mpi")]
692#[derive(Debug, Clone)]
693pub(crate) struct MpiEventChunkCursor {
694 chunk_start: usize,
695 chunk_size: usize,
696 cached_events: Vec<OwnedEvent>,
697}
698
699#[cfg(feature = "mpi")]
700pub(crate) fn resolve_mpi_event_fetch_chunk_size(total: usize) -> usize {
701 let clamped_total = total.max(1);
702 if let Some(raw) = std::env::var_os(MPI_EVENT_FETCH_CHUNK_SIZE_ENV) {
703 if let Some(parsed) = raw.to_str().and_then(|value| value.parse::<usize>().ok()) {
704 return parsed.max(1).min(clamped_total);
705 }
706 }
707 DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE.min(clamped_total)
708}
709
710#[cfg(feature = "mpi")]
711#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
712enum ColumnMutationKind {
713 P4,
714 Aux,
715}
716
717#[cfg(feature = "mpi")]
718#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
719struct ColumnMutationStatus {
720 kind: ColumnMutationKind,
721 name: String,
722 len_ok: bool,
723 duplicate: bool,
724}
725
726#[cfg(feature = "mpi")]
727impl MpiEventChunkCursor {
728 pub(crate) fn for_iteration(total: usize) -> Self {
729 Self::new(resolve_mpi_event_fetch_chunk_size(total))
730 }
731}
732
733#[cfg(feature = "mpi")]
734impl MpiEventChunkCursor {
735 pub(crate) fn new(chunk_size: usize) -> Self {
736 Self {
737 chunk_start: 0,
738 chunk_size: chunk_size.max(1),
739 cached_events: Vec::new(),
740 }
741 }
742
743 fn chunk_end(&self) -> usize {
744 self.chunk_start + self.cached_events.len()
745 }
746
747 fn contains(&self, global_index: usize) -> bool {
748 global_index >= self.chunk_start && global_index < self.chunk_end()
749 }
750
751 pub(crate) fn event_for_dataset(
752 &mut self,
753 dataset: &Dataset,
754 global_index: usize,
755 world: &SimpleCommunicator,
756 total: usize,
757 layout: MpiDatasetLayout,
758 ) -> LadduResult<Option<OwnedEvent>> {
759 if global_index >= total {
760 return Ok(None);
761 }
762 if !self.contains(global_index) {
763 self.chunk_start = global_index;
764 self.cached_events = dataset.fetch_event_chunk_mpi(
765 global_index,
766 self.chunk_size,
767 world,
768 total,
769 layout,
770 )?;
771 }
772 Ok(self
773 .cached_events
774 .get(global_index - self.chunk_start)
775 .cloned())
776 }
777}
778
779#[cfg(feature = "mpi")]
780pub struct DatasetArcMpiIter {
782 dataset: Arc<Dataset>,
783 world: SimpleCommunicator,
784 index: usize,
785 total: usize,
786 cursor: MpiEventChunkCursor,
787 layout: MpiDatasetLayout,
788}
789
790#[cfg(feature = "mpi")]
791impl Iterator for DatasetArcMpiIter {
792 type Item = OwnedEvent;
793
794 fn next(&mut self) -> Option<Self::Item> {
795 let event = self
796 .cursor
797 .event_for_dataset(
798 &self.dataset,
799 self.index,
800 &self.world,
801 self.total,
802 self.layout,
803 )
804 .ok()
805 .flatten();
806 self.index += 1;
807 event
808 }
809}
810
811impl Dataset {
812 #[cfg(feature = "mpi")]
813 fn validate_global_column_add(
814 &self,
815 kind: ColumnMutationKind,
816 name: &str,
817 len_ok: bool,
818 ) -> LadduResult<()> {
819 let Some(world) = crate::mpi::get_world() else {
820 return Ok(());
821 };
822 let duplicate = match kind {
823 ColumnMutationKind::P4 => self.metadata.ensure_new_p4_name(name).is_err(),
824 ColumnMutationKind::Aux => self.metadata.ensure_new_aux_name(name).is_err(),
825 };
826 let local_status = ColumnMutationStatus {
827 kind,
828 name: name.to_string(),
829 len_ok,
830 duplicate,
831 };
832 let serialized = bitcode::serialize(&local_status)?;
833 let local_byte_count = serialized.len() as i32;
834 let mut byte_counts = vec![0_i32; world.size() as usize];
835 world.all_gather_into(&local_byte_count, &mut byte_counts);
836 let mut byte_displs = vec![0_i32; byte_counts.len()];
837 for index in 1..byte_displs.len() {
838 byte_displs[index] = byte_displs[index - 1] + byte_counts[index - 1];
839 }
840 let gathered_bytes = world.all_gather_with_counts(&serialized, &byte_counts, &byte_displs);
841 let mut statuses = Vec::with_capacity(world.size() as usize);
842 for rank in 0..world.size() as usize {
843 let start = byte_displs[rank] as usize;
844 let end = start + byte_counts[rank] as usize;
845 statuses.push(bitcode::deserialize::<ColumnMutationStatus>(
846 &gathered_bytes[start..end],
847 )?);
848 }
849 for (rank, status) in statuses.iter().enumerate() {
850 if status.kind != kind {
851 return Err(LadduError::Custom(format!(
852 "Collective dataset column add mismatch: rank {rank} used {:?}, expected {:?}",
853 status.kind, kind
854 )));
855 }
856 if status.name != name {
857 return Err(LadduError::Custom(format!(
858 "Collective dataset column add mismatch: rank {rank} used name '{}', expected '{name}'",
859 status.name
860 )));
861 }
862 if !status.len_ok {
863 return Err(LadduError::Custom(format!(
864 "Collective dataset column add mismatch: rank {rank} provided a column with the wrong local length"
865 )));
866 }
867 if status.duplicate {
868 let category = match kind {
869 ColumnMutationKind::P4 => "p4",
870 ColumnMutationKind::Aux => "aux",
871 };
872 return Err(LadduError::DuplicateName {
873 category,
874 name: name.to_string(),
875 });
876 }
877 }
878 Ok(())
879 }
880
881 #[cfg(feature = "mpi")]
882 fn fetch_event_mpi(
883 &self,
884 global_index: usize,
885 world: &SimpleCommunicator,
886 total: usize,
887 ) -> LadduResult<OwnedEvent> {
888 let layout = self.mpi_layout.ok_or_else(|| {
889 LadduError::Custom(
890 "global MPI event fetch requires a global dataset layout".to_string(),
891 )
892 })?;
893 let (owning_rank, local_index) =
894 layout.owner_of(global_index, total, self.n_events_local(), world);
895 let mut serialized_event_buffer_len: usize = 0;
896 let mut serialized_event_buffer: Vec<u8> = Vec::default();
897 if world.rank() == owning_rank {
898 let event = self
899 .columnar
900 .event_data(self.rows.physical_index(local_index));
901 serialized_event_buffer = bitcode::serialize(&event)?;
902 serialized_event_buffer_len = serialized_event_buffer.len();
903 }
904 world
905 .process_at_rank(owning_rank)
906 .broadcast_into(&mut serialized_event_buffer_len);
907 if world.rank() != owning_rank {
908 serialized_event_buffer = vec![0; serialized_event_buffer_len];
909 }
910 world
911 .process_at_rank(owning_rank)
912 .broadcast_into(&mut serialized_event_buffer);
913
914 if world.rank() == owning_rank {
915 Ok(OwnedEvent::new(
916 Arc::new(
917 self.columnar
918 .event_data(self.rows.physical_index(local_index)),
919 ),
920 self.metadata.clone(),
921 ))
922 } else {
923 let event: EventData = bitcode::deserialize(&serialized_event_buffer[..])?;
924 Ok(OwnedEvent::new(Arc::new(event), self.metadata.clone()))
925 }
926 }
927
928 #[cfg(feature = "mpi")]
929 pub(crate) fn fetch_event_chunk_mpi(
930 &self,
931 start: usize,
932 len: usize,
933 world: &SimpleCommunicator,
934 total: usize,
935 layout: MpiDatasetLayout,
936 ) -> LadduResult<Vec<OwnedEvent>> {
937 if len == 0 || start >= total {
938 return Ok(Vec::new());
939 }
940
941 let end = (start + len).min(total);
942 let local_indices =
943 layout.local_indices_for_range(start, end, total, self.n_events_local(), world);
944
945 let local_events: Vec<EventData> = local_indices
946 .into_iter()
947 .map(|local_index| {
948 self.columnar
949 .event_data(self.rows.physical_index(local_index))
950 })
951 .collect();
952 let local_event_count = local_events.len() as i32;
953
954 let serialized_local = if local_events.is_empty() {
955 Vec::new()
956 } else {
957 bitcode::serialize(&local_events)?
958 };
959 let local_byte_count = serialized_local.len() as i32;
960
961 let mut gathered_event_counts = vec![0_i32; world.size() as usize];
962 let mut gathered_byte_counts = vec![0_i32; world.size() as usize];
963 world.all_gather_into(&local_event_count, &mut gathered_event_counts);
964 world.all_gather_into(&local_byte_count, &mut gathered_byte_counts);
965
966 let mut gathered_byte_displs = vec![0_i32; gathered_byte_counts.len()];
967 for index in 1..gathered_byte_displs.len() {
968 gathered_byte_displs[index] =
969 gathered_byte_displs[index - 1] + gathered_byte_counts[index - 1];
970 }
971 let gathered_bytes = world.all_gather_with_counts(
972 &serialized_local,
973 &gathered_byte_counts,
974 &gathered_byte_displs,
975 );
976
977 let mut events_by_rank = vec![Vec::new(); world.size() as usize];
978 for rank in 0..world.size() as usize {
979 if gathered_event_counts[rank] == 0 {
980 continue;
981 }
982 let byte_start = gathered_byte_displs[rank] as usize;
983 let byte_end = byte_start + gathered_byte_counts[rank] as usize;
984 let decoded: Vec<EventData> =
985 bitcode::deserialize(&gathered_bytes[byte_start..byte_end])?;
986 debug_assert_eq!(decoded.len(), gathered_event_counts[rank] as usize);
987 events_by_rank[rank] = decoded
988 .into_iter()
989 .map(|event| OwnedEvent::new(Arc::new(event), self.metadata.clone()))
990 .collect();
991 }
992
993 let mut offsets = vec![0usize; world.size() as usize];
994 let mut events = Vec::with_capacity(end - start);
995 for global_index in start..end {
996 let (owning_rank, _) =
997 layout.owner_of(global_index, total, self.n_events_local(), world);
998 let rank = owning_rank as usize;
999 let offset = offsets[rank];
1000 events.push(events_by_rank[rank][offset].clone());
1001 offsets[rank] += 1;
1002 }
1003 Ok(events)
1004 }
1005
1006 #[cfg(feature = "mpi")]
1007 pub(crate) fn set_cached_global_event_count_from_world(&mut self, world: &SimpleCommunicator) {
1008 let local_count = self.n_events_local();
1009 let mut global_count = 0usize;
1010 world.all_reduce_into(
1011 &local_count,
1012 &mut global_count,
1013 mpi::collective::SystemOperation::sum(),
1014 );
1015 self.cached_global_event_count = global_count;
1016 }
1017
1018 #[cfg(feature = "mpi")]
1019 pub(crate) fn set_cached_global_weighted_sum_from_world(&mut self, world: &SimpleCommunicator) {
1020 let mut weighted_sums = vec![0.0_f64; world.size() as usize];
1021 world.all_gather_into(&self.cached_local_weighted_sum, &mut weighted_sums);
1022 #[cfg(feature = "rayon")]
1023 {
1024 self.cached_global_weighted_sum = weighted_sums
1025 .into_par_iter()
1026 .parallel_sum_with_accumulator::<Klein<f64>>();
1027 }
1028 #[cfg(not(feature = "rayon"))]
1029 {
1030 self.cached_global_weighted_sum = weighted_sums
1031 .into_iter()
1032 .sum_with_accumulator::<Klein<f64>>();
1033 }
1034 }
1035
1036 fn columnar_from_events(
1037 events: &[Arc<EventData>],
1038 metadata: Arc<DatasetMetadata>,
1039 ) -> LadduResult<DatasetStorage> {
1040 let n_events = events.len();
1041 let (n_p4, n_aux) = match events.first() {
1042 Some(first) => (first.p4s.len(), first.aux.len()),
1043 None => (metadata.p4_names.len(), metadata.aux_names.len()),
1044 };
1045 let mut p4 = (0..n_p4)
1046 .map(|_| ColumnarP4Column::with_capacity(n_events))
1047 .collect::<Vec<_>>();
1048 let mut aux = (0..n_aux)
1049 .map(|_| Vec::with_capacity(n_events))
1050 .collect::<Vec<_>>();
1051 let mut weights = Vec::with_capacity(n_events);
1052 for (event_index, event) in events.iter().enumerate() {
1053 if event.p4s.len() != n_p4 || event.aux.len() != n_aux {
1054 return Err(LadduError::Custom(format!(
1055 "Ragged dataset shape at event {event_index}: expected ({n_p4} p4, {n_aux} aux), got ({} p4, {} aux)",
1056 event.p4s.len(),
1057 event.aux.len()
1058 )));
1059 }
1060 for (column, value) in p4.iter_mut().zip(event.p4s.iter()) {
1061 column.push(*value);
1062 }
1063 for (column, value) in aux.iter_mut().zip(event.aux.iter()) {
1064 column.push(*value);
1065 }
1066 weights.push(event.weight);
1067 }
1068 Ok(DatasetStorage {
1069 metadata,
1070 p4,
1071 aux,
1072 weights,
1073 })
1074 }
1075
1076 pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1083 let columnar = Self::columnar_from_events(&events, metadata.clone())
1084 .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1085 Self::from_columnar_storage(columnar, metadata, RowSelection::Identity)
1086 }
1087
1088 pub fn empty_local(metadata: DatasetMetadata) -> Self {
1097 let metadata = Arc::new(metadata);
1098 #[cfg(feature = "mpi")]
1099 {
1100 if crate::mpi::get_world().is_some() {
1101 let dataset = Dataset {
1102 columnar: Arc::new(DatasetStorage::empty_with_capacity(metadata.clone(), 0)),
1103 rows: RowSelection::Identity,
1104 metadata,
1105 cached_local_weighted_sum: 0.0,
1106 cached_global_event_count: 0,
1107 cached_global_weighted_sum: 0.0,
1108 mpi_layout: None,
1109 };
1110 return dataset;
1111 }
1112 }
1113 Dataset {
1114 columnar: Arc::new(DatasetStorage::empty_with_capacity(metadata.clone(), 0)),
1115 rows: RowSelection::Identity,
1116 metadata,
1117 cached_local_weighted_sum: 0.0,
1118 #[cfg(feature = "mpi")]
1119 cached_global_event_count: 0,
1120 #[cfg(feature = "mpi")]
1121 cached_global_weighted_sum: 0.0,
1122 #[cfg(feature = "mpi")]
1123 mpi_layout: None,
1124 }
1125 }
1126
1127 pub fn from_columns_local(
1132 metadata: DatasetMetadata,
1133 p4_columns: Vec<Vec<Vec4>>,
1134 aux_columns: Vec<Vec<f64>>,
1135 weights: Vec<f64>,
1136 ) -> LadduResult<Self> {
1137 let n_events = weights.len();
1138 if p4_columns.len() != metadata.p4_names().len() {
1139 return Err(LadduError::Custom(format!(
1140 "Expected {} p4 columns, got {}",
1141 metadata.p4_names().len(),
1142 p4_columns.len()
1143 )));
1144 }
1145 if aux_columns.len() != metadata.aux_names().len() {
1146 return Err(LadduError::Custom(format!(
1147 "Expected {} aux columns, got {}",
1148 metadata.aux_names().len(),
1149 aux_columns.len()
1150 )));
1151 }
1152 for (index, column) in p4_columns.iter().enumerate() {
1153 if column.len() != n_events {
1154 return Err(LadduError::Custom(format!(
1155 "P4 column {index} length {} does not match weight length {n_events}",
1156 column.len()
1157 )));
1158 }
1159 }
1160 for (index, column) in aux_columns.iter().enumerate() {
1161 if column.len() != n_events {
1162 return Err(LadduError::Custom(format!(
1163 "Aux column {index} length {} does not match weight length {n_events}",
1164 column.len()
1165 )));
1166 }
1167 }
1168
1169 let events = (0..n_events)
1170 .map(|event_index| {
1171 Arc::new(EventData {
1172 p4s: p4_columns
1173 .iter()
1174 .map(|column| column[event_index])
1175 .collect(),
1176 aux: aux_columns
1177 .iter()
1178 .map(|column| column[event_index])
1179 .collect(),
1180 weight: weights[event_index],
1181 })
1182 })
1183 .collect();
1184 Ok(Dataset::new_local(events, Arc::new(metadata)))
1185 }
1186
1187 pub fn from_columns_global(
1192 metadata: DatasetMetadata,
1193 p4_columns: Vec<Vec<Vec4>>,
1194 aux_columns: Vec<Vec<f64>>,
1195 weights: Vec<f64>,
1196 ) -> LadduResult<Self> {
1197 let dataset = Self::from_columns_local(metadata, p4_columns, aux_columns, weights)?;
1198 let events = dataset.local_event_data_arcs();
1199 Ok(Dataset::new_with_metadata(events, dataset.metadata))
1200 }
1201
1202 #[cfg(feature = "mpi")]
1209 pub fn new_mpi(
1210 events: Vec<Arc<EventData>>,
1211 metadata: Arc<DatasetMetadata>,
1212 world: &SimpleCommunicator,
1213 ) -> Self {
1214 let partitions = Dataset::partition(events, world);
1215 let local = &partitions[world.rank() as usize];
1216 let columnar = Self::columnar_from_events(local, metadata.clone())
1217 .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1218 let local_weighted_sum = local_weighted_sum(&columnar.weights);
1219 let mut dataset = Dataset {
1220 columnar: Arc::new(columnar),
1221 rows: RowSelection::Identity,
1222 metadata,
1223 cached_local_weighted_sum: local_weighted_sum,
1224 cached_global_event_count: 0,
1225 cached_global_weighted_sum: local_weighted_sum,
1226 mpi_layout: Some(MpiDatasetLayout::Canonical),
1227 };
1228 dataset.set_cached_global_event_count_from_world(world);
1229 dataset.set_cached_global_weighted_sum_from_world(world);
1230 dataset
1231 }
1232
1233 pub fn new(events: Vec<Arc<EventData>>) -> Self {
1239 Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
1240 }
1241
1242 pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1245 #[cfg(feature = "mpi")]
1246 {
1247 if let Some(world) = crate::mpi::get_world() {
1248 return Dataset::new_mpi(events, metadata, &world);
1249 }
1250 }
1251 Dataset::new_local(events, metadata)
1252 }
1253
1254 fn push_event_data_local(&mut self, event_data: Arc<EventData>) -> LadduResult<()> {
1255 self.ensure_mutable_storage("push events")?;
1256 Arc::make_mut(&mut self.columnar).push_event_data(&event_data);
1257 self.refresh_local_weight_cache();
1258 Ok(())
1259 }
1260
1261 fn replace_metadata(&mut self, metadata: DatasetMetadata) {
1262 let metadata = Arc::new(metadata);
1263 self.metadata = metadata.clone();
1264 Arc::make_mut(&mut self.columnar).set_metadata(metadata);
1265 }
1266
1267 fn validate_p4_column_len(&self, name: &str, len: usize) -> LadduResult<()> {
1268 if len != self.n_events_local() {
1269 return Err(LadduError::LengthMismatch {
1270 context: format!("P4 column '{name}'"),
1271 expected: self.n_events_local(),
1272 actual: len,
1273 });
1274 }
1275 Ok(())
1276 }
1277
1278 fn validate_aux_column_len(&self, name: &str, len: usize) -> LadduResult<()> {
1279 if len != self.n_events_local() {
1280 return Err(LadduError::LengthMismatch {
1281 context: format!("Aux column '{name}'"),
1282 expected: self.n_events_local(),
1283 actual: len,
1284 });
1285 }
1286 Ok(())
1287 }
1288
1289 fn add_p4_column_unchecked(&mut self, name: String, values: Vec<Vec4>) -> LadduResult<()> {
1290 let mut metadata = (*self.metadata).clone();
1291 metadata.add_p4_name(name)?;
1292 Arc::make_mut(&mut self.columnar).push_p4_column(values);
1293 self.replace_metadata(metadata);
1294 Ok(())
1295 }
1296
1297 fn add_aux_column_unchecked(&mut self, name: String, values: Vec<f64>) -> LadduResult<()> {
1298 let mut metadata = (*self.metadata).clone();
1299 metadata.add_aux_name(name)?;
1300 Arc::make_mut(&mut self.columnar).push_aux_column(values);
1301 self.replace_metadata(metadata);
1302 Ok(())
1303 }
1304
1305 pub fn add_p4_column_local<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1311 where
1312 N: Into<String>,
1313 V: IntoIterator<Item = Vec4>,
1314 {
1315 self.ensure_mutable_storage("add a p4 column")?;
1316 #[cfg(feature = "mpi")]
1317 {
1318 if self.mpi_layout.is_some() {
1319 return Err(LadduError::Custom(
1320 "Cannot add a local p4 column to an MPI dataset; use add_p4_column_global"
1321 .to_string(),
1322 ));
1323 }
1324 }
1325 let name = name.into();
1326 let values = values.into_iter().collect::<Vec<_>>();
1327 self.metadata.ensure_new_p4_name(&name)?;
1328 self.validate_p4_column_len(&name, values.len())?;
1329 self.add_p4_column_unchecked(name, values)
1330 }
1331
1332 pub fn add_aux_column_local<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1338 where
1339 N: Into<String>,
1340 V: IntoIterator<Item = f64>,
1341 {
1342 self.ensure_mutable_storage("add an aux column")?;
1343 #[cfg(feature = "mpi")]
1344 {
1345 if self.mpi_layout.is_some() {
1346 return Err(LadduError::Custom(
1347 "Cannot add a local aux column to an MPI dataset; use add_aux_column_global"
1348 .to_string(),
1349 ));
1350 }
1351 }
1352 let name = name.into();
1353 let values = values.into_iter().collect::<Vec<_>>();
1354 self.metadata.ensure_new_aux_name(&name)?;
1355 self.validate_aux_column_len(&name, values.len())?;
1356 self.add_aux_column_unchecked(name, values)
1357 }
1358
1359 pub fn add_p4_column_global<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1364 where
1365 N: Into<String>,
1366 V: IntoIterator<Item = Vec4>,
1367 {
1368 self.ensure_mutable_storage("add a p4 column")?;
1369 let name = name.into();
1370 let values = values.into_iter().collect::<Vec<_>>();
1371 #[cfg(feature = "mpi")]
1372 {
1373 if crate::mpi::get_world().is_some() {
1374 self.validate_global_column_add(
1375 ColumnMutationKind::P4,
1376 &name,
1377 values.len() == self.n_events_local(),
1378 )?;
1379 self.metadata.ensure_new_p4_name(&name)?;
1380 self.validate_p4_column_len(&name, values.len())?;
1381 return self.add_p4_column_unchecked(name, values);
1382 }
1383 }
1384 self.add_p4_column_local(name, values)
1385 }
1386
1387 pub fn add_aux_column_global<N, V>(&mut self, name: N, values: V) -> LadduResult<()>
1392 where
1393 N: Into<String>,
1394 V: IntoIterator<Item = f64>,
1395 {
1396 self.ensure_mutable_storage("add an aux column")?;
1397 let name = name.into();
1398 let values = values.into_iter().collect::<Vec<_>>();
1399 #[cfg(feature = "mpi")]
1400 {
1401 if crate::mpi::get_world().is_some() {
1402 self.validate_global_column_add(
1403 ColumnMutationKind::Aux,
1404 &name,
1405 values.len() == self.n_events_local(),
1406 )?;
1407 self.metadata.ensure_new_aux_name(&name)?;
1408 self.validate_aux_column_len(&name, values.len())?;
1409 return self.add_aux_column_unchecked(name, values);
1410 }
1411 }
1412 self.add_aux_column_local(name, values)
1413 }
1414
1415 pub fn push_event_local<P, A>(&mut self, p4s: P, aux: A, weight: f64) -> LadduResult<()>
1423 where
1424 P: IntoIterator<Item = Vec4>,
1425 A: IntoIterator<Item = f64>,
1426 {
1427 self.ensure_mutable_storage("push events")?;
1428 #[cfg(feature = "mpi")]
1429 {
1430 if self.mpi_layout == Some(MpiDatasetLayout::RoundRobin) && self.n_events() > 0 {
1431 return Err(LadduError::Custom(
1432 "Cannot push local events into a round-robin global dataset".to_string(),
1433 ));
1434 }
1435 self.mpi_layout = None;
1436 }
1437 let p4s = p4s.into_iter().collect::<Vec<_>>();
1438 let aux = aux.into_iter().collect::<Vec<_>>();
1439 if p4s.len() != self.metadata.p4_names().len() {
1440 return Err(LadduError::Custom(format!(
1441 "Expected {} p4 values, got {}",
1442 self.metadata.p4_names().len(),
1443 p4s.len()
1444 )));
1445 }
1446 if aux.len() != self.metadata.aux_names().len() {
1447 return Err(LadduError::Custom(format!(
1448 "Expected {} aux values, got {}",
1449 self.metadata.aux_names().len(),
1450 aux.len()
1451 )));
1452 }
1453
1454 let event_data = Arc::new(EventData { p4s, aux, weight });
1455 self.push_event_data_local(event_data)
1456 }
1457
1458 pub fn push_event_global<P, A>(&mut self, p4s: P, aux: A, weight: f64) -> LadduResult<()>
1464 where
1465 P: IntoIterator<Item = Vec4>,
1466 A: IntoIterator<Item = f64>,
1467 {
1468 self.ensure_mutable_storage("push events")?;
1469 let p4s = p4s.into_iter().collect::<Vec<_>>();
1470 let aux = aux.into_iter().collect::<Vec<_>>();
1471 if p4s.len() != self.metadata.p4_names().len() {
1472 return Err(LadduError::Custom(format!(
1473 "Expected {} p4 values, got {}",
1474 self.metadata.p4_names().len(),
1475 p4s.len()
1476 )));
1477 }
1478 if aux.len() != self.metadata.aux_names().len() {
1479 return Err(LadduError::Custom(format!(
1480 "Expected {} aux values, got {}",
1481 self.metadata.aux_names().len(),
1482 aux.len()
1483 )));
1484 }
1485
1486 #[cfg(feature = "mpi")]
1487 {
1488 if let Some(world) = crate::mpi::get_world() {
1489 if self.mpi_layout != Some(MpiDatasetLayout::RoundRobin) && self.n_events() > 0 {
1490 return Err(LadduError::Custom(
1491 "Cannot push round-robin global events into a non-empty local/canonical dataset"
1492 .to_string(),
1493 ));
1494 }
1495 self.mpi_layout = Some(MpiDatasetLayout::RoundRobin);
1496 let global_index = self.n_events();
1497 if global_index % world.size() as usize == world.rank() as usize {
1498 self.push_event_data_local(Arc::new(EventData { p4s, aux, weight }))?;
1499 } else {
1500 self.refresh_local_weight_cache();
1501 }
1502 return Ok(());
1503 }
1504 }
1505
1506 self.push_event_data_local(Arc::new(EventData { p4s, aux, weight }))
1507 }
1508
1509 pub fn push_event_named_local<P, PN, A, AN>(
1514 &mut self,
1515 p4s: P,
1516 aux: A,
1517 weight: f64,
1518 ) -> LadduResult<()>
1519 where
1520 P: IntoIterator<Item = (PN, Vec4)>,
1521 PN: AsRef<str>,
1522 A: IntoIterator<Item = (AN, f64)>,
1523 AN: AsRef<str>,
1524 {
1525 let mut ordered_p4s = vec![None; self.metadata.p4_names().len()];
1526 for (name, p4) in p4s {
1527 let name = name.as_ref();
1528 let index = self
1529 .metadata
1530 .p4_index(name)
1531 .ok_or_else(|| LadduError::UnknownName {
1532 category: "p4",
1533 name: name.to_string(),
1534 })?;
1535 if ordered_p4s[index].replace(p4).is_some() {
1536 return Err(LadduError::DuplicateName {
1537 category: "p4",
1538 name: name.to_string(),
1539 });
1540 }
1541 }
1542 let mut ordered_aux = vec![None; self.metadata.aux_names().len()];
1543 for (name, value) in aux {
1544 let name = name.as_ref();
1545 let index = self
1546 .metadata
1547 .aux_index(name)
1548 .ok_or_else(|| LadduError::UnknownName {
1549 category: "aux",
1550 name: name.to_string(),
1551 })?;
1552 if ordered_aux[index].replace(value).is_some() {
1553 return Err(LadduError::DuplicateName {
1554 category: "aux",
1555 name: name.to_string(),
1556 });
1557 }
1558 }
1559
1560 let p4s = ordered_p4s
1561 .into_iter()
1562 .enumerate()
1563 .map(|(index, value)| {
1564 value.ok_or_else(|| {
1565 LadduError::Custom(format!(
1566 "Missing p4 value for '{}'",
1567 self.metadata.p4_names()[index]
1568 ))
1569 })
1570 })
1571 .collect::<LadduResult<Vec<_>>>()?;
1572 let aux = ordered_aux
1573 .into_iter()
1574 .enumerate()
1575 .map(|(index, value)| {
1576 value.ok_or_else(|| {
1577 LadduError::Custom(format!(
1578 "Missing aux value for '{}'",
1579 self.metadata.aux_names()[index]
1580 ))
1581 })
1582 })
1583 .collect::<LadduResult<Vec<_>>>()?;
1584
1585 self.push_event_local(p4s, aux, weight)
1586 }
1587
1588 pub fn push_event_named_global<P, PN, A, AN>(
1594 &mut self,
1595 p4s: P,
1596 aux: A,
1597 weight: f64,
1598 ) -> LadduResult<()>
1599 where
1600 P: IntoIterator<Item = (PN, Vec4)>,
1601 PN: AsRef<str>,
1602 A: IntoIterator<Item = (AN, f64)>,
1603 AN: AsRef<str>,
1604 {
1605 let mut ordered_p4s = vec![None; self.metadata.p4_names().len()];
1606 for (name, p4) in p4s {
1607 let name = name.as_ref();
1608 let index = self
1609 .metadata
1610 .p4_index(name)
1611 .ok_or_else(|| LadduError::UnknownName {
1612 category: "p4",
1613 name: name.to_string(),
1614 })?;
1615 if ordered_p4s[index].replace(p4).is_some() {
1616 return Err(LadduError::DuplicateName {
1617 category: "p4",
1618 name: name.to_string(),
1619 });
1620 }
1621 }
1622 let mut ordered_aux = vec![None; self.metadata.aux_names().len()];
1623 for (name, value) in aux {
1624 let name = name.as_ref();
1625 let index = self
1626 .metadata
1627 .aux_index(name)
1628 .ok_or_else(|| LadduError::UnknownName {
1629 category: "aux",
1630 name: name.to_string(),
1631 })?;
1632 if ordered_aux[index].replace(value).is_some() {
1633 return Err(LadduError::DuplicateName {
1634 category: "aux",
1635 name: name.to_string(),
1636 });
1637 }
1638 }
1639
1640 let p4s = ordered_p4s
1641 .into_iter()
1642 .enumerate()
1643 .map(|(index, value)| {
1644 value.ok_or_else(|| {
1645 LadduError::Custom(format!(
1646 "Missing p4 value for '{}'",
1647 self.metadata.p4_names()[index]
1648 ))
1649 })
1650 })
1651 .collect::<LadduResult<Vec<_>>>()?;
1652 let aux = ordered_aux
1653 .into_iter()
1654 .enumerate()
1655 .map(|(index, value)| {
1656 value.ok_or_else(|| {
1657 LadduError::Custom(format!(
1658 "Missing aux value for '{}'",
1659 self.metadata.aux_names()[index]
1660 ))
1661 })
1662 })
1663 .collect::<LadduResult<Vec<_>>>()?;
1664
1665 self.push_event_global(p4s, aux, weight)
1666 }
1667
1668 pub fn n_events_local(&self) -> usize {
1675 self.rows.len(self.columnar.n_events())
1676 }
1677
1678 #[cfg(feature = "mpi")]
1685 pub fn n_events_mpi(&self, _world: &SimpleCommunicator) -> usize {
1686 self.cached_global_event_count
1687 }
1688
1689 pub fn n_events(&self) -> usize {
1691 #[cfg(feature = "mpi")]
1692 {
1693 if self.mpi_layout.is_some() {
1694 if let Some(world) = crate::mpi::get_world() {
1695 return self.n_events_mpi(&world);
1696 }
1697 }
1698 }
1699 self.n_events_local()
1700 }
1701
1702 pub fn n_events_global(&self) -> usize {
1706 self.n_events()
1707 }
1708}
1709
1710impl Dataset {
1711 pub fn weights_local(&self) -> Vec<f64> {
1718 match &self.rows {
1719 RowSelection::Identity => self.columnar.weights.clone(),
1720 RowSelection::Indices(indices) => indices
1721 .iter()
1722 .map(|index| self.columnar.weight(*index))
1723 .collect(),
1724 }
1725 }
1726
1727 #[cfg(feature = "mpi")]
1734 pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
1735 if matches!(
1736 self.mpi_layout,
1737 Some(MpiDatasetLayout::RoundRobin | MpiDatasetLayout::Derived)
1738 ) {
1739 return self.events_global().map(|event| event.weight()).collect();
1740 }
1741 let local_weights = self.weights_local();
1742 let n_events = self.n_events();
1743 let mut buffer: Vec<f64> = vec![0.0; n_events];
1744 let (counts, displs) = world.get_counts_displs(n_events);
1745 {
1746 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1749 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
1750 }
1751 buffer
1752 }
1753
1754 pub fn weights(&self) -> Vec<f64> {
1756 #[cfg(feature = "mpi")]
1757 {
1758 if self.mpi_layout.is_some() {
1759 if let Some(world) = crate::mpi::get_world() {
1760 return self.weights_mpi(&world);
1761 }
1762 }
1763 }
1764 self.weights_local()
1765 }
1766
1767 pub fn weights_global(&self) -> Vec<f64> {
1771 self.weights()
1772 }
1773
1774 pub fn n_events_weighted_local(&self) -> f64 {
1781 self.cached_local_weighted_sum
1782 }
1783 #[cfg(feature = "mpi")]
1790 pub fn n_events_weighted_mpi(&self, _world: &SimpleCommunicator) -> f64 {
1791 self.cached_global_weighted_sum
1792 }
1793
1794 pub fn n_events_weighted(&self) -> f64 {
1796 #[cfg(feature = "mpi")]
1797 {
1798 if self.mpi_layout.is_some() {
1799 if let Some(world) = crate::mpi::get_world() {
1800 return self.n_events_weighted_mpi(&world);
1801 }
1802 }
1803 }
1804 self.n_events_weighted_local()
1805 }
1806
1807 pub fn n_events_weighted_global(&self) -> f64 {
1811 self.n_events_weighted()
1812 }
1813
1814 pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
1822 let mut rng = fastrand::Rng::with_seed(seed as u64);
1823 let n_events = self.n_events_local();
1824 let mut indices: Vec<usize> = (0..n_events)
1825 .map(|_| rng.usize(0..n_events))
1826 .collect::<Vec<usize>>();
1827 indices.sort();
1828 self.indexed_local_view(
1829 indices
1830 .into_iter()
1831 .map(|index| self.rows.physical_index(index)),
1832 )
1833 }
1834
1835 #[cfg(feature = "mpi")]
1843 pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
1844 let n_events = self.n_events();
1845 let mut indices: Vec<usize> = vec![0; n_events];
1846 if world.is_root() {
1847 let mut rng = fastrand::Rng::with_seed(seed as u64);
1848 indices = (0..n_events)
1849 .map(|_| rng.usize(0..n_events))
1850 .collect::<Vec<usize>>();
1851 indices.sort();
1852 }
1853 world.process_at_root().broadcast_into(&mut indices);
1854 let local_indices: Vec<usize> = indices
1855 .into_iter()
1856 .filter_map(|idx| {
1857 let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
1858 if world.rank() == owning_rank {
1859 Some(local_index)
1860 } else {
1861 None
1862 }
1863 })
1864 .collect();
1865 self.indexed_local_view(
1866 local_indices
1867 .into_iter()
1868 .map(|index| self.rows.physical_index(index)),
1869 )
1870 }
1871
1872 pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1875 #[cfg(feature = "mpi")]
1876 {
1877 if let Some(world) = crate::mpi::get_world() {
1878 return self.bootstrap_mpi(seed, &world);
1879 }
1880 }
1881 self.bootstrap_local(seed)
1882 }
1883
1884 pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1887 let compiled = expression.compile(&self.metadata)?;
1888 #[cfg(feature = "rayon")]
1889 let filtered_indices: Vec<usize> = (0..self.n_events_local())
1890 .into_par_iter()
1891 .filter_map(|event_index| {
1892 let event = self.event_view(event_index);
1893 compiled
1894 .evaluate(&event)
1895 .then(|| self.rows.physical_index(event_index))
1896 })
1897 .collect();
1898 #[cfg(not(feature = "rayon"))]
1899 let filtered_indices: Vec<usize> = (0..self.n_events_local())
1900 .into_iter()
1901 .filter_map(|event_index| {
1902 let event = self.event_view(event_index);
1903 compiled
1904 .evaluate(&event)
1905 .then(|| self.rows.physical_index(event_index))
1906 })
1907 .collect();
1908 Ok(self.indexed_local_view(filtered_indices))
1909 }
1910
1911 pub fn bin_by<V>(
1914 &self,
1915 mut variable: V,
1916 bins: usize,
1917 range: (f64, f64),
1918 ) -> LadduResult<BinnedDataset>
1919 where
1920 V: Variable,
1921 {
1922 variable.bind(self.metadata())?;
1923 let bin_width = (range.1 - range.0) / bins as f64;
1924 let bin_edges = get_bin_edges(bins, range);
1925 let variable = variable;
1926 #[cfg(feature = "rayon")]
1927 let evaluated: Vec<(usize, usize)> = (0..self.n_events_local())
1928 .into_par_iter()
1929 .filter_map(|event| {
1930 let value = variable.value(&self.event_view(event));
1931 if value >= range.0 && value < range.1 {
1932 let bin_index = ((value - range.0) / bin_width) as usize;
1933 let bin_index = bin_index.min(bins - 1);
1934 Some((bin_index, self.rows.physical_index(event)))
1935 } else {
1936 None
1937 }
1938 })
1939 .collect();
1940 #[cfg(not(feature = "rayon"))]
1941 let evaluated: Vec<(usize, usize)> = (0..self.n_events_local())
1942 .into_iter()
1943 .filter_map(|event| {
1944 let value = variable.value(&self.event_view(event));
1945 if value >= range.0 && value < range.1 {
1946 let bin_index = ((value - range.0) / bin_width) as usize;
1947 let bin_index = bin_index.min(bins - 1);
1948 Some((bin_index, self.rows.physical_index(event)))
1949 } else {
1950 None
1951 }
1952 })
1953 .collect();
1954 let mut binned_indices: Vec<Vec<usize>> = vec![Vec::default(); bins];
1955 for (bin_index, index) in evaluated {
1956 binned_indices[bin_index].push(index);
1957 }
1958 #[cfg(feature = "rayon")]
1959 let datasets: Vec<Arc<Dataset>> = binned_indices
1960 .into_par_iter()
1961 .map(|indices| self.indexed_local_view(indices))
1962 .collect();
1963 #[cfg(not(feature = "rayon"))]
1964 let datasets: Vec<Arc<Dataset>> = binned_indices
1965 .into_iter()
1966 .map(|indices| self.indexed_local_view(indices))
1967 .collect();
1968 Ok(BinnedDataset {
1969 datasets,
1970 edges: bin_edges,
1971 })
1972 }
1973
1974 pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1977 where
1978 S: AsRef<str>,
1979 {
1980 let mut indices: Vec<usize> = Vec::new();
1981 for name in names {
1982 let name_ref = name.as_ref();
1983 if let Some(selection) = self.metadata.p4_selection(name_ref) {
1984 indices.extend_from_slice(selection.indices());
1985 } else {
1986 panic!("Unknown particle name '{name}'", name = name_ref);
1987 }
1988 }
1989 #[cfg(feature = "rayon")]
1990 let boosted_events: Vec<Arc<EventData>> = self
1991 .local_event_data_arcs()
1992 .into_par_iter()
1993 .map(|event| Arc::new(event.boost_to_rest_frame_of(&indices)))
1994 .collect();
1995 #[cfg(not(feature = "rayon"))]
1996 let boosted_events: Vec<Arc<EventData>> = self
1997 .local_event_data_arcs()
1998 .into_iter()
1999 .map(|event| Arc::new(event.boost_to_rest_frame_of(&indices)))
2000 .collect();
2001 Arc::new(Dataset::new_with_metadata(
2002 boosted_events,
2003 self.metadata.clone(),
2004 ))
2005 }
2006 pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
2008 variable.value_on(self)
2009 }
2010}
2011
2012#[cfg(test)]
2013pub(crate) use super::io::write_parquet_storage;
2014pub use super::io::{
2015 read_parquet, read_parquet_chunks, read_parquet_chunks_with_options, read_root, write_parquet,
2016 write_root,
2017};
2018#[cfg(test)]
2019pub(crate) use super::io::{read_parquet_storage, read_root_storage};
2020
2021impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
2022 debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
2023 debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
2024 let events = a
2025 .local_event_data_arcs()
2026 .into_iter()
2027 .chain(b.local_event_data_arcs())
2028 .collect::<Vec<_>>();
2029 Dataset::new_with_metadata(events, a.metadata.clone())
2030});
2031
2032#[derive(Default)]
2034pub struct DatasetChunkBuilder {
2035 metadata: Option<Arc<DatasetMetadata>>,
2036 events: Vec<Arc<EventData>>,
2037}
2038
2039impl DatasetChunkBuilder {
2040 pub fn new() -> Self {
2042 Self::default()
2043 }
2044
2045 pub fn push_chunk(&mut self, chunk: &Dataset) -> LadduResult<()> {
2047 if let Some(existing) = &self.metadata {
2048 if existing.p4_names != chunk.metadata.p4_names
2049 || existing.aux_names != chunk.metadata.aux_names
2050 {
2051 return Err(LadduError::Custom(
2052 "Dataset chunk metadata does not match previous chunks".to_string(),
2053 ));
2054 }
2055 } else {
2056 self.metadata = Some(chunk.metadata.clone());
2057 }
2058 self.events.extend(chunk.local_event_data_arcs());
2059 Ok(())
2060 }
2061
2062 pub fn finish(self) -> Arc<Dataset> {
2064 let metadata = self
2065 .metadata
2066 .unwrap_or_else(|| Arc::new(DatasetMetadata::empty()));
2067 Arc::new(Dataset::new_with_metadata(self.events, metadata))
2068 }
2069}
2070
2071pub fn try_fold_dataset_chunks<I, T, F>(chunks: I, init: T, mut op: F) -> LadduResult<T>
2073where
2074 I: IntoIterator<Item = LadduResult<Arc<Dataset>>>,
2075 F: FnMut(T, &Dataset) -> LadduResult<T>,
2076{
2077 let mut acc = init;
2078 for chunk in chunks {
2079 let chunk = chunk?;
2080 acc = op(acc, &chunk)?;
2081 }
2082 Ok(acc)
2083}
2084
2085#[derive(Default, Clone)]
2090pub struct DatasetReadOptions {
2091 pub p4_names: Option<Vec<String>>,
2093 pub aux_names: Option<Vec<String>>,
2095 pub tree: Option<String>,
2098 pub aliases: IndexMap<String, P4Selection>,
2100 pub chunk_size: Option<usize>,
2102}
2103
2104#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
2106pub enum FloatPrecision {
2107 F32,
2109 #[default]
2111 F64,
2112}
2113
2114#[derive(Clone, Debug)]
2116pub struct DatasetWriteOptions {
2117 pub batch_size: usize,
2119 pub precision: FloatPrecision,
2121 pub tree: Option<String>,
2123}
2124
2125impl Default for DatasetWriteOptions {
2126 fn default() -> Self {
2127 Self {
2128 batch_size: DEFAULT_WRITE_BATCH_SIZE,
2129 precision: FloatPrecision::default(),
2130 tree: None,
2131 }
2132 }
2133}
2134
2135impl DatasetWriteOptions {
2136 pub fn batch_size(mut self, batch_size: usize) -> Self {
2138 self.batch_size = batch_size;
2139 self
2140 }
2141
2142 pub fn precision(mut self, precision: FloatPrecision) -> Self {
2144 self.precision = precision;
2145 self
2146 }
2147
2148 pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
2150 self.tree = Some(name.into());
2151 self
2152 }
2153}
2154impl DatasetReadOptions {
2155 pub fn new() -> Self {
2157 Self::default()
2158 }
2159
2160 pub fn p4_names<I, S>(mut self, names: I) -> Self
2163 where
2164 I: IntoIterator<Item = S>,
2165 S: AsRef<str>,
2166 {
2167 self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
2168 self
2169 }
2170
2171 pub fn aux_names<I, S>(mut self, names: I) -> Self
2175 where
2176 I: IntoIterator<Item = S>,
2177 S: AsRef<str>,
2178 {
2179 self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
2180 self
2181 }
2182
2183 pub fn tree<S>(mut self, name: S) -> Self
2185 where
2186 S: AsRef<str>,
2187 {
2188 self.tree = Some(name.as_ref().to_string());
2189 self
2190 }
2191
2192 pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
2194 where
2195 N: Into<String>,
2196 S: IntoP4Selection,
2197 {
2198 self.aliases.insert(name.into(), selection.into_selection());
2199 self
2200 }
2201
2202 pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
2204 where
2205 I: IntoIterator<Item = (N, S)>,
2206 N: Into<String>,
2207 S: IntoP4Selection,
2208 {
2209 for (name, selection) in aliases {
2210 self = self.alias(name, selection);
2211 }
2212 self
2213 }
2214
2215 pub fn chunk_size(mut self, chunk_size: usize) -> Self {
2217 self.chunk_size = Some(chunk_size.max(1));
2218 self
2219 }
2220
2221 pub(crate) fn resolve_metadata(
2222 &self,
2223 detected_p4_names: Vec<String>,
2224 detected_aux_names: Vec<String>,
2225 ) -> LadduResult<Arc<DatasetMetadata>> {
2226 let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
2227 let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
2228
2229 let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
2230 if !self.aliases.is_empty() {
2231 metadata.add_p4_aliases(self.aliases.clone())?;
2232 }
2233 Ok(Arc::new(metadata))
2234 }
2235}
2236
2237const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
2238pub(crate) const DEFAULT_READ_CHUNK_SIZE: usize = 10_000;
2239
2240pub struct BinnedDataset {
2242 datasets: Vec<Arc<Dataset>>,
2243 edges: Vec<f64>,
2244}
2245
2246impl Index<usize> for BinnedDataset {
2247 type Output = Arc<Dataset>;
2248
2249 fn index(&self, index: usize) -> &Self::Output {
2250 &self.datasets[index]
2251 }
2252}
2253
2254impl IndexMut<usize> for BinnedDataset {
2255 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2256 &mut self.datasets[index]
2257 }
2258}
2259
2260impl Deref for BinnedDataset {
2261 type Target = Vec<Arc<Dataset>>;
2262
2263 fn deref(&self) -> &Self::Target {
2264 &self.datasets
2265 }
2266}
2267
2268impl DerefMut for BinnedDataset {
2269 fn deref_mut(&mut self) -> &mut Self::Target {
2270 &mut self.datasets
2271 }
2272}
2273
2274impl BinnedDataset {
2275 pub fn n_bins(&self) -> usize {
2277 self.datasets.len()
2278 }
2279
2280 pub fn edges(&self) -> Vec<f64> {
2282 self.edges.clone()
2283 }
2284
2285 pub fn range(&self) -> (f64, f64) {
2287 (self.edges[0], self.edges[self.n_bins()])
2288 }
2289}