use std::ops::Add;
use std::{fmt::Display, fs::File, iter::repeat_with, path::Path, sync::Arc};
use itertools::{izip, Either, Itertools};
use nalgebra::Vector3;
use oxyroot::{ReaderTree, RootFile, Slice};
use parquet::record::Field as ParquetField;
use parquet::{
file::reader::{FileReader, SerializedFileReader},
record::Row,
};
use rayon::prelude::*;
use tracing::info;
use crate::convert;
use crate::{errors::RustitudeError, prelude::FourMomentum, Field};
#[derive(Debug, Default, Clone)]
pub struct Event<F: Field + 'static> {
pub index: usize,
pub weight: F,
pub beam_p4: FourMomentum<F>,
pub recoil_p4: FourMomentum<F>,
pub daughter_p4s: Vec<FourMomentum<F>>,
pub eps: Vector3<F>,
}
impl<F: Field + 'static> Display for Event<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Index: {}", self.index)?;
writeln!(f, "Weight: {}", self.weight)?;
writeln!(f, "Beam P4: {}", self.beam_p4)?;
writeln!(f, "Recoil P4: {}", self.recoil_p4)?;
writeln!(f, "Daughters:")?;
for (i, p4) in self.daughter_p4s.iter().enumerate() {
writeln!(f, "\t{i} -> {p4}")?;
}
writeln!(
f,
"EPS: [{}, {}, {}]",
self.eps[0], self.eps[1], self.eps[2]
)?;
Ok(())
}
}
#[derive(Copy, Clone)]
pub enum ReadMethod<F: Field> {
Standard,
EPSInBeam,
EPS(F, F, F),
}
impl<F: Field> ReadMethod<F> {
pub fn from_linear_polarization(p_gamma: F, phi: F) -> Self {
Self::EPS(p_gamma * F::cos(phi), p_gamma * F::sin(phi), F::zero())
}
}
impl<F: Field> Event<F> {
pub fn eps_mag(&self) -> F {
F::sqrt(F::powi(self.eps.x, 2) + F::powi(self.eps.y, 2) + F::powi(self.eps.z, 2))
}
fn read_parquet_row(
index: usize,
row: Result<Row, parquet::errors::ParquetError>,
method: ReadMethod<F>,
) -> Result<Self, RustitudeError> {
let mut event = Self {
index,
..Default::default()
};
let mut e_fs: Vec<F> = Vec::new();
let mut px_fs: Vec<F> = Vec::new();
let mut py_fs: Vec<F> = Vec::new();
let mut pz_fs: Vec<F> = Vec::new();
for (name, field) in row?.get_column_iter() {
match (name.as_str(), field) {
("E_Beam", ParquetField::Float(value)) => {
event.beam_p4.set_e(convert!(*value, F));
if matches!(method, ReadMethod::EPSInBeam) {
event.beam_p4.set_pz(convert!(*value, F));
}
}
("Px_Beam", ParquetField::Float(value)) => {
if matches!(method, ReadMethod::EPSInBeam) {
event.eps[0] = convert!(*value, F);
} else {
event.beam_p4.set_px(convert!(*value, F));
}
}
("Py_Beam", ParquetField::Float(value)) => {
if matches!(method, ReadMethod::EPSInBeam) {
event.eps[1] = convert!(*value, F);
} else {
event.beam_p4.set_py(convert!(*value, F));
}
}
("Pz_Beam", ParquetField::Float(value)) => {
if !matches!(method, ReadMethod::EPSInBeam) {
event.beam_p4.set_pz(convert!(*value, F));
}
}
("Weight", ParquetField::Float(value)) => {
event.weight = convert!(*value, F);
}
("EPS", ParquetField::ListInternal(list)) => match method {
ReadMethod::Standard => {
event.eps = Vector3::from_vec(
list.elements()
.iter()
.map(|field| {
if let ParquetField::Float(value) = field {
convert!(*value, F)
} else {
panic!()
}
})
.collect(),
);
}
ReadMethod::EPS(x, y, z) => *event.eps = *Vector3::new(x, y, z),
_ => {}
},
("E_FinalState", ParquetField::ListInternal(list)) => {
e_fs = list
.elements()
.iter()
.map(|field| {
if let ParquetField::Float(value) = field {
convert!(*value, F)
} else {
panic!()
}
})
.collect()
}
("Px_FinalState", ParquetField::ListInternal(list)) => {
px_fs = list
.elements()
.iter()
.map(|field| {
if let ParquetField::Float(value) = field {
convert!(*value, F)
} else {
panic!()
}
})
.collect()
}
("Py_FinalState", ParquetField::ListInternal(list)) => {
py_fs = list
.elements()
.iter()
.map(|field| {
if let ParquetField::Float(value) = field {
convert!(*value, F)
} else {
panic!()
}
})
.collect()
}
("Pz_FinalState", ParquetField::ListInternal(list)) => {
pz_fs = list
.elements()
.iter()
.map(|field| {
if let ParquetField::Float(value) = field {
convert!(*value, F)
} else {
panic!()
}
})
.collect()
}
_ => {}
}
}
event.recoil_p4 = FourMomentum::new(e_fs[0], px_fs[0], py_fs[0], pz_fs[0]);
event.daughter_p4s = e_fs[1..]
.iter()
.zip(px_fs[1..].iter())
.zip(py_fs[1..].iter())
.zip(pz_fs[1..].iter())
.map(|(((e, px), py), pz)| FourMomentum::new(*e, *px, *py, *pz))
.collect();
Ok(event)
}
}
#[derive(Default, Debug, Clone)]
pub struct Dataset<F: Field + 'static> {
pub events: Arc<Vec<Event<F>>>,
}
impl<F: Field + 'static> Dataset<F> {
pub fn reindex(&mut self) {
self.events = Arc::new(
(*self.events)
.clone()
.iter_mut()
.enumerate()
.map(|(i, event)| {
event.index = i;
event.clone()
})
.collect(),
)
}
pub fn weights(&self) -> Vec<F> {
self.events.iter().map(|e| e.weight).collect()
}
pub fn weights_indexed(&self, indices: &[usize]) -> Vec<F> {
indices
.iter()
.map(|index| self.events[*index].weight)
.collect()
}
pub fn split_m(
&self,
range: (F, F),
bins: usize,
daughter_indices: Option<Vec<usize>>,
) -> (Vec<Vec<usize>>, Vec<usize>, Vec<usize>) {
let mass = |e: &Event<F>| {
let p4: FourMomentum<F> = daughter_indices
.clone()
.unwrap_or_else(|| vec![0, 1])
.iter()
.map(|i| e.daughter_p4s[*i])
.sum();
p4.m()
};
self.get_binned_indices(mass, range, bins)
}
pub fn from_parquet(path: &str, method: ReadMethod<F>) -> Result<Self, RustitudeError> {
let path = Path::new(path);
let file = File::open(path)?;
let reader = SerializedFileReader::new(file)?;
let row_iter = reader.get_row_iter(None)?;
Ok(Self::new(
row_iter
.enumerate()
.map(|(i, row)| Event::read_parquet_row(i, row, method))
.collect::<Result<Vec<Event<F>>, RustitudeError>>()?,
))
}
fn extract_f32(path: &str, ttree: &ReaderTree, branch: &str) -> Result<Vec<F>, RustitudeError> {
let res = ttree
.branch(branch)
.ok_or_else(|| {
RustitudeError::OxyrootError(format!(
"Could not find {} branch in {}",
branch, path
))
})?
.as_iter::<f64>()
.map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
.map(|val| convert!(val, F))
.collect();
Ok(res)
}
fn extract_vec_f32(
path: &str,
ttree: &ReaderTree,
branch: &str,
) -> Result<Vec<Vec<F>>, RustitudeError> {
let res: Vec<Vec<F>> = ttree
.branch(branch)
.ok_or_else(|| {
RustitudeError::OxyrootError(format!(
"Could not find {} branch in {}",
branch, path
))
})?
.as_iter::<Slice<f64>>()
.map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
.map(|v| {
v.into_vec()
.into_iter()
.map(|val| convert!(val, F))
.collect()
})
.collect();
Ok(res)
}
pub fn from_root(path: &str, method: ReadMethod<F>) -> Result<Self, RustitudeError> {
let ttree = RootFile::open(path)
.map_err(|err| RustitudeError::OxyrootError(err.to_string()))?
.get_tree("kin")
.map_err(|err| RustitudeError::OxyrootError(err.to_string()))?;
let weight: Vec<F> = Self::extract_f32(path, &ttree, "Weight")?;
let e_beam: Vec<F> = Self::extract_f32(path, &ttree, "E_Beam")?;
let px_beam: Vec<F> = Self::extract_f32(path, &ttree, "Px_Beam")?;
let py_beam: Vec<F> = Self::extract_f32(path, &ttree, "Py_Beam")?;
let pz_beam: Vec<F> = Self::extract_f32(path, &ttree, "Pz_Beam")?;
let e_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "E_FinalState")?;
let px_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Px_FinalState")?;
let py_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Py_FinalState")?;
let pz_fs: Vec<Vec<F>> = Self::extract_vec_f32(path, &ttree, "Pz_FinalState")?;
let eps_extracted: Vec<Vec<F>> = if matches!(method, ReadMethod::Standard) {
Self::extract_vec_f32(path, &ttree, "EPS")?
} else {
vec![vec![F::zero(); 3]; weight.len()]
};
Ok(Self::new(
izip!(
weight,
e_beam,
px_beam,
py_beam,
pz_beam,
e_fs,
px_fs,
py_fs,
pz_fs,
eps_extracted
)
.enumerate()
.map(
|(i, (w, e_b, px_b, py_b, pz_b, e_f, px_f, py_f, pz_f, eps_vec))| {
let (beam_p4, eps) = match method {
ReadMethod::Standard => (
FourMomentum::new(e_b, px_b, py_b, pz_b),
Vector3::from_vec(eps_vec),
),
ReadMethod::EPSInBeam => (
FourMomentum::new(e_b, F::zero(), F::zero(), e_b),
Vector3::new(px_b, py_b, pz_b),
),
ReadMethod::EPS(x, y, z) => (
FourMomentum::new(e_b, px_b, py_b, pz_b),
Vector3::new(x, y, z),
),
};
Event {
index: i,
weight: w,
beam_p4,
recoil_p4: FourMomentum::new(e_f[0], px_f[0], py_f[0], pz_f[0]),
daughter_p4s: izip!(
e_f[1..].iter(),
px_f[1..].iter(),
py_f[1..].iter(),
pz_f[1..].iter()
)
.map(|(e, px, py, pz)| FourMomentum::new(*e, *px, *py, *pz))
.collect(),
eps,
}
},
)
.collect(),
))
}
pub fn new(events: Vec<Event<F>>) -> Self {
info!("Dataset created with {} events", events.len());
Self {
events: Arc::new(events),
}
}
pub fn is_empty(&self) -> bool {
self.events.is_empty()
}
pub fn len(&self) -> usize {
self.events.len()
}
pub fn get_bootstrap_indices(&self, seed: usize) -> Vec<usize> {
fastrand::seed(seed as u64);
let mut inds: Vec<usize> = repeat_with(|| fastrand::usize(0..self.len()))
.take(self.len())
.collect();
inds.sort_unstable();
inds
}
pub fn get_selected_indices(
&self,
query: impl Fn(&Event<F>) -> bool + Sync + Send,
) -> (Vec<usize>, Vec<usize>) {
let (mut indices_selected, mut indices_rejected): (Vec<usize>, Vec<usize>) =
self.events.par_iter().partition_map(|event| {
if query(event) {
Either::Left(event.index)
} else {
Either::Right(event.index)
}
});
indices_selected.sort_unstable();
indices_rejected.sort_unstable();
(indices_selected, indices_rejected)
}
pub fn get_binned_indices(
&self,
variable: impl Fn(&Event<F>) -> F + Sync + Send,
range: (F, F),
nbins: usize,
) -> (Vec<Vec<usize>>, Vec<usize>, Vec<usize>) {
let mut bins: Vec<F> = Vec::with_capacity(nbins + 1);
let width = (range.1 - range.0) / convert!(nbins, F);
for m in 0..=nbins {
bins.push(F::mul_add(width, convert!(m, F), range.0));
}
let (underflow, _) = self.get_selected_indices(|event| variable(event) < bins[0]);
let (overflow, _) =
self.get_selected_indices(|event| variable(event) >= bins[bins.len() - 1]);
let binned_indices = bins
.into_iter()
.tuple_windows()
.map(|(lb, ub)| {
let (sel, _) = self.get_selected_indices(|event| {
let res = variable(event);
lb <= res && res < ub
});
sel
})
.collect();
(binned_indices, underflow, overflow)
}
}
impl<F: Field + 'static> Add for Dataset<F> {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
let mut combined_events = Vec::with_capacity(self.events.len() + other.events.len());
combined_events.extend(Arc::try_unwrap(self.events).unwrap_or_else(|arc| (*arc).clone()));
combined_events.extend(Arc::try_unwrap(other.events).unwrap_or_else(|arc| (*arc).clone()));
Self {
events: Arc::new(combined_events),
}
}
}