1use std::{path::PathBuf, sync::Arc};
2
3use laddu_core::{
4 data::{
5 io::{
6 infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
7 resolve_optional_weight_column, resolve_p4_component_columns, ParquetBatchWriter,
8 P4_COMPONENT_SUFFIXES,
9 },
10 read_parquet as core_read_parquet,
11 read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
12 read_root as core_read_root, write_parquet as core_write_parquet,
13 write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
14 DatasetWriteOptions, EventData, FloatPrecision, OwnedEvent, SharedDatasetIterExt,
15 },
16 variables::IntoP4Selection,
17 DatasetReadOptions, Vec4,
18};
19use numpy::{PyArray1, PyReadonlyArray1};
20use pyo3::{
21 exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
22 prelude::*,
23 types::{PyDict, PyList},
24 IntoPyObjectExt,
25};
26
27use crate::{
28 variables::{PyVariable, PyVariableExpression},
29 vectors::PyVec4,
30};
31
32fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
33 let Some(aliases) = aliases else {
34 return Ok(Vec::new());
35 };
36
37 let mut parsed = Vec::new();
38 for (key, value) in aliases.iter() {
39 let alias_name = key.extract::<String>()?;
40 let selection = if let Ok(single) = value.extract::<String>() {
41 vec![single]
42 } else {
43 let seq = value.extract::<Vec<String>>().map_err(|_| {
44 PyTypeError::new_err("Alias values must be a string or a sequence of strings")
45 })?;
46 if seq.is_empty() {
47 return Err(PyValueError::new_err(format!(
48 "Alias '{alias_name}' must reference at least one particle",
49 )));
50 }
51 seq
52 };
53 parsed.push((alias_name, selection));
54 }
55
56 Ok(parsed)
57}
58
59fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
60 if let Ok(s) = path.extract::<String>() {
61 Ok(s)
62 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
63 Ok(pathbuf.to_string_lossy().into_owned())
64 } else {
65 Err(PyTypeError::new_err("Expected str or Path"))
66 }
67}
68
69fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
70 match value.map(|v| v.to_ascii_lowercase()) {
71 None => Ok(FloatPrecision::F64),
72 Some(name) if name == "f64" || name == "float64" || name == "double" => {
73 Ok(FloatPrecision::F64)
74 }
75 Some(name) if name == "f32" || name == "float32" || name == "float" => {
76 Ok(FloatPrecision::F32)
77 }
78 Some(other) => Err(PyValueError::new_err(format!(
79 "Unsupported precision '{other}' (expected 'f64' or 'f32')"
80 ))),
81 }
82}
83
84fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
85 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
86 return Ok(array.as_slice()?.to_vec());
87 }
88 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
89 return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
90 }
91 if let Ok(values) = value.extract::<Vec<f64>>() {
92 return Ok(values);
93 }
94 if let Ok(values) = value.extract::<Vec<f32>>() {
95 return Ok(values.into_iter().map(|v| v as f64).collect());
96 }
97 if let Ok(list) = value.cast::<PyList>() {
98 let mut converted = Vec::with_capacity(list.len());
99 for item in list.iter() {
100 converted.push(item.extract::<f64>().map_err(|_| {
101 PyTypeError::new_err(format!(
102 "Column '{name}' must be numeric (float32/float64/list of floats)"
103 ))
104 })?);
105 }
106 return Ok(converted);
107 }
108 Err(PyTypeError::new_err(format!(
109 "Column '{name}' must be numeric (float32/float64/list of floats)"
110 )))
111}
112
113fn metadata_from_names_and_aliases(
114 p4_names: Vec<String>,
115 aux_names: Vec<String>,
116 aliases: Option<Bound<'_, PyDict>>,
117) -> PyResult<DatasetMetadata> {
118 let parsed_aliases = parse_aliases(aliases)?;
119 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
120 if !parsed_aliases.is_empty() {
121 metadata
122 .add_p4_aliases(
123 parsed_aliases
124 .into_iter()
125 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
126 )
127 .map_err(PyErr::from)?;
128 }
129 Ok(metadata)
130}
131
132fn parse_p4_mapping(p4: Bound<'_, PyDict>) -> PyResult<Vec<(String, Vec4)>> {
133 p4.iter()
134 .map(|(key, value)| {
135 Ok((
136 key.extract::<String>()?,
137 value
138 .extract::<PyVec4>()
139 .map_err(|_| PyTypeError::new_err("p4 values must be laddu.Vec4 instances"))?
140 .0,
141 ))
142 })
143 .collect()
144}
145
146fn parse_aux_mapping(aux: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, f64)>> {
147 let Some(aux) = aux else {
148 return Ok(Vec::new());
149 };
150 aux.iter()
151 .map(|(key, value)| Ok((key.extract::<String>()?, value.extract::<f64>()?)))
152 .collect()
153}
154
155fn parse_p4_column(values: Vec<PyVec4>) -> Vec<Vec4> {
156 values.into_iter().map(|value| value.0).collect()
157}
158
159fn dataset_from_py_events(
160 events: Vec<PyEvent>,
161 p4_names: Option<Vec<String>>,
162 aux_names: Option<Vec<String>>,
163 aliases: Option<Bound<PyDict>>,
164 global: bool,
165) -> PyResult<PyDataset> {
166 let inferred_metadata = events
167 .iter()
168 .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
169
170 let aliases = parse_aliases(aliases)?;
171 let use_explicit_metadata = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
172
173 let metadata = if use_explicit_metadata {
174 let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
175 (Some(names), _) => names,
176 (None, Some(metadata)) => metadata.p4_names().to_vec(),
177 (None, None) => Vec::new(),
178 };
179 let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
180 (Some(names), _) => names,
181 (None, Some(metadata)) => metadata.aux_names().to_vec(),
182 (None, None) => Vec::new(),
183 };
184
185 if !aliases.is_empty() && resolved_p4_names.is_empty() {
186 return Err(PyValueError::new_err(
187 "`aliases` requires `p4_names` or events with metadata for resolution",
188 ));
189 }
190
191 let mut metadata =
192 DatasetMetadata::new(resolved_p4_names, resolved_aux_names).map_err(PyErr::from)?;
193 if !aliases.is_empty() {
194 metadata
195 .add_p4_aliases(
196 aliases
197 .into_iter()
198 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
199 )
200 .map_err(PyErr::from)?;
201 }
202 Some(Arc::new(metadata))
203 } else {
204 inferred_metadata
205 };
206
207 let events: Vec<Arc<EventData>> = events
208 .into_iter()
209 .map(|event| event.event.data_arc())
210 .collect();
211 let dataset = match (metadata, global) {
212 (Some(metadata), true) => Dataset::new_with_metadata(events, metadata),
213 (Some(metadata), false) => Dataset::new_local(events, metadata),
214 (None, true) => Dataset::new(events),
215 (None, false) => Dataset::new_local(events, Arc::new(DatasetMetadata::default())),
216 };
217 Ok(PyDataset(Arc::new(dataset)))
218}
219
220#[pyclass(name = "Event", module = "laddu", from_py_object)]
258#[derive(Clone)]
259pub struct PyEvent {
260 pub event: OwnedEvent,
261 has_metadata: bool,
262}
263
264#[pymethods]
265impl PyEvent {
266 #[new]
267 #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
268 fn new(
269 p4s: Vec<PyVec4>,
270 aux: Vec<f64>,
271 weight: f64,
272 p4_names: Option<Vec<String>>,
273 aux_names: Option<Vec<String>>,
274 aliases: Option<Bound<PyDict>>,
275 ) -> PyResult<Self> {
276 let event = EventData {
277 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
278 aux,
279 weight,
280 };
281 let aliases = parse_aliases(aliases)?;
282
283 let missing_p4_names = p4_names
284 .as_ref()
285 .map(|names| names.is_empty())
286 .unwrap_or(true);
287
288 if !aliases.is_empty() && missing_p4_names {
289 return Err(PyValueError::new_err(
290 "`aliases` requires `p4_names` so selections can be resolved",
291 ));
292 }
293
294 let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
295 let metadata = if metadata_provided {
296 let p4_names = p4_names.unwrap_or_default();
297 let aux_names = aux_names.unwrap_or_default();
298 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
299 if !aliases.is_empty() {
300 metadata
301 .add_p4_aliases(
302 aliases.into_iter().map(|(alias_name, selection)| {
303 (alias_name, selection.into_selection())
304 }),
305 )
306 .map_err(PyErr::from)?;
307 }
308 Arc::new(metadata)
309 } else {
310 Arc::new(DatasetMetadata::empty())
311 };
312 let event = OwnedEvent::new(Arc::new(event), metadata);
313 Ok(Self {
314 event,
315 has_metadata: metadata_provided,
316 })
317 }
318 fn __str__(&self) -> String {
319 self.event.to_string()
320 }
321 fn __repr__(&self) -> String {
322 self.__str__()
323 }
324 #[getter]
327 fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
328 self.ensure_metadata()?;
329 let mapping = PyDict::new(py);
330 for (name, vec4) in self.event.p4s() {
331 mapping.set_item(name, PyVec4(vec4))?;
332 }
333 Ok(mapping.into())
334 }
335 #[getter]
338 #[pyo3(name = "aux")]
339 fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
340 self.ensure_metadata()?;
341 let mapping = PyDict::new(py);
342 for (name, value) in self.event.aux() {
343 mapping.set_item(name, value)?;
344 }
345 Ok(mapping.into())
346 }
347 #[getter]
350 fn get_weight(&self) -> f64 {
351 self.event.weight()
352 }
353 fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
366 let indices = self.resolve_p4_indices(&names)?;
367 Ok(PyVec4(self.event.data().get_p4_sum(indices)))
368 }
369 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
383 let indices = self.resolve_p4_indices(&names)?;
384 let boosted = self.event.data().boost_to_rest_frame_of(indices);
385 Ok(Self {
386 event: OwnedEvent::new(Arc::new(boosted), self.event.metadata_arc()),
387 has_metadata: self.has_metadata,
388 })
389 }
390 fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
420 let mut variable = variable.extract::<PyVariable>()?;
421 let metadata = self.ensure_metadata()?;
422 variable.bind_in_place(metadata)?;
423 variable.evaluate_event(&self.event)
424 }
425
426 fn p4(&self, name: &str) -> PyResult<PyVec4> {
428 self.ensure_metadata()?;
429 self.event
430 .p4(name)
431 .map(PyVec4)
432 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
433 }
434}
435
436impl PyEvent {
437 fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
438 if !self.has_metadata {
439 Err(PyValueError::new_err(
440 "Event has no associated metadata for name-based operations",
441 ))
442 } else {
443 Ok(self.event.metadata())
444 }
445 }
446
447 fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
448 let metadata = self.ensure_metadata()?;
449 let mut resolved = Vec::new();
450 for name in names {
451 let selection = metadata
452 .p4_selection(name)
453 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
454 resolved.extend_from_slice(selection.indices());
455 }
456 Ok(resolved)
457 }
458
459 pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
460 self.has_metadata.then(|| self.event.metadata())
461 }
462}
463
464#[doc(hidden)]
465#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
506#[derive(Clone)]
507pub struct PyDataset(pub Arc<Dataset>);
508
509#[pyclass(
510 name = "ParquetChunkIter",
511 module = "laddu",
512 unsendable,
513 skip_from_py_object
514)]
515pub struct PyParquetChunkIter {
516 chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
517}
518
519#[pymethods]
520impl PyParquetChunkIter {
521 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
522 slf.into()
523 }
524
525 fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
526 match self.chunks.next() {
527 Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
528 Some(Err(err)) => Err(PyErr::from(err)),
529 None => Ok(None),
530 }
531 }
532}
533
534#[pyclass(name = "ParquetBatchWriter", module = "laddu", unsendable)]
536pub struct PyParquetBatchWriter {
537 writer: Option<ParquetBatchWriter>,
538}
539
540#[pymethods]
541impl PyParquetBatchWriter {
542 fn write(&mut self, dataset: &PyDataset) -> PyResult<()> {
544 self.writer_mut()?
545 .write(dataset.0.as_ref())
546 .map_err(PyErr::from)
547 }
548
549 fn close(&mut self) -> PyResult<()> {
551 if let Some(writer) = &mut self.writer {
552 writer.close()?;
553 }
554 self.writer = None;
555 Ok(())
556 }
557
558 fn __enter__(slf: Py<Self>) -> Py<Self> {
559 slf
560 }
561
562 fn __exit__(
563 &mut self,
564 _exc_type: Bound<'_, PyAny>,
565 _exc_value: Bound<'_, PyAny>,
566 _traceback: Bound<'_, PyAny>,
567 ) -> PyResult<bool> {
568 self.close()?;
569 Ok(false)
570 }
571}
572
573impl PyParquetBatchWriter {
574 fn new(writer: ParquetBatchWriter) -> Self {
575 Self {
576 writer: Some(writer),
577 }
578 }
579
580 fn writer_mut(&mut self) -> PyResult<&mut ParquetBatchWriter> {
581 self.writer
582 .as_mut()
583 .ok_or_else(|| PyValueError::new_err("ParquetBatchWriter is closed"))
584 }
585}
586
587#[pyclass(
588 name = "DatasetEventsGlobal",
589 module = "laddu",
590 unsendable,
591 skip_from_py_object
592)]
593pub struct PyDatasetEventsGlobalIter {
594 iter: DatasetArcIter,
595}
596
597#[pymethods]
598impl PyDatasetEventsGlobalIter {
599 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsGlobalIter> {
600 slf.into()
601 }
602
603 fn __next__(&mut self) -> Option<PyEvent> {
604 self.iter.next().map(|rust_event| PyEvent {
605 event: rust_event,
606 has_metadata: true,
607 })
608 }
609}
610
611#[pyclass(
612 name = "DatasetEventsLocal",
613 module = "laddu",
614 unsendable,
615 skip_from_py_object
616)]
617pub struct PyDatasetEventsLocalIter {
618 dataset: Arc<Dataset>,
619 index: usize,
620}
621
622#[pymethods]
623impl PyDatasetEventsLocalIter {
624 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsLocalIter> {
625 slf.into()
626 }
627
628 fn __next__(&mut self) -> Option<PyEvent> {
629 if self.index >= self.dataset.n_events_local() {
630 return None;
631 }
632 let event = self
633 .dataset
634 .event_local(self.index)
635 .expect("local event index should exist")
636 .to_event_data();
637 self.index += 1;
638 Some(PyEvent {
639 event: OwnedEvent::new(Arc::new(event), self.dataset.metadata_arc()),
640 has_metadata: true,
641 })
642 }
643}
644
645#[pymethods]
646impl PyDataset {
647 #[new]
648 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
649 fn new(
650 events: Vec<PyEvent>,
651 p4_names: Option<Vec<String>>,
652 aux_names: Option<Vec<String>>,
653 aliases: Option<Bound<PyDict>>,
654 ) -> PyResult<Self> {
655 dataset_from_py_events(events, p4_names, aux_names, aliases, true)
656 }
657
658 #[staticmethod]
660 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
661 fn from_events_local(
662 events: Vec<PyEvent>,
663 p4_names: Option<Vec<String>>,
664 aux_names: Option<Vec<String>>,
665 aliases: Option<Bound<PyDict>>,
666 ) -> PyResult<Self> {
667 dataset_from_py_events(events, p4_names, aux_names, aliases, false)
668 }
669
670 #[staticmethod]
675 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
676 fn from_events_global(
677 events: Vec<PyEvent>,
678 p4_names: Option<Vec<String>>,
679 aux_names: Option<Vec<String>>,
680 aliases: Option<Bound<PyDict>>,
681 ) -> PyResult<Self> {
682 dataset_from_py_events(events, p4_names, aux_names, aliases, true)
683 }
684
685 #[staticmethod]
696 #[pyo3(signature = (*, p4_names, aux_names=None, aliases=None))]
697 fn empty_local(
698 p4_names: Vec<String>,
699 aux_names: Option<Vec<String>>,
700 aliases: Option<Bound<'_, PyDict>>,
701 ) -> PyResult<Self> {
702 let metadata =
703 metadata_from_names_and_aliases(p4_names, aux_names.unwrap_or_default(), aliases)?;
704 Ok(Self(Arc::new(Dataset::empty_local(metadata))))
705 }
706
707 #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
718 fn push_event_local(
719 &mut self,
720 p4: Bound<'_, PyDict>,
721 aux: Option<Bound<'_, PyDict>>,
722 weight: f64,
723 ) -> PyResult<()> {
724 let p4 = parse_p4_mapping(p4)?;
725 let aux = parse_aux_mapping(aux)?;
726 Arc::make_mut(&mut self.0)
727 .push_event_named_local(p4, aux, weight)
728 .map_err(PyErr::from)
729 }
730
731 #[pyo3(signature = (*, p4, aux=None, weight=1.0))]
736 fn push_event_global(
737 &mut self,
738 p4: Bound<'_, PyDict>,
739 aux: Option<Bound<'_, PyDict>>,
740 weight: f64,
741 ) -> PyResult<()> {
742 let p4 = parse_p4_mapping(p4)?;
743 let aux = parse_aux_mapping(aux)?;
744 Arc::make_mut(&mut self.0)
745 .push_event_named_global(p4, aux, weight)
746 .map_err(PyErr::from)
747 }
748
749 #[pyo3(signature = (name, values))]
751 fn add_p4_column_local(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
752 Arc::make_mut(&mut self.0)
753 .add_p4_column_local(name, parse_p4_column(values))
754 .map_err(PyErr::from)
755 }
756
757 #[pyo3(signature = (name, values))]
759 fn add_aux_column_local(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
760 let values = extract_numeric_column(values, &name)?;
761 Arc::make_mut(&mut self.0)
762 .add_aux_column_local(name, values)
763 .map_err(PyErr::from)
764 }
765
766 #[pyo3(signature = (name, values))]
771 fn add_p4_column_global(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
772 Arc::make_mut(&mut self.0)
773 .add_p4_column_global(name, parse_p4_column(values))
774 .map_err(PyErr::from)
775 }
776
777 #[pyo3(signature = (name, values))]
782 fn add_aux_column_global(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
783 let values = extract_numeric_column(values, &name)?;
784 Arc::make_mut(&mut self.0)
785 .add_aux_column_global(name, values)
786 .map_err(PyErr::from)
787 }
788
789 fn __len__(&self) -> usize {
790 self.0.n_events()
791 }
792 fn __iter__(&self) -> PyResult<()> {
793 Err(PyTypeError::new_err(
794 "Dataset iteration is explicit; use dataset.events_local or dataset.events_global",
795 ))
796 }
797 #[getter]
799 fn n_events_local(&self) -> usize {
800 self.0.n_events_local()
801 }
802 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
803 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
804 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
805 } else if let Ok(other_int) = other.extract::<usize>() {
806 if other_int == 0 {
807 Ok(self.clone())
808 } else {
809 Err(PyTypeError::new_err(
810 "Addition with an integer for this type is only defined for 0",
811 ))
812 }
813 } else {
814 Err(PyTypeError::new_err("Unsupported operand type for +"))
815 }
816 }
817 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
818 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
819 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
820 } else if let Ok(other_int) = other.extract::<usize>() {
821 if other_int == 0 {
822 Ok(self.clone())
823 } else {
824 Err(PyTypeError::new_err(
825 "Addition with an integer for this type is only defined for 0",
826 ))
827 }
828 } else {
829 Err(PyTypeError::new_err("Unsupported operand type for +"))
830 }
831 }
832
833 fn __repr__(&self) -> String {
834 format!(
835 "Dataset(n_events={}, n_events_local={}, p4_names={:?}, aux_names={:?})",
836 self.0.n_events_global(),
837 self.0.n_events_local(),
838 self.0.p4_names(),
839 self.0.aux_names()
840 )
841 }
842
843 fn __str__(&self) -> String {
844 self.__repr__()
845 }
846
847 #[getter]
860 fn n_events(&self) -> usize {
861 self.0.n_events()
862 }
863 #[getter]
865 fn n_events_global(&self) -> usize {
866 self.0.n_events_global()
867 }
868 #[getter]
870 fn p4_names(&self) -> Vec<String> {
871 self.0.p4_names().to_vec()
872 }
873 #[getter]
875 fn aux_names(&self) -> Vec<String> {
876 self.0.aux_names().to_vec()
877 }
878
879 #[getter]
891 fn n_events_weighted(&self) -> f64 {
892 self.0.n_events_weighted()
893 }
894 #[getter]
896 fn n_events_weighted_global(&self) -> f64 {
897 self.0.n_events_weighted_global()
898 }
899 #[getter]
912 fn n_events_weighted_local(&self) -> f64 {
913 self.0.n_events_weighted_local()
914 }
915 #[getter]
927 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
928 PyArray1::from_slice(py, &self.0.weights())
929 }
930 #[getter]
932 fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
933 PyArray1::from_slice(py, &self.0.weights_global())
934 }
935 #[getter]
947 fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
948 PyArray1::from_slice(py, &self.0.weights_local())
949 }
950 #[getter]
952 fn events_global(&self) -> PyDatasetEventsGlobalIter {
953 PyDatasetEventsGlobalIter {
954 iter: self.0.shared_iter_global(),
955 }
956 }
957 #[getter]
963 fn events_local(&self) -> PyDatasetEventsLocalIter {
964 PyDatasetEventsLocalIter {
965 dataset: self.0.clone(),
966 index: 0,
967 }
968 }
969 fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
971 self.0
972 .p4_by_name(index, name)
973 .map(PyVec4)
974 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
975 }
976 fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
978 self.0
979 .aux_by_name(index, name)
980 .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
981 }
982 fn event_global(&self, index: usize) -> PyResult<PyEvent> {
988 let event = self
989 .0
990 .event_global(index)
991 .map_err(|_| PyIndexError::new_err("index out of range"))?;
992 Ok(PyEvent {
993 event,
994 has_metadata: true,
995 })
996 }
997 fn __getitem__<'py>(
998 &self,
999 py: Python<'py>,
1000 index: Bound<'py, PyAny>,
1001 ) -> PyResult<Bound<'py, PyAny>> {
1002 if let Ok(value) = self.evaluate(py, index.clone()) {
1003 value.into_bound_py_any(py)
1004 } else if let Ok(index) = index.extract::<usize>() {
1005 let event = self
1006 .0
1007 .event_global(index)
1008 .map_err(|_| PyIndexError::new_err("index out of range"))?;
1009 PyEvent {
1010 event,
1011 has_metadata: true,
1012 }
1013 .into_bound_py_any(py)
1014 } else {
1015 Err(PyTypeError::new_err(
1016 "Unsupported index type (int or Variable)",
1017 ))
1018 }
1019 }
1020 #[pyo3(signature = (variable, bins, range))]
1058 fn bin_by(
1059 &self,
1060 variable: Bound<'_, PyAny>,
1061 bins: usize,
1062 range: (f64, f64),
1063 ) -> PyResult<PyBinnedDataset> {
1064 let py_variable = variable.extract::<PyVariable>()?;
1065 let bound_variable = py_variable.bound(self.0.metadata())?;
1066 Ok(PyBinnedDataset(self.0.bin_by(
1067 bound_variable,
1068 bins,
1069 range,
1070 )?))
1071 }
1072 pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
1090 Ok(PyDataset(
1091 self.0.filter(&expression.0).map_err(PyErr::from)?,
1092 ))
1093 }
1094 fn bootstrap(&self, seed: usize) -> PyDataset {
1115 PyDataset(self.0.bootstrap(seed))
1116 }
1117 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
1135 PyDataset(self.0.boost_to_rest_frame_of(&names))
1136 }
1137 fn evaluate<'py>(
1155 &self,
1156 py: Python<'py>,
1157 variable: Bound<'py, PyAny>,
1158 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
1159 let variable = variable.extract::<PyVariable>()?;
1160 let bound_variable = variable.bound(self.0.metadata())?;
1161 let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
1162 Ok(PyArray1::from_vec(py, values))
1163 }
1164}
1165
1166#[pyfunction]
1180#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
1181pub fn read_parquet(
1182 path: Bound<PyAny>,
1183 p4s: Option<Vec<String>>,
1184 aux: Option<Vec<String>>,
1185 aliases: Option<Bound<PyDict>>,
1186) -> PyResult<PyDataset> {
1187 let path_str = parse_dataset_path(path)?;
1188 let mut read_options = DatasetReadOptions::default();
1189 if let Some(p4s) = p4s {
1190 read_options = read_options.p4_names(p4s);
1191 }
1192 if let Some(aux) = aux {
1193 read_options = read_options.aux_names(aux);
1194 }
1195 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1196 read_options = read_options.alias(alias_name, selection);
1197 }
1198 let dataset = core_read_parquet(&path_str, &read_options)?;
1199 Ok(PyDataset(dataset))
1200}
1201
1202#[pyfunction]
1204#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
1205pub fn read_parquet_chunked(
1206 path: Bound<PyAny>,
1207 p4s: Option<Vec<String>>,
1208 aux: Option<Vec<String>>,
1209 aliases: Option<Bound<PyDict>>,
1210 chunk_size: Option<usize>,
1211) -> PyResult<PyParquetChunkIter> {
1212 let path_str = parse_dataset_path(path)?;
1213 let mut read_options = DatasetReadOptions::default();
1214 if let Some(p4s) = p4s {
1215 read_options = read_options.p4_names(p4s);
1216 }
1217 if let Some(aux) = aux {
1218 read_options = read_options.aux_names(aux);
1219 }
1220 if let Some(chunk_size) = chunk_size {
1221 read_options = read_options.chunk_size(chunk_size);
1222 }
1223 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1224 read_options = read_options.alias(alias_name, selection);
1225 }
1226
1227 let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
1228 Ok(PyParquetChunkIter {
1229 chunks: Box::new(chunks),
1230 })
1231}
1232
1233#[pyfunction]
1247#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
1248pub fn read_root(
1249 path: Bound<PyAny>,
1250 tree: Option<String>,
1251 p4s: Option<Vec<String>>,
1252 aux: Option<Vec<String>>,
1253 aliases: Option<Bound<PyDict>>,
1254) -> PyResult<PyDataset> {
1255 let path_str = parse_dataset_path(path)?;
1256 let mut read_options = DatasetReadOptions::default();
1257 if let Some(p4s) = p4s {
1258 read_options = read_options.p4_names(p4s);
1259 }
1260 if let Some(aux) = aux {
1261 read_options = read_options.aux_names(aux);
1262 }
1263 if let Some(tree) = tree {
1264 read_options = read_options.tree(tree);
1265 }
1266 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1267 read_options = read_options.alias(alias_name, selection);
1268 }
1269 let dataset = core_read_root(&path_str, &read_options)?;
1270 Ok(PyDataset(dataset))
1271}
1272
1273#[pyfunction]
1275#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
1276pub fn write_parquet(
1277 dataset: &PyDataset,
1278 path: Bound<PyAny>,
1279 chunk_size: Option<usize>,
1280 precision: &str,
1281) -> PyResult<()> {
1282 let path_str = parse_dataset_path(path)?;
1283 let mut write_options = DatasetWriteOptions::default();
1284 if let Some(size) = chunk_size {
1285 write_options.batch_size = size.max(1);
1286 }
1287 write_options.precision = parse_precision_arg(Some(precision))?;
1288 core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1289}
1290
1291#[pyfunction]
1293#[pyo3(signature = (path, *, chunk_size=None, precision="f64"))]
1294pub fn open_parquet_writer(
1295 path: Bound<PyAny>,
1296 chunk_size: Option<usize>,
1297 precision: &str,
1298) -> PyResult<PyParquetBatchWriter> {
1299 let path_str = parse_dataset_path(path)?;
1300 let mut write_options = DatasetWriteOptions::default();
1301 if let Some(size) = chunk_size {
1302 write_options.batch_size = size.max(1);
1303 }
1304 write_options.precision = parse_precision_arg(Some(precision))?;
1305 Ok(PyParquetBatchWriter::new(ParquetBatchWriter::new(
1306 &path_str,
1307 write_options,
1308 )?))
1309}
1310
1311#[pyfunction]
1313#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
1314pub fn write_root(
1315 dataset: &PyDataset,
1316 path: Bound<PyAny>,
1317 tree: Option<String>,
1318 chunk_size: Option<usize>,
1319 precision: &str,
1320) -> PyResult<()> {
1321 let path_str = parse_dataset_path(path)?;
1322 let mut write_options = DatasetWriteOptions::default();
1323 if let Some(name) = tree {
1324 write_options.tree = Some(name);
1325 }
1326 if let Some(size) = chunk_size {
1327 write_options.batch_size = size.max(1);
1328 }
1329 write_options.precision = parse_precision_arg(Some(precision))?;
1330 core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1331}
1332
1333#[doc(hidden)]
1334#[pyfunction]
1362#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
1363pub fn from_columns(
1364 columns: Bound<'_, PyDict>,
1365 p4s: Option<Vec<String>>,
1366 aux: Option<Vec<String>>,
1367 aliases: Option<Bound<'_, PyDict>>,
1368) -> PyResult<PyDataset> {
1369 let column_names = columns
1370 .iter()
1371 .map(|(key, _)| key.extract::<String>())
1372 .collect::<PyResult<Vec<_>>>()?;
1373
1374 let (detected_p4_names, detected_aux_names) =
1375 infer_p4_and_aux_names_from_columns(&column_names);
1376 let p4_names = p4s.unwrap_or(detected_p4_names);
1377 if p4_names.is_empty() {
1378 let mut partial_components: std::collections::BTreeMap<
1379 String,
1380 std::collections::BTreeSet<&str>,
1381 > = std::collections::BTreeMap::new();
1382 for column_name in &column_names {
1383 let lowered = column_name.to_ascii_lowercase();
1384 for suffix in P4_COMPONENT_SUFFIXES {
1385 if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
1386 let prefix = column_name[..column_name.len() - suffix.len()].to_string();
1387 partial_components.entry(prefix).or_default().insert(suffix);
1388 }
1389 }
1390 }
1391 if let Some((prefix, present)) = partial_components.iter().next() {
1392 if present.len() < P4_COMPONENT_SUFFIXES.len() {
1393 let missing = P4_COMPONENT_SUFFIXES
1394 .iter()
1395 .filter(|suffix| !present.contains(**suffix))
1396 .map(|suffix| format!("{prefix}{suffix}"))
1397 .collect::<Vec<_>>()
1398 .join(", ");
1399 return Err(PyKeyError::new_err(format!(
1400 "Missing components [{missing}] for four-momentum '{prefix}'"
1401 )));
1402 }
1403 }
1404 return Err(PyValueError::new_err(
1405 "No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
1406 ));
1407 }
1408
1409 let aux_names = aux.unwrap_or(detected_aux_names);
1410 let p4_component_columns =
1411 resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
1412 let resolved_aux_columns =
1413 resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
1414
1415 let n_events = {
1416 let first_name = p4_component_columns
1417 .first()
1418 .map(|components| components[0].clone())
1419 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
1420 let values = extract_numeric_column(
1421 columns
1422 .get_item(first_name.as_str())?
1423 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
1424 &first_name,
1425 )?;
1426 values.len()
1427 };
1428
1429 let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
1430 for component_names in &p4_component_columns {
1431 let px = extract_numeric_column(
1432 columns
1433 .get_item(component_names[0].as_str())?
1434 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
1435 component_names[0].as_str(),
1436 )?;
1437 let py = extract_numeric_column(
1438 columns
1439 .get_item(component_names[1].as_str())?
1440 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
1441 component_names[1].as_str(),
1442 )?;
1443 let pz = extract_numeric_column(
1444 columns
1445 .get_item(component_names[2].as_str())?
1446 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
1447 component_names[2].as_str(),
1448 )?;
1449 let e = extract_numeric_column(
1450 columns
1451 .get_item(component_names[3].as_str())?
1452 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
1453 component_names[3].as_str(),
1454 )?;
1455 if px.len() != n_events
1456 || py.len() != n_events
1457 || pz.len() != n_events
1458 || e.len() != n_events
1459 {
1460 return Err(PyValueError::new_err(
1461 "All p4 components must have the same length",
1462 ));
1463 }
1464 p4_columns.push([px, py, pz, e]);
1465 }
1466
1467 let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
1468 for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
1469 let values = extract_numeric_column(
1470 columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
1471 PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
1472 })?,
1473 aux_name,
1474 )?;
1475 if values.len() != n_events {
1476 return Err(PyValueError::new_err(format!(
1477 "Auxiliary column '{aux_name}' length does not match p4 columns"
1478 )));
1479 }
1480 aux_columns.push(values);
1481 }
1482
1483 let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
1484 let weight_values = columns
1485 .get_item(weight_column_name.as_str())?
1486 .ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
1487 let values = extract_numeric_column(weight_values, "weight")?;
1488 if values.len() != n_events {
1489 return Err(PyValueError::new_err(
1490 "Column 'weight' length does not match p4 columns",
1491 ));
1492 }
1493 values
1494 } else {
1495 vec![1.0; n_events]
1496 };
1497
1498 let parsed_aliases = parse_aliases(aliases)?;
1499 let mut metadata =
1500 DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
1501 if !parsed_aliases.is_empty() {
1502 metadata
1503 .add_p4_aliases(
1504 parsed_aliases
1505 .into_iter()
1506 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
1507 )
1508 .map_err(PyErr::from)?;
1509 }
1510
1511 let p4_columns = p4_columns
1512 .into_iter()
1513 .map(|components| {
1514 (0..n_events)
1515 .map(|event_idx| {
1516 laddu_core::vectors::Vec4::new(
1517 components[0][event_idx],
1518 components[1][event_idx],
1519 components[2][event_idx],
1520 components[3][event_idx],
1521 )
1522 })
1523 .collect::<Vec<_>>()
1524 })
1525 .collect::<Vec<_>>();
1526
1527 Ok(PyDataset(Arc::new(Dataset::from_columns_global(
1528 metadata,
1529 p4_columns,
1530 aux_columns,
1531 weights,
1532 )?)))
1533}
1534
1535#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
1544pub struct PyBinnedDataset(BinnedDataset);
1545
1546#[pymethods]
1547impl PyBinnedDataset {
1548 fn __len__(&self) -> usize {
1549 self.0.n_bins()
1550 }
1551 #[getter]
1554 fn n_bins(&self) -> usize {
1555 self.0.n_bins()
1556 }
1557 #[getter]
1560 fn range(&self) -> (f64, f64) {
1561 self.0.range()
1562 }
1563 #[getter]
1566 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
1567 PyArray1::from_slice(py, &self.0.edges())
1568 }
1569 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
1570 self.0
1571 .get(index)
1572 .ok_or(PyIndexError::new_err("index out of range"))
1573 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
1574 }
1575
1576 fn __repr__(&self) -> String {
1577 format!(
1578 "BinnedDataset(n_bins={}, range={:?})",
1579 self.0.n_bins(),
1580 self.0.range()
1581 )
1582 }
1583
1584 fn __str__(&self) -> String {
1585 self.__repr__()
1586 }
1587}