laddu_python/data.rs
1use crate::utils::variables::PyVariable;
2use laddu_core::{
3 data::{open, open_boosted_to_rest_frame_of, BinnedDataset, Dataset, Event},
4 Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8 exceptions::{PyIndexError, PyTypeError},
9 prelude::*,
10 IntoPyObjectExt,
11};
12use std::{path::PathBuf, sync::Arc};
13
14use crate::utils::vectors::{PyVec3, PyVec4};
15
16/// A single event
17///
18/// Events are composed of a set of 4-momenta of particles in the overall
19/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
20/// weight
21///
22/// Parameters
23/// ----------
24/// p4s : list of Vec4
25/// 4-momenta of each particle in the event in the overall center-of-momentum frame
26/// aux: list of Vec3
27/// 3-vectors describing auxiliary data for each particle given in `p4s`
28/// weight : float
29/// The weight associated with this event
30/// rest_frame_indices : list of int, optional
31/// If supplied, the event will be boosted to the rest frame of the 4-momenta at the
32/// given indices
33///
34#[pyclass(name = "Event", module = "laddu")]
35#[derive(Clone)]
36pub struct PyEvent(pub Arc<Event>);
37
38#[pymethods]
39impl PyEvent {
40 #[new]
41 #[pyo3(signature = (p4s, aux, weight, *, rest_frame_indices=None))]
42 fn new(
43 p4s: Vec<PyVec4>,
44 aux: Vec<PyVec3>,
45 weight: Float,
46 rest_frame_indices: Option<Vec<usize>>,
47 ) -> Self {
48 let event = Event {
49 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
50 aux: aux.into_iter().map(|arr| arr.0).collect(),
51 weight,
52 };
53 if let Some(indices) = rest_frame_indices {
54 Self(Arc::new(event.boost_to_rest_frame_of(indices)))
55 } else {
56 Self(Arc::new(event))
57 }
58 }
59 fn __str__(&self) -> String {
60 self.0.to_string()
61 }
62 /// The list of 4-momenta for each particle in the event
63 ///
64 #[getter]
65 fn get_p4s(&self) -> Vec<PyVec4> {
66 self.0.p4s.iter().map(|p4| PyVec4(*p4)).collect()
67 }
68 /// The list of 3-vectors describing the auxiliary data of particles in
69 /// the event
70 ///
71 #[getter]
72 fn get_aux(&self) -> Vec<PyVec3> {
73 self.0.aux.iter().map(|eps_vec| PyVec3(*eps_vec)).collect()
74 }
75 /// The weight of this event relative to others in a Dataset
76 ///
77 #[getter]
78 fn get_weight(&self) -> Float {
79 self.0.weight
80 }
81 /// Get the sum of the four-momenta within the event at the given indices
82 ///
83 /// Parameters
84 /// ----------
85 /// indices : list of int
86 /// The indices of the four-momenta to sum
87 ///
88 /// Returns
89 /// -------
90 /// Vec4
91 /// The result of summing the given four-momenta
92 ///
93 fn get_p4_sum(&self, indices: Vec<usize>) -> PyVec4 {
94 PyVec4(self.0.get_p4_sum(indices))
95 }
96 /// Boost all the four-momenta in the event to the rest frame of the given set of
97 /// four-momenta by indices.
98 ///
99 /// Parameters
100 /// ----------
101 /// indices : list of int
102 /// The indices of the four-momenta to sum
103 ///
104 /// Returns
105 /// -------
106 /// Event
107 /// The boosted event
108 ///
109 pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> Self {
110 PyEvent(Arc::new(self.0.boost_to_rest_frame_of(indices)))
111 }
112 /// Get the value of a Variable on the given Event
113 ///
114 /// Parameters
115 /// ----------
116 /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
117 ///
118 /// Returns
119 /// -------
120 /// float
121 ///
122 fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<Float> {
123 Ok(self.0.evaluate(&variable.extract::<PyVariable>()?))
124 }
125}
126
127/// A set of Events
128///
129/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
130///
131/// Datasets can also be indexed directly to access individual Events
132///
133/// Parameters
134/// ----------
135/// events : list of Event
136///
137/// See Also
138/// --------
139/// laddu.open
140///
141#[pyclass(name = "DatasetBase", module = "laddu", subclass)]
142#[derive(Clone)]
143pub struct PyDataset(pub Arc<Dataset>);
144
145#[pymethods]
146impl PyDataset {
147 #[new]
148 fn new(events: Vec<PyEvent>) -> Self {
149 Self(Arc::new(Dataset::new(
150 events.into_iter().map(|event| event.0).collect(),
151 )))
152 }
153 fn __len__(&self) -> usize {
154 self.0.n_events()
155 }
156 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
157 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
158 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
159 } else if let Ok(other_int) = other.extract::<usize>() {
160 if other_int == 0 {
161 Ok(self.clone())
162 } else {
163 Err(PyTypeError::new_err(
164 "Addition with an integer for this type is only defined for 0",
165 ))
166 }
167 } else {
168 Err(PyTypeError::new_err("Unsupported operand type for +"))
169 }
170 }
171 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
172 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
173 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
174 } else if let Ok(other_int) = other.extract::<usize>() {
175 if other_int == 0 {
176 Ok(self.clone())
177 } else {
178 Err(PyTypeError::new_err(
179 "Addition with an integer for this type is only defined for 0",
180 ))
181 }
182 } else {
183 Err(PyTypeError::new_err("Unsupported operand type for +"))
184 }
185 }
186 /// Get the number of Events in the Dataset
187 ///
188 /// Returns
189 /// -------
190 /// n_events : int
191 /// The number of Events
192 ///
193 #[getter]
194 fn n_events(&self) -> usize {
195 self.0.n_events()
196 }
197 /// Get the weighted number of Events in the Dataset
198 ///
199 /// Returns
200 /// -------
201 /// n_events : float
202 /// The sum of all Event weights
203 ///
204 #[getter]
205 fn n_events_weighted(&self) -> Float {
206 self.0.n_events_weighted()
207 }
208 /// The weights associated with the Dataset
209 ///
210 /// Returns
211 /// -------
212 /// weights : array_like
213 /// A ``numpy`` array of Event weights
214 ///
215 #[getter]
216 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
217 PyArray1::from_slice(py, &self.0.weights())
218 }
219 /// The internal list of Events stored in the Dataset
220 ///
221 /// Returns
222 /// -------
223 /// events : list of Event
224 /// The Events in the Dataset
225 ///
226 #[getter]
227 fn events(&self) -> Vec<PyEvent> {
228 self.0
229 .events
230 .iter()
231 .map(|rust_event| PyEvent(rust_event.clone()))
232 .collect()
233 }
234 fn __getitem__<'py>(
235 &self,
236 py: Python<'py>,
237 index: Bound<'py, PyAny>,
238 ) -> PyResult<Bound<'py, PyAny>> {
239 if let Ok(value) = self.evaluate(py, index.clone()) {
240 value.into_bound_py_any(py)
241 } else if let Ok(index) = index.extract::<usize>() {
242 PyEvent(Arc::new(self.0[index].clone())).into_bound_py_any(py)
243 } else {
244 Err(PyTypeError::new_err(
245 "Unsupported index type (int or Variable)",
246 ))
247 }
248 }
249 /// Separates a Dataset into histogram bins by a Variable value
250 ///
251 /// Parameters
252 /// ----------
253 /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
254 /// The Variable by which each Event is binned
255 /// bins : int
256 /// The number of equally-spaced bins
257 /// range : tuple[float, float]
258 /// The minimum and maximum bin edges
259 ///
260 /// Returns
261 /// -------
262 /// datasets : BinnedDataset
263 /// A pub structure that holds a list of Datasets binned by the given `variable`
264 ///
265 /// See Also
266 /// --------
267 /// laddu.Mass
268 /// laddu.CosTheta
269 /// laddu.Phi
270 /// laddu.PolAngle
271 /// laddu.PolMagnitude
272 /// laddu.Mandelstam
273 ///
274 /// Raises
275 /// ------
276 /// TypeError
277 /// If the given `variable` is not a valid variable
278 ///
279 #[pyo3(signature = (variable, bins, range))]
280 fn bin_by(
281 &self,
282 variable: Bound<'_, PyAny>,
283 bins: usize,
284 range: (Float, Float),
285 ) -> PyResult<PyBinnedDataset> {
286 let py_variable = variable.extract::<PyVariable>()?;
287 Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
288 }
289 /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
290 ///
291 /// The new Dataset is resampled with a random generator seeded by the provided `seed`
292 ///
293 /// Parameters
294 /// ----------
295 /// seed : int
296 /// The random seed used in the resampling process
297 ///
298 /// Returns
299 /// -------
300 /// Dataset
301 /// A bootstrapped Dataset
302 ///
303 fn bootstrap(&self, seed: usize) -> PyDataset {
304 PyDataset(self.0.bootstrap(seed))
305 }
306 /// Boost all the four-momenta in all events to the rest frame of the given set of
307 /// four-momenta by indices.
308 ///
309 /// Parameters
310 /// ----------
311 /// indices : list of int
312 /// The indices of the four-momenta to sum
313 ///
314 /// Returns
315 /// -------
316 /// Dataset
317 /// The boosted dataset
318 ///
319 pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> PyDataset {
320 PyDataset(self.0.boost_to_rest_frame_of(indices))
321 }
322 /// Get the value of a Variable over every event in the Dataset.
323 ///
324 /// Parameters
325 /// ----------
326 /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
327 ///
328 /// Returns
329 /// -------
330 /// values : array_like
331 ///
332 fn evaluate<'py>(
333 &self,
334 py: Python<'py>,
335 variable: Bound<'py, PyAny>,
336 ) -> PyResult<Bound<'py, PyArray1<Float>>> {
337 Ok(PyArray1::from_slice(
338 py,
339 &self.0.evaluate(&variable.extract::<PyVariable>()?),
340 ))
341 }
342}
343
344/// A collection of Datasets binned by a Variable
345///
346/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
347///
348/// See Also
349/// --------
350/// laddu.Dataset.bin_by
351///
352#[pyclass(name = "BinnedDataset", module = "laddu")]
353pub struct PyBinnedDataset(BinnedDataset);
354
355#[pymethods]
356impl PyBinnedDataset {
357 fn __len__(&self) -> usize {
358 self.0.n_bins()
359 }
360 /// The number of bins in the BinnedDataset
361 ///
362 #[getter]
363 fn n_bins(&self) -> usize {
364 self.0.n_bins()
365 }
366 /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
367 ///
368 #[getter]
369 fn range(&self) -> (Float, Float) {
370 self.0.range()
371 }
372 /// The edges of each bin in the BinnedDataset
373 ///
374 #[getter]
375 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
376 PyArray1::from_slice(py, &self.0.edges())
377 }
378 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
379 self.0
380 .get(index)
381 .ok_or(PyIndexError::new_err("index out of range"))
382 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
383 }
384}
385
386/// Open a Dataset from a file
387///
388/// Arguments
389/// ---------
390/// path : str or Path
391/// The path to the file
392/// rest_frame_indices : list of int, optional
393/// If supplied, the dataset will be boosted to the rest frame of the 4-momenta at the
394/// given indices
395///
396///
397/// Returns
398/// -------
399/// Dataset
400///
401/// Raises
402/// ------
403/// IOError
404/// If the file could not be read
405///
406/// Warnings
407/// --------
408/// This method will panic/fail if the columns do not have the correct names or data types.
409/// There is currently no way to make this nicer without a large performance dip (if you find a
410/// way, please open a PR).
411///
412/// Notes
413/// -----
414/// Data should be stored in Parquet format with each column being filled with 32-bit floats
415///
416/// Valid/required column names have the following formats:
417///
418/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
419///
420/// ``aux_{particle index}_{x|y|z}`` (auxiliary vectors for each particle)
421///
422/// ``weight`` (the weight of the Event)
423///
424/// For example, the four-momentum of the 0th particle in the event would be stored in columns
425/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
426/// polarization could be stored in the columns ``aux_0_x``, ``aux_0_y``, and ``aux_0_z``. This
427/// could continue for an arbitrary number of particles. The ``weight`` column is always
428/// required.
429///
430#[pyfunction(name = "open", signature = (path, *, rest_frame_indices=None))]
431pub fn py_open(path: Bound<PyAny>, rest_frame_indices: Option<Vec<usize>>) -> PyResult<PyDataset> {
432 let path_str = if let Ok(s) = path.extract::<String>() {
433 Ok(s)
434 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
435 Ok(pathbuf.to_string_lossy().into_owned())
436 } else {
437 Err(PyTypeError::new_err("Expected str or Path"))
438 }?;
439 if let Some(indices) = rest_frame_indices {
440 Ok(PyDataset(open_boosted_to_rest_frame_of(path_str, indices)?))
441 } else {
442 Ok(PyDataset(open(path_str)?))
443 }
444}