1use accurate::{sum::Klein, traits::*};
2use arrow::{
3 array::{Float32Array, Float64Array},
4 datatypes::{DataType, Field, Schema},
5 record_batch::RecordBatch,
6};
7use auto_ops::impl_op_ex;
8use parking_lot::Mutex;
9use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter};
10use serde::{Deserialize, Serialize};
11use std::ops::{Deref, DerefMut, Index, IndexMut};
12use std::path::Path;
13use std::{fmt::Display, fs::File};
14use std::{path::PathBuf, sync::Arc};
15
16use oxyroot::{Branch, Named, ReaderTree, RootFile, WriterTree};
17
18#[cfg(feature = "rayon")]
19use rayon::prelude::*;
20
21#[cfg(feature = "mpi")]
22use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
23
24#[cfg(feature = "mpi")]
25use crate::mpi::LadduMPI;
26
27#[cfg(feature = "mpi")]
28type WorldHandle = SimpleCommunicator;
29#[cfg(not(feature = "mpi"))]
30type WorldHandle = ();
31
32use crate::utils::get_bin_edges;
33use crate::{
34 utils::{
35 variables::{IntoP4Selection, P4Selection, Variable, VariableExpression},
36 vectors::Vec4,
37 },
38 LadduError, LadduResult,
39};
40use indexmap::{IndexMap, IndexSet};
41
42pub fn test_event() -> EventData {
46 use crate::utils::vectors::*;
47 let pol_magnitude = 0.38562805;
48 let pol_angle = 0.05708078;
49 EventData {
50 p4s: vec![
51 Vec3::new(0.0, 0.0, 8.747).with_mass(0.0), Vec3::new(0.119, 0.374, 0.222).with_mass(1.007), Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498), Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), ],
56 aux: vec![pol_magnitude, pol_angle],
57 weight: 0.48,
58 }
59}
60
61pub const TEST_P4_NAMES: &[&str] = &["beam", "proton", "kshort1", "kshort2"];
63pub const TEST_AUX_NAMES: &[&str] = &["pol_magnitude", "pol_angle"];
65
66pub fn test_dataset() -> Dataset {
70 let metadata = Arc::new(
71 DatasetMetadata::new(
72 TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
73 TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
74 )
75 .expect("Test metadata should be valid"),
76 );
77 Dataset::new_with_metadata(vec![Arc::new(test_event())], metadata)
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct EventData {
87 pub p4s: Vec<Vec4>,
89 pub aux: Vec<f64>,
91 pub weight: f64,
93}
94
95impl Display for EventData {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 writeln!(f, "Event:")?;
98 writeln!(f, " p4s:")?;
99 for p4 in &self.p4s {
100 writeln!(f, " {}", p4.to_p4_string())?;
101 }
102 writeln!(f, " aux:")?;
103 for (idx, value) in self.aux.iter().enumerate() {
104 writeln!(f, " aux[{idx}]: {value}")?;
105 }
106 writeln!(f, " weight:")?;
107 writeln!(f, " {}", self.weight)?;
108 Ok(())
109 }
110}
111
112impl EventData {
113 pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
115 indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
116 }
117 pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
120 let frame = self.get_p4_sum(indices);
121 EventData {
122 p4s: self
123 .p4s
124 .iter()
125 .map(|p4| p4.boost(&(-frame.beta())))
126 .collect(),
127 aux: self.aux.clone(),
128 weight: self.weight,
129 }
130 }
131 pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
133 variable.value(self)
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct DatasetMetadata {
140 pub(crate) p4_names: Vec<String>,
141 pub(crate) aux_names: Vec<String>,
142 pub(crate) p4_lookup: IndexMap<String, usize>,
143 pub(crate) aux_lookup: IndexMap<String, usize>,
144 pub(crate) p4_selections: IndexMap<String, P4Selection>,
145}
146
147impl DatasetMetadata {
148 pub fn new<P: Into<String>, A: Into<String>>(
150 p4_names: Vec<P>,
151 aux_names: Vec<A>,
152 ) -> LadduResult<Self> {
153 let mut p4_lookup = IndexMap::with_capacity(p4_names.len());
154 let mut aux_lookup = IndexMap::with_capacity(aux_names.len());
155 let mut p4_selections = IndexMap::with_capacity(p4_names.len());
156 let p4_names: Vec<String> = p4_names
157 .into_iter()
158 .enumerate()
159 .map(|(idx, name)| {
160 let name = name.into();
161 if p4_lookup.contains_key(&name) {
162 return Err(LadduError::DuplicateName {
163 category: "p4",
164 name,
165 });
166 }
167 p4_lookup.insert(name.clone(), idx);
168 p4_selections.insert(
169 name.clone(),
170 P4Selection::with_indices(vec![name.clone()], vec![idx]),
171 );
172 Ok(name)
173 })
174 .collect::<Result<_, _>>()?;
175 let aux_names: Vec<String> = aux_names
176 .into_iter()
177 .enumerate()
178 .map(|(idx, name)| {
179 let name = name.into();
180 if aux_lookup.contains_key(&name) {
181 return Err(LadduError::DuplicateName {
182 category: "aux",
183 name,
184 });
185 }
186 aux_lookup.insert(name.clone(), idx);
187 Ok(name)
188 })
189 .collect::<Result<_, _>>()?;
190 Ok(Self {
191 p4_names,
192 aux_names,
193 p4_lookup,
194 aux_lookup,
195 p4_selections,
196 })
197 }
198
199 pub fn empty() -> Self {
201 Self {
202 p4_names: Vec::new(),
203 aux_names: Vec::new(),
204 p4_lookup: IndexMap::new(),
205 aux_lookup: IndexMap::new(),
206 p4_selections: IndexMap::new(),
207 }
208 }
209
210 pub fn p4_index(&self, name: &str) -> Option<usize> {
212 self.p4_lookup.get(name).copied()
213 }
214
215 pub fn p4_names(&self) -> &[String] {
217 &self.p4_names
218 }
219
220 pub fn aux_index(&self, name: &str) -> Option<usize> {
222 self.aux_lookup.get(name).copied()
223 }
224
225 pub fn aux_names(&self) -> &[String] {
227 &self.aux_names
228 }
229
230 pub fn p4_selection(&self, name: &str) -> Option<&P4Selection> {
232 self.p4_selections.get(name)
233 }
234
235 pub fn add_p4_alias<N>(&mut self, alias: N, mut selection: P4Selection) -> LadduResult<()>
237 where
238 N: Into<String>,
239 {
240 let alias = alias.into();
241 if self.p4_selections.contains_key(&alias) {
242 return Err(LadduError::DuplicateName {
243 category: "alias",
244 name: alias,
245 });
246 }
247 selection.bind(self)?;
248 self.p4_selections.insert(alias, selection);
249 Ok(())
250 }
251
252 pub fn add_p4_aliases<I, N>(&mut self, entries: I) -> LadduResult<()>
254 where
255 I: IntoIterator<Item = (N, P4Selection)>,
256 N: Into<String>,
257 {
258 for (alias, selection) in entries {
259 self.add_p4_alias(alias, selection)?;
260 }
261 Ok(())
262 }
263
264 pub(crate) fn append_indices_for_name(
265 &self,
266 name: &str,
267 target: &mut Vec<usize>,
268 ) -> LadduResult<()> {
269 if let Some(selection) = self.p4_selections.get(name) {
270 target.extend_from_slice(selection.indices());
271 return Ok(());
272 }
273 Err(LadduError::UnknownName {
274 category: "p4",
275 name: name.to_string(),
276 })
277 }
278}
279
280impl Default for DatasetMetadata {
281 fn default() -> Self {
282 Self::empty()
283 }
284}
285
286#[derive(Debug, Clone)]
288pub struct Dataset {
289 pub events: Vec<Event>,
291 pub(crate) metadata: Arc<DatasetMetadata>,
292}
293
294#[derive(Clone, Debug)]
296pub struct Event {
297 event: Arc<EventData>,
298 metadata: Arc<DatasetMetadata>,
299}
300
301impl Event {
302 pub fn new(event: Arc<EventData>, metadata: Arc<DatasetMetadata>) -> Self {
304 Self { event, metadata }
305 }
306
307 pub fn data(&self) -> &EventData {
309 &self.event
310 }
311
312 pub fn data_arc(&self) -> Arc<EventData> {
314 self.event.clone()
315 }
316
317 pub fn p4s(&self) -> IndexMap<&str, Vec4> {
319 let mut map = IndexMap::with_capacity(self.metadata.p4_names.len());
320 for (idx, name) in self.metadata.p4_names.iter().enumerate() {
321 if let Some(p4) = self.event.p4s.get(idx) {
322 map.insert(name.as_str(), *p4);
323 }
324 }
325 map
326 }
327
328 pub fn aux(&self) -> IndexMap<&str, f64> {
330 let mut map = IndexMap::with_capacity(self.metadata.aux_names.len());
331 for (idx, name) in self.metadata.aux_names.iter().enumerate() {
332 if let Some(value) = self.event.aux.get(idx) {
333 map.insert(name.as_str(), *value);
334 }
335 }
336 map
337 }
338
339 pub fn weight(&self) -> f64 {
341 self.event.weight
342 }
343
344 pub fn metadata(&self) -> &DatasetMetadata {
346 &self.metadata
347 }
348
349 pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
351 self.metadata.clone()
352 }
353
354 pub fn p4(&self, name: &str) -> Option<Vec4> {
356 self.metadata
357 .p4_selection(name)
358 .map(|selection| selection.momentum(&self.event))
359 }
360
361 fn resolve_p4_indices<N>(&self, names: N) -> Vec<usize>
362 where
363 N: IntoIterator,
364 N::Item: AsRef<str>,
365 {
366 let mut indices = Vec::new();
367 for name in names {
368 let name_ref = name.as_ref();
369 if let Some(selection) = self.metadata.p4_selection(name_ref) {
370 indices.extend_from_slice(selection.indices());
371 } else {
372 panic!("Unknown particle name '{name}'", name = name_ref);
373 }
374 }
375 indices
376 }
377
378 pub fn get_p4_sum<N>(&self, names: N) -> Vec4
380 where
381 N: IntoIterator,
382 N::Item: AsRef<str>,
383 {
384 let indices = self.resolve_p4_indices(names);
385 self.event.get_p4_sum(&indices)
386 }
387
388 pub fn boost_to_rest_frame_of<N>(&self, names: N) -> EventData
390 where
391 N: IntoIterator,
392 N::Item: AsRef<str>,
393 {
394 let indices = self.resolve_p4_indices(names);
395 self.event.boost_to_rest_frame_of(&indices)
396 }
397
398 pub fn evaluate<V: Variable>(&self, variable: &V) -> f64 {
400 self.event.evaluate(variable)
401 }
402}
403
404impl Deref for Event {
405 type Target = EventData;
406
407 fn deref(&self) -> &Self::Target {
408 &self.event
409 }
410}
411
412impl AsRef<EventData> for Event {
413 fn as_ref(&self) -> &EventData {
414 self.data()
415 }
416}
417
418impl IntoIterator for Dataset {
419 type Item = Event;
420
421 type IntoIter = DatasetIntoIter;
422
423 fn into_iter(self) -> Self::IntoIter {
424 #[cfg(feature = "mpi")]
425 {
426 if let Some(world) = crate::mpi::get_world() {
427 let total = self.n_events();
429 return DatasetIntoIter::Mpi(DatasetMpiIntoIter {
430 events: self.events,
431 metadata: self.metadata,
432 world,
433 index: 0,
434 total,
435 });
436 }
437 }
438 DatasetIntoIter::Local(self.events.into_iter())
439 }
440}
441
442impl Dataset {
443 pub fn iter(&self) -> DatasetIter<'_> {
446 #[cfg(feature = "mpi")]
447 {
448 if let Some(world) = crate::mpi::get_world() {
449 return DatasetIter::Mpi(DatasetMpiIter {
450 dataset: self,
451 world,
452 index: 0,
453 total: self.n_events(),
454 });
455 }
456 }
457 DatasetIter::Local(self.events.iter())
458 }
459
460 pub fn metadata(&self) -> &DatasetMetadata {
462 &self.metadata
463 }
464
465 pub fn metadata_arc(&self) -> Arc<DatasetMetadata> {
467 self.metadata.clone()
468 }
469
470 pub fn p4_names(&self) -> &[String] {
472 &self.metadata.p4_names
473 }
474
475 pub fn aux_names(&self) -> &[String] {
477 &self.metadata.aux_names
478 }
479
480 pub fn p4_index(&self, name: &str) -> Option<usize> {
482 self.metadata.p4_index(name)
483 }
484
485 pub fn aux_index(&self, name: &str) -> Option<usize> {
487 self.metadata.aux_index(name)
488 }
489
490 pub fn named_event(&self, index: usize) -> Event {
492 self.events[index].clone()
493 }
494
495 pub fn p4_by_name(&self, event_index: usize, name: &str) -> Option<Vec4> {
497 self.events
498 .get(event_index)
499 .and_then(|event| event.p4(name))
500 }
501
502 pub fn aux_by_name(&self, event_index: usize, name: &str) -> Option<f64> {
504 let idx = self.aux_index(name)?;
505 self.events
506 .get(event_index)
507 .and_then(|event| event.aux.get(idx))
508 .copied()
509 }
510
511 pub fn index_local(&self, index: usize) -> &Event {
525 &self.events[index]
526 }
527
528 #[cfg(feature = "mpi")]
529 fn partition(
530 events: Vec<Arc<EventData>>,
531 world: &SimpleCommunicator,
532 ) -> Vec<Vec<Arc<EventData>>> {
533 let (counts, displs) = world.get_counts_displs(events.len());
534 counts
535 .iter()
536 .zip(displs.iter())
537 .map(|(&count, &displ)| {
538 events
539 .iter()
540 .skip(displ as usize)
541 .take(count as usize)
542 .cloned()
543 .collect()
544 })
545 .collect()
546 }
547
548 #[cfg(feature = "mpi")]
562 pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
563 let total = self.n_events();
564 let event = fetch_event_mpi(self, index, world, total);
565 Box::leak(Box::new(event))
566 }
567}
568
569impl Index<usize> for Dataset {
570 type Output = Event;
571
572 fn index(&self, index: usize) -> &Self::Output {
573 #[cfg(feature = "mpi")]
574 {
575 if let Some(world) = crate::mpi::get_world() {
576 return self.index_mpi(index, &world);
577 }
578 }
579 self.index_local(index)
580 }
581}
582
583pub enum DatasetIter<'a> {
585 Local(std::slice::Iter<'a, Event>),
587 #[cfg(feature = "mpi")]
588 Mpi(DatasetMpiIter<'a>),
590}
591
592impl<'a> Iterator for DatasetIter<'a> {
593 type Item = Event;
594
595 fn next(&mut self) -> Option<Self::Item> {
596 match self {
597 DatasetIter::Local(iter) => iter.next().cloned(),
598 #[cfg(feature = "mpi")]
599 DatasetIter::Mpi(iter) => iter.next(),
600 }
601 }
602}
603
604pub enum DatasetIntoIter {
606 Local(std::vec::IntoIter<Event>),
608 #[cfg(feature = "mpi")]
609 Mpi(DatasetMpiIntoIter),
611}
612
613impl Iterator for DatasetIntoIter {
614 type Item = Event;
615
616 fn next(&mut self) -> Option<Self::Item> {
617 match self {
618 DatasetIntoIter::Local(iter) => iter.next(),
619 #[cfg(feature = "mpi")]
620 DatasetIntoIter::Mpi(iter) => iter.next(),
621 }
622 }
623}
624
625#[cfg(feature = "mpi")]
626pub struct DatasetMpiIter<'a> {
628 dataset: &'a Dataset,
629 world: SimpleCommunicator,
630 index: usize,
631 total: usize,
632}
633
634#[cfg(feature = "mpi")]
635impl<'a> Iterator for DatasetMpiIter<'a> {
636 type Item = Event;
637
638 fn next(&mut self) -> Option<Self::Item> {
639 if self.index >= self.total {
640 return None;
641 }
642 let event = fetch_event_mpi(self.dataset, self.index, &self.world, self.total);
643 self.index += 1;
644 Some(event)
645 }
646}
647
648#[cfg(feature = "mpi")]
649pub struct DatasetMpiIntoIter {
651 events: Vec<Event>,
652 metadata: Arc<DatasetMetadata>,
653 world: SimpleCommunicator,
654 index: usize,
655 total: usize,
656}
657
658#[cfg(feature = "mpi")]
659impl Iterator for DatasetMpiIntoIter {
660 type Item = Event;
661
662 fn next(&mut self) -> Option<Self::Item> {
663 if self.index >= self.total {
664 return None;
665 }
666 let event = fetch_event_mpi_from_events(
667 &self.events,
668 &self.metadata,
669 self.index,
670 &self.world,
671 self.total,
672 );
673 self.index += 1;
674 Some(event)
675 }
676}
677
678#[cfg(feature = "mpi")]
679fn fetch_event_mpi(
680 dataset: &Dataset,
681 global_index: usize,
682 world: &SimpleCommunicator,
683 total: usize,
684) -> Event {
685 fetch_event_mpi_generic(
686 global_index,
687 total,
688 world,
689 &dataset.metadata,
690 |local_index| dataset.index_local(local_index),
691 )
692}
693
694#[cfg(feature = "mpi")]
695fn fetch_event_mpi_from_events(
696 events: &[Event],
697 metadata: &Arc<DatasetMetadata>,
698 global_index: usize,
699 world: &SimpleCommunicator,
700 total: usize,
701) -> Event {
702 fetch_event_mpi_generic(global_index, total, world, metadata, |local_index| {
703 &events[local_index]
704 })
705}
706
707#[cfg(feature = "mpi")]
708fn fetch_event_mpi_generic<'a, F>(
709 global_index: usize,
710 total: usize,
711 world: &SimpleCommunicator,
712 metadata: &Arc<DatasetMetadata>,
713 local_event: F,
714) -> Event
715where
716 F: Fn(usize) -> &'a Event,
717{
718 let (owning_rank, local_index) = world.owner_of_global_index(global_index, total);
719 let mut serialized_event_buffer_len: usize = 0;
720 let mut serialized_event_buffer: Vec<u8> = Vec::default();
721 let config = bincode::config::standard();
722 if world.rank() == owning_rank {
723 let event = local_event(local_index);
724 serialized_event_buffer = bincode::serde::encode_to_vec(event.data(), config).unwrap();
725 serialized_event_buffer_len = serialized_event_buffer.len();
726 }
727 world
728 .process_at_rank(owning_rank)
729 .broadcast_into(&mut serialized_event_buffer_len);
730 if world.rank() != owning_rank {
731 serialized_event_buffer = vec![0; serialized_event_buffer_len];
732 }
733 world
734 .process_at_rank(owning_rank)
735 .broadcast_into(&mut serialized_event_buffer);
736
737 if world.rank() == owning_rank {
738 local_event(local_index).clone()
739 } else {
740 let (event, _): (EventData, usize) =
741 bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
742 Event::new(Arc::new(event), metadata.clone())
743 }
744}
745
746impl Dataset {
747 pub fn new_local(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
754 let wrapped_events = events
755 .into_iter()
756 .map(|event| Event::new(event, metadata.clone()))
757 .collect();
758 Dataset {
759 events: wrapped_events,
760 metadata,
761 }
762 }
763
764 #[cfg(feature = "mpi")]
771 pub fn new_mpi(
772 events: Vec<Arc<EventData>>,
773 metadata: Arc<DatasetMetadata>,
774 world: &SimpleCommunicator,
775 ) -> Self {
776 let partitions = Dataset::partition(events, world);
777 let local = partitions[world.rank() as usize]
778 .iter()
779 .cloned()
780 .map(|event| Event::new(event, metadata.clone()))
781 .collect();
782 Dataset {
783 events: local,
784 metadata,
785 }
786 }
787
788 pub fn new(events: Vec<Arc<EventData>>) -> Self {
794 Dataset::new_with_metadata(events, Arc::new(DatasetMetadata::default()))
795 }
796
797 pub fn new_with_metadata(events: Vec<Arc<EventData>>, metadata: Arc<DatasetMetadata>) -> Self {
800 #[cfg(feature = "mpi")]
801 {
802 if let Some(world) = crate::mpi::get_world() {
803 return Dataset::new_mpi(events, metadata, &world);
804 }
805 }
806 Dataset::new_local(events, metadata)
807 }
808
809 pub fn n_events_local(&self) -> usize {
816 self.events.len()
817 }
818
819 #[cfg(feature = "mpi")]
826 pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
827 let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
828 let n_events_local = self.n_events_local();
829 world.all_gather_into(&n_events_local, &mut n_events_partitioned);
830 n_events_partitioned.iter().sum()
831 }
832
833 pub fn n_events(&self) -> usize {
835 #[cfg(feature = "mpi")]
836 {
837 if let Some(world) = crate::mpi::get_world() {
838 return self.n_events_mpi(&world);
839 }
840 }
841 self.n_events_local()
842 }
843}
844
845impl Dataset {
846 pub fn weights_local(&self) -> Vec<f64> {
853 #[cfg(feature = "rayon")]
854 return self.events.par_iter().map(|e| e.weight).collect();
855 #[cfg(not(feature = "rayon"))]
856 return self.events.iter().map(|e| e.weight).collect();
857 }
858
859 #[cfg(feature = "mpi")]
866 pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<f64> {
867 let local_weights = self.weights_local();
868 let n_events = self.n_events();
869 let mut buffer: Vec<f64> = vec![0.0; n_events];
870 let (counts, displs) = world.get_counts_displs(n_events);
871 {
872 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
873 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
874 }
875 buffer
876 }
877
878 pub fn weights(&self) -> Vec<f64> {
880 #[cfg(feature = "mpi")]
881 {
882 if let Some(world) = crate::mpi::get_world() {
883 return self.weights_mpi(&world);
884 }
885 }
886 self.weights_local()
887 }
888
889 pub fn n_events_weighted_local(&self) -> f64 {
896 #[cfg(feature = "rayon")]
897 return self
898 .events
899 .par_iter()
900 .map(|e| e.weight)
901 .parallel_sum_with_accumulator::<Klein<f64>>();
902 #[cfg(not(feature = "rayon"))]
903 return self.events.iter().map(|e| e.weight).sum();
904 }
905 #[cfg(feature = "mpi")]
912 pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> f64 {
913 let mut n_events_weighted_partitioned: Vec<f64> = vec![0.0; world.size() as usize];
914 let n_events_weighted_local = self.n_events_weighted_local();
915 world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
916 #[cfg(feature = "rayon")]
917 return n_events_weighted_partitioned
918 .into_par_iter()
919 .parallel_sum_with_accumulator::<Klein<f64>>();
920 #[cfg(not(feature = "rayon"))]
921 return n_events_weighted_partitioned.iter().sum();
922 }
923
924 pub fn n_events_weighted(&self) -> f64 {
926 #[cfg(feature = "mpi")]
927 {
928 if let Some(world) = crate::mpi::get_world() {
929 return self.n_events_weighted_mpi(&world);
930 }
931 }
932 self.n_events_weighted_local()
933 }
934
935 pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
943 let mut rng = fastrand::Rng::with_seed(seed as u64);
944 let mut indices: Vec<usize> = (0..self.n_events())
945 .map(|_| rng.usize(0..self.n_events()))
946 .collect::<Vec<usize>>();
947 indices.sort();
948 #[cfg(feature = "rayon")]
949 let bootstrapped_events: Vec<Arc<EventData>> = indices
950 .into_par_iter()
951 .map(|idx| self.events[idx].data_arc())
952 .collect();
953 #[cfg(not(feature = "rayon"))]
954 let bootstrapped_events: Vec<Arc<EventData>> = indices
955 .into_iter()
956 .map(|idx| self.events[idx].data_arc())
957 .collect();
958 Arc::new(Dataset::new_with_metadata(
959 bootstrapped_events,
960 self.metadata.clone(),
961 ))
962 }
963
964 #[cfg(feature = "mpi")]
972 pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
973 let n_events = self.n_events();
974 let mut indices: Vec<usize> = vec![0; n_events];
975 if world.is_root() {
976 let mut rng = fastrand::Rng::with_seed(seed as u64);
977 indices = (0..n_events)
978 .map(|_| rng.usize(0..n_events))
979 .collect::<Vec<usize>>();
980 indices.sort();
981 }
982 world.process_at_root().broadcast_into(&mut indices);
983 let local_indices: Vec<usize> = indices
984 .into_iter()
985 .filter_map(|idx| {
986 let (owning_rank, local_index) = world.owner_of_global_index(idx, n_events);
987 if world.rank() == owning_rank {
988 Some(local_index)
989 } else {
990 None
991 }
992 })
993 .collect();
994 #[cfg(feature = "rayon")]
997 let bootstrapped_events: Vec<Arc<EventData>> = local_indices
998 .into_par_iter()
999 .map(|idx| self.events[idx].data_arc())
1000 .collect();
1001 #[cfg(not(feature = "rayon"))]
1002 let bootstrapped_events: Vec<Arc<EventData>> = local_indices
1003 .into_iter()
1004 .map(|idx| self.events[idx].data_arc())
1005 .collect();
1006 Arc::new(Dataset::new_with_metadata(
1007 bootstrapped_events,
1008 self.metadata.clone(),
1009 ))
1010 }
1011
1012 pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
1015 #[cfg(feature = "mpi")]
1016 {
1017 if let Some(world) = crate::mpi::get_world() {
1018 return self.bootstrap_mpi(seed, &world);
1019 }
1020 }
1021 self.bootstrap_local(seed)
1022 }
1023
1024 pub fn filter(&self, expression: &VariableExpression) -> LadduResult<Arc<Dataset>> {
1027 let compiled = expression.compile(&self.metadata)?;
1028 #[cfg(feature = "rayon")]
1029 let filtered_events: Vec<Arc<EventData>> = self
1030 .events
1031 .par_iter()
1032 .filter(|event| compiled.evaluate(event.as_ref()))
1033 .map(|event| event.data_arc())
1034 .collect();
1035 #[cfg(not(feature = "rayon"))]
1036 let filtered_events: Vec<Arc<EventData>> = self
1037 .events
1038 .iter()
1039 .filter(|event| compiled.evaluate(event.as_ref()))
1040 .map(|event| event.data_arc())
1041 .collect();
1042 Ok(Arc::new(Dataset::new_with_metadata(
1043 filtered_events,
1044 self.metadata.clone(),
1045 )))
1046 }
1047
1048 pub fn bin_by<V>(
1051 &self,
1052 mut variable: V,
1053 bins: usize,
1054 range: (f64, f64),
1055 ) -> LadduResult<BinnedDataset>
1056 where
1057 V: Variable,
1058 {
1059 variable.bind(self.metadata())?;
1060 let bin_width = (range.1 - range.0) / bins as f64;
1061 let bin_edges = get_bin_edges(bins, range);
1062 let variable = variable;
1063 #[cfg(feature = "rayon")]
1064 let evaluated: Vec<(usize, Arc<EventData>)> = self
1065 .events
1066 .par_iter()
1067 .filter_map(|event| {
1068 let value = variable.value(event.as_ref());
1069 if value >= range.0 && value < range.1 {
1070 let bin_index = ((value - range.0) / bin_width) as usize;
1071 let bin_index = bin_index.min(bins - 1);
1072 Some((bin_index, event.data_arc()))
1073 } else {
1074 None
1075 }
1076 })
1077 .collect();
1078 #[cfg(not(feature = "rayon"))]
1079 let evaluated: Vec<(usize, Arc<EventData>)> = self
1080 .events
1081 .iter()
1082 .filter_map(|event| {
1083 let value = variable.value(event.as_ref());
1084 if value >= range.0 && value < range.1 {
1085 let bin_index = ((value - range.0) / bin_width) as usize;
1086 let bin_index = bin_index.min(bins - 1);
1087 Some((bin_index, event.data_arc()))
1088 } else {
1089 None
1090 }
1091 })
1092 .collect();
1093 let mut binned_events: Vec<Vec<Arc<EventData>>> = vec![Vec::default(); bins];
1094 for (bin_index, event) in evaluated {
1095 binned_events[bin_index].push(event.clone());
1096 }
1097 #[cfg(feature = "rayon")]
1098 let datasets: Vec<Arc<Dataset>> = binned_events
1099 .into_par_iter()
1100 .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1101 .collect();
1102 #[cfg(not(feature = "rayon"))]
1103 let datasets: Vec<Arc<Dataset>> = binned_events
1104 .into_iter()
1105 .map(|events| Arc::new(Dataset::new_with_metadata(events, self.metadata.clone())))
1106 .collect();
1107 Ok(BinnedDataset {
1108 datasets,
1109 edges: bin_edges,
1110 })
1111 }
1112
1113 pub fn boost_to_rest_frame_of<S>(&self, names: &[S]) -> Arc<Dataset>
1116 where
1117 S: AsRef<str>,
1118 {
1119 let mut indices: Vec<usize> = Vec::new();
1120 for name in names {
1121 let name_ref = name.as_ref();
1122 if let Some(selection) = self.metadata.p4_selection(name_ref) {
1123 indices.extend_from_slice(selection.indices());
1124 } else {
1125 panic!("Unknown particle name '{name}'", name = name_ref);
1126 }
1127 }
1128 #[cfg(feature = "rayon")]
1129 let boosted_events: Vec<Arc<EventData>> = self
1130 .events
1131 .par_iter()
1132 .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1133 .collect();
1134 #[cfg(not(feature = "rayon"))]
1135 let boosted_events: Vec<Arc<EventData>> = self
1136 .events
1137 .iter()
1138 .map(|event| Arc::new(event.data().boost_to_rest_frame_of(&indices)))
1139 .collect();
1140 Arc::new(Dataset::new_with_metadata(
1141 boosted_events,
1142 self.metadata.clone(),
1143 ))
1144 }
1145 pub fn evaluate<V: Variable>(&self, variable: &V) -> LadduResult<Vec<f64>> {
1147 variable.value_on(self)
1148 }
1149
1150 pub fn from_parquet(
1160 file_path: &str,
1161 options: &DatasetReadOptions,
1162 ) -> LadduResult<Arc<Dataset>> {
1163 let path = Self::canonicalize_dataset_path(file_path)?;
1164 Self::open_parquet_impl(path, options)
1165 }
1166
1167 pub fn from_root(file_path: &str, options: &DatasetReadOptions) -> LadduResult<Arc<Dataset>> {
1172 let path = Self::canonicalize_dataset_path(file_path)?;
1173 Self::open_root_impl(path, options)
1174 }
1175
1176 pub fn to_parquet(&self, file_path: &str, options: &DatasetWriteOptions) -> LadduResult<()> {
1178 let path = Self::expand_output_path(file_path)?;
1179 self.write_parquet_impl(path, options)
1180 }
1181
1182 pub fn to_root(&self, file_path: &str, options: &DatasetWriteOptions) -> LadduResult<()> {
1184 let path = Self::expand_output_path(file_path)?;
1185 self.write_root_impl(path, options)
1186 }
1187 fn canonicalize_dataset_path(file_path: &str) -> LadduResult<PathBuf> {
1188 Ok(Path::new(&*shellexpand::full(file_path)?).canonicalize()?)
1189 }
1190
1191 fn expand_output_path(file_path: &str) -> LadduResult<PathBuf> {
1192 Ok(PathBuf::from(&*shellexpand::full(file_path)?))
1193 }
1194
1195 fn open_parquet_impl(
1196 file_path: PathBuf,
1197 options: &DatasetReadOptions,
1198 ) -> LadduResult<Arc<Dataset>> {
1199 let (detected_p4_names, detected_aux_names) = detect_columns(&file_path)?;
1200 let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1201 let file = File::open(file_path)?;
1202 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1203 let reader = builder.build()?;
1204 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
1205
1206 #[cfg(feature = "rayon")]
1207 let events: Vec<Arc<EventData>> = {
1208 let batch_events: Vec<LadduResult<Vec<Arc<EventData>>>> = batches
1209 .into_par_iter()
1210 .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1211 .collect();
1212 let mut events = Vec::new();
1213 for batch in batch_events {
1214 let mut batch = batch?;
1215 events.append(&mut batch);
1216 }
1217 events
1218 };
1219
1220 #[cfg(not(feature = "rayon"))]
1221 let events: Vec<Arc<EventData>> = {
1222 batches
1223 .into_iter()
1224 .map(|batch| record_batch_to_events(batch, &metadata.p4_names, &metadata.aux_names))
1225 .collect::<LadduResult<Vec<_>>>()?
1226 .into_iter()
1227 .flatten()
1228 .collect()
1229 };
1230
1231 Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1232 }
1233
1234 fn open_root_impl(
1235 file_path: PathBuf,
1236 options: &DatasetReadOptions,
1237 ) -> LadduResult<Arc<Dataset>> {
1238 let mut file = RootFile::open(&file_path).map_err(|err| {
1239 LadduError::Custom(format!(
1240 "Failed to open ROOT file '{}': {err}",
1241 file_path.display()
1242 ))
1243 })?;
1244
1245 let (tree, tree_name) = resolve_root_tree(&mut file, options.tree.as_deref())?;
1246
1247 let branches: Vec<&Branch> = tree.branches().collect();
1248 let mut lookup: BranchLookup<'_> = IndexMap::new();
1249 for &branch in &branches {
1250 if let Some(kind) = branch_scalar_kind(branch) {
1251 lookup.insert(branch.name(), (kind, branch));
1252 }
1253 }
1254
1255 if lookup.is_empty() {
1256 return Err(LadduError::Custom(format!(
1257 "No float or double branches found in ROOT tree '{tree_name}'"
1258 )));
1259 }
1260
1261 let column_names: Vec<&str> = lookup.keys().copied().collect();
1262 let (detected_p4_names, detected_aux_names) = infer_p4_and_aux_names(&column_names);
1263 let metadata = options.resolve_metadata(detected_p4_names, detected_aux_names)?;
1264
1265 struct RootP4Columns {
1266 px: Vec<f64>,
1267 py: Vec<f64>,
1268 pz: Vec<f64>,
1269 e: Vec<f64>,
1270 }
1271
1272 let mut p4_columns = Vec::with_capacity(metadata.p4_names.len());
1274 for name in &metadata.p4_names {
1275 let logical = format!("{name}_px");
1276 let px = read_branch_values_from_candidates(
1277 &lookup,
1278 &component_candidates(name, "px"),
1279 &logical,
1280 )?;
1281
1282 let logical = format!("{name}_py");
1283 let py = read_branch_values_from_candidates(
1284 &lookup,
1285 &component_candidates(name, "py"),
1286 &logical,
1287 )?;
1288
1289 let logical = format!("{name}_pz");
1290 let pz = read_branch_values_from_candidates(
1291 &lookup,
1292 &component_candidates(name, "pz"),
1293 &logical,
1294 )?;
1295
1296 let logical = format!("{name}_e");
1297 let e = read_branch_values_from_candidates(
1298 &lookup,
1299 &component_candidates(name, "e"),
1300 &logical,
1301 )?;
1302
1303 p4_columns.push(RootP4Columns { px, py, pz, e });
1304 }
1305
1306 let mut aux_columns = Vec::with_capacity(metadata.aux_names.len());
1307 for name in &metadata.aux_names {
1308 let values = read_branch_values(&lookup, name)?;
1309 aux_columns.push(values);
1310 }
1311
1312 let n_events = if let Some(first) = p4_columns.first() {
1313 first.px.len()
1314 } else if let Some(first) = aux_columns.first() {
1315 first.len()
1316 } else {
1317 return Err(LadduError::Custom(
1318 "Unable to determine event count; dataset has no four-momentum or auxiliary columns".to_string(),
1319 ));
1320 };
1321
1322 let weight_values = match read_branch_values_optional(&lookup, "weight")? {
1323 Some(values) => {
1324 if values.len() != n_events {
1325 return Err(LadduError::Custom(format!(
1326 "Column 'weight' has {} entries but expected {}",
1327 values.len(),
1328 n_events
1329 )));
1330 }
1331 values
1332 }
1333 None => vec![1.0; n_events],
1334 };
1335
1336 let mut events = Vec::with_capacity(n_events);
1337 for row in 0..n_events {
1338 let mut p4s = Vec::with_capacity(p4_columns.len());
1339 for columns in &p4_columns {
1340 p4s.push(Vec4::new(
1341 columns.px[row],
1342 columns.py[row],
1343 columns.pz[row],
1344 columns.e[row],
1345 ));
1346 }
1347
1348 let mut aux = Vec::with_capacity(aux_columns.len());
1349 for column in &aux_columns {
1350 aux.push(column[row]);
1351 }
1352
1353 let event = EventData {
1354 p4s,
1355 aux,
1356 weight: weight_values[row],
1357 };
1358 events.push(Arc::new(event));
1359 }
1360
1361 Ok(Arc::new(Dataset::new_with_metadata(events, metadata)))
1362 }
1363
1364 fn write_parquet_impl(
1365 &self,
1366 file_path: PathBuf,
1367 options: &DatasetWriteOptions,
1368 ) -> LadduResult<()> {
1369 let batch_size = options.batch_size.max(1);
1370 let precision = options.precision;
1371 let schema = Arc::new(build_parquet_schema(&self.metadata, precision));
1372
1373 #[cfg(feature = "mpi")]
1374 let is_root = crate::mpi::get_world()
1375 .as_ref()
1376 .map_or(true, |world| world.rank() == 0);
1377 #[cfg(not(feature = "mpi"))]
1378 let is_root = true;
1379
1380 let mut writer: Option<ArrowWriter<File>> = None;
1381 if is_root {
1382 let file = File::create(&file_path)?;
1383 writer = Some(
1384 ArrowWriter::try_new(file, schema.clone(), None).map_err(|err| {
1385 LadduError::Custom(format!("Failed to create Parquet writer: {err}"))
1386 })?,
1387 );
1388 }
1389
1390 let mut iter = self.iter();
1391 loop {
1392 let mut buffers =
1393 ColumnBuffers::new(self.metadata.p4_names.len(), self.metadata.aux_names.len());
1394 let mut rows = 0usize;
1395
1396 while rows < batch_size {
1397 match iter.next() {
1398 Some(event) => {
1399 if is_root {
1400 buffers.push_event(&event);
1401 }
1402 rows += 1;
1403 }
1404 None => break,
1405 }
1406 }
1407
1408 if rows == 0 {
1409 break;
1410 }
1411
1412 if let Some(writer) = writer.as_mut() {
1413 let batch = buffers
1414 .into_record_batch(schema.clone(), precision)
1415 .map_err(|err| {
1416 LadduError::Custom(format!("Failed to build Parquet batch: {err}"))
1417 })?;
1418 writer.write(&batch).map_err(|err| {
1419 LadduError::Custom(format!("Failed to write Parquet batch: {err}"))
1420 })?;
1421 }
1422 }
1423
1424 if let Some(writer) = writer {
1425 writer.close().map_err(|err| {
1426 LadduError::Custom(format!("Failed to finalise Parquet file: {err}"))
1427 })?;
1428 }
1429
1430 Ok(())
1431 }
1432
1433 fn write_root_impl(
1434 &self,
1435 file_path: PathBuf,
1436 options: &DatasetWriteOptions,
1437 ) -> LadduResult<()> {
1438 let tree_name = options.tree.clone().unwrap_or_else(|| "events".to_string());
1439 let branch_count = self.metadata.p4_names.len() * 4 + self.metadata.aux_names.len() + 1; #[cfg(feature = "mpi")]
1442 let mut world_opt = crate::mpi::get_world();
1443 #[cfg(feature = "mpi")]
1444 let is_root = world_opt.as_ref().map_or(true, |world| world.rank() == 0);
1445 #[cfg(not(feature = "mpi"))]
1446 let is_root = true;
1447
1448 #[cfg(feature = "mpi")]
1449 let world: Option<WorldHandle> = world_opt.take();
1450 #[cfg(not(feature = "mpi"))]
1451 let world: Option<WorldHandle> = None;
1452
1453 let total_events = self.n_events();
1454 let dataset_arc = Arc::new(self.clone());
1455
1456 match options.precision {
1457 FloatPrecision::F64 => self.write_root_with_type::<f64>(
1458 dataset_arc,
1459 world,
1460 is_root,
1461 &file_path,
1462 &tree_name,
1463 branch_count,
1464 total_events,
1465 ),
1466 FloatPrecision::F32 => self.write_root_with_type::<f32>(
1467 dataset_arc,
1468 world,
1469 is_root,
1470 &file_path,
1471 &tree_name,
1472 branch_count,
1473 total_events,
1474 ),
1475 }
1476 }
1477}
1478
1479impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset {
1480 debug_assert_eq!(a.metadata.p4_names, b.metadata.p4_names);
1481 debug_assert_eq!(a.metadata.aux_names, b.metadata.aux_names);
1482 Dataset {
1483 events: a
1484 .events
1485 .iter()
1486 .chain(b.events.iter())
1487 .cloned()
1488 .collect(),
1489 metadata: a.metadata.clone(),
1490 }
1491});
1492
1493#[derive(Default, Clone)]
1498pub struct DatasetReadOptions {
1499 pub p4_names: Option<Vec<String>>,
1501 pub aux_names: Option<Vec<String>>,
1503 pub tree: Option<String>,
1506 pub aliases: IndexMap<String, P4Selection>,
1508}
1509
1510#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
1512pub enum FloatPrecision {
1513 F32,
1515 #[default]
1517 F64,
1518}
1519
1520#[derive(Clone, Debug)]
1522pub struct DatasetWriteOptions {
1523 pub batch_size: usize,
1525 pub precision: FloatPrecision,
1527 pub tree: Option<String>,
1529}
1530
1531impl Default for DatasetWriteOptions {
1532 fn default() -> Self {
1533 Self {
1534 batch_size: DEFAULT_WRITE_BATCH_SIZE,
1535 precision: FloatPrecision::default(),
1536 tree: None,
1537 }
1538 }
1539}
1540
1541impl DatasetWriteOptions {
1542 pub fn batch_size(mut self, batch_size: usize) -> Self {
1544 self.batch_size = batch_size;
1545 self
1546 }
1547
1548 pub fn precision(mut self, precision: FloatPrecision) -> Self {
1550 self.precision = precision;
1551 self
1552 }
1553
1554 pub fn tree<S: Into<String>>(mut self, name: S) -> Self {
1556 self.tree = Some(name.into());
1557 self
1558 }
1559}
1560impl DatasetReadOptions {
1561 pub fn new() -> Self {
1563 Self::default()
1564 }
1565
1566 pub fn p4_names<I, S>(mut self, names: I) -> Self
1569 where
1570 I: IntoIterator<Item = S>,
1571 S: AsRef<str>,
1572 {
1573 self.p4_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1574 self
1575 }
1576
1577 pub fn aux_names<I, S>(mut self, names: I) -> Self
1581 where
1582 I: IntoIterator<Item = S>,
1583 S: AsRef<str>,
1584 {
1585 self.aux_names = Some(names.into_iter().map(|s| s.as_ref().to_string()).collect());
1586 self
1587 }
1588
1589 pub fn tree<S>(mut self, name: S) -> Self
1591 where
1592 S: AsRef<str>,
1593 {
1594 self.tree = Some(name.as_ref().to_string());
1595 self
1596 }
1597
1598 pub fn alias<N, S>(mut self, name: N, selection: S) -> Self
1600 where
1601 N: Into<String>,
1602 S: IntoP4Selection,
1603 {
1604 self.aliases.insert(name.into(), selection.into_selection());
1605 self
1606 }
1607
1608 pub fn aliases<I, N, S>(mut self, aliases: I) -> Self
1610 where
1611 I: IntoIterator<Item = (N, S)>,
1612 N: Into<String>,
1613 S: IntoP4Selection,
1614 {
1615 for (name, selection) in aliases {
1616 self = self.alias(name, selection);
1617 }
1618 self
1619 }
1620
1621 fn resolve_metadata(
1622 &self,
1623 detected_p4_names: Vec<String>,
1624 detected_aux_names: Vec<String>,
1625 ) -> LadduResult<Arc<DatasetMetadata>> {
1626 let p4_names_vec = self.p4_names.clone().unwrap_or(detected_p4_names);
1627 let aux_names_vec = self.aux_names.clone().unwrap_or(detected_aux_names);
1628
1629 let mut metadata = DatasetMetadata::new(p4_names_vec, aux_names_vec)?;
1630 if !self.aliases.is_empty() {
1631 metadata.add_p4_aliases(self.aliases.clone())?;
1632 }
1633 Ok(Arc::new(metadata))
1634 }
1635}
1636
1637const P4_COMPONENT_SUFFIXES: [&str; 4] = ["_px", "_py", "_pz", "_e"];
1638const DEFAULT_WRITE_BATCH_SIZE: usize = 10_000;
1639
1640fn detect_columns(file_path: &PathBuf) -> LadduResult<(Vec<String>, Vec<String>)> {
1641 let file = File::open(file_path)?;
1642 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
1643 let schema = builder.schema();
1644 let float_cols: Vec<&str> = schema
1645 .fields()
1646 .iter()
1647 .filter(|f| matches!(f.data_type(), DataType::Float32 | DataType::Float64))
1648 .map(|f| f.name().as_str())
1649 .collect();
1650 Ok(infer_p4_and_aux_names(&float_cols))
1651}
1652
1653fn infer_p4_and_aux_names(float_cols: &[&str]) -> (Vec<String>, Vec<String>) {
1654 let suffix_set: IndexSet<&str> = P4_COMPONENT_SUFFIXES.iter().copied().collect();
1655 let mut groups: IndexMap<&str, IndexSet<&str>> = IndexMap::new();
1656 for col in float_cols {
1657 for suffix in &suffix_set {
1658 if let Some(prefix) = col.strip_suffix(suffix) {
1659 groups.entry(prefix).or_default().insert(*suffix);
1660 }
1661 }
1662 }
1663
1664 let mut p4_names: Vec<String> = Vec::new();
1665 let mut p4_columns: IndexSet<String> = IndexSet::new();
1666 for (prefix, suffixes) in &groups {
1667 if suffixes.len() == suffix_set.len() {
1668 p4_names.push((*prefix).to_string());
1669 for suffix in &suffix_set {
1670 p4_columns.insert(format!("{prefix}{suffix}"));
1671 }
1672 }
1673 }
1674
1675 let mut aux_names: Vec<String> = Vec::new();
1676 for col in float_cols {
1677 if p4_columns.contains(*col) {
1678 continue;
1679 }
1680 if col.eq_ignore_ascii_case("weight") {
1681 continue;
1682 }
1683 aux_names.push((*col).to_string());
1684 }
1685
1686 (p4_names, aux_names)
1687}
1688
1689type BranchLookup<'a> = IndexMap<&'a str, (RootScalarKind, &'a Branch)>;
1690
1691#[derive(Clone, Copy)]
1692enum RootScalarKind {
1693 F32,
1694 F64,
1695}
1696
1697fn branch_scalar_kind(branch: &Branch) -> Option<RootScalarKind> {
1698 let type_name = branch.item_type_name();
1699 let lower = type_name.to_ascii_lowercase();
1700 if lower.contains("vector") {
1701 return None;
1702 }
1703 match lower.as_str() {
1704 "float" | "float_t" | "float32_t" => Some(RootScalarKind::F32),
1705 "double" | "double_t" | "double32_t" => Some(RootScalarKind::F64),
1706 _ => None,
1707 }
1708}
1709
1710fn read_branch_values<'a>(lookup: &BranchLookup<'a>, column_name: &str) -> LadduResult<Vec<f64>> {
1711 let (kind, branch) =
1712 lookup
1713 .get(column_name)
1714 .copied()
1715 .ok_or_else(|| LadduError::MissingColumn {
1716 name: column_name.to_string(),
1717 })?;
1718
1719 let values = match kind {
1720 RootScalarKind::F32 => branch
1721 .as_iter::<f32>()
1722 .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1723 .map(|value| value as f64)
1724 .collect(),
1725 RootScalarKind::F64 => branch
1726 .as_iter::<f64>()
1727 .map_err(|err| map_root_error(&format!("Failed to read branch '{column_name}'"), err))?
1728 .collect(),
1729 };
1730 Ok(values)
1731}
1732
1733fn read_branch_values_optional<'a>(
1734 lookup: &BranchLookup<'a>,
1735 column_name: &str,
1736) -> LadduResult<Option<Vec<f64>>> {
1737 if lookup.contains_key(column_name) {
1738 read_branch_values(lookup, column_name).map(Some)
1739 } else {
1740 Ok(None)
1741 }
1742}
1743
1744fn read_branch_values_from_candidates<'a>(
1745 lookup: &BranchLookup<'a>,
1746 candidates: &[String],
1747 logical_name: &str,
1748) -> LadduResult<Vec<f64>> {
1749 for candidate in candidates {
1750 if lookup.contains_key(candidate.as_str()) {
1751 return read_branch_values(lookup, candidate);
1752 }
1753 }
1754 Err(LadduError::MissingColumn {
1755 name: logical_name.to_string(),
1756 })
1757}
1758
1759fn resolve_root_tree(
1760 file: &mut RootFile,
1761 requested: Option<&str>,
1762) -> LadduResult<(ReaderTree, String)> {
1763 if let Some(name) = requested {
1764 let tree = file
1765 .get_tree(name)
1766 .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{name}'"), err))?;
1767 return Ok((tree, name.to_string()));
1768 }
1769
1770 let tree_names: Vec<String> = file
1771 .keys()
1772 .into_iter()
1773 .filter(|key| key.class_name() == "TTree")
1774 .map(|key| key.name().to_string())
1775 .collect();
1776
1777 if tree_names.is_empty() {
1778 return Err(LadduError::Custom(
1779 "ROOT file does not contain any TTrees".to_string(),
1780 ));
1781 }
1782
1783 if tree_names.len() > 1 {
1784 return Err(LadduError::Custom(format!(
1785 "Multiple TTrees found ({:?}); specify DatasetReadOptions::tree to disambiguate",
1786 tree_names
1787 )));
1788 }
1789
1790 let selected = &tree_names[0];
1791 let tree = file
1792 .get_tree(selected)
1793 .map_err(|err| map_root_error(&format!("Failed to open ROOT tree '{selected}'"), err))?;
1794 Ok((tree, selected.clone()))
1795}
1796
1797fn map_root_error<E: std::fmt::Display>(context: &str, err: E) -> LadduError {
1798 LadduError::Custom(format!("{context}: {err}")) }
1800
1801#[derive(Clone, Copy)]
1802enum FloatColumn<'a> {
1803 F32(&'a Float32Array),
1804 F64(&'a Float64Array),
1805}
1806
1807impl<'a> FloatColumn<'a> {
1808 fn value(&self, row: usize) -> f64 {
1809 match self {
1810 Self::F32(array) => array.value(row) as f64,
1811 Self::F64(array) => array.value(row),
1812 }
1813 }
1814}
1815
1816struct P4Columns<'a> {
1817 px: FloatColumn<'a>,
1818 py: FloatColumn<'a>,
1819 pz: FloatColumn<'a>,
1820 e: FloatColumn<'a>,
1821}
1822
1823fn prepare_float_column<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<FloatColumn<'a>> {
1824 prepare_float_column_from_candidates(batch, &[name.to_string()], name)
1825}
1826
1827fn prepare_p4_columns<'a>(batch: &'a RecordBatch, name: &str) -> LadduResult<P4Columns<'a>> {
1828 Ok(P4Columns {
1829 px: prepare_float_column_from_candidates(
1830 batch,
1831 &component_candidates(name, "px"),
1832 &format!("{name}_px"),
1833 )?,
1834 py: prepare_float_column_from_candidates(
1835 batch,
1836 &component_candidates(name, "py"),
1837 &format!("{name}_py"),
1838 )?,
1839 pz: prepare_float_column_from_candidates(
1840 batch,
1841 &component_candidates(name, "pz"),
1842 &format!("{name}_pz"),
1843 )?,
1844 e: prepare_float_column_from_candidates(
1845 batch,
1846 &component_candidates(name, "e"),
1847 &format!("{name}_e"),
1848 )?,
1849 })
1850}
1851
1852fn component_candidates(name: &str, suffix: &str) -> Vec<String> {
1853 let mut candidates = Vec::with_capacity(3);
1854 let base = format!("{name}_{suffix}");
1855 candidates.push(base.clone());
1856
1857 let mut capitalized = suffix.to_string();
1858 if let Some(first) = capitalized.get_mut(0..1) {
1859 first.make_ascii_uppercase();
1860 }
1861 if capitalized != suffix {
1862 candidates.push(format!("{name}_{capitalized}"));
1863 }
1864
1865 let upper = suffix.to_ascii_uppercase();
1866 if upper != suffix && upper != capitalized {
1867 candidates.push(format!("{name}_{upper}"));
1868 }
1869
1870 candidates
1871}
1872
1873fn find_float_column_from_candidates<'a>(
1874 batch: &'a RecordBatch,
1875 candidates: &[String],
1876) -> LadduResult<Option<FloatColumn<'a>>> {
1877 use arrow::datatypes::DataType;
1878
1879 for candidate in candidates {
1880 if let Some(column) = batch.column_by_name(candidate) {
1881 return match column.data_type() {
1882 DataType::Float32 => Ok(Some(FloatColumn::F32(
1883 column
1884 .as_any()
1885 .downcast_ref::<Float32Array>()
1886 .expect("Column advertised as Float32 but could not be downcast"),
1887 ))),
1888 DataType::Float64 => Ok(Some(FloatColumn::F64(
1889 column
1890 .as_any()
1891 .downcast_ref::<Float64Array>()
1892 .expect("Column advertised as Float64 but could not be downcast"),
1893 ))),
1894 other => {
1895 return Err(LadduError::InvalidColumnType {
1896 name: candidate.clone(),
1897 datatype: other.to_string(),
1898 })
1899 }
1900 };
1901 }
1902 }
1903 Ok(None)
1904}
1905
1906fn prepare_float_column_from_candidates<'a>(
1907 batch: &'a RecordBatch,
1908 candidates: &[String],
1909 logical_name: &str,
1910) -> LadduResult<FloatColumn<'a>> {
1911 find_float_column_from_candidates(batch, candidates)?.ok_or_else(|| LadduError::MissingColumn {
1912 name: logical_name.to_string(),
1913 })
1914}
1915
1916fn record_batch_to_events(
1917 batch: RecordBatch,
1918 p4_names: &[String],
1919 aux_names: &[String],
1920) -> LadduResult<Vec<Arc<EventData>>> {
1921 let batch_ref = &batch;
1922 let p4_columns: Vec<P4Columns<'_>> = p4_names
1923 .iter()
1924 .map(|name| prepare_p4_columns(batch_ref, name))
1925 .collect::<Result<_, _>>()?;
1926
1927 let aux_columns: Vec<FloatColumn<'_>> = aux_names
1928 .iter()
1929 .map(|name| prepare_float_column(batch_ref, name))
1930 .collect::<Result<_, _>>()?;
1931
1932 let weight_column = find_float_column_from_candidates(batch_ref, &["weight".to_string()])?;
1933
1934 let mut events = Vec::with_capacity(batch_ref.num_rows());
1935 for row in 0..batch_ref.num_rows() {
1936 let mut p4s = Vec::with_capacity(p4_columns.len());
1937 for columns in &p4_columns {
1938 let px = columns.px.value(row);
1939 let py = columns.py.value(row);
1940 let pz = columns.pz.value(row);
1941 let e = columns.e.value(row);
1942 p4s.push(Vec4::new(px, py, pz, e));
1943 }
1944
1945 let mut aux = Vec::with_capacity(aux_columns.len());
1946 for column in &aux_columns {
1947 aux.push(column.value(row));
1948 }
1949
1950 let event = EventData {
1951 p4s,
1952 aux,
1953 weight: weight_column
1954 .as_ref()
1955 .map(|column| column.value(row))
1956 .unwrap_or(1.0),
1957 };
1958 events.push(Arc::new(event));
1959 }
1960 Ok(events)
1961}
1962
1963struct ColumnBuffers {
1964 p4: Vec<P4Buffer>,
1965 aux: Vec<Vec<f64>>,
1966 weight: Vec<f64>,
1967}
1968
1969impl ColumnBuffers {
1970 fn new(n_p4: usize, n_aux: usize) -> Self {
1971 let p4 = (0..n_p4).map(|_| P4Buffer::default()).collect();
1972 let aux = vec![Vec::new(); n_aux];
1973 Self {
1974 p4,
1975 aux,
1976 weight: Vec::new(),
1977 }
1978 }
1979
1980 fn push_event(&mut self, event: &Event) {
1981 for (buffer, p4) in self.p4.iter_mut().zip(event.p4s.iter()) {
1982 buffer.px.push(p4.x);
1983 buffer.py.push(p4.y);
1984 buffer.pz.push(p4.z);
1985 buffer.e.push(p4.t);
1986 }
1987
1988 for (buffer, value) in self.aux.iter_mut().zip(event.aux.iter()) {
1989 buffer.push(*value);
1990 }
1991
1992 self.weight.push(event.weight);
1993 }
1994
1995 fn into_record_batch(
1996 self,
1997 schema: Arc<Schema>,
1998 precision: FloatPrecision,
1999 ) -> arrow::error::Result<RecordBatch> {
2000 let mut columns: Vec<arrow::array::ArrayRef> = Vec::new();
2001
2002 match precision {
2003 FloatPrecision::F64 => {
2004 for buffer in &self.p4 {
2005 columns.push(Arc::new(Float64Array::from(buffer.px.clone())));
2006 columns.push(Arc::new(Float64Array::from(buffer.py.clone())));
2007 columns.push(Arc::new(Float64Array::from(buffer.pz.clone())));
2008 columns.push(Arc::new(Float64Array::from(buffer.e.clone())));
2009 }
2010
2011 for buffer in &self.aux {
2012 columns.push(Arc::new(Float64Array::from(buffer.clone())));
2013 }
2014
2015 columns.push(Arc::new(Float64Array::from(self.weight)));
2016 }
2017 FloatPrecision::F32 => {
2018 for buffer in &self.p4 {
2019 columns.push(Arc::new(Float32Array::from(
2020 buffer.px.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2021 )));
2022 columns.push(Arc::new(Float32Array::from(
2023 buffer.py.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2024 )));
2025 columns.push(Arc::new(Float32Array::from(
2026 buffer.pz.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2027 )));
2028 columns.push(Arc::new(Float32Array::from(
2029 buffer.e.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2030 )));
2031 }
2032
2033 for buffer in &self.aux {
2034 columns.push(Arc::new(Float32Array::from(
2035 buffer.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2036 )));
2037 }
2038
2039 columns.push(Arc::new(Float32Array::from(
2040 self.weight.iter().map(|v| *v as f32).collect::<Vec<_>>(),
2041 )));
2042 }
2043 }
2044
2045 RecordBatch::try_new(schema, columns)
2046 }
2047}
2048
2049#[derive(Default)]
2050struct P4Buffer {
2051 px: Vec<f64>,
2052 py: Vec<f64>,
2053 pz: Vec<f64>,
2054 e: Vec<f64>,
2055}
2056
2057fn build_parquet_schema(metadata: &DatasetMetadata, precision: FloatPrecision) -> Schema {
2058 let dtype = match precision {
2059 FloatPrecision::F64 => DataType::Float64,
2060 FloatPrecision::F32 => DataType::Float32,
2061 };
2062
2063 let mut fields = Vec::new();
2064 for name in &metadata.p4_names {
2065 for suffix in P4_COMPONENT_SUFFIXES {
2066 fields.push(Field::new(format!("{name}{suffix}"), dtype.clone(), false));
2067 }
2068 }
2069
2070 for name in &metadata.aux_names {
2071 fields.push(Field::new(name.clone(), dtype.clone(), false));
2072 }
2073
2074 fields.push(Field::new("weight", dtype, false));
2075 Schema::new(fields)
2076}
2077
2078trait FromF64 {
2079 fn from_f64(value: f64) -> Self;
2080}
2081
2082impl FromF64 for f64 {
2083 fn from_f64(value: f64) -> Self {
2084 value
2085 }
2086}
2087
2088impl FromF64 for f32 {
2089 fn from_f64(value: f64) -> Self {
2090 value as f32
2091 }
2092}
2093
2094struct SharedEventFetcher {
2095 dataset: Arc<Dataset>,
2096 world: Option<WorldHandle>,
2097 total: usize,
2098 branch_count: usize,
2099 current_index: Option<usize>,
2100 current_event: Option<Event>,
2101 remaining: usize,
2102}
2103
2104impl SharedEventFetcher {
2105 fn new(
2106 dataset: Arc<Dataset>,
2107 world: Option<WorldHandle>,
2108 total: usize,
2109 branch_count: usize,
2110 ) -> Self {
2111 Self {
2112 dataset,
2113 world,
2114 total,
2115 branch_count,
2116 current_index: None,
2117 current_event: None,
2118 remaining: 0,
2119 }
2120 }
2121
2122 fn event_for_index(&mut self, index: usize) -> Option<Event> {
2123 if index >= self.total {
2124 return None;
2125 }
2126
2127 let refresh_needed = match self.current_index {
2128 None => true,
2129 Some(current) => current != index || self.remaining == 0,
2130 };
2131
2132 if refresh_needed {
2133 let event =
2134 fetch_event_for_index(&self.dataset, index, self.total, self.world.as_ref());
2135 self.current_index = Some(index);
2136 self.remaining = self.branch_count;
2137 self.current_event = Some(event);
2138 }
2139
2140 let event = self.current_event.as_ref().cloned();
2141 if self.remaining > 0 {
2142 self.remaining -= 1;
2143 }
2144 if self.remaining == 0 {
2145 self.current_event = None;
2147 }
2148 event
2149 }
2150}
2151
2152enum ColumnKind {
2153 Px(usize),
2154 Py(usize),
2155 Pz(usize),
2156 E(usize),
2157 Aux(usize),
2158 Weight,
2159}
2160
2161struct ColumnIterator<T> {
2162 fetcher: Arc<Mutex<SharedEventFetcher>>,
2163 index: usize,
2164 kind: ColumnKind,
2165 _marker: std::marker::PhantomData<T>,
2166}
2167
2168impl<T> ColumnIterator<T> {
2169 fn new(fetcher: Arc<Mutex<SharedEventFetcher>>, kind: ColumnKind) -> Self {
2170 Self {
2171 fetcher,
2172 index: 0,
2173 kind,
2174 _marker: std::marker::PhantomData,
2175 }
2176 }
2177}
2178
2179impl<T> Iterator for ColumnIterator<T>
2180where
2181 T: FromF64,
2182{
2183 type Item = T;
2184
2185 fn next(&mut self) -> Option<Self::Item> {
2186 let mut fetcher = self.fetcher.lock();
2187 let event = fetcher.event_for_index(self.index)?;
2188 self.index += 1;
2189
2190 match self.kind {
2191 ColumnKind::Px(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.x)),
2192 ColumnKind::Py(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.y)),
2193 ColumnKind::Pz(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.z)),
2194 ColumnKind::E(idx) => event.p4s.get(idx).map(|p4| T::from_f64(p4.t)),
2195 ColumnKind::Aux(idx) => event.aux.get(idx).map(|value| T::from_f64(*value)),
2196 ColumnKind::Weight => Some(T::from_f64(event.weight)),
2197 }
2198 }
2199}
2200
2201fn build_root_column_iterators<T>(
2202 dataset: Arc<Dataset>,
2203 world: Option<WorldHandle>,
2204 branch_count: usize,
2205 total: usize,
2206) -> Vec<(String, ColumnIterator<T>)>
2207where
2208 T: FromF64,
2209{
2210 let fetcher = Arc::new(Mutex::new(SharedEventFetcher::new(
2211 dataset,
2212 world,
2213 total,
2214 branch_count,
2215 )));
2216
2217 let p4_names: Vec<String> = fetcher.lock().dataset.metadata.p4_names.clone();
2218 let aux_names: Vec<String> = fetcher.lock().dataset.metadata.aux_names.clone();
2219
2220 let mut iterators = Vec::new();
2221
2222 for (idx, name) in p4_names.iter().enumerate() {
2223 iterators.push((
2224 format!("{name}_px"),
2225 ColumnIterator::new(fetcher.clone(), ColumnKind::Px(idx)),
2226 ));
2227 iterators.push((
2228 format!("{name}_py"),
2229 ColumnIterator::new(fetcher.clone(), ColumnKind::Py(idx)),
2230 ));
2231 iterators.push((
2232 format!("{name}_pz"),
2233 ColumnIterator::new(fetcher.clone(), ColumnKind::Pz(idx)),
2234 ));
2235 iterators.push((
2236 format!("{name}_e"),
2237 ColumnIterator::new(fetcher.clone(), ColumnKind::E(idx)),
2238 ));
2239 }
2240
2241 for (idx, name) in aux_names.iter().enumerate() {
2242 iterators.push((
2243 name.clone(),
2244 ColumnIterator::new(fetcher.clone(), ColumnKind::Aux(idx)),
2245 ));
2246 }
2247
2248 iterators.push((
2249 "weight".to_string(),
2250 ColumnIterator::new(fetcher, ColumnKind::Weight),
2251 ));
2252
2253 iterators
2254}
2255
2256fn drain_column_iterators<T>(iterators: &mut [(String, ColumnIterator<T>)], n_events: usize)
2257where
2258 T: FromF64,
2259{
2260 for _ in 0..n_events {
2261 for (_name, iterator) in iterators.iter_mut() {
2262 let _ = iterator.next();
2263 }
2264 }
2265}
2266
2267fn fetch_event_for_index(
2268 dataset: &Dataset,
2269 index: usize,
2270 total: usize,
2271 world: Option<&WorldHandle>,
2272) -> Event {
2273 let _ = total;
2274 let _ = world;
2275 #[cfg(feature = "mpi")]
2276 {
2277 if let Some(world) = world {
2278 return fetch_event_mpi(dataset, index, world, total);
2279 }
2280 }
2281
2282 dataset.index_local(index).clone()
2283}
2284
2285impl Dataset {
2286 #[allow(clippy::too_many_arguments)]
2287 fn write_root_with_type<T>(
2288 &self,
2289 dataset: Arc<Dataset>,
2290 world: Option<WorldHandle>,
2291 is_root: bool,
2292 file_path: &Path,
2293 tree_name: &str,
2294 branch_count: usize,
2295 total_events: usize,
2296 ) -> LadduResult<()>
2297 where
2298 T: FromF64 + oxyroot::Marshaler + 'static,
2299 {
2300 let mut iterators =
2301 build_root_column_iterators::<T>(dataset, world, branch_count, total_events);
2302
2303 if is_root {
2304 let mut file = RootFile::create(file_path).map_err(|err| {
2305 LadduError::Custom(format!(
2306 "Failed to create ROOT file '{}': {err}",
2307 file_path.display()
2308 ))
2309 })?;
2310
2311 let mut tree = WriterTree::new(tree_name);
2312 for (name, iterator) in iterators {
2313 tree.new_branch(name, iterator);
2314 }
2315
2316 tree.write(&mut file).map_err(|err| {
2317 LadduError::Custom(format!(
2318 "Failed to write ROOT tree '{tree_name}' to '{}': {err}",
2319 file_path.display()
2320 ))
2321 })?;
2322
2323 file.close().map_err(|err| {
2324 LadduError::Custom(format!(
2325 "Failed to close ROOT file '{}': {err}",
2326 file_path.display()
2327 ))
2328 })?;
2329 } else {
2330 drain_column_iterators(&mut iterators, total_events);
2331 }
2332
2333 Ok(())
2334 }
2335}
2336
2337pub struct BinnedDataset {
2339 datasets: Vec<Arc<Dataset>>,
2340 edges: Vec<f64>,
2341}
2342
2343impl Index<usize> for BinnedDataset {
2344 type Output = Arc<Dataset>;
2345
2346 fn index(&self, index: usize) -> &Self::Output {
2347 &self.datasets[index]
2348 }
2349}
2350
2351impl IndexMut<usize> for BinnedDataset {
2352 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
2353 &mut self.datasets[index]
2354 }
2355}
2356
2357impl Deref for BinnedDataset {
2358 type Target = Vec<Arc<Dataset>>;
2359
2360 fn deref(&self) -> &Self::Target {
2361 &self.datasets
2362 }
2363}
2364
2365impl DerefMut for BinnedDataset {
2366 fn deref_mut(&mut self) -> &mut Self::Target {
2367 &mut self.datasets
2368 }
2369}
2370
2371impl BinnedDataset {
2372 pub fn n_bins(&self) -> usize {
2374 self.datasets.len()
2375 }
2376
2377 pub fn edges(&self) -> Vec<f64> {
2379 self.edges.clone()
2380 }
2381
2382 pub fn range(&self) -> (f64, f64) {
2384 (self.edges[0], self.edges[self.n_bins()])
2385 }
2386}
2387
2388#[cfg(test)]
2389mod tests {
2390 use crate::Mass;
2391
2392 use super::*;
2393 use crate::utils::vectors::Vec3;
2394 use approx::{assert_relative_eq, assert_relative_ne};
2395 use fastrand;
2396 use serde::{Deserialize, Serialize};
2397 use std::{
2398 env, fs,
2399 path::{Path, PathBuf},
2400 };
2401
2402 fn test_data_path(file: &str) -> PathBuf {
2403 Path::new(env!("CARGO_MANIFEST_DIR"))
2404 .join("test_data")
2405 .join(file)
2406 }
2407
2408 fn open_test_dataset(file: &str, options: DatasetReadOptions) -> Arc<Dataset> {
2409 let path = test_data_path(file);
2410 let path_str = path.to_str().expect("test data path should be valid UTF-8");
2411 let ext = path
2412 .extension()
2413 .and_then(|ext| ext.to_str())
2414 .unwrap_or_default()
2415 .to_ascii_lowercase();
2416 match ext.as_str() {
2417 "parquet" => Dataset::from_parquet(path_str, &options),
2418 "root" => Dataset::from_root(path_str, &options),
2419 other => panic!("Unsupported extension in test data: {other}"),
2420 }
2421 .expect("dataset should open")
2422 }
2423
2424 fn make_temp_dir() -> PathBuf {
2425 let dir = env::temp_dir().join(format!("laddu_test_{}", fastrand::u64(..)));
2426 fs::create_dir(&dir).expect("temp dir should be created");
2427 dir
2428 }
2429
2430 fn assert_events_close(left: &Event, right: &Event, p4_names: &[&str], aux_names: &[&str]) {
2431 for name in p4_names {
2432 let lp4 = left
2433 .p4(name)
2434 .unwrap_or_else(|| panic!("missing p4 '{name}' in left dataset"));
2435 let rp4 = right
2436 .p4(name)
2437 .unwrap_or_else(|| panic!("missing p4 '{name}' in right dataset"));
2438 assert_relative_eq!(lp4.px(), rp4.px(), epsilon = 1e-9);
2439 assert_relative_eq!(lp4.py(), rp4.py(), epsilon = 1e-9);
2440 assert_relative_eq!(lp4.pz(), rp4.pz(), epsilon = 1e-9);
2441 assert_relative_eq!(lp4.e(), rp4.e(), epsilon = 1e-9);
2442 }
2443 let left_aux = left.aux();
2444 let right_aux = right.aux();
2445 for name in aux_names {
2446 let laux = left_aux
2447 .get(name)
2448 .copied()
2449 .unwrap_or_else(|| panic!("missing aux '{name}' in left dataset"));
2450 let raux = right_aux
2451 .get(name)
2452 .copied()
2453 .unwrap_or_else(|| panic!("missing aux '{name}' in right dataset"));
2454 assert_relative_eq!(laux, raux, epsilon = 1e-9);
2455 }
2456 assert_relative_eq!(left.weight(), right.weight(), epsilon = 1e-9);
2457 }
2458
2459 fn assert_datasets_close(
2460 left: &Arc<Dataset>,
2461 right: &Arc<Dataset>,
2462 p4_names: &[&str],
2463 aux_names: &[&str],
2464 ) {
2465 assert_eq!(left.n_events(), right.n_events());
2466 for idx in 0..left.n_events() {
2467 let levent = &left[idx];
2468 let revent = &right[idx];
2469 assert_events_close(levent, revent, p4_names, aux_names);
2470 }
2471 }
2472
2473 #[test]
2474 fn test_from_parquet_auto_matches_explicit_names() {
2475 let auto = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2476 let explicit_options = DatasetReadOptions::new()
2477 .p4_names(TEST_P4_NAMES)
2478 .aux_names(TEST_AUX_NAMES);
2479 let explicit = open_test_dataset("data_f32.parquet", explicit_options);
2480
2481 let mut detected_p4: Vec<&str> = auto.p4_names().iter().map(String::as_str).collect();
2482 detected_p4.sort_unstable();
2483 let mut expected_p4 = TEST_P4_NAMES.to_vec();
2484 expected_p4.sort_unstable();
2485 assert_eq!(detected_p4, expected_p4);
2486 let mut detected_aux: Vec<&str> = auto.aux_names().iter().map(String::as_str).collect();
2487 detected_aux.sort_unstable();
2488 let mut expected_aux = TEST_AUX_NAMES.to_vec();
2489 expected_aux.sort_unstable();
2490 assert_eq!(detected_aux, expected_aux);
2491 assert_datasets_close(&auto, &explicit, TEST_P4_NAMES, TEST_AUX_NAMES);
2492 }
2493
2494 #[test]
2495 fn test_from_parquet_with_aliases() {
2496 let dataset = open_test_dataset(
2497 "data_f32.parquet",
2498 DatasetReadOptions::new().alias("resonance", ["kshort1", "kshort2"]),
2499 );
2500 let event = dataset.named_event(0);
2501 let alias_vec = event.p4("resonance").expect("alias vector");
2502 let expected = event.get_p4_sum(["kshort1", "kshort2"]);
2503 assert_relative_eq!(alias_vec.px(), expected.px(), epsilon = 1e-9);
2504 assert_relative_eq!(alias_vec.py(), expected.py(), epsilon = 1e-9);
2505 assert_relative_eq!(alias_vec.pz(), expected.pz(), epsilon = 1e-9);
2506 assert_relative_eq!(alias_vec.e(), expected.e(), epsilon = 1e-9);
2507 }
2508
2509 #[test]
2510 fn test_from_parquet_f64_matches_f32() {
2511 let f32_ds = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2512 let f64_ds = open_test_dataset("data_f64.parquet", DatasetReadOptions::new());
2513 assert_datasets_close(&f64_ds, &f32_ds, TEST_P4_NAMES, TEST_AUX_NAMES);
2514 }
2515
2516 #[test]
2517 fn test_from_root_detects_columns_and_matches_parquet() {
2518 let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2519 let root_auto = open_test_dataset("data_f32.root", DatasetReadOptions::new());
2520 let mut detected_p4: Vec<&str> = root_auto.p4_names().iter().map(String::as_str).collect();
2521 detected_p4.sort_unstable();
2522 let mut expected_p4 = TEST_P4_NAMES.to_vec();
2523 expected_p4.sort_unstable();
2524 assert_eq!(detected_p4, expected_p4);
2525 let mut detected_aux: Vec<&str> =
2526 root_auto.aux_names().iter().map(String::as_str).collect();
2527 detected_aux.sort_unstable();
2528 let mut expected_aux = TEST_AUX_NAMES.to_vec();
2529 expected_aux.sort_unstable();
2530 assert_eq!(detected_aux, expected_aux);
2531 let root_named_options = DatasetReadOptions::new()
2532 .p4_names(TEST_P4_NAMES)
2533 .aux_names(TEST_AUX_NAMES);
2534 let root_named = open_test_dataset("data_f32.root", root_named_options);
2535 assert_datasets_close(&root_auto, &root_named, TEST_P4_NAMES, TEST_AUX_NAMES);
2536 assert_datasets_close(&root_auto, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2537 }
2538
2539 #[test]
2540 fn test_from_root_f64_matches_parquet() {
2541 let parquet = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2542 let root_f64 = open_test_dataset("data_f64.root", DatasetReadOptions::new());
2543 assert_datasets_close(&root_f64, &parquet, TEST_P4_NAMES, TEST_AUX_NAMES);
2544 }
2545 #[test]
2546 fn test_event_creation() {
2547 let event = test_event();
2548 assert_eq!(event.p4s.len(), 4);
2549 assert_eq!(event.aux.len(), 2);
2550 assert_relative_eq!(event.weight, 0.48)
2551 }
2552
2553 #[test]
2554 fn test_event_p4_sum() {
2555 let event = test_event();
2556 let sum = event.get_p4_sum([2, 3]);
2557 assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
2558 assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
2559 assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
2560 assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
2561 }
2562
2563 #[test]
2564 fn test_event_boost() {
2565 let event = test_event();
2566 let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
2567 let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
2568 assert_relative_eq!(p4_sum.px(), 0.0);
2569 assert_relative_eq!(p4_sum.py(), 0.0);
2570 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2571 }
2572
2573 #[test]
2574 fn test_event_evaluate() {
2575 let event = test_event();
2576 let mut mass = Mass::new(["proton"]);
2577 mass.bind(
2578 &DatasetMetadata::new(
2579 TEST_P4_NAMES.iter().map(|s| (*s).to_string()).collect(),
2580 TEST_AUX_NAMES.iter().map(|s| (*s).to_string()).collect(),
2581 )
2582 .expect("metadata"),
2583 )
2584 .unwrap();
2585 assert_relative_eq!(event.evaluate(&mass), 1.007);
2586 }
2587
2588 #[test]
2589 fn test_dataset_size_check() {
2590 let dataset = Dataset::new(Vec::new());
2591 assert_eq!(dataset.n_events(), 0);
2592 let dataset = Dataset::new(vec![Arc::new(test_event())]);
2593 assert_eq!(dataset.n_events(), 1);
2594 }
2595
2596 #[test]
2597 fn test_dataset_sum() {
2598 let dataset = test_dataset();
2599 let metadata = dataset.metadata_arc();
2600 let dataset2 = Dataset::new_with_metadata(
2601 vec![Arc::new(EventData {
2602 p4s: test_event().p4s,
2603 aux: test_event().aux,
2604 weight: 0.52,
2605 })],
2606 metadata.clone(),
2607 );
2608 let dataset_sum = &dataset + &dataset2;
2609 assert_eq!(dataset_sum[0].weight, dataset[0].weight);
2610 assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
2611 }
2612
2613 #[test]
2614 fn test_dataset_weights() {
2615 let dataset = Dataset::new(vec![
2616 Arc::new(test_event()),
2617 Arc::new(EventData {
2618 p4s: test_event().p4s,
2619 aux: test_event().aux,
2620 weight: 0.52,
2621 }),
2622 ]);
2623 let weights = dataset.weights();
2624 assert_eq!(weights.len(), 2);
2625 assert_relative_eq!(weights[0], 0.48);
2626 assert_relative_eq!(weights[1], 0.52);
2627 assert_relative_eq!(dataset.n_events_weighted(), 1.0);
2628 }
2629
2630 #[test]
2631 fn test_dataset_filtering() {
2632 let metadata = Arc::new(
2633 DatasetMetadata::new(vec!["beam"], Vec::<String>::new())
2634 .expect("metadata should be valid"),
2635 );
2636 let events = vec![
2637 Arc::new(EventData {
2638 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
2639 aux: vec![],
2640 weight: 1.0,
2641 }),
2642 Arc::new(EventData {
2643 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
2644 aux: vec![],
2645 weight: 1.0,
2646 }),
2647 Arc::new(EventData {
2648 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
2649 aux: vec![],
2652 weight: 1.0,
2653 }),
2654 ];
2655 let dataset = Dataset::new_with_metadata(events, metadata);
2656
2657 let metadata = dataset.metadata_arc();
2658 let mut mass = Mass::new(["beam"]);
2659 mass.bind(metadata.as_ref()).unwrap();
2660 let expression = mass.gt(0.0).and(&mass.lt(1.0));
2661
2662 let filtered = dataset.filter(&expression).unwrap();
2663 assert_eq!(filtered.n_events(), 1);
2664 assert_relative_eq!(mass.value(&filtered[0]), 0.5);
2665 }
2666
2667 #[test]
2668 fn test_dataset_boost() {
2669 let dataset = test_dataset();
2670 let dataset_boosted = dataset.boost_to_rest_frame_of(&["proton", "kshort1", "kshort2"]);
2671 let p4_sum = dataset_boosted[0].get_p4_sum(["proton", "kshort1", "kshort2"]);
2672 assert_relative_eq!(p4_sum.px(), 0.0);
2673 assert_relative_eq!(p4_sum.py(), 0.0);
2674 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = f64::EPSILON.sqrt());
2675 }
2676
2677 #[test]
2678 fn test_named_event_view() {
2679 let dataset = test_dataset();
2680 let view = dataset.named_event(0);
2681
2682 assert_relative_eq!(view.weight(), dataset[0].weight);
2683 let beam = view.p4("beam").expect("beam p4");
2684 assert_relative_eq!(beam.px(), dataset[0].p4s[0].px());
2685 assert_relative_eq!(beam.e(), dataset[0].p4s[0].e());
2686
2687 let summed = view.get_p4_sum(["kshort1", "kshort2"]);
2688 assert_relative_eq!(summed.e(), dataset[0].p4s[2].e() + dataset[0].p4s[3].e());
2689
2690 let aux_angle = view.aux().get("pol_angle").copied().expect("pol angle");
2691 assert_relative_eq!(aux_angle, dataset[0].aux[1]);
2692
2693 let metadata = dataset.metadata_arc();
2694 let boosted = view.boost_to_rest_frame_of(["proton", "kshort1", "kshort2"]);
2695 let boosted_event = Event::new(Arc::new(boosted), metadata);
2696 let boosted_sum = boosted_event.get_p4_sum(["proton", "kshort1", "kshort2"]);
2697 assert_relative_eq!(boosted_sum.px(), 0.0);
2698 }
2699
2700 #[test]
2701 fn test_dataset_evaluate() {
2702 let dataset = test_dataset();
2703 let mass = Mass::new(["proton"]);
2704 assert_relative_eq!(dataset.evaluate(&mass).unwrap()[0], 1.007);
2705 }
2706
2707 #[test]
2708 fn test_dataset_metadata_rejects_duplicate_names() {
2709 let err = DatasetMetadata::new(vec!["beam", "beam"], Vec::<String>::new());
2710 assert!(matches!(
2711 err,
2712 Err(LadduError::DuplicateName { category, .. }) if category == "p4"
2713 ));
2714 let err = DatasetMetadata::new(
2715 vec!["beam"],
2716 vec!["pol_angle".to_string(), "pol_angle".to_string()],
2717 );
2718 assert!(matches!(
2719 err,
2720 Err(LadduError::DuplicateName { category, .. }) if category == "aux"
2721 ));
2722 }
2723
2724 #[test]
2725 fn test_dataset_lookup_by_name() {
2726 let dataset = test_dataset();
2727 let proton = dataset.p4_by_name(0, "proton").expect("proton p4");
2728 let proton_idx = dataset.metadata().p4_index("proton").unwrap();
2729 assert_relative_eq!(proton.e(), dataset[0].p4s[proton_idx].e());
2730 assert!(dataset.p4_by_name(0, "unknown").is_none());
2731 let angle = dataset.aux_by_name(0, "pol_angle").expect("pol_angle");
2732 assert_relative_eq!(angle, dataset[0].aux[1]);
2733 assert!(dataset.aux_by_name(0, "missing").is_none());
2734 }
2735
2736 #[test]
2737 fn test_binned_dataset() {
2738 let dataset = Dataset::new(vec![
2739 Arc::new(EventData {
2740 p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
2741 aux: vec![],
2742 weight: 1.0,
2743 }),
2744 Arc::new(EventData {
2745 p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
2746 aux: vec![],
2747 weight: 2.0,
2748 }),
2749 ]);
2750
2751 #[derive(Clone, Serialize, Deserialize, Debug)]
2752 struct BeamEnergy;
2753 impl Display for BeamEnergy {
2754 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2755 write!(f, "BeamEnergy")
2756 }
2757 }
2758 #[typetag::serde]
2759 impl Variable for BeamEnergy {
2760 fn value(&self, event: &EventData) -> f64 {
2761 event.p4s[0].e()
2762 }
2763 }
2764 assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
2765
2766 let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0)).unwrap();
2768
2769 assert_eq!(binned.n_bins(), 2);
2770 assert_eq!(binned.edges().len(), 3);
2771 assert_relative_eq!(binned.edges()[0], 0.0);
2772 assert_relative_eq!(binned.edges()[2], 3.0);
2773 assert_eq!(binned[0].n_events(), 1);
2774 assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
2775 assert_eq!(binned[1].n_events(), 1);
2776 assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
2777 }
2778
2779 #[test]
2780 fn test_dataset_bootstrap() {
2781 let metadata = test_dataset().metadata_arc();
2782 let dataset = Dataset::new_with_metadata(
2783 vec![
2784 Arc::new(test_event()),
2785 Arc::new(EventData {
2786 p4s: test_event().p4s.clone(),
2787 aux: test_event().aux.clone(),
2788 weight: 1.0,
2789 }),
2790 ],
2791 metadata,
2792 );
2793 assert_relative_ne!(dataset[0].weight, dataset[1].weight);
2794
2795 let bootstrapped = dataset.bootstrap(43);
2796 assert_eq!(bootstrapped.n_events(), dataset.n_events());
2797 assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
2798
2799 let empty_dataset = Dataset::new(Vec::new());
2801 let empty_bootstrap = empty_dataset.bootstrap(43);
2802 assert_eq!(empty_bootstrap.n_events(), 0);
2803 }
2804
2805 #[test]
2806 fn test_dataset_iteration_returns_events() {
2807 let dataset = test_dataset();
2808 let mut weights = Vec::new();
2809 for event in dataset.iter() {
2810 weights.push(event.weight());
2811 }
2812 assert_eq!(weights.len(), dataset.n_events());
2813 assert_relative_eq!(weights[0], dataset[0].weight);
2814 }
2815
2816 #[test]
2817 fn test_dataset_into_iter_returns_events() {
2818 let dataset = test_dataset();
2819 let weights: Vec<f64> = dataset.into_iter().map(|event| event.weight()).collect();
2820 assert_eq!(weights.len(), 1);
2821 assert_relative_eq!(weights[0], test_event().weight);
2822 }
2823 #[test]
2824 fn test_event_display() {
2825 let event = test_event();
2826 let display_string = format!("{}", event);
2827 assert!(display_string.contains("Event:"));
2828 assert!(display_string.contains("p4s:"));
2829 assert!(display_string.contains("aux:"));
2830 assert!(display_string.contains("aux[0]: 0.38562805"));
2831 assert!(display_string.contains("aux[1]: 0.05708078"));
2832 assert!(display_string.contains("weight:"));
2833 }
2834
2835 #[test]
2836 fn test_name_based_access() {
2837 let metadata =
2838 Arc::new(DatasetMetadata::new(vec!["beam", "target"], vec!["pol_angle"]).unwrap());
2839 let event = Arc::new(EventData {
2840 p4s: vec![Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(0.1, 0.2, 0.3, 0.5)],
2841 aux: vec![0.42],
2842 weight: 1.0,
2843 });
2844 let dataset = Dataset::new_with_metadata(vec![event], metadata);
2845 let beam = dataset.p4_by_name(0, "beam").unwrap();
2846 assert_relative_eq!(beam.px(), 0.0);
2847 assert_relative_eq!(beam.py(), 0.0);
2848 assert_relative_eq!(beam.pz(), 1.0);
2849 assert_relative_eq!(beam.e(), 1.0);
2850 assert_relative_eq!(dataset.aux_by_name(0, "pol_angle").unwrap(), 0.42);
2851 assert!(dataset.p4_by_name(0, "missing").is_none());
2852 assert!(dataset.aux_by_name(0, "missing").is_none());
2853 }
2854
2855 #[test]
2856 fn test_parquet_roundtrip_to_tempfile() {
2857 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2858 let dir = make_temp_dir();
2859 let path = dir.join("roundtrip.parquet");
2860 let path_str = path.to_str().expect("path should be valid UTF-8");
2861
2862 dataset
2863 .to_parquet(path_str, &DatasetWriteOptions::default())
2864 .expect("writing parquet should succeed");
2865 let reopened = Dataset::from_parquet(path_str, &DatasetReadOptions::new())
2866 .expect("parquet roundtrip should reopen");
2867
2868 assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2869 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2870 }
2871
2872 #[test]
2873 fn test_root_roundtrip_to_tempfile() {
2874 let dataset = open_test_dataset("data_f32.parquet", DatasetReadOptions::new());
2875 let dir = make_temp_dir();
2876 let path = dir.join("roundtrip.root");
2877 let path_str = path.to_str().expect("path should be valid UTF-8");
2878
2879 dataset
2880 .to_root(path_str, &DatasetWriteOptions::default())
2881 .expect("writing root should succeed");
2882 let reopened = Dataset::from_root(path_str, &DatasetReadOptions::new())
2883 .expect("root roundtrip should reopen");
2884
2885 assert_datasets_close(&dataset, &reopened, TEST_P4_NAMES, TEST_AUX_NAMES);
2886 fs::remove_dir_all(&dir).expect("temp dir cleanup should succeed");
2887 }
2888}