1use accurate::{sum::Klein, traits::*};
2use arrow::{
3 array::{Float32Array, Float64Array},
4 datatypes::{DataType, Field, Schema},
5 record_batch::RecordBatch,
6};
7use auto_ops::impl_op_ex;
8use parking_lot::Mutex;
9use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter};
10#[cfg(feature = "mpi")]
11use parquet::file::metadata::ParquetMetaData;
12use serde::{Deserialize, Serialize};
13use std::ops::{Deref, DerefMut, Index, IndexMut};
14use std::path::Path;
15use std::{fmt::Display, fs::File};
16use std::{path::PathBuf, sync::Arc};
17
18use oxyroot::{Branch, Named, ReaderTree, RootFile, WriterTree};
19
20#[cfg(feature = "rayon")]
21use rayon::prelude::*;
22
23#[cfg(feature = "mpi")]
24use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
25
26#[cfg(feature = "mpi")]
27use crate::mpi::LadduMPI;
28
29#[cfg(feature = "mpi")]
30type WorldHandle = SimpleCommunicator;
31#[cfg(not(feature = "mpi"))]
32type WorldHandle = ();
33
34use crate::utils::get_bin_edges;
35use crate::{
36 utils::{
37 variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
38 vectors::Vec4,
39 },
40 LadduError, LadduResult,
41};
42use indexmap::{IndexMap, IndexSet};
43
44pub fn test_event() -> EventData {
48 use crate::utils::vectors::*;
49 let pol_magnitude = 0.38562805;
50 let pol_angle = 0.05708078;
51 EventData {
52 p4s: vec![
53 Vec3::new(0.0, 0.0, 8.747).with_mass(0.0), Vec3::new(0.119, 0.374, 0.222).with_mass(1.007), Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498), Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), ],
58 aux: vec![pol_magnitude, pol_angle],
59 weight: 0.48,
60 }
61}
62
63pub const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
65pub const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
67
68pub fn test_dataset() -> Dataset {
72 let metadata = Arc::new(
73 DatasetMetadata::new(
74 TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
75 TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
76 )
77 .expect("Test metadata should be valid"),
78 );
79 Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
80}
81
82#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88pub struct EventData {
89 pub p4s: Vec<Vec4>,
91 pub aux: Vec<f64>,
93 pub weight: f64,
95}
96
97impl Display for EventData {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 writeln!(f, "Event:")?;
100 writeln!(f, " p4s:")?;
101 for p4 in &self.p4s {
102 writeln!(f, " {}", p4.to_p4_string())?;
103 }
104 writeln!(f, " aux:")?;
105 for (idx, value) in self.aux.iter().enumerate() {
106 writeln!(f, " aux[{idx}]: {value}")?;
107 }
108 writeln!(f, " weight:")?;
109 writeln!(f, " {}", self.weight)?;
110 Ok(())
111 }
112}
113
114impl EventData {
115 pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
117 indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
118 }
119 pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
122 let frame = self.get_p4_sum(indices);
123 EventData {
124 p4s: self
125 .p4s
126 .iter()
127 .map(|p4| p4.boost(&(-frame.beta())))
128 .collect(),
129 aux: self.aux.clone(),
130 weight: self.weight,
131 }
132 }
133 pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
135 variable.value(self)
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct DatasetMetadata {
142 pub(crate) p4_names: Vec<String>,
143 pub(crate) aux_names: Vec<String>,
144 pub(crate) p4_lookup: IndexMap<String, usize>,
145 pub(crate) aux_lookup: IndexMap<String, usize>,
146 pub(crate) p4_selections: IndexMap<String, P4Selection>,
147}
148
149impl DatasetMetadata {
150 pub fn new<P: Into<String>, A: Into<String>>(
152 p4_names: Vec<P>,
153 aux_names: Vec<A>,
154 ) -> LadduResult<Self> {
155 let mut p4_lookup = IndexMap::with_capacity(p4_names.len());
156 let mut aux_lookup = IndexMap::with_capacity(aux_names.len());
157 let mut p4_selections = IndexMap::with_capacity(p4_names.len());
158 let p4_names: Vec<String> = p4_names
159 .into_iter()
160 .enumerate()
161 .map(|(idx, name)| {
162 let name = name.into();
163 if p4_lookup.contains_key(&name) {
164 return Err(LadduError::DuplicateName {
165 category: "p4",
166 name,
167 });
168 }
169 p4_lookup.insert(name.clone(), idx);
170 p4_selections.insert(
171 name.clone(),
172 P4Selection::with_indices(vec![name.clone()], vec![idx]),
173 );
174 Ok(name)
175 })
176 .collect::<Result<_, _>>()?;
177 let aux_names: Vec<String> = aux_names
178 .into_iter()
179 .enumerate()
180 .map(|(idx, name)| {
181 let name = name.into();
182 if aux_lookup.contains_key(&name) {
183 return Err(LadduError::DuplicateName {
184 category: "aux",
185 name,
186 });
187 }
188 aux_lookup.insert(name.clone(), idx);
189 Ok(name)
190 })
191 .collect::<Result<_, _>>()?;
192 Ok(Self {
193 p4_names,
194 aux_names,
195 p4_lookup,
196 aux_lookup,
197 p4_selections,
198 })
199 }
200
201 pub fn empty() -> Self {
203 Self {
204 p4_names: Vec::new(),
205 aux_names: Vec::new(),
206 p4_lookup: IndexMap::new(),
207 aux_lookup: IndexMap::new(),
208 p4_selections: IndexMap::new(),
209 }
210 }
211
212 pub fn p4_index(&self, name: &str) -> Option<usize> {
214 self.p4_lookup.get(name).copied()
215 }
216
217 pub fn p4_names(&self) -> &[String] {
219 &self.p4_names
220 }
221
222 pub fn aux_index(&self, name: &str) -> Option<usize> {
224 self.aux_lookup.get(name).copied()
225 }
226
227 pub fn aux_names(&self) -> &[String] {
229 &self.aux_names
230 }
231
232 pub fn p4_selection(&self, name: &str) -> Option<&P4Selection> {
234 self.p4_selections.get(name)
235 }
236
237 pub fn add_p4_alias<N>(&mut self, alias: N, mut selection: P4Selection) -> LadduResult<()>
239 where
240 N: Into<String>,
241 {
242 let alias = alias.into();
243 if self.p4_selections.contains_key(&alias) {
244 return Err(LadduError::DuplicateName {
245 category: "alias",
246 name: alias,
247 });
248 }
249 selection.bind(self)?;
250 self.p4_selections.insert(alias, selection);
251 Ok(())
252 }
253
254 pub fn add_p4_aliases<I, N>(&mut self, entries: I) -> LadduResult<()>
256 where
257 I: IntoIterator<Item = (N, P4Selection)>,
258 N: Into<String>,
259 {
260 for (alias, selection) in entries {
261 self.add_p4_alias(alias, selection)?;
262 }
263 Ok(())
264 }
265
266 pub(crate) fn append_indices_for_name(
267 &self,
268 name: &str,
269 target: &mut Vec<usize>,
270 ) -> LadduResult<()> {
271 if let Some(selection) = self.p4_selections.get(name) {
272 target.extend_from_slice(selection.indices());
273 return Ok(());
274 }
275 Err(LadduError::UnknownName {
276 category: "p4",
277 name: name.to_string(),
278 })
279 }
280}
281
282impl Default for DatasetMetadata {
283 fn default() -> Self {
284 Self::empty()
285 }
286}
287
288#[derive(Debug, Clone)]
290pub struct Dataset {
291 pub events: Vec<Event>,
293 pub(crate) metadata: Arc<DatasetMetadata>,
294}
295
296#[derive(Clone, Debug)]
298pub struct Event {
299 event: Arc<EventData>,
300 metadata: Arc<DatasetMetadata>,
301}
302
303impl Event {
304 pub fn new(event: Arc<EventData>, metadata: Arc<DatasetMetadata>) -> Self {
306 Self { event, metadata }
307 }
308
309 pub fn data(&self) -> &EventData {
311 &self.event
312 }
313
314 pub fn data_arc(&self) -> Arc<EventData> {
316 self.event.clone()
317 }
318
319 pub fn p4s(&self) -> IndexMap<&str, Vec4> {
321 let mut map = IndexMap::with_capacity(self.metadata.p4_names.len());
322 for (idx, name) in self.metadata.p4_names.iter().enumerate() {
323 if let Some(p4) = self.event.p4s.get(idx) {
324 map.insert(name.as_str(), *p4);
325 }
326 }
327 map
328 }
329
330 pub fn aux(&self) -> IndexMap<&str, f64> {
332 let mut map = IndexMap::with_capacity(self.metadata.aux_names.len());
333 for (idx, name) in self.metadata.aux_names.iter().enumerate() {
334 if let Some(value) = self.event.aux.get(idx) {
335 map.insert(name.as_str(), *value);
336 }
337 }
338 map
339 }
340
341 pub fn weight(&self) -> f64 {
343 self.event.weight
344 }
345
346 pub fn metadata(&self) -> &DatasetMetadata {
348 &self.metadata
349 }
350
351 pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
353 self.metadata.clone()
354 }
355
356 pub fn p4(&self, name: &str) -> Option<Vec4> {
358 self.metadata
359 .p4_selection(name)
360 .map(|selection| selection.momentum(&self.event))
361 }
362
363 fn resolve_p4_indices<N>(&self, names: N) -> Vec<usize>
364 where
365 N: IntoIterator,
366 N::Item: AsRef<str>,
367 {
368 let mut indices = Vec::new();
369 for name in names {
370 let name_ref = name.as_ref();
371 if let Some(selection) = self.metadata.p4_selection(name_ref) {
372 indices.extend_from_slice(selection.indices());
373 } else {
374 panic!("Unknown particle name '{name}'", name = name_ref);
375 }
376 }
377 indices
378 }
379
380 pub fn get_p4_sum<N>(&self, names: N) -> Vec4
382 where
383 N: IntoIterator,
384 N::Item: AsRef<str>,
385 {
386 let indices = self.resolve_p4_indices(names);
387 self.event.get_p4_sum(&indices)
388 }
389
390 pub fn boost_to_rest_frame_of<N>(&self, names: N) -> EventData
392 where
393 N: IntoIterator,
394 N::Item: AsRef<str>,
395 {
396 let indices = self.resolve_p4_indices(names);
397 self.event.boost_to_rest_frame_of(&indices)
398 }
399
400 pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
402 self.event.evaluate(variable)
403 }
404}
405
406impl Deref for Event {
407 type Target = EventData;
408
409 fn deref(&self) -> &Self::Target {
410 &self.event
411 }
412}
413
414impl AsRef<EventData> for Event {
415 fn as_ref(&self) -> &EventData {
416 self.data()
417 }
418}
419
420impl IntoIterator for Dataset {
421 type Item = Event;
422
423 type IntoIter = DatasetIntoIter;
424
425 fn into_iter(self) -> Self::IntoIter {
426 #[cfg(feature = "mpi")]
427 {
428 if let Some(world) = crate::mpi::get_world() {
429 let total = self.n_events();
431 return DatasetIntoIter::Mpi(DatasetMpiIntoIter {
432 events: self.events,
433 metadata: self.metadata,
434 world,
435 index: 0,
436 total,
437 });
438 }
439 }
440 DatasetIntoIter::Local(self.events.into_iter())
441 }
442}
443
444impl Dataset {
445 pub fn iter(&self) -> DatasetIter<'_> {
448 #[cfg(feature = "mpi")]
449 {
450 if let Some(world) = crate::mpi::get_world() {
451 return DatasetIter::Mpi(DatasetMpiIter {
452 dataset: self,
453 world,
454 index: 0,
455 total: self.n_events(),
456 });
457 }
458 }
459 DatasetIter::Local(self.events.iter())
460 }
461
462 pub fn metadata(&self) -> &DatasetMetadata {
464 &self.metadata
465 }
466
467 pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
469 self.metadata.clone()
470 }
471
472 pub fn p4_names(&self) -> &[String] {
474 &self.metadata.p4_names
475 }
476
477 pub fn aux_names(&self) -> &[String] {
479 &self.metadata.aux_names
480 }
481
482 pub fn p4_index(&self, name: &str) -> Option<usize> {
484 self.metadata.p4_index(name)
485 }
486
487 pub fn aux_index(&self, name: &str) -> Option<usize> {
489 self.metadata.aux_index(name)
490 }
491
492 pub fn named_event(&self, index: usize) -> Event {
494 self.events[index].clone()
495 }
496
497 pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
499 self.events
500 .get(event_index)
501 .and_then(|event| event.p4(name))
502 }
503
504 pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
506 let idx = self.aux_index(name)?;
507 self.events
508 .get(event_index)
509 .and_then(|event| event.aux.get(idx))
510 .copied()
511 }
512
513 pub fn index_local(&self, index: usize) -> &Event {
527 &self.events[index]
528 }
529
530 #[cfg(feature = "mpi")]
531 fn partition(
532 events: Vec<Arc<EventData>>,
533 world: &SimpleCommunicator,
534 ) -> Vec<Vec<Arc<EventData>>> {
535 let partition = world.partition(events.len());
536 (0..partition.n_ranks())
537 .map(|rank| {
538 let range = partition.range_for_rank(rank);
539 events[range.clone()].iter().cloned().collect()
540 })
541 .collect()
542 }
543
544 #[cfg(feature = "mpi")]
558 pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
559 let total = self.n_events();
560 let event = fetch_event_mpi(self, index, world, total);
561 Box::leak(Box::new(event))
562 }
563}
564
565impl Index<usize> for Dataset {
566 type Output = Event;
567
568 fn index(&self, index: usize) -> &Self::Output {
569 #[cfg(feature = "mpi")]
570 {
571 if let Some(world) = crate::mpi::get_world() {
572 return self.index_mpi(index, &world);
573 }
574 }
575 self.index_local(index)
576 }
577}
578
579pub enum DatasetIter<'a> {
581 Local(std::slice::Iter<'a, Event>),
583 #[cfg(feature = "mpi")]
584 Mpi(DatasetMpiIter<'a>),
586}
587
588impl<'a> Iterator for DatasetIter<'a> {
589 type Item = Event;
590
591 fn next(&mut self) -> Option<Self::Item> {
592 match self {
593 DatasetIter::Local(iter) => iter.next().cloned(),
594 #[cfg(feature = "mpi")]
595 DatasetIter::Mpi(iter) => iter.next(),
596 }
597 }
598}
599
600pub enum DatasetIntoIter {
602 Local(std::vec::IntoIter<Event>),
604 #[cfg(feature = "mpi")]
605 Mpi(DatasetMpiIntoIter),
607}
608
609impl Iterator for DatasetIntoIter {
610 type Item = Event;
611
612 fn next(&mut self) -> Option<Self::Item> {
613 match self {
614 DatasetIntoIter::Local(iter) => iter.next(),
615 #[cfg(feature = "mpi")]
616 DatasetIntoIter::Mpi(iter) => iter.next(),
617 }
618 }
619}
620
621#[cfg(feature = "mpi")]
622pub struct DatasetMpiIter<'a> {
624 dataset: &'a Dataset,
625 world: SimpleCommunicator,
626 index: usize,
627 total: usize,
628}
629
630#[cfg(feature = "mpi")]
631impl<'a> Iterator for DatasetMpiIter<'a> {
632 type Item = Event;
633
634 fn next(&mut self) -> Option<Self::Item> {
635 if self.index >= self.total {
636 return None;
637 }
638 let event = fetch_event_mpi(self.dataset, self.index, &self.world, self.total);
639 self.index += 1;
640 Some(event)
641 }
642}
643
644#[cfg(feature = "mpi")]
645pub struct DatasetMpiIntoIter {
647 events: Vec<Event>,
648 metadata: Arc<DatasetMetadata>,
649 world: SimpleCommunicator,
650 index: usize,
651 total: usize,
652}
653
654#[cfg(feature = "mpi")]
655impl Iterator for DatasetMpiIntoIter {
656 type Item = Event;
657
658 fn next(&mut self) -> Option<Self::Item> {
659 if self.index >= self.total {
660 return None;
661 }
662 let event = fetch_event_mpi_from_events(
663 &self.events,
664 &self.metadata,
665 self.index,
666 &self.world,
667 self.total,
668 );
669 self.index += 1;
670 Some(event)
671 }
672}
673
674#[cfg(feature = "mpi")]
675fn fetch_event_mpi(
676 dataset: &Dataset,
677 global_index: usize,
678 world: &SimpleCommunicator,
679 total: usize,
680) -> Event {
681 fetch_event_mpi_generic(
682 global_index,
683 total,
684 world,
685 &dataset.metadata,
686 |local_index| dataset.index_local(local_index),
687 )
688}
689
690#[cfg(feature = "mpi")]
691fn fetch_event_mpi_from_events(
692 events: &[Event],
693 metadata: &Arc<DatasetMetadata>,
694 global_index: usize,
695 world: &SimpleCommunicator,
696 total: usize,
697) -> Event {
698 fetch_event_mpi_generic(global_index, total, world, metadata, |local_index| {
699 &events[local_index]
700 })
701}
702
703#[cfg(feature = "mpi")]
704fn fetch_event_mpi_generic<'a, F>(
705 global_index: usize,
706 total: usize,
707 world: &SimpleCommunicator,
708 metadata: &Arc<DatasetMetadata>,
709 local_event: F,
710) -> Event
711where
712 F: Fn(usize) -> &'a Event,
713{
714 let (owning_rank, local_index) = world.owner_of_global_index(global_index, total);
715 let mut serialized_event_buffer_len: usize = 0;
716 let mut serialized_event_buffer: Vec<u8> = Vec::default();
717 let config = bincode::config::standard();
718 if world.rank() == owning_rank {
719 let event = local_event(local_index);
720 serialized_event_buffer = bincode::serde::encode_to_vec(event.data(), config).unwrap();
721 serialized_event_buffer_len = serialized_event_buffer.len();
722 }
723 world
724 .process_at_rank(owning_rank)
725 .broadcast_into(&mut serialized_event_buffer_len);
726 if world.rank() != owning_rank {
727 serialized_event_buffer = vec![0; serialized_event_buffer_len];
728 }
729 world
730 .process_at_rank(owning_rank)
731 .broadcast_into(&mut serialized_event_buffer);
732
733 if world.rank() == owning_rank {
734 local_event(local_index).clone()
735 } else {
736 let (event, _): (EventData, usize) =
737 bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
738 Event::new(Arc::new(event), metadata.clone())
739 }
740}
741
742impl Dataset {
743 pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
750 let wrapped_events = events
751 .into_iter()
752 .map(|event| Event::new(event, metadata.clone()))
753 .collect();
754 Dataset {
755 events: wrapped_events,
756 metadata,
757 }
758 }
759
760 #[cfg(feature = "mpi")]
767 pub fn new_mpi(
768 events: Vec<Arc<EventData>>,
769 metadata: Arc<DatasetMetadata>,
770 world: &SimpleCommunicator,
771 ) -> Self {
772 let partitions = Dataset::partition(events, world);
773 let local = partitions[world.rank() as usize]
774 .iter()
775 .cloned()
776 .map(|event| Event::new(event, metadata.clone()))
777 .collect();
778 Dataset {
779 events: local,
780 metadata,
781 }
782 }
783
784 pub fn new(events: Vec<Arc<EventData>>) -> Self {
790 Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
791 }
792
793 pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
796 #[cfg(feature = "mpi")]
797 {
798 if let Some(world) = crate::mpi::get_world() {
799 return Dataset::new_mpi(events, metadata, &world);
800 }
801 }
802 Dataset::new_local(events, metadata)
803 }
804
805 pub fn n_events_local(&self) -> usize {
812 self.events.len()
813 }
814
815 #[cfg(feature = "mpi")]
822 pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
823 let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
824 let n_events_local = self.n_events_local();
825 world.all_gather_into(&n_events_local, &mut n_events_partitioned);
826 n_events_partitioned.iter().sum()
827 }
828
829 pub fn n_events(&self) -> usize {
831 #[cfg(feature = "mpi")]
832 {
833 if let Some(world) = crate::mpi::get_world() {
834 return self.n_events_mpi(&world);
835 }
836 }
837 self.n_events_local()
838 }
839}
840
841impl Dataset {
842 pub fn weights_local(&self) -> Vec<f64> {
849 #[cfg(feature = "rayon")]
850 return self.events.par_iter().map(|e| e.weight).collect();
851 #[cfg(not(feature = "rayon"))]
852 return self.events.iter().map(|e| e.weight).collect();
853 }
854
855 #[cfg(feature = "mpi")]
862 pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
863 let local_weights = self.weights_local();
864 let n_events = self.n_events();
865 let mut buffer: Vec<f64> = vec![0.0; n_events];
866 let (counts, displs) = world.get_counts_displs(n_events);
867 {
868 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
869 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
870 }
871 buffer
872 }
873
874 pub fn weights(&self) -> Vec<f64> {
876 #[cfg(feature = "mpi")]
877 {
878 if let Some(world) = crate::mpi::get_world() {
879 return self.weights_mpi(&world);
880 }
881 }
882 self.weights_local()
883 }
884
885 pub fn n_events_weighted_local(&self) -> f64 {
892 #[cfg(feature = "rayon")]
893 return self
894 .events
895 .par_iter()
896 .map(|e| e.weight)
897 .parallel_sum_with_accumulator::<Klein<f64>>();
898 #[cfg(not(feature = "rayon"))]
899 return self.events.iter().map(|e| e.weight).sum();
900 }
901 #[cfg(feature = "mpi")]
908 pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> f64 {
909 let mut n_events_weighted_partitioned: Vec<f64> = vec![0.0; world.size() as usize];
910 let n_events_weighted_local = self.n_events_weighted_local();
911 world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
912 #[cfg(feature = "rayon")]
913 return n_events_weighted_partitioned
914 .into_par_iter()
915 .parallel_sum_with_accumulator::<Klein<f64>>();
916 #[cfg(not(feature = "rayon"))]
917 return n_events_weighted_partitioned.iter().sum();
918 }
919
920 pub fn n_events_weighted(&self) -> f64 {
922 #[cfg(feature = "mpi")]
923 {
924 if let Some(world) = crate::mpi::get_world() {
925 return self.n_events_weighted_mpi(&world);
926 }
927 }
928 self.n_events_weighted_local()
929 }
930
931 pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
939 let mut rng = fastrand::Rng::with_seed(seed as u64);
940 let mut indices: Vec<usize> = (0..self.n_events())
941 .map(|_| rng.usize(0..self.n_events()))
942 .collect::<Vec<usize>>();
943 indices.sort();
944 #[cfg(feature = "rayon")]
945 let bootstrapped_events: Vec<Arc<EventData>> = indices
946 .into_par_iter()
947 .map(|idx| self.events[idx].data_arc())
948 .collect();
949 #[cfg(not(feature = "rayon"))]
950 let bootstrapped_events: Vec<Arc<EventData>> = indices
951 .into_iter()
952 .map(|idx| self.events[idx].data_arc())
953 .collect();
954 Arc::new(Dataset::new_with_metadata(
955 bootstrapped_events,
956 self.metadata.clone(),
957 ))
958 }
959
960 #[cfg(feature = "mpi")]
968 pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
969 let n_events = self.n_events();
970 let mut indices: Vec<usize> = vec![0; n_events];
971 if world.is_root() {
972 let mut rng = fastrand::Rng::with_seed(seed as u64);
973 indices = (0..n_events)
974 .map(|_| rng.usize(0..n_events))
975 .collect::<Vec<usize>>();
976 indices.sort();
977 }
978 world.process_at_root().broadcast_into(&mut indices);
979 let local_indices: Vec<usize> = indices
980 .into_iter()
981 .filter_map(|idx| {
982 let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
983 if world.rank() == owning_rank {
984 Some(local_index)
985 } else {
986 None
987 }
988 })
989 .collect();
990 #[cfg(feature = "rayon")]
993 let bootstrapped_events: Vec<Arc<EventData>> = local_indices
994 .into_par_iter()
995 .map(|idx| self.events[idx].data_arc())
996 .collect();
997 #[cfg(not(feature = "rayon"))]
998 let bootstrapped_events: Vec<Arc<EventData>> = local_indices
999 .into_iter()
1000 .map(|idx| self.events[idx].data_arc())
1001 .collect();
1002 Arc::new(Dataset::new_with_metadata(
1003 bootstrapped_events,
1004 self.metadata.clone(),
1005 ))
1006 }
1007
1008 pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1011 #[cfg(feature = "mpi")]
1012 {
1013 if let Some(world) = crate::mpi::get_world() {
1014 return self.bootstrap_mpi(seed, &world);
1015 }
1016 }
1017 self.bootstrap_local(seed)
1018 }
1019
1020 pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1023 let compiled = expression.compile(&self.metadata)?;
1024 #[cfg(feature = "rayon")]
1025 let filtered_events: Vec<Arc<EventData>> = self
1026 .events
1027 .par_iter()
1028 .filter(|event| compiled.evaluate(event.as_ref()))
1029 .map(|event| event.data_arc())
1030 .collect();
1031 #[cfg(not(feature = "rayon"))]
1032 let filtered_events: Vec<Arc<EventData>> = self
1033 .events
1034 .iter()
1035 .filter(|event| compiled.evaluate(event.as_ref()))
1036 .map(|event| event.data_arc())
1037 .collect();
1038 Ok(Arc::new(Dataset::new_with_metadata(
1039 filtered_events,
1040 self.metadata.clone(),
1041 )))
1042 }
1043
1044 pub fn bin_by<V>(
1047 &self,
1048 mut variable: V,
1049 bins: usize,
1050 range: (f64, f64),
1051 ) -> LadduResult<BinnedDataset>
1052 where
1053 V: Variable,
1054 {
1055 variable.bind(self.metadata())?;
1056 let bin_width = (range.1 - range.0) / bins as f64;
1057 let bin_edges = get_bin_edges(bins, range);
1058 let variable = variable;
1059 #[cfg(feature = "rayon")]
1060 let evaluated: Vec<(usize, Arc<EventData>)> = self
1061 .events
1062 .par_iter()
1063 .filter_map(|event| {
1064 let value = variable.value(event.as_ref());
1065 if value >= range.0 && value < range.1 {
1066 let bin_index = ((value - range.0) / bin_width) as usize;
1067 let bin_index = bin_index.min(bins - 1);
1068 Some((bin_index, event.data_arc()))
1069 } else {
1070 None
1071 }
1072 })
1073 .collect();
1074 #[cfg(not(feature = "rayon"))]
1075 let evaluated: Vec<(usize, Arc<EventData>)> = self
1076 .events
1077 .iter()
1078 .filter_map(|event| {
1079 let value = variable.value(event.as_ref());
1080 if value >= range.0 && value < range.1 {
1081 let bin_index = ((value - range.0) / bin_width) as usize;
1082 let bin_index = bin_index.min(bins - 1);
1083 Some((bin_index, event.data_arc()))
1084 } else {
1085 None
1086 }
1087 })
1088 .collect();
1089 let mut binned_events: Vec<Vec<Arc<EventData>>> = vec![Vec::default(); bins];
1090 for (bin_index, event) in evaluated {
1091 binned_events[bin_index].push(event.clone());
1092 }
1093 #[cfg(feature = "rayon")]
1094 let datasets: Vec<Arc<Dataset>> = binned_events
1095 .into_par_iter()
1096 .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1097 .collect();
1098 #[cfg(not(feature = "rayon"))]
1099 let datasets: Vec<Arc<Dataset>> = binned_events
1100 .into_iter()
1101 .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1102 .collect();
1103 Ok(BinnedDataset {
1104 datasets,
1105 edges: bin_edges,
1106 })
1107 }
1108
1109 pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1112 where
1113 S: AsRef<str>,
1114 {
1115 let mut indices: Vec<usize> = Vec::new();
1116 for name in names {
1117 let name_ref = name.as_ref();
1118 if let Some(selection) = self.metadata.p4_selection(name_ref) {
1119 indices.extend_from_slice(selection.indices());
1120 } else {
1121 panic!("Unknown particle name '{name}'", name = name_ref);
1122 }
1123 }
1124 #[cfg(feature = "rayon")]
1125 let boosted_events: Vec<Arc<EventData>> = self
1126 .events
1127 .par_iter()
1128 .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1129 .collect();
1130 #[cfg(not(feature = "rayon"))]
1131 let boosted_events: Vec<Arc<EventData>> = self
1132 .events
1133 .iter()
1134 .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1135 .collect();
1136 Arc::new(Dataset::new_with_metadata(
1137 boosted_events,
1138 self.metadata.clone(),
1139 ))
1140 }
1141 pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
1143 variable.value_on(self)
1144 }
1145
1146 fn write_parquet_impl(
1147 &self,
1148 file_path: PathBuf,
1149 options: &DatasetWriteOptions,
1150 ) -> LadduResult<()> {
1151 let batch_size = options.batch_size.max(1);
1152 let precision = options.precision;
1153 let schema = Arc::new(build_parquet_schema(&self.metadata, precision));
1154
1155 #[cfg(feature = "mpi")]
1156 let is_root = crate::mpi::get_world()
1157 .as_ref()
1158 .map_or(true, |world| world.rank() == 0);
1159 #[cfg(not(feature = "mpi"))]
1160 let is_root = true;
1161
1162 let mut writer: Option<ArrowWriter<File>> = None;
1163 if is_root {
1164 let file = File::create(&file_path)?;
1165 writer = Some(
1166 ArrowWriter::try_new(file, schema.clone(), None).map_err(|err| {
1167 LadduError::Custom(format!("Failed to create Parquet writer: {err}"))
1168 })?,
1169 );
1170 }
1171
1172 let mut iter = self.iter();
1173 loop {
1174 let mut buffers =
1175 ColumnBuffers::new(self.metadata.p4_names.len(), self.metadata.aux_names.len());
1176 let mut rows = 0usize;
1177
1178 while rows < batch_size {
1179 match iter.next() {
1180 Some(event) => {
1181 if is_root {
1182 buffers.push_event(&event);
1183 }
1184 rows += 1;
1185 }
1186 None => break,
1187 }
1188 }
1189
1190 if rows == 0 {
1191 break;
1192 }
1193
1194 if let Some(writer) = writer.as_mut() {
1195 let batch = buffers
1196 .into_record_batch(schema.clone(), precision)
1197 .map_err(|err| {
1198 LadduError::Custom(format!("Failed to build Parquet batch: {err}"))
1199 })?;
1200 writer.write(&batch).map_err(|err| {
1201 LadduError::Custom(format!("Failed to write Parquet batch: {err}"))
1202 })?;
1203 }
1204 }
1205
1206 if let Some(writer) = writer {
1207 writer.close().map_err(|err| {
1208 LadduError::Custom(format!("Failed to finalise Parquet file: {err}"))
1209 })?;
1210 }
1211
1212 Ok(())
1213 }
1214
1215 fn write_root_impl(
1216 &self,
1217 file_path: PathBuf,
1218 options: &DatasetWriteOptions,
1219 ) -> LadduResult<()> {
1220 let tree_name = options.tree.clone().unwrap_or_else(|| "events".to_string());
1221 let branch_count = self.metadata.p4_names.len() * 4 + self.metadata.aux_names.len() + 1; #[cfg(feature = "mpi")]
1224 let mut world_opt = crate::mpi::get_world();
1225 #[cfg(feature = "mpi")]
1226 let is_root = world_opt.as_ref().map_or(true, |world| world.rank() == 0);
1227 #[cfg(not(feature = "mpi"))]
1228 let is_root = true;
1229
1230 #[cfg(feature = "mpi")]
1231 let world: Option<WorldHandle> = world_opt.take();
1232 #[cfg(not(feature = "mpi"))]
1233 let world: Option<WorldHandle> = None;
1234
1235 let total_events = self.n_events();
1236 let dataset_arc = Arc::new(self.clone());
1237
1238 match options.precision {
1239 FloatPrecision::F64 => self.write_root_with_type::<f64>(
1240 dataset_arc,
1241 world,
1242 is_root,
1243 &file_path,
1244 &tree_name,
1245 branch_count,
1246 total_events,
1247 ),
1248 FloatPrecision::F32 => self.write_root_with_type::<f32>(
1249 dataset_arc,
1250 world,
1251 is_root,
1252 &file_path,
1253 &tree_name,
1254 branch_count,
1255 total_events,
1256 ),
1257 }
1258 }
1259}
1260
1261fn canonicalize_dataset_path(file_path: &str) -> LadduResult<PathBuf> {
1262 Ok(Path::new(&*shellexpand::full(file_path)?).canonicalize()?)
1263}
1264
1265fn expand_output_path(file_path: &str) -> LadduResult<PathBuf> {
1266 Ok(PathBuf::from(&*shellexpand::full(file_path)?))
1267}
1268
1269pub fn read_parquet(file_path: &str, options: &DatasetReadOptions) -> LadduResult<Arc<Dataset>> {
1271 let path = canonicalize_dataset_path(file_path)?;
1272 let (detected_p4_names, detected_aux_names) = detect_columns(&path)?;
1273 let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1274 let file = File::open(path)?;
1275 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1276
1277 #[cfg(feature = "mpi")]
1278 {
1279 if let Some(world) = crate::mpi::get_world() {
1280 return read_parquet_mpi(builder, metadata, &world);
1281 }
1282 }
1283
1284 read_parquet_local(builder, metadata)
1285}
1286
1287fn read_parquet_local(
1288 builder: ParquetRecordBatchReaderBuilder<File>,
1289 metadata: Arc<DatasetMetadata>,
1290) -> LadduResult<Arc<Dataset>> {
1291 let reader = builder.build()?;
1292 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
1293 let events = batches_to_events(batches, metadata.as_ref())?;
1294 Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1295}
1296
1297#[cfg(feature = "mpi")]
1298fn read_parquet_mpi(
1299 mut builder: ParquetRecordBatchReaderBuilder<File>,
1300 metadata: Arc<DatasetMetadata>,
1301 world: &SimpleCommunicator,
1302) -> LadduResult<Arc<Dataset>> {
1303 let parquet_metadata = builder.metadata().clone();
1304 let total_rows = parquet_metadata.file_metadata().num_rows() as usize;
1305 if total_rows == 0 {
1306 return Ok(Arc::new(Dataset::new_local(Vec::new(), metadata)));
1307 }
1308
1309 let partition = world.partition(total_rows);
1310 let rank = world.rank() as usize;
1311 let local_range = partition.range_for_rank(rank);
1312 let local_start = local_range.start;
1313 let local_end = local_range.end;
1314 if local_start == local_end {
1315 return Ok(Arc::new(Dataset::new_local(Vec::new(), metadata)));
1316 }
1317
1318 let (row_groups, first_row_start) =
1319 row_groups_for_range(&parquet_metadata, local_start, local_end);
1320 if !row_groups.is_empty() {
1321 builder = builder.with_row_groups(row_groups);
1322 }
1323
1324 let reader = builder.build()?;
1325 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
1326 let mut events = batches_to_events(batches, metadata.as_ref())?;
1327
1328 let drop_front = local_start.saturating_sub(first_row_start);
1329 if drop_front > 0 {
1330 events.drain(0..drop_front);
1331 }
1332 let expected_local = local_end - local_start;
1333 if events.len() > expected_local {
1334 events.truncate(expected_local);
1335 }
1336 if events.len() != expected_local {
1337 return Err(LadduError::Custom(format!(
1338 "Loaded {} rows on rank {} but expected {}",
1339 events.len(),
1340 rank,
1341 expected_local
1342 )));
1343 }
1344
1345 Ok(Arc::new(Dataset::new_local(events, metadata)))
1346}
1347
1348#[cfg(feature = "mpi")]
1349fn row_groups_for_range(
1350 metadata: &Arc<ParquetMetaData>,
1351 start: usize,
1352 end: usize,
1353) -> (Vec<usize>, usize) {
1354 let mut selected = Vec::new();
1355 let mut first_row_start = start;
1356 let mut offset = 0usize;
1357
1358 for (idx, row_group) in metadata.row_groups().iter().enumerate() {
1359 let group_start = offset;
1360 let rows = row_group.num_rows() as usize;
1361 let group_end = group_start + rows;
1362 offset = group_end;
1363
1364 if group_end <= start {
1365 continue;
1366 }
1367 if group_start >= end {
1368 break;
1369 }
1370 if selected.is_empty() {
1371 first_row_start = group_start;
1372 }
1373 selected.push(idx);
1374 if group_end >= end {
1375 break;
1376 }
1377 }
1378
1379 (selected, first_row_start)
1380}
1381
1382fn batches_to_events(
1383 batches: Vec<RecordBatch>,
1384 metadata: &DatasetMetadata,
1385) -> LadduResult<Vec<Arc<EventData>>> {
1386 #[cfg(feature = "rayon")]
1387 {
1388 let batch_events: Vec<LadduResult<Vec<Arc<EventData>>>> = batches
1389 .into_par_iter()
1390 .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1391 .collect();
1392 let mut events = Vec::new();
1393 for batch in batch_events {
1394 let mut batch = batch?;
1395 events.append(&mut batch);
1396 }
1397 Ok(events)
1398 }
1399
1400 #[cfg(not(feature = "rayon"))]
1401 {
1402 Ok(batches
1403 .into_iter()
1404 .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1405 .collect::<LadduResult<Vec<_>>>()?
1406 .into_iter()
1407 .flatten()
1408 .collect())
1409 }
1410}
1411
1412pub fn read_root(file_path: &str, options: &DatasetReadOptions) -> LadduResult<Arc<Dataset>> {
1414 let path = canonicalize_dataset_path(file_path)?;
1415 let mut file = RootFile::open(&path).map_err(|err| {
1416 LadduError::Custom(format!(
1417 "Failed to open ROOT file '{}': {err}",
1418 path.display()
1419 ))
1420 })?;
1421
1422 let (tree, tree_name) = resolve_root_tree(&mut file, options.tree.as_deref())?;
1423
1424 let branches: Vec<&Branch> = tree.branches().collect();
1425 let mut lookup: BranchLookup<'_> = IndexMap::new();
1426 for &branch in &branches {
1427 if let Some(kind) = branch_scalar_kind(branch) {
1428 lookup.insert(branch.name(), (kind, branch));
1429 }
1430 }
1431
1432 if lookup.is_empty() {
1433 return Err(LadduError::Custom(format!(
1434 "No float or double branches found in ROOT tree '{tree_name}'"
1435 )));
1436 }
1437
1438 let column_names: Vec<&str> = lookup.keys().copied().collect();
1439 let (detected_p4_names, detected_aux_names) = infer_p4_and_aux_names(&column_names);
1440 let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1441
1442 struct RootP4Columns {
1443 px: Vec<f64>,
1444 py: Vec<f64>,
1445 pz: Vec<f64>,
1446 e: Vec<f64>,
1447 }
1448
1449 let mut p4_columns = Vec::with_capacity(metadata.p4_names.len());
1451 for name in &metadata.p4_names {
1452 let logical = format!("{name}_px");
1453 let px = read_branch_values_from_candidates(
1454 &lookup,
1455 &component_candidates(name, "px"),
1456 &logical,
1457 )?;
1458
1459 let logical = format!("{name}_py");
1460 let py = read_branch_values_from_candidates(
1461 &lookup,
1462 &component_candidates(name, "py"),
1463 &logical,
1464 )?;
1465
1466 let logical = format!("{name}_pz");
1467 let pz = read_branch_values_from_candidates(
1468 &lookup,
1469 &component_candidates(name, "pz"),
1470 &logical,
1471 )?;
1472
1473 let logical = format!("{name}_e");
1474 let e = read_branch_values_from_candidates(
1475 &lookup,
1476 &component_candidates(name, "e"),
1477 &logical,
1478 )?;
1479
1480 p4_columns.push(RootP4Columns { px, py, pz, e });
1481 }
1482
1483 let mut aux_columns = Vec::with_capacity(metadata.aux_names.len());
1484 for name in &metadata.aux_names {
1485 let values = read_branch_values(&lookup, name)?;
1486 aux_columns.push(values);
1487 }
1488
1489 let n_events = if let Some(first) = p4_columns.first() {
1490 first.px.len()
1491 } else if let Some(first) = aux_columns.first() {
1492 first.len()
1493 } else {
1494 return Err(LadduError::Custom(
1495 "Unable to determine event count; dataset has no four-momentum or auxiliary columns"
1496 .to_string(),
1497 ));
1498 };
1499
1500 let weight_values = match read_branch_values_optional(&lookup, "weight")? {
1501 Some(values) => {
1502 if values.len() != n_events {
1503 return Err(LadduError::Custom(format!(
1504 "Column 'weight' has {} entries but expected {}",
1505 values.len(),
1506 n_events
1507 )));
1508 }
1509 values
1510 }
1511 None => vec![1.0; n_events],
1512 };
1513
1514 let mut events = Vec::with_capacity(n_events);
1515 for row in 0..n_events {
1516 let mut p4s = Vec::with_capacity(p4_columns.len());
1517 for columns in &p4_columns {
1518 p4s.push(Vec4::new(
1519 columns.px[row],
1520 columns.py[row],
1521 columns.pz[row],
1522 columns.e[row],
1523 ));
1524 }
1525
1526 let mut aux = Vec::with_capacity(aux_columns.len());
1527 for column in &aux_columns {
1528 aux.push(column[row]);
1529 }
1530
1531 let event = EventData {
1532 p4s,
1533 aux,
1534 weight: weight_values[row],
1535 };
1536 events.push(Arc::new(event));
1537 }
1538
1539 Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1540}
1541
1542pub fn write_parquet(
1544 dataset: &Dataset,
1545 file_path: &str,
1546 options: &DatasetWriteOptions,
1547) -> LadduResult<()> {
1548 let path = expand_output_path(file_path)?;
1549 dataset.write_parquet_impl(path, options)
1550}
1551
1552pub fn write_root(
1554 dataset: &Dataset,
1555 file_path: &str,
1556 options: &DatasetWriteOptions,
1557) -> LadduResult<()> {
1558 let path = expand_output_path(file_path)?;
1559 dataset.write_root_impl(path, options)
1560}
1561
1562impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
1563 debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
1564 debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
1565 Dataset {
1566 events: a
1567 .events
1568 .iter()
1569 .chain(b.events.iter())
1570 .cloned()
1571 .collect(),
1572 metadata: a.metadata.clone(),
1573 }
1574});
1575
1576#[derive(Default, Clone)]
1581pub struct DatasetReadOptions {
1582 pub p4_names: Option<Vec<String>>,
1584 pub aux_names: Option<Vec<String>>,
1586 pub tree: Option<String>,
1589 pub aliases: IndexMap<String, P4Selection>,
1591}
1592
1593#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
1595pub enum FloatPrecision {
1596 F32,
1598 #[default]
1600 F64,
1601}
1602
1603#[derive(Clone, Debug)]
1605pub struct DatasetWriteOptions {
1606 pub batch_size: usize,
1608 pub precision: FloatPrecision,
1610 pub tree: Option<String>,
1612}
1613
1614impl Default for DatasetWriteOptions {
1615 fn default() -> Self {
1616 Self {
1617 batch_size: DEFAULT_WRITE_BATCH_SIZE,
1618 precision: FloatPrecision::default(),
1619 tree: None,
1620 }
1621 }
1622}
1623
1624impl DatasetWriteOptions {
1625 pub fn batch_size(mut self, batch_size: usize) -> Self {
1627 self.batch_size = batch_size;
1628 self
1629 }
1630
1631 pub fn precision(mut self, precision: FloatPrecision) -> Self {
1633 self.precision = precision;
1634 self
1635 }
1636
1637 pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
1639 self.tree = Some(name.into());
1640 self
1641 }
1642}
1643impl DatasetReadOptions {
1644 pub fn new() -> Self {
1646 Self::default()
1647 }
1648
1649 pub fn p4_names<I, S>(mut self, names: I) -> Self
1652 where
1653 I: IntoIterator<Item = S>,
1654 S: AsRef<str>,
1655 {
1656 self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1657 self
1658 }
1659
1660 pub fn aux_names<I, S>(mut self, names: I) -> Self
1664 where
1665 I: IntoIterator<Item = S>,
1666 S: AsRef<str>,
1667 {
1668 self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1669 self
1670 }
1671
1672 pub fn tree<S>(mut self, name: S) -> Self
1674 where
1675 S: AsRef<str>,
1676 {
1677 self.tree = Some(name.as_ref().to_string());
1678 self
1679 }
1680
1681 pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
1683 where
1684 N: Into<String>,
1685 S: IntoP4Selection,
1686 {
1687 self.aliases.insert(name.into(), selection.into_selection());
1688 self
1689 }
1690
1691 pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
1693 where
1694 I: IntoIterator<Item = (N, S)>,
1695 N: Into<String>,
1696 S: IntoP4Selection,
1697 {
1698 for (name, selection) in aliases {
1699 self = self.alias(name, selection);
1700 }
1701 self
1702 }
1703
1704 fn resolve_metadata(
1705 &self,
1706 detected_p4_names: Vec<String>,
1707 detected_aux_names: Vec<String>,
1708 ) -> LadduResult<Arc<DatasetMetadata>> {
1709 let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
1710 let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
1711
1712 let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
1713 if !self.aliases.is_empty() {
1714 metadata.add_p4_aliases(self.aliases.clone())?;
1715 }
1716 Ok(Arc::new(metadata))
1717 }
1718}
1719
1720const P4_COMPONENT_SUFFIXES: [&str; 4] = ["_px", "_py", "_pz", "_e"];
1721const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
1722
1723fn detect_columns(file_path: &PathBuf) -> LadduResult<(Vec<String>, Vec<String>)> {
1724 let file = File::open(file_path)?;
1725 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1726 let schema = builder.schema();
1727 let float_cols: Vec<&str> = schema
1728 .fields()
1729 .iter()
1730 .filter(|f| matches!(f.data_type(), DataType::Float32 | DataType::Float64))
1731 .map(|f| f.name().as_str())
1732 .collect();
1733 Ok(infer_p4_and_aux_names(&float_cols))
1734}
1735
1736fn infer_p4_and_aux_names(float_cols: &[&str]) -> (Vec<String>, Vec<String>) {
1737 let suffix_set: IndexSet<&str> = P4_COMPONENT_SUFFIXES.iter().copied().collect();
1738 let mut groups: IndexMap<&str, IndexSet<&str>> = IndexMap::new();
1739 for col in float_cols {
1740 for suffix in &suffix_set {
1741 if let Some(prefix) = col.strip_suffix(suffix) {
1742 groups.entry(prefix).or_default().insert(*suffix);
1743 }
1744 }
1745 }
1746
1747 let mut p4_names: Vec<String> = Vec::new();
1748 let mut p4_columns: IndexSet<String> = IndexSet::new();
1749 for (prefix, suffixes) in &groups {
1750 if suffixes.len() == suffix_set.len() {
1751 p4_names.push((*prefix).to_string());
1752 for suffix in &suffix_set {
1753 p4_columns.insert(format!("{prefix}{suffix}"));
1754 }
1755 }
1756 }
1757
1758 let mut aux_names: Vec<String> = Vec::new();
1759 for col in float_cols {
1760 if p4_columns.contains(*col) {
1761 continue;
1762 }
1763 if col.eq_ignore_ascii_case("weight") {
1764 continue;
1765 }
1766 aux_names.push((*col).to_string());
1767 }
1768
1769 (p4_names, aux_names)
1770}
1771
1772type BranchLookup<'a> = IndexMap<&'a str, (RootScalarKind, &'a Branch)>;
1773
1774#[derive(Clone, Copy)]
1775enum RootScalarKind {
1776 F32,
1777 F64,
1778}
1779
1780fn branch_scalar_kind(branch: &Branch) -> Option<RootScalarKind> {
1781 let type_name = branch.item_type_name();
1782 let lower = type_name.to_ascii_lowercase();
1783 if lower.contains("vector") {
1784 return None;
1785 }
1786 match lower.as_str() {
1787 "float" | "float_t" | "float32_t" => Some(RootScalarKind::F32),
1788 "double" | "double_t" | "double32_t" => Some(RootScalarKind::F64),
1789 _ => None,
1790 }
1791}
1792
1793fn read_branch_values<'a>(lookup: &BranchLookup<'a>, column_name: &str) -> LadduResult<Vec<f64>> {
1794 let (kind, branch) =
1795 lookup
1796 .get(column_name)
1797 .copied()
1798 .ok_or_else(|| LadduError::MissingColumn {
1799 name: column_name.to_string(),
1800 })?;
1801
1802 let values = match kind {
1803 RootScalarKind::F32 => branch
1804 .as_iter::<f32>()
1805 .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1806 .map(|value| value as f64)
1807 .collect(),
1808 RootScalarKind::F64 => branch
1809 .as_iter::<f64>()
1810 .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1811 .collect(),
1812 };
1813 Ok(values)
1814}
1815
1816fn read_branch_values_optional<'a>(
1817 lookup: &BranchLookup<'a>,
1818 column_name: &str,
1819) -> LadduResult<Option<Vec<f64>>> {
1820 if lookup.contains_key(column_name) {
1821 read_branch_values(lookup, column_name).map(Some)
1822 } else {
1823 Ok(None)
1824 }
1825}
1826
1827fn read_branch_values_from_candidates<'a>(
1828 lookup: &BranchLookup<'a>,
1829 candidates: &[String],
1830 logical_name: &str,
1831) -> LadduResult<Vec<f64>> {
1832 for candidate in candidates {
1833 if lookup.contains_key(candidate.as_str()) {
1834 return read_branch_values(lookup, candidate);
1835 }
1836 }
1837 Err(LadduError::MissingColumn {
1838 name: logical_name.to_string(),
1839 })
1840}
1841
1842fn resolve_root_tree(
1843 file: &mut RootFile,
1844 requested: Option<&str>,
1845) -> LadduResult<(ReaderTree, String)> {
1846 if let Some(name) = requested {
1847 let tree = file
1848 .get_tree(name)
1849 .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{name}'"), err))?;
1850 return Ok((tree, name.to_string()));
1851 }
1852
1853 let tree_names: Vec<String> = file
1854 .keys()
1855 .into_iter()
1856 .filter(|key| key.class_name() == "TTree")
1857 .map(|key| key.name().to_string())
1858 .collect();
1859
1860 if tree_names.is_empty() {
1861 return Err(LadduError::Custom(
1862 "ROOT file does not contain any TTrees".to_string(),
1863 ));
1864 }
1865
1866 if tree_names.len() > 1 {
1867 return Err(LadduError::Custom(format!(
1868 "Multiple TTrees found ({:?}); specify DatasetReadOptions::tree to disambiguate",
1869 tree_names
1870 )));
1871 }
1872
1873 let selected = &tree_names[0];
1874 let tree = file
1875 .get_tree(selected)
1876 .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{selected}'"), err))?;
1877 Ok((tree, selected.clone()))
1878}
1879
1880fn map_root_error<E: std::fmt::Display>(context: &str, err: E) -> LadduError {
1881 LadduError::Custom(format!("{context}: {err}")) }
1883
1884#[derive(Clone, Copy)]
1885enum FloatColumn<'a> {
1886 F32(&'a Float32Array),
1887 F64(&'a Float64Array),
1888}
1889
1890impl<'a> FloatColumn<'a> {
1891 fn value(&self, row: usize) -> f64 {
1892 match self {
1893 Self::F32(array) => array.value(row) as f64,
1894 Self::F64(array) => array.value(row),
1895 }
1896 }
1897}
1898
1899struct P4Columns<'a> {
1900 px: FloatColumn<'a>,
1901 py: FloatColumn<'a>,
1902 pz: FloatColumn<'a>,
1903 e: FloatColumn<'a>,
1904}
1905
1906fn prepare_float_column<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<FloatColumn<'a>> {
1907 prepare_float_column_from_candidates(batch, &[name.to_string()], name)
1908}
1909
1910fn prepare_p4_columns<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<P4Columns<'a>> {
1911 Ok(P4Columns {
1912 px: prepare_float_column_from_candidates(
1913 batch,
1914 &component_candidates(name, "px"),
1915 &format!("{name}_px"),
1916 )?,
1917 py: prepare_float_column_from_candidates(
1918 batch,
1919 &component_candidates(name, "py"),
1920 &format!("{name}_py"),
1921 )?,
1922 pz: prepare_float_column_from_candidates(
1923 batch,
1924 &component_candidates(name, "pz"),
1925 &format!("{name}_pz"),
1926 )?,
1927 e: prepare_float_column_from_candidates(
1928 batch,
1929 &component_candidates(name, "e"),
1930 &format!("{name}_e"),
1931 )?,
1932 })
1933}
1934
1935fn component_candidates(name: &str, suffix: &str) -> Vec<String> {
1936 let mut candidates = Vec::with_capacity(3);
1937 let base = format!("{name}_{suffix}");
1938 candidates.push(base.clone());
1939
1940 let mut capitalized = suffix.to_string();
1941 if let Some(first) = capitalized.get_mut(0..1) {
1942 first.make_ascii_uppercase();
1943 }
1944 if capitalized != suffix {
1945 candidates.push(format!("{name}_{capitalized}"));
1946 }
1947
1948 let upper = suffix.to_ascii_uppercase();
1949 if upper != suffix && upper != capitalized {
1950 candidates.push(format!("{name}_{upper}"));
1951 }
1952
1953 candidates
1954}
1955
1956fn find_float_column_from_candidates<'a>(
1957 batch: &'a RecordBatch,
1958 candidates: &[String],
1959) -> LadduResult<Option<FloatColumn<'a>>> {
1960 use arrow::datatypes::DataType;
1961
1962 for candidate in candidates {
1963 if let Some(column) = batch.column_by_name(candidate) {
1964 return match column.data_type() {
1965 DataType::Float32 => Ok(Some(FloatColumn::F32(
1966 column
1967 .as_any()
1968 .downcast_ref::<Float32Array>()
1969 .expect("Column advertised as Float32 but could not be downcast"),
1970 ))),
1971 DataType::Float64 => Ok(Some(FloatColumn::F64(
1972 column
1973 .as_any()
1974 .downcast_ref::<Float64Array>()
1975 .expect("Column advertised as Float64 but could not be downcast"),
1976 ))),
1977 other => {
1978 return Err(LadduError::InvalidColumnType {
1979 name: candidate.clone(),
1980 datatype: other.to_string(),
1981 })
1982 }
1983 };
1984 }
1985 }
1986 Ok(None)
1987}
1988
1989fn prepare_float_column_from_candidates<'a>(
1990 batch: &'a RecordBatch,
1991 candidates: &[String],
1992 logical_name: &str,
1993) -> LadduResult<FloatColumn<'a>> {
1994 find_float_column_from_candidates(batch, candidates)?.ok_or_else(|| LadduError::MissingColumn {
1995 name: logical_name.to_string(),
1996 })
1997}
1998
1999fn record_batch_to_events(
2000 batch: RecordBatch,
2001 p4_names: &[String],
2002 aux_names: &[String],
2003) -> LadduResult<Vec<Arc<EventData>>> {
2004 let batch_ref = &batch;
2005 let p4_columns: Vec<P4Columns<'_>> = p4_names
2006 .iter()
2007 .map(|name| prepare_p4_columns(batch_ref, name))
2008 .collect::<Result<_, _>>()?;
2009
2010 let aux_columns: Vec<FloatColumn<'_>> = aux_names
2011 .iter()
2012 .map(|name| prepare_float_column(batch_ref, name))
2013 .collect::<Result<_, _>>()?;
2014
2015 let weight_column = find_float_column_from_candidates(batch_ref, &["weight".to_string()])?;
2016
2017 let mut events = Vec::with_capacity(batch_ref.num_rows());
2018 for row in 0..batch_ref.num_rows() {
2019 let mut p4s = Vec::with_capacity(p4_columns.len());
2020 for columns in &p4_columns {
2021 let px = columns.px.value(row);
2022 let py = columns.py.value(row);
2023 let pz = columns.pz.value(row);
2024 let e = columns.e.value(row);
2025 p4s.push(Vec4::new(px, py, pz, e));
2026 }
2027
2028 let mut aux = Vec::with_capacity(aux_columns.len());
2029 for column in &aux_columns {
2030 aux.push(column.value(row));
2031 }
2032
2033 let event = EventData {
2034 p4s,
2035 aux,
2036 weight: weight_column
2037 .as_ref()
2038 .map(|column| column.value(row))
2039 .unwrap_or(1.0),
2040 };
2041 events.push(Arc::new(event));
2042 }
2043 Ok(events)
2044}
2045
2046struct ColumnBuffers {
2047 p4: Vec<P4Buffer>,
2048 aux: Vec<Vec<f64>>,
2049 weight: Vec<f64>,
2050}
2051
2052impl ColumnBuffers {
2053 fn new(n_p4: usize, n_aux: usize) -> Self {
2054 let p4 = (0..n_p4).map(|_| P4Buffer::default()).collect();
2055 let aux = vec![Vec::new(); n_aux];
2056 Self {
2057 p4,
2058 aux,
2059 weight: Vec::new(),
2060 }
2061 }
2062
2063 fn push_event(&mut self, event: &Event) {
2064 for (buffer, p4) in self.p4.iter_mut().zip(event.p4s.iter()) {
2065 buffer.px.push(p4.x);
2066 buffer.py.push(p4.y);
2067 buffer.pz.push(p4.z);
2068 buffer.e.push(p4.t);
2069 }
2070
2071 for (buffer, value) in self.aux.iter_mut().zip(event.aux.iter()) {
2072 buffer.push(*value);
2073 }
2074
2075 self.weight.push(event.weight);
2076 }
2077
2078 fn into_record_batch(
2079 self,
2080 schema: Arc<Schema>,
2081 precision: FloatPrecision,
2082 ) -> arrow::error::Result<RecordBatch> {
2083 let mut columns: Vec<arrow::array::ArrayRef> = Vec::new();
2084
2085 match precision {
2086 FloatPrecision::F64 => {
2087 for buffer in &self.p4 {
2088 columns.push(Arc::new(Float64Array::from(buffer.px.clone())));
2089 columns.push(Arc::new(Float64Array::from(buffer.py.clone())));
2090 columns.push(Arc::new(Float64Array::from(buffer.pz.clone())));
2091 columns.push(Arc::new(Float64Array::from(buffer.e.clone())));
2092 }
2093
2094 for buffer in &self.aux {
2095 columns.push(Arc::new(Float64Array::from(buffer.clone())));
2096 }
2097
2098 columns.push(Arc::new(Float64Array::from(self.weight)));
2099 }
2100 FloatPrecision::F32 => {
2101 for buffer in &self.p4 {
2102 columns.push(Arc::new(Float32Array::from(
2103 buffer.px.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2104 )));
2105 columns.push(Arc::new(Float32Array::from(
2106 buffer.py.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2107 )));
2108 columns.push(Arc::new(Float32Array::from(
2109 buffer.pz.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2110 )));
2111 columns.push(Arc::new(Float32Array::from(
2112 buffer.e.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2113 )));
2114 }
2115
2116 for buffer in &self.aux {
2117 columns.push(Arc::new(Float32Array::from(
2118 buffer.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2119 )));
2120 }
2121
2122 columns.push(Arc::new(Float32Array::from(
2123 self.weight.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2124 )));
2125 }
2126 }
2127
2128 RecordBatch::try_new(schema, columns)
2129 }
2130}
2131
2132#[derive(Default)]
2133struct P4Buffer {
2134 px: Vec<f64>,
2135 py: Vec<f64>,
2136 pz: Vec<f64>,
2137 e: Vec<f64>,
2138}
2139
2140fn build_parquet_schema(metadata: &DatasetMetadata, precision: FloatPrecision) -> Schema {
2141 let dtype = match precision {
2142 FloatPrecision::F64 => DataType::Float64,
2143 FloatPrecision::F32 => DataType::Float32,
2144 };
2145
2146 let mut fields = Vec::new();
2147 for name in &metadata.p4_names {
2148 for suffix in P4_COMPONENT_SUFFIXES {
2149 fields.push(Field::new(format!("{name}{suffix}"), dtype.clone(), false));
2150 }
2151 }
2152
2153 for name in &metadata.aux_names {
2154 fields.push(Field::new(name.clone(), dtype.clone(), false));
2155 }
2156
2157 fields.push(Field::new("weight", dtype, false));
2158 Schema::new(fields)
2159}
2160
2161trait FromF64 {
2162 fn from_f64(value: f64) -> Self;
2163}
2164
2165impl FromF64 for f64 {
2166 fn from_f64(value: f64) -> Self {
2167 value
2168 }
2169}
2170
2171impl FromF64 for f32 {
2172 fn from_f64(value: f64) -> Self {
2173 value as f32
2174 }
2175}
2176
2177struct SharedEventFetcher {
2178 dataset: Arc<Dataset>,
2179 world: Option<WorldHandle>,
2180 total: usize,
2181 branch_count: usize,
2182 current_index: Option<usize>,
2183 current_event: Option<Event>,
2184 remaining: usize,
2185}
2186
2187impl SharedEventFetcher {
2188 fn new(
2189 dataset: Arc<Dataset>,
2190 world: Option<WorldHandle>,
2191 total: usize,
2192 branch_count: usize,
2193 ) -> Self {
2194 Self {
2195 dataset,
2196 world,
2197 total,
2198 branch_count,
2199 current_index: None,
2200 current_event: None,
2201 remaining: 0,
2202 }
2203 }
2204
2205 fn event_for_index(&mut self, index: usize) -> Option<Event> {
2206 if index >= self.total {
2207 return None;
2208 }
2209
2210 let refresh_needed = match self.current_index {
2211 None => true,
2212 Some(current) => current != index || self.remaining == 0,
2213 };
2214
2215 if refresh_needed {
2216 let event =
2217 fetch_event_for_index(&self.dataset, index, self.total, self.world.as_ref());
2218 self.current_index = Some(index);
2219 self.remaining = self.branch_count;
2220 self.current_event = Some(event);
2221 }
2222
2223 let event = self.current_event.as_ref().cloned();
2224 if self.remaining > 0 {
2225 self.remaining -= 1;
2226 }
2227 if self.remaining == 0 {
2228 self.current_event = None;
2230 }
2231 event
2232 }
2233}
2234
2235enum ColumnKind {
2236 Px(usize),
2237 Py(usize),
2238 Pz(usize),
2239 E(usize),
2240 Aux(usize),
2241 Weight,
2242}
2243
2244struct ColumnIterator<T> {
2245 fetcher: Arc<Mutex<SharedEventFetcher>>,
2246 index: usize,
2247 kind: ColumnKind,
2248 _marker: std::marker::PhantomData<T>,
2249}
2250
2251impl<T> ColumnIterator<T> {
2252 fn new(fetcher: Arc<Mutex<SharedEventFetcher>>, kind: ColumnKind) -> Self {
2253 Self {
2254 fetcher,
2255 index: 0,
2256 kind,
2257 _marker: std::marker::PhantomData,
2258 }
2259 }
2260}
2261
2262impl<T> Iterator for ColumnIterator<T>
2263where
2264 T: FromF64,
2265{
2266 type Item = T;
2267
2268 fn next(&mut self) -> Option<Self::Item> {
2269 let mut fetcher = self.fetcher.lock();
2270 let event = fetcher.event_for_index(self.index)?;
2271 self.index += 1;
2272
2273 match self.kind {
2274 ColumnKind::Px(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.x)),
2275 ColumnKind::Py(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.y)),
2276 ColumnKind::Pz(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.z)),
2277 ColumnKind::E(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.t)),
2278 ColumnKind::Aux(idx) => event.aux.get(idx).map(|value| T::from_f64(*value)),
2279 ColumnKind::Weight => Some(T::from_f64(event.weight)),
2280 }
2281 }
2282}
2283
2284fn build_root_column_iterators<T>(
2285 dataset: Arc<Dataset>,
2286 world: Option<WorldHandle>,
2287 branch_count: usize,
2288 total: usize,
2289) -> Vec<(String, ColumnIterator<T>)>
2290where
2291 T: FromF64,
2292{
2293 let fetcher = Arc::new(Mutex::new(SharedEventFetcher::new(
2294 dataset,
2295 world,
2296 total,
2297 branch_count,
2298 )));
2299
2300 let p4_names: Vec<String> = fetcher.lock().dataset.metadata.p4_names.clone();
2301 let aux_names: Vec<String> = fetcher.lock().dataset.metadata.aux_names.clone();
2302
2303 let mut iterators = Vec::new();
2304
2305 for (idx, name) in p4_names.iter().enumerate() {
2306 iterators.push((
2307 format!("{name}_px"),
2308 ColumnIterator::new(fetcher.clone(), ColumnKind::Px(idx)),
2309 ));
2310 iterators.push((
2311 format!("{name}_py"),
2312 ColumnIterator::new(fetcher.clone(), ColumnKind::Py(idx)),
2313 ));
2314 iterators.push((
2315 format!("{name}_pz"),
2316 ColumnIterator::new(fetcher.clone(), ColumnKind::Pz(idx)),
2317 ));
2318 iterators.push((
2319 format!("{name}_e"),
2320 ColumnIterator::new(fetcher.clone(), ColumnKind::E(idx)),
2321 ));
2322 }
2323
2324 for (idx, name) in aux_names.iter().enumerate() {
2325 iterators.push((
2326 name.clone(),
2327 ColumnIterator::new(fetcher.clone(), ColumnKind::Aux(idx)),
2328 ));
2329 }
2330
2331 iterators.push((
2332 "weight".to_string(),
2333 ColumnIterator::new(fetcher, ColumnKind::Weight),
2334 ));
2335
2336 iterators
2337}
2338
2339fn drain_column_iterators<T>(iterators: &mut [(String, ColumnIterator<T>)], n_events: usize)
2340where
2341 T: FromF64,
2342{
2343 for _ in 0..n_events {
2344 for (_name, iterator) in iterators.iter_mut() {
2345 let _ = iterator.next();
2346 }
2347 }
2348}
2349
2350fn fetch_event_for_index(
2351 dataset: &Dataset,
2352 index: usize,
2353 total: usize,
2354 world: Option<&WorldHandle>,
2355) -> Event {
2356 let _ = total;
2357 let _ = world;
2358 #[cfg(feature = "mpi")]
2359 {
2360 if let Some(world) = world {
2361 return fetch_event_mpi(dataset, index, world, total);
2362 }
2363 }
2364
2365 dataset.index_local(index).clone()
2366}
2367
2368impl Dataset {
2369 #[allow(clippy::too_many_arguments)]
2370 fn write_root_with_type<T>(
2371 &self,
2372 dataset: Arc<Dataset>,
2373 world: Option<WorldHandle>,
2374 is_root: bool,
2375 file_path: &Path,
2376 tree_name: &str,
2377 branch_count: usize,
2378 total_events: usize,
2379 ) -> LadduResult<()>
2380 where
2381 T: FromF64 + oxyroot::Marshaler + 'static,
2382 {
2383 let mut iterators =
2384 build_root_column_iterators::<T>(dataset, world, branch_count, total_events);
2385
2386 if is_root {
2387 let mut file = RootFile::create(file_path).map_err(|err| {
2388 LadduError::Custom(format!(
2389 "Failed to create ROOT file '{}': {err}",
2390 file_path.display()
2391 ))
2392 })?;
2393
2394 let mut tree = WriterTree::new(tree_name);
2395 for (name, iterator) in iterators {
2396 tree.new_branch(name, iterator);
2397 }
2398
2399 tree.write(&mut file).map_err(|err| {
2400 LadduError::Custom(format!(
2401 "Failed to write ROOT tree '{tree_name}' to '{}': {err}",
2402 file_path.display()
2403 ))
2404 })?;
2405
2406 file.close().map_err(|err| {
2407 LadduError::Custom(format!(
2408 "Failed to close ROOT file '{}': {err}",
2409 file_path.display()
2410 ))
2411 })?;
2412 } else {
2413 drain_column_iterators(&mut iterators, total_events);
2414 }
2415
2416 Ok(())
2417 }
2418}
2419
2420pub struct BinnedDataset {
2422 datasets: Vec<Arc<Dataset>>,
2423 edges: Vec<f64>,
2424}
2425
2426impl Index<usize> for BinnedDataset {
2427 type Output = Arc<Dataset>;
2428
2429 fn index(&self, index: usize) -> &Self::Output {
2430 &self.datasets[index]
2431 }
2432}
2433
2434impl IndexMut<usize> for BinnedDataset {
2435 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2436 &mut self.datasets[index]
2437 }
2438}
2439
2440impl Deref for BinnedDataset {
2441 type Target = Vec<Arc<Dataset>>;
2442
2443 fn deref(&self) -> &Self::Target {
2444 &self.datasets
2445 }
2446}
2447
2448impl DerefMut for BinnedDataset {
2449 fn deref_mut(&mut self) -> &mut Self::Target {
2450 &mut self.datasets
2451 }
2452}
2453
2454impl BinnedDataset {
2455 pub fn n_bins(&self) -> usize {
2457 self.datasets.len()
2458 }
2459
2460 pub fn edges(&self) -> Vec<f64> {
2462 self.edges.clone()
2463 }
2464
2465 pub fn range(&self) -> (f64, f64) {
2467 (self.edges[0], self.edges[self.n_bins()])
2468 }
2469}
2470
2471#[cfg(test)]
2472mod tests {
2473 use crate::Mass;
2474
2475 use super::*;
2476 use crate::utils::vectors::Vec3;
2477 use approx::{assert_relative_eq, assert_relative_ne};
2478 use fastrand;
2479 use serde::{Deserialize, Serialize};
2480 use std::{
2481 env, fs,
2482 path::{Path, PathBuf},
2483 };
2484
2485 fn test_data_path(file: &str) -> PathBuf {
2486 Path::new(env!("CARGO_MANIFEST_DIR"))
2487 .join("test_data")
2488 .join(file)
2489 }
2490
2491 fn open_test_dataset(file: &str, options: DatasetReadOptions) -> Arc<Dataset> {
2492 let path = test_data_path(file);
2493 let path_str = path.to_str().expect("test data path should be valid UTF-8");
2494 let ext = path
2495 .extension()
2496 .and_then(|ext| ext.to_str())
2497 .unwrap_or_default()
2498 .to_ascii_lowercase();
2499 match ext.as_str() {
2500 "parquet" => read_parquet(path_str, &options),
2501 "root" => read_root(path_str, &options),
2502 other => panic!("Unsupported extension in test data: {other}"),
2503 }
2504 .expect("dataset should open")
2505 }
2506
2507 fn make_temp_dir() -> PathBuf {
2508 let dir = env::temp_dir().join(format!("laddu_test_{}", fastrand::u64(..)));
2509 fs::create_dir(&dir).expect("temp dir should be created");
2510 dir
2511 }
2512
2513 fn assert_events_close(left: &Event, right: &Event, p4_names: &[&str], aux_names: &[&str]) {
2514 for name in p4_names {
2515 let lp4 = left
2516 .p4(name)
2517 .unwrap_or_else(|| panic!("missing p4 '{name}' in left dataset"));
2518 let rp4 = right
2519 .p4(name)
2520 .unwrap_or_else(|| panic!("missing p4 '{name}' in right dataset"));
2521 assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-9);
2522 assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-9);
2523 assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-9);
2524 assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-9);
2525 }
2526 let left_aux = left.aux();
2527 let right_aux = right.aux();
2528 for name in aux_names {
2529 let laux = left_aux
2530 .get(name)
2531 .copied()
2532 .unwrap_or_else(|| panic!("missing aux '{name}' in left dataset"));
2533 let raux = right_aux
2534 .get(name)
2535 .copied()
2536 .unwrap_or_else(|| panic!("missing aux '{name}' in right dataset"));
2537 assert_relative_eq!(laux, raux, epsilon = 1e-9);
2538 }
2539 assert_relative_eq!(left.weight(), right.weight(), epsilon = 1e-9);
2540 }
2541
2542 fn assert_datasets_close(
2543 left: &Arc<Dataset>,
2544 right: &Arc<Dataset>,
2545 p4_names: &[&str],
2546 aux_names: &[&str],
2547 ) {
2548 assert_eq!(left.n_events(), right.n_events());
2549 for idx in 0..left.n_events() {
2550 let levent = &left[idx];
2551 let revent = &right[idx];
2552 assert_events_close(levent, revent, p4_names, aux_names);
2553 }
2554 }
2555
2556 #[test]
2557 fn test_from_parquet_auto_matches_explicit_names() {
2558 let auto = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2559 let explicit_options = DatasetReadOptions::new()
2560 .p4_names(TEST_P4_NAMES)
2561 .aux_names(TEST_AUX_NAMES);
2562 let explicit = open_test_dataset("data_f32.parquet", explicit_options);
2563
2564 let mut detected_p4: Vec<&str> = auto.p4_names().iter().map(String::as_str).collect();
2565 detected_p4.sort_unstable();
2566 let mut expected_p4 = TEST_P4_NAMES.to_vec();
2567 expected_p4.sort_unstable();
2568 assert_eq!(detected_p4, expected_p4);
2569 let mut detected_aux: Vec<&str> = auto.aux_names().iter().map(String::as_str).collect();
2570 detected_aux.sort_unstable();
2571 let mut expected_aux = TEST_AUX_NAMES.to_vec();
2572 expected_aux.sort_unstable();
2573 assert_eq!(detected_aux, expected_aux);
2574 assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2575 }
2576
2577 #[test]
2578 fn test_from_parquet_with_aliases() {
2579 let dataset = open_test_dataset(
2580 "data_f32.parquet",
2581 DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2582 );
2583 let event = dataset.named_event(0);
2584 let alias_vec = event.p4("resonance").expect("alias vector");
2585 let expected = event.get_p4_sum(["kshort1", "kshort2"]);
2586 assert_relative_eq!(alias_vec.px(), expected.px(), epsilon = 1e-9);
2587 assert_relative_eq!(alias_vec.py(), expected.py(), epsilon = 1e-9);
2588 assert_relative_eq!(alias_vec.pz(), expected.pz(), epsilon = 1e-9);
2589 assert_relative_eq!(alias_vec.e(), expected.e(), epsilon = 1e-9);
2590 }
2591
2592 #[test]
2593 fn test_from_parquet_f64_matches_f32() {
2594 let f32_ds = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2595 let f64_ds = open_test_dataset("data_f64.parquet", DatasetReadOptions::new());
2596 assert_datasets_close(&f64_ds, &f32_ds, TEST_P4_NAMES, TEST_AUX_NAMES);
2597 }
2598
2599 #[test]
2600 fn test_from_root_detects_columns_and_matches_parquet() {
2601 let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2602 let root_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2603 let mut detected_p4: Vec<&str> = root_auto.p4_names().iter().map(String::as_str).collect();
2604 detected_p4.sort_unstable();
2605 let mut expected_p4 = TEST_P4_NAMES.to_vec();
2606 expected_p4.sort_unstable();
2607 assert_eq!(detected_p4, expected_p4);
2608 let mut detected_aux: Vec<&str> =
2609 root_auto.aux_names().iter().map(String::as_str).collect();
2610 detected_aux.sort_unstable();
2611 let mut expected_aux = TEST_AUX_NAMES.to_vec();
2612 expected_aux.sort_unstable();
2613 assert_eq!(detected_aux, expected_aux);
2614 let root_named_options = DatasetReadOptions::new()
2615 .p4_names(TEST_P4_NAMES)
2616 .aux_names(TEST_AUX_NAMES);
2617 let root_named = open_test_dataset("data_f32.root", root_named_options);
2618 assert_datasets_close(&root_auto, &root_named, TEST_P4_NAMES, TEST_AUX_NAMES);
2619 assert_datasets_close(&root_auto, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2620 }
2621
2622 #[test]
2623 fn test_from_root_f64_matches_parquet() {
2624 let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2625 let root_f64 = open_test_dataset("data_f64.root", DatasetReadOptions::new());
2626 assert_datasets_close(&root_f64, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2627 }
2628 #[test]
2629 fn test_event_creation() {
2630 let event = test_event();
2631 assert_eq!(event.p4s.len(), 4);
2632 assert_eq!(event.aux.len(), 2);
2633 assert_relative_eq!(event.weight, 0.48)
2634 }
2635
2636 #[test]
2637 fn test_event_p4_sum() {
2638 let event = test_event();
2639 let sum = event.get_p4_sum([2, 3]);
2640 assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
2641 assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
2642 assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
2643 assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
2644 }
2645
2646 #[test]
2647 fn test_event_boost() {
2648 let event = test_event();
2649 let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
2650 let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
2651 assert_relative_eq!(p4_sum.px(), 0.0);
2652 assert_relative_eq!(p4_sum.py(), 0.0);
2653 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2654 }
2655
2656 #[test]
2657 fn test_event_evaluate() {
2658 let event = test_event();
2659 let mut mass = Mass::new(["proton"]);
2660 mass.bind(
2661 &DatasetMetadata::new(
2662 TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
2663 TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
2664 )
2665 .expect("metadata"),
2666 )
2667 .unwrap();
2668 assert_relative_eq!(event.evaluate(&mass), 1.007);
2669 }
2670
2671 #[test]
2672 fn test_dataset_size_check() {
2673 let dataset = Dataset::new(Vec::new());
2674 assert_eq!(dataset.n_events(), 0);
2675 let dataset = Dataset::new(vec![Arc::new(test_event())]);
2676 assert_eq!(dataset.n_events(), 1);
2677 }
2678
2679 #[test]
2680 fn test_dataset_sum() {
2681 let dataset = test_dataset();
2682 let metadata = dataset.metadata_arc();
2683 let dataset2 = Dataset::new_with_metadata(
2684 vec![Arc::new(EventData {
2685 p4s: test_event().p4s,
2686 aux: test_event().aux,
2687 weight: 0.52,
2688 })],
2689 metadata.clone(),
2690 );
2691 let dataset_sum = &dataset + &dataset2;
2692 assert_eq!(dataset_sum[0].weight, dataset[0].weight);
2693 assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
2694 }
2695
2696 #[test]
2697 fn test_dataset_weights() {
2698 let dataset = Dataset::new(vec![
2699 Arc::new(test_event()),
2700 Arc::new(EventData {
2701 p4s: test_event().p4s,
2702 aux: test_event().aux,
2703 weight: 0.52,
2704 }),
2705 ]);
2706 let weights = dataset.weights();
2707 assert_eq!(weights.len(), 2);
2708 assert_relative_eq!(weights[0], 0.48);
2709 assert_relative_eq!(weights[1], 0.52);
2710 assert_relative_eq!(dataset.n_events_weighted(), 1.0);
2711 }
2712
2713 #[test]
2714 fn test_dataset_filtering() {
2715 let metadata = Arc::new(
2716 DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2717 .expect("metadata should be valid"),
2718 );
2719 let events = vec![
2720 Arc::new(EventData {
2721 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
2722 aux: vec![],
2723 weight: 1.0,
2724 }),
2725 Arc::new(EventData {
2726 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
2727 aux: vec![],
2728 weight: 1.0,
2729 }),
2730 Arc::new(EventData {
2731 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
2732 aux: vec![],
2735 weight: 1.0,
2736 }),
2737 ];
2738 let dataset = Dataset::new_with_metadata(events, metadata);
2739
2740 let metadata = dataset.metadata_arc();
2741 let mut mass = Mass::new(["beam"]);
2742 mass.bind(metadata.as_ref()).unwrap();
2743 let expression = mass.gt(0.0).and(&mass.lt(1.0));
2744
2745 let filtered = dataset.filter(&expression).unwrap();
2746 assert_eq!(filtered.n_events(), 1);
2747 assert_relative_eq!(mass.value(&filtered[0]), 0.5);
2748 }
2749
2750 #[test]
2751 fn test_dataset_boost() {
2752 let dataset = test_dataset();
2753 let dataset_boosted = dataset.boost_to_rest_frame_of(&["proton", "kshort1", "kshort2"]);
2754 let p4_sum = dataset_boosted[0].get_p4_sum(["proton", "kshort1", "kshort2"]);
2755 assert_relative_eq!(p4_sum.px(), 0.0);
2756 assert_relative_eq!(p4_sum.py(), 0.0);
2757 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2758 }
2759
2760 #[test]
2761 fn test_named_event_view() {
2762 let dataset = test_dataset();
2763 let view = dataset.named_event(0);
2764
2765 assert_relative_eq!(view.weight(), dataset[0].weight);
2766 let beam = view.p4("beam").expect("beam p4");
2767 assert_relative_eq!(beam.px(), dataset[0].p4s[0].px());
2768 assert_relative_eq!(beam.e(), dataset[0].p4s[0].e());
2769
2770 let summed = view.get_p4_sum(["kshort1", "kshort2"]);
2771 assert_relative_eq!(summed.e(), dataset[0].p4s[2].e() + dataset[0].p4s[3].e());
2772
2773 let aux_angle = view.aux().get("pol_angle").copied().expect("pol angle");
2774 assert_relative_eq!(aux_angle, dataset[0].aux[1]);
2775
2776 let metadata = dataset.metadata_arc();
2777 let boosted = view.boost_to_rest_frame_of(["proton", "kshort1", "kshort2"]);
2778 let boosted_event = Event::new(Arc::new(boosted), metadata);
2779 let boosted_sum = boosted_event.get_p4_sum(["proton", "kshort1", "kshort2"]);
2780 assert_relative_eq!(boosted_sum.px(), 0.0);
2781 }
2782
2783 #[test]
2784 fn test_dataset_evaluate() {
2785 let dataset = test_dataset();
2786 let mass = Mass::new(["proton"]);
2787 assert_relative_eq!(dataset.evaluate(&mass).unwrap()[0], 1.007);
2788 }
2789
2790 #[test]
2791 fn test_dataset_metadata_rejects_duplicate_names() {
2792 let err = DatasetMetadata::new(vec!["beam", "beam"], Vec::<String>::new());
2793 assert!(matches!(
2794 err,
2795 Err(LadduError::DuplicateName { category, .. }) if category == "p4"
2796 ));
2797 let err = DatasetMetadata::new(
2798 vec!["beam"],
2799 vec!["pol_angle".to_string(), "pol_angle".to_string()],
2800 );
2801 assert!(matches!(
2802 err,
2803 Err(LadduError::DuplicateName { category, .. }) if category == "aux"
2804 ));
2805 }
2806
2807 #[test]
2808 fn test_dataset_lookup_by_name() {
2809 let dataset = test_dataset();
2810 let proton = dataset.p4_by_name(0, "proton").expect("proton p4");
2811 let proton_idx = dataset.metadata().p4_index("proton").unwrap();
2812 assert_relative_eq!(proton.e(), dataset[0].p4s[proton_idx].e());
2813 assert!(dataset.p4_by_name(0, "unknown").is_none());
2814 let angle = dataset.aux_by_name(0, "pol_angle").expect("pol_angle");
2815 assert_relative_eq!(angle, dataset[0].aux[1]);
2816 assert!(dataset.aux_by_name(0, "missing").is_none());
2817 }
2818
2819 #[test]
2820 fn test_binned_dataset() {
2821 let dataset = Dataset::new(vec![
2822 Arc::new(EventData {
2823 p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
2824 aux: vec![],
2825 weight: 1.0,
2826 }),
2827 Arc::new(EventData {
2828 p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
2829 aux: vec![],
2830 weight: 2.0,
2831 }),
2832 ]);
2833
2834 #[derive(Clone, Serialize, Deserialize, Debug)]
2835 struct BeamEnergy;
2836 impl Display for BeamEnergy {
2837 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2838 write!(f, "BeamEnergy")
2839 }
2840 }
2841 #[typetag::serde]
2842 impl Variable for BeamEnergy {
2843 fn value(&self, event: &EventData) -> f64 {
2844 event.p4s[0].e()
2845 }
2846 }
2847 assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
2848
2849 let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0)).unwrap();
2851
2852 assert_eq!(binned.n_bins(), 2);
2853 assert_eq!(binned.edges().len(), 3);
2854 assert_relative_eq!(binned.edges()[0], 0.0);
2855 assert_relative_eq!(binned.edges()[2], 3.0);
2856 assert_eq!(binned[0].n_events(), 1);
2857 assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
2858 assert_eq!(binned[1].n_events(), 1);
2859 assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
2860 }
2861
2862 #[test]
2863 fn test_dataset_bootstrap() {
2864 let metadata = test_dataset().metadata_arc();
2865 let dataset = Dataset::new_with_metadata(
2866 vec![
2867 Arc::new(test_event()),
2868 Arc::new(EventData {
2869 p4s: test_event().p4s.clone(),
2870 aux: test_event().aux.clone(),
2871 weight: 1.0,
2872 }),
2873 ],
2874 metadata,
2875 );
2876 assert_relative_ne!(dataset[0].weight, dataset[1].weight);
2877
2878 let bootstrapped = dataset.bootstrap(43);
2879 assert_eq!(bootstrapped.n_events(), dataset.n_events());
2880 assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
2881
2882 let empty_dataset = Dataset::new(Vec::new());
2884 let empty_bootstrap = empty_dataset.bootstrap(43);
2885 assert_eq!(empty_bootstrap.n_events(), 0);
2886 }
2887
2888 #[test]
2889 fn test_dataset_iteration_returns_events() {
2890 let dataset = test_dataset();
2891 let mut weights = Vec::new();
2892 for event in dataset.iter() {
2893 weights.push(event.weight());
2894 }
2895 assert_eq!(weights.len(), dataset.n_events());
2896 assert_relative_eq!(weights[0], dataset[0].weight);
2897 }
2898
2899 #[test]
2900 fn test_dataset_into_iter_returns_events() {
2901 let dataset = test_dataset();
2902 let weights: Vec<f64> = dataset.into_iter().map(|event| event.weight()).collect();
2903 assert_eq!(weights.len(), 1);
2904 assert_relative_eq!(weights[0], test_event().weight);
2905 }
2906 #[test]
2907 fn test_event_display() {
2908 let event = test_event();
2909 let display_string = format!("{}", event);
2910 assert!(display_string.contains("Event:"));
2911 assert!(display_string.contains("p4s:"));
2912 assert!(display_string.contains("aux:"));
2913 assert!(display_string.contains("aux[0]: 0.38562805"));
2914 assert!(display_string.contains("aux[1]: 0.05708078"));
2915 assert!(display_string.contains("weight:"));
2916 }
2917
2918 #[test]
2919 fn test_name_based_access() {
2920 let metadata =
2921 Arc::new(DatasetMetadata::new(vec!["beam", "target"], vec!["pol_angle"]).unwrap());
2922 let event = Arc::new(EventData {
2923 p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(0.1, 0.2, 0.3, 0.5)],
2924 aux: vec![0.42],
2925 weight: 1.0,
2926 });
2927 let dataset = Dataset::new_with_metadata(vec![event], metadata);
2928 let beam = dataset.p4_by_name(0, "beam").unwrap();
2929 assert_relative_eq!(beam.px(), 0.0);
2930 assert_relative_eq!(beam.py(), 0.0);
2931 assert_relative_eq!(beam.pz(), 1.0);
2932 assert_relative_eq!(beam.e(), 1.0);
2933 assert_relative_eq!(dataset.aux_by_name(0, "pol_angle").unwrap(), 0.42);
2934 assert!(dataset.p4_by_name(0, "missing").is_none());
2935 assert!(dataset.aux_by_name(0, "missing").is_none());
2936 }
2937
2938 #[test]
2939 fn test_parquet_roundtrip_to_tempfile() {
2940 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2941 let dir = make_temp_dir();
2942 let path = dir.join("roundtrip.parquet");
2943 let path_str = path.to_str().expect("path should be valid UTF-8");
2944
2945 write_parquet(&dataset, path_str, &DatasetWriteOptions::default())
2946 .expect("writing parquet should succeed");
2947 let reopened = read_parquet(path_str, &DatasetReadOptions::new())
2948 .expect("parquet roundtrip should reopen");
2949
2950 assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2951 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2952 }
2953
2954 #[test]
2955 fn test_root_roundtrip_to_tempfile() {
2956 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2957 let dir = make_temp_dir();
2958 let path = dir.join("roundtrip.root");
2959 let path_str = path.to_str().expect("path should be valid UTF-8");
2960
2961 write_root(&dataset, path_str, &DatasetWriteOptions::default())
2962 .expect("writing root should succeed");
2963 let reopened =
2964 read_root(path_str, &DatasetReadOptions::new()).expect("root roundtrip should reopen");
2965
2966 assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2967 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2968 }
2969}