laddu_python/
data.rs

1use crate::utils::variables::PyVariable;
2use laddu_core::{
3    data::{open, BinnedDataset, Dataset, Event},
4    Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8    exceptions::{PyIndexError, PyTypeError},
9    prelude::*,
10};
11use std::sync::Arc;
12
13use crate::utils::vectors::{PyVector3, PyVector4};
14
15/// A single event
16///
17/// Events are composed of a set of 4-momenta of particles in the overall
18/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
19/// weight
20///
21/// Parameters
22/// ----------
23/// p4s : list of Vector4
24///     4-momenta of each particle in the event in the overall center-of-momentum frame
25/// eps : list of Vector3
26///     3-vectors describing the polarization or helicity of the particles
27///     given in `p4s`
28/// weight : float
29///     The weight associated with this event
30///
31#[pyclass(name = "Event", module = "laddu")]
32#[derive(Clone)]
33pub struct PyEvent(pub Arc<Event>);
34
35#[pymethods]
36impl PyEvent {
37    #[new]
38    fn new(p4s: Vec<PyVector4>, eps: Vec<PyVector3>, weight: Float) -> Self {
39        Self(Arc::new(Event {
40            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
41            aux: eps.into_iter().map(|arr| arr.0).collect(),
42            weight,
43        }))
44    }
45    fn __str__(&self) -> String {
46        self.0.to_string()
47    }
48    /// The list of 4-momenta for each particle in the event
49    ///
50    #[getter]
51    fn get_p4s(&self) -> Vec<PyVector4> {
52        self.0.p4s.iter().map(|p4| PyVector4(*p4)).collect()
53    }
54    /// The list of 3-vectors describing the polarization or helicity of particles in
55    /// the event
56    ///
57    #[getter]
58    fn get_eps(&self) -> Vec<PyVector3> {
59        self.0
60            .aux
61            .iter()
62            .map(|eps_vec| PyVector3(*eps_vec))
63            .collect()
64    }
65    /// The weight of this event relative to others in a Dataset
66    ///
67    #[getter]
68    fn get_weight(&self) -> Float {
69        self.0.weight
70    }
71    /// Get the sum of the four-momenta within the event at the given indices
72    ///
73    /// Parameters
74    /// ----------
75    /// indices : list of int
76    ///     The indices of the four-momenta to sum
77    ///
78    /// Returns
79    /// -------
80    /// Vector4
81    ///     The result of summing the given four-momenta
82    ///
83    fn get_p4_sum(&self, indices: Vec<usize>) -> PyVector4 {
84        PyVector4(self.0.get_p4_sum(indices))
85    }
86    /// Boost all the four-momenta in the event to the rest frame of the given set of
87    /// four-momenta by indices.
88    ///
89    /// Parameters
90    /// ----------
91    /// indices : list of int
92    ///     The indices of the four-momenta to sum
93    ///
94    /// Returns
95    /// -------
96    /// Event
97    ///     The boosted event
98    ///
99    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> Self {
100        PyEvent(Arc::new(self.0.boost_to_rest_frame_of(indices)))
101    }
102}
103
104/// A set of Events
105///
106/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
107///
108/// Datasets can also be indexed directly to access individual Events
109///
110/// Parameters
111/// ----------
112/// events : list of Event
113///
114/// See Also
115/// --------
116/// laddu.open
117///
118#[pyclass(name = "Dataset", module = "laddu")]
119#[derive(Clone)]
120pub struct PyDataset(pub Arc<Dataset>);
121
122#[pymethods]
123impl PyDataset {
124    #[new]
125    fn new(events: Vec<PyEvent>) -> Self {
126        Self(Arc::new(Dataset::new(
127            events.into_iter().map(|event| event.0).collect(),
128        )))
129    }
130    fn __len__(&self) -> usize {
131        self.0.n_events()
132    }
133    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
134        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
135            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
136        } else if let Ok(other_int) = other.extract::<usize>() {
137            if other_int == 0 {
138                Ok(self.clone())
139            } else {
140                Err(PyTypeError::new_err(
141                    "Addition with an integer for this type is only defined for 0",
142                ))
143            }
144        } else {
145            Err(PyTypeError::new_err("Unsupported operand type for +"))
146        }
147    }
148    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
149        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
150            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
151        } else if let Ok(other_int) = other.extract::<usize>() {
152            if other_int == 0 {
153                Ok(self.clone())
154            } else {
155                Err(PyTypeError::new_err(
156                    "Addition with an integer for this type is only defined for 0",
157                ))
158            }
159        } else {
160            Err(PyTypeError::new_err("Unsupported operand type for +"))
161        }
162    }
163    /// Get the number of Events in the Dataset
164    ///
165    /// Returns
166    /// -------
167    /// n_events : int
168    ///     The number of Events
169    ///
170    #[getter]
171    fn n_events(&self) -> usize {
172        self.0.n_events()
173    }
174    /// Get the weighted number of Events in the Dataset
175    ///
176    /// Returns
177    /// -------
178    /// n_events : float
179    ///     The sum of all Event weights
180    ///
181    #[getter]
182    fn n_events_weighted(&self) -> Float {
183        self.0.n_events_weighted()
184    }
185    /// The weights associated with the Dataset
186    ///
187    /// Returns
188    /// -------
189    /// weights : array_like
190    ///     A ``numpy`` array of Event weights
191    ///
192    #[getter]
193    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
194        PyArray1::from_slice(py, &self.0.weights())
195    }
196    /// The internal list of Events stored in the Dataset
197    ///
198    /// Returns
199    /// -------
200    /// events : list of Event
201    ///     The Events in the Dataset
202    ///
203    #[getter]
204    fn events(&self) -> Vec<PyEvent> {
205        self.0
206            .events
207            .iter()
208            .map(|rust_event| PyEvent(rust_event.clone()))
209            .collect()
210    }
211    fn __getitem__(&self, index: usize) -> PyEvent {
212        PyEvent(Arc::new(self.0[index].clone()))
213    }
214    /// Separates a Dataset into histogram bins by a Variable value
215    ///
216    /// Parameters
217    /// ----------
218    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
219    ///     The Variable by which each Event is binned
220    /// bins : int
221    ///     The number of equally-spaced bins
222    /// range : tuple[float, float]
223    ///     The minimum and maximum bin edges
224    ///
225    /// Returns
226    /// -------
227    /// datasets : BinnedDataset
228    ///     A pub structure that holds a list of Datasets binned by the given `variable`
229    ///
230    /// See Also
231    /// --------
232    /// laddu.Mass
233    /// laddu.CosTheta
234    /// laddu.Phi
235    /// laddu.PolAngle
236    /// laddu.PolMagnitude
237    /// laddu.Mandelstam
238    ///
239    /// Raises
240    /// ------
241    /// TypeError
242    ///     If the given `variable` is not a valid variable
243    ///
244    #[pyo3(signature = (variable, bins, range))]
245    fn bin_by(
246        &self,
247        variable: Bound<'_, PyAny>,
248        bins: usize,
249        range: (Float, Float),
250    ) -> PyResult<PyBinnedDataset> {
251        let py_variable = variable.extract::<PyVariable>()?;
252        Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
253    }
254    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
255    ///
256    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
257    ///
258    /// Parameters
259    /// ----------
260    /// seed : int
261    ///     The random seed used in the resampling process
262    ///
263    /// Returns
264    /// -------
265    /// Dataset
266    ///     A bootstrapped Dataset
267    ///
268    fn bootstrap(&self, seed: usize) -> PyDataset {
269        PyDataset(self.0.bootstrap(seed))
270    }
271    /// Boost all the four-momenta in all events to the rest frame of the given set of
272    /// four-momenta by indices.
273    ///
274    /// Parameters
275    /// ----------
276    /// indices : list of int
277    ///     The indices of the four-momenta to sum
278    ///
279    /// Returns
280    /// -------
281    /// Dataset
282    ///     The boosted dataset
283    ///
284    pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> PyDataset {
285        PyDataset(self.0.boost_to_rest_frame_of(indices))
286    }
287}
288
289/// A collection of Datasets binned by a Variable
290///
291/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
292///
293/// See Also
294/// --------
295/// laddu.Dataset.bin_by
296///
297#[pyclass(name = "BinnedDataset", module = "laddu")]
298pub struct PyBinnedDataset(BinnedDataset);
299
300#[pymethods]
301impl PyBinnedDataset {
302    fn __len__(&self) -> usize {
303        self.0.n_bins()
304    }
305    /// The number of bins in the BinnedDataset
306    ///
307    #[getter]
308    fn n_bins(&self) -> usize {
309        self.0.n_bins()
310    }
311    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
312    ///
313    #[getter]
314    fn range(&self) -> (Float, Float) {
315        self.0.range()
316    }
317    /// The edges of each bin in the BinnedDataset
318    ///
319    #[getter]
320    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
321        PyArray1::from_slice(py, &self.0.edges())
322    }
323    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
324        self.0
325            .get(index)
326            .ok_or(PyIndexError::new_err("index out of range"))
327            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
328    }
329}
330
331/// Open a Dataset from a file
332///
333/// Returns
334/// -------
335/// Dataset
336///
337/// Raises
338/// ------
339/// IOError
340///     If the file could not be read
341///
342/// Warnings
343/// --------
344/// This method will panic/fail if the columns do not have the correct names or data types.
345/// There is currently no way to make this nicer without a large performance dip (if you find a
346/// way, please open a PR).
347///
348/// Notes
349/// -----
350/// Data should be stored in Parquet format with each column being filled with 32-bit floats
351///
352/// Valid/required column names have the following formats:
353///
354/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
355///
356/// ``eps_{particle index}_{x|y|z}`` (polarization/helicity vectors for each particle)
357///
358/// ``weight`` (the weight of the Event)
359///
360/// For example, the four-momentum of the 0th particle in the event would be stored in columns
361/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
362/// polarization could be stored in the columns ``eps_0_x``, ``eps_0_y``, and ``eps_0_z``. This
363/// could continue for an arbitrary number of particles. The ``weight`` column is always
364/// required.
365///
366#[pyfunction(name = "open")]
367pub fn py_open(path: &str) -> PyResult<PyDataset> {
368    Ok(PyDataset(open(path)?))
369}