Skip to main content

laddu_python/
data.rs

1use std::{path::PathBuf, sync::Arc};
2
3use laddu_core::{
4    data::{
5        io::{
6            infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
7            resolve_optional_weight_column, resolve_p4_component_columns, P4_COMPONENT_SUFFIXES,
8        },
9        read_parquet as core_read_parquet,
10        read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
11        read_root as core_read_root, write_parquet as core_write_parquet,
12        write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
13        DatasetWriteOptions, EventData, FloatPrecision, OwnedEvent, SharedDatasetIterExt,
14    },
15    variables::IntoP4Selection,
16    DatasetReadOptions, Vec4,
17};
18use numpy::{PyArray1, PyReadonlyArray1};
19use pyo3::{
20    exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
21    prelude::*,
22    types::{PyDict, PyList},
23    IntoPyObjectExt,
24};
25
26use crate::{
27    variables::{PyVariable, PyVariableExpression},
28    vectors::PyVec4,
29};
30
31fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
32    let Some(aliases) = aliases else {
33        return Ok(Vec::new());
34    };
35
36    let mut parsed = Vec::new();
37    for (key, value) in aliases.iter() {
38        let alias_name = key.extract::<String>()?;
39        let selection = if let Ok(single) = value.extract::<String>() {
40            vec![single]
41        } else {
42            let seq = value.extract::<Vec<String>>().map_err(|_| {
43                PyTypeError::new_err("Alias values must be a string or a sequence of strings")
44            })?;
45            if seq.is_empty() {
46                return Err(PyValueError::new_err(format!(
47                    "Alias '{alias_name}' must reference at least one particle",
48                )));
49            }
50            seq
51        };
52        parsed.push((alias_name, selection));
53    }
54
55    Ok(parsed)
56}
57
58fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
59    if let Ok(s) = path.extract::<String>() {
60        Ok(s)
61    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
62        Ok(pathbuf.to_string_lossy().into_owned())
63    } else {
64        Err(PyTypeError::new_err("Expected str or Path"))
65    }
66}
67
68fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
69    match value.map(|v| v.to_ascii_lowercase()) {
70        None => Ok(FloatPrecision::F64),
71        Some(name) if name == "f64" || name == "float64" || name == "double" => {
72            Ok(FloatPrecision::F64)
73        }
74        Some(name) if name == "f32" || name == "float32" || name == "float" => {
75            Ok(FloatPrecision::F32)
76        }
77        Some(other) => Err(PyValueError::new_err(format!(
78            "Unsupported precision '{other}' (expected 'f64' or 'f32')"
79        ))),
80    }
81}
82
83fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
84    if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
85        return Ok(array.as_slice()?.to_vec());
86    }
87    if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
88        return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
89    }
90    if let Ok(values) = value.extract::<Vec<f64>>() {
91        return Ok(values);
92    }
93    if let Ok(values) = value.extract::<Vec<f32>>() {
94        return Ok(values.into_iter().map(|v| v as f64).collect());
95    }
96    if let Ok(list) = value.cast::<PyList>() {
97        let mut converted = Vec::with_capacity(list.len());
98        for item in list.iter() {
99            converted.push(item.extract::<f64>().map_err(|_| {
100                PyTypeError::new_err(format!(
101                    "Column '{name}' must be numeric (float32/float64/list of floats)"
102                ))
103            })?);
104        }
105        return Ok(converted);
106    }
107    Err(PyTypeError::new_err(format!(
108        "Column '{name}' must be numeric (float32/float64/list of floats)"
109    )))
110}
111
112fn metadata_from_names_and_aliases(
113    p4_names: Vec<String>,
114    aux_names: Vec<String>,
115    aliases: Option<Bound<'_, PyDict>>,
116) -> PyResult<DatasetMetadata> {
117    let parsed_aliases = parse_aliases(aliases)?;
118    let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
119    if !parsed_aliases.is_empty() {
120        metadata
121            .add_p4_aliases(
122                parsed_aliases
123                    .into_iter()
124                    .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
125            )
126            .map_err(PyErr::from)?;
127    }
128    Ok(metadata)
129}
130
131fn parse_p4_mapping(p4: Bound<'_, PyDict>) -> PyResult<Vec<(String, Vec4)>> {
132    p4.iter()
133        .map(|(key, value)| {
134            Ok((
135                key.extract::<String>()?,
136                value
137                    .extract::<PyVec4>()
138                    .map_err(|_| PyTypeError::new_err("p4 values must be laddu.Vec4 instances"))?
139                    .0,
140            ))
141        })
142        .collect()
143}
144
145fn parse_aux_mapping(aux: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, f64)>> {
146    let Some(aux) = aux else {
147        return Ok(Vec::new());
148    };
149    aux.iter()
150        .map(|(key, value)| Ok((key.extract::<String>()?, value.extract::<f64>()?)))
151        .collect()
152}
153
154fn parse_p4_column(values: Vec<PyVec4>) -> Vec<Vec4> {
155    values.into_iter().map(|value| value.0).collect()
156}
157
158fn dataset_from_py_events(
159    events: Vec<PyEvent>,
160    p4_names: Option<Vec<String>>,
161    aux_names: Option<Vec<String>>,
162    aliases: Option<Bound<PyDict>>,
163    global: bool,
164) -> PyResult<PyDataset> {
165    let inferred_metadata = events
166        .iter()
167        .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
168
169    let aliases = parse_aliases(aliases)?;
170    let use_explicit_metadata = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
171
172    let metadata = if use_explicit_metadata {
173        let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
174            (Some(names), _) => names,
175            (None, Some(metadata)) => metadata.p4_names().to_vec(),
176            (None, None) => Vec::new(),
177        };
178        let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
179            (Some(names), _) => names,
180            (None, Some(metadata)) => metadata.aux_names().to_vec(),
181            (None, None) => Vec::new(),
182        };
183
184        if !aliases.is_empty() && resolved_p4_names.is_empty() {
185            return Err(PyValueError::new_err(
186                "`aliases` requires `p4_names` or events with metadata for resolution",
187            ));
188        }
189
190        let mut metadata =
191            DatasetMetadata::new(resolved_p4_names, resolved_aux_names).map_err(PyErr::from)?;
192        if !aliases.is_empty() {
193            metadata
194                .add_p4_aliases(
195                    aliases
196                        .into_iter()
197                        .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
198                )
199                .map_err(PyErr::from)?;
200        }
201        Some(Arc::new(metadata))
202    } else {
203        inferred_metadata
204    };
205
206    let events: Vec<Arc<EventData>> = events
207        .into_iter()
208        .map(|event| event.event.data_arc())
209        .collect();
210    let dataset = match (metadata, global) {
211        (Some(metadata), true) => Dataset::new_with_metadata(events, metadata),
212        (Some(metadata), false) => Dataset::new_local(events, metadata),
213        (None, true) => Dataset::new(events),
214        (None, false) => Dataset::new_local(events, Arc::new(DatasetMetadata::default())),
215    };
216    Ok(PyDataset(Arc::new(dataset)))
217}
218
219/// A single event
220///
221/// Events are composed of a set of 4-momenta of particles in the overall
222/// center-of-momentum frame, optional auxiliary scalars (e.g. polarization magnitude or angle),
223/// and a weight.
224///
225/// Parameters
226/// ----------
227/// p4s : list of Vec4
228///     4-momenta of each particle in the event in the overall center-of-momentum frame
229/// aux: list of float
230///     Scalar auxiliary data associated with the event
231/// weight : float
232///     The weight associated with this event
233/// p4_names : list of str, optional
234///     Human-readable aliases for each four-momentum. Providing names enables name-based
235///     lookups when evaluating variables.
236/// aux_names : list of str, optional
237///     Aliases for auxiliary scalars corresponding to ``aux``.
238/// aliases : dict of {str: str or list[str]}, optional
239///     Additional particle identifiers that reference one or more entries from ``p4_names``.
240///
241/// Examples
242/// --------
243/// >>> from laddu import Event, Vec3  # doctest: +SKIP
244/// >>> event = Event(  # doctest: +SKIP
245/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, 1.0).with_mass(0.0)],
246/// ...     [],
247/// ...     1.0,
248/// ...     p4_names=['kshort1', 'kshort2'],
249/// ...     aliases={'pair': ['kshort1', 'kshort2']},
250/// ... )
251/// >>> event.p4('pair')  # doctest: +SKIP
252/// Vec4(px=0.0, py=0.0, pz=2.0, e=2.0)
253/// >>> event.aux['pol_angle']  # doctest: +SKIP
254/// 0.3
255///
256#[pyclass(name = "Event", module = "laddu", from_py_object)]
257#[derive(Clone)]
258pub struct PyEvent {
259    pub event: OwnedEvent,
260    has_metadata: bool,
261}
262
263#[pymethods]
264impl PyEvent {
265    #[new]
266    #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
267    fn new(
268        p4s: Vec<PyVec4>,
269        aux: Vec<f64>,
270        weight: f64,
271        p4_names: Option<Vec<String>>,
272        aux_names: Option<Vec<String>>,
273        aliases: Option<Bound<PyDict>>,
274    ) -> PyResult<Self> {
275        let event = EventData {
276            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
277            aux,
278            weight,
279        };
280        let aliases = parse_aliases(aliases)?;
281
282        let missing_p4_names = p4_names
283            .as_ref()
284            .map(|names| names.is_empty())
285            .unwrap_or(true);
286
287        if !aliases.is_empty() && missing_p4_names {
288            return Err(PyValueError::new_err(
289                "`aliases` requires `p4_names` so selections can be resolved",
290            ));
291        }
292
293        let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
294        let metadata = if metadata_provided {
295            let p4_names = p4_names.unwrap_or_default();
296            let aux_names = aux_names.unwrap_or_default();
297            let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
298            if !aliases.is_empty() {
299                metadata
300                    .add_p4_aliases(
301                        aliases.into_iter().map(|(alias_name, selection)| {
302                            (alias_name, selection.into_selection())
303                        }),
304                    )
305                    .map_err(PyErr::from)?;
306            }
307            Arc::new(metadata)
308        } else {
309            Arc::new(DatasetMetadata::empty())
310        };
311        let event = OwnedEvent::new(Arc::new(event), metadata);
312        Ok(Self {
313            event,
314            has_metadata: metadata_provided,
315        })
316    }
317    fn __str__(&self) -> String {
318        self.event.to_string()
319    }
320    fn __repr__(&self) -> String {
321        self.__str__()
322    }
323    /// The list of 4-momenta for each particle in the event
324    ///
325    #[getter]
326    fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
327        self.ensure_metadata()?;
328        let mapping = PyDict::new(py);
329        for (name, vec4) in self.event.p4s() {
330            mapping.set_item(name, PyVec4(vec4))?;
331        }
332        Ok(mapping.into())
333    }
334    /// The auxiliary scalar values associated with the event
335    ///
336    #[getter]
337    #[pyo3(name = "aux")]
338    fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
339        self.ensure_metadata()?;
340        let mapping = PyDict::new(py);
341        for (name, value) in self.event.aux() {
342            mapping.set_item(name, value)?;
343        }
344        Ok(mapping.into())
345    }
346    /// The weight of this event relative to others in a Dataset
347    ///
348    #[getter]
349    fn get_weight(&self) -> f64 {
350        self.event.weight()
351    }
352    /// Get the sum of the four-momenta within the event at the given indices
353    ///
354    /// Parameters
355    /// ----------
356    /// names : list of str
357    ///     The names of the four-momenta to sum
358    ///
359    /// Returns
360    /// -------
361    /// Vec4
362    ///     The result of summing the given four-momenta
363    ///
364    fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
365        let indices = self.resolve_p4_indices(&names)?;
366        Ok(PyVec4(self.event.data().get_p4_sum(indices)))
367    }
368    /// Boost all the four-momenta in the event to the rest frame of the given set of
369    /// four-momenta by indices.
370    ///
371    /// Parameters
372    /// ----------
373    /// names : list of str
374    ///     The names of the four-momenta whose rest frame should be used for the boost
375    ///
376    /// Returns
377    /// -------
378    /// Event
379    ///     The boosted event
380    ///
381    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
382        let indices = self.resolve_p4_indices(&names)?;
383        let boosted = self.event.data().boost_to_rest_frame_of(indices);
384        Ok(Self {
385            event: OwnedEvent::new(Arc::new(boosted), self.event.metadata_arc()),
386            has_metadata: self.has_metadata,
387        })
388    }
389    /// Get the value of a Variable on the given Event
390    ///
391    /// Parameters
392    /// ----------
393    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
394    ///
395    /// Returns
396    /// -------
397    /// float
398    ///
399    /// Notes
400    /// -----
401    /// Variables that rely on particle names require the event to carry metadata. Provide
402    /// ``p4_names``/``aux_names`` when constructing the event or evaluate variables through a
403    /// ``laddu.Dataset`` to ensure the metadata is available.
404    ///
405    /// Examples
406    /// --------
407    /// >>> from laddu import Event, Vec3  # doctest: +SKIP
408    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
409    /// >>> event = Event(  # doctest: +SKIP
410    /// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.938)],
411    /// ...     [],
412    /// ...     1.0,
413    /// ...     p4_names=['proton'],
414    /// ... )
415    /// >>> event.evaluate(Mass(['proton']))  # doctest: +SKIP
416    /// 0.938
417    ///
418    fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
419        let mut variable = variable.extract::<PyVariable>()?;
420        let metadata = self.ensure_metadata()?;
421        variable.bind_in_place(metadata)?;
422        variable.evaluate_event(&self.event)
423    }
424
425    /// Retrieve a four-momentum by name.
426    fn p4(&self, name: &str) -> PyResult<PyVec4> {
427        self.ensure_metadata()?;
428        self.event
429            .p4(name)
430            .map(PyVec4)
431            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
432    }
433}
434
435impl PyEvent {
436    fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
437        if !self.has_metadata {
438            Err(PyValueError::new_err(
439                "Event has no associated metadata for name-based operations",
440            ))
441        } else {
442            Ok(self.event.metadata())
443        }
444    }
445
446    fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
447        let metadata = self.ensure_metadata()?;
448        let mut resolved = Vec::new();
449        for name in names {
450            let selection = metadata
451                .p4_selection(name)
452                .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
453            resolved.extend_from_slice(selection.indices());
454        }
455        Ok(resolved)
456    }
457
458    pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
459        self.has_metadata.then(|| self.event.metadata())
460    }
461}
462
463#[doc(hidden)]
464/// A set of Events
465///
466/// Datasets can be created from lists of Events or by using the constructor helpers
467/// such as :func:`laddu.io.read_parquet`, :func:`laddu.io.read_root`, and
468/// :func:`laddu.io.read_amptools`
469///
470/// Datasets can also be indexed directly to access individual Events
471///
472/// Parameters
473/// ----------
474/// events : list of Event
475/// p4_names : list of str, optional
476///     Names assigned to each four-momentum; enables name-based lookups if provided.
477/// aux_names : list of str, optional
478///     Names for auxiliary scalars stored alongside the events.
479/// aliases : dict of {str: str or list[str]}, optional
480///     Additional particle identifiers that override aliases stored on the Events.
481///
482/// Notes
483/// -----
484/// Explicit metadata provided here takes precedence over metadata embedded in the
485/// input Events.
486///
487/// Examples
488/// --------
489/// >>> from laddu import Dataset, Event, Vec3  # doctest: +SKIP
490/// >>> event = Event(  # doctest: +SKIP
491/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, -1.0).with_mass(0.938)],
492/// ...     [0.4, 0.3],
493/// ...     1.0,
494/// ... )
495/// >>> dataset = Dataset(  # doctest: +SKIP
496/// ...     [event],
497/// ...     p4_names=['beam', 'proton'],
498/// ...     aux_names=['pol_magnitude', 'pol_angle'],
499/// ...     aliases={'target': 'proton'},
500/// ... )
501/// >>> dataset[0].p4('target')  # doctest: +SKIP
502/// Vec4(px=0.0, py=0.0, pz=-1.0, e=1.371073...)
503///
504#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
505#[derive(Clone)]
506pub struct PyDataset(pub Arc<Dataset>);
507
508#[pyclass(
509    name = "ParquetChunkIter",
510    module = "laddu",
511    unsendable,
512    skip_from_py_object
513)]
514pub struct PyParquetChunkIter {
515    chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
516}
517
518#[pymethods]
519impl PyParquetChunkIter {
520    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
521        slf.into()
522    }
523
524    fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
525        match self.chunks.next() {
526            Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
527            Some(Err(err)) => Err(PyErr::from(err)),
528            None => Ok(None),
529        }
530    }
531}
532
533#[pyclass(
534    name = "DatasetEventsGlobal",
535    module = "laddu",
536    unsendable,
537    skip_from_py_object
538)]
539pub struct PyDatasetEventsGlobalIter {
540    iter: DatasetArcIter,
541}
542
543#[pymethods]
544impl PyDatasetEventsGlobalIter {
545    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsGlobalIter> {
546        slf.into()
547    }
548
549    fn __next__(&mut self) -> Option<PyEvent> {
550        self.iter.next().map(|rust_event| PyEvent {
551            event: rust_event,
552            has_metadata: true,
553        })
554    }
555}
556
557#[pyclass(
558    name = "DatasetEventsLocal",
559    module = "laddu",
560    unsendable,
561    skip_from_py_object
562)]
563pub struct PyDatasetEventsLocalIter {
564    dataset: Arc<Dataset>,
565    index: usize,
566}
567
568#[pymethods]
569impl PyDatasetEventsLocalIter {
570    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsLocalIter> {
571        slf.into()
572    }
573
574    fn __next__(&mut self) -> Option<PyEvent> {
575        if self.index >= self.dataset.n_events_local() {
576            return None;
577        }
578        let event = self
579            .dataset
580            .event_local(self.index)
581            .expect("local event index should exist")
582            .to_event_data();
583        self.index += 1;
584        Some(PyEvent {
585            event: OwnedEvent::new(Arc::new(event), self.dataset.metadata_arc()),
586            has_metadata: true,
587        })
588    }
589}
590
591#[pymethods]
592impl PyDataset {
593    #[new]
594    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
595    fn new(
596        events: Vec<PyEvent>,
597        p4_names: Option<Vec<String>>,
598        aux_names: Option<Vec<String>>,
599        aliases: Option<Bound<PyDict>>,
600    ) -> PyResult<Self> {
601        dataset_from_py_events(events, p4_names, aux_names, aliases, true)
602    }
603
604    /// Create a local Dataset from a sequence of Events.
605    #[staticmethod]
606    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
607    fn from_events_local(
608        events: Vec<PyEvent>,
609        p4_names: Option<Vec<String>>,
610        aux_names: Option<Vec<String>>,
611        aliases: Option<Bound<PyDict>>,
612    ) -> PyResult<Self> {
613        dataset_from_py_events(events, p4_names, aux_names, aliases, false)
614    }
615
616    /// Create a global Dataset from a sequence of Events.
617    ///
618    /// Under MPI, every rank must pass the same global event list. Rows are
619    /// partitioned across ranks using laddu's canonical contiguous partition.
620    #[staticmethod]
621    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
622    fn from_events_global(
623        events: Vec<PyEvent>,
624        p4_names: Option<Vec<String>>,
625        aux_names: Option<Vec<String>>,
626        aliases: Option<Bound<PyDict>>,
627    ) -> PyResult<Self> {
628        dataset_from_py_events(events, p4_names, aux_names, aliases, true)
629    }
630
631    /// Create an empty local Dataset with explicit metadata.
632    ///
633    /// Parameters
634    /// ----------
635    /// p4_names : list of str
636    ///     Names of the four-momenta expected in each pushed event.
637    /// aux_names : list of str, optional
638    ///     Names of auxiliary scalar columns expected in each pushed event.
639    /// aliases : dict of {str: str or list[str]}, optional
640    ///     Additional particle identifiers that reference one or more p4 names.
641    #[staticmethod]
642    #[pyo3(signature = (*, p4_names, aux_names=None, aliases=None))]
643    fn empty_local(
644        p4_names: Vec<String>,
645        aux_names: Option<Vec<String>>,
646        aliases: Option<Bound<'_, PyDict>>,
647    ) -> PyResult<Self> {
648        let metadata =
649            metadata_from_names_and_aliases(p4_names, aux_names.unwrap_or_default(), aliases)?;
650        Ok(Self(Arc::new(Dataset::empty_local(metadata))))
651    }
652
653    /// Append one named event to the local Dataset partition.
654    ///
655    /// Parameters
656    /// ----------
657    /// p4 : dict[str, Vec4]
658    ///     Four-momenta keyed by p4 name.
659    /// aux : dict[str, float], optional
660    ///     Auxiliary scalar values keyed by aux name.
661    /// weight : float, optional
662    ///     Event weight. Defaults to 1.0.
663    #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
664    fn push_event_local(
665        &mut self,
666        p4: Bound<'_, PyDict>,
667        aux: Option<Bound<'_, PyDict>>,
668        weight: f64,
669    ) -> PyResult<()> {
670        let p4 = parse_p4_mapping(p4)?;
671        let aux = parse_aux_mapping(aux)?;
672        Arc::make_mut(&mut self.0)
673            .push_event_named_local(p4, aux, weight)
674            .map_err(PyErr::from)
675    }
676
677    /// Append one named event collectively as a single global event.
678    ///
679    /// Under MPI, all ranks must call this method in the same order. Exactly
680    /// one rank stores each event, selected by round-robin global event index.
681    #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
682    fn push_event_global(
683        &mut self,
684        p4: Bound<'_, PyDict>,
685        aux: Option<Bound<'_, PyDict>>,
686        weight: f64,
687    ) -> PyResult<()> {
688        let p4 = parse_p4_mapping(p4)?;
689        let aux = parse_aux_mapping(aux)?;
690        Arc::make_mut(&mut self.0)
691            .push_event_named_global(p4, aux, weight)
692            .map_err(PyErr::from)
693    }
694
695    /// Add a four-momentum column to the local Dataset partition.
696    #[pyo3(signature = (name, values))]
697    fn add_p4_column_local(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
698        Arc::make_mut(&mut self.0)
699            .add_p4_column_local(name, parse_p4_column(values))
700            .map_err(PyErr::from)
701    }
702
703    /// Add an auxiliary scalar column to the local Dataset partition.
704    #[pyo3(signature = (name, values))]
705    fn add_aux_column_local(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
706        let values = extract_numeric_column(values, &name)?;
707        Arc::make_mut(&mut self.0)
708            .add_aux_column_local(name, values)
709            .map_err(PyErr::from)
710    }
711
712    /// Add a four-momentum column collectively across all MPI ranks.
713    ///
714    /// Under MPI, all ranks must call this method in the same order with the
715    /// same column name. Each rank supplies values for its locally stored events.
716    #[pyo3(signature = (name, values))]
717    fn add_p4_column_global(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
718        Arc::make_mut(&mut self.0)
719            .add_p4_column_global(name, parse_p4_column(values))
720            .map_err(PyErr::from)
721    }
722
723    /// Add an auxiliary scalar column collectively across all MPI ranks.
724    ///
725    /// Under MPI, all ranks must call this method in the same order with the
726    /// same column name. Each rank supplies values for its locally stored events.
727    #[pyo3(signature = (name, values))]
728    fn add_aux_column_global(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
729        let values = extract_numeric_column(values, &name)?;
730        Arc::make_mut(&mut self.0)
731            .add_aux_column_global(name, values)
732            .map_err(PyErr::from)
733    }
734
735    fn __len__(&self) -> usize {
736        self.0.n_events()
737    }
738    fn __iter__(&self) -> PyResult<()> {
739        Err(PyTypeError::new_err(
740            "Dataset iteration is explicit; use dataset.events_local or dataset.events_global",
741        ))
742    }
743    /// Get the number of events owned by the current rank.
744    #[getter]
745    fn n_events_local(&self) -> usize {
746        self.0.n_events_local()
747    }
748    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
749        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
750            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
751        } else if let Ok(other_int) = other.extract::<usize>() {
752            if other_int == 0 {
753                Ok(self.clone())
754            } else {
755                Err(PyTypeError::new_err(
756                    "Addition with an integer for this type is only defined for 0",
757                ))
758            }
759        } else {
760            Err(PyTypeError::new_err("Unsupported operand type for +"))
761        }
762    }
763    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
764        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
765            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
766        } else if let Ok(other_int) = other.extract::<usize>() {
767            if other_int == 0 {
768                Ok(self.clone())
769            } else {
770                Err(PyTypeError::new_err(
771                    "Addition with an integer for this type is only defined for 0",
772                ))
773            }
774        } else {
775            Err(PyTypeError::new_err("Unsupported operand type for +"))
776        }
777    }
778
779    fn __repr__(&self) -> String {
780        format!(
781            "Dataset(n_events={}, n_events_local={}, p4_names={:?}, aux_names={:?})",
782            self.0.n_events_global(),
783            self.0.n_events_local(),
784            self.0.p4_names(),
785            self.0.aux_names()
786        )
787    }
788
789    fn __str__(&self) -> String {
790        self.__repr__()
791    }
792
793    /// Get the number of Events in the Dataset
794    ///
795    /// Notes
796    /// -----
797    /// When MPI is enabled, this returns the global event count.
798    /// It therefore matches ``len(dataset)`` and the valid range for ``dataset[i]``.
799    ///
800    /// Returns
801    /// -------
802    /// n_events : int
803    ///     The number of Events
804    ///
805    #[getter]
806    fn n_events(&self) -> usize {
807        self.0.n_events()
808    }
809    /// Alias for ``n_events``.
810    #[getter]
811    fn n_events_global(&self) -> usize {
812        self.0.n_events_global()
813    }
814    /// Particle names used to construct four-momenta when loading from a Parquet file.
815    #[getter]
816    fn p4_names(&self) -> Vec<String> {
817        self.0.p4_names().to_vec()
818    }
819    /// Auxiliary scalar names associated with this Dataset.
820    #[getter]
821    fn aux_names(&self) -> Vec<String> {
822        self.0.aux_names().to_vec()
823    }
824
825    /// Get the weighted number of Events in the Dataset
826    ///
827    /// Notes
828    /// -----
829    /// When MPI is enabled, this returns the global weighted event count.
830    ///
831    /// Returns
832    /// -------
833    /// n_events : float
834    ///     The sum of all Event weights
835    ///
836    #[getter]
837    fn n_events_weighted(&self) -> f64 {
838        self.0.n_events_weighted()
839    }
840    /// Alias for ``n_events_weighted``.
841    #[getter]
842    fn n_events_weighted_global(&self) -> f64 {
843        self.0.n_events_weighted_global()
844    }
845    /// Get the weighted number of local Events in the Dataset
846    ///
847    /// Notes
848    /// -----
849    /// When MPI is enabled, this returns the sum of the current rank's Event
850    /// weights.
851    ///
852    /// Returns
853    /// -------
854    /// n_events : float
855    ///     The sum of the current rank's Event weights
856    ///
857    #[getter]
858    fn n_events_weighted_local(&self) -> f64 {
859        self.0.n_events_weighted_local()
860    }
861    /// The weights associated with the Dataset
862    ///
863    /// Notes
864    /// -----
865    /// When MPI is enabled, this returns the global weight vector in dataset order.
866    ///
867    /// Returns
868    /// -------
869    /// weights : array_like
870    ///     A ``numpy`` array of global Event weights
871    ///
872    #[getter]
873    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
874        PyArray1::from_slice(py, &self.0.weights())
875    }
876    /// Alias for ``weights``.
877    #[getter]
878    fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
879        PyArray1::from_slice(py, &self.0.weights_global())
880    }
881    /// The weights associated with the Dataset on the current rank.
882    ///
883    /// Notes
884    /// -----
885    /// This is the explicit rank-local counterpart to ``weights``.
886    ///
887    /// Returns
888    /// -------
889    /// weights : array_like
890    ///     A ``numpy`` array of rank-local Event weights
891    ///
892    #[getter]
893    fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
894        PyArray1::from_slice(py, &self.0.weights_local())
895    }
896    /// Iterate over global Events in dataset order.
897    #[getter]
898    fn events_global(&self) -> PyDatasetEventsGlobalIter {
899        PyDatasetEventsGlobalIter {
900            iter: self.0.shared_iter_global(),
901        }
902    }
903    /// Iterate over Events stored on the current rank.
904    ///
905    /// Notes
906    /// -----
907    /// This is the explicit rank-local counterpart to ``events_global``.
908    #[getter]
909    fn events_local(&self) -> PyDatasetEventsLocalIter {
910        PyDatasetEventsLocalIter {
911            dataset: self.0.clone(),
912            index: 0,
913        }
914    }
915    /// Retrieve a four-momentum by particle name for the event at ``index``.
916    fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
917        self.0
918            .p4_by_name(index, name)
919            .map(PyVec4)
920            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
921    }
922    /// Retrieve an auxiliary scalar by name for the event at ``index``.
923    fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
924        self.0
925            .aux_by_name(index, name)
926            .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
927    }
928    /// Alias for ``dataset[index]``.
929    ///
930    /// Notes
931    /// -----
932    /// This preserves the default global indexing semantics under MPI.
933    fn event_global(&self, index: usize) -> PyResult<PyEvent> {
934        let event = self
935            .0
936            .event_global(index)
937            .map_err(|_| PyIndexError::new_err("index out of range"))?;
938        Ok(PyEvent {
939            event,
940            has_metadata: true,
941        })
942    }
943    fn __getitem__<'py>(
944        &self,
945        py: Python<'py>,
946        index: Bound<'py, PyAny>,
947    ) -> PyResult<Bound<'py, PyAny>> {
948        if let Ok(value) = self.evaluate(py, index.clone()) {
949            value.into_bound_py_any(py)
950        } else if let Ok(index) = index.extract::<usize>() {
951            let event = self
952                .0
953                .event_global(index)
954                .map_err(|_| PyIndexError::new_err("index out of range"))?;
955            PyEvent {
956                event,
957                has_metadata: true,
958            }
959            .into_bound_py_any(py)
960        } else {
961            Err(PyTypeError::new_err(
962                "Unsupported index type (int or Variable)",
963            ))
964        }
965    }
966    /// Separates a Dataset into histogram bins by a Variable value
967    ///
968    /// Parameters
969    /// ----------
970    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
971    ///     The Variable by which each Event is binned
972    /// bins : int
973    ///     The number of equally-spaced bins
974    /// range : tuple[float, float]
975    ///     The minimum and maximum bin edges
976    ///
977    /// Returns
978    /// -------
979    /// datasets : BinnedDataset
980    ///     A structure containing Datasets binned by the given `variable`
981    ///
982    /// See Also
983    /// --------
984    /// laddu.Mass
985    /// laddu.CosTheta
986    /// laddu.Phi
987    /// laddu.PolAngle
988    /// laddu.PolMagnitude
989    /// laddu.Mandelstam
990    ///
991    /// Examples
992    /// --------
993    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
994    /// >>> binned = dataset.bin_by(Mass(['kshort1']), bins=10, range=(0.9, 1.5))  # doctest: +SKIP
995    /// >>> len(binned)  # doctest: +SKIP
996    /// 10
997    ///
998    /// Raises
999    /// ------
1000    /// TypeError
1001    ///     If the given `variable` is not a valid variable
1002    ///
1003    #[pyo3(signature = (variable, bins, range))]
1004    fn bin_by(
1005        &self,
1006        variable: Bound<'_, PyAny>,
1007        bins: usize,
1008        range: (f64, f64),
1009    ) -> PyResult<PyBinnedDataset> {
1010        let py_variable = variable.extract::<PyVariable>()?;
1011        let bound_variable = py_variable.bound(self.0.metadata())?;
1012        Ok(PyBinnedDataset(self.0.bin_by(
1013            bound_variable,
1014            bins,
1015            range,
1016        )?))
1017    }
1018    /// Filter the Dataset by a given VariableExpression, selecting events for which the expression returns ``True``.
1019    ///
1020    /// Parameters
1021    /// ----------
1022    /// expression : VariableExpression
1023    ///     The expression with which to filter the Dataset
1024    ///
1025    /// Returns
1026    /// -------
1027    /// Dataset
1028    ///     The filtered Dataset
1029    ///
1030    /// Examples
1031    /// --------
1032    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
1033    /// >>> heavy = dataset.filter(Mass(['kshort1']) > 1.0)  # doctest: +SKIP
1034    ///
1035    pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
1036        Ok(PyDataset(
1037            self.0.filter(&expression.0).map_err(PyErr::from)?,
1038        ))
1039    }
1040    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
1041    ///
1042    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
1043    ///
1044    /// Parameters
1045    /// ----------
1046    /// seed : int
1047    ///     The random seed used in the resampling process
1048    ///
1049    /// Returns
1050    /// -------
1051    /// Dataset
1052    ///     A bootstrapped Dataset
1053    ///
1054    /// Examples
1055    /// --------
1056    /// >>> replica = dataset.bootstrap(2024)  # doctest: +SKIP
1057    /// >>> len(replica) == len(dataset)  # doctest: +SKIP
1058    /// True
1059    ///
1060    fn bootstrap(&self, seed: usize) -> PyDataset {
1061        PyDataset(self.0.bootstrap(seed))
1062    }
1063    /// Boost all the four-momenta in all events to the rest frame of the given set of
1064    /// named four-momenta.
1065    ///
1066    /// Parameters
1067    /// ----------
1068    /// names : list of str
1069    ///     The names of the four-momenta defining the rest frame
1070    ///
1071    /// Returns
1072    /// -------
1073    /// Dataset
1074    ///     The boosted dataset
1075    ///
1076    /// Examples
1077    /// --------
1078    /// >>> dataset.boost_to_rest_frame_of(['kshort1', 'kshort2'])  # doctest: +SKIP
1079    ///
1080    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
1081        PyDataset(self.0.boost_to_rest_frame_of(&names))
1082    }
1083    /// Get the value of a Variable over every event in the Dataset.
1084    ///
1085    /// Parameters
1086    /// ----------
1087    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
1088    ///
1089    /// Returns
1090    /// -------
1091    /// values : array_like
1092    ///
1093    /// Examples
1094    /// --------
1095    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
1096    /// >>> masses = dataset.evaluate(Mass(['kshort1']))  # doctest: +SKIP
1097    /// >>> masses.shape  # doctest: +SKIP
1098    /// (len(dataset),)
1099    ///
1100    fn evaluate<'py>(
1101        &self,
1102        py: Python<'py>,
1103        variable: Bound<'py, PyAny>,
1104    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
1105        let variable = variable.extract::<PyVariable>()?;
1106        let bound_variable = variable.bound(self.0.metadata())?;
1107        let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
1108        Ok(PyArray1::from_vec(py, values))
1109    }
1110}
1111
1112/// Read a Dataset from a Parquet file.
1113///
1114/// Examples
1115/// --------
1116/// >>> import laddu.io as ldio  # doctest: +SKIP
1117/// >>> dataset = ldio.read_parquet(  # doctest: +SKIP
1118/// ...     'events.parquet',
1119/// ...     p4s=['beam', 'proton'],
1120/// ...     aux=['pol_magnitude', 'pol_angle'],
1121/// ...     aliases={'target': 'proton'},
1122/// ... )
1123/// >>> dataset.p4_names  # doctest: +SKIP
1124/// ['beam', 'proton']
1125#[pyfunction]
1126#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
1127pub fn read_parquet(
1128    path: Bound<PyAny>,
1129    p4s: Option<Vec<String>>,
1130    aux: Option<Vec<String>>,
1131    aliases: Option<Bound<PyDict>>,
1132) -> PyResult<PyDataset> {
1133    let path_str = parse_dataset_path(path)?;
1134    let mut read_options = DatasetReadOptions::default();
1135    if let Some(p4s) = p4s {
1136        read_options = read_options.p4_names(p4s);
1137    }
1138    if let Some(aux) = aux {
1139        read_options = read_options.aux_names(aux);
1140    }
1141    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1142        read_options = read_options.alias(alias_name, selection);
1143    }
1144    let dataset = core_read_parquet(&path_str, &read_options)?;
1145    Ok(PyDataset(dataset))
1146}
1147
1148/// Read a Dataset from a Parquet file in chunks.
1149#[pyfunction]
1150#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
1151pub fn read_parquet_chunked(
1152    path: Bound<PyAny>,
1153    p4s: Option<Vec<String>>,
1154    aux: Option<Vec<String>>,
1155    aliases: Option<Bound<PyDict>>,
1156    chunk_size: Option<usize>,
1157) -> PyResult<PyParquetChunkIter> {
1158    let path_str = parse_dataset_path(path)?;
1159    let mut read_options = DatasetReadOptions::default();
1160    if let Some(p4s) = p4s {
1161        read_options = read_options.p4_names(p4s);
1162    }
1163    if let Some(aux) = aux {
1164        read_options = read_options.aux_names(aux);
1165    }
1166    if let Some(chunk_size) = chunk_size {
1167        read_options = read_options.chunk_size(chunk_size);
1168    }
1169    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1170        read_options = read_options.alias(alias_name, selection);
1171    }
1172
1173    let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
1174    Ok(PyParquetChunkIter {
1175        chunks: Box::new(chunks),
1176    })
1177}
1178
1179/// Read a Dataset from a ROOT file using the oxyroot backend.
1180///
1181/// Examples
1182/// --------
1183/// >>> import laddu.io as ldio  # doctest: +SKIP
1184/// >>> dataset = ldio.read_root(  # doctest: +SKIP
1185/// ...     'events.root',
1186/// ...     tree='kin',
1187/// ...     p4s=['beam', 'proton'],
1188/// ...     aux=['pol_magnitude', 'pol_angle'],
1189/// ... )
1190/// >>> dataset.aux_names  # doctest: +SKIP
1191/// ['pol_magnitude', 'pol_angle']
1192#[pyfunction]
1193#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
1194pub fn read_root(
1195    path: Bound<PyAny>,
1196    tree: Option<String>,
1197    p4s: Option<Vec<String>>,
1198    aux: Option<Vec<String>>,
1199    aliases: Option<Bound<PyDict>>,
1200) -> PyResult<PyDataset> {
1201    let path_str = parse_dataset_path(path)?;
1202    let mut read_options = DatasetReadOptions::default();
1203    if let Some(p4s) = p4s {
1204        read_options = read_options.p4_names(p4s);
1205    }
1206    if let Some(aux) = aux {
1207        read_options = read_options.aux_names(aux);
1208    }
1209    if let Some(tree) = tree {
1210        read_options = read_options.tree(tree);
1211    }
1212    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1213        read_options = read_options.alias(alias_name, selection);
1214    }
1215    let dataset = core_read_root(&path_str, &read_options)?;
1216    Ok(PyDataset(dataset))
1217}
1218
1219/// Write a Dataset to a Parquet file.
1220#[pyfunction]
1221#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
1222pub fn write_parquet(
1223    dataset: &PyDataset,
1224    path: Bound<PyAny>,
1225    chunk_size: Option<usize>,
1226    precision: &str,
1227) -> PyResult<()> {
1228    let path_str = parse_dataset_path(path)?;
1229    let mut write_options = DatasetWriteOptions::default();
1230    if let Some(size) = chunk_size {
1231        write_options.batch_size = size.max(1);
1232    }
1233    write_options.precision = parse_precision_arg(Some(precision))?;
1234    core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1235}
1236
1237/// Write a Dataset to a ROOT file using the oxyroot backend.
1238#[pyfunction]
1239#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
1240pub fn write_root(
1241    dataset: &PyDataset,
1242    path: Bound<PyAny>,
1243    tree: Option<String>,
1244    chunk_size: Option<usize>,
1245    precision: &str,
1246) -> PyResult<()> {
1247    let path_str = parse_dataset_path(path)?;
1248    let mut write_options = DatasetWriteOptions::default();
1249    if let Some(name) = tree {
1250        write_options.tree = Some(name);
1251    }
1252    if let Some(size) = chunk_size {
1253        write_options.batch_size = size.max(1);
1254    }
1255    write_options.precision = parse_precision_arg(Some(precision))?;
1256    core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1257}
1258
1259#[doc(hidden)]
1260/// Build a Dataset from columnar arrays.
1261///
1262/// This is the canonical high-throughput ingestion path used by Python reader helpers.
1263///
1264/// Examples
1265/// --------
1266/// >>> import laddu.io as ldio  # doctest: +SKIP
1267/// >>> dataset = ldio.from_columns(  # doctest: +SKIP
1268/// ...     {
1269/// ...         'beam_px': [0.0],
1270/// ...         'beam_py': [0.0],
1271/// ...         'beam_pz': [8.5],
1272/// ...         'beam_e': [8.5],
1273/// ...         'proton_px': [0.0],
1274/// ...         'proton_py': [0.0],
1275/// ...         'proton_pz': [-0.2],
1276/// ...         'proton_e': [0.959],
1277/// ...         'pol_magnitude': [0.4],
1278/// ...         'pol_angle': [0.3],
1279/// ...         'weight': [1.0],
1280/// ...     },
1281/// ...     p4s=['beam', 'proton'],
1282/// ...     aux=['pol_magnitude', 'pol_angle'],
1283/// ...     aliases={'target': 'proton'},
1284/// ... )
1285/// >>> dataset[0].p4('target')  # doctest: +SKIP
1286/// Vec4(px=0.0, py=0.0, pz=-0.2, e=0.959)
1287#[pyfunction]
1288#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
1289pub fn from_columns(
1290    columns: Bound<'_, PyDict>,
1291    p4s: Option<Vec<String>>,
1292    aux: Option<Vec<String>>,
1293    aliases: Option<Bound<'_, PyDict>>,
1294) -> PyResult<PyDataset> {
1295    let column_names = columns
1296        .iter()
1297        .map(|(key, _)| key.extract::<String>())
1298        .collect::<PyResult<Vec<_>>>()?;
1299
1300    let (detected_p4_names, detected_aux_names) =
1301        infer_p4_and_aux_names_from_columns(&column_names);
1302    let p4_names = p4s.unwrap_or(detected_p4_names);
1303    if p4_names.is_empty() {
1304        let mut partial_components: std::collections::BTreeMap<
1305            String,
1306            std::collections::BTreeSet<&str>,
1307        > = std::collections::BTreeMap::new();
1308        for column_name in &column_names {
1309            let lowered = column_name.to_ascii_lowercase();
1310            for suffix in P4_COMPONENT_SUFFIXES {
1311                if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
1312                    let prefix = column_name[..column_name.len() - suffix.len()].to_string();
1313                    partial_components.entry(prefix).or_default().insert(suffix);
1314                }
1315            }
1316        }
1317        if let Some((prefix, present)) = partial_components.iter().next() {
1318            if present.len() < P4_COMPONENT_SUFFIXES.len() {
1319                let missing = P4_COMPONENT_SUFFIXES
1320                    .iter()
1321                    .filter(|suffix| !present.contains(**suffix))
1322                    .map(|suffix| format!("{prefix}{suffix}"))
1323                    .collect::<Vec<_>>()
1324                    .join(", ");
1325                return Err(PyKeyError::new_err(format!(
1326                    "Missing components [{missing}] for four-momentum '{prefix}'"
1327                )));
1328            }
1329        }
1330        return Err(PyValueError::new_err(
1331            "No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
1332        ));
1333    }
1334
1335    let aux_names = aux.unwrap_or(detected_aux_names);
1336    let p4_component_columns =
1337        resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
1338    let resolved_aux_columns =
1339        resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
1340
1341    let n_events = {
1342        let first_name = p4_component_columns
1343            .first()
1344            .map(|components| components[0].clone())
1345            .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
1346        let values = extract_numeric_column(
1347            columns
1348                .get_item(first_name.as_str())?
1349                .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
1350            &first_name,
1351        )?;
1352        values.len()
1353    };
1354
1355    let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
1356    for component_names in &p4_component_columns {
1357        let px = extract_numeric_column(
1358            columns
1359                .get_item(component_names[0].as_str())?
1360                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
1361            component_names[0].as_str(),
1362        )?;
1363        let py = extract_numeric_column(
1364            columns
1365                .get_item(component_names[1].as_str())?
1366                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
1367            component_names[1].as_str(),
1368        )?;
1369        let pz = extract_numeric_column(
1370            columns
1371                .get_item(component_names[2].as_str())?
1372                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
1373            component_names[2].as_str(),
1374        )?;
1375        let e = extract_numeric_column(
1376            columns
1377                .get_item(component_names[3].as_str())?
1378                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
1379            component_names[3].as_str(),
1380        )?;
1381        if px.len() != n_events
1382            || py.len() != n_events
1383            || pz.len() != n_events
1384            || e.len() != n_events
1385        {
1386            return Err(PyValueError::new_err(
1387                "All p4 components must have the same length",
1388            ));
1389        }
1390        p4_columns.push([px, py, pz, e]);
1391    }
1392
1393    let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
1394    for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
1395        let values = extract_numeric_column(
1396            columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
1397                PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
1398            })?,
1399            aux_name,
1400        )?;
1401        if values.len() != n_events {
1402            return Err(PyValueError::new_err(format!(
1403                "Auxiliary column '{aux_name}' length does not match p4 columns"
1404            )));
1405        }
1406        aux_columns.push(values);
1407    }
1408
1409    let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
1410        let weight_values = columns
1411            .get_item(weight_column_name.as_str())?
1412            .ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
1413        let values = extract_numeric_column(weight_values, "weight")?;
1414        if values.len() != n_events {
1415            return Err(PyValueError::new_err(
1416                "Column 'weight' length does not match p4 columns",
1417            ));
1418        }
1419        values
1420    } else {
1421        vec![1.0; n_events]
1422    };
1423
1424    let parsed_aliases = parse_aliases(aliases)?;
1425    let mut metadata =
1426        DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
1427    if !parsed_aliases.is_empty() {
1428        metadata
1429            .add_p4_aliases(
1430                parsed_aliases
1431                    .into_iter()
1432                    .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
1433            )
1434            .map_err(PyErr::from)?;
1435    }
1436
1437    let p4_columns = p4_columns
1438        .into_iter()
1439        .map(|components| {
1440            (0..n_events)
1441                .map(|event_idx| {
1442                    laddu_core::vectors::Vec4::new(
1443                        components[0][event_idx],
1444                        components[1][event_idx],
1445                        components[2][event_idx],
1446                        components[3][event_idx],
1447                    )
1448                })
1449                .collect::<Vec<_>>()
1450        })
1451        .collect::<Vec<_>>();
1452
1453    Ok(PyDataset(Arc::new(Dataset::from_columns_global(
1454        metadata,
1455        p4_columns,
1456        aux_columns,
1457        weights,
1458    )?)))
1459}
1460
1461/// A collection of Datasets binned by a Variable
1462///
1463/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
1464///
1465/// See Also
1466/// --------
1467/// laddu.Dataset.bin_by
1468///
1469#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
1470pub struct PyBinnedDataset(BinnedDataset);
1471
1472#[pymethods]
1473impl PyBinnedDataset {
1474    fn __len__(&self) -> usize {
1475        self.0.n_bins()
1476    }
1477    /// The number of bins in the BinnedDataset
1478    ///
1479    #[getter]
1480    fn n_bins(&self) -> usize {
1481        self.0.n_bins()
1482    }
1483    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
1484    ///
1485    #[getter]
1486    fn range(&self) -> (f64, f64) {
1487        self.0.range()
1488    }
1489    /// The edges of each bin in the BinnedDataset
1490    ///
1491    #[getter]
1492    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
1493        PyArray1::from_slice(py, &self.0.edges())
1494    }
1495    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
1496        self.0
1497            .get(index)
1498            .ok_or(PyIndexError::new_err("index out of range"))
1499            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
1500    }
1501
1502    fn __repr__(&self) -> String {
1503        format!(
1504            "BinnedDataset(n_bins={}, range={:?})",
1505            self.0.n_bins(),
1506            self.0.range()
1507        )
1508    }
1509
1510    fn __str__(&self) -> String {
1511        self.__repr__()
1512    }
1513}