Skip to main content

laddu_python/
data.rs

1use std::{path::PathBuf, sync::Arc};
2
3use laddu_core::{
4    data::{
5        io::{
6            infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
7            resolve_optional_weight_column, resolve_p4_component_columns, ParquetBatchWriter,
8            P4_COMPONENT_SUFFIXES,
9        },
10        read_parquet as core_read_parquet,
11        read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
12        read_root as core_read_root, write_parquet as core_write_parquet,
13        write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
14        DatasetWriteOptions, EventData, FloatPrecision, OwnedEvent, SharedDatasetIterExt,
15    },
16    variables::IntoP4Selection,
17    DatasetReadOptions, Vec4,
18};
19use numpy::{PyArray1, PyReadonlyArray1};
20use pyo3::{
21    exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
22    prelude::*,
23    types::{PyDict, PyList},
24    IntoPyObjectExt,
25};
26
27use crate::{
28    variables::{PyVariable, PyVariableExpression},
29    vectors::PyVec4,
30};
31
32fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
33    let Some(aliases) = aliases else {
34        return Ok(Vec::new());
35    };
36
37    let mut parsed = Vec::new();
38    for (key, value) in aliases.iter() {
39        let alias_name = key.extract::<String>()?;
40        let selection = if let Ok(single) = value.extract::<String>() {
41            vec![single]
42        } else {
43            let seq = value.extract::<Vec<String>>().map_err(|_| {
44                PyTypeError::new_err("Alias values must be a string or a sequence of strings")
45            })?;
46            if seq.is_empty() {
47                return Err(PyValueError::new_err(format!(
48                    "Alias '{alias_name}' must reference at least one particle",
49                )));
50            }
51            seq
52        };
53        parsed.push((alias_name, selection));
54    }
55
56    Ok(parsed)
57}
58
59fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
60    if let Ok(s) = path.extract::<String>() {
61        Ok(s)
62    } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
63        Ok(pathbuf.to_string_lossy().into_owned())
64    } else {
65        Err(PyTypeError::new_err("Expected str or Path"))
66    }
67}
68
69fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
70    match value.map(|v| v.to_ascii_lowercase()) {
71        None => Ok(FloatPrecision::F64),
72        Some(name) if name == "f64" || name == "float64" || name == "double" => {
73            Ok(FloatPrecision::F64)
74        }
75        Some(name) if name == "f32" || name == "float32" || name == "float" => {
76            Ok(FloatPrecision::F32)
77        }
78        Some(other) => Err(PyValueError::new_err(format!(
79            "Unsupported precision '{other}' (expected 'f64' or 'f32')"
80        ))),
81    }
82}
83
84fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
85    if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
86        return Ok(array.as_slice()?.to_vec());
87    }
88    if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
89        return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
90    }
91    if let Ok(values) = value.extract::<Vec<f64>>() {
92        return Ok(values);
93    }
94    if let Ok(values) = value.extract::<Vec<f32>>() {
95        return Ok(values.into_iter().map(|v| v as f64).collect());
96    }
97    if let Ok(list) = value.cast::<PyList>() {
98        let mut converted = Vec::with_capacity(list.len());
99        for item in list.iter() {
100            converted.push(item.extract::<f64>().map_err(|_| {
101                PyTypeError::new_err(format!(
102                    "Column '{name}' must be numeric (float32/float64/list of floats)"
103                ))
104            })?);
105        }
106        return Ok(converted);
107    }
108    Err(PyTypeError::new_err(format!(
109        "Column '{name}' must be numeric (float32/float64/list of floats)"
110    )))
111}
112
113fn metadata_from_names_and_aliases(
114    p4_names: Vec<String>,
115    aux_names: Vec<String>,
116    aliases: Option<Bound<'_, PyDict>>,
117) -> PyResult<DatasetMetadata> {
118    let parsed_aliases = parse_aliases(aliases)?;
119    let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
120    if !parsed_aliases.is_empty() {
121        metadata
122            .add_p4_aliases(
123                parsed_aliases
124                    .into_iter()
125                    .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
126            )
127            .map_err(PyErr::from)?;
128    }
129    Ok(metadata)
130}
131
132fn parse_p4_mapping(p4: Bound<'_, PyDict>) -> PyResult<Vec<(String, Vec4)>> {
133    p4.iter()
134        .map(|(key, value)| {
135            Ok((
136                key.extract::<String>()?,
137                value
138                    .extract::<PyVec4>()
139                    .map_err(|_| PyTypeError::new_err("p4 values must be laddu.Vec4 instances"))?
140                    .0,
141            ))
142        })
143        .collect()
144}
145
146fn parse_aux_mapping(aux: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, f64)>> {
147    let Some(aux) = aux else {
148        return Ok(Vec::new());
149    };
150    aux.iter()
151        .map(|(key, value)| Ok((key.extract::<String>()?, value.extract::<f64>()?)))
152        .collect()
153}
154
155fn parse_p4_column(values: Vec<PyVec4>) -> Vec<Vec4> {
156    values.into_iter().map(|value| value.0).collect()
157}
158
159fn dataset_from_py_events(
160    events: Vec<PyEvent>,
161    p4_names: Option<Vec<String>>,
162    aux_names: Option<Vec<String>>,
163    aliases: Option<Bound<PyDict>>,
164    global: bool,
165) -> PyResult<PyDataset> {
166    let inferred_metadata = events
167        .iter()
168        .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
169
170    let aliases = parse_aliases(aliases)?;
171    let use_explicit_metadata = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
172
173    let metadata = if use_explicit_metadata {
174        let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
175            (Some(names), _) => names,
176            (None, Some(metadata)) => metadata.p4_names().to_vec(),
177            (None, None) => Vec::new(),
178        };
179        let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
180            (Some(names), _) => names,
181            (None, Some(metadata)) => metadata.aux_names().to_vec(),
182            (None, None) => Vec::new(),
183        };
184
185        if !aliases.is_empty() && resolved_p4_names.is_empty() {
186            return Err(PyValueError::new_err(
187                "`aliases` requires `p4_names` or events with metadata for resolution",
188            ));
189        }
190
191        let mut metadata =
192            DatasetMetadata::new(resolved_p4_names, resolved_aux_names).map_err(PyErr::from)?;
193        if !aliases.is_empty() {
194            metadata
195                .add_p4_aliases(
196                    aliases
197                        .into_iter()
198                        .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
199                )
200                .map_err(PyErr::from)?;
201        }
202        Some(Arc::new(metadata))
203    } else {
204        inferred_metadata
205    };
206
207    let events: Vec<Arc<EventData>> = events
208        .into_iter()
209        .map(|event| event.event.data_arc())
210        .collect();
211    let dataset = match (metadata, global) {
212        (Some(metadata), true) => Dataset::new_with_metadata(events, metadata),
213        (Some(metadata), false) => Dataset::new_local(events, metadata),
214        (None, true) => Dataset::new(events),
215        (None, false) => Dataset::new_local(events, Arc::new(DatasetMetadata::default())),
216    };
217    Ok(PyDataset(Arc::new(dataset)))
218}
219
220/// A single event
221///
222/// Events are composed of a set of 4-momenta of particles in the overall
223/// center-of-momentum frame, optional auxiliary scalars (e.g. polarization magnitude or angle),
224/// and a weight.
225///
226/// Parameters
227/// ----------
228/// p4s : list of Vec4
229///     4-momenta of each particle in the event in the overall center-of-momentum frame
230/// aux: list of float
231///     Scalar auxiliary data associated with the event
232/// weight : float
233///     The weight associated with this event
234/// p4_names : list of str, optional
235///     Human-readable aliases for each four-momentum. Providing names enables name-based
236///     lookups when evaluating variables.
237/// aux_names : list of str, optional
238///     Aliases for auxiliary scalars corresponding to ``aux``.
239/// aliases : dict of {str: str or list[str]}, optional
240///     Additional particle identifiers that reference one or more entries from ``p4_names``.
241///
242/// Examples
243/// --------
244/// >>> from laddu import Event, Vec3  # doctest: +SKIP
245/// >>> event = Event(  # doctest: +SKIP
246/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, 1.0).with_mass(0.0)],
247/// ...     [],
248/// ...     1.0,
249/// ...     p4_names=['kshort1', 'kshort2'],
250/// ...     aliases={'pair': ['kshort1', 'kshort2']},
251/// ... )
252/// >>> event.p4('pair')  # doctest: +SKIP
253/// Vec4(px=0.0, py=0.0, pz=2.0, e=2.0)
254/// >>> event.aux['pol_angle']  # doctest: +SKIP
255/// 0.3
256///
257#[pyclass(name = "Event", module = "laddu", from_py_object)]
258#[derive(Clone)]
259pub struct PyEvent {
260    pub event: OwnedEvent,
261    has_metadata: bool,
262}
263
264#[pymethods]
265impl PyEvent {
266    #[new]
267    #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
268    fn new(
269        p4s: Vec<PyVec4>,
270        aux: Vec<f64>,
271        weight: f64,
272        p4_names: Option<Vec<String>>,
273        aux_names: Option<Vec<String>>,
274        aliases: Option<Bound<PyDict>>,
275    ) -> PyResult<Self> {
276        let event = EventData {
277            p4s: p4s.into_iter().map(|arr| arr.0).collect(),
278            aux,
279            weight,
280        };
281        let aliases = parse_aliases(aliases)?;
282
283        let missing_p4_names = p4_names
284            .as_ref()
285            .map(|names| names.is_empty())
286            .unwrap_or(true);
287
288        if !aliases.is_empty() && missing_p4_names {
289            return Err(PyValueError::new_err(
290                "`aliases` requires `p4_names` so selections can be resolved",
291            ));
292        }
293
294        let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
295        let metadata = if metadata_provided {
296            let p4_names = p4_names.unwrap_or_default();
297            let aux_names = aux_names.unwrap_or_default();
298            let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
299            if !aliases.is_empty() {
300                metadata
301                    .add_p4_aliases(
302                        aliases.into_iter().map(|(alias_name, selection)| {
303                            (alias_name, selection.into_selection())
304                        }),
305                    )
306                    .map_err(PyErr::from)?;
307            }
308            Arc::new(metadata)
309        } else {
310            Arc::new(DatasetMetadata::empty())
311        };
312        let event = OwnedEvent::new(Arc::new(event), metadata);
313        Ok(Self {
314            event,
315            has_metadata: metadata_provided,
316        })
317    }
318    fn __str__(&self) -> String {
319        self.event.to_string()
320    }
321    fn __repr__(&self) -> String {
322        self.__str__()
323    }
324    /// The list of 4-momenta for each particle in the event
325    ///
326    #[getter]
327    fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
328        self.ensure_metadata()?;
329        let mapping = PyDict::new(py);
330        for (name, vec4) in self.event.p4s() {
331            mapping.set_item(name, PyVec4(vec4))?;
332        }
333        Ok(mapping.into())
334    }
335    /// The auxiliary scalar values associated with the event
336    ///
337    #[getter]
338    #[pyo3(name = "aux")]
339    fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
340        self.ensure_metadata()?;
341        let mapping = PyDict::new(py);
342        for (name, value) in self.event.aux() {
343            mapping.set_item(name, value)?;
344        }
345        Ok(mapping.into())
346    }
347    /// The weight of this event relative to others in a Dataset
348    ///
349    #[getter]
350    fn get_weight(&self) -> f64 {
351        self.event.weight()
352    }
353    /// Get the sum of the four-momenta within the event at the given indices
354    ///
355    /// Parameters
356    /// ----------
357    /// names : list of str
358    ///     The names of the four-momenta to sum
359    ///
360    /// Returns
361    /// -------
362    /// Vec4
363    ///     The result of summing the given four-momenta
364    ///
365    fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
366        let indices = self.resolve_p4_indices(&names)?;
367        Ok(PyVec4(self.event.data().get_p4_sum(indices)))
368    }
369    /// Boost all the four-momenta in the event to the rest frame of the given set of
370    /// four-momenta by indices.
371    ///
372    /// Parameters
373    /// ----------
374    /// names : list of str
375    ///     The names of the four-momenta whose rest frame should be used for the boost
376    ///
377    /// Returns
378    /// -------
379    /// Event
380    ///     The boosted event
381    ///
382    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
383        let indices = self.resolve_p4_indices(&names)?;
384        let boosted = self.event.data().boost_to_rest_frame_of(indices);
385        Ok(Self {
386            event: OwnedEvent::new(Arc::new(boosted), self.event.metadata_arc()),
387            has_metadata: self.has_metadata,
388        })
389    }
390    /// Get the value of a Variable on the given Event
391    ///
392    /// Parameters
393    /// ----------
394    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
395    ///
396    /// Returns
397    /// -------
398    /// float
399    ///
400    /// Notes
401    /// -----
402    /// Variables that rely on particle names require the event to carry metadata. Provide
403    /// ``p4_names``/``aux_names`` when constructing the event or evaluate variables through a
404    /// ``laddu.Dataset`` to ensure the metadata is available.
405    ///
406    /// Examples
407    /// --------
408    /// >>> from laddu import Event, Vec3  # doctest: +SKIP
409    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
410    /// >>> event = Event(  # doctest: +SKIP
411    /// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.938)],
412    /// ...     [],
413    /// ...     1.0,
414    /// ...     p4_names=['proton'],
415    /// ... )
416    /// >>> event.evaluate(Mass(['proton']))  # doctest: +SKIP
417    /// 0.938
418    ///
419    fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
420        let mut variable = variable.extract::<PyVariable>()?;
421        let metadata = self.ensure_metadata()?;
422        variable.bind_in_place(metadata)?;
423        variable.evaluate_event(&self.event)
424    }
425
426    /// Retrieve a four-momentum by name.
427    fn p4(&self, name: &str) -> PyResult<PyVec4> {
428        self.ensure_metadata()?;
429        self.event
430            .p4(name)
431            .map(PyVec4)
432            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
433    }
434}
435
436impl PyEvent {
437    fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
438        if !self.has_metadata {
439            Err(PyValueError::new_err(
440                "Event has no associated metadata for name-based operations",
441            ))
442        } else {
443            Ok(self.event.metadata())
444        }
445    }
446
447    fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
448        let metadata = self.ensure_metadata()?;
449        let mut resolved = Vec::new();
450        for name in names {
451            let selection = metadata
452                .p4_selection(name)
453                .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
454            resolved.extend_from_slice(selection.indices());
455        }
456        Ok(resolved)
457    }
458
459    pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
460        self.has_metadata.then(|| self.event.metadata())
461    }
462}
463
464#[doc(hidden)]
465/// A set of Events
466///
467/// Datasets can be created from lists of Events or by using the constructor helpers
468/// such as :func:`laddu.io.read_parquet`, :func:`laddu.io.read_root`, and
469/// :func:`laddu.io.read_amptools`
470///
471/// Datasets can also be indexed directly to access individual Events
472///
473/// Parameters
474/// ----------
475/// events : list of Event
476/// p4_names : list of str, optional
477///     Names assigned to each four-momentum; enables name-based lookups if provided.
478/// aux_names : list of str, optional
479///     Names for auxiliary scalars stored alongside the events.
480/// aliases : dict of {str: str or list[str]}, optional
481///     Additional particle identifiers that override aliases stored on the Events.
482///
483/// Notes
484/// -----
485/// Explicit metadata provided here takes precedence over metadata embedded in the
486/// input Events.
487///
488/// Examples
489/// --------
490/// >>> from laddu import Dataset, Event, Vec3  # doctest: +SKIP
491/// >>> event = Event(  # doctest: +SKIP
492/// ...     [Vec3(0.0, 0.0, 1.0).with_mass(0.0), Vec3(0.0, 0.0, -1.0).with_mass(0.938)],
493/// ...     [0.4, 0.3],
494/// ...     1.0,
495/// ... )
496/// >>> dataset = Dataset(  # doctest: +SKIP
497/// ...     [event],
498/// ...     p4_names=['beam', 'proton'],
499/// ...     aux_names=['pol_magnitude', 'pol_angle'],
500/// ...     aliases={'target': 'proton'},
501/// ... )
502/// >>> dataset[0].p4('target')  # doctest: +SKIP
503/// Vec4(px=0.0, py=0.0, pz=-1.0, e=1.371073...)
504///
505#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
506#[derive(Clone)]
507pub struct PyDataset(pub Arc<Dataset>);
508
509#[pyclass(
510    name = "ParquetChunkIter",
511    module = "laddu",
512    unsendable,
513    skip_from_py_object
514)]
515pub struct PyParquetChunkIter {
516    chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
517}
518
519#[pymethods]
520impl PyParquetChunkIter {
521    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
522        slf.into()
523    }
524
525    fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
526        match self.chunks.next() {
527            Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
528            Some(Err(err)) => Err(PyErr::from(err)),
529            None => Ok(None),
530        }
531    }
532}
533
534/// Streaming writer for appending compatible Dataset batches to one Parquet file.
535#[pyclass(name = "ParquetBatchWriter", module = "laddu", unsendable)]
536pub struct PyParquetBatchWriter {
537    writer: Option<ParquetBatchWriter>,
538}
539
540#[pymethods]
541impl PyParquetBatchWriter {
542    /// Append one Dataset batch.
543    fn write(&mut self, dataset: &PyDataset) -> PyResult<()> {
544        self.writer_mut()?
545            .write(dataset.0.as_ref())
546            .map_err(PyErr::from)
547    }
548
549    /// Finalize the Parquet file.
550    fn close(&mut self) -> PyResult<()> {
551        if let Some(writer) = &mut self.writer {
552            writer.close()?;
553        }
554        self.writer = None;
555        Ok(())
556    }
557
558    fn __enter__(slf: Py<Self>) -> Py<Self> {
559        slf
560    }
561
562    fn __exit__(
563        &mut self,
564        _exc_type: Bound<'_, PyAny>,
565        _exc_value: Bound<'_, PyAny>,
566        _traceback: Bound<'_, PyAny>,
567    ) -> PyResult<bool> {
568        self.close()?;
569        Ok(false)
570    }
571}
572
573impl PyParquetBatchWriter {
574    fn new(writer: ParquetBatchWriter) -> Self {
575        Self {
576            writer: Some(writer),
577        }
578    }
579
580    fn writer_mut(&mut self) -> PyResult<&mut ParquetBatchWriter> {
581        self.writer
582            .as_mut()
583            .ok_or_else(|| PyValueError::new_err("ParquetBatchWriter is closed"))
584    }
585}
586
587#[pyclass(
588    name = "DatasetEventsGlobal",
589    module = "laddu",
590    unsendable,
591    skip_from_py_object
592)]
593pub struct PyDatasetEventsGlobalIter {
594    iter: DatasetArcIter,
595}
596
597#[pymethods]
598impl PyDatasetEventsGlobalIter {
599    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsGlobalIter> {
600        slf.into()
601    }
602
603    fn __next__(&mut self) -> Option<PyEvent> {
604        self.iter.next().map(|rust_event| PyEvent {
605            event: rust_event,
606            has_metadata: true,
607        })
608    }
609}
610
611#[pyclass(
612    name = "DatasetEventsLocal",
613    module = "laddu",
614    unsendable,
615    skip_from_py_object
616)]
617pub struct PyDatasetEventsLocalIter {
618    dataset: Arc<Dataset>,
619    index: usize,
620}
621
622#[pymethods]
623impl PyDatasetEventsLocalIter {
624    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsLocalIter> {
625        slf.into()
626    }
627
628    fn __next__(&mut self) -> Option<PyEvent> {
629        if self.index >= self.dataset.n_events_local() {
630            return None;
631        }
632        let event = self
633            .dataset
634            .event_local(self.index)
635            .expect("local event index should exist")
636            .to_event_data();
637        self.index += 1;
638        Some(PyEvent {
639            event: OwnedEvent::new(Arc::new(event), self.dataset.metadata_arc()),
640            has_metadata: true,
641        })
642    }
643}
644
645#[pymethods]
646impl PyDataset {
647    #[new]
648    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
649    fn new(
650        events: Vec<PyEvent>,
651        p4_names: Option<Vec<String>>,
652        aux_names: Option<Vec<String>>,
653        aliases: Option<Bound<PyDict>>,
654    ) -> PyResult<Self> {
655        dataset_from_py_events(events, p4_names, aux_names, aliases, true)
656    }
657
658    /// Create a local Dataset from a sequence of Events.
659    #[staticmethod]
660    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
661    fn from_events_local(
662        events: Vec<PyEvent>,
663        p4_names: Option<Vec<String>>,
664        aux_names: Option<Vec<String>>,
665        aliases: Option<Bound<PyDict>>,
666    ) -> PyResult<Self> {
667        dataset_from_py_events(events, p4_names, aux_names, aliases, false)
668    }
669
670    /// Create a global Dataset from a sequence of Events.
671    ///
672    /// Under MPI, every rank must pass the same global event list. Rows are
673    /// partitioned across ranks using laddu's canonical contiguous partition.
674    #[staticmethod]
675    #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
676    fn from_events_global(
677        events: Vec<PyEvent>,
678        p4_names: Option<Vec<String>>,
679        aux_names: Option<Vec<String>>,
680        aliases: Option<Bound<PyDict>>,
681    ) -> PyResult<Self> {
682        dataset_from_py_events(events, p4_names, aux_names, aliases, true)
683    }
684
685    /// Create an empty local Dataset with explicit metadata.
686    ///
687    /// Parameters
688    /// ----------
689    /// p4_names : list of str
690    ///     Names of the four-momenta expected in each pushed event.
691    /// aux_names : list of str, optional
692    ///     Names of auxiliary scalar columns expected in each pushed event.
693    /// aliases : dict of {str: str or list[str]}, optional
694    ///     Additional particle identifiers that reference one or more p4 names.
695    #[staticmethod]
696    #[pyo3(signature = (*, p4_names, aux_names=None, aliases=None))]
697    fn empty_local(
698        p4_names: Vec<String>,
699        aux_names: Option<Vec<String>>,
700        aliases: Option<Bound<'_, PyDict>>,
701    ) -> PyResult<Self> {
702        let metadata =
703            metadata_from_names_and_aliases(p4_names, aux_names.unwrap_or_default(), aliases)?;
704        Ok(Self(Arc::new(Dataset::empty_local(metadata))))
705    }
706
707    /// Append one named event to the local Dataset partition.
708    ///
709    /// Parameters
710    /// ----------
711    /// p4 : dict[str, Vec4]
712    ///     Four-momenta keyed by p4 name.
713    /// aux : dict[str, float], optional
714    ///     Auxiliary scalar values keyed by aux name.
715    /// weight : float, optional
716    ///     Event weight. Defaults to 1.0.
717    #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
718    fn push_event_local(
719        &mut self,
720        p4: Bound<'_, PyDict>,
721        aux: Option<Bound<'_, PyDict>>,
722        weight: f64,
723    ) -> PyResult<()> {
724        let p4 = parse_p4_mapping(p4)?;
725        let aux = parse_aux_mapping(aux)?;
726        Arc::make_mut(&mut self.0)
727            .push_event_named_local(p4, aux, weight)
728            .map_err(PyErr::from)
729    }
730
731    /// Append one named event collectively as a single global event.
732    ///
733    /// Under MPI, all ranks must call this method in the same order. Exactly
734    /// one rank stores each event, selected by round-robin global event index.
735    #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
736    fn push_event_global(
737        &mut self,
738        p4: Bound<'_, PyDict>,
739        aux: Option<Bound<'_, PyDict>>,
740        weight: f64,
741    ) -> PyResult<()> {
742        let p4 = parse_p4_mapping(p4)?;
743        let aux = parse_aux_mapping(aux)?;
744        Arc::make_mut(&mut self.0)
745            .push_event_named_global(p4, aux, weight)
746            .map_err(PyErr::from)
747    }
748
749    /// Add a four-momentum column to the local Dataset partition.
750    #[pyo3(signature = (name, values))]
751    fn add_p4_column_local(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
752        Arc::make_mut(&mut self.0)
753            .add_p4_column_local(name, parse_p4_column(values))
754            .map_err(PyErr::from)
755    }
756
757    /// Add an auxiliary scalar column to the local Dataset partition.
758    #[pyo3(signature = (name, values))]
759    fn add_aux_column_local(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
760        let values = extract_numeric_column(values, &name)?;
761        Arc::make_mut(&mut self.0)
762            .add_aux_column_local(name, values)
763            .map_err(PyErr::from)
764    }
765
766    /// Add a four-momentum column collectively across all MPI ranks.
767    ///
768    /// Under MPI, all ranks must call this method in the same order with the
769    /// same column name. Each rank supplies values for its locally stored events.
770    #[pyo3(signature = (name, values))]
771    fn add_p4_column_global(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
772        Arc::make_mut(&mut self.0)
773            .add_p4_column_global(name, parse_p4_column(values))
774            .map_err(PyErr::from)
775    }
776
777    /// Add an auxiliary scalar column collectively across all MPI ranks.
778    ///
779    /// Under MPI, all ranks must call this method in the same order with the
780    /// same column name. Each rank supplies values for its locally stored events.
781    #[pyo3(signature = (name, values))]
782    fn add_aux_column_global(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
783        let values = extract_numeric_column(values, &name)?;
784        Arc::make_mut(&mut self.0)
785            .add_aux_column_global(name, values)
786            .map_err(PyErr::from)
787    }
788
789    fn __len__(&self) -> usize {
790        self.0.n_events()
791    }
792    fn __iter__(&self) -> PyResult<()> {
793        Err(PyTypeError::new_err(
794            "Dataset iteration is explicit; use dataset.events_local or dataset.events_global",
795        ))
796    }
797    /// Get the number of events owned by the current rank.
798    #[getter]
799    fn n_events_local(&self) -> usize {
800        self.0.n_events_local()
801    }
802    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
803        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
804            Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
805        } else if let Ok(other_int) = other.extract::<usize>() {
806            if other_int == 0 {
807                Ok(self.clone())
808            } else {
809                Err(PyTypeError::new_err(
810                    "Addition with an integer for this type is only defined for 0",
811                ))
812            }
813        } else {
814            Err(PyTypeError::new_err("Unsupported operand type for +"))
815        }
816    }
817    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
818        if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
819            Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
820        } else if let Ok(other_int) = other.extract::<usize>() {
821            if other_int == 0 {
822                Ok(self.clone())
823            } else {
824                Err(PyTypeError::new_err(
825                    "Addition with an integer for this type is only defined for 0",
826                ))
827            }
828        } else {
829            Err(PyTypeError::new_err("Unsupported operand type for +"))
830        }
831    }
832
833    fn __repr__(&self) -> String {
834        format!(
835            "Dataset(n_events={}, n_events_local={}, p4_names={:?}, aux_names={:?})",
836            self.0.n_events_global(),
837            self.0.n_events_local(),
838            self.0.p4_names(),
839            self.0.aux_names()
840        )
841    }
842
843    fn __str__(&self) -> String {
844        self.__repr__()
845    }
846
847    /// Get the number of Events in the Dataset
848    ///
849    /// Notes
850    /// -----
851    /// When MPI is enabled, this returns the global event count.
852    /// It therefore matches ``len(dataset)`` and the valid range for ``dataset[i]``.
853    ///
854    /// Returns
855    /// -------
856    /// n_events : int
857    ///     The number of Events
858    ///
859    #[getter]
860    fn n_events(&self) -> usize {
861        self.0.n_events()
862    }
863    /// Alias for ``n_events``.
864    #[getter]
865    fn n_events_global(&self) -> usize {
866        self.0.n_events_global()
867    }
868    /// Particle names used to construct four-momenta when loading from a Parquet file.
869    #[getter]
870    fn p4_names(&self) -> Vec<String> {
871        self.0.p4_names().to_vec()
872    }
873    /// Auxiliary scalar names associated with this Dataset.
874    #[getter]
875    fn aux_names(&self) -> Vec<String> {
876        self.0.aux_names().to_vec()
877    }
878
879    /// Get the weighted number of Events in the Dataset
880    ///
881    /// Notes
882    /// -----
883    /// When MPI is enabled, this returns the global weighted event count.
884    ///
885    /// Returns
886    /// -------
887    /// n_events : float
888    ///     The sum of all Event weights
889    ///
890    #[getter]
891    fn n_events_weighted(&self) -> f64 {
892        self.0.n_events_weighted()
893    }
894    /// Alias for ``n_events_weighted``.
895    #[getter]
896    fn n_events_weighted_global(&self) -> f64 {
897        self.0.n_events_weighted_global()
898    }
899    /// Get the weighted number of local Events in the Dataset
900    ///
901    /// Notes
902    /// -----
903    /// When MPI is enabled, this returns the sum of the current rank's Event
904    /// weights.
905    ///
906    /// Returns
907    /// -------
908    /// n_events : float
909    ///     The sum of the current rank's Event weights
910    ///
911    #[getter]
912    fn n_events_weighted_local(&self) -> f64 {
913        self.0.n_events_weighted_local()
914    }
915    /// The weights associated with the Dataset
916    ///
917    /// Notes
918    /// -----
919    /// When MPI is enabled, this returns the global weight vector in dataset order.
920    ///
921    /// Returns
922    /// -------
923    /// weights : array_like
924    ///     A ``numpy`` array of global Event weights
925    ///
926    #[getter]
927    fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
928        PyArray1::from_slice(py, &self.0.weights())
929    }
930    /// Alias for ``weights``.
931    #[getter]
932    fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
933        PyArray1::from_slice(py, &self.0.weights_global())
934    }
935    /// The weights associated with the Dataset on the current rank.
936    ///
937    /// Notes
938    /// -----
939    /// This is the explicit rank-local counterpart to ``weights``.
940    ///
941    /// Returns
942    /// -------
943    /// weights : array_like
944    ///     A ``numpy`` array of rank-local Event weights
945    ///
946    #[getter]
947    fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
948        PyArray1::from_slice(py, &self.0.weights_local())
949    }
950    /// Iterate over global Events in dataset order.
951    #[getter]
952    fn events_global(&self) -> PyDatasetEventsGlobalIter {
953        PyDatasetEventsGlobalIter {
954            iter: self.0.shared_iter_global(),
955        }
956    }
957    /// Iterate over Events stored on the current rank.
958    ///
959    /// Notes
960    /// -----
961    /// This is the explicit rank-local counterpart to ``events_global``.
962    #[getter]
963    fn events_local(&self) -> PyDatasetEventsLocalIter {
964        PyDatasetEventsLocalIter {
965            dataset: self.0.clone(),
966            index: 0,
967        }
968    }
969    /// Retrieve a four-momentum by particle name for the event at ``index``.
970    fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
971        self.0
972            .p4_by_name(index, name)
973            .map(PyVec4)
974            .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
975    }
976    /// Retrieve an auxiliary scalar by name for the event at ``index``.
977    fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
978        self.0
979            .aux_by_name(index, name)
980            .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
981    }
982    /// Alias for ``dataset[index]``.
983    ///
984    /// Notes
985    /// -----
986    /// This preserves the default global indexing semantics under MPI.
987    fn event_global(&self, index: usize) -> PyResult<PyEvent> {
988        let event = self
989            .0
990            .event_global(index)
991            .map_err(|_| PyIndexError::new_err("index out of range"))?;
992        Ok(PyEvent {
993            event,
994            has_metadata: true,
995        })
996    }
997    fn __getitem__<'py>(
998        &self,
999        py: Python<'py>,
1000        index: Bound<'py, PyAny>,
1001    ) -> PyResult<Bound<'py, PyAny>> {
1002        if let Ok(value) = self.evaluate(py, index.clone()) {
1003            value.into_bound_py_any(py)
1004        } else if let Ok(index) = index.extract::<usize>() {
1005            let event = self
1006                .0
1007                .event_global(index)
1008                .map_err(|_| PyIndexError::new_err("index out of range"))?;
1009            PyEvent {
1010                event,
1011                has_metadata: true,
1012            }
1013            .into_bound_py_any(py)
1014        } else {
1015            Err(PyTypeError::new_err(
1016                "Unsupported index type (int or Variable)",
1017            ))
1018        }
1019    }
1020    /// Separates a Dataset into histogram bins by a Variable value
1021    ///
1022    /// Parameters
1023    /// ----------
1024    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
1025    ///     The Variable by which each Event is binned
1026    /// bins : int
1027    ///     The number of equally-spaced bins
1028    /// range : tuple[float, float]
1029    ///     The minimum and maximum bin edges
1030    ///
1031    /// Returns
1032    /// -------
1033    /// datasets : BinnedDataset
1034    ///     A structure containing Datasets binned by the given `variable`
1035    ///
1036    /// See Also
1037    /// --------
1038    /// laddu.Mass
1039    /// laddu.CosTheta
1040    /// laddu.Phi
1041    /// laddu.PolAngle
1042    /// laddu.PolMagnitude
1043    /// laddu.Mandelstam
1044    ///
1045    /// Examples
1046    /// --------
1047    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
1048    /// >>> binned = dataset.bin_by(Mass(['kshort1']), bins=10, range=(0.9, 1.5))  # doctest: +SKIP
1049    /// >>> len(binned)  # doctest: +SKIP
1050    /// 10
1051    ///
1052    /// Raises
1053    /// ------
1054    /// TypeError
1055    ///     If the given `variable` is not a valid variable
1056    ///
1057    #[pyo3(signature = (variable, bins, range))]
1058    fn bin_by(
1059        &self,
1060        variable: Bound<'_, PyAny>,
1061        bins: usize,
1062        range: (f64, f64),
1063    ) -> PyResult<PyBinnedDataset> {
1064        let py_variable = variable.extract::<PyVariable>()?;
1065        let bound_variable = py_variable.bound(self.0.metadata())?;
1066        Ok(PyBinnedDataset(self.0.bin_by(
1067            bound_variable,
1068            bins,
1069            range,
1070        )?))
1071    }
1072    /// Filter the Dataset by a given VariableExpression, selecting events for which the expression returns ``True``.
1073    ///
1074    /// Parameters
1075    /// ----------
1076    /// expression : VariableExpression
1077    ///     The expression with which to filter the Dataset
1078    ///
1079    /// Returns
1080    /// -------
1081    /// Dataset
1082    ///     The filtered Dataset
1083    ///
1084    /// Examples
1085    /// --------
1086    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
1087    /// >>> heavy = dataset.filter(Mass(['kshort1']) > 1.0)  # doctest: +SKIP
1088    ///
1089    pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
1090        Ok(PyDataset(
1091            self.0.filter(&expression.0).map_err(PyErr::from)?,
1092        ))
1093    }
1094    /// Generate a new bootstrapped Dataset by randomly resampling the original with replacement
1095    ///
1096    /// The new Dataset is resampled with a random generator seeded by the provided `seed`
1097    ///
1098    /// Parameters
1099    /// ----------
1100    /// seed : int
1101    ///     The random seed used in the resampling process
1102    ///
1103    /// Returns
1104    /// -------
1105    /// Dataset
1106    ///     A bootstrapped Dataset
1107    ///
1108    /// Examples
1109    /// --------
1110    /// >>> replica = dataset.bootstrap(2024)  # doctest: +SKIP
1111    /// >>> len(replica) == len(dataset)  # doctest: +SKIP
1112    /// True
1113    ///
1114    fn bootstrap(&self, seed: usize) -> PyDataset {
1115        PyDataset(self.0.bootstrap(seed))
1116    }
1117    /// Boost all the four-momenta in all events to the rest frame of the given set of
1118    /// named four-momenta.
1119    ///
1120    /// Parameters
1121    /// ----------
1122    /// names : list of str
1123    ///     The names of the four-momenta defining the rest frame
1124    ///
1125    /// Returns
1126    /// -------
1127    /// Dataset
1128    ///     The boosted dataset
1129    ///
1130    /// Examples
1131    /// --------
1132    /// >>> dataset.boost_to_rest_frame_of(['kshort1', 'kshort2'])  # doctest: +SKIP
1133    ///
1134    pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
1135        PyDataset(self.0.boost_to_rest_frame_of(&names))
1136    }
1137    /// Get the value of a Variable over every event in the Dataset.
1138    ///
1139    /// Parameters
1140    /// ----------
1141    /// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
1142    ///
1143    /// Returns
1144    /// -------
1145    /// values : array_like
1146    ///
1147    /// Examples
1148    /// --------
1149    /// >>> from laddu.utils.variables import Mass  # doctest: +SKIP
1150    /// >>> masses = dataset.evaluate(Mass(['kshort1']))  # doctest: +SKIP
1151    /// >>> masses.shape  # doctest: +SKIP
1152    /// (len(dataset),)
1153    ///
1154    fn evaluate<'py>(
1155        &self,
1156        py: Python<'py>,
1157        variable: Bound<'py, PyAny>,
1158    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
1159        let variable = variable.extract::<PyVariable>()?;
1160        let bound_variable = variable.bound(self.0.metadata())?;
1161        let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
1162        Ok(PyArray1::from_vec(py, values))
1163    }
1164}
1165
1166/// Read a Dataset from a Parquet file.
1167///
1168/// Examples
1169/// --------
1170/// >>> import laddu.io as ldio  # doctest: +SKIP
1171/// >>> dataset = ldio.read_parquet(  # doctest: +SKIP
1172/// ...     'events.parquet',
1173/// ...     p4s=['beam', 'proton'],
1174/// ...     aux=['pol_magnitude', 'pol_angle'],
1175/// ...     aliases={'target': 'proton'},
1176/// ... )
1177/// >>> dataset.p4_names  # doctest: +SKIP
1178/// ['beam', 'proton']
1179#[pyfunction]
1180#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
1181pub fn read_parquet(
1182    path: Bound<PyAny>,
1183    p4s: Option<Vec<String>>,
1184    aux: Option<Vec<String>>,
1185    aliases: Option<Bound<PyDict>>,
1186) -> PyResult<PyDataset> {
1187    let path_str = parse_dataset_path(path)?;
1188    let mut read_options = DatasetReadOptions::default();
1189    if let Some(p4s) = p4s {
1190        read_options = read_options.p4_names(p4s);
1191    }
1192    if let Some(aux) = aux {
1193        read_options = read_options.aux_names(aux);
1194    }
1195    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1196        read_options = read_options.alias(alias_name, selection);
1197    }
1198    let dataset = core_read_parquet(&path_str, &read_options)?;
1199    Ok(PyDataset(dataset))
1200}
1201
1202/// Read a Dataset from a Parquet file in chunks.
1203#[pyfunction]
1204#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
1205pub fn read_parquet_chunked(
1206    path: Bound<PyAny>,
1207    p4s: Option<Vec<String>>,
1208    aux: Option<Vec<String>>,
1209    aliases: Option<Bound<PyDict>>,
1210    chunk_size: Option<usize>,
1211) -> PyResult<PyParquetChunkIter> {
1212    let path_str = parse_dataset_path(path)?;
1213    let mut read_options = DatasetReadOptions::default();
1214    if let Some(p4s) = p4s {
1215        read_options = read_options.p4_names(p4s);
1216    }
1217    if let Some(aux) = aux {
1218        read_options = read_options.aux_names(aux);
1219    }
1220    if let Some(chunk_size) = chunk_size {
1221        read_options = read_options.chunk_size(chunk_size);
1222    }
1223    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1224        read_options = read_options.alias(alias_name, selection);
1225    }
1226
1227    let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
1228    Ok(PyParquetChunkIter {
1229        chunks: Box::new(chunks),
1230    })
1231}
1232
1233/// Read a Dataset from a ROOT file using the oxyroot backend.
1234///
1235/// Examples
1236/// --------
1237/// >>> import laddu.io as ldio  # doctest: +SKIP
1238/// >>> dataset = ldio.read_root(  # doctest: +SKIP
1239/// ...     'events.root',
1240/// ...     tree='kin',
1241/// ...     p4s=['beam', 'proton'],
1242/// ...     aux=['pol_magnitude', 'pol_angle'],
1243/// ... )
1244/// >>> dataset.aux_names  # doctest: +SKIP
1245/// ['pol_magnitude', 'pol_angle']
1246#[pyfunction]
1247#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
1248pub fn read_root(
1249    path: Bound<PyAny>,
1250    tree: Option<String>,
1251    p4s: Option<Vec<String>>,
1252    aux: Option<Vec<String>>,
1253    aliases: Option<Bound<PyDict>>,
1254) -> PyResult<PyDataset> {
1255    let path_str = parse_dataset_path(path)?;
1256    let mut read_options = DatasetReadOptions::default();
1257    if let Some(p4s) = p4s {
1258        read_options = read_options.p4_names(p4s);
1259    }
1260    if let Some(aux) = aux {
1261        read_options = read_options.aux_names(aux);
1262    }
1263    if let Some(tree) = tree {
1264        read_options = read_options.tree(tree);
1265    }
1266    for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1267        read_options = read_options.alias(alias_name, selection);
1268    }
1269    let dataset = core_read_root(&path_str, &read_options)?;
1270    Ok(PyDataset(dataset))
1271}
1272
1273/// Write a Dataset to a Parquet file.
1274#[pyfunction]
1275#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
1276pub fn write_parquet(
1277    dataset: &PyDataset,
1278    path: Bound<PyAny>,
1279    chunk_size: Option<usize>,
1280    precision: &str,
1281) -> PyResult<()> {
1282    let path_str = parse_dataset_path(path)?;
1283    let mut write_options = DatasetWriteOptions::default();
1284    if let Some(size) = chunk_size {
1285        write_options.batch_size = size.max(1);
1286    }
1287    write_options.precision = parse_precision_arg(Some(precision))?;
1288    core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1289}
1290
1291/// Open a streaming Parquet writer for compatible Dataset batches.
1292#[pyfunction]
1293#[pyo3(signature = (path, *, chunk_size=None, precision="f64"))]
1294pub fn open_parquet_writer(
1295    path: Bound<PyAny>,
1296    chunk_size: Option<usize>,
1297    precision: &str,
1298) -> PyResult<PyParquetBatchWriter> {
1299    let path_str = parse_dataset_path(path)?;
1300    let mut write_options = DatasetWriteOptions::default();
1301    if let Some(size) = chunk_size {
1302        write_options.batch_size = size.max(1);
1303    }
1304    write_options.precision = parse_precision_arg(Some(precision))?;
1305    Ok(PyParquetBatchWriter::new(ParquetBatchWriter::new(
1306        &path_str,
1307        write_options,
1308    )?))
1309}
1310
1311/// Write a Dataset to a ROOT file using the oxyroot backend.
1312#[pyfunction]
1313#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
1314pub fn write_root(
1315    dataset: &PyDataset,
1316    path: Bound<PyAny>,
1317    tree: Option<String>,
1318    chunk_size: Option<usize>,
1319    precision: &str,
1320) -> PyResult<()> {
1321    let path_str = parse_dataset_path(path)?;
1322    let mut write_options = DatasetWriteOptions::default();
1323    if let Some(name) = tree {
1324        write_options.tree = Some(name);
1325    }
1326    if let Some(size) = chunk_size {
1327        write_options.batch_size = size.max(1);
1328    }
1329    write_options.precision = parse_precision_arg(Some(precision))?;
1330    core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1331}
1332
1333#[doc(hidden)]
1334/// Build a Dataset from columnar arrays.
1335///
1336/// This is the canonical high-throughput ingestion path used by Python reader helpers.
1337///
1338/// Examples
1339/// --------
1340/// >>> import laddu.io as ldio  # doctest: +SKIP
1341/// >>> dataset = ldio.from_columns(  # doctest: +SKIP
1342/// ...     {
1343/// ...         'beam_px': [0.0],
1344/// ...         'beam_py': [0.0],
1345/// ...         'beam_pz': [8.5],
1346/// ...         'beam_e': [8.5],
1347/// ...         'proton_px': [0.0],
1348/// ...         'proton_py': [0.0],
1349/// ...         'proton_pz': [-0.2],
1350/// ...         'proton_e': [0.959],
1351/// ...         'pol_magnitude': [0.4],
1352/// ...         'pol_angle': [0.3],
1353/// ...         'weight': [1.0],
1354/// ...     },
1355/// ...     p4s=['beam', 'proton'],
1356/// ...     aux=['pol_magnitude', 'pol_angle'],
1357/// ...     aliases={'target': 'proton'},
1358/// ... )
1359/// >>> dataset[0].p4('target')  # doctest: +SKIP
1360/// Vec4(px=0.0, py=0.0, pz=-0.2, e=0.959)
1361#[pyfunction]
1362#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
1363pub fn from_columns(
1364    columns: Bound<'_, PyDict>,
1365    p4s: Option<Vec<String>>,
1366    aux: Option<Vec<String>>,
1367    aliases: Option<Bound<'_, PyDict>>,
1368) -> PyResult<PyDataset> {
1369    let column_names = columns
1370        .iter()
1371        .map(|(key, _)| key.extract::<String>())
1372        .collect::<PyResult<Vec<_>>>()?;
1373
1374    let (detected_p4_names, detected_aux_names) =
1375        infer_p4_and_aux_names_from_columns(&column_names);
1376    let p4_names = p4s.unwrap_or(detected_p4_names);
1377    if p4_names.is_empty() {
1378        let mut partial_components: std::collections::BTreeMap<
1379            String,
1380            std::collections::BTreeSet<&str>,
1381        > = std::collections::BTreeMap::new();
1382        for column_name in &column_names {
1383            let lowered = column_name.to_ascii_lowercase();
1384            for suffix in P4_COMPONENT_SUFFIXES {
1385                if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
1386                    let prefix = column_name[..column_name.len() - suffix.len()].to_string();
1387                    partial_components.entry(prefix).or_default().insert(suffix);
1388                }
1389            }
1390        }
1391        if let Some((prefix, present)) = partial_components.iter().next() {
1392            if present.len() < P4_COMPONENT_SUFFIXES.len() {
1393                let missing = P4_COMPONENT_SUFFIXES
1394                    .iter()
1395                    .filter(|suffix| !present.contains(**suffix))
1396                    .map(|suffix| format!("{prefix}{suffix}"))
1397                    .collect::<Vec<_>>()
1398                    .join(", ");
1399                return Err(PyKeyError::new_err(format!(
1400                    "Missing components [{missing}] for four-momentum '{prefix}'"
1401                )));
1402            }
1403        }
1404        return Err(PyValueError::new_err(
1405            "No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
1406        ));
1407    }
1408
1409    let aux_names = aux.unwrap_or(detected_aux_names);
1410    let p4_component_columns =
1411        resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
1412    let resolved_aux_columns =
1413        resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
1414
1415    let n_events = {
1416        let first_name = p4_component_columns
1417            .first()
1418            .map(|components| components[0].clone())
1419            .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
1420        let values = extract_numeric_column(
1421            columns
1422                .get_item(first_name.as_str())?
1423                .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
1424            &first_name,
1425        )?;
1426        values.len()
1427    };
1428
1429    let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
1430    for component_names in &p4_component_columns {
1431        let px = extract_numeric_column(
1432            columns
1433                .get_item(component_names[0].as_str())?
1434                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
1435            component_names[0].as_str(),
1436        )?;
1437        let py = extract_numeric_column(
1438            columns
1439                .get_item(component_names[1].as_str())?
1440                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
1441            component_names[1].as_str(),
1442        )?;
1443        let pz = extract_numeric_column(
1444            columns
1445                .get_item(component_names[2].as_str())?
1446                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
1447            component_names[2].as_str(),
1448        )?;
1449        let e = extract_numeric_column(
1450            columns
1451                .get_item(component_names[3].as_str())?
1452                .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
1453            component_names[3].as_str(),
1454        )?;
1455        if px.len() != n_events
1456            || py.len() != n_events
1457            || pz.len() != n_events
1458            || e.len() != n_events
1459        {
1460            return Err(PyValueError::new_err(
1461                "All p4 components must have the same length",
1462            ));
1463        }
1464        p4_columns.push([px, py, pz, e]);
1465    }
1466
1467    let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
1468    for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
1469        let values = extract_numeric_column(
1470            columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
1471                PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
1472            })?,
1473            aux_name,
1474        )?;
1475        if values.len() != n_events {
1476            return Err(PyValueError::new_err(format!(
1477                "Auxiliary column '{aux_name}' length does not match p4 columns"
1478            )));
1479        }
1480        aux_columns.push(values);
1481    }
1482
1483    let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
1484        let weight_values = columns
1485            .get_item(weight_column_name.as_str())?
1486            .ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
1487        let values = extract_numeric_column(weight_values, "weight")?;
1488        if values.len() != n_events {
1489            return Err(PyValueError::new_err(
1490                "Column 'weight' length does not match p4 columns",
1491            ));
1492        }
1493        values
1494    } else {
1495        vec![1.0; n_events]
1496    };
1497
1498    let parsed_aliases = parse_aliases(aliases)?;
1499    let mut metadata =
1500        DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
1501    if !parsed_aliases.is_empty() {
1502        metadata
1503            .add_p4_aliases(
1504                parsed_aliases
1505                    .into_iter()
1506                    .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
1507            )
1508            .map_err(PyErr::from)?;
1509    }
1510
1511    let p4_columns = p4_columns
1512        .into_iter()
1513        .map(|components| {
1514            (0..n_events)
1515                .map(|event_idx| {
1516                    laddu_core::vectors::Vec4::new(
1517                        components[0][event_idx],
1518                        components[1][event_idx],
1519                        components[2][event_idx],
1520                        components[3][event_idx],
1521                    )
1522                })
1523                .collect::<Vec<_>>()
1524        })
1525        .collect::<Vec<_>>();
1526
1527    Ok(PyDataset(Arc::new(Dataset::from_columns_global(
1528        metadata,
1529        p4_columns,
1530        aux_columns,
1531        weights,
1532    )?)))
1533}
1534
1535/// A collection of Datasets binned by a Variable
1536///
1537/// BinnedDatasets can be indexed directly to access the underlying Datasets by bin
1538///
1539/// See Also
1540/// --------
1541/// laddu.Dataset.bin_by
1542///
1543#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
1544pub struct PyBinnedDataset(BinnedDataset);
1545
1546#[pymethods]
1547impl PyBinnedDataset {
1548    fn __len__(&self) -> usize {
1549        self.0.n_bins()
1550    }
1551    /// The number of bins in the BinnedDataset
1552    ///
1553    #[getter]
1554    fn n_bins(&self) -> usize {
1555        self.0.n_bins()
1556    }
1557    /// The minimum and maximum values of the binning Variable used to create this BinnedDataset
1558    ///
1559    #[getter]
1560    fn range(&self) -> (f64, f64) {
1561        self.0.range()
1562    }
1563    /// The edges of each bin in the BinnedDataset
1564    ///
1565    #[getter]
1566    fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
1567        PyArray1::from_slice(py, &self.0.edges())
1568    }
1569    fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
1570        self.0
1571            .get(index)
1572            .ok_or(PyIndexError::new_err("index out of range"))
1573            .map(|rust_dataset| PyDataset(rust_dataset.clone()))
1574    }
1575
1576    fn __repr__(&self) -> String {
1577        format!(
1578            "BinnedDataset(n_bins={}, range={:?})",
1579            self.0.n_bins(),
1580            self.0.range()
1581        )
1582    }
1583
1584    fn __str__(&self) -> String {
1585        self.__repr__()
1586    }
1587}