Skip to main content

laddu_python/
data.rs

1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3    data::{
4        io::{
5            infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
6            resolve_optional_weight_column, resolve_p4_component_columns, P4_COMPONENT_SUFFIXES,
7        },
8        read_parquet as core_read_parquet,
9        read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
10        read_root as core_read_root, write_parquet as core_write_parquet,
11        write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
12        DatasetWriteOptions, Event, EventData, FloatPrecision, SharedDatasetIterExt,
13    },
14    utils::variables::IntoP4Selection,
15    DatasetReadOptions,
16};
17use numpy::{PyArray1, PyReadonlyArray1};
18use pyo3::{
19    exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
20    prelude::*,
21    types::{PyDict, PyList},
22    IntoPyObjectExt,
23};
24use std::{path::PathBuf, sync::Arc};
25
26use crate::utils::vectors::PyVec4;
27
28fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
29    let Some(aliases) = aliases else {
30        return Ok(Vec::new());
31    };
32
33    let mut parsed = Vec::new();
34    for (key, value) in aliases.iter() {
35        let alias_name = key.extract::<String>()?;
36        let selection = if let Ok(single) = value.extract::<String>() {
37            vec![single]
38        } else {
39            let seq = value.extract::<Vec<String>>().map_err(|_| {
40                PyTypeError::new_err("Alias values must be a string or a sequence of strings")
41            })?;
42            if seq.is_empty() {
43                return Err(PyValueError::new_err(format!(
44                    "Alias '{alias_name}' must reference at least one particle",
45                )));
46            }
47            seq
48        };
49        parsed.push((alias_name, selection));
50    }
51
52    Ok(parsed)
53}
54
55fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
56    if let Ok(s) = path.extract::<String>() {
57        Ok(s)
58    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
59        Ok(pathbuf.to_string_lossy().into_owned())
60    } else {
61        Err(PyTypeError::new_err("Expected str or Path"))
62    }
63}
64
65fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
66    match value.map(|v| v.to_ascii_lowercase()) {
67        None => Ok(FloatPrecision::F64),
68        Some(name) if name == "f64" || name == "float64" || name == "double" => {
69            Ok(FloatPrecision::F64)
70        }
71        Some(name) if name == "f32" || name == "float32" || name == "float" => {
72            Ok(FloatPrecision::F32)
73        }
74        Some(other) => Err(PyValueError::new_err(format!(
75            "Unsupported precision '{other}' (expected 'f64' or 'f32')"
76        ))),
77    }
78}
79
80fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
81    if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
82        return Ok(array.as_slice()?.to_vec());
83    }
84    if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
85        return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
86    }
87    if let Ok(values) = value.extract::<Vec<f64>>() {
88        return Ok(values);
89    }
90    if let Ok(values) = value.extract::<Vec<f32>>() {
91        return Ok(values.into_iter().map(|v| v as f64).collect());
92    }
93    if let Ok(list) = value.cast::<PyList>() {
94        let mut converted = Vec::with_capacity(list.len());
95        for item in list.iter() {
96            converted.push(item.extract::<f64>().map_err(|_| {
97                PyTypeError::new_err(format!(
98                    "Column '{name}' must be numeric (float32/float64/list of floats)"
99                ))
100            })?);
101        }
102        return Ok(converted);
103    }
104    Err(PyTypeError::new_err(format!(
105        "Column '{name}' must be numeric (float32/float64/list of floats)"
106    )))
107}
108
109/// A single event
110///
111/// Events are composed of a set of 4-momenta of particles in the overall
112/// center-of-momentum frame, optional auxiliary scalars (e.g. polarization magnitude or angle),
113/// and a weight.
114///
115/// Parameters
116/// ----------
117/// p4s : list of Vec4
118///     4-momenta of each particle in the event in the overall center-of-momentum frame
119/// aux: list of float
120///     Scalar auxiliary data associated with the event
121/// weight : float
122///     The weight associated with this event
123/// p4_names : list of str, optional
124///     Human-readable aliases for each four-momentum. Providing names enables name-based
125///     lookups when evaluating variables.
126/// aux_names : list of str, optional
127///     Aliases for auxiliary scalars corresponding to ``aux``.
128/// aliases : dict of {str: str or list[str]}, optional
129///     Additional particle identifiers that reference one or more entries from ``p4_names``.
130///
131/// Examples
132/// --------
133/// >>> from laddu import Event, Vec3  # doctest: +SKIP
134/// >>> event = Event(  # doctest: +SKIP
135/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, 1.0).with_mass(0.0)],
136/// ...     [],
137/// ...     1.0,
138/// ...     p4_names=['kshort1', 'kshort2'],
139/// ...     aliases={'pair': ['kshort1', 'kshort2']},
140/// ... )
141/// >>> event.p4('pair')  # doctest: +SKIP
142/// Vec4(px=0.0, py=0.0, pz=2.0, e=2.0)
143/// >>> event.aux['pol_angle']  # doctest: +SKIP
144/// 0.3
145///
146#[pyclass(name = "Event", module = "laddu", from_py_object)]
147#[derive(Clone)]
148pub struct PyEvent {
149    pub event: Event,
150    has_metadata: bool,
151}
152
153#[pymethods]
154impl PyEvent {
155    #[new]
156    #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
157    fn new(
158        p4s: Vec<PyVec4>,
159        aux: Vec<f64>,
160        weight: f64,
161        p4_names: Option<Vec<String>>,
162        aux_names: Option<Vec<String>>,
163        aliases: Option<Bound<PyDict>>,
164    ) -> PyResult<Self> {
165        let event = EventData {
166            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
167            aux,
168            weight,
169        };
170        let aliases = parse_aliases(aliases)?;
171
172        let missing_p4_names = p4_names
173            .as_ref()
174            .map(|names| names.is_empty())
175            .unwrap_or(true);
176
177        if !aliases.is_empty() && missing_p4_names {
178            return Err(PyValueError::new_err(
179                "`aliases` requires `p4_names` so selections can be resolved",
180            ));
181        }
182
183        let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
184        let metadata = if metadata_provided {
185            let p4_names = p4_names.unwrap_or_default();
186            let aux_names = aux_names.unwrap_or_default();
187            let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
188            if !aliases.is_empty() {
189                metadata
190                    .add_p4_aliases(
191                        aliases.into_iter().map(|(alias_name, selection)| {
192                            (alias_name, selection.into_selection())
193                        }),
194                    )
195                    .map_err(PyErr::from)?;
196            }
197            Arc::new(metadata)
198        } else {
199            Arc::new(DatasetMetadata::empty())
200        };
201        let event = Event::new(Arc::new(event), metadata);
202        Ok(Self {
203            event,
204            has_metadata: metadata_provided,
205        })
206    }
207    fn __str__(&self) -> String {
208        self.event.data().to_string()
209    }
210    /// The list of 4-momenta for each particle in the event
211    ///
212    #[getter]
213    fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
214        self.ensure_metadata()?;
215        let mapping = PyDict::new(py);
216        for (name, vec4) in self.event.p4s() {
217            mapping.set_item(name, PyVec4(vec4))?;
218        }
219        Ok(mapping.into())
220    }
221    /// The auxiliary scalar values associated with the event
222    ///
223    #[getter]
224    #[pyo3(name = "aux")]
225    fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
226        self.ensure_metadata()?;
227        let mapping = PyDict::new(py);
228        for (name, value) in self.event.aux() {
229            mapping.set_item(name, value)?;
230        }
231        Ok(mapping.into())
232    }
233    /// The weight of this event relative to others in a Dataset
234    ///
235    #[getter]
236    fn get_weight(&self) -> f64 {
237        self.event.weight()
238    }
239    /// Get the sum of the four-momenta within the event at the given indices
240    ///
241    /// Parameters
242    /// ----------
243    /// names : list of str
244    ///     The names of the four-momenta to sum
245    ///
246    /// Returns
247    /// -------
248    /// Vec4
249    ///     The result of summing the given four-momenta
250    ///
251    fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
252        let indices = self.resolve_p4_indices(&names)?;
253        Ok(PyVec4(self.event.data().get_p4_sum(indices)))
254    }
255    /// Boost all the four-momenta in the event to the rest frame of the given set of
256    /// four-momenta by indices.
257    ///
258    /// Parameters
259    /// ----------
260    /// names : list of str
261    ///     The names of the four-momenta whose rest frame should be used for the boost
262    ///
263    /// Returns
264    /// -------
265    /// Event
266    ///     The boosted event
267    ///
268    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
269        let indices = self.resolve_p4_indices(&names)?;
270        let boosted = self.event.data().boost_to_rest_frame_of(indices);
271        Ok(Self {
272            event: Event::new(Arc::new(boosted), self.event.metadata_arc()),
273            has_metadata: self.has_metadata,
274        })
275    }
276    /// Get the value of a Variable on the given Event
277    ///
278    /// Parameters
279    /// ----------
280    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
281    ///
282    /// Returns
283    /// -------
284    /// float
285    ///
286    /// Notes
287    /// -----
288    /// Variables that rely on particle names require the event to carry metadata. Provide
289    /// ``p4_names``/``aux_names`` when constructing the event or evaluate variables through a
290    /// ``laddu.Dataset`` to ensure the metadata is available.
291    ///
292    /// Examples
293    /// --------
294    /// >>> from laddu import Event, Vec3  # doctest: +SKIP
295    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
296    /// >>> event = Event(  # doctest: +SKIP
297    /// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.938)],
298    /// ...     [],
299    /// ...     1.0,
300    /// ...     p4_names=['proton'],
301    /// ... )
302    /// >>> event.evaluate(Mass(['proton']))  # doctest: +SKIP
303    /// 0.938
304    ///
305    fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
306        let mut variable = variable.extract::<PyVariable>()?;
307        let metadata = self.ensure_metadata()?;
308        variable.bind_in_place(metadata)?;
309        variable.evaluate_event(&self.event)
310    }
311
312    /// Retrieve a four-momentum by name.
313    fn p4(&self, name: &str) -> PyResult<PyVec4> {
314        self.ensure_metadata()?;
315        self.event
316            .p4(name)
317            .map(PyVec4)
318            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
319    }
320}
321
322impl PyEvent {
323    fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
324        if !self.has_metadata {
325            Err(PyValueError::new_err(
326                "Event has no associated metadata for name-based operations",
327            ))
328        } else {
329            Ok(self.event.metadata())
330        }
331    }
332
333    fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
334        let metadata = self.ensure_metadata()?;
335        let mut resolved = Vec::new();
336        for name in names {
337            let selection = metadata
338                .p4_selection(name)
339                .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
340            resolved.extend_from_slice(selection.indices());
341        }
342        Ok(resolved)
343    }
344
345    pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
346        self.has_metadata.then(|| self.event.metadata())
347    }
348}
349
350#[doc(hidden)]
351/// A set of Events
352///
353/// Datasets can be created from lists of Events or by using the constructor helpers
354/// such as :func:`laddu.io.read_parquet`, :func:`laddu.io.read_root`, and
355/// :func:`laddu.io.read_amptools`
356///
357/// Datasets can also be indexed directly to access individual Events
358///
359/// Parameters
360/// ----------
361/// events : list of Event
362/// p4_names : list of str, optional
363///     Names assigned to each four-momentum; enables name-based lookups if provided.
364/// aux_names : list of str, optional
365///     Names for auxiliary scalars stored alongside the events.
366/// aliases : dict of {str: str or list[str]}, optional
367///     Additional particle identifiers that override aliases stored on the Events.
368///
369/// Notes
370/// -----
371/// Explicit metadata provided here takes precedence over metadata embedded in the
372/// input Events.
373///
374/// Examples
375/// --------
376/// >>> from laddu import Dataset, Event, Vec3  # doctest: +SKIP
377/// >>> event = Event(  # doctest: +SKIP
378/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, -1.0).with_mass(0.938)],
379/// ...     [0.4, 0.3],
380/// ...     1.0,
381/// ... )
382/// >>> dataset = Dataset(  # doctest: +SKIP
383/// ...     [event],
384/// ...     p4_names=['beam', 'proton'],
385/// ...     aux_names=['pol_magnitude', 'pol_angle'],
386/// ...     aliases={'target': 'proton'},
387/// ... )
388/// >>> dataset[0].p4('target')  # doctest: +SKIP
389/// Vec4(px=0.0, py=0.0, pz=-1.0, e=1.371073...)
390///
391#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
392#[derive(Clone)]
393pub struct PyDataset(pub Arc<Dataset>);
394
395#[pyclass(
396    name = "ParquetChunkIter",
397    module = "laddu",
398    unsendable,
399    skip_from_py_object
400)]
401pub struct PyParquetChunkIter {
402    chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
403}
404
405#[pymethods]
406impl PyParquetChunkIter {
407    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
408        slf.into()
409    }
410
411    fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
412        match self.chunks.next() {
413            Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
414            Some(Err(err)) => Err(PyErr::from(err)),
415            None => Ok(None),
416        }
417    }
418}
419
420#[pyclass(
421    name = "DatasetIter",
422    module = "laddu",
423    unsendable,
424    skip_from_py_object
425)]
426struct PyDatasetIter {
427    kind: PyDatasetIterKind,
428}
429
430enum PyDatasetIterKind {
431    Local { dataset: Arc<Dataset>, index: usize },
432    Global(DatasetArcIter),
433}
434
435#[pymethods]
436impl PyDatasetIter {
437    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetIter> {
438        slf.into()
439    }
440
441    fn __next__(&mut self) -> Option<PyEvent> {
442        let event = match &mut self.kind {
443            PyDatasetIterKind::Local { dataset, index } => {
444                let event = dataset.events_local().get(*index)?.clone();
445                *index += 1;
446                event
447            }
448            PyDatasetIterKind::Global(iterator) => iterator.next()?,
449        };
450        Some(PyEvent {
451            event,
452            has_metadata: true,
453        })
454    }
455}
456
457#[pymethods]
458impl PyDataset {
459    #[new]
460    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
461    fn new(
462        events: Vec<PyEvent>,
463        p4_names: Option<Vec<String>>,
464        aux_names: Option<Vec<String>>,
465        aliases: Option<Bound<PyDict>>,
466    ) -> PyResult<Self> {
467        let inferred_metadata = events
468            .iter()
469            .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
470
471        let aliases = parse_aliases(aliases)?;
472        let use_explicit_metadata =
473            p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
474
475        let metadata =
476            if use_explicit_metadata {
477                let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
478                    (Some(names), _) => names,
479                    (None, Some(metadata)) => metadata.p4_names().to_vec(),
480                    (None, None) => Vec::new(),
481                };
482                let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
483                    (Some(names), _) => names,
484                    (None, Some(metadata)) => metadata.aux_names().to_vec(),
485                    (None, None) => Vec::new(),
486                };
487
488                if !aliases.is_empty() && resolved_p4_names.is_empty() {
489                    return Err(PyValueError::new_err(
490                        "`aliases` requires `p4_names` or events with metadata for resolution",
491                    ));
492                }
493
494                let mut metadata = DatasetMetadata::new(resolved_p4_names, resolved_aux_names)
495                    .map_err(PyErr::from)?;
496                if !aliases.is_empty() {
497                    metadata
498                        .add_p4_aliases(aliases.into_iter().map(|(alias_name, selection)| {
499                            (alias_name, selection.into_selection())
500                        }))
501                        .map_err(PyErr::from)?;
502                }
503                Some(Arc::new(metadata))
504            } else {
505                inferred_metadata
506            };
507
508        let events: Vec<Arc<EventData>> = events
509            .into_iter()
510            .map(|event| event.event.data_arc())
511            .collect();
512        let dataset = if let Some(metadata) = metadata {
513            Dataset::new_with_metadata(events, metadata)
514        } else {
515            Dataset::new(events)
516        };
517        Ok(Self(Arc::new(dataset)))
518    }
519
520    fn __len__(&self) -> usize {
521        self.0.n_events()
522    }
523    /// Iterate over all events in dataset order.
524    ///
525    /// Notes
526    /// -----
527    /// This is the default iterator used by ``for event in dataset``.
528    /// When MPI is enabled, it preserves global indexing semantics and may fetch
529    /// remote events as needed.
530    fn __iter__(&self) -> PyDatasetIter {
531        self.iter_global()
532    }
533    /// Get the number of events owned by the current rank.
534    #[getter]
535    fn n_events_local(&self) -> usize {
536        self.0.n_events_local()
537    }
538    /// Iterate over the events owned by the current rank.
539    ///
540    /// Notes
541    /// -----
542    /// When MPI is disabled, this iterates over the full Dataset.
543    /// When MPI is enabled, this iterates only over events stored on the current rank.
544    /// The yielded order matches the current rank's local storage order.
545    fn iter_local(&self) -> PyDatasetIter {
546        PyDatasetIter {
547            kind: PyDatasetIterKind::Local {
548                dataset: self.0.clone(),
549                index: 0,
550            },
551        }
552    }
553    /// Iterate over all events in the Dataset.
554    ///
555    /// Notes
556    /// -----
557    /// This is the default iterator used by ``for event in dataset``.
558    /// When MPI is enabled, this preserves global dataset order and performs
559    /// explicit cross-rank event fetches as needed.
560    fn iter_global(&self) -> PyDatasetIter {
561        PyDatasetIter {
562            kind: PyDatasetIterKind::Global(self.0.shared_iter_global()),
563        }
564    }
565    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
566        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
567            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
568        } else if let Ok(other_int) = other.extract::<usize>() {
569            if other_int == 0 {
570                Ok(self.clone())
571            } else {
572                Err(PyTypeError::new_err(
573                    "Addition with an integer for this type is only defined for 0",
574                ))
575            }
576        } else {
577            Err(PyTypeError::new_err("Unsupported operand type for +"))
578        }
579    }
580    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
581        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
582            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
583        } else if let Ok(other_int) = other.extract::<usize>() {
584            if other_int == 0 {
585                Ok(self.clone())
586            } else {
587                Err(PyTypeError::new_err(
588                    "Addition with an integer for this type is only defined for 0",
589                ))
590            }
591        } else {
592            Err(PyTypeError::new_err("Unsupported operand type for +"))
593        }
594    }
595    /// Get the number of Events in the Dataset
596    ///
597    /// Notes
598    /// -----
599    /// When MPI is enabled, this returns the global event count.
600    /// It therefore matches ``len(dataset)`` and the valid range for ``dataset[i]``.
601    ///
602    /// Returns
603    /// -------
604    /// n_events : int
605    ///     The number of Events
606    ///
607    #[getter]
608    fn n_events(&self) -> usize {
609        self.0.n_events()
610    }
611    /// Alias for ``n_events``.
612    #[getter]
613    fn n_events_global(&self) -> usize {
614        self.0.n_events_global()
615    }
616    /// Particle names used to construct four-momenta when loading from a Parquet file.
617    #[getter]
618    fn p4_names(&self) -> Vec<String> {
619        self.0.p4_names().to_vec()
620    }
621    /// Auxiliary scalar names associated with this Dataset.
622    #[getter]
623    fn aux_names(&self) -> Vec<String> {
624        self.0.aux_names().to_vec()
625    }
626
627    /// Get the weighted number of Events in the Dataset
628    ///
629    /// Notes
630    /// -----
631    /// When MPI is enabled, this returns the global weighted event count.
632    ///
633    /// Returns
634    /// -------
635    /// n_events : float
636    ///     The sum of all Event weights
637    ///
638    #[getter]
639    fn n_events_weighted(&self) -> f64 {
640        self.0.n_events_weighted()
641    }
642    /// Alias for ``n_events_weighted``.
643    #[getter]
644    fn n_events_weighted_global(&self) -> f64 {
645        self.0.n_events_weighted_global()
646    }
647    /// Get the weighted number of local Events in the Dataset
648    ///
649    /// Notes
650    /// -----
651    /// When MPI is enabled, this returns the sum of the current rank's Event
652    /// weights.
653    ///
654    /// Returns
655    /// -------
656    /// n_events : float
657    ///     The sum of the current rank's Event weights
658    ///
659    #[getter]
660    fn n_events_weighted_local(&self) -> f64 {
661        self.0.n_events_weighted_local()
662    }
663    /// The weights associated with the Dataset
664    ///
665    /// Notes
666    /// -----
667    /// When MPI is enabled, this returns the global weight vector in dataset order.
668    ///
669    /// Returns
670    /// -------
671    /// weights : array_like
672    ///     A ``numpy`` array of global Event weights
673    ///
674    #[getter]
675    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
676        PyArray1::from_slice(py, &self.0.weights())
677    }
678    /// Alias for ``weights``.
679    #[getter]
680    fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
681        PyArray1::from_slice(py, &self.0.weights_global())
682    }
683    /// The weights associated with the Dataset on the current rank.
684    ///
685    /// Notes
686    /// -----
687    /// This is the explicit rank-local counterpart to ``weights``.
688    ///
689    /// Returns
690    /// -------
691    /// weights : array_like
692    ///     A ``numpy`` array of rank-local Event weights
693    ///
694    #[getter]
695    fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
696        PyArray1::from_slice(py, &self.0.weights_local())
697    }
698    /// The internal list of Events stored in the Dataset
699    ///
700    /// Notes
701    /// -----
702    /// When MPI is enabled, this returns the full global event list.
703    /// Use ``events_local`` or ``iter_local()`` to access only the current rank's
704    /// event ownership.
705    /// The returned list matches the order produced by ``for event in dataset``.
706    ///
707    /// Returns
708    /// -------
709    /// events : list of Event
710    ///     The Events in the Dataset
711    ///
712    #[getter]
713    fn events(&self) -> Vec<PyEvent> {
714        self.0
715            .shared_iter()
716            .map(|rust_event| PyEvent {
717                event: rust_event,
718                has_metadata: true,
719            })
720            .collect()
721    }
722    /// Alias for ``events``.
723    #[getter]
724    fn events_global(&self) -> Vec<PyEvent> {
725        self.events()
726    }
727    /// The list of Events stored on the current rank.
728    ///
729    /// Notes
730    /// -----
731    /// This is the explicit rank-local counterpart to ``events``.
732    /// The returned list matches the order produced by ``iter_local()``.
733    #[getter]
734    fn events_local(&self) -> Vec<PyEvent> {
735        self.0
736            .events_local()
737            .iter()
738            .map(|rust_event| PyEvent {
739                event: rust_event.clone(),
740                has_metadata: true,
741            })
742            .collect()
743    }
744    /// Retrieve a four-momentum by particle name for the event at ``index``.
745    fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
746        self.0
747            .p4_by_name(index, name)
748            .map(PyVec4)
749            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
750    }
751    /// Retrieve an auxiliary scalar by name for the event at ``index``.
752    fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
753        self.0
754            .aux_by_name(index, name)
755            .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
756    }
757    /// Alias for ``dataset[index]``.
758    ///
759    /// Notes
760    /// -----
761    /// This preserves the default global indexing semantics under MPI.
762    fn event_global(&self, index: usize) -> PyResult<PyEvent> {
763        let event = self
764            .0
765            .get_event_global(index)
766            .ok_or_else(|| PyIndexError::new_err("index out of range"))?;
767        Ok(PyEvent {
768            event,
769            has_metadata: true,
770        })
771    }
772    fn __getitem__<'py>(
773        &self,
774        py: Python<'py>,
775        index: Bound<'py, PyAny>,
776    ) -> PyResult<Bound<'py, PyAny>> {
777        if let Ok(value) = self.evaluate(py, index.clone()) {
778            value.into_bound_py_any(py)
779        } else if let Ok(index) = index.extract::<usize>() {
780            let event = self
781                .0
782                .get_event(index)
783                .ok_or_else(|| PyIndexError::new_err("index out of range"))?;
784            PyEvent {
785                event,
786                has_metadata: true,
787            }
788            .into_bound_py_any(py)
789        } else {
790            Err(PyTypeError::new_err(
791                "Unsupported index type (int or Variable)",
792            ))
793        }
794    }
795    /// Separates a Dataset into histogram bins by a Variable value
796    ///
797    /// Parameters
798    /// ----------
799    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
800    ///     The Variable by which each Event is binned
801    /// bins : int
802    ///     The number of equally-spaced bins
803    /// range : tuple[float, float]
804    ///     The minimum and maximum bin edges
805    ///
806    /// Returns
807    /// -------
808    /// datasets : BinnedDataset
809    ///     A structure containing Datasets binned by the given `variable`
810    ///
811    /// See Also
812    /// --------
813    /// laddu.Mass
814    /// laddu.CosTheta
815    /// laddu.Phi
816    /// laddu.PolAngle
817    /// laddu.PolMagnitude
818    /// laddu.Mandelstam
819    ///
820    /// Examples
821    /// --------
822    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
823    /// >>> binned = dataset.bin_by(Mass(['kshort1']), bins=10, range=(0.9, 1.5))  # doctest: +SKIP
824    /// >>> len(binned)  # doctest: +SKIP
825    /// 10
826    ///
827    /// Raises
828    /// ------
829    /// TypeError
830    ///     If the given `variable` is not a valid variable
831    ///
832    #[pyo3(signature = (variable, bins, range))]
833    fn bin_by(
834        &self,
835        variable: Bound<'_, PyAny>,
836        bins: usize,
837        range: (f64, f64),
838    ) -> PyResult<PyBinnedDataset> {
839        let py_variable = variable.extract::<PyVariable>()?;
840        let bound_variable = py_variable.bound(self.0.metadata())?;
841        Ok(PyBinnedDataset(self.0.bin_by(
842            bound_variable,
843            bins,
844            range,
845        )?))
846    }
847    /// Filter the Dataset by a given VariableExpression, selecting events for which the expression returns ``True``.
848    ///
849    /// Parameters
850    /// ----------
851    /// expression : VariableExpression
852    ///     The expression with which to filter the Dataset
853    ///
854    /// Returns
855    /// -------
856    /// Dataset
857    ///     The filtered Dataset
858    ///
859    /// Examples
860    /// --------
861    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
862    /// >>> heavy = dataset.filter(Mass(['kshort1']) > 1.0)  # doctest: +SKIP
863    ///
864    pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
865        Ok(PyDataset(
866            self.0.filter(&expression.0).map_err(PyErr::from)?,
867        ))
868    }
869    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
870    ///
871    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
872    ///
873    /// Parameters
874    /// ----------
875    /// seed : int
876    ///     The random seed used in the resampling process
877    ///
878    /// Returns
879    /// -------
880    /// Dataset
881    ///     A bootstrapped Dataset
882    ///
883    /// Examples
884    /// --------
885    /// >>> replica = dataset.bootstrap(2024)  # doctest: +SKIP
886    /// >>> len(replica) == len(dataset)  # doctest: +SKIP
887    /// True
888    ///
889    fn bootstrap(&self, seed: usize) -> PyDataset {
890        PyDataset(self.0.bootstrap(seed))
891    }
892    /// Boost all the four-momenta in all events to the rest frame of the given set of
893    /// named four-momenta.
894    ///
895    /// Parameters
896    /// ----------
897    /// names : list of str
898    ///     The names of the four-momenta defining the rest frame
899    ///
900    /// Returns
901    /// -------
902    /// Dataset
903    ///     The boosted dataset
904    ///
905    /// Examples
906    /// --------
907    /// >>> dataset.boost_to_rest_frame_of(['kshort1', 'kshort2'])  # doctest: +SKIP
908    ///
909    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
910        PyDataset(self.0.boost_to_rest_frame_of(&names))
911    }
912    /// Get the value of a Variable over every event in the Dataset.
913    ///
914    /// Parameters
915    /// ----------
916    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
917    ///
918    /// Returns
919    /// -------
920    /// values : array_like
921    ///
922    /// Examples
923    /// --------
924    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
925    /// >>> masses = dataset.evaluate(Mass(['kshort1']))  # doctest: +SKIP
926    /// >>> masses.shape  # doctest: +SKIP
927    /// (len(dataset),)
928    ///
929    fn evaluate<'py>(
930        &self,
931        py: Python<'py>,
932        variable: Bound<'py, PyAny>,
933    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
934        let variable = variable.extract::<PyVariable>()?;
935        let bound_variable = variable.bound(self.0.metadata())?;
936        let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
937        Ok(PyArray1::from_vec(py, values))
938    }
939}
940
941/// Read a Dataset from a Parquet file.
942///
943/// Examples
944/// --------
945/// >>> import laddu.io as ldio  # doctest: +SKIP
946/// >>> dataset = ldio.read_parquet(  # doctest: +SKIP
947/// ...     'events.parquet',
948/// ...     p4s=['beam', 'proton'],
949/// ...     aux=['pol_magnitude', 'pol_angle'],
950/// ...     aliases={'target': 'proton'},
951/// ... )
952/// >>> dataset.p4_names  # doctest: +SKIP
953/// ['beam', 'proton']
954#[pyfunction]
955#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
956pub fn read_parquet(
957    path: Bound<PyAny>,
958    p4s: Option<Vec<String>>,
959    aux: Option<Vec<String>>,
960    aliases: Option<Bound<PyDict>>,
961) -> PyResult<PyDataset> {
962    let path_str = parse_dataset_path(path)?;
963    let mut read_options = DatasetReadOptions::default();
964    if let Some(p4s) = p4s {
965        read_options = read_options.p4_names(p4s);
966    }
967    if let Some(aux) = aux {
968        read_options = read_options.aux_names(aux);
969    }
970    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
971        read_options = read_options.alias(alias_name, selection);
972    }
973    let dataset = core_read_parquet(&path_str, &read_options)?;
974    Ok(PyDataset(dataset))
975}
976
977/// Read a Dataset from a Parquet file in chunks.
978#[pyfunction]
979#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
980pub fn read_parquet_chunked(
981    path: Bound<PyAny>,
982    p4s: Option<Vec<String>>,
983    aux: Option<Vec<String>>,
984    aliases: Option<Bound<PyDict>>,
985    chunk_size: Option<usize>,
986) -> PyResult<PyParquetChunkIter> {
987    let path_str = parse_dataset_path(path)?;
988    let mut read_options = DatasetReadOptions::default();
989    if let Some(p4s) = p4s {
990        read_options = read_options.p4_names(p4s);
991    }
992    if let Some(aux) = aux {
993        read_options = read_options.aux_names(aux);
994    }
995    if let Some(chunk_size) = chunk_size {
996        read_options = read_options.chunk_size(chunk_size);
997    }
998    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
999        read_options = read_options.alias(alias_name, selection);
1000    }
1001
1002    let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
1003    Ok(PyParquetChunkIter {
1004        chunks: Box::new(chunks),
1005    })
1006}
1007
1008/// Read a Dataset from a ROOT file using the oxyroot backend.
1009///
1010/// Examples
1011/// --------
1012/// >>> import laddu.io as ldio  # doctest: +SKIP
1013/// >>> dataset = ldio.read_root(  # doctest: +SKIP
1014/// ...     'events.root',
1015/// ...     tree='kin',
1016/// ...     p4s=['beam', 'proton'],
1017/// ...     aux=['pol_magnitude', 'pol_angle'],
1018/// ... )
1019/// >>> dataset.aux_names  # doctest: +SKIP
1020/// ['pol_magnitude', 'pol_angle']
1021#[pyfunction]
1022#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
1023pub fn read_root(
1024    path: Bound<PyAny>,
1025    tree: Option<String>,
1026    p4s: Option<Vec<String>>,
1027    aux: Option<Vec<String>>,
1028    aliases: Option<Bound<PyDict>>,
1029) -> PyResult<PyDataset> {
1030    let path_str = parse_dataset_path(path)?;
1031    let mut read_options = DatasetReadOptions::default();
1032    if let Some(p4s) = p4s {
1033        read_options = read_options.p4_names(p4s);
1034    }
1035    if let Some(aux) = aux {
1036        read_options = read_options.aux_names(aux);
1037    }
1038    if let Some(tree) = tree {
1039        read_options = read_options.tree(tree);
1040    }
1041    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1042        read_options = read_options.alias(alias_name, selection);
1043    }
1044    let dataset = core_read_root(&path_str, &read_options)?;
1045    Ok(PyDataset(dataset))
1046}
1047
1048/// Write a Dataset to a Parquet file.
1049#[pyfunction]
1050#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
1051pub fn write_parquet(
1052    dataset: &PyDataset,
1053    path: Bound<PyAny>,
1054    chunk_size: Option<usize>,
1055    precision: &str,
1056) -> PyResult<()> {
1057    let path_str = parse_dataset_path(path)?;
1058    let mut write_options = DatasetWriteOptions::default();
1059    if let Some(size) = chunk_size {
1060        write_options.batch_size = size.max(1);
1061    }
1062    write_options.precision = parse_precision_arg(Some(precision))?;
1063    core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1064}
1065
1066/// Write a Dataset to a ROOT file using the oxyroot backend.
1067#[pyfunction]
1068#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
1069pub fn write_root(
1070    dataset: &PyDataset,
1071    path: Bound<PyAny>,
1072    tree: Option<String>,
1073    chunk_size: Option<usize>,
1074    precision: &str,
1075) -> PyResult<()> {
1076    let path_str = parse_dataset_path(path)?;
1077    let mut write_options = DatasetWriteOptions::default();
1078    if let Some(name) = tree {
1079        write_options.tree = Some(name);
1080    }
1081    if let Some(size) = chunk_size {
1082        write_options.batch_size = size.max(1);
1083    }
1084    write_options.precision = parse_precision_arg(Some(precision))?;
1085    core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1086}
1087
1088#[doc(hidden)]
1089/// Build a Dataset from columnar arrays.
1090///
1091/// This is the canonical high-throughput ingestion path used by Python reader helpers.
1092///
1093/// Examples
1094/// --------
1095/// >>> import laddu.io as ldio  # doctest: +SKIP
1096/// >>> dataset = ldio.from_columns(  # doctest: +SKIP
1097/// ...     {
1098/// ...         'beam_px': [0.0],
1099/// ...         'beam_py': [0.0],
1100/// ...         'beam_pz': [8.5],
1101/// ...         'beam_e': [8.5],
1102/// ...         'proton_px': [0.0],
1103/// ...         'proton_py': [0.0],
1104/// ...         'proton_pz': [-0.2],
1105/// ...         'proton_e': [0.959],
1106/// ...         'pol_magnitude': [0.4],
1107/// ...         'pol_angle': [0.3],
1108/// ...         'weight': [1.0],
1109/// ...     },
1110/// ...     p4s=['beam', 'proton'],
1111/// ...     aux=['pol_magnitude', 'pol_angle'],
1112/// ...     aliases={'target': 'proton'},
1113/// ... )
1114/// >>> dataset[0].p4('target')  # doctest: +SKIP
1115/// Vec4(px=0.0, py=0.0, pz=-0.2, e=0.959)
1116#[pyfunction]
1117#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
1118pub fn from_columns(
1119    columns: Bound<'_, PyDict>,
1120    p4s: Option<Vec<String>>,
1121    aux: Option<Vec<String>>,
1122    aliases: Option<Bound<'_, PyDict>>,
1123) -> PyResult<PyDataset> {
1124    let column_names = columns
1125        .iter()
1126        .map(|(key, _)| key.extract::<String>())
1127        .collect::<PyResult<Vec<_>>>()?;
1128
1129    let (detected_p4_names, detected_aux_names) =
1130        infer_p4_and_aux_names_from_columns(&column_names);
1131    let p4_names = p4s.unwrap_or(detected_p4_names);
1132    if p4_names.is_empty() {
1133        let mut partial_components: std::collections::BTreeMap<
1134            String,
1135            std::collections::BTreeSet<&str>,
1136        > = std::collections::BTreeMap::new();
1137        for column_name in &column_names {
1138            let lowered = column_name.to_ascii_lowercase();
1139            for suffix in P4_COMPONENT_SUFFIXES {
1140                if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
1141                    let prefix = column_name[..column_name.len() - suffix.len()].to_string();
1142                    partial_components.entry(prefix).or_default().insert(suffix);
1143                }
1144            }
1145        }
1146        if let Some((prefix, present)) = partial_components.iter().next() {
1147            if present.len() < P4_COMPONENT_SUFFIXES.len() {
1148                let missing = P4_COMPONENT_SUFFIXES
1149                    .iter()
1150                    .filter(|suffix| !present.contains(**suffix))
1151                    .map(|suffix| format!("{prefix}{suffix}"))
1152                    .collect::<Vec<_>>()
1153                    .join(", ");
1154                return Err(PyKeyError::new_err(format!(
1155                    "Missing components [{missing}] for four-momentum '{prefix}'"
1156                )));
1157            }
1158        }
1159        return Err(PyValueError::new_err(
1160            "No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
1161        ));
1162    }
1163
1164    let aux_names = aux.unwrap_or(detected_aux_names);
1165    let p4_component_columns =
1166        resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
1167    let resolved_aux_columns =
1168        resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
1169
1170    let n_events = {
1171        let first_name = p4_component_columns
1172            .first()
1173            .map(|components| components[0].clone())
1174            .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
1175        let values = extract_numeric_column(
1176            columns
1177                .get_item(first_name.as_str())?
1178                .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
1179            &first_name,
1180        )?;
1181        values.len()
1182    };
1183
1184    let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
1185    for component_names in &p4_component_columns {
1186        let px = extract_numeric_column(
1187            columns
1188                .get_item(component_names[0].as_str())?
1189                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
1190            component_names[0].as_str(),
1191        )?;
1192        let py = extract_numeric_column(
1193            columns
1194                .get_item(component_names[1].as_str())?
1195                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
1196            component_names[1].as_str(),
1197        )?;
1198        let pz = extract_numeric_column(
1199            columns
1200                .get_item(component_names[2].as_str())?
1201                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
1202            component_names[2].as_str(),
1203        )?;
1204        let e = extract_numeric_column(
1205            columns
1206                .get_item(component_names[3].as_str())?
1207                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
1208            component_names[3].as_str(),
1209        )?;
1210        if px.len() != n_events
1211            || py.len() != n_events
1212            || pz.len() != n_events
1213            || e.len() != n_events
1214        {
1215            return Err(PyValueError::new_err(
1216                "All p4 components must have the same length",
1217            ));
1218        }
1219        p4_columns.push([px, py, pz, e]);
1220    }
1221
1222    let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
1223    for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
1224        let values = extract_numeric_column(
1225            columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
1226                PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
1227            })?,
1228            aux_name,
1229        )?;
1230        if values.len() != n_events {
1231            return Err(PyValueError::new_err(format!(
1232                "Auxiliary column '{aux_name}' length does not match p4 columns"
1233            )));
1234        }
1235        aux_columns.push(values);
1236    }
1237
1238    let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
1239        let weight_values = columns
1240            .get_item(weight_column_name.as_str())?
1241            .ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
1242        let values = extract_numeric_column(weight_values, "weight")?;
1243        if values.len() != n_events {
1244            return Err(PyValueError::new_err(
1245                "Column 'weight' length does not match p4 columns",
1246            ));
1247        }
1248        values
1249    } else {
1250        vec![1.0; n_events]
1251    };
1252
1253    let parsed_aliases = parse_aliases(aliases)?;
1254    let mut metadata =
1255        DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
1256    if !parsed_aliases.is_empty() {
1257        metadata
1258            .add_p4_aliases(
1259                parsed_aliases
1260                    .into_iter()
1261                    .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
1262            )
1263            .map_err(PyErr::from)?;
1264    }
1265
1266    let mut events = Vec::with_capacity(n_events);
1267    for event_idx in 0..n_events {
1268        let p4s = p4_columns
1269            .iter()
1270            .map(|components| {
1271                laddu_core::utils::vectors::Vec4::new(
1272                    components[0][event_idx],
1273                    components[1][event_idx],
1274                    components[2][event_idx],
1275                    components[3][event_idx],
1276                )
1277            })
1278            .collect::<Vec<_>>();
1279        let aux = aux_columns
1280            .iter()
1281            .map(|values| values[event_idx])
1282            .collect::<Vec<_>>();
1283        events.push(Arc::new(EventData {
1284            p4s,
1285            aux,
1286            weight: weights[event_idx],
1287        }));
1288    }
1289
1290    Ok(PyDataset(Arc::new(Dataset::new_with_metadata(
1291        events,
1292        Arc::new(metadata),
1293    ))))
1294}
1295
1296/// A collection of Datasets binned by a Variable
1297///
1298/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
1299///
1300/// See Also
1301/// --------
1302/// laddu.Dataset.bin_by
1303///
1304#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
1305pub struct PyBinnedDataset(BinnedDataset);
1306
1307#[pymethods]
1308impl PyBinnedDataset {
1309    fn __len__(&self) -> usize {
1310        self.0.n_bins()
1311    }
1312    /// The number of bins in the BinnedDataset
1313    ///
1314    #[getter]
1315    fn n_bins(&self) -> usize {
1316        self.0.n_bins()
1317    }
1318    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
1319    ///
1320    #[getter]
1321    fn range(&self) -> (f64, f64) {
1322        self.0.range()
1323    }
1324    /// The edges of each bin in the BinnedDataset
1325    ///
1326    #[getter]
1327    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
1328        PyArray1::from_slice(py, &self.0.edges())
1329    }
1330    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
1331        self.0
1332            .get(index)
1333            .ok_or(PyIndexError::new_err("index out of range"))
1334            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
1335    }
1336}