1use std::{path::PathBuf, sync::Arc};
2
3use laddu_core::{
4 data::{
5 io::{
6 infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
7 resolve_optional_weight_column, resolve_p4_component_columns, P4_COMPONENT_SUFFIXES,
8 },
9 read_parquet as core_read_parquet,
10 read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
11 read_root as core_read_root, write_parquet as core_write_parquet,
12 write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
13 DatasetWriteOptions, EventData, FloatPrecision, OwnedEvent, SharedDatasetIterExt,
14 },
15 variables::IntoP4Selection,
16 DatasetReadOptions, Vec4,
17};
18use numpy::{PyArray1, PyReadonlyArray1};
19use pyo3::{
20 exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
21 prelude::*,
22 types::{PyDict, PyList},
23 IntoPyObjectExt,
24};
25
26use crate::{
27 variables::{PyVariable, PyVariableExpression},
28 vectors::PyVec4,
29};
30
31fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
32 let Some(aliases) = aliases else {
33 return Ok(Vec::new());
34 };
35
36 let mut parsed = Vec::new();
37 for (key, value) in aliases.iter() {
38 let alias_name = key.extract::<String>()?;
39 let selection = if let Ok(single) = value.extract::<String>() {
40 vec![single]
41 } else {
42 let seq = value.extract::<Vec<String>>().map_err(|_| {
43 PyTypeError::new_err("Alias values must be a string or a sequence of strings")
44 })?;
45 if seq.is_empty() {
46 return Err(PyValueError::new_err(format!(
47 "Alias '{alias_name}' must reference at least one particle",
48 )));
49 }
50 seq
51 };
52 parsed.push((alias_name, selection));
53 }
54
55 Ok(parsed)
56}
57
58fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
59 if let Ok(s) = path.extract::<String>() {
60 Ok(s)
61 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
62 Ok(pathbuf.to_string_lossy().into_owned())
63 } else {
64 Err(PyTypeError::new_err("Expected str or Path"))
65 }
66}
67
68fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
69 match value.map(|v| v.to_ascii_lowercase()) {
70 None => Ok(FloatPrecision::F64),
71 Some(name) if name == "f64" || name == "float64" || name == "double" => {
72 Ok(FloatPrecision::F64)
73 }
74 Some(name) if name == "f32" || name == "float32" || name == "float" => {
75 Ok(FloatPrecision::F32)
76 }
77 Some(other) => Err(PyValueError::new_err(format!(
78 "Unsupported precision '{other}' (expected 'f64' or 'f32')"
79 ))),
80 }
81}
82
83fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
84 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
85 return Ok(array.as_slice()?.to_vec());
86 }
87 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
88 return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
89 }
90 if let Ok(values) = value.extract::<Vec<f64>>() {
91 return Ok(values);
92 }
93 if let Ok(values) = value.extract::<Vec<f32>>() {
94 return Ok(values.into_iter().map(|v| v as f64).collect());
95 }
96 if let Ok(list) = value.cast::<PyList>() {
97 let mut converted = Vec::with_capacity(list.len());
98 for item in list.iter() {
99 converted.push(item.extract::<f64>().map_err(|_| {
100 PyTypeError::new_err(format!(
101 "Column '{name}' must be numeric (float32/float64/list of floats)"
102 ))
103 })?);
104 }
105 return Ok(converted);
106 }
107 Err(PyTypeError::new_err(format!(
108 "Column '{name}' must be numeric (float32/float64/list of floats)"
109 )))
110}
111
112fn metadata_from_names_and_aliases(
113 p4_names: Vec<String>,
114 aux_names: Vec<String>,
115 aliases: Option<Bound<'_, PyDict>>,
116) -> PyResult<DatasetMetadata> {
117 let parsed_aliases = parse_aliases(aliases)?;
118 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
119 if !parsed_aliases.is_empty() {
120 metadata
121 .add_p4_aliases(
122 parsed_aliases
123 .into_iter()
124 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
125 )
126 .map_err(PyErr::from)?;
127 }
128 Ok(metadata)
129}
130
131fn parse_p4_mapping(p4: Bound<'_, PyDict>) -> PyResult<Vec<(String, Vec4)>> {
132 p4.iter()
133 .map(|(key, value)| {
134 Ok((
135 key.extract::<String>()?,
136 value
137 .extract::<PyVec4>()
138 .map_err(|_| PyTypeError::new_err("p4 values must be laddu.Vec4 instances"))?
139 .0,
140 ))
141 })
142 .collect()
143}
144
145fn parse_aux_mapping(aux: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, f64)>> {
146 let Some(aux) = aux else {
147 return Ok(Vec::new());
148 };
149 aux.iter()
150 .map(|(key, value)| Ok((key.extract::<String>()?, value.extract::<f64>()?)))
151 .collect()
152}
153
154fn parse_p4_column(values: Vec<PyVec4>) -> Vec<Vec4> {
155 values.into_iter().map(|value| value.0).collect()
156}
157
158fn dataset_from_py_events(
159 events: Vec<PyEvent>,
160 p4_names: Option<Vec<String>>,
161 aux_names: Option<Vec<String>>,
162 aliases: Option<Bound<PyDict>>,
163 global: bool,
164) -> PyResult<PyDataset> {
165 let inferred_metadata = events
166 .iter()
167 .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
168
169 let aliases = parse_aliases(aliases)?;
170 let use_explicit_metadata = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
171
172 let metadata = if use_explicit_metadata {
173 let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
174 (Some(names), _) => names,
175 (None, Some(metadata)) => metadata.p4_names().to_vec(),
176 (None, None) => Vec::new(),
177 };
178 let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
179 (Some(names), _) => names,
180 (None, Some(metadata)) => metadata.aux_names().to_vec(),
181 (None, None) => Vec::new(),
182 };
183
184 if !aliases.is_empty() && resolved_p4_names.is_empty() {
185 return Err(PyValueError::new_err(
186 "`aliases` requires `p4_names` or events with metadata for resolution",
187 ));
188 }
189
190 let mut metadata =
191 DatasetMetadata::new(resolved_p4_names, resolved_aux_names).map_err(PyErr::from)?;
192 if !aliases.is_empty() {
193 metadata
194 .add_p4_aliases(
195 aliases
196 .into_iter()
197 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
198 )
199 .map_err(PyErr::from)?;
200 }
201 Some(Arc::new(metadata))
202 } else {
203 inferred_metadata
204 };
205
206 let events: Vec<Arc<EventData>> = events
207 .into_iter()
208 .map(|event| event.event.data_arc())
209 .collect();
210 let dataset = match (metadata, global) {
211 (Some(metadata), true) => Dataset::new_with_metadata(events, metadata),
212 (Some(metadata), false) => Dataset::new_local(events, metadata),
213 (None, true) => Dataset::new(events),
214 (None, false) => Dataset::new_local(events, Arc::new(DatasetMetadata::default())),
215 };
216 Ok(PyDataset(Arc::new(dataset)))
217}
218
219#[pyclass(name = "Event", module = "laddu", from_py_object)]
257#[derive(Clone)]
258pub struct PyEvent {
259 pub event: OwnedEvent,
260 has_metadata: bool,
261}
262
263#[pymethods]
264impl PyEvent {
265 #[new]
266 #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
267 fn new(
268 p4s: Vec<PyVec4>,
269 aux: Vec<f64>,
270 weight: f64,
271 p4_names: Option<Vec<String>>,
272 aux_names: Option<Vec<String>>,
273 aliases: Option<Bound<PyDict>>,
274 ) -> PyResult<Self> {
275 let event = EventData {
276 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
277 aux,
278 weight,
279 };
280 let aliases = parse_aliases(aliases)?;
281
282 let missing_p4_names = p4_names
283 .as_ref()
284 .map(|names| names.is_empty())
285 .unwrap_or(true);
286
287 if !aliases.is_empty() && missing_p4_names {
288 return Err(PyValueError::new_err(
289 "`aliases` requires `p4_names` so selections can be resolved",
290 ));
291 }
292
293 let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
294 let metadata = if metadata_provided {
295 let p4_names = p4_names.unwrap_or_default();
296 let aux_names = aux_names.unwrap_or_default();
297 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
298 if !aliases.is_empty() {
299 metadata
300 .add_p4_aliases(
301 aliases.into_iter().map(|(alias_name, selection)| {
302 (alias_name, selection.into_selection())
303 }),
304 )
305 .map_err(PyErr::from)?;
306 }
307 Arc::new(metadata)
308 } else {
309 Arc::new(DatasetMetadata::empty())
310 };
311 let event = OwnedEvent::new(Arc::new(event), metadata);
312 Ok(Self {
313 event,
314 has_metadata: metadata_provided,
315 })
316 }
317 fn __str__(&self) -> String {
318 self.event.to_string()
319 }
320 fn __repr__(&self) -> String {
321 self.__str__()
322 }
323 #[getter]
326 fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
327 self.ensure_metadata()?;
328 let mapping = PyDict::new(py);
329 for (name, vec4) in self.event.p4s() {
330 mapping.set_item(name, PyVec4(vec4))?;
331 }
332 Ok(mapping.into())
333 }
334 #[getter]
337 #[pyo3(name = "aux")]
338 fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
339 self.ensure_metadata()?;
340 let mapping = PyDict::new(py);
341 for (name, value) in self.event.aux() {
342 mapping.set_item(name, value)?;
343 }
344 Ok(mapping.into())
345 }
346 #[getter]
349 fn get_weight(&self) -> f64 {
350 self.event.weight()
351 }
352 fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
365 let indices = self.resolve_p4_indices(&names)?;
366 Ok(PyVec4(self.event.data().get_p4_sum(indices)))
367 }
368 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
382 let indices = self.resolve_p4_indices(&names)?;
383 let boosted = self.event.data().boost_to_rest_frame_of(indices);
384 Ok(Self {
385 event: OwnedEvent::new(Arc::new(boosted), self.event.metadata_arc()),
386 has_metadata: self.has_metadata,
387 })
388 }
389 fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
419 let mut variable = variable.extract::<PyVariable>()?;
420 let metadata = self.ensure_metadata()?;
421 variable.bind_in_place(metadata)?;
422 variable.evaluate_event(&self.event)
423 }
424
425 fn p4(&self, name: &str) -> PyResult<PyVec4> {
427 self.ensure_metadata()?;
428 self.event
429 .p4(name)
430 .map(PyVec4)
431 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
432 }
433}
434
435impl PyEvent {
436 fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
437 if !self.has_metadata {
438 Err(PyValueError::new_err(
439 "Event has no associated metadata for name-based operations",
440 ))
441 } else {
442 Ok(self.event.metadata())
443 }
444 }
445
446 fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
447 let metadata = self.ensure_metadata()?;
448 let mut resolved = Vec::new();
449 for name in names {
450 let selection = metadata
451 .p4_selection(name)
452 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
453 resolved.extend_from_slice(selection.indices());
454 }
455 Ok(resolved)
456 }
457
458 pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
459 self.has_metadata.then(|| self.event.metadata())
460 }
461}
462
463#[doc(hidden)]
464#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
505#[derive(Clone)]
506pub struct PyDataset(pub Arc<Dataset>);
507
508#[pyclass(
509 name = "ParquetChunkIter",
510 module = "laddu",
511 unsendable,
512 skip_from_py_object
513)]
514pub struct PyParquetChunkIter {
515 chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
516}
517
518#[pymethods]
519impl PyParquetChunkIter {
520 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
521 slf.into()
522 }
523
524 fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
525 match self.chunks.next() {
526 Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
527 Some(Err(err)) => Err(PyErr::from(err)),
528 None => Ok(None),
529 }
530 }
531}
532
533#[pyclass(
534 name = "DatasetEventsGlobal",
535 module = "laddu",
536 unsendable,
537 skip_from_py_object
538)]
539pub struct PyDatasetEventsGlobalIter {
540 iter: DatasetArcIter,
541}
542
543#[pymethods]
544impl PyDatasetEventsGlobalIter {
545 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsGlobalIter> {
546 slf.into()
547 }
548
549 fn __next__(&mut self) -> Option<PyEvent> {
550 self.iter.next().map(|rust_event| PyEvent {
551 event: rust_event,
552 has_metadata: true,
553 })
554 }
555}
556
557#[pyclass(
558 name = "DatasetEventsLocal",
559 module = "laddu",
560 unsendable,
561 skip_from_py_object
562)]
563pub struct PyDatasetEventsLocalIter {
564 dataset: Arc<Dataset>,
565 index: usize,
566}
567
568#[pymethods]
569impl PyDatasetEventsLocalIter {
570 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsLocalIter> {
571 slf.into()
572 }
573
574 fn __next__(&mut self) -> Option<PyEvent> {
575 if self.index >= self.dataset.n_events_local() {
576 return None;
577 }
578 let event = self
579 .dataset
580 .event_local(self.index)
581 .expect("local event index should exist")
582 .to_event_data();
583 self.index += 1;
584 Some(PyEvent {
585 event: OwnedEvent::new(Arc::new(event), self.dataset.metadata_arc()),
586 has_metadata: true,
587 })
588 }
589}
590
591#[pymethods]
592impl PyDataset {
593 #[new]
594 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
595 fn new(
596 events: Vec<PyEvent>,
597 p4_names: Option<Vec<String>>,
598 aux_names: Option<Vec<String>>,
599 aliases: Option<Bound<PyDict>>,
600 ) -> PyResult<Self> {
601 dataset_from_py_events(events, p4_names, aux_names, aliases, true)
602 }
603
604 #[staticmethod]
606 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
607 fn from_events_local(
608 events: Vec<PyEvent>,
609 p4_names: Option<Vec<String>>,
610 aux_names: Option<Vec<String>>,
611 aliases: Option<Bound<PyDict>>,
612 ) -> PyResult<Self> {
613 dataset_from_py_events(events, p4_names, aux_names, aliases, false)
614 }
615
616 #[staticmethod]
621 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
622 fn from_events_global(
623 events: Vec<PyEvent>,
624 p4_names: Option<Vec<String>>,
625 aux_names: Option<Vec<String>>,
626 aliases: Option<Bound<PyDict>>,
627 ) -> PyResult<Self> {
628 dataset_from_py_events(events, p4_names, aux_names, aliases, true)
629 }
630
631 #[staticmethod]
642 #[pyo3(signature = (*, p4_names, aux_names=None, aliases=None))]
643 fn empty_local(
644 p4_names: Vec<String>,
645 aux_names: Option<Vec<String>>,
646 aliases: Option<Bound<'_, PyDict>>,
647 ) -> PyResult<Self> {
648 let metadata =
649 metadata_from_names_and_aliases(p4_names, aux_names.unwrap_or_default(), aliases)?;
650 Ok(Self(Arc::new(Dataset::empty_local(metadata))))
651 }
652
653 #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
664 fn push_event_local(
665 &mut self,
666 p4: Bound<'_, PyDict>,
667 aux: Option<Bound<'_, PyDict>>,
668 weight: f64,
669 ) -> PyResult<()> {
670 let p4 = parse_p4_mapping(p4)?;
671 let aux = parse_aux_mapping(aux)?;
672 Arc::make_mut(&mut self.0)
673 .push_event_named_local(p4, aux, weight)
674 .map_err(PyErr::from)
675 }
676
677 #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
682 fn push_event_global(
683 &mut self,
684 p4: Bound<'_, PyDict>,
685 aux: Option<Bound<'_, PyDict>>,
686 weight: f64,
687 ) -> PyResult<()> {
688 let p4 = parse_p4_mapping(p4)?;
689 let aux = parse_aux_mapping(aux)?;
690 Arc::make_mut(&mut self.0)
691 .push_event_named_global(p4, aux, weight)
692 .map_err(PyErr::from)
693 }
694
695 #[pyo3(signature = (name, values))]
697 fn add_p4_column_local(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
698 Arc::make_mut(&mut self.0)
699 .add_p4_column_local(name, parse_p4_column(values))
700 .map_err(PyErr::from)
701 }
702
703 #[pyo3(signature = (name, values))]
705 fn add_aux_column_local(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
706 let values = extract_numeric_column(values, &name)?;
707 Arc::make_mut(&mut self.0)
708 .add_aux_column_local(name, values)
709 .map_err(PyErr::from)
710 }
711
712 #[pyo3(signature = (name, values))]
717 fn add_p4_column_global(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
718 Arc::make_mut(&mut self.0)
719 .add_p4_column_global(name, parse_p4_column(values))
720 .map_err(PyErr::from)
721 }
722
723 #[pyo3(signature = (name, values))]
728 fn add_aux_column_global(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
729 let values = extract_numeric_column(values, &name)?;
730 Arc::make_mut(&mut self.0)
731 .add_aux_column_global(name, values)
732 .map_err(PyErr::from)
733 }
734
735 fn __len__(&self) -> usize {
736 self.0.n_events()
737 }
738 fn __iter__(&self) -> PyResult<()> {
739 Err(PyTypeError::new_err(
740 "Dataset iteration is explicit; use dataset.events_local or dataset.events_global",
741 ))
742 }
743 #[getter]
745 fn n_events_local(&self) -> usize {
746 self.0.n_events_local()
747 }
748 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
749 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
750 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
751 } else if let Ok(other_int) = other.extract::<usize>() {
752 if other_int == 0 {
753 Ok(self.clone())
754 } else {
755 Err(PyTypeError::new_err(
756 "Addition with an integer for this type is only defined for 0",
757 ))
758 }
759 } else {
760 Err(PyTypeError::new_err("Unsupported operand type for +"))
761 }
762 }
763 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
764 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
765 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
766 } else if let Ok(other_int) = other.extract::<usize>() {
767 if other_int == 0 {
768 Ok(self.clone())
769 } else {
770 Err(PyTypeError::new_err(
771 "Addition with an integer for this type is only defined for 0",
772 ))
773 }
774 } else {
775 Err(PyTypeError::new_err("Unsupported operand type for +"))
776 }
777 }
778
779 fn __repr__(&self) -> String {
780 format!(
781 "Dataset(n_events={}, n_events_local={}, p4_names={:?}, aux_names={:?})",
782 self.0.n_events_global(),
783 self.0.n_events_local(),
784 self.0.p4_names(),
785 self.0.aux_names()
786 )
787 }
788
789 fn __str__(&self) -> String {
790 self.__repr__()
791 }
792
793 #[getter]
806 fn n_events(&self) -> usize {
807 self.0.n_events()
808 }
809 #[getter]
811 fn n_events_global(&self) -> usize {
812 self.0.n_events_global()
813 }
814 #[getter]
816 fn p4_names(&self) -> Vec<String> {
817 self.0.p4_names().to_vec()
818 }
819 #[getter]
821 fn aux_names(&self) -> Vec<String> {
822 self.0.aux_names().to_vec()
823 }
824
825 #[getter]
837 fn n_events_weighted(&self) -> f64 {
838 self.0.n_events_weighted()
839 }
840 #[getter]
842 fn n_events_weighted_global(&self) -> f64 {
843 self.0.n_events_weighted_global()
844 }
845 #[getter]
858 fn n_events_weighted_local(&self) -> f64 {
859 self.0.n_events_weighted_local()
860 }
861 #[getter]
873 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
874 PyArray1::from_slice(py, &self.0.weights())
875 }
876 #[getter]
878 fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
879 PyArray1::from_slice(py, &self.0.weights_global())
880 }
881 #[getter]
893 fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
894 PyArray1::from_slice(py, &self.0.weights_local())
895 }
896 #[getter]
898 fn events_global(&self) -> PyDatasetEventsGlobalIter {
899 PyDatasetEventsGlobalIter {
900 iter: self.0.shared_iter_global(),
901 }
902 }
903 #[getter]
909 fn events_local(&self) -> PyDatasetEventsLocalIter {
910 PyDatasetEventsLocalIter {
911 dataset: self.0.clone(),
912 index: 0,
913 }
914 }
915 fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
917 self.0
918 .p4_by_name(index, name)
919 .map(PyVec4)
920 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
921 }
922 fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
924 self.0
925 .aux_by_name(index, name)
926 .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
927 }
928 fn event_global(&self, index: usize) -> PyResult<PyEvent> {
934 let event = self
935 .0
936 .event_global(index)
937 .map_err(|_| PyIndexError::new_err("index out of range"))?;
938 Ok(PyEvent {
939 event,
940 has_metadata: true,
941 })
942 }
943 fn __getitem__<'py>(
944 &self,
945 py: Python<'py>,
946 index: Bound<'py, PyAny>,
947 ) -> PyResult<Bound<'py, PyAny>> {
948 if let Ok(value) = self.evaluate(py, index.clone()) {
949 value.into_bound_py_any(py)
950 } else if let Ok(index) = index.extract::<usize>() {
951 let event = self
952 .0
953 .event_global(index)
954 .map_err(|_| PyIndexError::new_err("index out of range"))?;
955 PyEvent {
956 event,
957 has_metadata: true,
958 }
959 .into_bound_py_any(py)
960 } else {
961 Err(PyTypeError::new_err(
962 "Unsupported index type (int or Variable)",
963 ))
964 }
965 }
966 #[pyo3(signature = (variable, bins, range))]
1004 fn bin_by(
1005 &self,
1006 variable: Bound<'_, PyAny>,
1007 bins: usize,
1008 range: (f64, f64),
1009 ) -> PyResult<PyBinnedDataset> {
1010 let py_variable = variable.extract::<PyVariable>()?;
1011 let bound_variable = py_variable.bound(self.0.metadata())?;
1012 Ok(PyBinnedDataset(self.0.bin_by(
1013 bound_variable,
1014 bins,
1015 range,
1016 )?))
1017 }
1018 pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
1036 Ok(PyDataset(
1037 self.0.filter(&expression.0).map_err(PyErr::from)?,
1038 ))
1039 }
1040 fn bootstrap(&self, seed: usize) -> PyDataset {
1061 PyDataset(self.0.bootstrap(seed))
1062 }
1063 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
1081 PyDataset(self.0.boost_to_rest_frame_of(&names))
1082 }
1083 fn evaluate<'py>(
1101 &self,
1102 py: Python<'py>,
1103 variable: Bound<'py, PyAny>,
1104 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
1105 let variable = variable.extract::<PyVariable>()?;
1106 let bound_variable = variable.bound(self.0.metadata())?;
1107 let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
1108 Ok(PyArray1::from_vec(py, values))
1109 }
1110}
1111
1112#[pyfunction]
1126#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
1127pub fn read_parquet(
1128 path: Bound<PyAny>,
1129 p4s: Option<Vec<String>>,
1130 aux: Option<Vec<String>>,
1131 aliases: Option<Bound<PyDict>>,
1132) -> PyResult<PyDataset> {
1133 let path_str = parse_dataset_path(path)?;
1134 let mut read_options = DatasetReadOptions::default();
1135 if let Some(p4s) = p4s {
1136 read_options = read_options.p4_names(p4s);
1137 }
1138 if let Some(aux) = aux {
1139 read_options = read_options.aux_names(aux);
1140 }
1141 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1142 read_options = read_options.alias(alias_name, selection);
1143 }
1144 let dataset = core_read_parquet(&path_str, &read_options)?;
1145 Ok(PyDataset(dataset))
1146}
1147
1148#[pyfunction]
1150#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
1151pub fn read_parquet_chunked(
1152 path: Bound<PyAny>,
1153 p4s: Option<Vec<String>>,
1154 aux: Option<Vec<String>>,
1155 aliases: Option<Bound<PyDict>>,
1156 chunk_size: Option<usize>,
1157) -> PyResult<PyParquetChunkIter> {
1158 let path_str = parse_dataset_path(path)?;
1159 let mut read_options = DatasetReadOptions::default();
1160 if let Some(p4s) = p4s {
1161 read_options = read_options.p4_names(p4s);
1162 }
1163 if let Some(aux) = aux {
1164 read_options = read_options.aux_names(aux);
1165 }
1166 if let Some(chunk_size) = chunk_size {
1167 read_options = read_options.chunk_size(chunk_size);
1168 }
1169 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1170 read_options = read_options.alias(alias_name, selection);
1171 }
1172
1173 let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
1174 Ok(PyParquetChunkIter {
1175 chunks: Box::new(chunks),
1176 })
1177}
1178
1179#[pyfunction]
1193#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
1194pub fn read_root(
1195 path: Bound<PyAny>,
1196 tree: Option<String>,
1197 p4s: Option<Vec<String>>,
1198 aux: Option<Vec<String>>,
1199 aliases: Option<Bound<PyDict>>,
1200) -> PyResult<PyDataset> {
1201 let path_str = parse_dataset_path(path)?;
1202 let mut read_options = DatasetReadOptions::default();
1203 if let Some(p4s) = p4s {
1204 read_options = read_options.p4_names(p4s);
1205 }
1206 if let Some(aux) = aux {
1207 read_options = read_options.aux_names(aux);
1208 }
1209 if let Some(tree) = tree {
1210 read_options = read_options.tree(tree);
1211 }
1212 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1213 read_options = read_options.alias(alias_name, selection);
1214 }
1215 let dataset = core_read_root(&path_str, &read_options)?;
1216 Ok(PyDataset(dataset))
1217}
1218
1219#[pyfunction]
1221#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
1222pub fn write_parquet(
1223 dataset: &PyDataset,
1224 path: Bound<PyAny>,
1225 chunk_size: Option<usize>,
1226 precision: &str,
1227) -> PyResult<()> {
1228 let path_str = parse_dataset_path(path)?;
1229 let mut write_options = DatasetWriteOptions::default();
1230 if let Some(size) = chunk_size {
1231 write_options.batch_size = size.max(1);
1232 }
1233 write_options.precision = parse_precision_arg(Some(precision))?;
1234 core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1235}
1236
1237#[pyfunction]
1239#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
1240pub fn write_root(
1241 dataset: &PyDataset,
1242 path: Bound<PyAny>,
1243 tree: Option<String>,
1244 chunk_size: Option<usize>,
1245 precision: &str,
1246) -> PyResult<()> {
1247 let path_str = parse_dataset_path(path)?;
1248 let mut write_options = DatasetWriteOptions::default();
1249 if let Some(name) = tree {
1250 write_options.tree = Some(name);
1251 }
1252 if let Some(size) = chunk_size {
1253 write_options.batch_size = size.max(1);
1254 }
1255 write_options.precision = parse_precision_arg(Some(precision))?;
1256 core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1257}
1258
1259#[doc(hidden)]
1260#[pyfunction]
1288#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
1289pub fn from_columns(
1290 columns: Bound<'_, PyDict>,
1291 p4s: Option<Vec<String>>,
1292 aux: Option<Vec<String>>,
1293 aliases: Option<Bound<'_, PyDict>>,
1294) -> PyResult<PyDataset> {
1295 let column_names = columns
1296 .iter()
1297 .map(|(key, _)| key.extract::<String>())
1298 .collect::<PyResult<Vec<_>>>()?;
1299
1300 let (detected_p4_names, detected_aux_names) =
1301 infer_p4_and_aux_names_from_columns(&column_names);
1302 let p4_names = p4s.unwrap_or(detected_p4_names);
1303 if p4_names.is_empty() {
1304 let mut partial_components: std::collections::BTreeMap<
1305 String,
1306 std::collections::BTreeSet<&str>,
1307 > = std::collections::BTreeMap::new();
1308 for column_name in &column_names {
1309 let lowered = column_name.to_ascii_lowercase();
1310 for suffix in P4_COMPONENT_SUFFIXES {
1311 if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
1312 let prefix = column_name[..column_name.len() - suffix.len()].to_string();
1313 partial_components.entry(prefix).or_default().insert(suffix);
1314 }
1315 }
1316 }
1317 if let Some((prefix, present)) = partial_components.iter().next() {
1318 if present.len() < P4_COMPONENT_SUFFIXES.len() {
1319 let missing = P4_COMPONENT_SUFFIXES
1320 .iter()
1321 .filter(|suffix| !present.contains(**suffix))
1322 .map(|suffix| format!("{prefix}{suffix}"))
1323 .collect::<Vec<_>>()
1324 .join(", ");
1325 return Err(PyKeyError::new_err(format!(
1326 "Missing components [{missing}] for four-momentum '{prefix}'"
1327 )));
1328 }
1329 }
1330 return Err(PyValueError::new_err(
1331 "No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
1332 ));
1333 }
1334
1335 let aux_names = aux.unwrap_or(detected_aux_names);
1336 let p4_component_columns =
1337 resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
1338 let resolved_aux_columns =
1339 resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
1340
1341 let n_events = {
1342 let first_name = p4_component_columns
1343 .first()
1344 .map(|components| components[0].clone())
1345 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
1346 let values = extract_numeric_column(
1347 columns
1348 .get_item(first_name.as_str())?
1349 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
1350 &first_name,
1351 )?;
1352 values.len()
1353 };
1354
1355 let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
1356 for component_names in &p4_component_columns {
1357 let px = extract_numeric_column(
1358 columns
1359 .get_item(component_names[0].as_str())?
1360 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
1361 component_names[0].as_str(),
1362 )?;
1363 let py = extract_numeric_column(
1364 columns
1365 .get_item(component_names[1].as_str())?
1366 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
1367 component_names[1].as_str(),
1368 )?;
1369 let pz = extract_numeric_column(
1370 columns
1371 .get_item(component_names[2].as_str())?
1372 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
1373 component_names[2].as_str(),
1374 )?;
1375 let e = extract_numeric_column(
1376 columns
1377 .get_item(component_names[3].as_str())?
1378 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
1379 component_names[3].as_str(),
1380 )?;
1381 if px.len() != n_events
1382 || py.len() != n_events
1383 || pz.len() != n_events
1384 || e.len() != n_events
1385 {
1386 return Err(PyValueError::new_err(
1387 "All p4 components must have the same length",
1388 ));
1389 }
1390 p4_columns.push([px, py, pz, e]);
1391 }
1392
1393 let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
1394 for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
1395 let values = extract_numeric_column(
1396 columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
1397 PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
1398 })?,
1399 aux_name,
1400 )?;
1401 if values.len() != n_events {
1402 return Err(PyValueError::new_err(format!(
1403 "Auxiliary column '{aux_name}' length does not match p4 columns"
1404 )));
1405 }
1406 aux_columns.push(values);
1407 }
1408
1409 let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
1410 let weight_values = columns
1411 .get_item(weight_column_name.as_str())?
1412 .ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
1413 let values = extract_numeric_column(weight_values, "weight")?;
1414 if values.len() != n_events {
1415 return Err(PyValueError::new_err(
1416 "Column 'weight' length does not match p4 columns",
1417 ));
1418 }
1419 values
1420 } else {
1421 vec![1.0; n_events]
1422 };
1423
1424 let parsed_aliases = parse_aliases(aliases)?;
1425 let mut metadata =
1426 DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
1427 if !parsed_aliases.is_empty() {
1428 metadata
1429 .add_p4_aliases(
1430 parsed_aliases
1431 .into_iter()
1432 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
1433 )
1434 .map_err(PyErr::from)?;
1435 }
1436
1437 let p4_columns = p4_columns
1438 .into_iter()
1439 .map(|components| {
1440 (0..n_events)
1441 .map(|event_idx| {
1442 laddu_core::vectors::Vec4::new(
1443 components[0][event_idx],
1444 components[1][event_idx],
1445 components[2][event_idx],
1446 components[3][event_idx],
1447 )
1448 })
1449 .collect::<Vec<_>>()
1450 })
1451 .collect::<Vec<_>>();
1452
1453 Ok(PyDataset(Arc::new(Dataset::from_columns_global(
1454 metadata,
1455 p4_columns,
1456 aux_columns,
1457 weights,
1458 )?)))
1459}
1460
1461#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
1470pub struct PyBinnedDataset(BinnedDataset);
1471
1472#[pymethods]
1473impl PyBinnedDataset {
1474 fn __len__(&self) -> usize {
1475 self.0.n_bins()
1476 }
1477 #[getter]
1480 fn n_bins(&self) -> usize {
1481 self.0.n_bins()
1482 }
1483 #[getter]
1486 fn range(&self) -> (f64, f64) {
1487 self.0.range()
1488 }
1489 #[getter]
1492 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
1493 PyArray1::from_slice(py, &self.0.edges())
1494 }
1495 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
1496 self.0
1497 .get(index)
1498 .ok_or(PyIndexError::new_err("index out of range"))
1499 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
1500 }
1501
1502 fn __repr__(&self) -> String {
1503 format!(
1504 "BinnedDataset(n_bins={}, range={:?})",
1505 self.0.n_bins(),
1506 self.0.range()
1507 )
1508 }
1509
1510 fn __str__(&self) -> String {
1511 self.__repr__()
1512 }
1513}