laddu_python/data.rs
1use crate::utils::variables::PyVariable;
2use laddu_core::{
3 data::{open, BinnedDataset, Dataset, Event},
4 Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8 exceptions::{PyIndexError, PyTypeError},
9 prelude::*,
10};
11use std::sync::Arc;
12
13use crate::utils::vectors::{PyVector3, PyVector4};
14
15/// A single event
16///
17/// Events are composed of a set of 4-momenta of particles in the overall
18/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
19/// weight
20///
21/// Parameters
22/// ----------
23/// p4s : list of Vector4
24/// 4-momenta of each particle in the event in the overall center-of-momentum frame
25/// eps : list of Vector3
26/// 3-vectors describing the polarization or helicity of the particles
27/// given in `p4s`
28/// weight : float
29/// The weight associated with this event
30///
31#[pyclass(name = "Event", module = "laddu")]
32#[derive(Clone)]
33pub struct PyEvent(pub Arc<Event>);
34
35#[pymethods]
36impl PyEvent {
37 #[new]
38 fn new(p4s: Vec<PyVector4>, eps: Vec<PyVector3>, weight: Float) -> Self {
39 Self(Arc::new(Event {
40 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
41 eps: eps.into_iter().map(|arr| arr.0).collect(),
42 weight,
43 }))
44 }
45 fn __str__(&self) -> String {
46 self.0.to_string()
47 }
48 /// The list of 4-momenta for each particle in the event
49 ///
50 #[getter]
51 fn get_p4s(&self) -> Vec<PyVector4> {
52 self.0.p4s.iter().map(|p4| PyVector4(*p4)).collect()
53 }
54 /// The list of 3-vectors describing the polarization or helicity of particles in
55 /// the event
56 ///
57 #[getter]
58 fn get_eps(&self) -> Vec<PyVector3> {
59 self.0
60 .eps
61 .iter()
62 .map(|eps_vec| PyVector3(*eps_vec))
63 .collect()
64 }
65 /// The weight of this event relative to others in a Dataset
66 ///
67 #[getter]
68 fn get_weight(&self) -> Float {
69 self.0.weight
70 }
71 /// Get the sum of the four-momenta within the event at the given indices
72 ///
73 /// Parameters
74 /// ----------
75 /// indices : list of int
76 /// The indices of the four-momenta to sum
77 ///
78 /// Returns
79 /// -------
80 /// Vector4
81 /// The result of summing the given four-momenta
82 ///
83 fn get_p4_sum(&self, indices: Vec<usize>) -> PyVector4 {
84 PyVector4(self.0.get_p4_sum(indices))
85 }
86}
87
88/// A set of Events
89///
90/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
91///
92/// Datasets can also be indexed directly to access individual Events
93///
94/// Parameters
95/// ----------
96/// events : list of Event
97///
98/// See Also
99/// --------
100/// laddu.open
101///
102#[pyclass(name = "Dataset", module = "laddu")]
103#[derive(Clone)]
104pub struct PyDataset(pub Arc<Dataset>);
105
106#[pymethods]
107impl PyDataset {
108 #[new]
109 fn new(events: Vec<PyEvent>) -> Self {
110 Self(Arc::new(Dataset::new(
111 events.into_iter().map(|event| event.0).collect(),
112 )))
113 }
114 fn __len__(&self) -> usize {
115 self.0.n_events()
116 }
117 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
118 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
119 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
120 } else if let Ok(other_int) = other.extract::<usize>() {
121 if other_int == 0 {
122 Ok(self.clone())
123 } else {
124 Err(PyTypeError::new_err(
125 "Addition with an integer for this type is only defined for 0",
126 ))
127 }
128 } else {
129 Err(PyTypeError::new_err("Unsupported operand type for +"))
130 }
131 }
132 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
133 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
134 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
135 } else if let Ok(other_int) = other.extract::<usize>() {
136 if other_int == 0 {
137 Ok(self.clone())
138 } else {
139 Err(PyTypeError::new_err(
140 "Addition with an integer for this type is only defined for 0",
141 ))
142 }
143 } else {
144 Err(PyTypeError::new_err("Unsupported operand type for +"))
145 }
146 }
147 /// Get the number of Events in the Dataset
148 ///
149 /// Returns
150 /// -------
151 /// n_events : int
152 /// The number of Events
153 ///
154 #[getter]
155 fn n_events(&self) -> usize {
156 self.0.n_events()
157 }
158 /// Get the weighted number of Events in the Dataset
159 ///
160 /// Returns
161 /// -------
162 /// n_events : float
163 /// The sum of all Event weights
164 ///
165 #[getter]
166 fn n_events_weighted(&self) -> Float {
167 self.0.n_events_weighted()
168 }
169 /// The weights associated with the Dataset
170 ///
171 /// Returns
172 /// -------
173 /// weights : array_like
174 /// A ``numpy`` array of Event weights
175 ///
176 #[getter]
177 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
178 PyArray1::from_slice(py, &self.0.weights())
179 }
180 /// The internal list of Events stored in the Dataset
181 ///
182 /// Returns
183 /// -------
184 /// events : list of Event
185 /// The Events in the Dataset
186 ///
187 #[getter]
188 fn events(&self) -> Vec<PyEvent> {
189 self.0
190 .events
191 .iter()
192 .map(|rust_event| PyEvent(rust_event.clone()))
193 .collect()
194 }
195 fn __getitem__(&self, index: usize) -> PyEvent {
196 PyEvent(Arc::new(self.0[index].clone()))
197 }
198 /// Separates a Dataset into histogram bins by a Variable value
199 ///
200 /// Parameters
201 /// ----------
202 /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
203 /// The Variable by which each Event is binned
204 /// bins : int
205 /// The number of equally-spaced bins
206 /// range : tuple[float, float]
207 /// The minimum and maximum bin edges
208 ///
209 /// Returns
210 /// -------
211 /// datasets : BinnedDataset
212 /// A pub structure that holds a list of Datasets binned by the given `variable`
213 ///
214 /// See Also
215 /// --------
216 /// laddu.Mass
217 /// laddu.CosTheta
218 /// laddu.Phi
219 /// laddu.PolAngle
220 /// laddu.PolMagnitude
221 /// laddu.Mandelstam
222 ///
223 /// Raises
224 /// ------
225 /// TypeError
226 /// If the given `variable` is not a valid variable
227 ///
228 #[pyo3(signature = (variable, bins, range))]
229 fn bin_by(
230 &self,
231 variable: Bound<'_, PyAny>,
232 bins: usize,
233 range: (Float, Float),
234 ) -> PyResult<PyBinnedDataset> {
235 let py_variable = variable.extract::<PyVariable>()?;
236 Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
237 }
238 /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
239 ///
240 /// The new Dataset is resampled with a random generator seeded by the provided `seed`
241 ///
242 /// Parameters
243 /// ----------
244 /// seed : int
245 /// The random seed used in the resampling process
246 ///
247 /// Returns
248 /// -------
249 /// Dataset
250 /// A bootstrapped Dataset
251 ///
252 fn bootstrap(&self, seed: usize) -> PyDataset {
253 PyDataset(self.0.bootstrap(seed))
254 }
255}
256
257/// A collection of Datasets binned by a Variable
258///
259/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
260///
261/// See Also
262/// --------
263/// laddu.Dataset.bin_by
264///
265#[pyclass(name = "BinnedDataset", module = "laddu")]
266pub struct PyBinnedDataset(BinnedDataset);
267
268#[pymethods]
269impl PyBinnedDataset {
270 fn __len__(&self) -> usize {
271 self.0.n_bins()
272 }
273 /// The number of bins in the BinnedDataset
274 ///
275 #[getter]
276 fn n_bins(&self) -> usize {
277 self.0.n_bins()
278 }
279 /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
280 ///
281 #[getter]
282 fn range(&self) -> (Float, Float) {
283 self.0.range()
284 }
285 /// The edges of each bin in the BinnedDataset
286 ///
287 #[getter]
288 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
289 PyArray1::from_slice(py, &self.0.edges())
290 }
291 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
292 self.0
293 .get(index)
294 .ok_or(PyIndexError::new_err("index out of range"))
295 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
296 }
297}
298
299/// Open a Dataset from a file
300///
301/// Returns
302/// -------
303/// Dataset
304///
305/// Raises
306/// ------
307/// IOError
308/// If the file could not be read
309///
310/// Warnings
311/// --------
312/// This method will panic/fail if the columns do not have the correct names or data types.
313/// There is currently no way to make this nicer without a large performance dip (if you find a
314/// way, please open a PR).
315///
316/// Notes
317/// -----
318/// Data should be stored in Parquet format with each column being filled with 32-bit floats
319///
320/// Valid/required column names have the following formats:
321///
322/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
323///
324/// ``eps_{particle index}_{x|y|z}`` (polarization/helicity vectors for each particle)
325///
326/// ``weight`` (the weight of the Event)
327///
328/// For example, the four-momentum of the 0th particle in the event would be stored in columns
329/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
330/// polarization could be stored in the columns ``eps_0_x``, ``eps_0_y``, and ``eps_0_z``. This
331/// could continue for an arbitrary number of particles. The ``weight`` column is always
332/// required.
333///
334#[pyfunction(name = "open")]
335pub fn py_open(path: &str) -> PyResult<PyDataset> {
336 Ok(PyDataset(open(path)?))
337}