laddu_python/
data.rs

1use crate::utils::variables::PyVariable;
2use laddu_core::{
3    data::{open, open_boosted_to_rest_frame_of, BinnedDataset, Dataset, Event},
4    Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8    exceptions::{PyIndexError, PyTypeError},
9    prelude::*,
10};
11use std::{path::PathBuf, sync::Arc};
12
13use crate::utils::vectors::{PyVec3, PyVec4};
14
15/// A single event
16///
17/// Events are composed of a set of 4-momenta of particles in the overall
18/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
19/// weight
20///
21/// Parameters
22/// ----------
23/// p4s : list of Vec4
24///     4-momenta of each particle in the event in the overall center-of-momentum frame
25/// aux: list of Vec3
26///     3-vectors describing auxiliary data for each particle given in `p4s`
27/// weight : float
28///     The weight associated with this event
29/// rest_frame_indices : list of int, optional
30///     If supplied, the event will be boosted to the rest frame of the 4-momenta at the
31///     given indices
32///
33#[pyclass(name = "Event", module = "laddu")]
34#[derive(Clone)]
35pub struct PyEvent(pub Arc<Event>);
36
37#[pymethods]
38impl PyEvent {
39    #[new]
40    #[pyo3(signature = (p4s, aux, weight, *, rest_frame_indices=None))]
41    fn new(
42        p4s: Vec<PyVec4>,
43        aux: Vec<PyVec3>,
44        weight: Float,
45        rest_frame_indices: Option<Vec<usize>>,
46    ) -> Self {
47        let event = Event {
48            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
49            aux: aux.into_iter().map(|arr| arr.0).collect(),
50            weight,
51        };
52        if let Some(indices) = rest_frame_indices {
53            Self(Arc::new(event.boost_to_rest_frame_of(indices)))
54        } else {
55            Self(Arc::new(event))
56        }
57    }
58    fn __str__(&self) -> String {
59        self.0.to_string()
60    }
61    /// The list of 4-momenta for each particle in the event
62    ///
63    #[getter]
64    fn get_p4s(&self) -> Vec<PyVec4> {
65        self.0.p4s.iter().map(|p4| PyVec4(*p4)).collect()
66    }
67    /// The list of 3-vectors describing the auxiliary data of particles in
68    /// the event
69    ///
70    #[getter]
71    fn get_aux(&self) -> Vec<PyVec3> {
72        self.0.aux.iter().map(|eps_vec| PyVec3(*eps_vec)).collect()
73    }
74    /// The weight of this event relative to others in a Dataset
75    ///
76    #[getter]
77    fn get_weight(&self) -> Float {
78        self.0.weight
79    }
80    /// Get the sum of the four-momenta within the event at the given indices
81    ///
82    /// Parameters
83    /// ----------
84    /// indices : list of int
85    ///     The indices of the four-momenta to sum
86    ///
87    /// Returns
88    /// -------
89    /// Vec4
90    ///     The result of summing the given four-momenta
91    ///
92    fn get_p4_sum(&self, indices: Vec<usize>) -> PyVec4 {
93        PyVec4(self.0.get_p4_sum(indices))
94    }
95    /// Boost all the four-momenta in the event to the rest frame of the given set of
96    /// four-momenta by indices.
97    ///
98    /// Parameters
99    /// ----------
100    /// indices : list of int
101    ///     The indices of the four-momenta to sum
102    ///
103    /// Returns
104    /// -------
105    /// Event
106    ///     The boosted event
107    ///
108    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> Self {
109        PyEvent(Arc::new(self.0.boost_to_rest_frame_of(indices)))
110    }
111}
112
113/// A set of Events
114///
115/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
116///
117/// Datasets can also be indexed directly to access individual Events
118///
119/// Parameters
120/// ----------
121/// events : list of Event
122///
123/// See Also
124/// --------
125/// laddu.open
126///
127#[pyclass(name = "DatasetBase", module = "laddu", subclass)]
128#[derive(Clone)]
129pub struct PyDataset(pub Arc<Dataset>);
130
131#[pymethods]
132impl PyDataset {
133    #[new]
134    fn new(events: Vec<PyEvent>) -> Self {
135        Self(Arc::new(Dataset::new(
136            events.into_iter().map(|event| event.0).collect(),
137        )))
138    }
139    fn __len__(&self) -> usize {
140        self.0.n_events()
141    }
142    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
143        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
144            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
145        } else if let Ok(other_int) = other.extract::<usize>() {
146            if other_int == 0 {
147                Ok(self.clone())
148            } else {
149                Err(PyTypeError::new_err(
150                    "Addition with an integer for this type is only defined for 0",
151                ))
152            }
153        } else {
154            Err(PyTypeError::new_err("Unsupported operand type for +"))
155        }
156    }
157    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
158        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
159            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
160        } else if let Ok(other_int) = other.extract::<usize>() {
161            if other_int == 0 {
162                Ok(self.clone())
163            } else {
164                Err(PyTypeError::new_err(
165                    "Addition with an integer for this type is only defined for 0",
166                ))
167            }
168        } else {
169            Err(PyTypeError::new_err("Unsupported operand type for +"))
170        }
171    }
172    /// Get the number of Events in the Dataset
173    ///
174    /// Returns
175    /// -------
176    /// n_events : int
177    ///     The number of Events
178    ///
179    #[getter]
180    fn n_events(&self) -> usize {
181        self.0.n_events()
182    }
183    /// Get the weighted number of Events in the Dataset
184    ///
185    /// Returns
186    /// -------
187    /// n_events : float
188    ///     The sum of all Event weights
189    ///
190    #[getter]
191    fn n_events_weighted(&self) -> Float {
192        self.0.n_events_weighted()
193    }
194    /// The weights associated with the Dataset
195    ///
196    /// Returns
197    /// -------
198    /// weights : array_like
199    ///     A ``numpy`` array of Event weights
200    ///
201    #[getter]
202    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
203        PyArray1::from_slice(py, &self.0.weights())
204    }
205    /// The internal list of Events stored in the Dataset
206    ///
207    /// Returns
208    /// -------
209    /// events : list of Event
210    ///     The Events in the Dataset
211    ///
212    #[getter]
213    fn events(&self) -> Vec<PyEvent> {
214        self.0
215            .events
216            .iter()
217            .map(|rust_event| PyEvent(rust_event.clone()))
218            .collect()
219    }
220    fn __getitem__(&self, index: usize) -> PyEvent {
221        PyEvent(Arc::new(self.0[index].clone()))
222    }
223    /// Separates a Dataset into histogram bins by a Variable value
224    ///
225    /// Parameters
226    /// ----------
227    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
228    ///     The Variable by which each Event is binned
229    /// bins : int
230    ///     The number of equally-spaced bins
231    /// range : tuple[float, float]
232    ///     The minimum and maximum bin edges
233    ///
234    /// Returns
235    /// -------
236    /// datasets : BinnedDataset
237    ///     A pub structure that holds a list of Datasets binned by the given `variable`
238    ///
239    /// See Also
240    /// --------
241    /// laddu.Mass
242    /// laddu.CosTheta
243    /// laddu.Phi
244    /// laddu.PolAngle
245    /// laddu.PolMagnitude
246    /// laddu.Mandelstam
247    ///
248    /// Raises
249    /// ------
250    /// TypeError
251    ///     If the given `variable` is not a valid variable
252    ///
253    #[pyo3(signature = (variable, bins, range))]
254    fn bin_by(
255        &self,
256        variable: Bound<'_, PyAny>,
257        bins: usize,
258        range: (Float, Float),
259    ) -> PyResult<PyBinnedDataset> {
260        let py_variable = variable.extract::<PyVariable>()?;
261        Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
262    }
263    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
264    ///
265    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
266    ///
267    /// Parameters
268    /// ----------
269    /// seed : int
270    ///     The random seed used in the resampling process
271    ///
272    /// Returns
273    /// -------
274    /// Dataset
275    ///     A bootstrapped Dataset
276    ///
277    fn bootstrap(&self, seed: usize) -> PyDataset {
278        PyDataset(self.0.bootstrap(seed))
279    }
280    /// Boost all the four-momenta in all events to the rest frame of the given set of
281    /// four-momenta by indices.
282    ///
283    /// Parameters
284    /// ----------
285    /// indices : list of int
286    ///     The indices of the four-momenta to sum
287    ///
288    /// Returns
289    /// -------
290    /// Dataset
291    ///     The boosted dataset
292    ///
293    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> PyDataset {
294        PyDataset(self.0.boost_to_rest_frame_of(indices))
295    }
296}
297
298/// A collection of Datasets binned by a Variable
299///
300/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
301///
302/// See Also
303/// --------
304/// laddu.Dataset.bin_by
305///
306#[pyclass(name = "BinnedDataset", module = "laddu")]
307pub struct PyBinnedDataset(BinnedDataset);
308
309#[pymethods]
310impl PyBinnedDataset {
311    fn __len__(&self) -> usize {
312        self.0.n_bins()
313    }
314    /// The number of bins in the BinnedDataset
315    ///
316    #[getter]
317    fn n_bins(&self) -> usize {
318        self.0.n_bins()
319    }
320    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
321    ///
322    #[getter]
323    fn range(&self) -> (Float, Float) {
324        self.0.range()
325    }
326    /// The edges of each bin in the BinnedDataset
327    ///
328    #[getter]
329    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
330        PyArray1::from_slice(py, &self.0.edges())
331    }
332    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
333        self.0
334            .get(index)
335            .ok_or(PyIndexError::new_err("index out of range"))
336            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
337    }
338}
339
340/// Open a Dataset from a file
341///
342/// Arguments
343/// ---------
344/// path : str or Path
345///     The path to the file
346/// rest_frame_indices : list of int, optional
347///     If supplied, the dataset will be boosted to the rest frame of the 4-momenta at the
348///     given indices
349///
350///
351/// Returns
352/// -------
353/// Dataset
354///
355/// Raises
356/// ------
357/// IOError
358///     If the file could not be read
359///
360/// Warnings
361/// --------
362/// This method will panic/fail if the columns do not have the correct names or data types.
363/// There is currently no way to make this nicer without a large performance dip (if you find a
364/// way, please open a PR).
365///
366/// Notes
367/// -----
368/// Data should be stored in Parquet format with each column being filled with 32-bit floats
369///
370/// Valid/required column names have the following formats:
371///
372/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
373///
374/// ``aux_{particle index}_{x|y|z}`` (auxiliary vectors for each particle)
375///
376/// ``weight`` (the weight of the Event)
377///
378/// For example, the four-momentum of the 0th particle in the event would be stored in columns
379/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
380/// polarization could be stored in the columns ``aux_0_x``, ``aux_0_y``, and ``aux_0_z``. This
381/// could continue for an arbitrary number of particles. The ``weight`` column is always
382/// required.
383///
384#[pyfunction(name = "open", signature = (path, *, rest_frame_indices=None))]
385pub fn py_open(path: Bound<PyAny>, rest_frame_indices: Option<Vec<usize>>) -> PyResult<PyDataset> {
386    let path_str = if let Ok(s) = path.extract::<String>() {
387        Ok(s)
388    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
389        Ok(pathbuf.to_string_lossy().into_owned())
390    } else {
391        Err(PyTypeError::new_err("Expected str or Path"))
392    }?;
393    if let Some(indices) = rest_frame_indices {
394        Ok(PyDataset(open_boosted_to_rest_frame_of(path_str, indices)?))
395    } else {
396        Ok(PyDataset(open(path_str)?))
397    }
398}