Skip to main content

laddu_python/
data.rs

1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3    data::{
4        read_parquet as core_read_parquet, read_root as core_read_root,
5        write_parquet as core_write_parquet, write_root as core_write_root, BinnedDataset, Dataset,
6        DatasetMetadata, DatasetWriteOptions, Event, EventData, FloatPrecision,
7    },
8    utils::variables::IntoP4Selection,
9    DatasetReadOptions,
10};
11use numpy::PyArray1;
12use pyo3::{
13    exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
14    prelude::*,
15    types::PyDict,
16    IntoPyObjectExt,
17};
18use std::{path::PathBuf, sync::Arc};
19
20use crate::utils::vectors::PyVec4;
21
22fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
23    let Some(aliases) = aliases else {
24        return Ok(Vec::new());
25    };
26
27    let mut parsed = Vec::new();
28    for (key, value) in aliases.iter() {
29        let alias_name = key.extract::<String>()?;
30        let selection = if let Ok(single) = value.extract::<String>() {
31            vec![single]
32        } else {
33            let seq = value.extract::<Vec<String>>().map_err(|_| {
34                PyTypeError::new_err("Alias values must be a string or a sequence of strings")
35            })?;
36            if seq.is_empty() {
37                return Err(PyValueError::new_err(format!(
38                    "Alias '{alias_name}' must reference at least one particle",
39                )));
40            }
41            seq
42        };
43        parsed.push((alias_name, selection));
44    }
45
46    Ok(parsed)
47}
48
49fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
50    if let Ok(s) = path.extract::<String>() {
51        Ok(s)
52    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
53        Ok(pathbuf.to_string_lossy().into_owned())
54    } else {
55        Err(PyTypeError::new_err("Expected str or Path"))
56    }
57}
58
59fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
60    match value.map(|v| v.to_ascii_lowercase()) {
61        None => Ok(FloatPrecision::F64),
62        Some(name) if name == "f64" || name == "float64" || name == "double" => {
63            Ok(FloatPrecision::F64)
64        }
65        Some(name) if name == "f32" || name == "float32" || name == "float" => {
66            Ok(FloatPrecision::F32)
67        }
68        Some(other) => Err(PyValueError::new_err(format!(
69            "Unsupported precision '{other}' (expected 'f64' or 'f32')"
70        ))),
71    }
72}
73
74/// A single event
75///
76/// Events are composed of a set of 4-momenta of particles in the overall
77/// center-of-momentum frame, optional auxiliary scalars (e.g. polarization magnitude or angle),
78/// and a weight.
79///
80/// Parameters
81/// ----------
82/// p4s : list of Vec4
83///     4-momenta of each particle in the event in the overall center-of-momentum frame
84/// aux: list of float
85///     Scalar auxiliary data associated with the event
86/// weight : float
87///     The weight associated with this event
88/// p4_names : list of str, optional
89///     Human-readable aliases for each four-momentum. Providing names enables name-based
90///     lookups when evaluating variables.
91/// aux_names : list of str, optional
92///     Aliases for auxiliary scalars corresponding to ``aux``.
93/// aliases : dict of {str: str or list[str]}, optional
94///     Additional particle identifiers that reference one or more entries from ``p4_names``.
95///
96/// Examples
97/// --------
98/// >>> from laddu import Event, Vec3  # doctest: +SKIP
99/// >>> event = Event(  # doctest: +SKIP
100/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, 1.0).with_mass(0.0)],
101/// ...     [],
102/// ...     1.0,
103/// ...     p4_names=['kshort1', 'kshort2'],
104/// ...     aliases={'pair': ['kshort1', 'kshort2']},
105/// ... )
106/// >>> event.p4('pair')  # doctest: +SKIP
107/// Vec4(px=0.0, py=0.0, pz=2.0, e=2.0)
108///
109#[pyclass(name = "Event", module = "laddu")]
110#[derive(Clone)]
111pub struct PyEvent {
112    pub event: Event,
113    has_metadata: bool,
114}
115
116#[pymethods]
117impl PyEvent {
118    #[new]
119    #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
120    fn new(
121        p4s: Vec<PyVec4>,
122        aux: Vec<f64>,
123        weight: f64,
124        p4_names: Option<Vec<String>>,
125        aux_names: Option<Vec<String>>,
126        aliases: Option<Bound<PyDict>>,
127    ) -> PyResult<Self> {
128        let event = EventData {
129            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
130            aux,
131            weight,
132        };
133        let aliases = parse_aliases(aliases)?;
134
135        let missing_p4_names = p4_names
136            .as_ref()
137            .map(|names| names.is_empty())
138            .unwrap_or(true);
139
140        if !aliases.is_empty() && missing_p4_names {
141            return Err(PyValueError::new_err(
142                "`aliases` requires `p4_names` so selections can be resolved",
143            ));
144        }
145
146        let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
147        let metadata = if metadata_provided {
148            let p4_names = p4_names.unwrap_or_default();
149            let aux_names = aux_names.unwrap_or_default();
150            let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
151            if !aliases.is_empty() {
152                metadata
153                    .add_p4_aliases(
154                        aliases.into_iter().map(|(alias_name, selection)| {
155                            (alias_name, selection.into_selection())
156                        }),
157                    )
158                    .map_err(PyErr::from)?;
159            }
160            Arc::new(metadata)
161        } else {
162            Arc::new(DatasetMetadata::empty())
163        };
164        let event = Event::new(Arc::new(event), metadata);
165        Ok(Self {
166            event,
167            has_metadata: metadata_provided,
168        })
169    }
170    fn __str__(&self) -> String {
171        self.event.data().to_string()
172    }
173    /// The list of 4-momenta for each particle in the event
174    ///
175    #[getter]
176    fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
177        self.ensure_metadata()?;
178        let mapping = PyDict::new(py);
179        for (name, vec4) in self.event.p4s() {
180            mapping.set_item(name, PyVec4(vec4))?;
181        }
182        Ok(mapping.into())
183    }
184    /// The auxiliary scalar values associated with the event
185    ///
186    #[getter]
187    #[pyo3(name = "aux")]
188    fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
189        self.ensure_metadata()?;
190        let mapping = PyDict::new(py);
191        for (name, value) in self.event.aux() {
192            mapping.set_item(name, value)?;
193        }
194        Ok(mapping.into())
195    }
196    /// The weight of this event relative to others in a Dataset
197    ///
198    #[getter]
199    fn get_weight(&self) -> f64 {
200        self.event.weight()
201    }
202    /// Get the sum of the four-momenta within the event at the given indices
203    ///
204    /// Parameters
205    /// ----------
206    /// names : list of str
207    ///     The names of the four-momenta to sum
208    ///
209    /// Returns
210    /// -------
211    /// Vec4
212    ///     The result of summing the given four-momenta
213    ///
214    fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
215        let indices = self.resolve_p4_indices(&names)?;
216        Ok(PyVec4(self.event.data().get_p4_sum(indices)))
217    }
218    /// Boost all the four-momenta in the event to the rest frame of the given set of
219    /// four-momenta by indices.
220    ///
221    /// Parameters
222    /// ----------
223    /// names : list of str
224    ///     The names of the four-momenta whose rest frame should be used for the boost
225    ///
226    /// Returns
227    /// -------
228    /// Event
229    ///     The boosted event
230    ///
231    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
232        let indices = self.resolve_p4_indices(&names)?;
233        let boosted = self.event.data().boost_to_rest_frame_of(indices);
234        Ok(Self {
235            event: Event::new(Arc::new(boosted), self.event.metadata_arc()),
236            has_metadata: self.has_metadata,
237        })
238    }
239    /// Get the value of a Variable on the given Event
240    ///
241    /// Parameters
242    /// ----------
243    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
244    ///
245    /// Returns
246    /// -------
247    /// float
248    ///
249    /// Notes
250    /// -----
251    /// Variables that rely on particle names require the event to carry metadata. Provide
252    /// ``p4_names``/``aux_names`` when constructing the event or evaluate variables through a
253    /// ``laddu.Dataset`` to ensure the metadata is available.
254    ///
255    fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
256        let mut variable = variable.extract::<PyVariable>()?;
257        if !self.has_metadata {
258            return Err(PyValueError::new_err(
259                "Cannot evaluate variable on an Event without associated metadata. Construct the Event with `p4_names`/`aux_names` or evaluate through a Dataset.",
260            ));
261        }
262        variable.bind_in_place(self.event.metadata())?;
263        let event_arc = self.event.data_arc();
264        variable.evaluate_event(&event_arc)
265    }
266
267    /// Retrieve a four-momentum by name (if present).
268    fn p4(&self, name: &str) -> PyResult<Option<PyVec4>> {
269        self.ensure_metadata()?;
270        Ok(self.event.p4(name).map(PyVec4))
271    }
272}
273
274impl PyEvent {
275    fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
276        if !self.has_metadata {
277            Err(PyValueError::new_err(
278                "Event has no associated metadata for name-based operations",
279            ))
280        } else {
281            Ok(self.event.metadata())
282        }
283    }
284
285    fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
286        let metadata = self.ensure_metadata()?;
287        let mut resolved = Vec::new();
288        for name in names {
289            let selection = metadata
290                .p4_selection(name)
291                .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
292            resolved.extend_from_slice(selection.indices());
293        }
294        Ok(resolved)
295    }
296
297    pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
298        self.has_metadata.then(|| self.event.metadata())
299    }
300}
301
302/// A set of Events
303///
304/// Datasets can be created from lists of Events or by using the constructor helpers
305/// such as :func:`laddu.io.read_parquet`, :func:`laddu.io.read_root`, and
306/// :func:`laddu.io.read_amptools`
307///
308/// Datasets can also be indexed directly to access individual Events
309///
310/// Parameters
311/// ----------
312/// events : list of Event
313/// p4_names : list of str, optional
314///     Names assigned to each four-momentum; enables name-based lookups if provided.
315/// aux_names : list of str, optional
316///     Names for auxiliary scalars stored alongside the events.
317/// aliases : dict of {str: str or list[str]}, optional
318///     Additional particle identifiers that override aliases stored on the Events.
319///
320/// Notes
321/// -----
322/// Explicit metadata provided here takes precedence over metadata embedded in the
323/// input Events.
324///
325#[pyclass(name = "Dataset", module = "laddu", subclass)]
326#[derive(Clone)]
327pub struct PyDataset(pub Arc<Dataset>);
328
329#[pyclass(name = "DatasetIter", module = "laddu")]
330struct PyDatasetIter {
331    dataset: Arc<Dataset>,
332    index: usize,
333    total: usize,
334}
335
336#[pymethods]
337impl PyDatasetIter {
338    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetIter> {
339        slf.into()
340    }
341
342    fn __next__(&mut self) -> Option<PyEvent> {
343        if self.index >= self.total {
344            return None;
345        }
346        let event = self.dataset[self.index].clone();
347        self.index += 1;
348        Some(PyEvent {
349            event,
350            has_metadata: true,
351        })
352    }
353}
354
355#[pymethods]
356impl PyDataset {
357    #[new]
358    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
359    fn new(
360        events: Vec<PyEvent>,
361        p4_names: Option<Vec<String>>,
362        aux_names: Option<Vec<String>>,
363        aliases: Option<Bound<PyDict>>,
364    ) -> PyResult<Self> {
365        let inferred_metadata = events
366            .iter()
367            .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
368
369        let aliases = parse_aliases(aliases)?;
370        let use_explicit_metadata =
371            p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
372
373        let metadata =
374            if use_explicit_metadata {
375                let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
376                    (Some(names), _) => names,
377                    (None, Some(metadata)) => metadata.p4_names().to_vec(),
378                    (None, None) => Vec::new(),
379                };
380                let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
381                    (Some(names), _) => names,
382                    (None, Some(metadata)) => metadata.aux_names().to_vec(),
383                    (None, None) => Vec::new(),
384                };
385
386                if !aliases.is_empty() && resolved_p4_names.is_empty() {
387                    return Err(PyValueError::new_err(
388                        "`aliases` requires `p4_names` or events with metadata for resolution",
389                    ));
390                }
391
392                let mut metadata = DatasetMetadata::new(resolved_p4_names, resolved_aux_names)
393                    .map_err(PyErr::from)?;
394                if !aliases.is_empty() {
395                    metadata
396                        .add_p4_aliases(aliases.into_iter().map(|(alias_name, selection)| {
397                            (alias_name, selection.into_selection())
398                        }))
399                        .map_err(PyErr::from)?;
400                }
401                Some(Arc::new(metadata))
402            } else {
403                inferred_metadata
404            };
405
406        let events: Vec<Arc<EventData>> = events
407            .into_iter()
408            .map(|event| event.event.data_arc())
409            .collect();
410        let dataset = if let Some(metadata) = metadata {
411            Dataset::new_with_metadata(events, metadata)
412        } else {
413            Dataset::new(events)
414        };
415        Ok(Self(Arc::new(dataset)))
416    }
417
418    fn __len__(&self) -> usize {
419        self.0.n_events()
420    }
421    fn __iter__(&self) -> PyDatasetIter {
422        PyDatasetIter {
423            dataset: self.0.clone(),
424            index: 0,
425            total: self.0.n_events(),
426        }
427    }
428    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
429        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
430            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
431        } else if let Ok(other_int) = other.extract::<usize>() {
432            if other_int == 0 {
433                Ok(self.clone())
434            } else {
435                Err(PyTypeError::new_err(
436                    "Addition with an integer for this type is only defined for 0",
437                ))
438            }
439        } else {
440            Err(PyTypeError::new_err("Unsupported operand type for +"))
441        }
442    }
443    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
444        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
445            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
446        } else if let Ok(other_int) = other.extract::<usize>() {
447            if other_int == 0 {
448                Ok(self.clone())
449            } else {
450                Err(PyTypeError::new_err(
451                    "Addition with an integer for this type is only defined for 0",
452                ))
453            }
454        } else {
455            Err(PyTypeError::new_err("Unsupported operand type for +"))
456        }
457    }
458    /// Get the number of Events in the Dataset
459    ///
460    /// Returns
461    /// -------
462    /// n_events : int
463    ///     The number of Events
464    ///
465    #[getter]
466    fn n_events(&self) -> usize {
467        self.0.n_events()
468    }
469    /// Particle names used to construct four-momenta when loading from a Parquet file.
470    #[getter]
471    fn p4_names(&self) -> Vec<String> {
472        self.0.p4_names().to_vec()
473    }
474    /// Auxiliary scalar names associated with this Dataset.
475    #[getter]
476    fn aux_names(&self) -> Vec<String> {
477        self.0.aux_names().to_vec()
478    }
479
480    /// Get the weighted number of Events in the Dataset
481    ///
482    /// Returns
483    /// -------
484    /// n_events : float
485    ///     The sum of all Event weights
486    ///
487    #[getter]
488    fn n_events_weighted(&self) -> f64 {
489        self.0.n_events_weighted()
490    }
491    /// The weights associated with the Dataset
492    ///
493    /// Returns
494    /// -------
495    /// weights : array_like
496    ///     A ``numpy`` array of Event weights
497    ///
498    #[getter]
499    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
500        PyArray1::from_slice(py, &self.0.weights())
501    }
502    /// The internal list of Events stored in the Dataset
503    ///
504    /// Notes
505    /// -----
506    /// When MPI is enabled, this returns only the events local to the current rank.
507    /// Use Python iteration (`for event in dataset`, `list(dataset)`, etc.) to
508    /// traverse all events across ranks.
509    ///
510    /// Returns
511    /// -------
512    /// events : list of Event
513    ///     The Events in the Dataset
514    ///
515    #[getter]
516    fn events(&self) -> Vec<PyEvent> {
517        self.0
518            .events
519            .iter()
520            .map(|rust_event| PyEvent {
521                event: rust_event.clone(),
522                has_metadata: true,
523            })
524            .collect()
525    }
526    /// Retrieve a four-momentum by particle name for the event at ``index``.
527    fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
528        self.0
529            .p4_by_name(index, name)
530            .map(PyVec4)
531            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
532    }
533    /// Retrieve an auxiliary scalar by name for the event at ``index``.
534    fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
535        self.0
536            .aux_by_name(index, name)
537            .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
538    }
539    fn __getitem__<'py>(
540        &self,
541        py: Python<'py>,
542        index: Bound<'py, PyAny>,
543    ) -> PyResult<Bound<'py, PyAny>> {
544        if let Ok(value) = self.evaluate(py, index.clone()) {
545            value.into_bound_py_any(py)
546        } else if let Ok(index) = index.extract::<usize>() {
547            PyEvent {
548                event: self.0[index].clone(),
549                has_metadata: true,
550            }
551            .into_bound_py_any(py)
552        } else {
553            Err(PyTypeError::new_err(
554                "Unsupported index type (int or Variable)",
555            ))
556        }
557    }
558    /// Separates a Dataset into histogram bins by a Variable value
559    ///
560    /// Parameters
561    /// ----------
562    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
563    ///     The Variable by which each Event is binned
564    /// bins : int
565    ///     The number of equally-spaced bins
566    /// range : tuple[float, float]
567    ///     The minimum and maximum bin edges
568    ///
569    /// Returns
570    /// -------
571    /// datasets : BinnedDataset
572    ///     A pub structure that holds a list of Datasets binned by the given `variable`
573    ///
574    /// See Also
575    /// --------
576    /// laddu.Mass
577    /// laddu.CosTheta
578    /// laddu.Phi
579    /// laddu.PolAngle
580    /// laddu.PolMagnitude
581    /// laddu.Mandelstam
582    ///
583    /// Examples
584    /// --------
585    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
586    /// >>> binned = dataset.bin_by(Mass(['kshort1']), bins=10, range=(0.9, 1.5))  # doctest: +SKIP
587    /// >>> len(binned)  # doctest: +SKIP
588    /// 10
589    ///
590    /// Raises
591    /// ------
592    /// TypeError
593    ///     If the given `variable` is not a valid variable
594    ///
595    #[pyo3(signature = (variable, bins, range))]
596    fn bin_by(
597        &self,
598        variable: Bound<'_, PyAny>,
599        bins: usize,
600        range: (f64, f64),
601    ) -> PyResult<PyBinnedDataset> {
602        let py_variable = variable.extract::<PyVariable>()?;
603        let bound_variable = py_variable.bound(self.0.metadata())?;
604        Ok(PyBinnedDataset(self.0.bin_by(
605            bound_variable,
606            bins,
607            range,
608        )?))
609    }
610    /// Filter the Dataset by a given VariableExpression, selecting events for which the expression returns ``True``.
611    ///
612    /// Parameters
613    /// ----------
614    /// expression : VariableExpression
615    ///     The expression with which to filter the Dataset
616    ///
617    /// Returns
618    /// -------
619    /// Dataset
620    ///     The filtered Dataset
621    ///
622    /// Examples
623    /// --------
624    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
625    /// >>> heavy = dataset.filter(Mass(['kshort1']) > 1.0)  # doctest: +SKIP
626    ///
627    pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
628        Ok(PyDataset(
629            self.0.filter(&expression.0).map_err(PyErr::from)?,
630        ))
631    }
632    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
633    ///
634    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
635    ///
636    /// Parameters
637    /// ----------
638    /// seed : int
639    ///     The random seed used in the resampling process
640    ///
641    /// Returns
642    /// -------
643    /// Dataset
644    ///     A bootstrapped Dataset
645    ///
646    /// Examples
647    /// --------
648    /// >>> replica = dataset.bootstrap(2024)  # doctest: +SKIP
649    /// >>> len(replica) == len(dataset)  # doctest: +SKIP
650    /// True
651    ///
652    fn bootstrap(&self, seed: usize) -> PyDataset {
653        PyDataset(self.0.bootstrap(seed))
654    }
655    /// Boost all the four-momenta in all events to the rest frame of the given set of
656    /// named four-momenta.
657    ///
658    /// Parameters
659    /// ----------
660    /// names : list of str
661    ///     The names of the four-momenta defining the rest frame
662    ///
663    /// Returns
664    /// -------
665    /// Dataset
666    ///     The boosted dataset
667    ///
668    /// Examples
669    /// --------
670    /// >>> dataset.boost_to_rest_frame_of(['kshort1', 'kshort2'])  # doctest: +SKIP
671    ///
672    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
673        PyDataset(self.0.boost_to_rest_frame_of(&names))
674    }
675    /// Get the value of a Variable over every event in the Dataset.
676    ///
677    /// Parameters
678    /// ----------
679    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
680    ///
681    /// Returns
682    /// -------
683    /// values : array_like
684    ///
685    /// Examples
686    /// --------
687    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
688    /// >>> masses = dataset.evaluate(Mass(['kshort1']))  # doctest: +SKIP
689    /// >>> masses.shape  # doctest: +SKIP
690    /// (len(dataset),)
691    ///
692    fn evaluate<'py>(
693        &self,
694        py: Python<'py>,
695        variable: Bound<'py, PyAny>,
696    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
697        let variable = variable.extract::<PyVariable>()?;
698        let bound_variable = variable.bound(self.0.metadata())?;
699        let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
700        Ok(PyArray1::from_vec(py, values))
701    }
702}
703
704/// Read a Dataset from a Parquet file.
705#[pyfunction]
706#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
707pub fn read_parquet(
708    path: Bound<PyAny>,
709    p4s: Option<Vec<String>>,
710    aux: Option<Vec<String>>,
711    aliases: Option<Bound<PyDict>>,
712) -> PyResult<PyDataset> {
713    let path_str = parse_dataset_path(path)?;
714    let mut read_options = DatasetReadOptions::default();
715    if let Some(p4s) = p4s {
716        read_options = read_options.p4_names(p4s);
717    }
718    if let Some(aux) = aux {
719        read_options = read_options.aux_names(aux);
720    }
721    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
722        read_options = read_options.alias(alias_name, selection);
723    }
724    let dataset = core_read_parquet(&path_str, &read_options)?;
725    Ok(PyDataset(dataset))
726}
727
728/// Read a Dataset from a ROOT file using the oxyroot backend.
729#[pyfunction]
730#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
731pub fn read_root(
732    path: Bound<PyAny>,
733    tree: Option<String>,
734    p4s: Option<Vec<String>>,
735    aux: Option<Vec<String>>,
736    aliases: Option<Bound<PyDict>>,
737) -> PyResult<PyDataset> {
738    let path_str = parse_dataset_path(path)?;
739    let mut read_options = DatasetReadOptions::default();
740    if let Some(p4s) = p4s {
741        read_options = read_options.p4_names(p4s);
742    }
743    if let Some(aux) = aux {
744        read_options = read_options.aux_names(aux);
745    }
746    if let Some(tree) = tree {
747        read_options = read_options.tree(tree);
748    }
749    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
750        read_options = read_options.alias(alias_name, selection);
751    }
752    let dataset = core_read_root(&path_str, &read_options)?;
753    Ok(PyDataset(dataset))
754}
755
756/// Write a Dataset to a Parquet file.
757#[pyfunction]
758#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
759pub fn write_parquet(
760    dataset: &PyDataset,
761    path: Bound<PyAny>,
762    chunk_size: Option<usize>,
763    precision: &str,
764) -> PyResult<()> {
765    let path_str = parse_dataset_path(path)?;
766    let mut write_options = DatasetWriteOptions::default();
767    if let Some(size) = chunk_size {
768        write_options.batch_size = size.max(1);
769    }
770    write_options.precision = parse_precision_arg(Some(precision))?;
771    core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
772}
773
774/// Write a Dataset to a ROOT file using the oxyroot backend.
775#[pyfunction]
776#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
777pub fn write_root(
778    dataset: &PyDataset,
779    path: Bound<PyAny>,
780    tree: Option<String>,
781    chunk_size: Option<usize>,
782    precision: &str,
783) -> PyResult<()> {
784    let path_str = parse_dataset_path(path)?;
785    let mut write_options = DatasetWriteOptions::default();
786    if let Some(name) = tree {
787        write_options.tree = Some(name);
788    }
789    if let Some(size) = chunk_size {
790        write_options.batch_size = size.max(1);
791    }
792    write_options.precision = parse_precision_arg(Some(precision))?;
793    core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
794}
795
796/// A collection of Datasets binned by a Variable
797///
798/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
799///
800/// See Also
801/// --------
802/// laddu.Dataset.bin_by
803///
804#[pyclass(name = "BinnedDataset", module = "laddu")]
805pub struct PyBinnedDataset(BinnedDataset);
806
807#[pymethods]
808impl PyBinnedDataset {
809    fn __len__(&self) -> usize {
810        self.0.n_bins()
811    }
812    /// The number of bins in the BinnedDataset
813    ///
814    #[getter]
815    fn n_bins(&self) -> usize {
816        self.0.n_bins()
817    }
818    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
819    ///
820    #[getter]
821    fn range(&self) -> (f64, f64) {
822        self.0.range()
823    }
824    /// The edges of each bin in the BinnedDataset
825    ///
826    #[getter]
827    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
828        PyArray1::from_slice(py, &self.0.edges())
829    }
830    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
831        self.0
832            .get(index)
833            .ok_or(PyIndexError::new_err("index out of range"))
834            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
835    }
836}