laddu_python/
data.rs

1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3    data::{
4        BinnedDataset, Dataset, DatasetMetadata, DatasetWriteOptions, Event, EventData,
5        FloatPrecision,
6    },
7    utils::variables::IntoP4Selection,
8    DatasetReadOptions,
9};
10use numpy::PyArray1;
11use pyo3::{
12    exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
13    prelude::*,
14    types::PyDict,
15    IntoPyObjectExt,
16};
17use std::{path::PathBuf, sync::Arc};
18
19use crate::utils::vectors::PyVec4;
20
21fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
22    let Some(aliases) = aliases else {
23        return Ok(Vec::new());
24    };
25
26    let mut parsed = Vec::new();
27    for (key, value) in aliases.iter() {
28        let alias_name = key.extract::<String>()?;
29        let selection = if let Ok(single) = value.extract::<String>() {
30            vec![single]
31        } else {
32            let seq = value.extract::<Vec<String>>().map_err(|_| {
33                PyTypeError::new_err("Alias values must be a string or a sequence of strings")
34            })?;
35            if seq.is_empty() {
36                return Err(PyValueError::new_err(format!(
37                    "Alias '{alias_name}' must reference at least one particle",
38                )));
39            }
40            seq
41        };
42        parsed.push((alias_name, selection));
43    }
44
45    Ok(parsed)
46}
47
48fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
49    if let Ok(s) = path.extract::<String>() {
50        Ok(s)
51    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
52        Ok(pathbuf.to_string_lossy().into_owned())
53    } else {
54        Err(PyTypeError::new_err("Expected str or Path"))
55    }
56}
57
58fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
59    match value.map(|v| v.to_ascii_lowercase()) {
60        None => Ok(FloatPrecision::F64),
61        Some(name) if name == "f64" || name == "float64" || name == "double" => {
62            Ok(FloatPrecision::F64)
63        }
64        Some(name) if name == "f32" || name == "float32" || name == "float" => {
65            Ok(FloatPrecision::F32)
66        }
67        Some(other) => Err(PyValueError::new_err(format!(
68            "Unsupported precision '{other}' (expected 'f64' or 'f32')"
69        ))),
70    }
71}
72
73/// A single event
74///
75/// Events are composed of a set of 4-momenta of particles in the overall
76/// center-of-momentum frame, optional auxiliary scalars (e.g. polarization magnitude or angle),
77/// and a weight.
78///
79/// Parameters
80/// ----------
81/// p4s : list of Vec4
82///     4-momenta of each particle in the event in the overall center-of-momentum frame
83/// aux: list of float
84///     Scalar auxiliary data associated with the event
85/// weight : float
86///     The weight associated with this event
87/// p4_names : list of str, optional
88///     Human-readable aliases for each four-momentum. Providing names enables name-based
89///     lookups when evaluating variables.
90/// aux_names : list of str, optional
91///     Aliases for auxiliary scalars corresponding to ``aux``.
92/// aliases : dict of {str: str or list[str]}, optional
93///     Additional particle identifiers that reference one or more entries from ``p4_names``.
94///
95/// Examples
96/// --------
97/// >>> from laddu import Event, Vec3  # doctest: +SKIP
98/// >>> event = Event(  # doctest: +SKIP
99/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, 1.0).with_mass(0.0)],
100/// ...     [],
101/// ...     1.0,
102/// ...     p4_names=['kshort1', 'kshort2'],
103/// ...     aliases={'pair': ['kshort1', 'kshort2']},
104/// ... )
105/// >>> event.p4('pair')  # doctest: +SKIP
106/// Vec4(px=0.0, py=0.0, pz=2.0, e=2.0)
107///
108#[pyclass(name = "Event", module = "laddu")]
109#[derive(Clone)]
110pub struct PyEvent {
111    pub event: Event,
112    has_metadata: bool,
113}
114
115#[pymethods]
116impl PyEvent {
117    #[new]
118    #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
119    fn new(
120        p4s: Vec<PyVec4>,
121        aux: Vec<f64>,
122        weight: f64,
123        p4_names: Option<Vec<String>>,
124        aux_names: Option<Vec<String>>,
125        aliases: Option<Bound<PyDict>>,
126    ) -> PyResult<Self> {
127        let event = EventData {
128            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
129            aux,
130            weight,
131        };
132        let aliases = parse_aliases(aliases)?;
133
134        let missing_p4_names = p4_names
135            .as_ref()
136            .map(|names| names.is_empty())
137            .unwrap_or(true);
138
139        if !aliases.is_empty() && missing_p4_names {
140            return Err(PyValueError::new_err(
141                "`aliases` requires `p4_names` so selections can be resolved",
142            ));
143        }
144
145        let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
146        let metadata = if metadata_provided {
147            let p4_names = p4_names.unwrap_or_default();
148            let aux_names = aux_names.unwrap_or_default();
149            let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
150            if !aliases.is_empty() {
151                metadata
152                    .add_p4_aliases(
153                        aliases.into_iter().map(|(alias_name, selection)| {
154                            (alias_name, selection.into_selection())
155                        }),
156                    )
157                    .map_err(PyErr::from)?;
158            }
159            Arc::new(metadata)
160        } else {
161            Arc::new(DatasetMetadata::empty())
162        };
163        let event = Event::new(Arc::new(event), metadata);
164        Ok(Self {
165            event,
166            has_metadata: metadata_provided,
167        })
168    }
169    fn __str__(&self) -> String {
170        self.event.data().to_string()
171    }
172    /// The list of 4-momenta for each particle in the event
173    ///
174    #[getter]
175    fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
176        self.ensure_metadata()?;
177        let mapping = PyDict::new(py);
178        for (name, vec4) in self.event.p4s() {
179            mapping.set_item(name, PyVec4(vec4))?;
180        }
181        Ok(mapping.into())
182    }
183    /// The auxiliary scalar values associated with the event
184    ///
185    #[getter]
186    #[pyo3(name = "aux")]
187    fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
188        self.ensure_metadata()?;
189        let mapping = PyDict::new(py);
190        for (name, value) in self.event.aux() {
191            mapping.set_item(name, value)?;
192        }
193        Ok(mapping.into())
194    }
195    /// The weight of this event relative to others in a Dataset
196    ///
197    #[getter]
198    fn get_weight(&self) -> f64 {
199        self.event.weight()
200    }
201    /// Get the sum of the four-momenta within the event at the given indices
202    ///
203    /// Parameters
204    /// ----------
205    /// names : list of str
206    ///     The names of the four-momenta to sum
207    ///
208    /// Returns
209    /// -------
210    /// Vec4
211    ///     The result of summing the given four-momenta
212    ///
213    fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
214        let indices = self.resolve_p4_indices(&names)?;
215        Ok(PyVec4(self.event.data().get_p4_sum(indices)))
216    }
217    /// Boost all the four-momenta in the event to the rest frame of the given set of
218    /// four-momenta by indices.
219    ///
220    /// Parameters
221    /// ----------
222    /// names : list of str
223    ///     The names of the four-momenta whose rest frame should be used for the boost
224    ///
225    /// Returns
226    /// -------
227    /// Event
228    ///     The boosted event
229    ///
230    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
231        let indices = self.resolve_p4_indices(&names)?;
232        let boosted = self.event.data().boost_to_rest_frame_of(indices);
233        Ok(Self {
234            event: Event::new(Arc::new(boosted), self.event.metadata_arc()),
235            has_metadata: self.has_metadata,
236        })
237    }
238    /// Get the value of a Variable on the given Event
239    ///
240    /// Parameters
241    /// ----------
242    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
243    ///
244    /// Returns
245    /// -------
246    /// float
247    ///
248    /// Notes
249    /// -----
250    /// Variables that rely on particle names require the event to carry metadata. Provide
251    /// ``p4_names``/``aux_names`` when constructing the event or evaluate variables through a
252    /// ``laddu.Dataset`` to ensure the metadata is available.
253    ///
254    fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
255        let mut variable = variable.extract::<PyVariable>()?;
256        if !self.has_metadata {
257            return Err(PyValueError::new_err(
258                "Cannot evaluate variable on an Event without associated metadata. Construct the Event with `p4_names`/`aux_names` or evaluate through a Dataset.",
259            ));
260        }
261        variable.bind_in_place(self.event.metadata())?;
262        let event_arc = self.event.data_arc();
263        variable.evaluate_event(&event_arc)
264    }
265
266    /// Retrieve a four-momentum by name (if present).
267    fn p4(&self, name: &str) -> PyResult<Option<PyVec4>> {
268        self.ensure_metadata()?;
269        Ok(self.event.p4(name).map(PyVec4))
270    }
271}
272
273impl PyEvent {
274    fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
275        if !self.has_metadata {
276            Err(PyValueError::new_err(
277                "Event has no associated metadata for name-based operations",
278            ))
279        } else {
280            Ok(self.event.metadata())
281        }
282    }
283
284    fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
285        let metadata = self.ensure_metadata()?;
286        let mut resolved = Vec::new();
287        for name in names {
288            let selection = metadata
289                .p4_selection(name)
290                .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
291            resolved.extend_from_slice(selection.indices());
292        }
293        Ok(resolved)
294    }
295
296    pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
297        self.has_metadata.then(|| self.event.metadata())
298    }
299}
300
301/// A set of Events
302///
303/// Datasets can be created from lists of Events or by using the constructor helpers
304/// such as :meth:`laddu.Dataset.from_parquet`, :meth:`laddu.Dataset.from_root`, and
305/// :meth:`laddu.Dataset.from_amptools`
306///
307/// Datasets can also be indexed directly to access individual Events
308///
309/// Parameters
310/// ----------
311/// events : list of Event
312/// p4_names : list of str, optional
313///     Names assigned to each four-momentum; enables name-based lookups if provided.
314/// aux_names : list of str, optional
315///     Names for auxiliary scalars stored alongside the events.
316/// aliases : dict of {str: str or list[str]}, optional
317///     Additional particle identifiers that override aliases stored on the Events.
318///
319/// Notes
320/// -----
321/// Explicit metadata provided here takes precedence over metadata embedded in the
322/// input Events.
323///
324#[pyclass(name = "Dataset", module = "laddu", subclass)]
325#[derive(Clone)]
326pub struct PyDataset(pub Arc<Dataset>);
327
328#[pyclass(name = "DatasetIter", module = "laddu")]
329struct PyDatasetIter {
330    dataset: Arc<Dataset>,
331    index: usize,
332    total: usize,
333}
334
335#[pymethods]
336impl PyDatasetIter {
337    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetIter> {
338        slf.into()
339    }
340
341    fn __next__(&mut self) -> Option<PyEvent> {
342        if self.index >= self.total {
343            return None;
344        }
345        let event = self.dataset[self.index].clone();
346        self.index += 1;
347        Some(PyEvent {
348            event,
349            has_metadata: true,
350        })
351    }
352}
353
354#[pymethods]
355impl PyDataset {
356    #[new]
357    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
358    fn new(
359        events: Vec<PyEvent>,
360        p4_names: Option<Vec<String>>,
361        aux_names: Option<Vec<String>>,
362        aliases: Option<Bound<PyDict>>,
363    ) -> PyResult<Self> {
364        let inferred_metadata = events
365            .iter()
366            .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
367
368        let aliases = parse_aliases(aliases)?;
369        let use_explicit_metadata =
370            p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
371
372        let metadata =
373            if use_explicit_metadata {
374                let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
375                    (Some(names), _) => names,
376                    (None, Some(metadata)) => metadata.p4_names().to_vec(),
377                    (None, None) => Vec::new(),
378                };
379                let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
380                    (Some(names), _) => names,
381                    (None, Some(metadata)) => metadata.aux_names().to_vec(),
382                    (None, None) => Vec::new(),
383                };
384
385                if !aliases.is_empty() && resolved_p4_names.is_empty() {
386                    return Err(PyValueError::new_err(
387                        "`aliases` requires `p4_names` or events with metadata for resolution",
388                    ));
389                }
390
391                let mut metadata = DatasetMetadata::new(resolved_p4_names, resolved_aux_names)
392                    .map_err(PyErr::from)?;
393                if !aliases.is_empty() {
394                    metadata
395                        .add_p4_aliases(aliases.into_iter().map(|(alias_name, selection)| {
396                            (alias_name, selection.into_selection())
397                        }))
398                        .map_err(PyErr::from)?;
399                }
400                Some(Arc::new(metadata))
401            } else {
402                inferred_metadata
403            };
404
405        let events: Vec<Arc<EventData>> = events
406            .into_iter()
407            .map(|event| event.event.data_arc())
408            .collect();
409        let dataset = if let Some(metadata) = metadata {
410            Dataset::new_with_metadata(events, metadata)
411        } else {
412            Dataset::new(events)
413        };
414        Ok(Self(Arc::new(dataset)))
415    }
416
417    /// Read a Dataset from a Parquet file.
418    ///
419    /// Parameters
420    /// ----------
421    /// path : str or Path
422    ///     The path to the Parquet file.
423    /// p4s : list[str], optional
424    ///     Particle identifiers corresponding to ``*_px``, ``*_py``, ``*_pz``, ``*_e`` columns.
425    /// aux : list[str], optional
426    ///     Auxiliary scalar column names copied verbatim in order.
427    ///
428    /// Notes
429    /// -----
430    /// If `p4s` or `aux` are not provided, they will be inferred from the column names. If all of
431    /// the valid suffixes are provided for a particle, the corresponding columns will be read as a
432    /// four-momentum, otherwise they will be read as auxiliary scalars.
433    ///
434    #[staticmethod]
435    #[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
436    fn from_parquet(
437        path: Bound<PyAny>,
438        p4s: Option<Vec<String>>,
439        aux: Option<Vec<String>>,
440        aliases: Option<Bound<PyDict>>,
441    ) -> PyResult<Self> {
442        let path_str = parse_dataset_path(path)?;
443
444        let mut read_options = DatasetReadOptions::default();
445        if let Some(p4s) = p4s {
446            read_options = read_options.p4_names(p4s);
447        }
448        if let Some(aux) = aux {
449            read_options = read_options.aux_names(aux);
450        }
451        for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
452            read_options = read_options.alias(alias_name, selection);
453        }
454        let dataset = Dataset::from_parquet(&path_str, &read_options)?;
455
456        Ok(Self(dataset))
457    }
458
459    /// Read a Dataset from a ROOT TTree using the oxyroot backend.
460    ///
461    /// Parameters
462    /// ----------
463    /// path : str or Path
464    ///     The path to the ROOT file.
465    /// tree : str, optional
466    ///     Name of the TTree to read when opening ROOT files.
467    /// p4s : list[str], optional
468    ///     Particle identifiers corresponding to ``*_px``, ``*_py``, ``*_pz``, ``*_e`` columns.
469    /// aux : list[str], optional
470    ///     Auxiliary scalar column names copied verbatim in order.
471    ///
472    /// Notes
473    /// -----
474    /// If `p4s` or `aux` are not provided, they will be inferred from the column names. If all of
475    /// the valid suffixes are provided for a particle, the corresponding columns will be read as a
476    /// four-momentum, otherwise they will be read as auxiliary scalars.
477    ///
478    #[staticmethod]
479    #[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
480    fn from_root(
481        path: Bound<PyAny>,
482        tree: Option<String>,
483        p4s: Option<Vec<String>>,
484        aux: Option<Vec<String>>,
485        aliases: Option<Bound<PyDict>>,
486    ) -> PyResult<Self> {
487        let path_str = parse_dataset_path(path)?;
488
489        let mut read_options = DatasetReadOptions::default();
490        if let Some(p4s) = p4s {
491            read_options = read_options.p4_names(p4s);
492        }
493        if let Some(aux) = aux {
494            read_options = read_options.aux_names(aux);
495        }
496        if let Some(tree) = tree {
497            read_options = read_options.tree(tree);
498        }
499        for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
500            read_options = read_options.alias(alias_name, selection);
501        }
502        let dataset = Dataset::from_root(&path_str, &read_options)?;
503
504        Ok(Self(dataset))
505    }
506    fn __len__(&self) -> usize {
507        self.0.n_events()
508    }
509    fn __iter__(&self) -> PyDatasetIter {
510        PyDatasetIter {
511            dataset: self.0.clone(),
512            index: 0,
513            total: self.0.n_events(),
514        }
515    }
516    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
517        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
518            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
519        } else if let Ok(other_int) = other.extract::<usize>() {
520            if other_int == 0 {
521                Ok(self.clone())
522            } else {
523                Err(PyTypeError::new_err(
524                    "Addition with an integer for this type is only defined for 0",
525                ))
526            }
527        } else {
528            Err(PyTypeError::new_err("Unsupported operand type for +"))
529        }
530    }
531    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
532        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
533            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
534        } else if let Ok(other_int) = other.extract::<usize>() {
535            if other_int == 0 {
536                Ok(self.clone())
537            } else {
538                Err(PyTypeError::new_err(
539                    "Addition with an integer for this type is only defined for 0",
540                ))
541            }
542        } else {
543            Err(PyTypeError::new_err("Unsupported operand type for +"))
544        }
545    }
546    /// Get the number of Events in the Dataset
547    ///
548    /// Returns
549    /// -------
550    /// n_events : int
551    ///     The number of Events
552    ///
553    #[getter]
554    fn n_events(&self) -> usize {
555        self.0.n_events()
556    }
557    /// Particle names used to construct four-momenta when loading from a Parquet file.
558    #[getter]
559    fn p4_names(&self) -> Vec<String> {
560        self.0.p4_names().to_vec()
561    }
562    /// Auxiliary scalar names associated with this Dataset.
563    #[getter]
564    fn aux_names(&self) -> Vec<String> {
565        self.0.aux_names().to_vec()
566    }
567
568    /// Write the dataset to a Parquet file.
569    #[pyo3(signature = (path, *, chunk_size=None, precision="f64"))]
570    fn to_parquet(
571        &self,
572        path: Bound<'_, PyAny>,
573        chunk_size: Option<usize>,
574        precision: &str,
575    ) -> PyResult<()> {
576        let path_str = parse_dataset_path(path)?;
577        let mut write_options = DatasetWriteOptions::default();
578        if let Some(size) = chunk_size {
579            write_options.batch_size = size.max(1);
580        }
581        write_options.precision = parse_precision_arg(Some(precision))?;
582
583        self.0
584            .to_parquet(&path_str, &write_options)
585            .map_err(PyErr::from)
586    }
587
588    /// Write the dataset to a ROOT file using the oxyroot backend.
589    #[pyo3(signature = (path, *, tree=None, chunk_size=None, precision="f64"))]
590    fn to_root(
591        &self,
592        path: Bound<'_, PyAny>,
593        tree: Option<String>,
594        chunk_size: Option<usize>,
595        precision: &str,
596    ) -> PyResult<()> {
597        let path_str = parse_dataset_path(path)?;
598        let mut write_options = DatasetWriteOptions::default();
599        if let Some(name) = tree {
600            write_options.tree = Some(name);
601        }
602        if let Some(size) = chunk_size {
603            write_options.batch_size = size.max(1);
604        }
605        write_options.precision = parse_precision_arg(Some(precision))?;
606
607        self.0
608            .to_root(&path_str, &write_options)
609            .map_err(PyErr::from)
610    }
611    /// Get the weighted number of Events in the Dataset
612    ///
613    /// Returns
614    /// -------
615    /// n_events : float
616    ///     The sum of all Event weights
617    ///
618    #[getter]
619    fn n_events_weighted(&self) -> f64 {
620        self.0.n_events_weighted()
621    }
622    /// The weights associated with the Dataset
623    ///
624    /// Returns
625    /// -------
626    /// weights : array_like
627    ///     A ``numpy`` array of Event weights
628    ///
629    #[getter]
630    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
631        PyArray1::from_slice(py, &self.0.weights())
632    }
633    /// The internal list of Events stored in the Dataset
634    ///
635    /// Notes
636    /// -----
637    /// When MPI is enabled, this returns only the events local to the current rank.
638    /// Use Python iteration (`for event in dataset`, `list(dataset)`, etc.) to
639    /// traverse all events across ranks.
640    ///
641    /// Returns
642    /// -------
643    /// events : list of Event
644    ///     The Events in the Dataset
645    ///
646    #[getter]
647    fn events(&self) -> Vec<PyEvent> {
648        self.0
649            .events
650            .iter()
651            .map(|rust_event| PyEvent {
652                event: rust_event.clone(),
653                has_metadata: true,
654            })
655            .collect()
656    }
657    /// Retrieve a four-momentum by particle name for the event at ``index``.
658    fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
659        self.0
660            .p4_by_name(index, name)
661            .map(PyVec4)
662            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
663    }
664    /// Retrieve an auxiliary scalar by name for the event at ``index``.
665    fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
666        self.0
667            .aux_by_name(index, name)
668            .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
669    }
670    fn __getitem__<'py>(
671        &self,
672        py: Python<'py>,
673        index: Bound<'py, PyAny>,
674    ) -> PyResult<Bound<'py, PyAny>> {
675        if let Ok(value) = self.evaluate(py, index.clone()) {
676            value.into_bound_py_any(py)
677        } else if let Ok(index) = index.extract::<usize>() {
678            PyEvent {
679                event: self.0[index].clone(),
680                has_metadata: true,
681            }
682            .into_bound_py_any(py)
683        } else {
684            Err(PyTypeError::new_err(
685                "Unsupported index type (int or Variable)",
686            ))
687        }
688    }
689    /// Separates a Dataset into histogram bins by a Variable value
690    ///
691    /// Parameters
692    /// ----------
693    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
694    ///     The Variable by which each Event is binned
695    /// bins : int
696    ///     The number of equally-spaced bins
697    /// range : tuple[float, float]
698    ///     The minimum and maximum bin edges
699    ///
700    /// Returns
701    /// -------
702    /// datasets : BinnedDataset
703    ///     A pub structure that holds a list of Datasets binned by the given `variable`
704    ///
705    /// See Also
706    /// --------
707    /// laddu.Mass
708    /// laddu.CosTheta
709    /// laddu.Phi
710    /// laddu.PolAngle
711    /// laddu.PolMagnitude
712    /// laddu.Mandelstam
713    ///
714    /// Examples
715    /// --------
716    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
717    /// >>> binned = dataset.bin_by(Mass(['kshort1']), bins=10, range=(0.9, 1.5))  # doctest: +SKIP
718    /// >>> len(binned)  # doctest: +SKIP
719    /// 10
720    ///
721    /// Raises
722    /// ------
723    /// TypeError
724    ///     If the given `variable` is not a valid variable
725    ///
726    #[pyo3(signature = (variable, bins, range))]
727    fn bin_by(
728        &self,
729        variable: Bound<'_, PyAny>,
730        bins: usize,
731        range: (f64, f64),
732    ) -> PyResult<PyBinnedDataset> {
733        let py_variable = variable.extract::<PyVariable>()?;
734        let bound_variable = py_variable.bound(self.0.metadata())?;
735        Ok(PyBinnedDataset(self.0.bin_by(
736            bound_variable,
737            bins,
738            range,
739        )?))
740    }
741    /// Filter the Dataset by a given VariableExpression, selecting events for which the expression returns ``True``.
742    ///
743    /// Parameters
744    /// ----------
745    /// expression : VariableExpression
746    ///     The expression with which to filter the Dataset
747    ///
748    /// Returns
749    /// -------
750    /// Dataset
751    ///     The filtered Dataset
752    ///
753    /// Examples
754    /// --------
755    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
756    /// >>> heavy = dataset.filter(Mass(['kshort1']) > 1.0)  # doctest: +SKIP
757    ///
758    pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
759        Ok(PyDataset(
760            self.0.filter(&expression.0).map_err(PyErr::from)?,
761        ))
762    }
763    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
764    ///
765    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
766    ///
767    /// Parameters
768    /// ----------
769    /// seed : int
770    ///     The random seed used in the resampling process
771    ///
772    /// Returns
773    /// -------
774    /// Dataset
775    ///     A bootstrapped Dataset
776    ///
777    /// Examples
778    /// --------
779    /// >>> replica = dataset.bootstrap(2024)  # doctest: +SKIP
780    /// >>> len(replica) == len(dataset)  # doctest: +SKIP
781    /// True
782    ///
783    fn bootstrap(&self, seed: usize) -> PyDataset {
784        PyDataset(self.0.bootstrap(seed))
785    }
786    /// Boost all the four-momenta in all events to the rest frame of the given set of
787    /// named four-momenta.
788    ///
789    /// Parameters
790    /// ----------
791    /// names : list of str
792    ///     The names of the four-momenta defining the rest frame
793    ///
794    /// Returns
795    /// -------
796    /// Dataset
797    ///     The boosted dataset
798    ///
799    /// Examples
800    /// --------
801    /// >>> dataset.boost_to_rest_frame_of(['kshort1', 'kshort2'])  # doctest: +SKIP
802    ///
803    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
804        PyDataset(self.0.boost_to_rest_frame_of(&names))
805    }
806    /// Get the value of a Variable over every event in the Dataset.
807    ///
808    /// Parameters
809    /// ----------
810    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
811    ///
812    /// Returns
813    /// -------
814    /// values : array_like
815    ///
816    /// Examples
817    /// --------
818    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
819    /// >>> masses = dataset.evaluate(Mass(['kshort1']))  # doctest: +SKIP
820    /// >>> masses.shape  # doctest: +SKIP
821    /// (len(dataset),)
822    ///
823    fn evaluate<'py>(
824        &self,
825        py: Python<'py>,
826        variable: Bound<'py, PyAny>,
827    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
828        let variable = variable.extract::<PyVariable>()?;
829        let bound_variable = variable.bound(self.0.metadata())?;
830        let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
831        Ok(PyArray1::from_vec(py, values))
832    }
833}
834
835/// A collection of Datasets binned by a Variable
836///
837/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
838///
839/// See Also
840/// --------
841/// laddu.Dataset.bin_by
842///
843#[pyclass(name = "BinnedDataset", module = "laddu")]
844pub struct PyBinnedDataset(BinnedDataset);
845
846#[pymethods]
847impl PyBinnedDataset {
848    fn __len__(&self) -> usize {
849        self.0.n_bins()
850    }
851    /// The number of bins in the BinnedDataset
852    ///
853    #[getter]
854    fn n_bins(&self) -> usize {
855        self.0.n_bins()
856    }
857    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
858    ///
859    #[getter]
860    fn range(&self) -> (f64, f64) {
861        self.0.range()
862    }
863    /// The edges of each bin in the BinnedDataset
864    ///
865    #[getter]
866    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
867        PyArray1::from_slice(py, &self.0.edges())
868    }
869    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
870        self.0
871            .get(index)
872            .ok_or(PyIndexError::new_err("index out of range"))
873            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
874    }
875}