laddu_python/
data.rs

1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3    data::{open, open_boosted_to_rest_frame_of, BinnedDataset, Dataset, Event},
4    Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8    exceptions::{PyIndexError, PyTypeError},
9    prelude::*,
10    IntoPyObjectExt,
11};
12use std::{path::PathBuf, sync::Arc};
13
14use crate::utils::vectors::{PyVec3, PyVec4};
15
16/// A single event
17///
18/// Events are composed of a set of 4-momenta of particles in the overall
19/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
20/// weight
21///
22/// Parameters
23/// ----------
24/// p4s : list of Vec4
25///     4-momenta of each particle in the event in the overall center-of-momentum frame
26/// aux: list of Vec3
27///     3-vectors describing auxiliary data for each particle given in `p4s`
28/// weight : float
29///     The weight associated with this event
30/// rest_frame_indices : list of int, optional
31///     If supplied, the event will be boosted to the rest frame of the 4-momenta at the
32///     given indices
33///
34#[pyclass(name = "Event", module = "laddu")]
35#[derive(Clone)]
36pub struct PyEvent(pub Arc<Event>);
37
38#[pymethods]
39impl PyEvent {
40    #[new]
41    #[pyo3(signature = (p4s, aux, weight, *, rest_frame_indices=None))]
42    fn new(
43        p4s: Vec<PyVec4>,
44        aux: Vec<PyVec3>,
45        weight: Float,
46        rest_frame_indices: Option<Vec<usize>>,
47    ) -> Self {
48        let event = Event {
49            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
50            aux: aux.into_iter().map(|arr| arr.0).collect(),
51            weight,
52        };
53        if let Some(indices) = rest_frame_indices {
54            Self(Arc::new(event.boost_to_rest_frame_of(indices)))
55        } else {
56            Self(Arc::new(event))
57        }
58    }
59    fn __str__(&self) -> String {
60        self.0.to_string()
61    }
62    /// The list of 4-momenta for each particle in the event
63    ///
64    #[getter]
65    fn get_p4s(&self) -> Vec<PyVec4> {
66        self.0.p4s.iter().map(|p4| PyVec4(*p4)).collect()
67    }
68    /// The list of 3-vectors describing the auxiliary data of particles in
69    /// the event
70    ///
71    #[getter]
72    fn get_aux(&self) -> Vec<PyVec3> {
73        self.0.aux.iter().map(|eps_vec| PyVec3(*eps_vec)).collect()
74    }
75    /// The weight of this event relative to others in a Dataset
76    ///
77    #[getter]
78    fn get_weight(&self) -> Float {
79        self.0.weight
80    }
81    /// Get the sum of the four-momenta within the event at the given indices
82    ///
83    /// Parameters
84    /// ----------
85    /// indices : list of int
86    ///     The indices of the four-momenta to sum
87    ///
88    /// Returns
89    /// -------
90    /// Vec4
91    ///     The result of summing the given four-momenta
92    ///
93    fn get_p4_sum(&self, indices: Vec<usize>) -> PyVec4 {
94        PyVec4(self.0.get_p4_sum(indices))
95    }
96    /// Boost all the four-momenta in the event to the rest frame of the given set of
97    /// four-momenta by indices.
98    ///
99    /// Parameters
100    /// ----------
101    /// indices : list of int
102    ///     The indices of the four-momenta to sum
103    ///
104    /// Returns
105    /// -------
106    /// Event
107    ///     The boosted event
108    ///
109    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> Self {
110        PyEvent(Arc::new(self.0.boost_to_rest_frame_of(indices)))
111    }
112    /// Get the value of a Variable on the given Event
113    ///
114    /// Parameters
115    /// ----------
116    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
117    ///
118    /// Returns
119    /// -------
120    /// float
121    ///
122    fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<Float> {
123        Ok(self.0.evaluate(&variable.extract::<PyVariable>()?))
124    }
125}
126
127/// A set of Events
128///
129/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
130///
131/// Datasets can also be indexed directly to access individual Events
132///
133/// Parameters
134/// ----------
135/// events : list of Event
136///
137/// See Also
138/// --------
139/// laddu.open
140///
141#[pyclass(name = "DatasetBase", module = "laddu", subclass)]
142#[derive(Clone)]
143pub struct PyDataset(pub Arc<Dataset>);
144
145#[pymethods]
146impl PyDataset {
147    #[new]
148    fn new(events: Vec<PyEvent>) -> Self {
149        Self(Arc::new(Dataset::new(
150            events.into_iter().map(|event| event.0).collect(),
151        )))
152    }
153    fn __len__(&self) -> usize {
154        self.0.n_events()
155    }
156    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
157        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
158            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
159        } else if let Ok(other_int) = other.extract::<usize>() {
160            if other_int == 0 {
161                Ok(self.clone())
162            } else {
163                Err(PyTypeError::new_err(
164                    "Addition with an integer for this type is only defined for 0",
165                ))
166            }
167        } else {
168            Err(PyTypeError::new_err("Unsupported operand type for +"))
169        }
170    }
171    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
172        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
173            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
174        } else if let Ok(other_int) = other.extract::<usize>() {
175            if other_int == 0 {
176                Ok(self.clone())
177            } else {
178                Err(PyTypeError::new_err(
179                    "Addition with an integer for this type is only defined for 0",
180                ))
181            }
182        } else {
183            Err(PyTypeError::new_err("Unsupported operand type for +"))
184        }
185    }
186    /// Get the number of Events in the Dataset
187    ///
188    /// Returns
189    /// -------
190    /// n_events : int
191    ///     The number of Events
192    ///
193    #[getter]
194    fn n_events(&self) -> usize {
195        self.0.n_events()
196    }
197    /// Get the weighted number of Events in the Dataset
198    ///
199    /// Returns
200    /// -------
201    /// n_events : float
202    ///     The sum of all Event weights
203    ///
204    #[getter]
205    fn n_events_weighted(&self) -> Float {
206        self.0.n_events_weighted()
207    }
208    /// The weights associated with the Dataset
209    ///
210    /// Returns
211    /// -------
212    /// weights : array_like
213    ///     A ``numpy`` array of Event weights
214    ///
215    #[getter]
216    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
217        PyArray1::from_slice(py, &self.0.weights())
218    }
219    /// The internal list of Events stored in the Dataset
220    ///
221    /// Returns
222    /// -------
223    /// events : list of Event
224    ///     The Events in the Dataset
225    ///
226    #[getter]
227    fn events(&self) -> Vec<PyEvent> {
228        self.0
229            .events
230            .iter()
231            .map(|rust_event| PyEvent(rust_event.clone()))
232            .collect()
233    }
234    fn __getitem__<'py>(
235        &self,
236        py: Python<'py>,
237        index: Bound<'py, PyAny>,
238    ) -> PyResult<Bound<'py, PyAny>> {
239        if let Ok(value) = self.evaluate(py, index.clone()) {
240            value.into_bound_py_any(py)
241        } else if let Ok(index) = index.extract::<usize>() {
242            PyEvent(Arc::new(self.0[index].clone())).into_bound_py_any(py)
243        } else {
244            Err(PyTypeError::new_err(
245                "Unsupported index type (int or Variable)",
246            ))
247        }
248    }
249    /// Separates a Dataset into histogram bins by a Variable value
250    ///
251    /// Parameters
252    /// ----------
253    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
254    ///     The Variable by which each Event is binned
255    /// bins : int
256    ///     The number of equally-spaced bins
257    /// range : tuple[float, float]
258    ///     The minimum and maximum bin edges
259    ///
260    /// Returns
261    /// -------
262    /// datasets : BinnedDataset
263    ///     A pub structure that holds a list of Datasets binned by the given `variable`
264    ///
265    /// See Also
266    /// --------
267    /// laddu.Mass
268    /// laddu.CosTheta
269    /// laddu.Phi
270    /// laddu.PolAngle
271    /// laddu.PolMagnitude
272    /// laddu.Mandelstam
273    ///
274    /// Raises
275    /// ------
276    /// TypeError
277    ///     If the given `variable` is not a valid variable
278    ///
279    #[pyo3(signature = (variable, bins, range))]
280    fn bin_by(
281        &self,
282        variable: Bound<'_, PyAny>,
283        bins: usize,
284        range: (Float, Float),
285    ) -> PyResult<PyBinnedDataset> {
286        let py_variable = variable.extract::<PyVariable>()?;
287        Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
288    }
289    /// Filter the Dataset by a given VariableExpression, selecting events for which the expression returns ``True``.
290    ///
291    /// Parameters
292    /// ----------
293    /// expression : VariableExpression
294    ///     The expression with which to filter the Dataset
295    ///
296    /// Returns
297    /// -------
298    /// Dataset
299    ///     The filtered Dataset
300    ///
301    pub fn filter(&self, expression: &PyVariableExpression) -> PyDataset {
302        PyDataset(self.0.filter(&expression.0))
303    }
304    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
305    ///
306    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
307    ///
308    /// Parameters
309    /// ----------
310    /// seed : int
311    ///     The random seed used in the resampling process
312    ///
313    /// Returns
314    /// -------
315    /// Dataset
316    ///     A bootstrapped Dataset
317    ///
318    fn bootstrap(&self, seed: usize) -> PyDataset {
319        PyDataset(self.0.bootstrap(seed))
320    }
321    /// Boost all the four-momenta in all events to the rest frame of the given set of
322    /// four-momenta by indices.
323    ///
324    /// Parameters
325    /// ----------
326    /// indices : list of int
327    ///     The indices of the four-momenta to sum
328    ///
329    /// Returns
330    /// -------
331    /// Dataset
332    ///     The boosted dataset
333    ///
334    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> PyDataset {
335        PyDataset(self.0.boost_to_rest_frame_of(indices))
336    }
337    /// Get the value of a Variable over every event in the Dataset.
338    ///
339    /// Parameters
340    /// ----------
341    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
342    ///
343    /// Returns
344    /// -------
345    /// values : array_like
346    ///
347    fn evaluate<'py>(
348        &self,
349        py: Python<'py>,
350        variable: Bound<'py, PyAny>,
351    ) -> PyResult<Bound<'py, PyArray1<Float>>> {
352        Ok(PyArray1::from_slice(
353            py,
354            &self.0.evaluate(&variable.extract::<PyVariable>()?),
355        ))
356    }
357}
358
359/// A collection of Datasets binned by a Variable
360///
361/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
362///
363/// See Also
364/// --------
365/// laddu.Dataset.bin_by
366///
367#[pyclass(name = "BinnedDataset", module = "laddu")]
368pub struct PyBinnedDataset(BinnedDataset);
369
370#[pymethods]
371impl PyBinnedDataset {
372    fn __len__(&self) -> usize {
373        self.0.n_bins()
374    }
375    /// The number of bins in the BinnedDataset
376    ///
377    #[getter]
378    fn n_bins(&self) -> usize {
379        self.0.n_bins()
380    }
381    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
382    ///
383    #[getter]
384    fn range(&self) -> (Float, Float) {
385        self.0.range()
386    }
387    /// The edges of each bin in the BinnedDataset
388    ///
389    #[getter]
390    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
391        PyArray1::from_slice(py, &self.0.edges())
392    }
393    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
394        self.0
395            .get(index)
396            .ok_or(PyIndexError::new_err("index out of range"))
397            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
398    }
399}
400
401/// Open a Dataset from a file
402///
403/// Arguments
404/// ---------
405/// path : str or Path
406///     The path to the file
407/// rest_frame_indices : list of int, optional
408///     If supplied, the dataset will be boosted to the rest frame of the 4-momenta at the
409///     given indices
410///
411///
412/// Returns
413/// -------
414/// Dataset
415///
416/// Raises
417/// ------
418/// IOError
419///     If the file could not be read
420///
421/// Warnings
422/// --------
423/// This method will panic/fail if the columns do not have the correct names or data types.
424/// There is currently no way to make this nicer without a large performance dip (if you find a
425/// way, please open a PR).
426///
427/// Notes
428/// -----
429/// Data should be stored in Parquet format with each column being filled with 32-bit floats
430///
431/// Valid/required column names have the following formats:
432///
433/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
434///
435/// ``aux_{particle index}_{x|y|z}`` (auxiliary vectors for each particle)
436///
437/// ``weight`` (the weight of the Event)
438///
439/// For example, the four-momentum of the 0th particle in the event would be stored in columns
440/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
441/// polarization could be stored in the columns ``aux_0_x``, ``aux_0_y``, and ``aux_0_z``. This
442/// could continue for an arbitrary number of particles. The ``weight`` column is always
443/// required.
444///
445#[pyfunction(name = "open", signature = (path, *, rest_frame_indices=None))]
446pub fn py_open(path: Bound<PyAny>, rest_frame_indices: Option<Vec<usize>>) -> PyResult<PyDataset> {
447    let path_str = if let Ok(s) = path.extract::<String>() {
448        Ok(s)
449    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
450        Ok(pathbuf.to_string_lossy().into_owned())
451    } else {
452        Err(PyTypeError::new_err("Expected str or Path"))
453    }?;
454    if let Some(indices) = rest_frame_indices {
455        Ok(PyDataset(open_boosted_to_rest_frame_of(path_str, indices)?))
456    } else {
457        Ok(PyDataset(open(path_str)?))
458    }
459}