laddu_python/
data.rs

1use crate::utils::variables::PyVariable;
2use laddu_core::{
3    data::{open, open_boosted_to_rest_frame_of, BinnedDataset, Dataset, Event},
4    Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8    exceptions::{PyIndexError, PyTypeError},
9    prelude::*,
10    IntoPyObjectExt,
11};
12use std::{path::PathBuf, sync::Arc};
13
14use crate::utils::vectors::{PyVec3, PyVec4};
15
16/// A single event
17///
18/// Events are composed of a set of 4-momenta of particles in the overall
19/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
20/// weight
21///
22/// Parameters
23/// ----------
24/// p4s : list of Vec4
25///     4-momenta of each particle in the event in the overall center-of-momentum frame
26/// aux: list of Vec3
27///     3-vectors describing auxiliary data for each particle given in `p4s`
28/// weight : float
29///     The weight associated with this event
30/// rest_frame_indices : list of int, optional
31///     If supplied, the event will be boosted to the rest frame of the 4-momenta at the
32///     given indices
33///
34#[pyclass(name = "Event", module = "laddu")]
35#[derive(Clone)]
36pub struct PyEvent(pub Arc<Event>);
37
38#[pymethods]
39impl PyEvent {
40    #[new]
41    #[pyo3(signature = (p4s, aux, weight, *, rest_frame_indices=None))]
42    fn new(
43        p4s: Vec<PyVec4>,
44        aux: Vec<PyVec3>,
45        weight: Float,
46        rest_frame_indices: Option<Vec<usize>>,
47    ) -> Self {
48        let event = Event {
49            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
50            aux: aux.into_iter().map(|arr| arr.0).collect(),
51            weight,
52        };
53        if let Some(indices) = rest_frame_indices {
54            Self(Arc::new(event.boost_to_rest_frame_of(indices)))
55        } else {
56            Self(Arc::new(event))
57        }
58    }
59    fn __str__(&self) -> String {
60        self.0.to_string()
61    }
62    /// The list of 4-momenta for each particle in the event
63    ///
64    #[getter]
65    fn get_p4s(&self) -> Vec<PyVec4> {
66        self.0.p4s.iter().map(|p4| PyVec4(*p4)).collect()
67    }
68    /// The list of 3-vectors describing the auxiliary data of particles in
69    /// the event
70    ///
71    #[getter]
72    fn get_aux(&self) -> Vec<PyVec3> {
73        self.0.aux.iter().map(|eps_vec| PyVec3(*eps_vec)).collect()
74    }
75    /// The weight of this event relative to others in a Dataset
76    ///
77    #[getter]
78    fn get_weight(&self) -> Float {
79        self.0.weight
80    }
81    /// Get the sum of the four-momenta within the event at the given indices
82    ///
83    /// Parameters
84    /// ----------
85    /// indices : list of int
86    ///     The indices of the four-momenta to sum
87    ///
88    /// Returns
89    /// -------
90    /// Vec4
91    ///     The result of summing the given four-momenta
92    ///
93    fn get_p4_sum(&self, indices: Vec<usize>) -> PyVec4 {
94        PyVec4(self.0.get_p4_sum(indices))
95    }
96    /// Boost all the four-momenta in the event to the rest frame of the given set of
97    /// four-momenta by indices.
98    ///
99    /// Parameters
100    /// ----------
101    /// indices : list of int
102    ///     The indices of the four-momenta to sum
103    ///
104    /// Returns
105    /// -------
106    /// Event
107    ///     The boosted event
108    ///
109    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> Self {
110        PyEvent(Arc::new(self.0.boost_to_rest_frame_of(indices)))
111    }
112    /// Get the value of a Variable on the given Event
113    ///
114    /// Parameters
115    /// ----------
116    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
117    ///
118    /// Returns
119    /// -------
120    /// float
121    ///
122    fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<Float> {
123        Ok(self.0.evaluate(&variable.extract::<PyVariable>()?))
124    }
125}
126
127/// A set of Events
128///
129/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
130///
131/// Datasets can also be indexed directly to access individual Events
132///
133/// Parameters
134/// ----------
135/// events : list of Event
136///
137/// See Also
138/// --------
139/// laddu.open
140///
141#[pyclass(name = "DatasetBase", module = "laddu", subclass)]
142#[derive(Clone)]
143pub struct PyDataset(pub Arc<Dataset>);
144
145#[pymethods]
146impl PyDataset {
147    #[new]
148    fn new(events: Vec<PyEvent>) -> Self {
149        Self(Arc::new(Dataset::new(
150            events.into_iter().map(|event| event.0).collect(),
151        )))
152    }
153    fn __len__(&self) -> usize {
154        self.0.n_events()
155    }
156    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
157        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
158            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
159        } else if let Ok(other_int) = other.extract::<usize>() {
160            if other_int == 0 {
161                Ok(self.clone())
162            } else {
163                Err(PyTypeError::new_err(
164                    "Addition with an integer for this type is only defined for 0",
165                ))
166            }
167        } else {
168            Err(PyTypeError::new_err("Unsupported operand type for +"))
169        }
170    }
171    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
172        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
173            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
174        } else if let Ok(other_int) = other.extract::<usize>() {
175            if other_int == 0 {
176                Ok(self.clone())
177            } else {
178                Err(PyTypeError::new_err(
179                    "Addition with an integer for this type is only defined for 0",
180                ))
181            }
182        } else {
183            Err(PyTypeError::new_err("Unsupported operand type for +"))
184        }
185    }
186    /// Get the number of Events in the Dataset
187    ///
188    /// Returns
189    /// -------
190    /// n_events : int
191    ///     The number of Events
192    ///
193    #[getter]
194    fn n_events(&self) -> usize {
195        self.0.n_events()
196    }
197    /// Get the weighted number of Events in the Dataset
198    ///
199    /// Returns
200    /// -------
201    /// n_events : float
202    ///     The sum of all Event weights
203    ///
204    #[getter]
205    fn n_events_weighted(&self) -> Float {
206        self.0.n_events_weighted()
207    }
208    /// The weights associated with the Dataset
209    ///
210    /// Returns
211    /// -------
212    /// weights : array_like
213    ///     A ``numpy`` array of Event weights
214    ///
215    #[getter]
216    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
217        PyArray1::from_slice(py, &self.0.weights())
218    }
219    /// The internal list of Events stored in the Dataset
220    ///
221    /// Returns
222    /// -------
223    /// events : list of Event
224    ///     The Events in the Dataset
225    ///
226    #[getter]
227    fn events(&self) -> Vec<PyEvent> {
228        self.0
229            .events
230            .iter()
231            .map(|rust_event| PyEvent(rust_event.clone()))
232            .collect()
233    }
234    fn __getitem__<'py>(
235        &self,
236        py: Python<'py>,
237        index: Bound<'py, PyAny>,
238    ) -> PyResult<Bound<'py, PyAny>> {
239        if let Ok(value) = self.evaluate(py, index.clone()) {
240            value.into_bound_py_any(py)
241        } else if let Ok(index) = index.extract::<usize>() {
242            PyEvent(Arc::new(self.0[index].clone())).into_bound_py_any(py)
243        } else {
244            Err(PyTypeError::new_err(
245                "Unsupported index type (int or Variable)",
246            ))
247        }
248    }
249    /// Separates a Dataset into histogram bins by a Variable value
250    ///
251    /// Parameters
252    /// ----------
253    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
254    ///     The Variable by which each Event is binned
255    /// bins : int
256    ///     The number of equally-spaced bins
257    /// range : tuple[float, float]
258    ///     The minimum and maximum bin edges
259    ///
260    /// Returns
261    /// -------
262    /// datasets : BinnedDataset
263    ///     A pub structure that holds a list of Datasets binned by the given `variable`
264    ///
265    /// See Also
266    /// --------
267    /// laddu.Mass
268    /// laddu.CosTheta
269    /// laddu.Phi
270    /// laddu.PolAngle
271    /// laddu.PolMagnitude
272    /// laddu.Mandelstam
273    ///
274    /// Raises
275    /// ------
276    /// TypeError
277    ///     If the given `variable` is not a valid variable
278    ///
279    #[pyo3(signature = (variable, bins, range))]
280    fn bin_by(
281        &self,
282        variable: Bound<'_, PyAny>,
283        bins: usize,
284        range: (Float, Float),
285    ) -> PyResult<PyBinnedDataset> {
286        let py_variable = variable.extract::<PyVariable>()?;
287        Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
288    }
289    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
290    ///
291    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
292    ///
293    /// Parameters
294    /// ----------
295    /// seed : int
296    ///     The random seed used in the resampling process
297    ///
298    /// Returns
299    /// -------
300    /// Dataset
301    ///     A bootstrapped Dataset
302    ///
303    fn bootstrap(&self, seed: usize) -> PyDataset {
304        PyDataset(self.0.bootstrap(seed))
305    }
306    /// Boost all the four-momenta in all events to the rest frame of the given set of
307    /// four-momenta by indices.
308    ///
309    /// Parameters
310    /// ----------
311    /// indices : list of int
312    ///     The indices of the four-momenta to sum
313    ///
314    /// Returns
315    /// -------
316    /// Dataset
317    ///     The boosted dataset
318    ///
319    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> PyDataset {
320        PyDataset(self.0.boost_to_rest_frame_of(indices))
321    }
322    /// Get the value of a Variable over every event in the Dataset.
323    ///
324    /// Parameters
325    /// ----------
326    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
327    ///
328    /// Returns
329    /// -------
330    /// values : array_like
331    ///
332    fn evaluate<'py>(
333        &self,
334        py: Python<'py>,
335        variable: Bound<'py, PyAny>,
336    ) -> PyResult<Bound<'py, PyArray1<Float>>> {
337        Ok(PyArray1::from_slice(
338            py,
339            &self.0.evaluate(&variable.extract::<PyVariable>()?),
340        ))
341    }
342}
343
344/// A collection of Datasets binned by a Variable
345///
346/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
347///
348/// See Also
349/// --------
350/// laddu.Dataset.bin_by
351///
352#[pyclass(name = "BinnedDataset", module = "laddu")]
353pub struct PyBinnedDataset(BinnedDataset);
354
355#[pymethods]
356impl PyBinnedDataset {
357    fn __len__(&self) -> usize {
358        self.0.n_bins()
359    }
360    /// The number of bins in the BinnedDataset
361    ///
362    #[getter]
363    fn n_bins(&self) -> usize {
364        self.0.n_bins()
365    }
366    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
367    ///
368    #[getter]
369    fn range(&self) -> (Float, Float) {
370        self.0.range()
371    }
372    /// The edges of each bin in the BinnedDataset
373    ///
374    #[getter]
375    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
376        PyArray1::from_slice(py, &self.0.edges())
377    }
378    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
379        self.0
380            .get(index)
381            .ok_or(PyIndexError::new_err("index out of range"))
382            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
383    }
384}
385
386/// Open a Dataset from a file
387///
388/// Arguments
389/// ---------
390/// path : str or Path
391///     The path to the file
392/// rest_frame_indices : list of int, optional
393///     If supplied, the dataset will be boosted to the rest frame of the 4-momenta at the
394///     given indices
395///
396///
397/// Returns
398/// -------
399/// Dataset
400///
401/// Raises
402/// ------
403/// IOError
404///     If the file could not be read
405///
406/// Warnings
407/// --------
408/// This method will panic/fail if the columns do not have the correct names or data types.
409/// There is currently no way to make this nicer without a large performance dip (if you find a
410/// way, please open a PR).
411///
412/// Notes
413/// -----
414/// Data should be stored in Parquet format with each column being filled with 32-bit floats
415///
416/// Valid/required column names have the following formats:
417///
418/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
419///
420/// ``aux_{particle index}_{x|y|z}`` (auxiliary vectors for each particle)
421///
422/// ``weight`` (the weight of the Event)
423///
424/// For example, the four-momentum of the 0th particle in the event would be stored in columns
425/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
426/// polarization could be stored in the columns ``aux_0_x``, ``aux_0_y``, and ``aux_0_z``. This
427/// could continue for an arbitrary number of particles. The ``weight`` column is always
428/// required.
429///
430#[pyfunction(name = "open", signature = (path, *, rest_frame_indices=None))]
431pub fn py_open(path: Bound<PyAny>, rest_frame_indices: Option<Vec<usize>>) -> PyResult<PyDataset> {
432    let path_str = if let Ok(s) = path.extract::<String>() {
433        Ok(s)
434    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
435        Ok(pathbuf.to_string_lossy().into_owned())
436    } else {
437        Err(PyTypeError::new_err("Expected str or Path"))
438    }?;
439    if let Some(indices) = rest_frame_indices {
440        Ok(PyDataset(open_boosted_to_rest_frame_of(path_str, indices)?))
441    } else {
442        Ok(PyDataset(open(path_str)?))
443    }
444}