laddu_python/
data.rs

1use crate::utils::variables::PyVariable;
2use laddu_core::{
3    data::{open, BinnedDataset, Dataset, Event},
4    Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8    exceptions::{PyIndexError, PyTypeError},
9    prelude::*,
10};
11use std::sync::Arc;
12
13use crate::utils::vectors::{PyVector3, PyVector4};
14
15/// A single event
16///
17/// Events are composed of a set of 4-momenta of particles in the overall
18/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
19/// weight
20///
21/// Parameters
22/// ----------
23/// p4s : list of Vector4
24///     4-momenta of each particle in the event in the overall center-of-momentum frame
25/// eps : list of Vector3
26///     3-vectors describing the polarization or helicity of the particles
27///     given in `p4s`
28/// weight : float
29///     The weight associated with this event
30///
31#[pyclass(name = "Event", module = "laddu")]
32#[derive(Clone)]
33pub struct PyEvent(pub Arc<Event>);
34
35#[pymethods]
36impl PyEvent {
37    #[new]
38    fn new(p4s: Vec<PyVector4>, eps: Vec<PyVector3>, weight: Float) -> Self {
39        Self(Arc::new(Event {
40            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
41            eps: eps.into_iter().map(|arr| arr.0).collect(),
42            weight,
43        }))
44    }
45    fn __str__(&self) -> String {
46        self.0.to_string()
47    }
48    /// The list of 4-momenta for each particle in the event
49    ///
50    #[getter]
51    fn get_p4s(&self) -> Vec<PyVector4> {
52        self.0.p4s.iter().map(|p4| PyVector4(*p4)).collect()
53    }
54    /// The list of 3-vectors describing the polarization or helicity of particles in
55    /// the event
56    ///
57    #[getter]
58    fn get_eps(&self) -> Vec<PyVector3> {
59        self.0
60            .eps
61            .iter()
62            .map(|eps_vec| PyVector3(*eps_vec))
63            .collect()
64    }
65    /// The weight of this event relative to others in a Dataset
66    ///
67    #[getter]
68    fn get_weight(&self) -> Float {
69        self.0.weight
70    }
71    /// Get the sum of the four-momenta within the event at the given indices
72    ///
73    /// Parameters
74    /// ----------
75    /// indices : list of int
76    ///     The indices of the four-momenta to sum
77    ///
78    /// Returns
79    /// -------
80    /// Vector4
81    ///     The result of summing the given four-momenta
82    ///
83    fn get_p4_sum(&self, indices: Vec<usize>) -> PyVector4 {
84        PyVector4(self.0.get_p4_sum(indices))
85    }
86}
87
88/// A set of Events
89///
90/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
91///
92/// Datasets can also be indexed directly to access individual Events
93///
94/// Parameters
95/// ----------
96/// events : list of Event
97///
98/// See Also
99/// --------
100/// laddu.open
101///
102#[pyclass(name = "Dataset", module = "laddu")]
103#[derive(Clone)]
104pub struct PyDataset(pub Arc<Dataset>);
105
106#[pymethods]
107impl PyDataset {
108    #[new]
109    fn new(events: Vec<PyEvent>) -> Self {
110        Self(Arc::new(Dataset::new(
111            events.into_iter().map(|event| event.0).collect(),
112        )))
113    }
114    fn __len__(&self) -> usize {
115        self.0.n_events()
116    }
117    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
118        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
119            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
120        } else if let Ok(other_int) = other.extract::<usize>() {
121            if other_int == 0 {
122                Ok(self.clone())
123            } else {
124                Err(PyTypeError::new_err(
125                    "Addition with an integer for this type is only defined for 0",
126                ))
127            }
128        } else {
129            Err(PyTypeError::new_err("Unsupported operand type for +"))
130        }
131    }
132    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
133        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
134            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
135        } else if let Ok(other_int) = other.extract::<usize>() {
136            if other_int == 0 {
137                Ok(self.clone())
138            } else {
139                Err(PyTypeError::new_err(
140                    "Addition with an integer for this type is only defined for 0",
141                ))
142            }
143        } else {
144            Err(PyTypeError::new_err("Unsupported operand type for +"))
145        }
146    }
147    /// Get the number of Events in the Dataset
148    ///
149    /// Returns
150    /// -------
151    /// n_events : int
152    ///     The number of Events
153    ///
154    #[getter]
155    fn n_events(&self) -> usize {
156        self.0.n_events()
157    }
158    /// Get the weighted number of Events in the Dataset
159    ///
160    /// Returns
161    /// -------
162    /// n_events : float
163    ///     The sum of all Event weights
164    ///
165    #[getter]
166    fn n_events_weighted(&self) -> Float {
167        self.0.n_events_weighted()
168    }
169    /// The weights associated with the Dataset
170    ///
171    /// Returns
172    /// -------
173    /// weights : array_like
174    ///     A ``numpy`` array of Event weights
175    ///
176    #[getter]
177    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
178        PyArray1::from_slice(py, &self.0.weights())
179    }
180    /// The internal list of Events stored in the Dataset
181    ///
182    /// Returns
183    /// -------
184    /// events : list of Event
185    ///     The Events in the Dataset
186    ///
187    #[getter]
188    fn events(&self) -> Vec<PyEvent> {
189        self.0
190            .events
191            .iter()
192            .map(|rust_event| PyEvent(rust_event.clone()))
193            .collect()
194    }
195    fn __getitem__(&self, index: usize) -> PyEvent {
196        PyEvent(Arc::new(self.0[index].clone()))
197    }
198    /// Separates a Dataset into histogram bins by a Variable value
199    ///
200    /// Parameters
201    /// ----------
202    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
203    ///     The Variable by which each Event is binned
204    /// bins : int
205    ///     The number of equally-spaced bins
206    /// range : tuple[float, float]
207    ///     The minimum and maximum bin edges
208    ///
209    /// Returns
210    /// -------
211    /// datasets : BinnedDataset
212    ///     A pub structure that holds a list of Datasets binned by the given `variable`
213    ///
214    /// See Also
215    /// --------
216    /// laddu.Mass
217    /// laddu.CosTheta
218    /// laddu.Phi
219    /// laddu.PolAngle
220    /// laddu.PolMagnitude
221    /// laddu.Mandelstam
222    ///
223    /// Raises
224    /// ------
225    /// TypeError
226    ///     If the given `variable` is not a valid variable
227    ///
228    #[pyo3(signature = (variable, bins, range))]
229    fn bin_by(
230        &self,
231        variable: Bound<'_, PyAny>,
232        bins: usize,
233        range: (Float, Float),
234    ) -> PyResult<PyBinnedDataset> {
235        let py_variable = variable.extract::<PyVariable>()?;
236        Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
237    }
238    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
239    ///
240    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
241    ///
242    /// Parameters
243    /// ----------
244    /// seed : int
245    ///     The random seed used in the resampling process
246    ///
247    /// Returns
248    /// -------
249    /// Dataset
250    ///     A bootstrapped Dataset
251    ///
252    fn bootstrap(&self, seed: usize) -> PyDataset {
253        PyDataset(self.0.bootstrap(seed))
254    }
255}
256
257/// A collection of Datasets binned by a Variable
258///
259/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
260///
261/// See Also
262/// --------
263/// laddu.Dataset.bin_by
264///
265#[pyclass(name = "BinnedDataset", module = "laddu")]
266pub struct PyBinnedDataset(BinnedDataset);
267
268#[pymethods]
269impl PyBinnedDataset {
270    fn __len__(&self) -> usize {
271        self.0.n_bins()
272    }
273    /// The number of bins in the BinnedDataset
274    ///
275    #[getter]
276    fn n_bins(&self) -> usize {
277        self.0.n_bins()
278    }
279    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
280    ///
281    #[getter]
282    fn range(&self) -> (Float, Float) {
283        self.0.range()
284    }
285    /// The edges of each bin in the BinnedDataset
286    ///
287    #[getter]
288    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
289        PyArray1::from_slice(py, &self.0.edges())
290    }
291    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
292        self.0
293            .get(index)
294            .ok_or(PyIndexError::new_err("index out of range"))
295            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
296    }
297}
298
299/// Open a Dataset from a file
300///
301/// Returns
302/// -------
303/// Dataset
304///
305/// Raises
306/// ------
307/// IOError
308///     If the file could not be read
309///
310/// Warnings
311/// --------
312/// This method will panic/fail if the columns do not have the correct names or data types.
313/// There is currently no way to make this nicer without a large performance dip (if you find a
314/// way, please open a PR).
315///
316/// Notes
317/// -----
318/// Data should be stored in Parquet format with each column being filled with 32-bit floats
319///
320/// Valid/required column names have the following formats:
321///
322/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
323///
324/// ``eps_{particle index}_{x|y|z}`` (polarization/helicity vectors for each particle)
325///
326/// ``weight`` (the weight of the Event)
327///
328/// For example, the four-momentum of the 0th particle in the event would be stored in columns
329/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
330/// polarization could be stored in the columns ``eps_0_x``, ``eps_0_y``, and ``eps_0_z``. This
331/// could continue for an arbitrary number of particles. The ``weight`` column is always
332/// required.
333///
334#[pyfunction(name = "open")]
335pub fn py_open(path: &str) -> PyResult<PyDataset> {
336    Ok(PyDataset(open(path)?))
337}