1#[cfg(feature = "mpi")]
2use crate::mpi::LadduMPI;
3use accurate::{sum::Klein, traits::*};
4use auto_ops::impl_op_ex;
5#[cfg(feature = "mpi")]
6use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
7#[cfg(feature = "rayon")]
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::{
11 fmt::Display,
12 ops::{Deref, DerefMut, Index, IndexMut},
13 sync::Arc,
14};
15
16#[cfg(feature = "mpi")]
17type WorldHandle = SimpleCommunicator;
18#[cfg(not(feature = "mpi"))]
19type WorldHandle = ();
20
21#[cfg(feature = "mpi")]
22const DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE: usize = 512;
25#[cfg(feature = "mpi")]
26const MPI_EVENT_FETCH_CHUNK_SIZE_ENV: &str = "LADDU_MPI_EVENT_FETCH_CHUNK_SIZE";
27
28use crate::utils::get_bin_edges;
29use crate::{
30 utils::{
31 variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
32 vectors::Vec4,
33 },
34 LadduError, LadduResult,
35};
36use indexmap::{IndexMap, IndexSet};
37
38pub mod io;
40
41pub fn test_event() -> EventData {
45 use crate::utils::vectors::*;
46 let pol_magnitude = 0.38562805;
47 let pol_angle = 0.05708078;
48 EventData {
49 p4s: vec![
50 Vec3::new(0.0, 0.0, 8.747).with_mass(0.0), Vec3::new(0.119, 0.374, 0.222).with_mass(1.007), Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498), Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), ],
55 aux: vec![pol_magnitude, pol_angle],
56 weight: 0.48,
57 }
58}
59
60pub const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
62pub const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
64
65pub fn test_dataset() -> Dataset {
69 let metadata = Arc::new(
70 DatasetMetadata::new(
71 TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
72 TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
73 )
74 .expect("Test metadata should be valid"),
75 );
76 Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
77}
78
79#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct EventData {
86 pub p4s: Vec<Vec4>,
88 pub aux: Vec<f64>,
90 pub weight: f64,
92}
93
94impl Display for EventData {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 writeln!(f, "Event:")?;
97 writeln!(f, " p4s:")?;
98 for p4 in &self.p4s {
99 writeln!(f, " {}", p4.to_p4_string())?;
100 }
101 writeln!(f, " aux:")?;
102 for (idx, value) in self.aux.iter().enumerate() {
103 writeln!(f, " aux[{idx}]: {value}")?;
104 }
105 writeln!(f, " weight:")?;
106 writeln!(f, " {}", self.weight)?;
107 Ok(())
108 }
109}
110
111impl EventData {
112 pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
114 indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
115 }
116 pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
119 let frame = self.get_p4_sum(indices);
120 EventData {
121 p4s: self
122 .p4s
123 .iter()
124 .map(|p4| p4.boost(&(-frame.beta())))
125 .collect(),
126 aux: self.aux.clone(),
127 weight: self.weight,
128 }
129 }
130}
131
132#[allow(dead_code)]
133#[derive(Debug, Clone, Default)]
134struct ColumnarP4Column {
135 px: Vec<f64>,
136 py: Vec<f64>,
137 pz: Vec<f64>,
138 e: Vec<f64>,
139}
140
141#[allow(dead_code)]
142impl ColumnarP4Column {
143 fn with_capacity(capacity: usize) -> Self {
144 Self {
145 px: Vec::with_capacity(capacity),
146 py: Vec::with_capacity(capacity),
147 pz: Vec::with_capacity(capacity),
148 e: Vec::with_capacity(capacity),
149 }
150 }
151
152 fn push(&mut self, p4: Vec4) {
153 self.px.push(p4.x);
154 self.py.push(p4.y);
155 self.pz.push(p4.z);
156 self.e.push(p4.t);
157 }
158
159 fn get(&self, event_index: usize) -> Vec4 {
160 Vec4::new(
161 self.px[event_index],
162 self.py[event_index],
163 self.pz[event_index],
164 self.e[event_index],
165 )
166 }
167}
168
169#[derive(Debug, Default)]
171pub(crate) struct DatasetStorage {
172 metadata: Arc<DatasetMetadata>,
173 p4: Vec<ColumnarP4Column>,
174 aux: Vec<Vec<f64>>,
175 weights: Vec<f64>,
176}
177
178impl Clone for DatasetStorage {
179 fn clone(&self) -> Self {
180 Self {
181 metadata: self.metadata.clone(),
182 p4: self.p4.clone(),
183 aux: self.aux.clone(),
184 weights: self.weights.clone(),
185 }
186 }
187}
188
189impl DatasetStorage {
190 pub(crate) fn to_dataset(&self) -> Dataset {
192 let events = (0..self.n_events())
193 .map(|event_index| Arc::new(self.event_data(event_index)))
194 .collect::<Vec<_>>();
195 #[cfg(not(feature = "mpi"))]
196 let dataset = Dataset::new_local(events, self.metadata.clone());
197 #[cfg(feature = "mpi")]
198 let mut dataset = Dataset::new_local(events, self.metadata.clone());
199 #[cfg(feature = "mpi")]
200 {
201 if let Some(world) = crate::mpi::get_world() {
202 dataset.set_cached_global_event_count_from_world(&world);
203 dataset.set_cached_global_weighted_sum_from_world(&world);
204 }
205 }
206 dataset
207 }
208
209 pub(crate) fn metadata(&self) -> &DatasetMetadata {
211 &self.metadata
212 }
213
214 pub(crate) fn n_events(&self) -> usize {
216 self.weights.len()
217 }
218
219 pub(crate) fn p4(&self, event_index: usize, p4_index: usize) -> Vec4 {
221 self.p4[p4_index].get(event_index)
222 }
223
224 pub(crate) fn aux(&self, event_index: usize, aux_index: usize) -> f64 {
226 self.aux[aux_index][event_index]
227 }
228
229 pub(crate) fn weight(&self, event_index: usize) -> f64 {
231 self.weights[event_index]
232 }
233
234 pub(crate) fn event_data(&self, event_index: usize) -> EventData {
235 let mut p4s = Vec::with_capacity(self.p4.len());
236 for p4_index in 0..self.p4.len() {
237 p4s.push(self.p4(event_index, p4_index));
238 }
239 let mut aux = Vec::with_capacity(self.aux.len());
240 for aux_index in 0..self.aux.len() {
241 aux.push(self.aux(event_index, aux_index));
242 }
243 EventData {
244 p4s,
245 aux,
246 weight: self.weight(event_index),
247 }
248 }
249
250 fn row_view(&self, event_index: usize) -> ColumnarEventView<'_> {
251 ColumnarEventView {
252 storage: self,
253 event_index,
254 }
255 }
256
257 #[allow(dead_code)]
258 pub(crate) fn for_each_named_event_local<F>(&self, mut op: F)
259 where
260 F: FnMut(usize, NamedEventView<'_>),
261 {
262 for event_index in 0..self.n_events() {
263 let row = self.row_view(event_index);
264 let view = NamedEventView {
265 row,
266 metadata: &self.metadata,
267 };
268 op(event_index, view);
269 }
270 }
271
272 pub(crate) fn event_view(&self, event_index: usize) -> NamedEventView<'_> {
273 let row = self.row_view(event_index);
274 NamedEventView {
275 row,
276 metadata: self.metadata(),
277 }
278 }
279}
280
281#[allow(dead_code)]
282#[derive(Debug)]
283struct ColumnarEventView<'a> {
284 storage: &'a DatasetStorage,
285 event_index: usize,
286}
287
288#[allow(dead_code)]
289impl ColumnarEventView<'_> {
290 fn p4(&self, p4_index: usize) -> Vec4 {
291 self.storage.p4(self.event_index, p4_index)
292 }
293
294 fn aux(&self, aux_index: usize) -> f64 {
295 self.storage.aux(self.event_index, aux_index)
296 }
297
298 fn weight(&self) -> f64 {
299 self.storage.weight(self.event_index)
300 }
301
302 fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
303 indices.as_ref().iter().map(|index| self.p4(*index)).sum()
304 }
305}
306
307#[derive(Debug)]
309pub struct NamedEventView<'a> {
310 row: ColumnarEventView<'a>,
311 metadata: &'a DatasetMetadata,
312}
313
314impl NamedEventView<'_> {
315 pub fn p4_at(&self, p4_index: usize) -> Vec4 {
317 self.row.p4(p4_index)
318 }
319
320 pub fn aux_at(&self, aux_index: usize) -> f64 {
322 self.row.aux(aux_index)
323 }
324
325 pub fn n_p4(&self) -> usize {
327 self.row.storage.p4.len()
328 }
329
330 pub fn n_aux(&self) -> usize {
332 self.row.storage.aux.len()
333 }
334
335 pub fn p4(&self, name: &str) -> Option<Vec4> {
337 let selection = self.metadata.p4_selection(name)?;
338 Some(
339 selection
340 .indices()
341 .iter()
342 .map(|index| self.row.p4(*index))
343 .sum(),
344 )
345 }
346
347 pub fn aux(&self, name: &str) -> Option<f64> {
349 let index = self.metadata.aux_index(name)?;
350 Some(self.row.aux(index))
351 }
352
353 pub fn weight(&self) -> f64 {
355 self.row.weight()
356 }
357
358 pub fn get_p4_sum<N>(&self, names: N) -> Option<Vec4>
360 where
361 N: IntoIterator,
362 N::Item: AsRef<str>,
363 {
364 names
365 .into_iter()
366 .map(|name| self.p4(name.as_ref()))
367 .collect::<Option<Vec<_>>>()
368 .map(|momenta| momenta.into_iter().sum())
369 }
370
371 pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
373 variable.value(self)
374 }
375}
376
377#[derive(Debug, Clone)]
379pub struct DatasetMetadata {
380 pub(crate) p4_names: Vec<String>,
381 pub(crate) aux_names: Vec<String>,
382 pub(crate) p4_lookup: IndexMap<String, usize>,
383 pub(crate) aux_lookup: IndexMap<String, usize>,
384 pub(crate) p4_selections: IndexMap<String, P4Selection>,
385}
386
387impl DatasetMetadata {
388 pub fn new<P: Into<String>, A: Into<String>>(
390 p4_names: Vec<P>,
391 aux_names: Vec<A>,
392 ) -> LadduResult<Self> {
393 let mut p4_lookup = IndexMap::with_capacity(p4_names.len());
394 let mut aux_lookup = IndexMap::with_capacity(aux_names.len());
395 let mut p4_selections = IndexMap::with_capacity(p4_names.len());
396 let p4_names: Vec<String> = p4_names
397 .into_iter()
398 .enumerate()
399 .map(|(idx, name)| {
400 let name = name.into();
401 if p4_lookup.contains_key(&name) {
402 return Err(LadduError::DuplicateName {
403 category: "p4",
404 name,
405 });
406 }
407 p4_lookup.insert(name.clone(), idx);
408 p4_selections.insert(
409 name.clone(),
410 P4Selection::with_indices(vec![name.clone()], vec![idx]),
411 );
412 Ok(name)
413 })
414 .collect::<Result<_, _>>()?;
415 let aux_names: Vec<String> = aux_names
416 .into_iter()
417 .enumerate()
418 .map(|(idx, name)| {
419 let name = name.into();
420 if aux_lookup.contains_key(&name) {
421 return Err(LadduError::DuplicateName {
422 category: "aux",
423 name,
424 });
425 }
426 aux_lookup.insert(name.clone(), idx);
427 Ok(name)
428 })
429 .collect::<Result<_, _>>()?;
430 Ok(Self {
431 p4_names,
432 aux_names,
433 p4_lookup,
434 aux_lookup,
435 p4_selections,
436 })
437 }
438
439 pub fn empty() -> Self {
441 Self {
442 p4_names: Vec::new(),
443 aux_names: Vec::new(),
444 p4_lookup: IndexMap::new(),
445 aux_lookup: IndexMap::new(),
446 p4_selections: IndexMap::new(),
447 }
448 }
449
450 pub fn p4_index(&self, name: &str) -> Option<usize> {
452 self.p4_lookup.get(name).copied()
453 }
454
455 pub fn p4_names(&self) -> &[String] {
457 &self.p4_names
458 }
459
460 pub fn aux_index(&self, name: &str) -> Option<usize> {
462 self.aux_lookup.get(name).copied()
463 }
464
465 pub fn aux_names(&self) -> &[String] {
467 &self.aux_names
468 }
469
470 pub fn p4_selection(&self, name: &str) -> Option<&P4Selection> {
472 self.p4_selections.get(name)
473 }
474
475 pub fn add_p4_alias<N>(&mut self, alias: N, mut selection: P4Selection) -> LadduResult<()>
477 where
478 N: Into<String>,
479 {
480 let alias = alias.into();
481 if self.p4_selections.contains_key(&alias) {
482 return Err(LadduError::DuplicateName {
483 category: "alias",
484 name: alias,
485 });
486 }
487 selection.bind(self)?;
488 self.p4_selections.insert(alias, selection);
489 Ok(())
490 }
491
492 pub fn add_p4_aliases<I, N>(&mut self, entries: I) -> LadduResult<()>
494 where
495 I: IntoIterator<Item = (N, P4Selection)>,
496 N: Into<String>,
497 {
498 for (alias, selection) in entries {
499 self.add_p4_alias(alias, selection)?;
500 }
501 Ok(())
502 }
503
504 pub(crate) fn append_indices_for_name(
505 &self,
506 name: &str,
507 target: &mut Vec<usize>,
508 ) -> LadduResult<()> {
509 if let Some(selection) = self.p4_selections.get(name) {
510 target.extend_from_slice(selection.indices());
511 return Ok(());
512 }
513 Err(LadduError::UnknownName {
514 category: "p4",
515 name: name.to_string(),
516 })
517 }
518}
519
520impl Default for DatasetMetadata {
521 fn default() -> Self {
522 Self::empty()
523 }
524}
525
526#[derive(Debug, Clone)]
528pub struct Dataset {
529 events: Vec<Event>,
531 pub(crate) columnar: DatasetStorage,
532 pub(crate) metadata: Arc<DatasetMetadata>,
533 pub(crate) cached_local_weighted_sum: f64,
534 #[cfg(feature = "mpi")]
535 pub(crate) cached_global_event_count: usize,
536 #[cfg(feature = "mpi")]
537 pub(crate) cached_global_weighted_sum: f64,
538}
539
540#[derive(Clone, Debug)]
542pub struct Event {
543 event: Arc<EventData>,
544 metadata: Arc<DatasetMetadata>,
545}
546
547impl Event {
548 pub fn new(event: Arc<EventData>, metadata: Arc<DatasetMetadata>) -> Self {
550 Self { event, metadata }
551 }
552
553 pub fn data(&self) -> &EventData {
555 &self.event
556 }
557
558 pub fn data_arc(&self) -> Arc<EventData> {
560 self.event.clone()
561 }
562
563 pub fn p4s(&self) -> IndexMap<&str, Vec4> {
565 let mut map = IndexMap::with_capacity(self.metadata.p4_names.len());
566 for (idx, name) in self.metadata.p4_names.iter().enumerate() {
567 if let Some(p4) = self.event.p4s.get(idx) {
568 map.insert(name.as_str(), *p4);
569 }
570 }
571 map
572 }
573
574 pub fn aux(&self) -> IndexMap<&str, f64> {
576 let mut map = IndexMap::with_capacity(self.metadata.aux_names.len());
577 for (idx, name) in self.metadata.aux_names.iter().enumerate() {
578 if let Some(value) = self.event.aux.get(idx) {
579 map.insert(name.as_str(), *value);
580 }
581 }
582 map
583 }
584
585 pub fn weight(&self) -> f64 {
587 self.event.weight
588 }
589
590 pub fn metadata(&self) -> &DatasetMetadata {
592 &self.metadata
593 }
594
595 pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
597 self.metadata.clone()
598 }
599
600 pub fn p4(&self, name: &str) -> Option<Vec4> {
602 self.metadata
603 .p4_selection(name)
604 .map(|selection| selection.momentum(&self.event))
605 }
606
607 fn resolve_p4_indices<N>(&self, names: N) -> Vec<usize>
608 where
609 N: IntoIterator,
610 N::Item: AsRef<str>,
611 {
612 let mut indices = Vec::new();
613 for name in names {
614 let name_ref = name.as_ref();
615 if let Some(selection) = self.metadata.p4_selection(name_ref) {
616 indices.extend_from_slice(selection.indices());
617 } else {
618 panic!("Unknown particle name '{name}'", name = name_ref);
619 }
620 }
621 indices
622 }
623
624 pub fn get_p4_sum<N>(&self, names: N) -> Vec4
626 where
627 N: IntoIterator,
628 N::Item: AsRef<str>,
629 {
630 let indices = self.resolve_p4_indices(names);
631 self.event.get_p4_sum(&indices)
632 }
633
634 pub fn boost_to_rest_frame_of<N>(&self, names: N) -> EventData
636 where
637 N: IntoIterator,
638 N::Item: AsRef<str>,
639 {
640 let indices = self.resolve_p4_indices(names);
641 self.event.boost_to_rest_frame_of(&indices)
642 }
643}
644
645impl Deref for Event {
646 type Target = EventData;
647
648 fn deref(&self) -> &Self::Target {
649 &self.event
650 }
651}
652
653impl AsRef<EventData> for Event {
654 fn as_ref(&self) -> &EventData {
655 self.data()
656 }
657}
658
659impl IntoIterator for Dataset {
660 type Item = Event;
661
662 type IntoIter = DatasetIntoIter;
663
664 fn into_iter(self) -> Self::IntoIter {
665 #[cfg(feature = "mpi")]
666 {
667 if let Some(world) = crate::mpi::get_world() {
668 let total = self.n_events();
670 return DatasetIntoIter::Mpi(DatasetMpiIntoIter {
671 events: self.events,
672 metadata: self.metadata,
673 world,
674 index: 0,
675 total,
676 cursor: MpiEventChunkCursor::for_iteration(total),
677 });
678 }
679 }
680 DatasetIntoIter::Local(self.events.into_iter())
681 }
682}
683
684fn shared_dataset_iter(dataset: Arc<Dataset>) -> DatasetArcIter {
685 #[cfg(feature = "mpi")]
686 {
687 if let Some(world) = crate::mpi::get_world() {
688 let total = dataset.n_events();
689 return DatasetArcIter::Mpi(DatasetArcMpiIter {
690 dataset,
691 world,
692 index: 0,
693 total,
694 cursor: MpiEventChunkCursor::for_iteration(total),
695 });
696 }
697 }
698 DatasetArcIter::Local { dataset, index: 0 }
699}
700
701pub trait SharedDatasetIterExt {
703 fn shared_iter(&self) -> DatasetArcIter;
705
706 fn shared_iter_global(&self) -> DatasetArcIter;
708}
709
710impl SharedDatasetIterExt for Arc<Dataset> {
711 fn shared_iter(&self) -> DatasetArcIter {
712 shared_dataset_iter(self.clone())
713 }
714
715 fn shared_iter_global(&self) -> DatasetArcIter {
716 self.shared_iter()
717 }
718}
719
720impl Dataset {
721 pub fn events_local(&self) -> &[Event] {
725 &self.events
726 }
727
728 pub fn events_global(&self) -> Vec<Event> {
733 self.iter_global().collect()
734 }
735
736 #[cfg(test)]
737 pub(crate) fn clear_events_local(&mut self) {
738 self.events.clear();
739 }
740
741 pub fn iter(&self) -> DatasetIter<'_> {
744 #[cfg(feature = "mpi")]
745 {
746 if let Some(world) = crate::mpi::get_world() {
747 let total = self.n_events();
748 return DatasetIter::Mpi(DatasetMpiIter {
749 dataset: self,
750 world,
751 index: 0,
752 total,
753 cursor: MpiEventChunkCursor::for_iteration(total),
754 });
755 }
756 }
757 DatasetIter::Local(self.events.iter())
758 }
759
760 pub fn iter_global(&self) -> DatasetIter<'_> {
764 self.iter()
765 }
766
767 pub fn metadata(&self) -> &DatasetMetadata {
769 &self.metadata
770 }
771
772 pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
774 self.metadata.clone()
775 }
776
777 pub fn p4_names(&self) -> &[String] {
779 &self.metadata.p4_names
780 }
781
782 pub fn aux_names(&self) -> &[String] {
784 &self.metadata.aux_names
785 }
786
787 pub fn p4_index(&self, name: &str) -> Option<usize> {
789 self.metadata.p4_index(name)
790 }
791
792 pub fn aux_index(&self, name: &str) -> Option<usize> {
794 self.metadata.aux_index(name)
795 }
796
797 pub fn named_event(&self, index: usize) -> LadduResult<Event> {
799 self.event(index)
800 }
801
802 pub fn named_event_global(&self, index: usize) -> LadduResult<Event> {
804 self.named_event(index)
805 }
806
807 pub fn get_event(&self, index: usize) -> Option<Event> {
809 #[cfg(feature = "mpi")]
810 {
811 if let Some(world) = crate::mpi::get_world() {
812 let total = self.n_events();
813 if index >= total {
814 return None;
815 }
816 return Some(fetch_event_mpi(self, index, &world, total));
817 }
818 }
819
820 self.events.get(index).cloned()
821 }
822
823 pub fn get_event_global(&self, index: usize) -> Option<Event> {
827 self.get_event(index)
828 }
829
830 pub fn event(&self, index: usize) -> LadduResult<Event> {
832 self.get_event(index).ok_or_else(|| {
833 LadduError::Custom(format!(
834 "Dataset index out of bounds: index {index}, length {}",
835 self.n_events()
836 ))
837 })
838 }
839
840 pub fn event_global(&self, index: usize) -> LadduResult<Event> {
844 self.event(index)
845 }
846
847 pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
849 self.get_event(event_index).and_then(|event| event.p4(name))
850 }
851
852 pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
854 let idx = self.aux_index(name)?;
855 self.get_event(event_index)
856 .and_then(|event| event.aux.get(idx).copied())
857 }
858
859 pub fn for_each_named_event_local<F>(&self, op: F)
861 where
862 F: FnMut(usize, NamedEventView<'_>),
863 {
864 self.columnar.for_each_named_event_local(op);
865 }
866
867 pub fn event_view(&self, event_index: usize) -> NamedEventView<'_> {
869 self.columnar.event_view(event_index)
870 }
871
872 pub fn index_local(&self, index: usize) -> &Event {
885 &self.events[index]
886 }
887
888 #[cfg(feature = "mpi")]
889 fn partition(
890 events: Vec<Arc<EventData>>,
891 world: &SimpleCommunicator,
892 ) -> Vec<Vec<Arc<EventData>>> {
893 let partition = world.partition(events.len());
894 (0..partition.n_ranks())
895 .map(|rank| {
896 let range = partition.range_for_rank(rank);
897 events[range.clone()].to_vec()
898 })
899 .collect()
900 }
901}
902
903pub enum DatasetIter<'a> {
905 Local(std::slice::Iter<'a, Event>),
907 #[cfg(feature = "mpi")]
908 Mpi(DatasetMpiIter<'a>),
910}
911
912impl<'a> Iterator for DatasetIter<'a> {
913 type Item = Event;
914
915 fn next(&mut self) -> Option<Self::Item> {
916 match self {
917 DatasetIter::Local(iter) => iter.next().cloned(),
918 #[cfg(feature = "mpi")]
919 DatasetIter::Mpi(iter) => iter.next(),
920 }
921 }
922}
923
924pub enum DatasetIntoIter {
926 Local(std::vec::IntoIter<Event>),
928 #[cfg(feature = "mpi")]
929 Mpi(DatasetMpiIntoIter),
931}
932
933impl Iterator for DatasetIntoIter {
934 type Item = Event;
935
936 fn next(&mut self) -> Option<Self::Item> {
937 match self {
938 DatasetIntoIter::Local(iter) => iter.next(),
939 #[cfg(feature = "mpi")]
940 DatasetIntoIter::Mpi(iter) => iter.next(),
941 }
942 }
943}
944
945pub enum DatasetArcIter {
947 Local {
949 dataset: Arc<Dataset>,
951 index: usize,
953 },
954 #[cfg(feature = "mpi")]
955 Mpi(DatasetArcMpiIter),
957}
958
959impl Iterator for DatasetArcIter {
960 type Item = Event;
961
962 fn next(&mut self) -> Option<Self::Item> {
963 match self {
964 DatasetArcIter::Local { dataset, index } => {
965 let event = dataset.events.get(*index).cloned();
966 *index += 1;
967 event
968 }
969 #[cfg(feature = "mpi")]
970 DatasetArcIter::Mpi(iter) => iter.next(),
971 }
972 }
973}
974
975#[cfg(feature = "mpi")]
976pub struct DatasetMpiIter<'a> {
978 dataset: &'a Dataset,
979 world: SimpleCommunicator,
980 index: usize,
981 total: usize,
982 cursor: MpiEventChunkCursor,
983}
984
985#[cfg(feature = "mpi")]
986#[derive(Debug, Clone)]
987pub(crate) struct MpiEventChunkCursor {
988 chunk_start: usize,
989 chunk_size: usize,
990 events: Vec<Event>,
991}
992
993#[cfg(feature = "mpi")]
994fn resolve_mpi_event_fetch_chunk_size(total: usize) -> usize {
995 let clamped_total = total.max(1);
996 if let Some(raw) = std::env::var_os(MPI_EVENT_FETCH_CHUNK_SIZE_ENV) {
997 if let Some(parsed) = raw.to_str().and_then(|value| value.parse::<usize>().ok()) {
998 return parsed.max(1).min(clamped_total);
999 }
1000 }
1001 DEFAULT_MPI_EVENT_FETCH_CHUNK_SIZE.min(clamped_total)
1002}
1003
1004#[cfg(feature = "mpi")]
1005impl MpiEventChunkCursor {
1006 pub(crate) fn for_iteration(total: usize) -> Self {
1007 Self::new(resolve_mpi_event_fetch_chunk_size(total))
1008 }
1009}
1010
1011#[cfg(feature = "mpi")]
1012impl MpiEventChunkCursor {
1013 pub(crate) fn new(chunk_size: usize) -> Self {
1014 Self {
1015 chunk_start: 0,
1016 chunk_size: chunk_size.max(1),
1017 events: Vec::new(),
1018 }
1019 }
1020
1021 fn chunk_end(&self) -> usize {
1022 self.chunk_start + self.events.len()
1023 }
1024
1025 fn contains(&self, global_index: usize) -> bool {
1026 global_index >= self.chunk_start && global_index < self.chunk_end()
1027 }
1028
1029 pub(crate) fn event_for_dataset(
1030 &mut self,
1031 dataset: &Dataset,
1032 global_index: usize,
1033 world: &SimpleCommunicator,
1034 total: usize,
1035 ) -> Option<Event> {
1036 if global_index >= total {
1037 return None;
1038 }
1039 if !self.contains(global_index) {
1040 self.chunk_start = global_index;
1041 self.events =
1042 fetch_event_chunk_mpi(dataset, global_index, self.chunk_size, world, total);
1043 }
1044 self.events.get(global_index - self.chunk_start).cloned()
1045 }
1046
1047 pub(crate) fn event_for_events(
1048 &mut self,
1049 events: &[Event],
1050 metadata: &Arc<DatasetMetadata>,
1051 global_index: usize,
1052 world: &SimpleCommunicator,
1053 total: usize,
1054 ) -> Option<Event> {
1055 if global_index >= total {
1056 return None;
1057 }
1058 if !self.contains(global_index) {
1059 self.chunk_start = global_index;
1060 self.events = fetch_event_chunk_mpi_from_events(
1061 events,
1062 metadata,
1063 global_index,
1064 self.chunk_size,
1065 world,
1066 total,
1067 );
1068 }
1069 self.events.get(global_index - self.chunk_start).cloned()
1070 }
1071}
1072
1073#[cfg(feature = "mpi")]
1074impl<'a> Iterator for DatasetMpiIter<'a> {
1075 type Item = Event;
1076
1077 fn next(&mut self) -> Option<Self::Item> {
1078 let event =
1079 self.cursor
1080 .event_for_dataset(self.dataset, self.index, &self.world, self.total);
1081 self.index += 1;
1082 event
1083 }
1084}
1085
1086#[cfg(feature = "mpi")]
1087pub struct DatasetArcMpiIter {
1089 dataset: Arc<Dataset>,
1090 world: SimpleCommunicator,
1091 index: usize,
1092 total: usize,
1093 cursor: MpiEventChunkCursor,
1094}
1095
1096#[cfg(feature = "mpi")]
1097impl Iterator for DatasetArcMpiIter {
1098 type Item = Event;
1099
1100 fn next(&mut self) -> Option<Self::Item> {
1101 let event =
1102 self.cursor
1103 .event_for_dataset(&self.dataset, self.index, &self.world, self.total);
1104 self.index += 1;
1105 event
1106 }
1107}
1108
1109#[cfg(feature = "mpi")]
1110pub struct DatasetMpiIntoIter {
1112 events: Vec<Event>,
1113 metadata: Arc<DatasetMetadata>,
1114 world: SimpleCommunicator,
1115 index: usize,
1116 total: usize,
1117 cursor: MpiEventChunkCursor,
1118}
1119
1120#[cfg(feature = "mpi")]
1121impl Iterator for DatasetMpiIntoIter {
1122 type Item = Event;
1123
1124 fn next(&mut self) -> Option<Self::Item> {
1125 let event = self.cursor.event_for_events(
1126 &self.events,
1127 &self.metadata,
1128 self.index,
1129 &self.world,
1130 self.total,
1131 );
1132 self.index += 1;
1133 event
1134 }
1135}
1136
1137#[cfg(feature = "mpi")]
1138fn fetch_event_mpi(
1139 dataset: &Dataset,
1140 global_index: usize,
1141 world: &SimpleCommunicator,
1142 total: usize,
1143) -> Event {
1144 fetch_event_mpi_generic(
1145 global_index,
1146 total,
1147 world,
1148 &dataset.metadata,
1149 |local_index| dataset.index_local(local_index),
1150 )
1151}
1152
1153#[cfg(feature = "mpi")]
1154fn fetch_event_chunk_mpi(
1155 dataset: &Dataset,
1156 start: usize,
1157 len: usize,
1158 world: &SimpleCommunicator,
1159 total: usize,
1160) -> Vec<Event> {
1161 fetch_event_chunk_mpi_generic(start, len, total, world, &dataset.metadata, |local_index| {
1162 dataset.index_local(local_index)
1163 })
1164}
1165
1166#[cfg(feature = "mpi")]
1167fn fetch_event_chunk_mpi_from_events(
1168 events: &[Event],
1169 metadata: &Arc<DatasetMetadata>,
1170 start: usize,
1171 len: usize,
1172 world: &SimpleCommunicator,
1173 total: usize,
1174) -> Vec<Event> {
1175 fetch_event_chunk_mpi_generic(start, len, total, world, metadata, |local_index| {
1176 &events[local_index]
1177 })
1178}
1179
1180#[cfg(feature = "mpi")]
1181fn fetch_event_mpi_generic<'a, F>(
1182 global_index: usize,
1183 total: usize,
1184 world: &SimpleCommunicator,
1185 metadata: &Arc<DatasetMetadata>,
1186 local_event: F,
1187) -> Event
1188where
1189 F: Fn(usize) -> &'a Event,
1190{
1191 let (owning_rank, local_index) = world.owner_of_global_index(global_index, total);
1192 let mut serialized_event_buffer_len: usize = 0;
1193 let mut serialized_event_buffer: Vec<u8> = Vec::default();
1194 if world.rank() == owning_rank {
1195 let event = local_event(local_index);
1196 serialized_event_buffer = bitcode::serialize(event.data()).unwrap();
1197 serialized_event_buffer_len = serialized_event_buffer.len();
1198 }
1199 world
1200 .process_at_rank(owning_rank)
1201 .broadcast_into(&mut serialized_event_buffer_len);
1202 if world.rank() != owning_rank {
1203 serialized_event_buffer = vec![0; serialized_event_buffer_len];
1204 }
1205 world
1206 .process_at_rank(owning_rank)
1207 .broadcast_into(&mut serialized_event_buffer);
1208
1209 if world.rank() == owning_rank {
1210 local_event(local_index).clone()
1211 } else {
1212 let event: EventData = bitcode::deserialize(&serialized_event_buffer[..]).unwrap();
1213 Event::new(Arc::new(event), metadata.clone())
1214 }
1215}
1216
1217#[cfg(feature = "mpi")]
1218#[allow(dead_code)]
1219fn fetch_event_chunk_mpi_generic<'a, F>(
1220 start: usize,
1221 len: usize,
1222 total: usize,
1223 world: &SimpleCommunicator,
1224 metadata: &Arc<DatasetMetadata>,
1225 local_event: F,
1226) -> Vec<Event>
1227where
1228 F: Fn(usize) -> &'a Event,
1229{
1230 if len == 0 || start >= total {
1231 return Vec::new();
1232 }
1233
1234 let end = (start + len).min(total);
1235 let partition = world.partition(total);
1236 let local_range = partition.range_for_rank(world.rank() as usize);
1237 let owned_start = start.max(local_range.start);
1238 let owned_end = end.min(local_range.end);
1239 let local_indices = if owned_start < owned_end {
1240 (owned_start - local_range.start)..(owned_end - local_range.start)
1241 } else {
1242 0..0
1243 };
1244
1245 let local_events: Vec<EventData> = local_indices
1246 .map(|local_index| local_event(local_index).data().clone())
1247 .collect();
1248 let local_event_count = local_events.len() as i32;
1249
1250 let serialized_local = if local_events.is_empty() {
1251 Vec::new()
1252 } else {
1253 bitcode::serialize(&local_events).unwrap()
1254 };
1255 let local_byte_count = serialized_local.len() as i32;
1256
1257 let mut gathered_event_counts = vec![0_i32; world.size() as usize];
1258 let mut gathered_byte_counts = vec![0_i32; world.size() as usize];
1259 world.all_gather_into(&local_event_count, &mut gathered_event_counts);
1260 world.all_gather_into(&local_byte_count, &mut gathered_byte_counts);
1261
1262 let mut gathered_byte_displs = vec![0_i32; gathered_byte_counts.len()];
1263 for index in 1..gathered_byte_displs.len() {
1264 gathered_byte_displs[index] =
1265 gathered_byte_displs[index - 1] + gathered_byte_counts[index - 1];
1266 }
1267 let gathered_bytes = world.all_gather_with_counts(
1268 &serialized_local,
1269 &gathered_byte_counts,
1270 &gathered_byte_displs,
1271 );
1272
1273 let mut events = Vec::with_capacity(end - start);
1274 for rank in 0..world.size() as usize {
1275 if gathered_event_counts[rank] == 0 {
1276 continue;
1277 }
1278 let byte_start = gathered_byte_displs[rank] as usize;
1279 let byte_end = byte_start + gathered_byte_counts[rank] as usize;
1280 let decoded: Vec<EventData> =
1281 bitcode::deserialize(&gathered_bytes[byte_start..byte_end]).unwrap();
1282 debug_assert_eq!(decoded.len(), gathered_event_counts[rank] as usize);
1283 events.extend(
1284 decoded
1285 .into_iter()
1286 .map(|event| Event::new(Arc::new(event), metadata.clone())),
1287 );
1288 }
1289
1290 events
1291}
1292
1293impl Dataset {
1294 #[cfg(feature = "mpi")]
1295 pub(crate) fn set_cached_global_event_count_from_world(&mut self, world: &SimpleCommunicator) {
1296 let local_count = self.n_events_local();
1297 let mut global_count = 0usize;
1298 world.all_reduce_into(
1299 &local_count,
1300 &mut global_count,
1301 mpi::collective::SystemOperation::sum(),
1302 );
1303 self.cached_global_event_count = global_count;
1304 }
1305
1306 #[cfg(feature = "mpi")]
1307 pub(crate) fn set_cached_global_weighted_sum_from_world(&mut self, world: &SimpleCommunicator) {
1308 let mut weighted_sums = vec![0.0_f64; world.size() as usize];
1309 world.all_gather_into(&self.cached_local_weighted_sum, &mut weighted_sums);
1310 #[cfg(feature = "rayon")]
1311 {
1312 self.cached_global_weighted_sum = weighted_sums
1313 .into_par_iter()
1314 .parallel_sum_with_accumulator::<Klein<f64>>();
1315 }
1316 #[cfg(not(feature = "rayon"))]
1317 {
1318 self.cached_global_weighted_sum = weighted_sums
1319 .into_iter()
1320 .sum_with_accumulator::<Klein<f64>>();
1321 }
1322 }
1323
1324 fn columnar_from_wrapped_events(
1325 events: &[Event],
1326 metadata: Arc<DatasetMetadata>,
1327 ) -> LadduResult<DatasetStorage> {
1328 let n_events = events.len();
1329 let (n_p4, n_aux) = match events.first() {
1330 Some(first) => (first.p4s.len(), first.aux.len()),
1331 None => (metadata.p4_names.len(), metadata.aux_names.len()),
1332 };
1333 let mut p4 = (0..n_p4)
1334 .map(|_| ColumnarP4Column::with_capacity(n_events))
1335 .collect::<Vec<_>>();
1336 let mut aux = (0..n_aux)
1337 .map(|_| Vec::with_capacity(n_events))
1338 .collect::<Vec<_>>();
1339 let mut weights = Vec::with_capacity(n_events);
1340 for (event_index, event) in events.iter().enumerate() {
1341 if event.p4s.len() != n_p4 || event.aux.len() != n_aux {
1342 return Err(LadduError::Custom(format!(
1343 "Ragged dataset shape at event {event_index}: expected ({n_p4} p4, {n_aux} aux), got ({} p4, {} aux)",
1344 event.p4s.len(),
1345 event.aux.len()
1346 )));
1347 }
1348 for (column, value) in p4.iter_mut().zip(event.p4s.iter()) {
1349 column.push(*value);
1350 }
1351 for (column, value) in aux.iter_mut().zip(event.aux.iter()) {
1352 column.push(*value);
1353 }
1354 weights.push(event.weight);
1355 }
1356 Ok(DatasetStorage {
1357 metadata,
1358 p4,
1359 aux,
1360 weights,
1361 })
1362 }
1363
1364 pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1371 let wrapped_events = events
1372 .into_iter()
1373 .map(|event| Event::new(event, metadata.clone()))
1374 .collect::<Vec<_>>();
1375 #[cfg(feature = "mpi")]
1376 let local_count = wrapped_events.len();
1377 let columnar = Self::columnar_from_wrapped_events(&wrapped_events, metadata.clone())
1378 .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1379 #[cfg(feature = "rayon")]
1380 let local_weighted_sum = columnar
1381 .weights
1382 .par_iter()
1383 .copied()
1384 .parallel_sum_with_accumulator::<Klein<f64>>();
1385 #[cfg(not(feature = "rayon"))]
1386 let local_weighted_sum = columnar
1387 .weights
1388 .iter()
1389 .copied()
1390 .sum_with_accumulator::<Klein<f64>>();
1391 Dataset {
1392 events: wrapped_events,
1393 columnar,
1394 metadata,
1395 cached_local_weighted_sum: local_weighted_sum,
1396 #[cfg(feature = "mpi")]
1397 cached_global_event_count: local_count,
1398 #[cfg(feature = "mpi")]
1399 cached_global_weighted_sum: local_weighted_sum,
1400 }
1401 }
1402
1403 #[cfg(feature = "mpi")]
1410 pub fn new_mpi(
1411 events: Vec<Arc<EventData>>,
1412 metadata: Arc<DatasetMetadata>,
1413 world: &SimpleCommunicator,
1414 ) -> Self {
1415 let partitions = Dataset::partition(events, world);
1416 let local: Vec<Event> = partitions[world.rank() as usize]
1417 .iter()
1418 .cloned()
1419 .map(|event| Event::new(event, metadata.clone()))
1420 .collect();
1421 let columnar = Self::columnar_from_wrapped_events(&local, metadata.clone())
1422 .expect("Dataset requires rectangular p4/aux columns for canonical columnar storage");
1423 #[cfg(feature = "rayon")]
1424 let local_weighted_sum = columnar
1425 .weights
1426 .par_iter()
1427 .copied()
1428 .parallel_sum_with_accumulator::<Klein<f64>>();
1429 #[cfg(not(feature = "rayon"))]
1430 let local_weighted_sum = columnar
1431 .weights
1432 .iter()
1433 .copied()
1434 .sum_with_accumulator::<Klein<f64>>();
1435 let mut dataset = Dataset {
1436 events: local,
1437 columnar,
1438 metadata,
1439 cached_local_weighted_sum: local_weighted_sum,
1440 cached_global_event_count: 0,
1441 cached_global_weighted_sum: local_weighted_sum,
1442 };
1443 dataset.set_cached_global_event_count_from_world(world);
1444 dataset.set_cached_global_weighted_sum_from_world(world);
1445 dataset
1446 }
1447
1448 pub fn new(events: Vec<Arc<EventData>>) -> Self {
1454 Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
1455 }
1456
1457 pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
1460 #[cfg(feature = "mpi")]
1461 {
1462 if let Some(world) = crate::mpi::get_world() {
1463 return Dataset::new_mpi(events, metadata, &world);
1464 }
1465 }
1466 Dataset::new_local(events, metadata)
1467 }
1468
1469 pub fn n_events_local(&self) -> usize {
1476 self.columnar.n_events()
1477 }
1478
1479 #[cfg(feature = "mpi")]
1486 pub fn n_events_mpi(&self, _world: &SimpleCommunicator) -> usize {
1487 self.cached_global_event_count
1488 }
1489
1490 pub fn n_events(&self) -> usize {
1492 #[cfg(feature = "mpi")]
1493 {
1494 if let Some(world) = crate::mpi::get_world() {
1495 return self.n_events_mpi(&world);
1496 }
1497 }
1498 self.n_events_local()
1499 }
1500
1501 pub fn n_events_global(&self) -> usize {
1505 self.n_events()
1506 }
1507}
1508
1509impl Dataset {
1510 pub fn weights_local(&self) -> Vec<f64> {
1517 self.columnar.weights.clone()
1518 }
1519
1520 #[cfg(feature = "mpi")]
1527 pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
1528 let local_weights = self.weights_local();
1529 let n_events = self.n_events();
1530 let mut buffer: Vec<f64> = vec![0.0; n_events];
1531 let (counts, displs) = world.get_counts_displs(n_events);
1532 {
1533 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1536 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
1537 }
1538 buffer
1539 }
1540
1541 pub fn weights(&self) -> Vec<f64> {
1543 #[cfg(feature = "mpi")]
1544 {
1545 if let Some(world) = crate::mpi::get_world() {
1546 return self.weights_mpi(&world);
1547 }
1548 }
1549 self.weights_local()
1550 }
1551
1552 pub fn weights_global(&self) -> Vec<f64> {
1556 self.weights()
1557 }
1558
1559 pub fn n_events_weighted_local(&self) -> f64 {
1566 self.cached_local_weighted_sum
1567 }
1568 #[cfg(feature = "mpi")]
1575 pub fn n_events_weighted_mpi(&self, _world: &SimpleCommunicator) -> f64 {
1576 self.cached_global_weighted_sum
1577 }
1578
1579 pub fn n_events_weighted(&self) -> f64 {
1581 #[cfg(feature = "mpi")]
1582 {
1583 if let Some(world) = crate::mpi::get_world() {
1584 return self.n_events_weighted_mpi(&world);
1585 }
1586 }
1587 self.n_events_weighted_local()
1588 }
1589
1590 pub fn n_events_weighted_global(&self) -> f64 {
1594 self.n_events_weighted()
1595 }
1596
1597 pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
1605 let mut rng = fastrand::Rng::with_seed(seed as u64);
1606 let mut indices: Vec<usize> = (0..self.n_events())
1607 .map(|_| rng.usize(0..self.n_events()))
1608 .collect::<Vec<usize>>();
1609 indices.sort();
1610 #[cfg(feature = "rayon")]
1611 let bootstrapped_events: Vec<Arc<EventData>> = indices
1612 .into_par_iter()
1613 .map(|idx| self.events[idx].data_arc())
1614 .collect();
1615 #[cfg(not(feature = "rayon"))]
1616 let bootstrapped_events: Vec<Arc<EventData>> = indices
1617 .into_iter()
1618 .map(|idx| self.events[idx].data_arc())
1619 .collect();
1620 Arc::new(Dataset::new_with_metadata(
1621 bootstrapped_events,
1622 self.metadata.clone(),
1623 ))
1624 }
1625
1626 #[cfg(feature = "mpi")]
1634 pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
1635 let n_events = self.n_events();
1636 let mut indices: Vec<usize> = vec![0; n_events];
1637 if world.is_root() {
1638 let mut rng = fastrand::Rng::with_seed(seed as u64);
1639 indices = (0..n_events)
1640 .map(|_| rng.usize(0..n_events))
1641 .collect::<Vec<usize>>();
1642 indices.sort();
1643 }
1644 world.process_at_root().broadcast_into(&mut indices);
1645 let local_indices: Vec<usize> = indices
1646 .into_iter()
1647 .filter_map(|idx| {
1648 let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
1649 if world.rank() == owning_rank {
1650 Some(local_index)
1651 } else {
1652 None
1653 }
1654 })
1655 .collect();
1656 #[cfg(feature = "rayon")]
1659 let bootstrapped_events: Vec<Arc<EventData>> = local_indices
1660 .into_par_iter()
1661 .map(|idx| self.events[idx].data_arc())
1662 .collect();
1663 #[cfg(not(feature = "rayon"))]
1664 let bootstrapped_events: Vec<Arc<EventData>> = local_indices
1665 .into_iter()
1666 .map(|idx| self.events[idx].data_arc())
1667 .collect();
1668 Arc::new(Dataset::new_with_metadata(
1669 bootstrapped_events,
1670 self.metadata.clone(),
1671 ))
1672 }
1673
1674 pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1677 #[cfg(feature = "mpi")]
1678 {
1679 if let Some(world) = crate::mpi::get_world() {
1680 return self.bootstrap_mpi(seed, &world);
1681 }
1682 }
1683 self.bootstrap_local(seed)
1684 }
1685
1686 pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1689 let compiled = expression.compile(&self.metadata)?;
1690 #[cfg(feature = "rayon")]
1691 let filtered_events: Vec<Arc<EventData>> = (0..self.n_events_local())
1692 .into_par_iter()
1693 .filter_map(|event_index| {
1694 let event = self.event_view(event_index);
1695 compiled
1696 .evaluate(&event)
1697 .then(|| self.events[event_index].data_arc())
1698 })
1699 .collect();
1700 #[cfg(not(feature = "rayon"))]
1701 let filtered_events: Vec<Arc<EventData>> = (0..self.n_events_local())
1702 .into_iter()
1703 .filter_map(|event_index| {
1704 let event = self.event_view(event_index);
1705 compiled
1706 .evaluate(&event)
1707 .then(|| self.events[event_index].data_arc())
1708 })
1709 .collect();
1710 Ok(Arc::new(Dataset::new_with_metadata(
1711 filtered_events,
1712 self.metadata.clone(),
1713 )))
1714 }
1715
1716 pub fn bin_by<V>(
1719 &self,
1720 mut variable: V,
1721 bins: usize,
1722 range: (f64, f64),
1723 ) -> LadduResult<BinnedDataset>
1724 where
1725 V: Variable,
1726 {
1727 variable.bind(self.metadata())?;
1728 let bin_width = (range.1 - range.0) / bins as f64;
1729 let bin_edges = get_bin_edges(bins, range);
1730 let variable = variable;
1731 #[cfg(feature = "rayon")]
1732 let evaluated: Vec<(usize, Arc<EventData>)> = (0..self.n_events_local())
1733 .into_par_iter()
1734 .filter_map(|event| {
1735 let value = variable.value(&self.event_view(event));
1736 if value >= range.0 && value < range.1 {
1737 let bin_index = ((value - range.0) / bin_width) as usize;
1738 let bin_index = bin_index.min(bins - 1);
1739 Some((bin_index, self.events[event].data_arc()))
1740 } else {
1741 None
1742 }
1743 })
1744 .collect();
1745 #[cfg(not(feature = "rayon"))]
1746 let evaluated: Vec<(usize, Arc<EventData>)> = (0..self.n_events_local())
1747 .into_iter()
1748 .filter_map(|event| {
1749 let value = variable.value(&self.event_view(event));
1750 if value >= range.0 && value < range.1 {
1751 let bin_index = ((value - range.0) / bin_width) as usize;
1752 let bin_index = bin_index.min(bins - 1);
1753 Some((bin_index, self.events[event].data_arc()))
1754 } else {
1755 None
1756 }
1757 })
1758 .collect();
1759 let mut binned_events: Vec<Vec<Arc<EventData>>> = vec![Vec::default(); bins];
1760 for (bin_index, event) in evaluated {
1761 binned_events[bin_index].push(event.clone());
1762 }
1763 #[cfg(feature = "rayon")]
1764 let datasets: Vec<Arc<Dataset>> = binned_events
1765 .into_par_iter()
1766 .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1767 .collect();
1768 #[cfg(not(feature = "rayon"))]
1769 let datasets: Vec<Arc<Dataset>> = binned_events
1770 .into_iter()
1771 .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1772 .collect();
1773 Ok(BinnedDataset {
1774 datasets,
1775 edges: bin_edges,
1776 })
1777 }
1778
1779 pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1782 where
1783 S: AsRef<str>,
1784 {
1785 let mut indices: Vec<usize> = Vec::new();
1786 for name in names {
1787 let name_ref = name.as_ref();
1788 if let Some(selection) = self.metadata.p4_selection(name_ref) {
1789 indices.extend_from_slice(selection.indices());
1790 } else {
1791 panic!("Unknown particle name '{name}'", name = name_ref);
1792 }
1793 }
1794 #[cfg(feature = "rayon")]
1795 let boosted_events: Vec<Arc<EventData>> = self
1796 .events
1797 .par_iter()
1798 .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1799 .collect();
1800 #[cfg(not(feature = "rayon"))]
1801 let boosted_events: Vec<Arc<EventData>> = self
1802 .events
1803 .iter()
1804 .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1805 .collect();
1806 Arc::new(Dataset::new_with_metadata(
1807 boosted_events,
1808 self.metadata.clone(),
1809 ))
1810 }
1811 pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
1813 variable.value_on(self)
1814 }
1815}
1816
1817#[cfg(test)]
1818pub(crate) use io::write_parquet_storage;
1819pub use io::{
1820 read_parquet, read_parquet_chunks, read_parquet_chunks_with_options, read_root, write_parquet,
1821 write_root,
1822};
1823#[cfg(test)]
1824pub(crate) use io::{read_parquet_storage, read_root_storage};
1825
1826impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
1827 debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
1828 debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
1829 let events = a
1830 .events
1831 .iter()
1832 .chain(b.events.iter())
1833 .map(Event::data_arc)
1834 .collect::<Vec<_>>();
1835 Dataset::new_with_metadata(events, a.metadata.clone())
1836});
1837
1838#[derive(Default)]
1840pub struct DatasetChunkBuilder {
1841 metadata: Option<Arc<DatasetMetadata>>,
1842 events: Vec<Arc<EventData>>,
1843}
1844
1845impl DatasetChunkBuilder {
1846 pub fn new() -> Self {
1848 Self::default()
1849 }
1850
1851 pub fn push_chunk(&mut self, chunk: &Dataset) -> LadduResult<()> {
1853 if let Some(existing) = &self.metadata {
1854 if existing.p4_names != chunk.metadata.p4_names
1855 || existing.aux_names != chunk.metadata.aux_names
1856 {
1857 return Err(LadduError::Custom(
1858 "Dataset chunk metadata does not match previous chunks".to_string(),
1859 ));
1860 }
1861 } else {
1862 self.metadata = Some(chunk.metadata.clone());
1863 }
1864 self.events
1865 .extend(chunk.events_local().iter().map(Event::data_arc));
1866 Ok(())
1867 }
1868
1869 pub fn finish(self) -> Arc<Dataset> {
1871 let metadata = self
1872 .metadata
1873 .unwrap_or_else(|| Arc::new(DatasetMetadata::empty()));
1874 Arc::new(Dataset::new_with_metadata(self.events, metadata))
1875 }
1876}
1877
1878pub fn try_fold_dataset_chunks<I, T, F>(chunks: I, init: T, mut op: F) -> LadduResult<T>
1880where
1881 I: IntoIterator<Item = LadduResult<Arc<Dataset>>>,
1882 F: FnMut(T, &Dataset) -> LadduResult<T>,
1883{
1884 let mut acc = init;
1885 for chunk in chunks {
1886 let chunk = chunk?;
1887 acc = op(acc, &chunk)?;
1888 }
1889 Ok(acc)
1890}
1891
1892#[derive(Default, Clone)]
1897pub struct DatasetReadOptions {
1898 pub p4_names: Option<Vec<String>>,
1900 pub aux_names: Option<Vec<String>>,
1902 pub tree: Option<String>,
1905 pub aliases: IndexMap<String, P4Selection>,
1907 pub chunk_size: Option<usize>,
1909}
1910
1911#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
1913pub enum FloatPrecision {
1914 F32,
1916 #[default]
1918 F64,
1919}
1920
1921#[derive(Clone, Debug)]
1923pub struct DatasetWriteOptions {
1924 pub batch_size: usize,
1926 pub precision: FloatPrecision,
1928 pub tree: Option<String>,
1930}
1931
1932impl Default for DatasetWriteOptions {
1933 fn default() -> Self {
1934 Self {
1935 batch_size: DEFAULT_WRITE_BATCH_SIZE,
1936 precision: FloatPrecision::default(),
1937 tree: None,
1938 }
1939 }
1940}
1941
1942impl DatasetWriteOptions {
1943 pub fn batch_size(mut self, batch_size: usize) -> Self {
1945 self.batch_size = batch_size;
1946 self
1947 }
1948
1949 pub fn precision(mut self, precision: FloatPrecision) -> Self {
1951 self.precision = precision;
1952 self
1953 }
1954
1955 pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
1957 self.tree = Some(name.into());
1958 self
1959 }
1960}
1961impl DatasetReadOptions {
1962 pub fn new() -> Self {
1964 Self::default()
1965 }
1966
1967 pub fn p4_names<I, S>(mut self, names: I) -> Self
1970 where
1971 I: IntoIterator<Item = S>,
1972 S: AsRef<str>,
1973 {
1974 self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1975 self
1976 }
1977
1978 pub fn aux_names<I, S>(mut self, names: I) -> Self
1982 where
1983 I: IntoIterator<Item = S>,
1984 S: AsRef<str>,
1985 {
1986 self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1987 self
1988 }
1989
1990 pub fn tree<S>(mut self, name: S) -> Self
1992 where
1993 S: AsRef<str>,
1994 {
1995 self.tree = Some(name.as_ref().to_string());
1996 self
1997 }
1998
1999 pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
2001 where
2002 N: Into<String>,
2003 S: IntoP4Selection,
2004 {
2005 self.aliases.insert(name.into(), selection.into_selection());
2006 self
2007 }
2008
2009 pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
2011 where
2012 I: IntoIterator<Item = (N, S)>,
2013 N: Into<String>,
2014 S: IntoP4Selection,
2015 {
2016 for (name, selection) in aliases {
2017 self = self.alias(name, selection);
2018 }
2019 self
2020 }
2021
2022 pub fn chunk_size(mut self, chunk_size: usize) -> Self {
2024 self.chunk_size = Some(chunk_size.max(1));
2025 self
2026 }
2027
2028 fn resolve_metadata(
2029 &self,
2030 detected_p4_names: Vec<String>,
2031 detected_aux_names: Vec<String>,
2032 ) -> LadduResult<Arc<DatasetMetadata>> {
2033 let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
2034 let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
2035
2036 let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
2037 if !self.aliases.is_empty() {
2038 metadata.add_p4_aliases(self.aliases.clone())?;
2039 }
2040 Ok(Arc::new(metadata))
2041 }
2042}
2043
2044const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
2045pub(crate) const DEFAULT_READ_CHUNK_SIZE: usize = 10_000;
2046
2047pub struct BinnedDataset {
2049 datasets: Vec<Arc<Dataset>>,
2050 edges: Vec<f64>,
2051}
2052
2053impl Index<usize> for BinnedDataset {
2054 type Output = Arc<Dataset>;
2055
2056 fn index(&self, index: usize) -> &Self::Output {
2057 &self.datasets[index]
2058 }
2059}
2060
2061impl IndexMut<usize> for BinnedDataset {
2062 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2063 &mut self.datasets[index]
2064 }
2065}
2066
2067impl Deref for BinnedDataset {
2068 type Target = Vec<Arc<Dataset>>;
2069
2070 fn deref(&self) -> &Self::Target {
2071 &self.datasets
2072 }
2073}
2074
2075impl DerefMut for BinnedDataset {
2076 fn deref_mut(&mut self) -> &mut Self::Target {
2077 &mut self.datasets
2078 }
2079}
2080
2081impl BinnedDataset {
2082 pub fn n_bins(&self) -> usize {
2084 self.datasets.len()
2085 }
2086
2087 pub fn edges(&self) -> Vec<f64> {
2089 self.edges.clone()
2090 }
2091
2092 pub fn range(&self) -> (f64, f64) {
2094 (self.edges[0], self.edges[self.n_bins()])
2095 }
2096}
2097
2098#[cfg(test)]
2099mod tests {
2100 use crate::Mass;
2101
2102 use super::*;
2103 #[cfg(feature = "mpi")]
2104 use crate::mpi::{finalize_mpi, get_world, use_mpi};
2105 use crate::utils::vectors::Vec3;
2106 use approx::{assert_relative_eq, assert_relative_ne};
2107 use fastrand;
2108 #[cfg(feature = "mpi")]
2109 use mpi_test::mpi_test;
2110 use serde::{Deserialize, Serialize};
2111 use std::{
2112 env, fs,
2113 path::{Path, PathBuf},
2114 };
2115
2116 fn test_data_path(file: &str) -> PathBuf {
2117 Path::new(env!("CARGO_MANIFEST_DIR"))
2118 .join("test_data")
2119 .join(file)
2120 }
2121
2122 fn open_test_dataset(file: &str, options: DatasetReadOptions) -> Arc<Dataset> {
2123 let path = test_data_path(file);
2124 let path_str = path.to_str().expect("test data path should be valid UTF-8");
2125 let ext = path
2126 .extension()
2127 .and_then(|ext| ext.to_str())
2128 .unwrap_or_default()
2129 .to_ascii_lowercase();
2130 match ext.as_str() {
2131 "parquet" => read_parquet(path_str, &options),
2132 "root" => read_root(path_str, &options),
2133 other => panic!("Unsupported extension in test data: {other}"),
2134 }
2135 .expect("dataset should open")
2136 }
2137
2138 fn make_temp_dir() -> PathBuf {
2139 let dir = env::temp_dir().join(format!("laddu_test_{}", fastrand::u64(..)));
2140 fs::create_dir(&dir).expect("temp dir should be created");
2141 dir
2142 }
2143
2144 #[cfg(feature = "mpi")]
2145 fn mpi_chunk_test_dataset(n_events: usize) -> Dataset {
2146 let metadata = test_dataset().metadata_arc();
2147 let base = test_event();
2148 let events = (0..n_events)
2149 .map(|index| {
2150 let mut event = base.clone();
2151 event.p4s[0] =
2152 Vec3::new(index as f64 * 0.1, 0.0, 8.747 + index as f64 * 0.01).with_mass(0.0);
2153 event.aux[0] += index as f64;
2154 event.aux[1] += index as f64 * 0.5;
2155 event.weight = 1.0 + index as f64;
2156 Arc::new(event)
2157 })
2158 .collect();
2159 Dataset::new_with_metadata(events, metadata)
2160 }
2161
2162 fn assert_events_close(left: &Event, right: &Event, p4_names: &[&str], aux_names: &[&str]) {
2163 for name in p4_names {
2164 let lp4 = left
2165 .p4(name)
2166 .unwrap_or_else(|| panic!("missing p4 '{name}' in left dataset"));
2167 let rp4 = right
2168 .p4(name)
2169 .unwrap_or_else(|| panic!("missing p4 '{name}' in right dataset"));
2170 assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-9);
2171 assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-9);
2172 assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-9);
2173 assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-9);
2174 }
2175 let left_aux = left.aux();
2176 let right_aux = right.aux();
2177 for name in aux_names {
2178 let laux = left_aux
2179 .get(name)
2180 .copied()
2181 .unwrap_or_else(|| panic!("missing aux '{name}' in left dataset"));
2182 let raux = right_aux
2183 .get(name)
2184 .copied()
2185 .unwrap_or_else(|| panic!("missing aux '{name}' in right dataset"));
2186 assert_relative_eq!(laux, raux, epsilon = 1e-9);
2187 }
2188 assert_relative_eq!(left.weight(), right.weight(), epsilon = 1e-9);
2189 }
2190
2191 fn assert_datasets_close(
2192 left: &Arc<Dataset>,
2193 right: &Arc<Dataset>,
2194 p4_names: &[&str],
2195 aux_names: &[&str],
2196 ) {
2197 assert_eq!(left.n_events(), right.n_events());
2198 for idx in 0..left.n_events() {
2199 let Ok(levent) = left.event(idx) else {
2200 panic!("left dataset missing event at index {idx}");
2201 };
2202 let Ok(revent) = right.event(idx) else {
2203 panic!("right dataset missing event at index {idx}");
2204 };
2205 assert_events_close(&levent, &revent, p4_names, aux_names);
2206 }
2207 }
2208
2209 fn assert_dataset_columnar_close(left: &DatasetStorage, right: &DatasetStorage) {
2210 assert_eq!(left.n_events(), right.n_events());
2211 assert_eq!(left.metadata().p4_names(), right.metadata().p4_names());
2212 assert_eq!(left.metadata().aux_names(), right.metadata().aux_names());
2213 for event_index in 0..left.n_events() {
2214 for p4_index in 0..left.metadata().p4_names().len() {
2215 let lp4 = left.p4(event_index, p4_index);
2216 let rp4 = right.p4(event_index, p4_index);
2217 assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-12);
2218 assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-12);
2219 assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-12);
2220 assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-12);
2221 }
2222 for aux_index in 0..left.metadata().aux_names().len() {
2223 let l = left.aux(event_index, aux_index);
2224 let r = right.aux(event_index, aux_index);
2225 assert_relative_eq!(l, r, epsilon = 1e-12);
2226 }
2227 let lw = left.weight(event_index);
2228 let rw = right.weight(event_index);
2229 assert_relative_eq!(lw, rw, epsilon = 1e-12);
2230 }
2231 }
2232
2233 #[test]
2234 fn test_from_parquet_auto_matches_explicit_names() {
2235 let auto = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2236 let explicit_options = DatasetReadOptions::new()
2237 .p4_names(TEST_P4_NAMES)
2238 .aux_names(TEST_AUX_NAMES);
2239 let explicit = open_test_dataset("data_f32.parquet", explicit_options);
2240
2241 let mut detected_p4: Vec<&str> = auto.p4_names().iter().map(String::as_str).collect();
2242 detected_p4.sort_unstable();
2243 let mut expected_p4 = TEST_P4_NAMES.to_vec();
2244 expected_p4.sort_unstable();
2245 assert_eq!(detected_p4, expected_p4);
2246 let mut detected_aux: Vec<&str> = auto.aux_names().iter().map(String::as_str).collect();
2247 detected_aux.sort_unstable();
2248 let mut expected_aux = TEST_AUX_NAMES.to_vec();
2249 expected_aux.sort_unstable();
2250 assert_eq!(detected_aux, expected_aux);
2251 assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2252 }
2253
2254 #[test]
2255 fn test_from_parquet_with_aliases() {
2256 let dataset = open_test_dataset(
2257 "data_f32.parquet",
2258 DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2259 );
2260 let event = dataset.named_event(0).expect("event should exist");
2261 let alias_vec = event.p4("resonance").expect("alias vector");
2262 let expected = event.get_p4_sum(["kshort1", "kshort2"]);
2263 assert_relative_eq!(alias_vec.px(), expected.px(), epsilon = 1e-9);
2264 assert_relative_eq!(alias_vec.py(), expected.py(), epsilon = 1e-9);
2265 assert_relative_eq!(alias_vec.pz(), expected.pz(), epsilon = 1e-9);
2266 assert_relative_eq!(alias_vec.e(), expected.e(), epsilon = 1e-9);
2267 }
2268
2269 #[test]
2270 fn test_from_parquet_alias_resolution_parity_auto_vs_explicit() {
2271 let auto = open_test_dataset(
2272 "data_f32.parquet",
2273 DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2274 );
2275 let explicit = open_test_dataset(
2276 "data_f32.parquet",
2277 DatasetReadOptions::new()
2278 .p4_names(TEST_P4_NAMES)
2279 .aux_names(TEST_AUX_NAMES)
2280 .alias("resonance", ["kshort1", "kshort2"]),
2281 );
2282
2283 assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2284 for event_index in 0..auto.n_events() {
2285 let auto_event = auto
2286 .named_event(event_index)
2287 .expect("auto parquet event should exist");
2288 let explicit_event = explicit
2289 .named_event(event_index)
2290 .expect("explicit parquet event should exist");
2291
2292 let auto_alias = auto_event
2293 .p4("resonance")
2294 .expect("auto alias should resolve");
2295 let explicit_alias = explicit_event
2296 .p4("resonance")
2297 .expect("explicit alias should resolve");
2298 let auto_expected = auto_event.get_p4_sum(["kshort1", "kshort2"]);
2299 let explicit_expected = explicit_event.get_p4_sum(["kshort1", "kshort2"]);
2300
2301 assert_relative_eq!(auto_alias.px(), auto_expected.px(), epsilon = 1e-9);
2302 assert_relative_eq!(auto_alias.py(), auto_expected.py(), epsilon = 1e-9);
2303 assert_relative_eq!(auto_alias.pz(), auto_expected.pz(), epsilon = 1e-9);
2304 assert_relative_eq!(auto_alias.e(), auto_expected.e(), epsilon = 1e-9);
2305
2306 assert_relative_eq!(explicit_alias.px(), explicit_expected.px(), epsilon = 1e-9);
2307 assert_relative_eq!(explicit_alias.py(), explicit_expected.py(), epsilon = 1e-9);
2308 assert_relative_eq!(explicit_alias.pz(), explicit_expected.pz(), epsilon = 1e-9);
2309 assert_relative_eq!(explicit_alias.e(), explicit_expected.e(), epsilon = 1e-9);
2310 }
2311 }
2312
2313 #[test]
2314 fn test_from_parquet_f64_matches_f32() {
2315 let f32_ds = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2316 let f64_ds = open_test_dataset("data_f64.parquet", DatasetReadOptions::new());
2317 assert_datasets_close(&f64_ds, &f32_ds, TEST_P4_NAMES, TEST_AUX_NAMES);
2318 }
2319
2320 #[test]
2321 fn test_from_root_detects_columns_and_matches_parquet() {
2322 let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2323 let root_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2324 let mut detected_p4: Vec<&str> = root_auto.p4_names().iter().map(String::as_str).collect();
2325 detected_p4.sort_unstable();
2326 let mut expected_p4 = TEST_P4_NAMES.to_vec();
2327 expected_p4.sort_unstable();
2328 assert_eq!(detected_p4, expected_p4);
2329 let mut detected_aux: Vec<&str> =
2330 root_auto.aux_names().iter().map(String::as_str).collect();
2331 detected_aux.sort_unstable();
2332 let mut expected_aux = TEST_AUX_NAMES.to_vec();
2333 expected_aux.sort_unstable();
2334 assert_eq!(detected_aux, expected_aux);
2335 let root_named_options = DatasetReadOptions::new()
2336 .p4_names(TEST_P4_NAMES)
2337 .aux_names(TEST_AUX_NAMES);
2338 let root_named = open_test_dataset("data_f32.root", root_named_options);
2339 assert_datasets_close(&root_auto, &root_named, TEST_P4_NAMES, TEST_AUX_NAMES);
2340 assert_datasets_close(&root_auto, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2341 }
2342
2343 #[cfg(feature = "mpi")]
2344 #[mpi_test(np = [2])]
2345 fn test_from_root_metadata_matches_non_mpi_under_mpi() {
2346 let reference_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2347 let explicit_options = DatasetReadOptions::new()
2348 .p4_names(TEST_P4_NAMES)
2349 .aux_names(TEST_AUX_NAMES);
2350 let reference_explicit = open_test_dataset("data_f32.root", explicit_options.clone());
2351
2352 use_mpi(true);
2353 let local_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2354 let local_explicit = open_test_dataset("data_f32.root", explicit_options);
2355
2356 assert_eq!(local_auto.p4_names(), reference_auto.p4_names());
2357 assert_eq!(local_auto.aux_names(), reference_auto.aux_names());
2358 assert_eq!(local_explicit.p4_names(), reference_explicit.p4_names());
2359 assert_eq!(local_explicit.aux_names(), reference_explicit.aux_names());
2360 assert_eq!(local_auto.p4_names(), local_explicit.p4_names());
2361 assert_eq!(local_auto.aux_names(), local_explicit.aux_names());
2362
2363 for name in local_auto.p4_names() {
2364 let local_auto_selection = local_auto
2365 .metadata()
2366 .p4_selection(name)
2367 .expect("local auto canonical p4 selection should exist");
2368 let reference_auto_selection = reference_auto
2369 .metadata()
2370 .p4_selection(name)
2371 .expect("reference auto canonical p4 selection should exist");
2372 let local_explicit_selection = local_explicit
2373 .metadata()
2374 .p4_selection(name)
2375 .expect("local explicit canonical p4 selection should exist");
2376 assert_eq!(
2377 local_auto_selection.names(),
2378 reference_auto_selection.names()
2379 );
2380 assert_eq!(
2381 local_auto_selection.indices(),
2382 reference_auto_selection.indices()
2383 );
2384 assert_eq!(
2385 local_explicit_selection.names(),
2386 reference_auto_selection.names()
2387 );
2388 assert_eq!(
2389 local_explicit_selection.indices(),
2390 reference_auto_selection.indices()
2391 );
2392 }
2393
2394 finalize_mpi();
2395 }
2396
2397 #[test]
2398 fn test_from_root_f64_matches_parquet() {
2399 let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2400 let root_f64 = open_test_dataset("data_f64.root", DatasetReadOptions::new());
2401 assert_datasets_close(&root_f64, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2402 }
2403
2404 #[cfg(feature = "mpi")]
2405 #[mpi_test(np = [2])]
2406 fn test_from_root_alias_resolution_matches_non_mpi_under_mpi() {
2407 let alias_options = DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]);
2408 let explicit_alias_options = DatasetReadOptions::new()
2409 .p4_names(TEST_P4_NAMES)
2410 .aux_names(TEST_AUX_NAMES)
2411 .alias("resonance", ["kshort1", "kshort2"]);
2412 let reference_auto = open_test_dataset("data_f32.root", alias_options.clone());
2413 let reference_explicit = open_test_dataset("data_f32.root", explicit_alias_options.clone());
2414
2415 use_mpi(true);
2416 let world = get_world().expect("MPI world should be initialized");
2417 let local_auto = open_test_dataset("data_f32.root", alias_options);
2418 let local_explicit = open_test_dataset("data_f32.root", explicit_alias_options);
2419
2420 let local_auto_alias = local_auto
2421 .metadata()
2422 .p4_selection("resonance")
2423 .expect("local auto alias should exist");
2424 let local_explicit_alias = local_explicit
2425 .metadata()
2426 .p4_selection("resonance")
2427 .expect("local explicit alias should exist");
2428 let reference_alias = reference_auto
2429 .metadata()
2430 .p4_selection("resonance")
2431 .expect("reference alias should exist");
2432 let reference_explicit_alias = reference_explicit
2433 .metadata()
2434 .p4_selection("resonance")
2435 .expect("reference explicit alias should exist");
2436 assert_eq!(local_auto_alias.names(), reference_alias.names());
2437 assert_eq!(local_auto_alias.indices(), reference_alias.indices());
2438 assert_eq!(
2439 local_explicit_alias.names(),
2440 reference_explicit_alias.names()
2441 );
2442 assert_eq!(
2443 local_explicit_alias.indices(),
2444 reference_explicit_alias.indices()
2445 );
2446 assert_eq!(local_auto_alias.names(), local_explicit_alias.names());
2447 assert_eq!(local_auto_alias.indices(), local_explicit_alias.indices());
2448
2449 let partition = world.partition(reference_auto.n_events());
2450 let local_range = partition.range_for_rank(world.rank() as usize);
2451 assert_eq!(local_auto.n_events_local(), local_range.len());
2452 assert_eq!(local_explicit.n_events_local(), local_range.len());
2453
2454 for (local_index, global_index) in local_range.enumerate() {
2455 let local_auto_event = local_auto.event_view(local_index);
2456 let local_explicit_event = local_explicit.event_view(local_index);
2457 let reference_event = reference_auto.event_view(global_index);
2458 let reference_explicit_event = reference_explicit.event_view(global_index);
2459
2460 let local_auto_value = local_auto_event
2461 .p4("resonance")
2462 .expect("local auto alias should resolve");
2463 let local_explicit_value = local_explicit_event
2464 .p4("resonance")
2465 .expect("local explicit alias should resolve");
2466 let reference_value = reference_event
2467 .p4("resonance")
2468 .expect("reference alias should resolve");
2469 let reference_explicit_value = reference_explicit_event
2470 .p4("resonance")
2471 .expect("reference explicit alias should resolve");
2472
2473 assert_relative_eq!(local_auto_value.px(), reference_value.px(), epsilon = 1e-9);
2474 assert_relative_eq!(local_auto_value.py(), reference_value.py(), epsilon = 1e-9);
2475 assert_relative_eq!(local_auto_value.pz(), reference_value.pz(), epsilon = 1e-9);
2476 assert_relative_eq!(local_auto_value.e(), reference_value.e(), epsilon = 1e-9);
2477
2478 assert_relative_eq!(
2479 local_explicit_value.px(),
2480 reference_explicit_value.px(),
2481 epsilon = 1e-9
2482 );
2483 assert_relative_eq!(
2484 local_explicit_value.py(),
2485 reference_explicit_value.py(),
2486 epsilon = 1e-9
2487 );
2488 assert_relative_eq!(
2489 local_explicit_value.pz(),
2490 reference_explicit_value.pz(),
2491 epsilon = 1e-9
2492 );
2493 assert_relative_eq!(
2494 local_explicit_value.e(),
2495 reference_explicit_value.e(),
2496 epsilon = 1e-9
2497 );
2498 }
2499
2500 finalize_mpi();
2501 }
2502
2503 #[test]
2504 fn test_event_creation() {
2505 let event = test_event();
2506 assert_eq!(event.p4s.len(), 4);
2507 assert_eq!(event.aux.len(), 2);
2508 assert_relative_eq!(event.weight, 0.48)
2509 }
2510
2511 #[test]
2512 fn test_event_p4_sum() {
2513 let event = test_event();
2514 let sum = event.get_p4_sum([2, 3]);
2515 assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
2516 assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
2517 assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
2518 assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
2519 }
2520
2521 #[test]
2522 fn test_event_boost() {
2523 let event = test_event();
2524 let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
2525 let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
2526 assert_relative_eq!(p4_sum.px(), 0.0);
2527 assert_relative_eq!(p4_sum.py(), 0.0);
2528 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2529 }
2530
2531 #[test]
2532 fn test_named_event_view_evaluate() {
2533 let dataset = test_dataset();
2534 let event = dataset.event_view(0);
2535 let mut mass = Mass::new(["proton"]);
2536 mass.bind(dataset.metadata()).unwrap();
2537 assert_relative_eq!(event.evaluate(&mass), 1.007);
2538 }
2539
2540 #[test]
2541 fn test_dataset_size_check() {
2542 let dataset = Dataset::new(Vec::new());
2543 assert_eq!(dataset.n_events(), 0);
2544 let dataset = Dataset::new(vec![Arc::new(test_event())]);
2545 assert_eq!(dataset.n_events(), 1);
2546 }
2547
2548 #[test]
2549 fn test_dataset_sum() {
2550 let dataset = test_dataset();
2551 let metadata = dataset.metadata_arc();
2552 let dataset2 = Dataset::new_with_metadata(
2553 vec![Arc::new(EventData {
2554 p4s: test_event().p4s,
2555 aux: test_event().aux,
2556 weight: 0.52,
2557 })],
2558 metadata.clone(),
2559 );
2560 let dataset_sum = &dataset + &dataset2;
2561 assert_eq!(
2562 dataset_sum.event(0).expect("event should exist").weight,
2563 dataset.event(0).expect("event should exist").weight
2564 );
2565 assert_eq!(
2566 dataset_sum.event(1).expect("event should exist").weight,
2567 dataset2.event(0).expect("event should exist").weight
2568 );
2569 }
2570
2571 #[test]
2572 fn test_dataset_weights() {
2573 let dataset = Dataset::new(vec![
2574 Arc::new(test_event()),
2575 Arc::new(EventData {
2576 p4s: test_event().p4s,
2577 aux: test_event().aux,
2578 weight: 0.52,
2579 }),
2580 ]);
2581 let weights = dataset.weights();
2582 assert_eq!(weights.len(), 2);
2583 assert_relative_eq!(weights[0], 0.48);
2584 assert_relative_eq!(weights[1], 0.52);
2585 assert_relative_eq!(dataset.n_events_weighted(), 1.0);
2586 }
2587
2588 #[test]
2589 #[should_panic(
2590 expected = "Dataset requires rectangular p4/aux columns for canonical columnar storage"
2591 )]
2592 fn test_dataset_rejects_ragged_rows_at_construction() {
2593 let _ = Dataset::new(vec![
2594 Arc::new(EventData {
2595 p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0)],
2596 aux: vec![0.1],
2597 weight: 1.0,
2598 }),
2599 Arc::new(EventData {
2600 p4s: vec![],
2601 aux: vec![0.2, 0.3],
2602 weight: 2.0,
2603 }),
2604 ]);
2605 }
2606
2607 #[test]
2608 fn test_dataset_filtering() {
2609 let metadata = Arc::new(
2610 DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2611 .expect("metadata should be valid"),
2612 );
2613 let events = vec![
2614 Arc::new(EventData {
2615 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
2616 aux: vec![],
2617 weight: 1.0,
2618 }),
2619 Arc::new(EventData {
2620 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
2621 aux: vec![],
2622 weight: 1.0,
2623 }),
2624 Arc::new(EventData {
2625 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
2626 aux: vec![],
2629 weight: 1.0,
2630 }),
2631 ];
2632 let dataset = Dataset::new_with_metadata(events, metadata);
2633
2634 let metadata = dataset.metadata_arc();
2635 let mut mass = Mass::new(["beam"]);
2636 mass.bind(metadata.as_ref()).unwrap();
2637 let expression = mass.gt(0.0).and(&mass.lt(1.0));
2638
2639 let filtered = dataset.filter(&expression).unwrap();
2640 assert_eq!(filtered.n_events(), 1);
2641 assert_relative_eq!(mass.value(&filtered.event_view(0)), 0.5);
2642 }
2643
2644 #[test]
2645 fn test_dataset_boost() {
2646 let dataset = test_dataset();
2647 let dataset_boosted = dataset.boost_to_rest_frame_of(&["proton", "kshort1", "kshort2"]);
2648 let p4_sum = dataset_boosted
2649 .event(0)
2650 .expect("event should exist")
2651 .get_p4_sum(["proton", "kshort1", "kshort2"]);
2652 assert_relative_eq!(p4_sum.px(), 0.0);
2653 assert_relative_eq!(p4_sum.py(), 0.0);
2654 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2655 }
2656
2657 #[test]
2658 fn test_named_event_view() {
2659 let dataset = test_dataset();
2660 let view = dataset.named_event(0).expect("event should exist");
2661 let dataset_event = dataset.event(0).expect("event should exist");
2662 assert_relative_eq!(view.weight(), dataset_event.weight);
2663 let beam = view.p4("beam").expect("beam p4");
2664 assert_relative_eq!(beam.px(), dataset_event.p4s[0].px());
2665 assert_relative_eq!(beam.e(), dataset_event.p4s[0].e());
2666
2667 let summed = view.get_p4_sum(["kshort1", "kshort2"]);
2668 assert_relative_eq!(
2669 summed.e(),
2670 dataset_event.p4s[2].e() + dataset_event.p4s[3].e()
2671 );
2672
2673 let aux_angle = view.aux().get("pol_angle").copied().expect("pol angle");
2674 assert_relative_eq!(aux_angle, dataset_event.aux[1]);
2675
2676 let metadata = dataset.metadata_arc();
2677 let boosted = view.boost_to_rest_frame_of(["proton", "kshort1", "kshort2"]);
2678 let boosted_event = Event::new(Arc::new(boosted), metadata);
2679 let boosted_sum = boosted_event.get_p4_sum(["proton", "kshort1", "kshort2"]);
2680 assert_relative_eq!(boosted_sum.px(), 0.0);
2681 }
2682
2683 #[test]
2684 fn test_dataset_evaluate() {
2685 let dataset = test_dataset();
2686 let mass = Mass::new(["proton"]);
2687 assert_relative_eq!(dataset.evaluate(&mass).unwrap()[0], 1.007);
2688 }
2689
2690 #[test]
2691 fn test_dataset_metadata_rejects_duplicate_names() {
2692 let err = DatasetMetadata::new(vec!["beam", "beam"], Vec::<String>::new());
2693 assert!(matches!(
2694 err,
2695 Err(LadduError::DuplicateName { category, .. }) if category == "p4"
2696 ));
2697 let err = DatasetMetadata::new(
2698 vec!["beam"],
2699 vec!["pol_angle".to_string(), "pol_angle".to_string()],
2700 );
2701 assert!(matches!(
2702 err,
2703 Err(LadduError::DuplicateName { category, .. }) if category == "aux"
2704 ));
2705 }
2706
2707 #[test]
2708 fn test_dataset_lookup_by_name() {
2709 let dataset = test_dataset();
2710 let proton = dataset.p4_by_name(0, "proton").expect("proton p4");
2711 let proton_idx = dataset.metadata().p4_index("proton").unwrap();
2712 assert_relative_eq!(
2713 proton.e(),
2714 dataset.event(0).expect("event should exist").p4s[proton_idx].e()
2715 );
2716 assert!(dataset.p4_by_name(0, "unknown").is_none());
2717 let angle = dataset.aux_by_name(0, "pol_angle").expect("pol_angle");
2718 assert_relative_eq!(angle, dataset.event(0).expect("event should exist").aux[1]);
2719 assert!(dataset.aux_by_name(0, "missing").is_none());
2720 }
2721
2722 #[test]
2723 fn test_binned_dataset() {
2724 let dataset = Dataset::new(vec![
2725 Arc::new(EventData {
2726 p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
2727 aux: vec![],
2728 weight: 1.0,
2729 }),
2730 Arc::new(EventData {
2731 p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
2732 aux: vec![],
2733 weight: 2.0,
2734 }),
2735 ]);
2736
2737 #[derive(Clone, Serialize, Deserialize, Debug)]
2738 struct BeamEnergy;
2739 impl Display for BeamEnergy {
2740 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2741 write!(f, "BeamEnergy")
2742 }
2743 }
2744 #[typetag::serde]
2745 impl Variable for BeamEnergy {
2746 fn value(&self, event: &NamedEventView<'_>) -> f64 {
2747 event.p4_at(0).e()
2748 }
2749 }
2750 assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
2751
2752 let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0)).unwrap();
2754
2755 assert_eq!(binned.n_bins(), 2);
2756 assert_eq!(binned.edges().len(), 3);
2757 assert_relative_eq!(binned.edges()[0], 0.0);
2758 assert_relative_eq!(binned.edges()[2], 3.0);
2759 assert_eq!(binned[0].n_events(), 1);
2760 assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
2761 assert_eq!(binned[1].n_events(), 1);
2762 assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
2763 }
2764
2765 #[test]
2766 fn test_dataset_bootstrap() {
2767 let metadata = test_dataset().metadata_arc();
2768 let dataset = Dataset::new_with_metadata(
2769 vec![
2770 Arc::new(test_event()),
2771 Arc::new(EventData {
2772 p4s: test_event().p4s.clone(),
2773 aux: test_event().aux.clone(),
2774 weight: 1.0,
2775 }),
2776 ],
2777 metadata,
2778 );
2779 assert_relative_ne!(
2780 dataset.event(0).expect("event should exist").weight,
2781 dataset.event(1).expect("event should exist").weight
2782 );
2783
2784 let bootstrapped = dataset.bootstrap(43);
2785 assert_eq!(bootstrapped.n_events(), dataset.n_events());
2786 assert_relative_eq!(
2787 bootstrapped.event(0).expect("event should exist").weight,
2788 bootstrapped.event(1).expect("event should exist").weight
2789 );
2790
2791 let empty_dataset = Dataset::new(Vec::new());
2793 let empty_bootstrap = empty_dataset.bootstrap(43);
2794 assert_eq!(empty_bootstrap.n_events(), 0);
2795 }
2796
2797 fn assert_weight_cache_matches_local_events(dataset: &Dataset) {
2798 #[cfg(feature = "rayon")]
2799 let expected = dataset
2800 .events_local()
2801 .par_iter()
2802 .map(|event| event.weight)
2803 .parallel_sum_with_accumulator::<Klein<f64>>();
2804 #[cfg(not(feature = "rayon"))]
2805 let expected = dataset
2806 .events_local()
2807 .iter()
2808 .map(|event| event.weight)
2809 .sum_with_accumulator::<Klein<f64>>();
2810 assert_relative_eq!(dataset.cached_local_weighted_sum, expected);
2811 assert_relative_eq!(dataset.n_events_weighted_local(), expected);
2812 }
2813
2814 #[test]
2815 fn test_weight_cache_recomputed_for_dataset_transforms() {
2816 let metadata = Arc::new(
2817 DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2818 .expect("metadata should be valid"),
2819 );
2820 let dataset = Dataset::new_with_metadata(
2821 vec![
2822 Arc::new(EventData {
2823 p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(0.0)],
2824 aux: vec![],
2825 weight: 1.0,
2826 }),
2827 Arc::new(EventData {
2828 p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(0.0)],
2829 aux: vec![],
2830 weight: 2.0,
2831 }),
2832 Arc::new(EventData {
2833 p4s: vec![Vec3::new(0.0, 0.0, 3.0).with_mass(0.0)],
2834 aux: vec![],
2835 weight: 3.0,
2836 }),
2837 ],
2838 metadata,
2839 );
2840 assert_weight_cache_matches_local_events(&dataset);
2841
2842 let filtered = dataset.filter(&Mass::new(["beam"]).gt(0.0)).unwrap();
2843 assert_weight_cache_matches_local_events(&filtered);
2844
2845 let bootstrapped = dataset.bootstrap(7);
2846 assert_weight_cache_matches_local_events(&bootstrapped);
2847
2848 let boosted = dataset.boost_to_rest_frame_of(&["beam"]);
2849 assert_weight_cache_matches_local_events(&boosted);
2850
2851 let combined = &dataset + &dataset;
2852 assert_weight_cache_matches_local_events(&combined);
2853 }
2854
2855 #[test]
2856 fn test_dataset_iteration_returns_events() {
2857 let dataset = test_dataset();
2858 let mut weights = Vec::new();
2859 for event in dataset.iter() {
2860 weights.push(event.weight());
2861 }
2862 assert_eq!(weights.len(), dataset.n_events());
2863 assert_relative_eq!(
2864 weights[0],
2865 dataset.event(0).expect("event should exist").weight
2866 );
2867 }
2868
2869 #[test]
2870 fn test_dataset_into_iter_returns_events() {
2871 let dataset = test_dataset();
2872 let weights: Vec<f64> = dataset.into_iter().map(|event| event.weight()).collect();
2873 assert_eq!(weights.len(), 1);
2874 assert_relative_eq!(weights[0], test_event().weight);
2875 }
2876
2877 #[test]
2878 fn test_dataset_arc_into_iter_returns_events() {
2879 let dataset = Arc::new(test_dataset());
2880 let weights: Vec<f64> = dataset.shared_iter().map(|event| event.weight()).collect();
2881 assert_eq!(weights.len(), 1);
2882 assert_relative_eq!(weights[0], test_event().weight);
2883 }
2884
2885 #[test]
2886 fn test_dataset_get_event_local_reuses_underlying_data() {
2887 let dataset = test_dataset();
2888 let first = dataset.get_event(0).expect("event should exist");
2889 let second = dataset.get_event(0).expect("event should exist");
2890 assert!(Arc::ptr_eq(&first.data_arc(), &second.data_arc()));
2891 }
2892
2893 #[test]
2894 fn test_dataset_event_out_of_bounds_is_error() {
2895 let dataset = test_dataset();
2896 assert!(dataset.event(99).is_err());
2897 assert!(dataset.get_event(99).is_none());
2898 }
2899
2900 #[cfg(feature = "mpi")]
2901 fn event_iteration_signature<I>(iter: I) -> (usize, f64, f64, f64)
2902 where
2903 I: IntoIterator<Item = Event>,
2904 {
2905 let mut count = 0usize;
2906 let mut weight_signature = 0.0;
2907 let mut beam_signature = 0.0;
2908 let mut aux_signature = 0.0;
2909
2910 for (index, event) in iter.into_iter().enumerate() {
2911 let position = (index + 1) as f64;
2912 count += 1;
2913 weight_signature += position * event.weight();
2914 beam_signature += position * event.p4("beam").expect("beam should exist").e();
2915 aux_signature += position
2916 * event
2917 .aux()
2918 .get("pol_angle")
2919 .copied()
2920 .expect("pol_angle should exist");
2921 }
2922
2923 (count, weight_signature, beam_signature, aux_signature)
2924 }
2925
2926 #[cfg(feature = "mpi")]
2927 fn read_resident_rss_kb() -> Option<u64> {
2928 #[cfg(target_os = "linux")]
2929 {
2930 let status = fs::read_to_string("/proc/self/status").ok()?;
2931 let vm_rss = status
2932 .lines()
2933 .find(|line| line.starts_with("VmRSS:"))?
2934 .split_whitespace()
2935 .nth(1)?;
2936 vm_rss.parse::<u64>().ok()
2937 }
2938
2939 #[cfg(not(target_os = "linux"))]
2940 {
2941 None
2942 }
2943 }
2944
2945 #[test]
2946 fn test_dataset_event_stress_local_repeated_access() {
2947 let metadata = test_dataset().metadata_arc();
2948 let base = test_event();
2949 let mut events = Vec::new();
2950 for idx in 0..8 {
2951 events.push(Arc::new(EventData {
2952 p4s: base.p4s.clone(),
2953 aux: base.aux.clone(),
2954 weight: 1.0 + idx as f64,
2955 }));
2956 }
2957 let dataset = Dataset::new_with_metadata(events, metadata);
2958 let baseline: Vec<f64> = (0..dataset.n_events())
2959 .map(|index| dataset.event(index).expect("event should exist").weight())
2960 .collect();
2961
2962 for _ in 0..250 {
2963 for (index, expected_weight) in baseline.iter().enumerate() {
2964 let event = dataset.event(index).expect("event should exist");
2965 assert_relative_eq!(event.weight(), *expected_weight);
2966 }
2967 }
2968 }
2969
2970 #[cfg(feature = "mpi")]
2971 #[mpi_test(np = [2])]
2972 fn test_dataset_event_mpi_repeated_access_is_stable() {
2973 use_mpi(true);
2974 assert!(get_world().is_some(), "MPI world should be initialized");
2975
2976 let dataset = test_dataset();
2977 for _ in 0..32 {
2978 let first = dataset.event(0).expect("event should exist");
2979 let second = dataset.event(0).expect("event should exist");
2980 assert_relative_eq!(first.weight(), second.weight());
2981 }
2982 finalize_mpi();
2983 }
2984
2985 #[cfg(feature = "mpi")]
2986 #[mpi_test(np = [2])]
2987 fn test_dataset_event_stress_mpi_repeated_access() {
2988 use_mpi(true);
2989 assert!(get_world().is_some(), "MPI world should be initialized");
2990
2991 let metadata = test_dataset().metadata_arc();
2992 let base = test_event();
2993 let mut events = Vec::new();
2994 for idx in 0..8 {
2995 events.push(Arc::new(EventData {
2996 p4s: base.p4s.clone(),
2997 aux: base.aux.clone(),
2998 weight: 1.0 + idx as f64,
2999 }));
3000 }
3001 let dataset = Dataset::new_with_metadata(events, metadata);
3002
3003 let baseline: Vec<f64> = (0..dataset.n_events())
3004 .map(|index| dataset.event(index).expect("event should exist").weight())
3005 .collect();
3006
3007 for _ in 0..120 {
3008 for (index, expected_weight) in baseline.iter().enumerate() {
3009 let event = dataset.event(index).expect("event should exist");
3010 assert_relative_eq!(event.weight(), *expected_weight);
3011 }
3012 }
3013 finalize_mpi();
3014 }
3015
3016 #[cfg(feature = "mpi")]
3017 #[mpi_test(np = [2])]
3018 fn test_dataset_iter_stress_mpi_repeated_passes() {
3019 use_mpi(true);
3020 assert!(get_world().is_some(), "MPI world should be initialized");
3021
3022 let metadata = test_dataset().metadata_arc();
3023 let base = test_event();
3024 let mut events = Vec::new();
3025 for idx in 0..8 {
3026 events.push(Arc::new(EventData {
3027 p4s: base.p4s.clone(),
3028 aux: base.aux.clone(),
3029 weight: 1.0 + idx as f64,
3030 }));
3031 }
3032 let dataset = Dataset::new_with_metadata(events, metadata);
3033 let baseline: Vec<f64> = dataset.iter().map(|event| event.weight()).collect();
3034
3035 for _ in 0..80 {
3036 let current: Vec<f64> = dataset.iter().map(|event| event.weight()).collect();
3037 assert_eq!(current.len(), baseline.len());
3038 for (current_weight, expected_weight) in current.iter().zip(baseline.iter()) {
3039 assert_relative_eq!(*current_weight, *expected_weight);
3040 }
3041 }
3042 finalize_mpi();
3043 }
3044
3045 #[cfg(feature = "mpi")]
3046 #[mpi_test(np = [2])]
3047 fn test_dataset_arc_into_iter_stress_mpi_repeated_passes() {
3048 use_mpi(true);
3049 assert!(get_world().is_some(), "MPI world should be initialized");
3050
3051 let metadata = test_dataset().metadata_arc();
3052 let base = test_event();
3053 let mut events = Vec::new();
3054 for idx in 0..8 {
3055 events.push(Arc::new(EventData {
3056 p4s: base.p4s.clone(),
3057 aux: base.aux.clone(),
3058 weight: 1.0 + idx as f64,
3059 }));
3060 }
3061 let dataset = Arc::new(Dataset::new_with_metadata(events, metadata));
3062 let baseline: Vec<f64> = dataset.shared_iter().map(|event| event.weight()).collect();
3063
3064 for _ in 0..80 {
3065 let current: Vec<f64> = dataset.shared_iter().map(|event| event.weight()).collect();
3066 assert_eq!(current.len(), baseline.len());
3067 for (current_weight, expected_weight) in current.iter().zip(baseline.iter()) {
3068 assert_relative_eq!(*current_weight, *expected_weight);
3069 }
3070 }
3071 finalize_mpi();
3072 }
3073
3074 #[cfg(feature = "mpi")]
3075 #[mpi_test(np = [2])]
3076 fn test_dataset_iteration_long_running_mpi_repeated_passes() {
3077 use_mpi(true);
3078 assert!(get_world().is_some(), "MPI world should be initialized");
3079
3080 let dataset = Arc::new(mpi_chunk_test_dataset(8_192));
3081 let baseline_iter = event_iteration_signature(dataset.iter());
3082 let baseline_shared = event_iteration_signature(dataset.shared_iter());
3083 assert_eq!(baseline_iter, baseline_shared);
3084 let mut post_warmup_rss_kb = Vec::new();
3085
3086 for pass_index in 0..48 {
3087 let current_iter = event_iteration_signature(dataset.iter());
3088 let current_shared = event_iteration_signature(dataset.shared_iter());
3089 assert_eq!(current_iter, baseline_iter);
3090 assert_eq!(current_shared, baseline_shared);
3091 if pass_index >= 7 {
3092 if let Some(rss_kb) = read_resident_rss_kb() {
3093 post_warmup_rss_kb.push(rss_kb);
3094 }
3095 }
3096 }
3097
3098 if let Some((&first_rss_kb, rest_rss_kb)) = post_warmup_rss_kb.split_first() {
3099 let last_rss_kb = *rest_rss_kb.last().unwrap_or(&first_rss_kb);
3100 let min_rss_kb = post_warmup_rss_kb
3101 .iter()
3102 .copied()
3103 .min()
3104 .expect("post-warmup RSS sample should exist");
3105 let max_rss_kb = post_warmup_rss_kb
3106 .iter()
3107 .copied()
3108 .max()
3109 .expect("post-warmup RSS sample should exist");
3110 const MAX_POST_WARMUP_RSS_GROWTH_KB: u64 = 32 * 1024;
3111 const MAX_POST_WARMUP_RSS_SPREAD_KB: u64 = 32 * 1024;
3112 assert!(
3113 last_rss_kb.saturating_sub(first_rss_kb) <= MAX_POST_WARMUP_RSS_GROWTH_KB,
3114 "post-warmup RSS grew by {} KiB (first={} KiB, last={} KiB)",
3115 last_rss_kb.saturating_sub(first_rss_kb),
3116 first_rss_kb,
3117 last_rss_kb
3118 );
3119 assert!(
3120 max_rss_kb.saturating_sub(min_rss_kb) <= MAX_POST_WARMUP_RSS_SPREAD_KB,
3121 "post-warmup RSS spread was {} KiB (min={} KiB, max={} KiB)",
3122 max_rss_kb.saturating_sub(min_rss_kb),
3123 min_rss_kb,
3124 max_rss_kb
3125 );
3126 }
3127
3128 finalize_mpi();
3129 }
3130
3131 #[cfg(feature = "mpi")]
3132 #[mpi_test(np = [2])]
3133 fn test_fetch_event_chunk_mpi_matches_single_event_fetches() {
3134 use_mpi(true);
3135 let world = get_world().expect("MPI world should be initialized");
3136
3137 let dataset = mpi_chunk_test_dataset(8);
3138 let chunk = fetch_event_chunk_mpi(&dataset, 1, 5, &world, dataset.n_events());
3139
3140 assert_eq!(chunk.len(), 5);
3141 for (offset, event) in chunk.iter().enumerate() {
3142 let baseline = dataset
3143 .event(1 + offset)
3144 .expect("chunk baseline event should exist");
3145 assert_events_close(event, &baseline, TEST_P4_NAMES, TEST_AUX_NAMES);
3146 }
3147
3148 assert!(
3149 fetch_event_chunk_mpi(&dataset, dataset.n_events(), 4, &world, dataset.n_events())
3150 .is_empty()
3151 );
3152 finalize_mpi();
3153 }
3154
3155 #[cfg(feature = "mpi")]
3156 #[mpi_test(np = [2])]
3157 fn test_fetch_event_chunk_mpi_truncates_at_dataset_end() {
3158 use_mpi(true);
3159 let world = get_world().expect("MPI world should be initialized");
3160
3161 let dataset = mpi_chunk_test_dataset(8);
3162 let chunk = fetch_event_chunk_mpi(&dataset, 6, 10, &world, dataset.n_events());
3163
3164 assert_eq!(chunk.len(), 2);
3165 for (offset, event) in chunk.iter().enumerate() {
3166 let baseline = dataset
3167 .event(6 + offset)
3168 .expect("truncated chunk baseline event should exist");
3169 assert_events_close(event, &baseline, TEST_P4_NAMES, TEST_AUX_NAMES);
3170 }
3171 finalize_mpi();
3172 }
3173
3174 #[cfg(feature = "mpi")]
3175 #[mpi_test(np = [2])]
3176 fn test_mpi_event_chunk_cursor_reuses_cached_chunk_for_dataset_and_events() {
3177 use_mpi(true);
3178 let world = get_world().expect("MPI world should be initialized");
3179
3180 let dataset = mpi_chunk_test_dataset(9);
3181 let total = dataset.n_events();
3182 let metadata = dataset.metadata_arc();
3183
3184 let mut dataset_cursor = MpiEventChunkCursor::new(3);
3185 for index in 0..total {
3186 let actual = dataset_cursor
3187 .event_for_dataset(&dataset, index, &world, total)
3188 .expect("dataset cursor event should exist");
3189 let expected = dataset.event(index).expect("baseline event should exist");
3190 assert_events_close(&actual, &expected, TEST_P4_NAMES, TEST_AUX_NAMES);
3191 }
3192 assert!(dataset_cursor
3193 .event_for_dataset(&dataset, total, &world, total)
3194 .is_none());
3195
3196 let mut events_cursor = MpiEventChunkCursor::new(4);
3197 for index in 0..total {
3198 let actual = events_cursor
3199 .event_for_events(dataset.events_local(), &metadata, index, &world, total)
3200 .expect("events cursor event should exist");
3201 let expected = dataset.event(index).expect("baseline event should exist");
3202 assert_events_close(&actual, &expected, TEST_P4_NAMES, TEST_AUX_NAMES);
3203 }
3204 finalize_mpi();
3205 }
3206
3207 #[cfg(feature = "mpi")]
3208 #[test]
3209 #[ignore = "developer probe for MPI iteration chunk-size tuning"]
3210 fn probe_mpi_iteration_chunk_size() {
3211 use std::time::Instant;
3212
3213 use_mpi(true);
3214 let Some(world) = get_world() else {
3215 finalize_mpi();
3216 return;
3217 };
3218
3219 let dataset = mpi_chunk_test_dataset(32_768);
3220 let total = dataset.n_events();
3221 let chunk_sizes = [64_usize, 128, 256, 512, 1024];
3222 if world.rank() == 0 {
3223 println!("probe=iteration");
3224 }
3225 for chunk_size in chunk_sizes {
3226 let started = Instant::now();
3227 let mut checksum = 0.0;
3228 for _ in 0..8 {
3229 let mut cursor = MpiEventChunkCursor::new(chunk_size);
3230 for index in 0..total {
3231 let event = cursor
3232 .event_for_dataset(&dataset, index, &world, total)
3233 .expect("cursor event should exist");
3234 checksum += event.weight() + event.p4("beam").expect("beam should exist").e();
3235 }
3236 }
3237 if world.rank() == 0 {
3238 println!(
3239 "probe=iteration chunk_size={} elapsed_sec={:.6} checksum={:.6}",
3240 chunk_size,
3241 started.elapsed().as_secs_f64(),
3242 checksum,
3243 );
3244 }
3245 }
3246 finalize_mpi();
3247 }
3248
3249 #[cfg(feature = "mpi")]
3250 #[test]
3251 #[ignore = "developer probe for MPI ROOT write chunk-size tuning"]
3252 fn probe_mpi_root_write_chunk_size() {
3253 use std::time::Instant;
3254
3255 use_mpi(true);
3256 let Some(world) = get_world() else {
3257 finalize_mpi();
3258 return;
3259 };
3260
3261 let dataset = Arc::new(mpi_chunk_test_dataset(32_768));
3262 let chunk_sizes = [64_usize, 128, 256, 512, 1024];
3263 if world.rank() == 0 {
3264 println!("probe=root_write");
3265 }
3266 for chunk_size in chunk_sizes {
3267 let dir = make_temp_dir();
3268 let path = dir.join(format!("mpi_chunk_probe_{chunk_size}.root"));
3269 let path_str = path.to_str().expect("probe path should be valid UTF-8");
3270 let started = Instant::now();
3271 for _ in 0..4 {
3272 io::write_root_with_chunk_size_for_test(
3273 &dataset,
3274 path_str,
3275 &DatasetWriteOptions::default(),
3276 chunk_size,
3277 )
3278 .expect("probe root write should succeed");
3279 }
3280
3281 if world.rank() == 0 {
3282 println!(
3283 "probe=root_write chunk_size={} elapsed_sec={:.6}",
3284 chunk_size,
3285 started.elapsed().as_secs_f64(),
3286 );
3287 fs::remove_dir_all(&dir).expect("probe temp dir cleanup should succeed");
3288 }
3289 }
3290 finalize_mpi();
3291 }
3292
3293 #[test]
3294 fn test_event_display() {
3295 let event = test_event();
3296 let display_string = format!("{}", event);
3297 assert!(display_string.contains("Event:"));
3298 assert!(display_string.contains("p4s:"));
3299 assert!(display_string.contains("aux:"));
3300 assert!(display_string.contains("aux[0]: 0.38562805"));
3301 assert!(display_string.contains("aux[1]: 0.05708078"));
3302 assert!(display_string.contains("weight:"));
3303 }
3304
3305 #[test]
3306 fn test_name_based_access() {
3307 let metadata =
3308 Arc::new(DatasetMetadata::new(vec!["beam", "target"], vec!["pol_angle"]).unwrap());
3309 let event = Arc::new(EventData {
3310 p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(0.1, 0.2, 0.3, 0.5)],
3311 aux: vec![0.42],
3312 weight: 1.0,
3313 });
3314 let dataset = Dataset::new_with_metadata(vec![event], metadata);
3315 let beam = dataset.p4_by_name(0, "beam").unwrap();
3316 assert_relative_eq!(beam.px(), 0.0);
3317 assert_relative_eq!(beam.py(), 0.0);
3318 assert_relative_eq!(beam.pz(), 1.0);
3319 assert_relative_eq!(beam.e(), 1.0);
3320 assert_relative_eq!(dataset.aux_by_name(0, "pol_angle").unwrap(), 0.42);
3321 assert!(dataset.p4_by_name(0, "missing").is_none());
3322 assert!(dataset.aux_by_name(0, "missing").is_none());
3323 }
3324
3325 #[test]
3326 fn test_parquet_roundtrip_to_tempfile() {
3327 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3328 let dir = make_temp_dir();
3329 let path = dir.join("roundtrip.parquet");
3330 let path_str = path.to_str().expect("path should be valid UTF-8");
3331
3332 write_parquet(&dataset, path_str, &DatasetWriteOptions::default())
3333 .expect("writing parquet should succeed");
3334 let reopened = read_parquet(path_str, &DatasetReadOptions::new())
3335 .expect("parquet roundtrip should reopen");
3336
3337 assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3338 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3339 }
3340
3341 #[test]
3342 fn test_parquet_roundtrip_incremental_small_batches() {
3343 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3344 let dir = make_temp_dir();
3345 let path = dir.join("roundtrip_small_batches.parquet");
3346 let path_str = path.to_str().expect("path should be valid UTF-8");
3347
3348 let write_options = DatasetWriteOptions::default().batch_size(1);
3349 write_parquet(&dataset, path_str, &write_options)
3350 .expect("writing parquet in small batches should succeed");
3351 let reopened = read_parquet(path_str, &DatasetReadOptions::new())
3352 .expect("parquet roundtrip should reopen");
3353
3354 assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3355 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3356 }
3357
3358 #[test]
3359 fn test_parquet_read_order_is_deterministic_across_repeated_reads() {
3360 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3361 let dir = make_temp_dir();
3362 let path = dir.join("deterministic_order.parquet");
3363 let path_str = path.to_str().expect("path should be valid UTF-8");
3364
3365 let write_options = DatasetWriteOptions::default().batch_size(1);
3367 write_parquet(&dataset, path_str, &write_options)
3368 .expect("writing parquet in small batches should succeed");
3369
3370 let first = read_parquet(path_str, &DatasetReadOptions::new())
3371 .expect("first parquet read should succeed");
3372 let second = read_parquet(path_str, &DatasetReadOptions::new())
3373 .expect("second parquet read should succeed");
3374
3375 assert_eq!(first.n_events(), second.n_events());
3376 assert_eq!(first.n_events(), dataset.n_events());
3377 for event_index in 0..dataset.n_events() {
3378 let source = dataset
3379 .event(event_index)
3380 .expect("source event should exist");
3381 let first_event = first
3382 .event(event_index)
3383 .expect("first read event should exist");
3384 let second_event = second
3385 .event(event_index)
3386 .expect("second read event should exist");
3387 assert_events_close(&source, &first_event, TEST_P4_NAMES, TEST_AUX_NAMES);
3388 assert_events_close(&source, &second_event, TEST_P4_NAMES, TEST_AUX_NAMES);
3389 }
3390
3391 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3392 }
3393
3394 #[test]
3395 fn test_parquet_storage_roundtrip_to_tempfile() {
3396 let source_path = test_data_path("data_f32.parquet");
3397 let source_path_str = source_path.to_str().expect("path should be valid UTF-8");
3398 let dataset_columnar = read_parquet_storage(source_path_str, &DatasetReadOptions::new())
3399 .expect("columnar load");
3400 let dir = make_temp_dir();
3401 let path = dir.join("roundtrip_columnar.parquet");
3402 let path_str = path.to_str().expect("path should be valid UTF-8");
3403
3404 write_parquet_storage(&dataset_columnar, path_str, &DatasetWriteOptions::default())
3405 .expect("writing columnar parquet should succeed");
3406 let reopened = read_parquet_storage(path_str, &DatasetReadOptions::new())
3407 .expect("columnar roundtrip reopen");
3408
3409 assert_dataset_columnar_close(&dataset_columnar, &reopened);
3410 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3411 }
3412
3413 #[test]
3414 fn test_root_storage_matches_parquet_storage() {
3415 let root_path = test_data_path("data_f32.root");
3416 let root_path_str = root_path.to_str().expect("path should be valid UTF-8");
3417 let parquet_path = test_data_path("data_f32.parquet");
3418 let parquet_path_str = parquet_path.to_str().expect("path should be valid UTF-8");
3419
3420 let from_root = read_root_storage(root_path_str, &DatasetReadOptions::new())
3421 .expect("root columnar load should work");
3422 let from_parquet = read_parquet_storage(parquet_path_str, &DatasetReadOptions::new())
3423 .expect("parquet columnar load should work");
3424 assert_dataset_columnar_close(&from_root, &from_parquet);
3425 }
3426
3427 #[test]
3428 fn test_root_storage_repeated_reads_are_stable() {
3429 let root_path = test_data_path("data_f32.root");
3430 let root_path_str = root_path.to_str().expect("path should be valid UTF-8");
3431 let first = read_root_storage(root_path_str, &DatasetReadOptions::new())
3432 .expect("first root columnar load should work");
3433 let second = read_root_storage(root_path_str, &DatasetReadOptions::new())
3434 .expect("second root columnar load should work");
3435 assert_dataset_columnar_close(&first, &second);
3436 }
3437
3438 #[cfg(feature = "mpi")]
3439 #[mpi_test(np = [2])]
3440 fn test_root_storage_reads_rank_local_entry_ranges_under_mpi() {
3441 let root_path = test_data_path("data_f32.root");
3442 let root_path_str = root_path.to_str().expect("path should be valid UTF-8");
3443 let full = read_root_storage(root_path_str, &DatasetReadOptions::new())
3444 .expect("full root columnar load should work");
3445 let total = full.n_events();
3446
3447 use_mpi(true);
3448 let world = get_world().expect("MPI world should be initialized");
3449 let partition = world.partition(total);
3450 let local_range = partition.range_for_rank(world.rank() as usize);
3451
3452 let local = read_root_storage(root_path_str, &DatasetReadOptions::new())
3453 .expect("rank-local root columnar load should work");
3454 assert_eq!(local.n_events(), local_range.len());
3455
3456 for (local_index, global_index) in local_range.clone().enumerate() {
3457 for p4_index in 0..full.metadata().p4_names().len() {
3458 let expected = full.p4(global_index, p4_index);
3459 let actual = local.p4(local_index, p4_index);
3460 assert_relative_eq!(actual.px(), expected.px(), epsilon = 1e-12);
3461 assert_relative_eq!(actual.py(), expected.py(), epsilon = 1e-12);
3462 assert_relative_eq!(actual.pz(), expected.pz(), epsilon = 1e-12);
3463 assert_relative_eq!(actual.e(), expected.e(), epsilon = 1e-12);
3464 }
3465 for aux_index in 0..full.metadata().aux_names().len() {
3466 assert_relative_eq!(
3467 local.aux(local_index, aux_index),
3468 full.aux(global_index, aux_index),
3469 epsilon = 1e-12
3470 );
3471 }
3472 assert_relative_eq!(
3473 local.weight(local_index),
3474 full.weight(global_index),
3475 epsilon = 1e-12
3476 );
3477 }
3478
3479 let local_dataset = local.to_dataset();
3480 assert_eq!(local_dataset.n_events_local(), local_range.len());
3481 assert_eq!(local_dataset.n_events(), total);
3482 finalize_mpi();
3483 }
3484
3485 #[test]
3486 fn test_root_roundtrip_to_tempfile() {
3487 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3488 let dir = make_temp_dir();
3489 let path = dir.join("roundtrip.root");
3490 let path_str = path.to_str().expect("path should be valid UTF-8");
3491
3492 write_root(&dataset, path_str, &DatasetWriteOptions::default())
3493 .expect("writing root should succeed");
3494 let reopened =
3495 read_root(path_str, &DatasetReadOptions::new()).expect("root roundtrip should reopen");
3496
3497 assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3498 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
3499 }
3500
3501 #[cfg(feature = "mpi")]
3502 #[mpi_test(np = [2])]
3503 fn test_root_roundtrip_to_tempfile_mpi() {
3504 let reference = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3505 use_mpi(true);
3506 let world = get_world().expect("MPI world should be initialized");
3507 let is_root = world.is_root();
3508
3509 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3510 let path = env::temp_dir().join("laddu_mpi_root_roundtrip.root");
3511 let path_str = path.to_str().expect("path should be valid UTF-8");
3512
3513 if world.is_root() && path.exists() {
3514 fs::remove_file(&path).expect("stale mpi root file cleanup should succeed");
3515 }
3516 world.barrier();
3517
3518 write_root(&dataset, path_str, &DatasetWriteOptions::default())
3519 .expect("writing root with mpi should succeed");
3520 world.barrier();
3521 world.barrier();
3522 finalize_mpi();
3523
3524 if is_root {
3525 let reopened = read_root(path_str, &DatasetReadOptions::new())
3526 .expect("root roundtrip should reopen");
3527 assert_datasets_close(&reference, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
3528 if path.exists() {
3529 fs::remove_file(&path).expect("mpi root roundtrip cleanup should succeed");
3530 }
3531 }
3532 }
3533
3534 #[cfg(feature = "mpi")]
3535 #[mpi_test(np = [2])]
3536 fn test_root_output_is_deterministic_under_mpi() {
3537 use_mpi(true);
3538 let world = get_world().expect("MPI world should be initialized");
3539
3540 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3541 let first_path = env::temp_dir().join("laddu_mpi_root_determinism_first.root");
3542 let second_path = env::temp_dir().join("laddu_mpi_root_determinism_second.root");
3543 let first_path_str = first_path.to_str().expect("path should be valid UTF-8");
3544 let second_path_str = second_path.to_str().expect("path should be valid UTF-8");
3545
3546 if world.is_root() {
3547 for path in [&first_path, &second_path] {
3548 if path.exists() {
3549 fs::remove_file(path).expect("stale mpi root file cleanup should succeed");
3550 }
3551 }
3552 }
3553 world.barrier();
3554
3555 write_root(&dataset, first_path_str, &DatasetWriteOptions::default())
3556 .expect("first mpi root write should succeed");
3557 world.barrier();
3558 write_root(&dataset, second_path_str, &DatasetWriteOptions::default())
3559 .expect("second mpi root write should succeed");
3560 world.barrier();
3561
3562 let first = read_root_storage(first_path_str, &DatasetReadOptions::new())
3563 .expect("first mpi root output should reopen");
3564 let second = read_root_storage(second_path_str, &DatasetReadOptions::new())
3565 .expect("second mpi root output should reopen");
3566 assert_dataset_columnar_close(&first, &second);
3567
3568 world.barrier();
3569 if world.is_root() {
3570 for path in [&first_path, &second_path] {
3571 if path.exists() {
3572 fs::remove_file(path).expect("mpi root determinism cleanup should succeed");
3573 }
3574 }
3575 }
3576 finalize_mpi();
3577 }
3578
3579 #[cfg(feature = "mpi")]
3580 #[mpi_test(np = [2])]
3581 fn test_root_output_matches_between_mpi_and_non_mpi_writes() {
3582 let cpu_dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3583 let mpi_path = env::temp_dir().join("laddu_root_mpi_reference.root");
3584 let mpi_path_str = mpi_path.to_str().expect("path should be valid UTF-8");
3585
3586 use_mpi(true);
3587 let world = get_world().expect("MPI world should be initialized");
3588 let is_root = world.is_root();
3589 let mpi_dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3590
3591 if is_root && mpi_path.exists() {
3592 fs::remove_file(&mpi_path).expect("stale root file cleanup should succeed");
3593 }
3594 world.barrier();
3595 write_root(&mpi_dataset, mpi_path_str, &DatasetWriteOptions::default())
3596 .expect("mpi root write should succeed");
3597 world.barrier();
3598 world.barrier();
3599 finalize_mpi();
3600
3601 if is_root {
3602 let cpu_dir = make_temp_dir();
3603 let cpu_path = cpu_dir.join("laddu_root_cpu_reference.root");
3604 let cpu_path_str = cpu_path.to_str().expect("path should be valid UTF-8");
3605 write_root(&cpu_dataset, cpu_path_str, &DatasetWriteOptions::default())
3606 .expect("non-mpi root write should succeed");
3607
3608 let cpu_output = read_root_storage(cpu_path_str, &DatasetReadOptions::new())
3609 .expect("non-mpi root output should reopen");
3610 let mpi_output = read_root_storage(mpi_path_str, &DatasetReadOptions::new())
3611 .expect("mpi root output should reopen");
3612 assert_dataset_columnar_close(&cpu_output, &mpi_output);
3613
3614 fs::remove_dir_all(&cpu_dir).expect("root comparison temp dir cleanup should succeed");
3615 if mpi_path.exists() {
3616 fs::remove_file(&mpi_path).expect("root comparison cleanup should succeed");
3617 }
3618 }
3619 }
3620
3621 #[test]
3622 fn test_root_local_column_buffers_match_columnar_storage() {
3623 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3624 let buffers = io::build_root_local_column_buffers::<f64>(&dataset.columnar);
3625 let expected_names = dataset
3626 .p4_names()
3627 .iter()
3628 .flat_map(|name| {
3629 io::P4_COMPONENT_SUFFIXES
3630 .iter()
3631 .map(move |suffix| format!("{name}{suffix}"))
3632 })
3633 .chain(dataset.aux_names().iter().cloned())
3634 .chain(std::iter::once("weight".to_string()))
3635 .collect::<Vec<_>>();
3636 let expected_values = dataset
3637 .columnar
3638 .p4
3639 .iter()
3640 .flat_map(|p4| [p4.px.clone(), p4.py.clone(), p4.pz.clone(), p4.e.clone()])
3641 .chain(dataset.columnar.aux.clone())
3642 .chain(std::iter::once(dataset.columnar.weights.clone()))
3643 .collect::<Vec<_>>();
3644 assert_eq!(
3645 buffers
3646 .iter()
3647 .map(|(name, _)| name.as_str())
3648 .collect::<Vec<_>>(),
3649 expected_names
3650 );
3651 assert_eq!(
3652 buffers
3653 .into_iter()
3654 .map(|(_, values)| values)
3655 .collect::<Vec<_>>(),
3656 expected_values
3657 );
3658 }
3659
3660 #[test]
3661 fn test_root_local_column_buffers_convert_precision() {
3662 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
3663 let buffers = io::build_root_local_column_buffers::<f32>(&dataset.columnar);
3664 let expected_values = dataset
3665 .columnar
3666 .p4
3667 .iter()
3668 .flat_map(|p4| {
3669 [
3670 p4.px.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3671 p4.py.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3672 p4.pz.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3673 p4.e.iter().map(|value| *value as f32).collect::<Vec<_>>(),
3674 ]
3675 })
3676 .chain(
3677 dataset
3678 .columnar
3679 .aux
3680 .iter()
3681 .map(|aux| aux.iter().map(|value| *value as f32).collect::<Vec<_>>()),
3682 )
3683 .chain(std::iter::once(
3684 dataset
3685 .columnar
3686 .weights
3687 .iter()
3688 .map(|value| *value as f32)
3689 .collect::<Vec<_>>(),
3690 ))
3691 .collect::<Vec<_>>();
3692
3693 assert_eq!(
3694 buffers
3695 .into_iter()
3696 .map(|(_, values)| values)
3697 .collect::<Vec<_>>(),
3698 expected_values
3699 );
3700 }
3701
3702 #[test]
3703 fn test_parquet_chunk_iterator_matches_full_read() {
3704 let path = test_data_path("data_f32.parquet");
3705 let path_str = path.to_str().expect("path should be valid UTF-8");
3706 let options = DatasetReadOptions::new();
3707 let full = read_parquet(path_str, &options).expect("full parquet read should work");
3708 let chunks =
3709 read_parquet_chunks(path_str, &options, 17).expect("chunk iterator should open");
3710
3711 let mut global_idx = 0usize;
3712 for chunk in chunks {
3713 let chunk = chunk.expect("chunk read should succeed");
3714 for local_idx in 0..chunk.n_events_local() {
3715 let left = full
3716 .event(global_idx)
3717 .expect("full dataset event should exist");
3718 let right = chunk
3719 .event(local_idx)
3720 .expect("chunk dataset event should exist");
3721 assert_events_close(&left, &right, TEST_P4_NAMES, TEST_AUX_NAMES);
3722 global_idx += 1;
3723 }
3724 }
3725
3726 assert_eq!(global_idx, full.n_events());
3727 }
3728
3729 #[cfg(feature = "mpi")]
3730 #[mpi_test(np = [2])]
3731 fn test_parquet_chunk_iterator_respects_mpi_partition() {
3732 let path = test_data_path("data_f32.parquet");
3733 let path_str = path.to_str().expect("path should be valid UTF-8");
3734 let options = DatasetReadOptions::new();
3735 let reference =
3736 read_parquet(path_str, &options).expect("reference parquet read should work");
3737
3738 use_mpi(true);
3739 let world = get_world().expect("MPI world should be initialized");
3740 let partition = world.partition(reference.n_events());
3741 let local_range = partition.range_for_rank(world.rank() as usize);
3742 let chunks =
3743 read_parquet_chunks(path_str, &options, 17).expect("chunk iterator should open");
3744
3745 let mut local_idx = 0usize;
3746 for chunk in chunks {
3747 let chunk = chunk.expect("chunk read should succeed");
3748 assert!(chunk.n_events_local() <= 17);
3749 for chunk_idx in 0..chunk.n_events_local() {
3750 let expected = reference
3751 .event(local_range.start + local_idx)
3752 .expect("reference event should exist");
3753 let actual = chunk.event(chunk_idx).expect("chunk event should exist");
3754 assert_events_close(&expected, &actual, TEST_P4_NAMES, TEST_AUX_NAMES);
3755 local_idx += 1;
3756 }
3757 }
3758
3759 assert_eq!(local_idx, local_range.len());
3760 let mut gathered_counts = vec![0usize; world.size() as usize];
3761 world.all_gather_into(&local_idx, &mut gathered_counts);
3762 assert_eq!(
3763 gathered_counts.into_iter().sum::<usize>(),
3764 reference.n_events()
3765 );
3766 finalize_mpi();
3767 }
3768
3769 #[test]
3770 fn test_parquet_chunk_iterator_with_options_chunk_size_one() {
3771 let path = test_data_path("data_f32.parquet");
3772 let path_str = path.to_str().expect("path should be valid UTF-8");
3773 let options = DatasetReadOptions::new().chunk_size(1);
3774 let full = read_parquet(path_str, &DatasetReadOptions::new())
3775 .expect("full parquet read should work");
3776 let chunks = read_parquet_chunks_with_options(path_str, &options)
3777 .expect("chunk iterator should open");
3778 let mut event_count = 0usize;
3779 let mut chunk_count = 0usize;
3780
3781 for chunk in chunks {
3782 let chunk = chunk.expect("chunk read should succeed");
3783 chunk_count += 1;
3784 assert_eq!(chunk.n_events_local(), 1);
3785 event_count += chunk.n_events_local();
3786 }
3787
3788 assert_eq!(event_count, full.n_events());
3789 assert_eq!(chunk_count, full.n_events());
3790 }
3791
3792 #[test]
3793 fn test_parquet_chunk_iterator_with_options_large_chunk_size() {
3794 let path = test_data_path("data_f32.parquet");
3795 let path_str = path.to_str().expect("path should be valid UTF-8");
3796 let full = read_parquet(path_str, &DatasetReadOptions::new())
3797 .expect("full parquet read should work");
3798 let options = DatasetReadOptions::new().chunk_size(full.n_events() + 100);
3799 let chunks = read_parquet_chunks_with_options(path_str, &options)
3800 .expect("chunk iterator should open");
3801 let chunk_vec = chunks
3802 .collect::<LadduResult<Vec<_>>>()
3803 .expect("all chunk reads should succeed");
3804
3805 assert_eq!(chunk_vec.len(), 1);
3806 assert_eq!(chunk_vec[0].n_events_local(), full.n_events());
3807 }
3808
3809 #[test]
3810 fn test_dataset_chunk_builder_matches_full_parquet_read() {
3811 let path = test_data_path("data_f32.parquet");
3812 let path_str = path.to_str().expect("path should be valid UTF-8");
3813 let options = DatasetReadOptions::new().chunk_size(13);
3814 let full = read_parquet(path_str, &DatasetReadOptions::new())
3815 .expect("full parquet read should work");
3816 let chunks = read_parquet_chunks_with_options(path_str, &options)
3817 .expect("chunk iterator should open");
3818
3819 let mut builder = DatasetChunkBuilder::new();
3820 for chunk in chunks {
3821 let chunk = chunk.expect("chunk read should succeed");
3822 builder.push_chunk(&chunk).expect("chunk should append");
3823 }
3824 let rebuilt = builder.finish();
3825
3826 assert_datasets_close(&full, &rebuilt, TEST_P4_NAMES, TEST_AUX_NAMES);
3827 }
3828
3829 #[test]
3830 fn test_try_fold_dataset_chunks_matches_full_weight_sum() {
3831 let path = test_data_path("data_f32.parquet");
3832 let path_str = path.to_str().expect("path should be valid UTF-8");
3833 let full = read_parquet(path_str, &DatasetReadOptions::new())
3834 .expect("full parquet read should work");
3835 let chunks = read_parquet_chunks(path_str, &DatasetReadOptions::new(), 11)
3836 .expect("chunk iterator should open");
3837
3838 let folded = try_fold_dataset_chunks(chunks, 0.0_f64, |acc, chunk| {
3839 Ok(acc + chunk.n_events_weighted_local())
3840 })
3841 .expect("chunk fold should succeed");
3842
3843 assert_relative_eq!(folded, full.n_events_weighted_local(), epsilon = 1e-9);
3844 }
3845}