laddu_python/data.rs
1use crate::utils::variables::PyVariable;
2use laddu_core::{
3 data::{open, BinnedDataset, Dataset, Event},
4 Float,
5};
6use numpy::PyArray1;
7use pyo3::{
8 exceptions::{PyIndexError, PyTypeError},
9 prelude::*,
10};
11use std::sync::Arc;
12
13use crate::utils::vectors::{PyVector3, PyVector4};
14
15/// A single event
16///
17/// Events are composed of a set of 4-momenta of particles in the overall
18/// center-of-momentum frame, polarizations or helicities described by 3-vectors, and a
19/// weight
20///
21/// Parameters
22/// ----------
23/// p4s : list of Vector4
24/// 4-momenta of each particle in the event in the overall center-of-momentum frame
25/// eps : list of Vector3
26/// 3-vectors describing the polarization or helicity of the particles
27/// given in `p4s`
28/// weight : float
29/// The weight associated with this event
30///
31#[pyclass(name = "Event", module = "laddu")]
32#[derive(Clone)]
33pub struct PyEvent(pub Arc<Event>);
34
35#[pymethods]
36impl PyEvent {
37 #[new]
38 fn new(p4s: Vec<PyVector4>, eps: Vec<PyVector3>, weight: Float) -> Self {
39 Self(Arc::new(Event {
40 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
41 aux: eps.into_iter().map(|arr| arr.0).collect(),
42 weight,
43 }))
44 }
45 fn __str__(&self) -> String {
46 self.0.to_string()
47 }
48 /// The list of 4-momenta for each particle in the event
49 ///
50 #[getter]
51 fn get_p4s(&self) -> Vec<PyVector4> {
52 self.0.p4s.iter().map(|p4| PyVector4(*p4)).collect()
53 }
54 /// The list of 3-vectors describing the polarization or helicity of particles in
55 /// the event
56 ///
57 #[getter]
58 fn get_eps(&self) -> Vec<PyVector3> {
59 self.0
60 .aux
61 .iter()
62 .map(|eps_vec| PyVector3(*eps_vec))
63 .collect()
64 }
65 /// The weight of this event relative to others in a Dataset
66 ///
67 #[getter]
68 fn get_weight(&self) -> Float {
69 self.0.weight
70 }
71 /// Get the sum of the four-momenta within the event at the given indices
72 ///
73 /// Parameters
74 /// ----------
75 /// indices : list of int
76 /// The indices of the four-momenta to sum
77 ///
78 /// Returns
79 /// -------
80 /// Vector4
81 /// The result of summing the given four-momenta
82 ///
83 fn get_p4_sum(&self, indices: Vec<usize>) -> PyVector4 {
84 PyVector4(self.0.get_p4_sum(indices))
85 }
86 /// Boost all the four-momenta in the event to the rest frame of the given set of
87 /// four-momenta by indices.
88 ///
89 /// Parameters
90 /// ----------
91 /// indices : list of int
92 /// The indices of the four-momenta to sum
93 ///
94 /// Returns
95 /// -------
96 /// Event
97 /// The boosted event
98 ///
99 pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> Self {
100 PyEvent(Arc::new(self.0.boost_to_rest_frame_of(indices)))
101 }
102}
103
104/// A set of Events
105///
106/// Datasets can be created from lists of Events or by using the provided ``laddu.open`` function
107///
108/// Datasets can also be indexed directly to access individual Events
109///
110/// Parameters
111/// ----------
112/// events : list of Event
113///
114/// See Also
115/// --------
116/// laddu.open
117///
118#[pyclass(name = "Dataset", module = "laddu")]
119#[derive(Clone)]
120pub struct PyDataset(pub Arc<Dataset>);
121
122#[pymethods]
123impl PyDataset {
124 #[new]
125 fn new(events: Vec<PyEvent>) -> Self {
126 Self(Arc::new(Dataset::new(
127 events.into_iter().map(|event| event.0).collect(),
128 )))
129 }
130 fn __len__(&self) -> usize {
131 self.0.n_events()
132 }
133 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
134 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
135 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
136 } else if let Ok(other_int) = other.extract::<usize>() {
137 if other_int == 0 {
138 Ok(self.clone())
139 } else {
140 Err(PyTypeError::new_err(
141 "Addition with an integer for this type is only defined for 0",
142 ))
143 }
144 } else {
145 Err(PyTypeError::new_err("Unsupported operand type for +"))
146 }
147 }
148 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
149 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
150 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
151 } else if let Ok(other_int) = other.extract::<usize>() {
152 if other_int == 0 {
153 Ok(self.clone())
154 } else {
155 Err(PyTypeError::new_err(
156 "Addition with an integer for this type is only defined for 0",
157 ))
158 }
159 } else {
160 Err(PyTypeError::new_err("Unsupported operand type for +"))
161 }
162 }
163 /// Get the number of Events in the Dataset
164 ///
165 /// Returns
166 /// -------
167 /// n_events : int
168 /// The number of Events
169 ///
170 #[getter]
171 fn n_events(&self) -> usize {
172 self.0.n_events()
173 }
174 /// Get the weighted number of Events in the Dataset
175 ///
176 /// Returns
177 /// -------
178 /// n_events : float
179 /// The sum of all Event weights
180 ///
181 #[getter]
182 fn n_events_weighted(&self) -> Float {
183 self.0.n_events_weighted()
184 }
185 /// The weights associated with the Dataset
186 ///
187 /// Returns
188 /// -------
189 /// weights : array_like
190 /// A ``numpy`` array of Event weights
191 ///
192 #[getter]
193 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
194 PyArray1::from_slice(py, &self.0.weights())
195 }
196 /// The internal list of Events stored in the Dataset
197 ///
198 /// Returns
199 /// -------
200 /// events : list of Event
201 /// The Events in the Dataset
202 ///
203 #[getter]
204 fn events(&self) -> Vec<PyEvent> {
205 self.0
206 .events
207 .iter()
208 .map(|rust_event| PyEvent(rust_event.clone()))
209 .collect()
210 }
211 fn __getitem__(&self, index: usize) -> PyEvent {
212 PyEvent(Arc::new(self.0[index].clone()))
213 }
214 /// Separates a Dataset into histogram bins by a Variable value
215 ///
216 /// Parameters
217 /// ----------
218 /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
219 /// The Variable by which each Event is binned
220 /// bins : int
221 /// The number of equally-spaced bins
222 /// range : tuple[float, float]
223 /// The minimum and maximum bin edges
224 ///
225 /// Returns
226 /// -------
227 /// datasets : BinnedDataset
228 /// A pub structure that holds a list of Datasets binned by the given `variable`
229 ///
230 /// See Also
231 /// --------
232 /// laddu.Mass
233 /// laddu.CosTheta
234 /// laddu.Phi
235 /// laddu.PolAngle
236 /// laddu.PolMagnitude
237 /// laddu.Mandelstam
238 ///
239 /// Raises
240 /// ------
241 /// TypeError
242 /// If the given `variable` is not a valid variable
243 ///
244 #[pyo3(signature = (variable, bins, range))]
245 fn bin_by(
246 &self,
247 variable: Bound<'_, PyAny>,
248 bins: usize,
249 range: (Float, Float),
250 ) -> PyResult<PyBinnedDataset> {
251 let py_variable = variable.extract::<PyVariable>()?;
252 Ok(PyBinnedDataset(self.0.bin_by(py_variable, bins, range)))
253 }
254 /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
255 ///
256 /// The new Dataset is resampled with a random generator seeded by the provided `seed`
257 ///
258 /// Parameters
259 /// ----------
260 /// seed : int
261 /// The random seed used in the resampling process
262 ///
263 /// Returns
264 /// -------
265 /// Dataset
266 /// A bootstrapped Dataset
267 ///
268 fn bootstrap(&self, seed: usize) -> PyDataset {
269 PyDataset(self.0.bootstrap(seed))
270 }
271 /// Boost all the four-momenta in all events to the rest frame of the given set of
272 /// four-momenta by indices.
273 ///
274 /// Parameters
275 /// ----------
276 /// indices : list of int
277 /// The indices of the four-momenta to sum
278 ///
279 /// Returns
280 /// -------
281 /// Dataset
282 /// The boosted dataset
283 ///
284 pub fn boost_to_rest_frame_of(&self, indices: Vec<usize>) -> PyDataset {
285 PyDataset(self.0.boost_to_rest_frame_of(indices))
286 }
287}
288
289/// A collection of Datasets binned by a Variable
290///
291/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
292///
293/// See Also
294/// --------
295/// laddu.Dataset.bin_by
296///
297#[pyclass(name = "BinnedDataset", module = "laddu")]
298pub struct PyBinnedDataset(BinnedDataset);
299
300#[pymethods]
301impl PyBinnedDataset {
302 fn __len__(&self) -> usize {
303 self.0.n_bins()
304 }
305 /// The number of bins in the BinnedDataset
306 ///
307 #[getter]
308 fn n_bins(&self) -> usize {
309 self.0.n_bins()
310 }
311 /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
312 ///
313 #[getter]
314 fn range(&self) -> (Float, Float) {
315 self.0.range()
316 }
317 /// The edges of each bin in the BinnedDataset
318 ///
319 #[getter]
320 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<Float>> {
321 PyArray1::from_slice(py, &self.0.edges())
322 }
323 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
324 self.0
325 .get(index)
326 .ok_or(PyIndexError::new_err("index out of range"))
327 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
328 }
329}
330
331/// Open a Dataset from a file
332///
333/// Returns
334/// -------
335/// Dataset
336///
337/// Raises
338/// ------
339/// IOError
340/// If the file could not be read
341///
342/// Warnings
343/// --------
344/// This method will panic/fail if the columns do not have the correct names or data types.
345/// There is currently no way to make this nicer without a large performance dip (if you find a
346/// way, please open a PR).
347///
348/// Notes
349/// -----
350/// Data should be stored in Parquet format with each column being filled with 32-bit floats
351///
352/// Valid/required column names have the following formats:
353///
354/// ``p4_{particle index}_{E|Px|Py|Pz}`` (four-momentum components for each particle)
355///
356/// ``eps_{particle index}_{x|y|z}`` (polarization/helicity vectors for each particle)
357///
358/// ``weight`` (the weight of the Event)
359///
360/// For example, the four-momentum of the 0th particle in the event would be stored in columns
361/// with the names ``p4_0_E``, ``p4_0_Px``, ``p4_0_Py``, and ``p4_0_Pz``. That particle's
362/// polarization could be stored in the columns ``eps_0_x``, ``eps_0_y``, and ``eps_0_z``. This
363/// could continue for an arbitrary number of particles. The ``weight`` column is always
364/// required.
365///
366#[pyfunction(name = "open")]
367pub fn py_open(path: &str) -> PyResult<PyDataset> {
368 Ok(PyDataset(open(path)?))
369}