laddu_python/data.rs
1use crate::utils::variables::PyVariable;
2use laddu_core::{
3 data::{open, open_boosted_to_rest_frame_of, BinnedDataset, Dataset, Event},
4 Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8 exceptions::{PyIndexError, PyTypeError},
9 prelude::*,
10};
11use std::{path::PathBuf, sync::Arc};
12
13use crate::utils::vectors::{PyVec3, PyVec4};
14
15/// A single event
16///
17/// Events are composed of a set of 4-momenta of particles in the overall
18/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
19/// weight
20///
21/// Parameters
22/// ----------
23/// p4s : list of Vec4
24/// 4-momenta of each particle in the event in the overall center-of-momentum frame
25/// aux: list of Vec3
26/// 3-vectors describing auxiliary data for each particle given in `p4s`
27/// weight : float
28/// The weight associated with this event
29/// rest_frame_indices : list of int, optional
30/// If supplied, the event will be boosted to the rest frame of the 4-momenta at the
31/// given indices
32///
33#[pyclass(name = "Event", module = "laddu")]
34#[derive(Clone)]
35pub struct PyEvent(pub Arc<Event>);
36
37#[pymethods]
38impl PyEvent {
39 #[new]
40 #[pyo3(signature = (p4s, aux, weight, *, rest_frame_indices=None))]
41 fn new(
42 p4s: Vec<PyVec4>,
43 aux: Vec<PyVec3>,
44 weight: Float,
45 rest_frame_indices: Option<Vec<usize>>,
46 ) -> Self {
47 let event = Event {
48 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
49 aux: aux.into_iter().map(|arr| arr.0).collect(),
50 weight,
51 };
52 if let Some(indices) = rest_frame_indices {
53 Self(Arc::new(event.boost_to_rest_frame_of(indices)))
54 } else {
55 Self(Arc::new(event))
56 }
57 }
58 fn __str__(&self) -> String {
59 self.0.to_string()
60 }
61 /// The list of 4-momenta for each particle in the event
62 ///
63 #[getter]
64 fn get_p4s(&self) -> Vec<PyVec4> {
65 self.0.p4s.iter().map(|p4| PyVec4(*p4)).collect()
66 }
67 /// The list of 3-vectors describing the auxiliary data of particles in
68 /// the event
69 ///
70 #[getter]
71 fn get_aux(&self) -> Vec<PyVec3> {
72 self.0.aux.iter().map(|eps_vec| PyVec3(*eps_vec)).collect()
73 }
74 /// The weight of this event relative to others in a Dataset
75 ///
76 #[getter]
77 fn get_weight(&self) -> Float {
78 self.0.weight
79 }
80 /// Get the sum of the four-momenta within the event at the given indices
81 ///
82 /// Parameters
83 /// ----------
84 /// indices : list of int
85 /// The indices of the four-momenta to sum
86 ///
87 /// Returns
88 /// -------
89 /// Vec4
90 /// The result of summing the given four-momenta
91 ///
92 fn get_p4_sum(&self, indices: Vec<usize>) -> PyVec4 {
93 PyVec4(self.0.get_p4_sum(indices))
94 }
95 /// Boost all the four-momenta in the event to the rest frame of the given set of
96 /// four-momenta by indices.
97 ///
98 /// Parameters
99 /// ----------
100 /// indices : list of int
101 /// The indices of the four-momenta to sum
102 ///
103 /// Returns
104 /// -------
105 /// Event
106 /// The boosted event
107 ///
108 pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> Self {
109 PyEvent(Arc::new(self.0.boost_to_rest_frame_of(indices)))
110 }
111}
112
113/// A set of Events
114///
115/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
116///
117/// Datasets can also be indexed directly to access individual Events
118///
119/// Parameters
120/// ----------
121/// events : list of Event
122///
123/// See Also
124/// --------
125/// laddu.open
126///
127#[pyclass(name = "DatasetBase", module = "laddu", subclass)]
128#[derive(Clone)]
129pub struct PyDataset(pub Arc<Dataset>);
130
131#[pymethods]
132impl PyDataset {
133 #[new]
134 fn new(events: Vec<PyEvent>) -> Self {
135 Self(Arc::new(Dataset::new(
136 events.into_iter().map(|event| event.0).collect(),
137 )))
138 }
139 fn __len__(&self) -> usize {
140 self.0.n_events()
141 }
142 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
143 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
144 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
145 } else if let Ok(other_int) = other.extract::<usize>() {
146 if other_int == 0 {
147 Ok(self.clone())
148 } else {
149 Err(PyTypeError::new_err(
150 "Addition with an integer for this type is only defined for 0",
151 ))
152 }
153 } else {
154 Err(PyTypeError::new_err("Unsupported operand type for +"))
155 }
156 }
157 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
158 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
159 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
160 } else if let Ok(other_int) = other.extract::<usize>() {
161 if other_int == 0 {
162 Ok(self.clone())
163 } else {
164 Err(PyTypeError::new_err(
165 "Addition with an integer for this type is only defined for 0",
166 ))
167 }
168 } else {
169 Err(PyTypeError::new_err("Unsupported operand type for +"))
170 }
171 }
172 /// Get the number of Events in the Dataset
173 ///
174 /// Returns
175 /// -------
176 /// n_events : int
177 /// The number of Events
178 ///
179 #[getter]
180 fn n_events(&self) -> usize {
181 self.0.n_events()
182 }
183 /// Get the weighted number of Events in the Dataset
184 ///
185 /// Returns
186 /// -------
187 /// n_events : float
188 /// The sum of all Event weights
189 ///
190 #[getter]
191 fn n_events_weighted(&self) -> Float {
192 self.0.n_events_weighted()
193 }
194 /// The weights associated with the Dataset
195 ///
196 /// Returns
197 /// -------
198 /// weights : array_like
199 /// A ``numpy`` array of Event weights
200 ///
201 #[getter]
202 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
203 PyArray1::from_slice(py, &self.0.weights())
204 }
205 /// The internal list of Events stored in the Dataset
206 ///
207 /// Returns
208 /// -------
209 /// events : list of Event
210 /// The Events in the Dataset
211 ///
212 #[getter]
213 fn events(&self) -> Vec<PyEvent> {
214 self.0
215 .events
216 .iter()
217 .map(|rust_event| PyEvent(rust_event.clone()))
218 .collect()
219 }
220 fn __getitem__(&self, index: usize) -> PyEvent {
221 PyEvent(Arc::new(self.0[index].clone()))
222 }
223 /// Separates a Dataset into histogram bins by a Variable value
224 ///
225 /// Parameters
226 /// ----------
227 /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
228 /// The Variable by which each Event is binned
229 /// bins : int
230 /// The number of equally-spaced bins
231 /// range : tuple[float, float]
232 /// The minimum and maximum bin edges
233 ///
234 /// Returns
235 /// -------
236 /// datasets : BinnedDataset
237 /// A pub structure that holds a list of Datasets binned by the given `variable`
238 ///
239 /// See Also
240 /// --------
241 /// laddu.Mass
242 /// laddu.CosTheta
243 /// laddu.Phi
244 /// laddu.PolAngle
245 /// laddu.PolMagnitude
246 /// laddu.Mandelstam
247 ///
248 /// Raises
249 /// ------
250 /// TypeError
251 /// If the given `variable` is not a valid variable
252 ///
253 #[pyo3(signature = (variable, bins, range))]
254 fn bin_by(
255 &self,
256 variable: Bound<'_, PyAny>,
257 bins: usize,
258 range: (Float, Float),
259 ) -> PyResult<PyBinnedDataset> {
260 let py_variable = variable.extract::<PyVariable>()?;
261 Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
262 }
263 /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
264 ///
265 /// The new Dataset is resampled with a random generator seeded by the provided `seed`
266 ///
267 /// Parameters
268 /// ----------
269 /// seed : int
270 /// The random seed used in the resampling process
271 ///
272 /// Returns
273 /// -------
274 /// Dataset
275 /// A bootstrapped Dataset
276 ///
277 fn bootstrap(&self, seed: usize) -> PyDataset {
278 PyDataset(self.0.bootstrap(seed))
279 }
280 /// Boost all the four-momenta in all events to the rest frame of the given set of
281 /// four-momenta by indices.
282 ///
283 /// Parameters
284 /// ----------
285 /// indices : list of int
286 /// The indices of the four-momenta to sum
287 ///
288 /// Returns
289 /// -------
290 /// Dataset
291 /// The boosted dataset
292 ///
293 pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> PyDataset {
294 PyDataset(self.0.boost_to_rest_frame_of(indices))
295 }
296}
297
298/// A collection of Datasets binned by a Variable
299///
300/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
301///
302/// See Also
303/// --------
304/// laddu.Dataset.bin_by
305///
306#[pyclass(name = "BinnedDataset", module = "laddu")]
307pub struct PyBinnedDataset(BinnedDataset);
308
309#[pymethods]
310impl PyBinnedDataset {
311 fn __len__(&self) -> usize {
312 self.0.n_bins()
313 }
314 /// The number of bins in the BinnedDataset
315 ///
316 #[getter]
317 fn n_bins(&self) -> usize {
318 self.0.n_bins()
319 }
320 /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
321 ///
322 #[getter]
323 fn range(&self) -> (Float, Float) {
324 self.0.range()
325 }
326 /// The edges of each bin in the BinnedDataset
327 ///
328 #[getter]
329 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
330 PyArray1::from_slice(py, &self.0.edges())
331 }
332 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
333 self.0
334 .get(index)
335 .ok_or(PyIndexError::new_err("index out of range"))
336 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
337 }
338}
339
340/// Open a Dataset from a file
341///
342/// Arguments
343/// ---------
344/// path : str or Path
345/// The path to the file
346/// rest_frame_indices : list of int, optional
347/// If supplied, the dataset will be boosted to the rest frame of the 4-momenta at the
348/// given indices
349///
350///
351/// Returns
352/// -------
353/// Dataset
354///
355/// Raises
356/// ------
357/// IOError
358/// If the file could not be read
359///
360/// Warnings
361/// --------
362/// This method will panic/fail if the columns do not have the correct names or data types.
363/// There is currently no way to make this nicer without a large performance dip (if you find a
364/// way, please open a PR).
365///
366/// Notes
367/// -----
368/// Data should be stored in Parquet format with each column being filled with 32-bit floats
369///
370/// Valid/required column names have the following formats:
371///
372/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
373///
374/// ``aux_{particle index}_{x|y|z}`` (auxiliary vectors for each particle)
375///
376/// ``weight`` (the weight of the Event)
377///
378/// For example, the four-momentum of the 0th particle in the event would be stored in columns
379/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
380/// polarization could be stored in the columns ``aux_0_x``, ``aux_0_y``, and ``aux_0_z``. This
381/// could continue for an arbitrary number of particles. The ``weight`` column is always
382/// required.
383///
384#[pyfunction(name = "open", signature = (path, *, rest_frame_indices=None))]
385pub fn py_open(path: Bound<PyAny>, rest_frame_indices: Option<Vec<usize>>) -> PyResult<PyDataset> {
386 let path_str = if let Ok(s) = path.extract::<String>() {
387 Ok(s)
388 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
389 Ok(pathbuf.to_string_lossy().into_owned())
390 } else {
391 Err(PyTypeError::new_err("Expected str or Path"))
392 }?;
393 if let Some(indices) = rest_frame_indices {
394 Ok(PyDataset(open_boosted_to_rest_frame_of(path_str, indices)?))
395 } else {
396 Ok(PyDataset(open(path_str)?))
397 }
398}