1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23 utils::{
24 variables::Variable,
25 vectors::{Vec3, Vec4},
26 },
27 Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33pub fn test_event() -> Event {
37 use crate::utils::vectors::*;
38 Event {
39 p4s: vec![
40 Vec3::new(0.0, 0.0, 8.747).with_mass(0.0), Vec3::new(0.119, 0.374, 0.222).with_mass(1.007), Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498), Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), ],
45 aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46 weight: 0.48,
47 }
48}
49
50pub fn test_dataset() -> Dataset {
54 Dataset::new(vec![Arc::new(test_event())])
55}
56
57#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60 pub p4s: Vec<Vec4>,
62 pub aux: Vec<Vec3>,
64 pub weight: Float,
66}
67
68impl Display for Event {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 writeln!(f, "Event:")?;
71 writeln!(f, " p4s:")?;
72 for p4 in &self.p4s {
73 writeln!(f, " {}", p4.to_p4_string())?;
74 }
75 writeln!(f, " eps:")?;
76 for eps_vec in &self.aux {
77 writeln!(f, " [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78 }
79 writeln!(f, " weight:")?;
80 writeln!(f, " {}", self.weight)?;
81 Ok(())
82 }
83}
84
85impl Event {
86 pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88 indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89 }
90 pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
93 let frame = self.get_p4_sum(indices);
94 Event {
95 p4s: self
96 .p4s
97 .iter()
98 .map(|p4| p4.boost(&(-frame.beta())))
99 .collect(),
100 aux: self.aux.clone(),
101 weight: self.weight,
102 }
103 }
104}
105
106#[derive(Debug, Clone, Default)]
108pub struct Dataset {
109 pub events: Vec<Arc<Event>>,
111}
112
113impl Dataset {
114 pub fn index_local(&self, index: usize) -> &Event {
128 &self.events[index]
129 }
130
131 #[cfg(feature = "mpi")]
132 fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
133 for (i, &displ) in displs.iter().enumerate() {
134 if displ as usize > index {
135 return (i as i32 - 1, index - displs[i - 1] as usize);
136 }
137 }
138 (
139 world.size() - 1,
140 index - displs[world.size() as usize - 1] as usize,
141 )
142 }
143
144 #[cfg(feature = "mpi")]
145 fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
146 let (counts, displs) = world.get_counts_displs(events.len());
147 counts
148 .iter()
149 .zip(displs.iter())
150 .map(|(&count, &displ)| {
151 events
152 .iter()
153 .skip(displ as usize)
154 .take(count as usize)
155 .cloned()
156 .collect()
157 })
158 .collect()
159 }
160
161 #[cfg(feature = "mpi")]
175 pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
176 let (_, displs) = world.get_counts_displs(self.n_events());
177 let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
178 let mut serialized_event_buffer_len: usize = 0;
179 let mut serialized_event_buffer: Vec<u8> = Vec::default();
180 let config = bincode::config::standard();
181 if world.rank() == owning_rank {
182 let event = self.index_local(local_index);
183 serialized_event_buffer = bincode::serde::encode_to_vec(event, config).unwrap();
184 serialized_event_buffer_len = serialized_event_buffer.len();
185 }
186 world
187 .process_at_rank(owning_rank)
188 .broadcast_into(&mut serialized_event_buffer_len);
189 if world.rank() != owning_rank {
190 serialized_event_buffer = vec![0; serialized_event_buffer_len];
191 }
192 world
193 .process_at_rank(owning_rank)
194 .broadcast_into(&mut serialized_event_buffer);
195 let (event, _): (Event, usize) =
196 bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
197 Box::leak(Box::new(event))
198 }
199}
200
201impl Index<usize> for Dataset {
202 type Output = Event;
203
204 fn index(&self, index: usize) -> &Self::Output {
205 #[cfg(feature = "mpi")]
206 {
207 if let Some(world) = crate::mpi::get_world() {
208 return self.index_mpi(index, &world);
209 }
210 }
211 self.index_local(index)
212 }
213}
214
215impl Dataset {
216 pub fn new_local(events: Vec<Arc<Event>>) -> Self {
223 Dataset { events }
224 }
225
226 #[cfg(feature = "mpi")]
233 pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
234 Dataset {
235 events: Dataset::partition(events, world)[world.rank() as usize].clone(),
236 }
237 }
238
239 pub fn new(events: Vec<Arc<Event>>) -> Self {
245 #[cfg(feature = "mpi")]
246 {
247 if let Some(world) = crate::mpi::get_world() {
248 return Dataset::new_mpi(events, &world);
249 }
250 }
251 Dataset::new_local(events)
252 }
253
254 pub fn n_events_local(&self) -> usize {
261 self.events.len()
262 }
263
264 #[cfg(feature = "mpi")]
271 pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
272 let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
273 let n_events_local = self.n_events_local();
274 world.all_gather_into(&n_events_local, &mut n_events_partitioned);
275 n_events_partitioned.iter().sum()
276 }
277
278 pub fn n_events(&self) -> usize {
280 #[cfg(feature = "mpi")]
281 {
282 if let Some(world) = crate::mpi::get_world() {
283 return self.n_events_mpi(&world);
284 }
285 }
286 self.n_events_local()
287 }
288}
289
290impl Dataset {
291 pub fn weights_local(&self) -> Vec<Float> {
298 #[cfg(feature = "rayon")]
299 return self.events.par_iter().map(|e| e.weight).collect();
300 #[cfg(not(feature = "rayon"))]
301 return self.events.iter().map(|e| e.weight).collect();
302 }
303
304 #[cfg(feature = "mpi")]
311 pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
312 let local_weights = self.weights_local();
313 let n_events = self.n_events();
314 let mut buffer: Vec<Float> = vec![0.0; n_events];
315 let (counts, displs) = world.get_counts_displs(n_events);
316 {
317 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
318 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
319 }
320 buffer
321 }
322
323 pub fn weights(&self) -> Vec<Float> {
325 #[cfg(feature = "mpi")]
326 {
327 if let Some(world) = crate::mpi::get_world() {
328 return self.weights_mpi(&world);
329 }
330 }
331 self.weights_local()
332 }
333
334 pub fn n_events_weighted_local(&self) -> Float {
341 #[cfg(feature = "rayon")]
342 return self
343 .events
344 .par_iter()
345 .map(|e| e.weight)
346 .parallel_sum_with_accumulator::<Klein<Float>>();
347 #[cfg(not(feature = "rayon"))]
348 return self.events.iter().map(|e| e.weight).sum();
349 }
350 #[cfg(feature = "mpi")]
357 pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
358 let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
359 let n_events_weighted_local = self.n_events_weighted_local();
360 world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
361 #[cfg(feature = "rayon")]
362 return n_events_weighted_partitioned
363 .into_par_iter()
364 .parallel_sum_with_accumulator::<Klein<Float>>();
365 #[cfg(not(feature = "rayon"))]
366 return n_events_weighted_partitioned.iter().sum();
367 }
368
369 pub fn n_events_weighted(&self) -> Float {
371 #[cfg(feature = "mpi")]
372 {
373 if let Some(world) = crate::mpi::get_world() {
374 return self.n_events_weighted_mpi(&world);
375 }
376 }
377 self.n_events_weighted_local()
378 }
379
380 pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
388 let mut rng = fastrand::Rng::with_seed(seed as u64);
389 let mut indices: Vec<usize> = (0..self.n_events())
390 .map(|_| rng.usize(0..self.n_events()))
391 .collect::<Vec<usize>>();
392 indices.sort();
393 #[cfg(feature = "rayon")]
394 let bootstrapped_events: Vec<Arc<Event>> = indices
395 .into_par_iter()
396 .map(|idx| self.events[idx].clone())
397 .collect();
398 #[cfg(not(feature = "rayon"))]
399 let bootstrapped_events: Vec<Arc<Event>> = indices
400 .into_iter()
401 .map(|idx| self.events[idx].clone())
402 .collect();
403 Arc::new(Dataset {
404 events: bootstrapped_events,
405 })
406 }
407
408 #[cfg(feature = "mpi")]
416 pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
417 let n_events = self.n_events();
418 let mut indices: Vec<usize> = vec![0; n_events];
419 if world.is_root() {
420 let mut rng = fastrand::Rng::with_seed(seed as u64);
421 indices = (0..n_events)
422 .map(|_| rng.usize(0..n_events))
423 .collect::<Vec<usize>>();
424 indices.sort();
425 }
426 world.process_at_root().broadcast_into(&mut indices);
427 let (_, displs) = world.get_counts_displs(self.n_events());
428 let local_indices: Vec<usize> = indices
429 .into_iter()
430 .filter_map(|idx| {
431 let (owning_rank, local_index) = Dataset::get_rank_index(idx, &displs, world);
432 if world.rank() == owning_rank {
433 Some(local_index)
434 } else {
435 None
436 }
437 })
438 .collect();
439 #[cfg(feature = "rayon")]
442 let bootstrapped_events: Vec<Arc<Event>> = local_indices
443 .into_par_iter()
444 .map(|idx| self.events[idx].clone())
445 .collect();
446 #[cfg(not(feature = "rayon"))]
447 let bootstrapped_events: Vec<Arc<Event>> = local_indices
448 .into_iter()
449 .map(|idx| self.events[idx].clone())
450 .collect();
451 Arc::new(Dataset {
452 events: bootstrapped_events,
453 })
454 }
455
456 pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
459 #[cfg(feature = "mpi")]
460 {
461 if let Some(world) = crate::mpi::get_world() {
462 return self.bootstrap_mpi(seed, &world);
463 }
464 }
465 self.bootstrap_local(seed)
466 }
467
468 pub fn filter<P>(&self, predicate: P) -> Arc<Dataset>
471 where
472 P: Fn(&Event) -> bool + Send + Sync,
473 {
474 #[cfg(feature = "rayon")]
475 let filtered_events = self
476 .events
477 .par_iter()
478 .filter(|e| predicate(e))
479 .cloned()
480 .collect();
481 #[cfg(not(feature = "rayon"))]
482 let filtered_events = self
483 .events
484 .iter()
485 .filter(|e| predicate(e))
486 .cloned()
487 .collect();
488 Arc::new(Dataset {
489 events: filtered_events,
490 })
491 }
492
493 pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
496 where
497 V: Variable,
498 {
499 let bin_width = (range.1 - range.0) / bins as Float;
500 let bin_edges = get_bin_edges(bins, range);
501 #[cfg(feature = "rayon")]
502 let evaluated: Vec<(usize, &Arc<Event>)> = self
503 .events
504 .par_iter()
505 .filter_map(|event| {
506 let value = variable.value(event.as_ref());
507 if value >= range.0 && value < range.1 {
508 let bin_index = ((value - range.0) / bin_width) as usize;
509 let bin_index = bin_index.min(bins - 1);
510 Some((bin_index, event))
511 } else {
512 None
513 }
514 })
515 .collect();
516 #[cfg(not(feature = "rayon"))]
517 let evaluated: Vec<(usize, &Arc<Event>)> = self
518 .events
519 .iter()
520 .filter_map(|event| {
521 let value = variable.value(event.as_ref());
522 if value >= range.0 && value < range.1 {
523 let bin_index = ((value - range.0) / bin_width) as usize;
524 let bin_index = bin_index.min(bins - 1);
525 Some((bin_index, event))
526 } else {
527 None
528 }
529 })
530 .collect();
531 let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
532 for (bin_index, event) in evaluated {
533 binned_events[bin_index].push(event.clone());
534 }
535 BinnedDataset {
536 #[cfg(feature = "rayon")]
537 datasets: binned_events
538 .into_par_iter()
539 .map(|events| Arc::new(Dataset { events }))
540 .collect(),
541 #[cfg(not(feature = "rayon"))]
542 datasets: binned_events
543 .into_iter()
544 .map(|events| Arc::new(Dataset { events }))
545 .collect(),
546 edges: bin_edges,
547 }
548 }
549
550 pub fn boost_to_rest_frame_of<T: AsRef<[usize]> + Sync>(&self, indices: T) -> Arc<Dataset> {
553 #[cfg(feature = "rayon")]
554 {
555 Arc::new(Dataset {
556 events: self
557 .events
558 .par_iter()
559 .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
560 .collect(),
561 })
562 }
563 #[cfg(not(feature = "rayon"))]
564 {
565 Arc::new(Dataset {
566 events: self
567 .events
568 .iter()
569 .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
570 .collect(),
571 })
572 }
573 }
574}
575
576impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
577
578fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
579 let mut p4s = Vec::new();
580 let mut aux = Vec::new();
581
582 let p4_count = batch
583 .schema()
584 .fields()
585 .iter()
586 .filter(|field| field.name().starts_with(P4_PREFIX))
587 .count()
588 / 4;
589 let aux_count = batch
590 .schema()
591 .fields()
592 .iter()
593 .filter(|field| field.name().starts_with(AUX_PREFIX))
594 .count()
595 / 3;
596
597 for i in 0..p4_count {
598 let e = batch
599 .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
600 .unwrap()
601 .as_any()
602 .downcast_ref::<Float32Array>()
603 .unwrap()
604 .value(row) as Float;
605 let px = batch
606 .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
607 .unwrap()
608 .as_any()
609 .downcast_ref::<Float32Array>()
610 .unwrap()
611 .value(row) as Float;
612 let py = batch
613 .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
614 .unwrap()
615 .as_any()
616 .downcast_ref::<Float32Array>()
617 .unwrap()
618 .value(row) as Float;
619 let pz = batch
620 .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
621 .unwrap()
622 .as_any()
623 .downcast_ref::<Float32Array>()
624 .unwrap()
625 .value(row) as Float;
626 p4s.push(Vec4::new(px, py, pz, e));
627 }
628
629 for i in 0..aux_count {
631 let x = batch
632 .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
633 .unwrap()
634 .as_any()
635 .downcast_ref::<Float32Array>()
636 .unwrap()
637 .value(row) as Float;
638 let y = batch
639 .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
640 .unwrap()
641 .as_any()
642 .downcast_ref::<Float32Array>()
643 .unwrap()
644 .value(row) as Float;
645 let z = batch
646 .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
647 .unwrap()
648 .as_any()
649 .downcast_ref::<Float32Array>()
650 .unwrap()
651 .value(row) as Float;
652 aux.push(Vec3::new(x, y, z));
653 }
654
655 let weight = batch
656 .column(19)
657 .as_any()
658 .downcast_ref::<Float32Array>()
659 .unwrap()
660 .value(row) as Float;
661
662 Event { p4s, aux, weight }
663}
664
665pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
667 let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
669 let file = File::open(file_path)?;
670 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
671 let reader = builder.build()?;
672 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
673
674 #[cfg(feature = "rayon")]
675 let events: Vec<Arc<Event>> = batches
676 .into_par_iter()
677 .flat_map(|batch| {
678 let num_rows = batch.num_rows();
679 let mut local_events = Vec::with_capacity(num_rows);
680
681 for row in 0..num_rows {
683 let event = batch_to_event(&batch, row);
684 local_events.push(Arc::new(event));
685 }
686 local_events
687 })
688 .collect();
689 #[cfg(not(feature = "rayon"))]
690 let events: Vec<Arc<Event>> = batches
691 .into_iter()
692 .flat_map(|batch| {
693 let num_rows = batch.num_rows();
694 let mut local_events = Vec::with_capacity(num_rows);
695
696 for row in 0..num_rows {
698 let event = batch_to_event(&batch, row);
699 local_events.push(Arc::new(event));
700 }
701 local_events
702 })
703 .collect();
704 Ok(Arc::new(Dataset::new(events)))
705}
706
707pub fn open_boosted_to_rest_frame_of<T: AsRef<str>, I: AsRef<[usize]> + Sync>(
710 file_path: T,
711 indices: I,
712) -> Result<Arc<Dataset>, LadduError> {
713 let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
715 let file = File::open(file_path)?;
716 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
717 let reader = builder.build()?;
718 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
719
720 #[cfg(feature = "rayon")]
721 let events: Vec<Arc<Event>> = batches
722 .into_par_iter()
723 .flat_map(|batch| {
724 let num_rows = batch.num_rows();
725 let mut local_events = Vec::with_capacity(num_rows);
726
727 for row in 0..num_rows {
729 let mut event = batch_to_event(&batch, row);
730 event = event.boost_to_rest_frame_of(indices.as_ref());
731 local_events.push(Arc::new(event));
732 }
733 local_events
734 })
735 .collect();
736 #[cfg(not(feature = "rayon"))]
737 let events: Vec<Arc<Event>> = batches
738 .into_iter()
739 .flat_map(|batch| {
740 let num_rows = batch.num_rows();
741 let mut local_events = Vec::with_capacity(num_rows);
742
743 for row in 0..num_rows {
745 let mut event = batch_to_event(&batch, row);
746 event = event.boost_to_rest_frame_of(indices.as_ref());
747 local_events.push(Arc::new(event));
748 }
749 local_events
750 })
751 .collect();
752 Ok(Arc::new(Dataset::new(events)))
753}
754
755pub struct BinnedDataset {
757 datasets: Vec<Arc<Dataset>>,
758 edges: Vec<Float>,
759}
760
761impl Index<usize> for BinnedDataset {
762 type Output = Arc<Dataset>;
763
764 fn index(&self, index: usize) -> &Self::Output {
765 &self.datasets[index]
766 }
767}
768
769impl IndexMut<usize> for BinnedDataset {
770 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
771 &mut self.datasets[index]
772 }
773}
774
775impl Deref for BinnedDataset {
776 type Target = Vec<Arc<Dataset>>;
777
778 fn deref(&self) -> &Self::Target {
779 &self.datasets
780 }
781}
782
783impl DerefMut for BinnedDataset {
784 fn deref_mut(&mut self) -> &mut Self::Target {
785 &mut self.datasets
786 }
787}
788
789impl BinnedDataset {
790 pub fn n_bins(&self) -> usize {
792 self.datasets.len()
793 }
794
795 pub fn edges(&self) -> Vec<Float> {
797 self.edges.clone()
798 }
799
800 pub fn range(&self) -> (Float, Float) {
802 (self.edges[0], self.edges[self.n_bins()])
803 }
804}
805
806#[cfg(test)]
807mod tests {
808 use super::*;
809 use approx::{assert_relative_eq, assert_relative_ne};
810 use serde::{Deserialize, Serialize};
811 #[test]
812 fn test_event_creation() {
813 let event = test_event();
814 assert_eq!(event.p4s.len(), 4);
815 assert_eq!(event.aux.len(), 1);
816 assert_relative_eq!(event.weight, 0.48)
817 }
818
819 #[test]
820 fn test_event_p4_sum() {
821 let event = test_event();
822 let sum = event.get_p4_sum([2, 3]);
823 assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
824 assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
825 assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
826 assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
827 }
828
829 #[test]
830 fn test_event_boost() {
831 let event = test_event();
832 let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
833 let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
834 assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
835 assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
836 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
837 }
838
839 #[test]
840 fn test_dataset_size_check() {
841 let mut dataset = Dataset::default();
842 assert_eq!(dataset.n_events(), 0);
843 dataset.events.push(Arc::new(test_event()));
844 assert_eq!(dataset.n_events(), 1);
845 }
846
847 #[test]
848 fn test_dataset_sum() {
849 let dataset = test_dataset();
850 let dataset2 = Dataset::new(vec![Arc::new(Event {
851 p4s: test_event().p4s,
852 aux: test_event().aux,
853 weight: 0.52,
854 })]);
855 let dataset_sum = &dataset + &dataset2;
856 assert_eq!(dataset_sum[0].weight, dataset[0].weight);
857 assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
858 }
859
860 #[test]
861 fn test_dataset_weights() {
862 let mut dataset = Dataset::default();
863 dataset.events.push(Arc::new(test_event()));
864 dataset.events.push(Arc::new(Event {
865 p4s: test_event().p4s,
866 aux: test_event().aux,
867 weight: 0.52,
868 }));
869 let weights = dataset.weights();
870 assert_eq!(weights.len(), 2);
871 assert_relative_eq!(weights[0], 0.48);
872 assert_relative_eq!(weights[1], 0.52);
873 assert_relative_eq!(dataset.n_events_weighted(), 1.0);
874 }
875
876 #[test]
877 fn test_dataset_filtering() {
878 let mut dataset = test_dataset();
879 dataset.events.push(Arc::new(Event {
880 p4s: vec![
881 Vec3::new(0.0, 0.0, 5.0).with_mass(0.0),
882 Vec3::new(0.0, 0.0, 1.0).with_mass(1.0),
883 ],
884 aux: vec![],
885 weight: 1.0,
886 }));
887
888 let filtered = dataset.filter(|event| event.p4s.len() == 2);
889 assert_eq!(filtered.n_events(), 1);
890 assert_eq!(filtered[0].p4s.len(), 2);
891 }
892
893 #[test]
894 fn test_dataset_boost() {
895 let dataset = test_dataset();
896 let dataset_boosted = dataset.boost_to_rest_frame_of([1, 2, 3]);
897 let p4_sum = dataset_boosted[0].get_p4_sum([1, 2, 3]);
898 assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
899 assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
900 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
901 }
902
903 #[test]
904 fn test_binned_dataset() {
905 let dataset = Dataset::new(vec![
906 Arc::new(Event {
907 p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
908 aux: vec![],
909 weight: 1.0,
910 }),
911 Arc::new(Event {
912 p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
913 aux: vec![],
914 weight: 2.0,
915 }),
916 ]);
917
918 #[derive(Clone, Serialize, Deserialize, Debug)]
919 struct BeamEnergy;
920 impl Display for BeamEnergy {
921 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
922 write!(f, "BeamEnergy")
923 }
924 }
925 #[typetag::serde]
926 impl Variable for BeamEnergy {
927 fn value(&self, event: &Event) -> Float {
928 event.p4s[0].e()
929 }
930 }
931 assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
932
933 let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
935
936 assert_eq!(binned.n_bins(), 2);
937 assert_eq!(binned.edges().len(), 3);
938 assert_relative_eq!(binned.edges()[0], 0.0);
939 assert_relative_eq!(binned.edges()[2], 3.0);
940 assert_eq!(binned[0].n_events(), 1);
941 assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
942 assert_eq!(binned[1].n_events(), 1);
943 assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
944 }
945
946 #[test]
947 fn test_dataset_bootstrap() {
948 let mut dataset = test_dataset();
949 dataset.events.push(Arc::new(Event {
950 p4s: test_event().p4s.clone(),
951 aux: test_event().aux.clone(),
952 weight: 1.0,
953 }));
954 assert_relative_ne!(dataset[0].weight, dataset[1].weight);
955
956 let bootstrapped = dataset.bootstrap(43);
957 assert_eq!(bootstrapped.n_events(), dataset.n_events());
958 assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
959
960 let empty_dataset = Dataset::default();
962 let empty_bootstrap = empty_dataset.bootstrap(43);
963 assert_eq!(empty_bootstrap.n_events(), 0);
964 }
965 #[test]
966 fn test_event_display() {
967 let event = test_event();
968 let display_string = format!("{}", event);
969 assert_eq!(
970 display_string,
971 "Event:\n p4s:\n [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n eps:\n [0.385, 0.022, 0]\n weight:\n 0.48\n"
972 );
973 }
974}