1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23 utils::{
24 variables::{Variable, VariableExpression},
25 vectors::{Vec3, Vec4},
26 },
27 Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33pub fn test_event() -> Event {
37 use crate::utils::vectors::*;
38 Event {
39 p4s: vec![
40 Vec3::new(0.0, 0.0, 8.747).with_mass(0.0), Vec3::new(0.119, 0.374, 0.222).with_mass(1.007), Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498), Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), ],
45 aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46 weight: 0.48,
47 }
48}
49
50pub fn test_dataset() -> Dataset {
54 Dataset::new(vec![Arc::new(test_event())])
55}
56
57#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60 pub p4s: Vec<Vec4>,
62 pub aux: Vec<Vec3>,
64 pub weight: Float,
66}
67
68impl Display for Event {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 writeln!(f, "Event:")?;
71 writeln!(f, " p4s:")?;
72 for p4 in &self.p4s {
73 writeln!(f, " {}", p4.to_p4_string())?;
74 }
75 writeln!(f, " eps:")?;
76 for eps_vec in &self.aux {
77 writeln!(f, " [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78 }
79 writeln!(f, " weight:")?;
80 writeln!(f, " {}", self.weight)?;
81 Ok(())
82 }
83}
84
85impl Event {
86 pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88 indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89 }
90 pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
93 let frame = self.get_p4_sum(indices);
94 Event {
95 p4s: self
96 .p4s
97 .iter()
98 .map(|p4| p4.boost(&(-frame.beta())))
99 .collect(),
100 aux: self.aux.clone(),
101 weight: self.weight,
102 }
103 }
104 pub fn evaluate<V: Variable>(&self, variable: &V) -> Float {
106 variable.value(self)
107 }
108}
109
110#[derive(Debug, Clone, Default)]
112pub struct Dataset {
113 pub events: Vec<Arc<Event>>,
115}
116
117impl Dataset {
118 pub fn index_local(&self, index: usize) -> &Event {
132 &self.events[index]
133 }
134
135 #[cfg(feature = "mpi")]
136 fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
137 for (i, &displ) in displs.iter().enumerate() {
138 if displ as usize > index {
139 return (i as i32 - 1, index - displs[i - 1] as usize);
140 }
141 }
142 (
143 world.size() - 1,
144 index - displs[world.size() as usize - 1] as usize,
145 )
146 }
147
148 #[cfg(feature = "mpi")]
149 fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
150 let (counts, displs) = world.get_counts_displs(events.len());
151 counts
152 .iter()
153 .zip(displs.iter())
154 .map(|(&count, &displ)| {
155 events
156 .iter()
157 .skip(displ as usize)
158 .take(count as usize)
159 .cloned()
160 .collect()
161 })
162 .collect()
163 }
164
165 #[cfg(feature = "mpi")]
179 pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
180 let (_, displs) = world.get_counts_displs(self.n_events());
181 let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
182 let mut serialized_event_buffer_len: usize = 0;
183 let mut serialized_event_buffer: Vec<u8> = Vec::default();
184 let config = bincode::config::standard();
185 if world.rank() == owning_rank {
186 let event = self.index_local(local_index);
187 serialized_event_buffer = bincode::serde::encode_to_vec(event, config).unwrap();
188 serialized_event_buffer_len = serialized_event_buffer.len();
189 }
190 world
191 .process_at_rank(owning_rank)
192 .broadcast_into(&mut serialized_event_buffer_len);
193 if world.rank() != owning_rank {
194 serialized_event_buffer = vec![0; serialized_event_buffer_len];
195 }
196 world
197 .process_at_rank(owning_rank)
198 .broadcast_into(&mut serialized_event_buffer);
199 let (event, _): (Event, usize) =
200 bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
201 Box::leak(Box::new(event))
202 }
203}
204
205impl Index<usize> for Dataset {
206 type Output = Event;
207
208 fn index(&self, index: usize) -> &Self::Output {
209 #[cfg(feature = "mpi")]
210 {
211 if let Some(world) = crate::mpi::get_world() {
212 return self.index_mpi(index, &world);
213 }
214 }
215 self.index_local(index)
216 }
217}
218
219impl Dataset {
220 pub fn new_local(events: Vec<Arc<Event>>) -> Self {
227 Dataset { events }
228 }
229
230 #[cfg(feature = "mpi")]
237 pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
238 Dataset {
239 events: Dataset::partition(events, world)[world.rank() as usize].clone(),
240 }
241 }
242
243 pub fn new(events: Vec<Arc<Event>>) -> Self {
249 #[cfg(feature = "mpi")]
250 {
251 if let Some(world) = crate::mpi::get_world() {
252 return Dataset::new_mpi(events, &world);
253 }
254 }
255 Dataset::new_local(events)
256 }
257
258 pub fn n_events_local(&self) -> usize {
265 self.events.len()
266 }
267
268 #[cfg(feature = "mpi")]
275 pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
276 let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
277 let n_events_local = self.n_events_local();
278 world.all_gather_into(&n_events_local, &mut n_events_partitioned);
279 n_events_partitioned.iter().sum()
280 }
281
282 pub fn n_events(&self) -> usize {
284 #[cfg(feature = "mpi")]
285 {
286 if let Some(world) = crate::mpi::get_world() {
287 return self.n_events_mpi(&world);
288 }
289 }
290 self.n_events_local()
291 }
292}
293
294impl Dataset {
295 pub fn weights_local(&self) -> Vec<Float> {
302 #[cfg(feature = "rayon")]
303 return self.events.par_iter().map(|e| e.weight).collect();
304 #[cfg(not(feature = "rayon"))]
305 return self.events.iter().map(|e| e.weight).collect();
306 }
307
308 #[cfg(feature = "mpi")]
315 pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
316 let local_weights = self.weights_local();
317 let n_events = self.n_events();
318 let mut buffer: Vec<Float> = vec![0.0; n_events];
319 let (counts, displs) = world.get_counts_displs(n_events);
320 {
321 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
322 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
323 }
324 buffer
325 }
326
327 pub fn weights(&self) -> Vec<Float> {
329 #[cfg(feature = "mpi")]
330 {
331 if let Some(world) = crate::mpi::get_world() {
332 return self.weights_mpi(&world);
333 }
334 }
335 self.weights_local()
336 }
337
338 pub fn n_events_weighted_local(&self) -> Float {
345 #[cfg(feature = "rayon")]
346 return self
347 .events
348 .par_iter()
349 .map(|e| e.weight)
350 .parallel_sum_with_accumulator::<Klein<Float>>();
351 #[cfg(not(feature = "rayon"))]
352 return self.events.iter().map(|e| e.weight).sum();
353 }
354 #[cfg(feature = "mpi")]
361 pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
362 let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
363 let n_events_weighted_local = self.n_events_weighted_local();
364 world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
365 #[cfg(feature = "rayon")]
366 return n_events_weighted_partitioned
367 .into_par_iter()
368 .parallel_sum_with_accumulator::<Klein<Float>>();
369 #[cfg(not(feature = "rayon"))]
370 return n_events_weighted_partitioned.iter().sum();
371 }
372
373 pub fn n_events_weighted(&self) -> Float {
375 #[cfg(feature = "mpi")]
376 {
377 if let Some(world) = crate::mpi::get_world() {
378 return self.n_events_weighted_mpi(&world);
379 }
380 }
381 self.n_events_weighted_local()
382 }
383
384 pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
392 let mut rng = fastrand::Rng::with_seed(seed as u64);
393 let mut indices: Vec<usize> = (0..self.n_events())
394 .map(|_| rng.usize(0..self.n_events()))
395 .collect::<Vec<usize>>();
396 indices.sort();
397 #[cfg(feature = "rayon")]
398 let bootstrapped_events: Vec<Arc<Event>> = indices
399 .into_par_iter()
400 .map(|idx| self.events[idx].clone())
401 .collect();
402 #[cfg(not(feature = "rayon"))]
403 let bootstrapped_events: Vec<Arc<Event>> = indices
404 .into_iter()
405 .map(|idx| self.events[idx].clone())
406 .collect();
407 Arc::new(Dataset {
408 events: bootstrapped_events,
409 })
410 }
411
412 #[cfg(feature = "mpi")]
420 pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
421 let n_events = self.n_events();
422 let mut indices: Vec<usize> = vec![0; n_events];
423 if world.is_root() {
424 let mut rng = fastrand::Rng::with_seed(seed as u64);
425 indices = (0..n_events)
426 .map(|_| rng.usize(0..n_events))
427 .collect::<Vec<usize>>();
428 indices.sort();
429 }
430 world.process_at_root().broadcast_into(&mut indices);
431 let (_, displs) = world.get_counts_displs(self.n_events());
432 let local_indices: Vec<usize> = indices
433 .into_iter()
434 .filter_map(|idx| {
435 let (owning_rank, local_index) = Dataset::get_rank_index(idx, &displs, world);
436 if world.rank() == owning_rank {
437 Some(local_index)
438 } else {
439 None
440 }
441 })
442 .collect();
443 #[cfg(feature = "rayon")]
446 let bootstrapped_events: Vec<Arc<Event>> = local_indices
447 .into_par_iter()
448 .map(|idx| self.events[idx].clone())
449 .collect();
450 #[cfg(not(feature = "rayon"))]
451 let bootstrapped_events: Vec<Arc<Event>> = local_indices
452 .into_iter()
453 .map(|idx| self.events[idx].clone())
454 .collect();
455 Arc::new(Dataset {
456 events: bootstrapped_events,
457 })
458 }
459
460 pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
463 #[cfg(feature = "mpi")]
464 {
465 if let Some(world) = crate::mpi::get_world() {
466 return self.bootstrap_mpi(seed, &world);
467 }
468 }
469 self.bootstrap_local(seed)
470 }
471
472 pub fn filter(&self, expression: &VariableExpression) -> Arc<Dataset> {
475 let compiled = expression.compile();
476 #[cfg(feature = "rayon")]
477 let filtered_events = self
478 .events
479 .par_iter()
480 .filter(|e| compiled.evaluate(e))
481 .cloned()
482 .collect();
483 #[cfg(not(feature = "rayon"))]
484 let filtered_events = self
485 .events
486 .iter()
487 .filter(|e| compiled.evaluate(e))
488 .cloned()
489 .collect();
490 Arc::new(Dataset {
491 events: filtered_events,
492 })
493 }
494
495 pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
498 where
499 V: Variable,
500 {
501 let bin_width = (range.1 - range.0) / bins as Float;
502 let bin_edges = get_bin_edges(bins, range);
503 #[cfg(feature = "rayon")]
504 let evaluated: Vec<(usize, &Arc<Event>)> = self
505 .events
506 .par_iter()
507 .filter_map(|event| {
508 let value = variable.value(event.as_ref());
509 if value >= range.0 && value < range.1 {
510 let bin_index = ((value - range.0) / bin_width) as usize;
511 let bin_index = bin_index.min(bins - 1);
512 Some((bin_index, event))
513 } else {
514 None
515 }
516 })
517 .collect();
518 #[cfg(not(feature = "rayon"))]
519 let evaluated: Vec<(usize, &Arc<Event>)> = self
520 .events
521 .iter()
522 .filter_map(|event| {
523 let value = variable.value(event.as_ref());
524 if value >= range.0 && value < range.1 {
525 let bin_index = ((value - range.0) / bin_width) as usize;
526 let bin_index = bin_index.min(bins - 1);
527 Some((bin_index, event))
528 } else {
529 None
530 }
531 })
532 .collect();
533 let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
534 for (bin_index, event) in evaluated {
535 binned_events[bin_index].push(event.clone());
536 }
537 BinnedDataset {
538 #[cfg(feature = "rayon")]
539 datasets: binned_events
540 .into_par_iter()
541 .map(|events| Arc::new(Dataset { events }))
542 .collect(),
543 #[cfg(not(feature = "rayon"))]
544 datasets: binned_events
545 .into_iter()
546 .map(|events| Arc::new(Dataset { events }))
547 .collect(),
548 edges: bin_edges,
549 }
550 }
551
552 pub fn boost_to_rest_frame_of<T: AsRef<[usize]> + Sync>(&self, indices: T) -> Arc<Dataset> {
555 #[cfg(feature = "rayon")]
556 {
557 Arc::new(Dataset {
558 events: self
559 .events
560 .par_iter()
561 .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
562 .collect(),
563 })
564 }
565 #[cfg(not(feature = "rayon"))]
566 {
567 Arc::new(Dataset {
568 events: self
569 .events
570 .iter()
571 .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
572 .collect(),
573 })
574 }
575 }
576 pub fn evaluate<V: Variable>(&self, variable: &V) -> Vec<Float> {
578 variable.value_on(self)
579 }
580}
581
582impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
583
584fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
585 let mut p4s = Vec::new();
586 let mut aux = Vec::new();
587
588 let p4_count = batch
589 .schema()
590 .fields()
591 .iter()
592 .filter(|field| field.name().starts_with(P4_PREFIX))
593 .count()
594 / 4;
595 let aux_count = batch
596 .schema()
597 .fields()
598 .iter()
599 .filter(|field| field.name().starts_with(AUX_PREFIX))
600 .count()
601 / 3;
602
603 for i in 0..p4_count {
604 let e = batch
605 .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
606 .unwrap()
607 .as_any()
608 .downcast_ref::<Float32Array>()
609 .unwrap()
610 .value(row) as Float;
611 let px = batch
612 .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
613 .unwrap()
614 .as_any()
615 .downcast_ref::<Float32Array>()
616 .unwrap()
617 .value(row) as Float;
618 let py = batch
619 .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
620 .unwrap()
621 .as_any()
622 .downcast_ref::<Float32Array>()
623 .unwrap()
624 .value(row) as Float;
625 let pz = batch
626 .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
627 .unwrap()
628 .as_any()
629 .downcast_ref::<Float32Array>()
630 .unwrap()
631 .value(row) as Float;
632 p4s.push(Vec4::new(px, py, pz, e));
633 }
634
635 for i in 0..aux_count {
637 let x = batch
638 .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
639 .unwrap()
640 .as_any()
641 .downcast_ref::<Float32Array>()
642 .unwrap()
643 .value(row) as Float;
644 let y = batch
645 .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
646 .unwrap()
647 .as_any()
648 .downcast_ref::<Float32Array>()
649 .unwrap()
650 .value(row) as Float;
651 let z = batch
652 .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
653 .unwrap()
654 .as_any()
655 .downcast_ref::<Float32Array>()
656 .unwrap()
657 .value(row) as Float;
658 aux.push(Vec3::new(x, y, z));
659 }
660
661 let weight = batch
662 .column(19)
663 .as_any()
664 .downcast_ref::<Float32Array>()
665 .unwrap()
666 .value(row) as Float;
667
668 Event { p4s, aux, weight }
669}
670
671pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
673 let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
675 let file = File::open(file_path)?;
676 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
677 let reader = builder.build()?;
678 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
679
680 #[cfg(feature = "rayon")]
681 let events: Vec<Arc<Event>> = batches
682 .into_par_iter()
683 .flat_map(|batch| {
684 let num_rows = batch.num_rows();
685 let mut local_events = Vec::with_capacity(num_rows);
686
687 for row in 0..num_rows {
689 let event = batch_to_event(&batch, row);
690 local_events.push(Arc::new(event));
691 }
692 local_events
693 })
694 .collect();
695 #[cfg(not(feature = "rayon"))]
696 let events: Vec<Arc<Event>> = batches
697 .into_iter()
698 .flat_map(|batch| {
699 let num_rows = batch.num_rows();
700 let mut local_events = Vec::with_capacity(num_rows);
701
702 for row in 0..num_rows {
704 let event = batch_to_event(&batch, row);
705 local_events.push(Arc::new(event));
706 }
707 local_events
708 })
709 .collect();
710 Ok(Arc::new(Dataset::new(events)))
711}
712
713pub fn open_boosted_to_rest_frame_of<T: AsRef<str>, I: AsRef<[usize]> + Sync>(
716 file_path: T,
717 indices: I,
718) -> Result<Arc<Dataset>, LadduError> {
719 let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
721 let file = File::open(file_path)?;
722 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
723 let reader = builder.build()?;
724 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
725
726 #[cfg(feature = "rayon")]
727 let events: Vec<Arc<Event>> = batches
728 .into_par_iter()
729 .flat_map(|batch| {
730 let num_rows = batch.num_rows();
731 let mut local_events = Vec::with_capacity(num_rows);
732
733 for row in 0..num_rows {
735 let mut event = batch_to_event(&batch, row);
736 event = event.boost_to_rest_frame_of(indices.as_ref());
737 local_events.push(Arc::new(event));
738 }
739 local_events
740 })
741 .collect();
742 #[cfg(not(feature = "rayon"))]
743 let events: Vec<Arc<Event>> = batches
744 .into_iter()
745 .flat_map(|batch| {
746 let num_rows = batch.num_rows();
747 let mut local_events = Vec::with_capacity(num_rows);
748
749 for row in 0..num_rows {
751 let mut event = batch_to_event(&batch, row);
752 event = event.boost_to_rest_frame_of(indices.as_ref());
753 local_events.push(Arc::new(event));
754 }
755 local_events
756 })
757 .collect();
758 Ok(Arc::new(Dataset::new(events)))
759}
760
761pub struct BinnedDataset {
763 datasets: Vec<Arc<Dataset>>,
764 edges: Vec<Float>,
765}
766
767impl Index<usize> for BinnedDataset {
768 type Output = Arc<Dataset>;
769
770 fn index(&self, index: usize) -> &Self::Output {
771 &self.datasets[index]
772 }
773}
774
775impl IndexMut<usize> for BinnedDataset {
776 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
777 &mut self.datasets[index]
778 }
779}
780
781impl Deref for BinnedDataset {
782 type Target = Vec<Arc<Dataset>>;
783
784 fn deref(&self) -> &Self::Target {
785 &self.datasets
786 }
787}
788
789impl DerefMut for BinnedDataset {
790 fn deref_mut(&mut self) -> &mut Self::Target {
791 &mut self.datasets
792 }
793}
794
795impl BinnedDataset {
796 pub fn n_bins(&self) -> usize {
798 self.datasets.len()
799 }
800
801 pub fn edges(&self) -> Vec<Float> {
803 self.edges.clone()
804 }
805
806 pub fn range(&self) -> (Float, Float) {
808 (self.edges[0], self.edges[self.n_bins()])
809 }
810}
811
812#[cfg(test)]
813mod tests {
814 use crate::Mass;
815
816 use super::*;
817 use approx::{assert_relative_eq, assert_relative_ne};
818 use serde::{Deserialize, Serialize};
819 #[test]
820 fn test_event_creation() {
821 let event = test_event();
822 assert_eq!(event.p4s.len(), 4);
823 assert_eq!(event.aux.len(), 1);
824 assert_relative_eq!(event.weight, 0.48)
825 }
826
827 #[test]
828 fn test_event_p4_sum() {
829 let event = test_event();
830 let sum = event.get_p4_sum([2, 3]);
831 assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
832 assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
833 assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
834 assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
835 }
836
837 #[test]
838 fn test_event_boost() {
839 let event = test_event();
840 let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
841 let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
842 assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
843 assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
844 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
845 }
846
847 #[test]
848 fn test_event_evaluate() {
849 let event = test_event();
850 let mass = Mass::new([1]);
851 assert_relative_eq!(event.evaluate(&mass), 1.007);
852 }
853
854 #[test]
855 fn test_dataset_size_check() {
856 let mut dataset = Dataset::default();
857 assert_eq!(dataset.n_events(), 0);
858 dataset.events.push(Arc::new(test_event()));
859 assert_eq!(dataset.n_events(), 1);
860 }
861
862 #[test]
863 fn test_dataset_sum() {
864 let dataset = test_dataset();
865 let dataset2 = Dataset::new(vec![Arc::new(Event {
866 p4s: test_event().p4s,
867 aux: test_event().aux,
868 weight: 0.52,
869 })]);
870 let dataset_sum = &dataset + &dataset2;
871 assert_eq!(dataset_sum[0].weight, dataset[0].weight);
872 assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
873 }
874
875 #[test]
876 fn test_dataset_weights() {
877 let mut dataset = Dataset::default();
878 dataset.events.push(Arc::new(test_event()));
879 dataset.events.push(Arc::new(Event {
880 p4s: test_event().p4s,
881 aux: test_event().aux,
882 weight: 0.52,
883 }));
884 let weights = dataset.weights();
885 assert_eq!(weights.len(), 2);
886 assert_relative_eq!(weights[0], 0.48);
887 assert_relative_eq!(weights[1], 0.52);
888 assert_relative_eq!(dataset.n_events_weighted(), 1.0);
889 }
890
891 #[test]
892 fn test_dataset_filtering() {
893 let mut dataset = Dataset::default();
894 dataset.events.push(Arc::new(Event {
895 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
896 aux: vec![],
897 weight: 1.0,
898 }));
899 dataset.events.push(Arc::new(Event {
900 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
901 aux: vec![],
902 weight: 1.0,
903 }));
904 dataset.events.push(Arc::new(Event {
905 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
906 aux: vec![],
909 weight: 1.0,
910 }));
911
912 let mass = Mass::new([0]);
913 let expression = mass.gt(0.0).and(&mass.lt(1.0));
914
915 let filtered = dataset.filter(&expression);
916 assert_eq!(filtered.n_events(), 1);
917 assert_relative_eq!(
918 mass.value(&filtered[0]),
919 0.5,
920 epsilon = Float::EPSILON.sqrt()
921 );
922 }
923
924 #[test]
925 fn test_dataset_boost() {
926 let dataset = test_dataset();
927 let dataset_boosted = dataset.boost_to_rest_frame_of([1, 2, 3]);
928 let p4_sum = dataset_boosted[0].get_p4_sum([1, 2, 3]);
929 assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
930 assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
931 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
932 }
933
934 #[test]
935 fn test_dataset_evaluate() {
936 let dataset = test_dataset();
937 let mass = Mass::new([1]);
938 assert_relative_eq!(dataset.evaluate(&mass)[0], 1.007);
939 }
940
941 #[test]
942 fn test_binned_dataset() {
943 let dataset = Dataset::new(vec![
944 Arc::new(Event {
945 p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
946 aux: vec![],
947 weight: 1.0,
948 }),
949 Arc::new(Event {
950 p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
951 aux: vec![],
952 weight: 2.0,
953 }),
954 ]);
955
956 #[derive(Clone, Serialize, Deserialize, Debug)]
957 struct BeamEnergy;
958 impl Display for BeamEnergy {
959 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
960 write!(f, "BeamEnergy")
961 }
962 }
963 #[typetag::serde]
964 impl Variable for BeamEnergy {
965 fn value(&self, event: &Event) -> Float {
966 event.p4s[0].e()
967 }
968 }
969 assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
970
971 let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
973
974 assert_eq!(binned.n_bins(), 2);
975 assert_eq!(binned.edges().len(), 3);
976 assert_relative_eq!(binned.edges()[0], 0.0);
977 assert_relative_eq!(binned.edges()[2], 3.0);
978 assert_eq!(binned[0].n_events(), 1);
979 assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
980 assert_eq!(binned[1].n_events(), 1);
981 assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
982 }
983
984 #[test]
985 fn test_dataset_bootstrap() {
986 let mut dataset = test_dataset();
987 dataset.events.push(Arc::new(Event {
988 p4s: test_event().p4s.clone(),
989 aux: test_event().aux.clone(),
990 weight: 1.0,
991 }));
992 assert_relative_ne!(dataset[0].weight, dataset[1].weight);
993
994 let bootstrapped = dataset.bootstrap(43);
995 assert_eq!(bootstrapped.n_events(), dataset.n_events());
996 assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
997
998 let empty_dataset = Dataset::default();
1000 let empty_bootstrap = empty_dataset.bootstrap(43);
1001 assert_eq!(empty_bootstrap.n_events(), 0);
1002 }
1003 #[test]
1004 fn test_event_display() {
1005 let event = test_event();
1006 let display_string = format!("{}", event);
1007 assert_eq!(
1008 display_string,
1009 "Event:\n p4s:\n [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n eps:\n [0.385, 0.022, 0]\n weight:\n 0.48\n"
1010 );
1011 }
1012}