1use accurate::{sum::Klein, traits::*};
2use arrow::array::Float32Array;
3use arrow::record_batch::RecordBatch;
4use auto_ops::impl_op_ex;
5use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
6use serde::{Deserialize, Serialize};
7use std::ops::{Deref, DerefMut, Index, IndexMut};
8use std::path::Path;
9use std::sync::Arc;
10use std::{fmt::Display, fs::File};
11
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14
15#[cfg(feature = "mpi")]
16use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
17
18#[cfg(feature = "mpi")]
19use crate::mpi::LadduMPI;
20
21use crate::utils::get_bin_edges;
22use crate::{
23 utils::{
24 variables::{Variable, VariableExpression},
25 vectors::{Vec3, Vec4},
26 },
27 Float, LadduError,
28};
29
30const P4_PREFIX: &str = "p4_";
31const AUX_PREFIX: &str = "aux_";
32
33pub fn test_event() -> Event {
37 use crate::utils::vectors::*;
38 Event {
39 p4s: vec![
40 Vec3::new(0.0, 0.0, 8.747).with_mass(0.0), Vec3::new(0.119, 0.374, 0.222).with_mass(1.007), Vec3::new(-0.112, 0.293, 3.081).with_mass(0.498), Vec3::new(-0.007, -0.667, 5.446).with_mass(0.498), ],
45 aux: vec![Vec3::new(0.385, 0.022, 0.000)],
46 weight: 0.48,
47 }
48}
49
50pub fn test_dataset() -> Dataset {
54 Dataset::new(vec![Arc::new(test_event())])
55}
56
57#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct Event {
60 pub p4s: Vec<Vec4>,
62 pub aux: Vec<Vec3>,
64 pub weight: Float,
66}
67
68impl Display for Event {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 writeln!(f, "Event:")?;
71 writeln!(f, " p4s:")?;
72 for p4 in &self.p4s {
73 writeln!(f, " {}", p4.to_p4_string())?;
74 }
75 writeln!(f, " eps:")?;
76 for eps_vec in &self.aux {
77 writeln!(f, " [{}, {}, {}]", eps_vec.x, eps_vec.y, eps_vec.z)?;
78 }
79 writeln!(f, " weight:")?;
80 writeln!(f, " {}", self.weight)?;
81 Ok(())
82 }
83}
84
85impl Event {
86 pub fn get_p4_sum<T: AsRef<[usize]>>(&self, indices: T) -> Vec4 {
88 indices.as_ref().iter().map(|i| self.p4s[*i]).sum::<Vec4>()
89 }
90 pub fn boost_to_rest_frame_of<T: AsRef<[usize]>>(&self, indices: T) -> Self {
93 let frame = self.get_p4_sum(indices);
94 Event {
95 p4s: self
96 .p4s
97 .iter()
98 .map(|p4| p4.boost(&(-frame.beta())))
99 .collect(),
100 aux: self.aux.clone(),
101 weight: self.weight,
102 }
103 }
104 pub fn evaluate<V: Variable>(&self, variable: &V) -> Float {
106 variable.value(self)
107 }
108}
109
110#[derive(Debug, Clone, Default)]
112pub struct Dataset {
113 pub events: Vec<Arc<Event>>,
115}
116
117impl Dataset {
118 pub fn index_local(&self, index: usize) -> &Event {
132 &self.events[index]
133 }
134
135 #[cfg(feature = "mpi")]
136 fn get_rank_index(index: usize, displs: &[i32], world: &SimpleCommunicator) -> (i32, usize) {
137 for (i, &displ) in displs.iter().enumerate() {
138 if displ as usize > index {
139 return (i as i32 - 1, index - displs[i - 1] as usize);
140 }
141 }
142 (
143 world.size() - 1,
144 index - displs[world.size() as usize - 1] as usize,
145 )
146 }
147
148 #[cfg(feature = "mpi")]
153 pub fn get_counts_displs_locals_from_indices(
154 &self,
155 indices: &[usize],
156 world: &SimpleCommunicator,
157 ) -> (Vec<i32>, Vec<i32>, Vec<usize>) {
158 let mut counts = vec![0i32; world.size() as usize];
159 let mut displs = vec![0i32; world.size() as usize];
160 let (_, global_displs) = world.get_counts_displs(self.n_events());
161 let owning_rank_locals: Vec<(i32, usize)> = indices
162 .iter()
163 .map(|i| Dataset::get_rank_index(*i, &global_displs, world))
164 .collect();
165 let mut locals_by_rank = vec![Vec::new(); world.size() as usize];
166 for &(r, li) in owning_rank_locals.iter() {
167 locals_by_rank[r as usize].push(li);
168 }
169 for rank in 0..world.size() as usize {
170 counts[rank] = locals_by_rank[rank].len() as i32;
171 displs[rank] = if rank == 0 {
172 0
173 } else {
174 displs[rank - 1] + counts[rank - 1]
175 };
176 }
177 (
178 counts,
179 displs,
180 locals_by_rank[world.rank() as usize].clone(),
181 )
182 }
183
184 #[cfg(feature = "mpi")]
190 pub fn get_flattened_counts_displs_locals_from_indices(
191 &self,
192 indices: &[usize],
193 internal_len: usize,
194 world: &SimpleCommunicator,
195 ) -> (Vec<i32>, Vec<i32>, Vec<usize>) {
196 let mut counts = vec![0i32; world.size() as usize];
197 let mut displs = vec![0i32; world.size() as usize];
198 let (_, global_displs) = world.get_counts_displs(self.n_events());
199 let owning_rank_locals: Vec<(i32, usize)> = indices
200 .iter()
201 .map(|i| Dataset::get_rank_index(*i, &global_displs, world))
202 .collect();
203 let mut locals_by_rank = vec![Vec::new(); world.size() as usize];
204 for &(r, li) in owning_rank_locals.iter() {
205 locals_by_rank[r as usize].push(li);
206 }
207 for rank in 0..world.size() as usize {
208 counts[rank] = (locals_by_rank[rank].len() * internal_len) as i32;
209 displs[rank] = if rank == 0 {
210 0
211 } else {
212 displs[rank - 1] + counts[rank - 1]
213 };
214 }
215 (
216 counts,
217 displs,
218 locals_by_rank[world.rank() as usize].clone(),
219 )
220 }
221
222 #[cfg(feature = "mpi")]
223 fn partition(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Vec<Vec<Arc<Event>>> {
224 let (counts, displs) = world.get_counts_displs(events.len());
225 counts
226 .iter()
227 .zip(displs.iter())
228 .map(|(&count, &displ)| {
229 events
230 .iter()
231 .skip(displ as usize)
232 .take(count as usize)
233 .cloned()
234 .collect()
235 })
236 .collect()
237 }
238
239 #[cfg(feature = "mpi")]
253 pub fn index_mpi(&self, index: usize, world: &SimpleCommunicator) -> &Event {
254 let (_, displs) = world.get_counts_displs(self.n_events());
255 let (owning_rank, local_index) = Dataset::get_rank_index(index, &displs, world);
256 let mut serialized_event_buffer_len: usize = 0;
257 let mut serialized_event_buffer: Vec<u8> = Vec::default();
258 let config = bincode::config::standard();
259 if world.rank() == owning_rank {
260 let event = self.index_local(local_index);
261 serialized_event_buffer = bincode::serde::encode_to_vec(event, config).unwrap();
262 serialized_event_buffer_len = serialized_event_buffer.len();
263 }
264 world
265 .process_at_rank(owning_rank)
266 .broadcast_into(&mut serialized_event_buffer_len);
267 if world.rank() != owning_rank {
268 serialized_event_buffer = vec![0; serialized_event_buffer_len];
269 }
270 world
271 .process_at_rank(owning_rank)
272 .broadcast_into(&mut serialized_event_buffer);
273 let (event, _): (Event, usize) =
274 bincode::serde::decode_from_slice(&serialized_event_buffer[..], config).unwrap();
275 Box::leak(Box::new(event))
276 }
277}
278
279impl Index<usize> for Dataset {
280 type Output = Event;
281
282 fn index(&self, index: usize) -> &Self::Output {
283 #[cfg(feature = "mpi")]
284 {
285 if let Some(world) = crate::mpi::get_world() {
286 return self.index_mpi(index, &world);
287 }
288 }
289 self.index_local(index)
290 }
291}
292
293impl Dataset {
294 pub fn new_local(events: Vec<Arc<Event>>) -> Self {
301 Dataset { events }
302 }
303
304 #[cfg(feature = "mpi")]
311 pub fn new_mpi(events: Vec<Arc<Event>>, world: &SimpleCommunicator) -> Self {
312 Dataset {
313 events: Dataset::partition(events, world)[world.rank() as usize].clone(),
314 }
315 }
316
317 pub fn new(events: Vec<Arc<Event>>) -> Self {
323 #[cfg(feature = "mpi")]
324 {
325 if let Some(world) = crate::mpi::get_world() {
326 return Dataset::new_mpi(events, &world);
327 }
328 }
329 Dataset::new_local(events)
330 }
331
332 pub fn n_events_local(&self) -> usize {
339 self.events.len()
340 }
341
342 #[cfg(feature = "mpi")]
349 pub fn n_events_mpi(&self, world: &SimpleCommunicator) -> usize {
350 let mut n_events_partitioned: Vec<usize> = vec![0; world.size() as usize];
351 let n_events_local = self.n_events_local();
352 world.all_gather_into(&n_events_local, &mut n_events_partitioned);
353 n_events_partitioned.iter().sum()
354 }
355
356 pub fn n_events(&self) -> usize {
358 #[cfg(feature = "mpi")]
359 {
360 if let Some(world) = crate::mpi::get_world() {
361 return self.n_events_mpi(&world);
362 }
363 }
364 self.n_events_local()
365 }
366}
367
368impl Dataset {
369 pub fn weights_local(&self) -> Vec<Float> {
376 #[cfg(feature = "rayon")]
377 return self.events.par_iter().map(|e| e.weight).collect();
378 #[cfg(not(feature = "rayon"))]
379 return self.events.iter().map(|e| e.weight).collect();
380 }
381
382 #[cfg(feature = "mpi")]
389 pub fn weights_mpi(&self, world: &SimpleCommunicator) -> Vec<Float> {
390 let local_weights = self.weights_local();
391 let n_events = self.n_events();
392 let mut buffer: Vec<Float> = vec![0.0; n_events];
393 let (counts, displs) = world.get_counts_displs(n_events);
394 {
395 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
396 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
397 }
398 buffer
399 }
400
401 pub fn weights(&self) -> Vec<Float> {
403 #[cfg(feature = "mpi")]
404 {
405 if let Some(world) = crate::mpi::get_world() {
406 return self.weights_mpi(&world);
407 }
408 }
409 self.weights_local()
410 }
411
412 pub fn n_events_weighted_local(&self) -> Float {
419 #[cfg(feature = "rayon")]
420 return self
421 .events
422 .par_iter()
423 .map(|e| e.weight)
424 .parallel_sum_with_accumulator::<Klein<Float>>();
425 #[cfg(not(feature = "rayon"))]
426 return self.events.iter().map(|e| e.weight).sum();
427 }
428 #[cfg(feature = "mpi")]
435 pub fn n_events_weighted_mpi(&self, world: &SimpleCommunicator) -> Float {
436 let mut n_events_weighted_partitioned: Vec<Float> = vec![0.0; world.size() as usize];
437 let n_events_weighted_local = self.n_events_weighted_local();
438 world.all_gather_into(&n_events_weighted_local, &mut n_events_weighted_partitioned);
439 #[cfg(feature = "rayon")]
440 return n_events_weighted_partitioned
441 .into_par_iter()
442 .parallel_sum_with_accumulator::<Klein<Float>>();
443 #[cfg(not(feature = "rayon"))]
444 return n_events_weighted_partitioned.iter().sum();
445 }
446
447 pub fn n_events_weighted(&self) -> Float {
449 #[cfg(feature = "mpi")]
450 {
451 if let Some(world) = crate::mpi::get_world() {
452 return self.n_events_weighted_mpi(&world);
453 }
454 }
455 self.n_events_weighted_local()
456 }
457
458 pub fn bootstrap_local(&self, seed: usize) -> Arc<Dataset> {
466 let mut rng = fastrand::Rng::with_seed(seed as u64);
467 let mut indices: Vec<usize> = (0..self.n_events())
468 .map(|_| rng.usize(0..self.n_events()))
469 .collect::<Vec<usize>>();
470 indices.sort();
471 #[cfg(feature = "rayon")]
472 let bootstrapped_events: Vec<Arc<Event>> = indices
473 .into_par_iter()
474 .map(|idx| self.events[idx].clone())
475 .collect();
476 #[cfg(not(feature = "rayon"))]
477 let bootstrapped_events: Vec<Arc<Event>> = indices
478 .into_iter()
479 .map(|idx| self.events[idx].clone())
480 .collect();
481 Arc::new(Dataset {
482 events: bootstrapped_events,
483 })
484 }
485
486 #[cfg(feature = "mpi")]
494 pub fn bootstrap_mpi(&self, seed: usize, world: &SimpleCommunicator) -> Arc<Dataset> {
495 let n_events = self.n_events();
496 let mut indices: Vec<usize> = vec![0; n_events];
497 if world.is_root() {
498 let mut rng = fastrand::Rng::with_seed(seed as u64);
499 indices = (0..n_events)
500 .map(|_| rng.usize(0..n_events))
501 .collect::<Vec<usize>>();
502 indices.sort();
503 }
504 world.process_at_root().broadcast_into(&mut indices);
505 let (_, displs) = world.get_counts_displs(self.n_events());
506 let local_indices: Vec<usize> = indices
507 .into_iter()
508 .filter_map(|idx| {
509 let (owning_rank, local_index) = Dataset::get_rank_index(idx, &displs, world);
510 if world.rank() == owning_rank {
511 Some(local_index)
512 } else {
513 None
514 }
515 })
516 .collect();
517 #[cfg(feature = "rayon")]
520 let bootstrapped_events: Vec<Arc<Event>> = local_indices
521 .into_par_iter()
522 .map(|idx| self.events[idx].clone())
523 .collect();
524 #[cfg(not(feature = "rayon"))]
525 let bootstrapped_events: Vec<Arc<Event>> = local_indices
526 .into_iter()
527 .map(|idx| self.events[idx].clone())
528 .collect();
529 Arc::new(Dataset {
530 events: bootstrapped_events,
531 })
532 }
533
534 pub fn bootstrap(&self, seed: usize) -> Arc<Dataset> {
537 #[cfg(feature = "mpi")]
538 {
539 if let Some(world) = crate::mpi::get_world() {
540 return self.bootstrap_mpi(seed, &world);
541 }
542 }
543 self.bootstrap_local(seed)
544 }
545
546 pub fn filter(&self, expression: &VariableExpression) -> Arc<Dataset> {
549 let compiled = expression.compile();
550 #[cfg(feature = "rayon")]
551 let filtered_events = self
552 .events
553 .par_iter()
554 .filter(|e| compiled.evaluate(e))
555 .cloned()
556 .collect();
557 #[cfg(not(feature = "rayon"))]
558 let filtered_events = self
559 .events
560 .iter()
561 .filter(|e| compiled.evaluate(e))
562 .cloned()
563 .collect();
564 Arc::new(Dataset {
565 events: filtered_events,
566 })
567 }
568
569 pub fn bin_by<V>(&self, variable: V, bins: usize, range: (Float, Float)) -> BinnedDataset
572 where
573 V: Variable,
574 {
575 let bin_width = (range.1 - range.0) / bins as Float;
576 let bin_edges = get_bin_edges(bins, range);
577 #[cfg(feature = "rayon")]
578 let evaluated: Vec<(usize, &Arc<Event>)> = self
579 .events
580 .par_iter()
581 .filter_map(|event| {
582 let value = variable.value(event.as_ref());
583 if value >= range.0 && value < range.1 {
584 let bin_index = ((value - range.0) / bin_width) as usize;
585 let bin_index = bin_index.min(bins - 1);
586 Some((bin_index, event))
587 } else {
588 None
589 }
590 })
591 .collect();
592 #[cfg(not(feature = "rayon"))]
593 let evaluated: Vec<(usize, &Arc<Event>)> = self
594 .events
595 .iter()
596 .filter_map(|event| {
597 let value = variable.value(event.as_ref());
598 if value >= range.0 && value < range.1 {
599 let bin_index = ((value - range.0) / bin_width) as usize;
600 let bin_index = bin_index.min(bins - 1);
601 Some((bin_index, event))
602 } else {
603 None
604 }
605 })
606 .collect();
607 let mut binned_events: Vec<Vec<Arc<Event>>> = vec![Vec::default(); bins];
608 for (bin_index, event) in evaluated {
609 binned_events[bin_index].push(event.clone());
610 }
611 BinnedDataset {
612 #[cfg(feature = "rayon")]
613 datasets: binned_events
614 .into_par_iter()
615 .map(|events| Arc::new(Dataset { events }))
616 .collect(),
617 #[cfg(not(feature = "rayon"))]
618 datasets: binned_events
619 .into_iter()
620 .map(|events| Arc::new(Dataset { events }))
621 .collect(),
622 edges: bin_edges,
623 }
624 }
625
626 pub fn boost_to_rest_frame_of<T: AsRef<[usize]> + Sync>(&self, indices: T) -> Arc<Dataset> {
629 #[cfg(feature = "rayon")]
630 {
631 Arc::new(Dataset {
632 events: self
633 .events
634 .par_iter()
635 .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
636 .collect(),
637 })
638 }
639 #[cfg(not(feature = "rayon"))]
640 {
641 Arc::new(Dataset {
642 events: self
643 .events
644 .iter()
645 .map(|event| Arc::new(event.boost_to_rest_frame_of(indices.as_ref())))
646 .collect(),
647 })
648 }
649 }
650 pub fn evaluate<V: Variable>(&self, variable: &V) -> Vec<Float> {
652 variable.value_on(self)
653 }
654}
655
656impl_op_ex!(+ |a: &Dataset, b: &Dataset| -> Dataset { Dataset { events: a.events.iter().chain(b.events.iter()).cloned().collect() }});
657
658fn batch_to_event(batch: &RecordBatch, row: usize) -> Event {
659 let mut p4s = Vec::new();
660 let mut aux = Vec::new();
661
662 let p4_count = batch
663 .schema()
664 .fields()
665 .iter()
666 .filter(|field| field.name().starts_with(P4_PREFIX))
667 .count()
668 / 4;
669 let aux_count = batch
670 .schema()
671 .fields()
672 .iter()
673 .filter(|field| field.name().starts_with(AUX_PREFIX))
674 .count()
675 / 3;
676
677 for i in 0..p4_count {
678 let e = batch
679 .column_by_name(&format!("{}{}_E", P4_PREFIX, i))
680 .unwrap()
681 .as_any()
682 .downcast_ref::<Float32Array>()
683 .unwrap()
684 .value(row) as Float;
685 let px = batch
686 .column_by_name(&format!("{}{}_Px", P4_PREFIX, i))
687 .unwrap()
688 .as_any()
689 .downcast_ref::<Float32Array>()
690 .unwrap()
691 .value(row) as Float;
692 let py = batch
693 .column_by_name(&format!("{}{}_Py", P4_PREFIX, i))
694 .unwrap()
695 .as_any()
696 .downcast_ref::<Float32Array>()
697 .unwrap()
698 .value(row) as Float;
699 let pz = batch
700 .column_by_name(&format!("{}{}_Pz", P4_PREFIX, i))
701 .unwrap()
702 .as_any()
703 .downcast_ref::<Float32Array>()
704 .unwrap()
705 .value(row) as Float;
706 p4s.push(Vec4::new(px, py, pz, e));
707 }
708
709 for i in 0..aux_count {
711 let x = batch
712 .column_by_name(&format!("{}{}_x", AUX_PREFIX, i))
713 .unwrap()
714 .as_any()
715 .downcast_ref::<Float32Array>()
716 .unwrap()
717 .value(row) as Float;
718 let y = batch
719 .column_by_name(&format!("{}{}_y", AUX_PREFIX, i))
720 .unwrap()
721 .as_any()
722 .downcast_ref::<Float32Array>()
723 .unwrap()
724 .value(row) as Float;
725 let z = batch
726 .column_by_name(&format!("{}{}_z", AUX_PREFIX, i))
727 .unwrap()
728 .as_any()
729 .downcast_ref::<Float32Array>()
730 .unwrap()
731 .value(row) as Float;
732 aux.push(Vec3::new(x, y, z));
733 }
734
735 let weight = batch
736 .column(19)
737 .as_any()
738 .downcast_ref::<Float32Array>()
739 .unwrap()
740 .value(row) as Float;
741
742 Event { p4s, aux, weight }
743}
744
745pub fn open<T: AsRef<str>>(file_path: T) -> Result<Arc<Dataset>, LadduError> {
747 let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
749 let file = File::open(file_path)?;
750 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
751 let reader = builder.build()?;
752 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
753
754 #[cfg(feature = "rayon")]
755 let events: Vec<Arc<Event>> = batches
756 .into_par_iter()
757 .flat_map(|batch| {
758 let num_rows = batch.num_rows();
759 let mut local_events = Vec::with_capacity(num_rows);
760
761 for row in 0..num_rows {
763 let event = batch_to_event(&batch, row);
764 local_events.push(Arc::new(event));
765 }
766 local_events
767 })
768 .collect();
769 #[cfg(not(feature = "rayon"))]
770 let events: Vec<Arc<Event>> = batches
771 .into_iter()
772 .flat_map(|batch| {
773 let num_rows = batch.num_rows();
774 let mut local_events = Vec::with_capacity(num_rows);
775
776 for row in 0..num_rows {
778 let event = batch_to_event(&batch, row);
779 local_events.push(Arc::new(event));
780 }
781 local_events
782 })
783 .collect();
784 Ok(Arc::new(Dataset::new(events)))
785}
786
787pub fn open_boosted_to_rest_frame_of<T: AsRef<str>, I: AsRef<[usize]> + Sync>(
790 file_path: T,
791 indices: I,
792) -> Result<Arc<Dataset>, LadduError> {
793 let file_path = Path::new(&*shellexpand::full(file_path.as_ref())?).canonicalize()?;
795 let file = File::open(file_path)?;
796 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
797 let reader = builder.build()?;
798 let batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
799
800 #[cfg(feature = "rayon")]
801 let events: Vec<Arc<Event>> = batches
802 .into_par_iter()
803 .flat_map(|batch| {
804 let num_rows = batch.num_rows();
805 let mut local_events = Vec::with_capacity(num_rows);
806
807 for row in 0..num_rows {
809 let mut event = batch_to_event(&batch, row);
810 event = event.boost_to_rest_frame_of(indices.as_ref());
811 local_events.push(Arc::new(event));
812 }
813 local_events
814 })
815 .collect();
816 #[cfg(not(feature = "rayon"))]
817 let events: Vec<Arc<Event>> = batches
818 .into_iter()
819 .flat_map(|batch| {
820 let num_rows = batch.num_rows();
821 let mut local_events = Vec::with_capacity(num_rows);
822
823 for row in 0..num_rows {
825 let mut event = batch_to_event(&batch, row);
826 event = event.boost_to_rest_frame_of(indices.as_ref());
827 local_events.push(Arc::new(event));
828 }
829 local_events
830 })
831 .collect();
832 Ok(Arc::new(Dataset::new(events)))
833}
834
835pub struct BinnedDataset {
837 datasets: Vec<Arc<Dataset>>,
838 edges: Vec<Float>,
839}
840
841impl Index<usize> for BinnedDataset {
842 type Output = Arc<Dataset>;
843
844 fn index(&self, index: usize) -> &Self::Output {
845 &self.datasets[index]
846 }
847}
848
849impl IndexMut<usize> for BinnedDataset {
850 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
851 &mut self.datasets[index]
852 }
853}
854
855impl Deref for BinnedDataset {
856 type Target = Vec<Arc<Dataset>>;
857
858 fn deref(&self) -> &Self::Target {
859 &self.datasets
860 }
861}
862
863impl DerefMut for BinnedDataset {
864 fn deref_mut(&mut self) -> &mut Self::Target {
865 &mut self.datasets
866 }
867}
868
869impl BinnedDataset {
870 pub fn n_bins(&self) -> usize {
872 self.datasets.len()
873 }
874
875 pub fn edges(&self) -> Vec<Float> {
877 self.edges.clone()
878 }
879
880 pub fn range(&self) -> (Float, Float) {
882 (self.edges[0], self.edges[self.n_bins()])
883 }
884}
885
886#[cfg(test)]
887mod tests {
888 use crate::Mass;
889
890 use super::*;
891 use approx::{assert_relative_eq, assert_relative_ne};
892 use serde::{Deserialize, Serialize};
893 #[test]
894 fn test_event_creation() {
895 let event = test_event();
896 assert_eq!(event.p4s.len(), 4);
897 assert_eq!(event.aux.len(), 1);
898 assert_relative_eq!(event.weight, 0.48)
899 }
900
901 #[test]
902 fn test_event_p4_sum() {
903 let event = test_event();
904 let sum = event.get_p4_sum([2, 3]);
905 assert_relative_eq!(sum.px(), event.p4s[2].px() + event.p4s[3].px());
906 assert_relative_eq!(sum.py(), event.p4s[2].py() + event.p4s[3].py());
907 assert_relative_eq!(sum.pz(), event.p4s[2].pz() + event.p4s[3].pz());
908 assert_relative_eq!(sum.e(), event.p4s[2].e() + event.p4s[3].e());
909 }
910
911 #[test]
912 fn test_event_boost() {
913 let event = test_event();
914 let event_boosted = event.boost_to_rest_frame_of([1, 2, 3]);
915 let p4_sum = event_boosted.get_p4_sum([1, 2, 3]);
916 assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
917 assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
918 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
919 }
920
921 #[test]
922 fn test_event_evaluate() {
923 let event = test_event();
924 let mass = Mass::new([1]);
925 assert_relative_eq!(event.evaluate(&mass), 1.007);
926 }
927
928 #[test]
929 fn test_dataset_size_check() {
930 let mut dataset = Dataset::default();
931 assert_eq!(dataset.n_events(), 0);
932 dataset.events.push(Arc::new(test_event()));
933 assert_eq!(dataset.n_events(), 1);
934 }
935
936 #[test]
937 fn test_dataset_sum() {
938 let dataset = test_dataset();
939 let dataset2 = Dataset::new(vec![Arc::new(Event {
940 p4s: test_event().p4s,
941 aux: test_event().aux,
942 weight: 0.52,
943 })]);
944 let dataset_sum = &dataset + &dataset2;
945 assert_eq!(dataset_sum[0].weight, dataset[0].weight);
946 assert_eq!(dataset_sum[1].weight, dataset2[0].weight);
947 }
948
949 #[test]
950 fn test_dataset_weights() {
951 let mut dataset = Dataset::default();
952 dataset.events.push(Arc::new(test_event()));
953 dataset.events.push(Arc::new(Event {
954 p4s: test_event().p4s,
955 aux: test_event().aux,
956 weight: 0.52,
957 }));
958 let weights = dataset.weights();
959 assert_eq!(weights.len(), 2);
960 assert_relative_eq!(weights[0], 0.48);
961 assert_relative_eq!(weights[1], 0.52);
962 assert_relative_eq!(dataset.n_events_weighted(), 1.0);
963 }
964
965 #[test]
966 fn test_dataset_filtering() {
967 let mut dataset = Dataset::default();
968 dataset.events.push(Arc::new(Event {
969 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.0)],
970 aux: vec![],
971 weight: 1.0,
972 }));
973 dataset.events.push(Arc::new(Event {
974 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(0.5)],
975 aux: vec![],
976 weight: 1.0,
977 }));
978 dataset.events.push(Arc::new(Event {
979 p4s: vec![Vec3::new(0.0, 0.0, 5.0).with_mass(1.1)],
980 aux: vec![],
983 weight: 1.0,
984 }));
985
986 let mass = Mass::new([0]);
987 let expression = mass.gt(0.0).and(&mass.lt(1.0));
988
989 let filtered = dataset.filter(&expression);
990 assert_eq!(filtered.n_events(), 1);
991 assert_relative_eq!(
992 mass.value(&filtered[0]),
993 0.5,
994 epsilon = Float::EPSILON.sqrt()
995 );
996 }
997
998 #[test]
999 fn test_dataset_boost() {
1000 let dataset = test_dataset();
1001 let dataset_boosted = dataset.boost_to_rest_frame_of([1, 2, 3]);
1002 let p4_sum = dataset_boosted[0].get_p4_sum([1, 2, 3]);
1003 assert_relative_eq!(p4_sum.px(), 0.0, epsilon = Float::EPSILON.sqrt());
1004 assert_relative_eq!(p4_sum.py(), 0.0, epsilon = Float::EPSILON.sqrt());
1005 assert_relative_eq!(p4_sum.pz(), 0.0, epsilon = Float::EPSILON.sqrt());
1006 }
1007
1008 #[test]
1009 fn test_dataset_evaluate() {
1010 let dataset = test_dataset();
1011 let mass = Mass::new([1]);
1012 assert_relative_eq!(dataset.evaluate(&mass)[0], 1.007);
1013 }
1014
1015 #[test]
1016 fn test_binned_dataset() {
1017 let dataset = Dataset::new(vec![
1018 Arc::new(Event {
1019 p4s: vec![Vec3::new(0.0, 0.0, 1.0).with_mass(1.0)],
1020 aux: vec![],
1021 weight: 1.0,
1022 }),
1023 Arc::new(Event {
1024 p4s: vec![Vec3::new(0.0, 0.0, 2.0).with_mass(2.0)],
1025 aux: vec![],
1026 weight: 2.0,
1027 }),
1028 ]);
1029
1030 #[derive(Clone, Serialize, Deserialize, Debug)]
1031 struct BeamEnergy;
1032 impl Display for BeamEnergy {
1033 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1034 write!(f, "BeamEnergy")
1035 }
1036 }
1037 #[typetag::serde]
1038 impl Variable for BeamEnergy {
1039 fn value(&self, event: &Event) -> Float {
1040 event.p4s[0].e()
1041 }
1042 }
1043 assert_eq!(BeamEnergy.to_string(), "BeamEnergy");
1044
1045 let binned = dataset.bin_by(BeamEnergy, 2, (0.0, 3.0));
1047
1048 assert_eq!(binned.n_bins(), 2);
1049 assert_eq!(binned.edges().len(), 3);
1050 assert_relative_eq!(binned.edges()[0], 0.0);
1051 assert_relative_eq!(binned.edges()[2], 3.0);
1052 assert_eq!(binned[0].n_events(), 1);
1053 assert_relative_eq!(binned[0].n_events_weighted(), 1.0);
1054 assert_eq!(binned[1].n_events(), 1);
1055 assert_relative_eq!(binned[1].n_events_weighted(), 2.0);
1056 }
1057
1058 #[test]
1059 fn test_dataset_bootstrap() {
1060 let mut dataset = test_dataset();
1061 dataset.events.push(Arc::new(Event {
1062 p4s: test_event().p4s.clone(),
1063 aux: test_event().aux.clone(),
1064 weight: 1.0,
1065 }));
1066 assert_relative_ne!(dataset[0].weight, dataset[1].weight);
1067
1068 let bootstrapped = dataset.bootstrap(43);
1069 assert_eq!(bootstrapped.n_events(), dataset.n_events());
1070 assert_relative_eq!(bootstrapped[0].weight, bootstrapped[1].weight);
1071
1072 let empty_dataset = Dataset::default();
1074 let empty_bootstrap = empty_dataset.bootstrap(43);
1075 assert_eq!(empty_bootstrap.n_events(), 0);
1076 }
1077 #[test]
1078 fn test_event_display() {
1079 let event = test_event();
1080 let display_string = format!("{}", event);
1081 assert_eq!(
1082 display_string,
1083 "Event:\n p4s:\n [e = 8.74700; p = (0.00000, 0.00000, 8.74700); m = 0.00000]\n [e = 1.10334; p = (0.11900, 0.37400, 0.22200); m = 1.00700]\n [e = 3.13671; p = (-0.11200, 0.29300, 3.08100); m = 0.49800]\n [e = 5.50925; p = (-0.00700, -0.66700, 5.44600); m = 0.49800]\n eps:\n [0.385, 0.022, 0]\n weight:\n 0.48\n"
1084 );
1085 }
1086}