1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3 data::{
4 io::{
5 infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
6 resolve_optional_weight_column, resolve_p4_component_columns, P4_COMPONENT_SUFFIXES,
7 },
8 read_parquet as core_read_parquet,
9 read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
10 read_root as core_read_root, write_parquet as core_write_parquet,
11 write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
12 DatasetWriteOptions, Event, EventData, FloatPrecision, SharedDatasetIterExt,
13 },
14 utils::variables::IntoP4Selection,
15 DatasetReadOptions,
16};
17use numpy::{PyArray1, PyReadonlyArray1};
18use pyo3::{
19 exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
20 prelude::*,
21 types::{PyDict, PyList},
22 IntoPyObjectExt,
23};
24use std::{path::PathBuf, sync::Arc};
25
26use crate::utils::vectors::PyVec4;
27
28fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
29 let Some(aliases) = aliases else {
30 return Ok(Vec::new());
31 };
32
33 let mut parsed = Vec::new();
34 for (key, value) in aliases.iter() {
35 let alias_name = key.extract::<String>()?;
36 let selection = if let Ok(single) = value.extract::<String>() {
37 vec![single]
38 } else {
39 let seq = value.extract::<Vec<String>>().map_err(|_| {
40 PyTypeError::new_err("Alias values must be a string or a sequence of strings")
41 })?;
42 if seq.is_empty() {
43 return Err(PyValueError::new_err(format!(
44 "Alias '{alias_name}' must reference at least one particle",
45 )));
46 }
47 seq
48 };
49 parsed.push((alias_name, selection));
50 }
51
52 Ok(parsed)
53}
54
55fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
56 if let Ok(s) = path.extract::<String>() {
57 Ok(s)
58 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
59 Ok(pathbuf.to_string_lossy().into_owned())
60 } else {
61 Err(PyTypeError::new_err("Expected str or Path"))
62 }
63}
64
65fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
66 match value.map(|v| v.to_ascii_lowercase()) {
67 None => Ok(FloatPrecision::F64),
68 Some(name) if name == "f64" || name == "float64" || name == "double" => {
69 Ok(FloatPrecision::F64)
70 }
71 Some(name) if name == "f32" || name == "float32" || name == "float" => {
72 Ok(FloatPrecision::F32)
73 }
74 Some(other) => Err(PyValueError::new_err(format!(
75 "Unsupported precision '{other}' (expected 'f64' or 'f32')"
76 ))),
77 }
78}
79
80fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
81 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
82 return Ok(array.as_slice()?.to_vec());
83 }
84 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
85 return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
86 }
87 if let Ok(values) = value.extract::<Vec<f64>>() {
88 return Ok(values);
89 }
90 if let Ok(values) = value.extract::<Vec<f32>>() {
91 return Ok(values.into_iter().map(|v| v as f64).collect());
92 }
93 if let Ok(list) = value.cast::<PyList>() {
94 let mut converted = Vec::with_capacity(list.len());
95 for item in list.iter() {
96 converted.push(item.extract::<f64>().map_err(|_| {
97 PyTypeError::new_err(format!(
98 "Column '{name}' must be numeric (float32/float64/list of floats)"
99 ))
100 })?);
101 }
102 return Ok(converted);
103 }
104 Err(PyTypeError::new_err(format!(
105 "Column '{name}' must be numeric (float32/float64/list of floats)"
106 )))
107}
108
109#[pyclass(name = "Event", module = "laddu", from_py_object)]
147#[derive(Clone)]
148pub struct PyEvent {
149 pub event: Event,
150 has_metadata: bool,
151}
152
153#[pymethods]
154impl PyEvent {
155 #[new]
156 #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
157 fn new(
158 p4s: Vec<PyVec4>,
159 aux: Vec<f64>,
160 weight: f64,
161 p4_names: Option<Vec<String>>,
162 aux_names: Option<Vec<String>>,
163 aliases: Option<Bound<PyDict>>,
164 ) -> PyResult<Self> {
165 let event = EventData {
166 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
167 aux,
168 weight,
169 };
170 let aliases = parse_aliases(aliases)?;
171
172 let missing_p4_names = p4_names
173 .as_ref()
174 .map(|names| names.is_empty())
175 .unwrap_or(true);
176
177 if !aliases.is_empty() && missing_p4_names {
178 return Err(PyValueError::new_err(
179 "`aliases` requires `p4_names` so selections can be resolved",
180 ));
181 }
182
183 let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
184 let metadata = if metadata_provided {
185 let p4_names = p4_names.unwrap_or_default();
186 let aux_names = aux_names.unwrap_or_default();
187 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
188 if !aliases.is_empty() {
189 metadata
190 .add_p4_aliases(
191 aliases.into_iter().map(|(alias_name, selection)| {
192 (alias_name, selection.into_selection())
193 }),
194 )
195 .map_err(PyErr::from)?;
196 }
197 Arc::new(metadata)
198 } else {
199 Arc::new(DatasetMetadata::empty())
200 };
201 let event = Event::new(Arc::new(event), metadata);
202 Ok(Self {
203 event,
204 has_metadata: metadata_provided,
205 })
206 }
207 fn __str__(&self) -> String {
208 self.event.data().to_string()
209 }
210 #[getter]
213 fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
214 self.ensure_metadata()?;
215 let mapping = PyDict::new(py);
216 for (name, vec4) in self.event.p4s() {
217 mapping.set_item(name, PyVec4(vec4))?;
218 }
219 Ok(mapping.into())
220 }
221 #[getter]
224 #[pyo3(name = "aux")]
225 fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
226 self.ensure_metadata()?;
227 let mapping = PyDict::new(py);
228 for (name, value) in self.event.aux() {
229 mapping.set_item(name, value)?;
230 }
231 Ok(mapping.into())
232 }
233 #[getter]
236 fn get_weight(&self) -> f64 {
237 self.event.weight()
238 }
239 fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
252 let indices = self.resolve_p4_indices(&names)?;
253 Ok(PyVec4(self.event.data().get_p4_sum(indices)))
254 }
255 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
269 let indices = self.resolve_p4_indices(&names)?;
270 let boosted = self.event.data().boost_to_rest_frame_of(indices);
271 Ok(Self {
272 event: Event::new(Arc::new(boosted), self.event.metadata_arc()),
273 has_metadata: self.has_metadata,
274 })
275 }
276 fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
306 let mut variable = variable.extract::<PyVariable>()?;
307 let metadata = self.ensure_metadata()?;
308 variable.bind_in_place(metadata)?;
309 variable.evaluate_event(&self.event)
310 }
311
312 fn p4(&self, name: &str) -> PyResult<PyVec4> {
314 self.ensure_metadata()?;
315 self.event
316 .p4(name)
317 .map(PyVec4)
318 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
319 }
320}
321
322impl PyEvent {
323 fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
324 if !self.has_metadata {
325 Err(PyValueError::new_err(
326 "Event has no associated metadata for name-based operations",
327 ))
328 } else {
329 Ok(self.event.metadata())
330 }
331 }
332
333 fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
334 let metadata = self.ensure_metadata()?;
335 let mut resolved = Vec::new();
336 for name in names {
337 let selection = metadata
338 .p4_selection(name)
339 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
340 resolved.extend_from_slice(selection.indices());
341 }
342 Ok(resolved)
343 }
344
345 pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
346 self.has_metadata.then(|| self.event.metadata())
347 }
348}
349
350#[doc(hidden)]
351#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
392#[derive(Clone)]
393pub struct PyDataset(pub Arc<Dataset>);
394
395#[pyclass(
396 name = "ParquetChunkIter",
397 module = "laddu",
398 unsendable,
399 skip_from_py_object
400)]
401pub struct PyParquetChunkIter {
402 chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
403}
404
405#[pymethods]
406impl PyParquetChunkIter {
407 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
408 slf.into()
409 }
410
411 fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
412 match self.chunks.next() {
413 Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
414 Some(Err(err)) => Err(PyErr::from(err)),
415 None => Ok(None),
416 }
417 }
418}
419
420#[pyclass(
421 name = "DatasetIter",
422 module = "laddu",
423 unsendable,
424 skip_from_py_object
425)]
426struct PyDatasetIter {
427 kind: PyDatasetIterKind,
428}
429
430enum PyDatasetIterKind {
431 Local { dataset: Arc<Dataset>, index: usize },
432 Global(DatasetArcIter),
433}
434
435#[pymethods]
436impl PyDatasetIter {
437 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetIter> {
438 slf.into()
439 }
440
441 fn __next__(&mut self) -> Option<PyEvent> {
442 let event = match &mut self.kind {
443 PyDatasetIterKind::Local { dataset, index } => {
444 let event = dataset.events_local().get(*index)?.clone();
445 *index += 1;
446 event
447 }
448 PyDatasetIterKind::Global(iterator) => iterator.next()?,
449 };
450 Some(PyEvent {
451 event,
452 has_metadata: true,
453 })
454 }
455}
456
457#[pymethods]
458impl PyDataset {
459 #[new]
460 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
461 fn new(
462 events: Vec<PyEvent>,
463 p4_names: Option<Vec<String>>,
464 aux_names: Option<Vec<String>>,
465 aliases: Option<Bound<PyDict>>,
466 ) -> PyResult<Self> {
467 let inferred_metadata = events
468 .iter()
469 .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
470
471 let aliases = parse_aliases(aliases)?;
472 let use_explicit_metadata =
473 p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
474
475 let metadata =
476 if use_explicit_metadata {
477 let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
478 (Some(names), _) => names,
479 (None, Some(metadata)) => metadata.p4_names().to_vec(),
480 (None, None) => Vec::new(),
481 };
482 let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
483 (Some(names), _) => names,
484 (None, Some(metadata)) => metadata.aux_names().to_vec(),
485 (None, None) => Vec::new(),
486 };
487
488 if !aliases.is_empty() && resolved_p4_names.is_empty() {
489 return Err(PyValueError::new_err(
490 "`aliases` requires `p4_names` or events with metadata for resolution",
491 ));
492 }
493
494 let mut metadata = DatasetMetadata::new(resolved_p4_names, resolved_aux_names)
495 .map_err(PyErr::from)?;
496 if !aliases.is_empty() {
497 metadata
498 .add_p4_aliases(aliases.into_iter().map(|(alias_name, selection)| {
499 (alias_name, selection.into_selection())
500 }))
501 .map_err(PyErr::from)?;
502 }
503 Some(Arc::new(metadata))
504 } else {
505 inferred_metadata
506 };
507
508 let events: Vec<Arc<EventData>> = events
509 .into_iter()
510 .map(|event| event.event.data_arc())
511 .collect();
512 let dataset = if let Some(metadata) = metadata {
513 Dataset::new_with_metadata(events, metadata)
514 } else {
515 Dataset::new(events)
516 };
517 Ok(Self(Arc::new(dataset)))
518 }
519
520 fn __len__(&self) -> usize {
521 self.0.n_events()
522 }
523 fn __iter__(&self) -> PyDatasetIter {
531 self.iter_global()
532 }
533 #[getter]
535 fn n_events_local(&self) -> usize {
536 self.0.n_events_local()
537 }
538 fn iter_local(&self) -> PyDatasetIter {
546 PyDatasetIter {
547 kind: PyDatasetIterKind::Local {
548 dataset: self.0.clone(),
549 index: 0,
550 },
551 }
552 }
553 fn iter_global(&self) -> PyDatasetIter {
561 PyDatasetIter {
562 kind: PyDatasetIterKind::Global(self.0.shared_iter_global()),
563 }
564 }
565 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
566 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
567 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
568 } else if let Ok(other_int) = other.extract::<usize>() {
569 if other_int == 0 {
570 Ok(self.clone())
571 } else {
572 Err(PyTypeError::new_err(
573 "Addition with an integer for this type is only defined for 0",
574 ))
575 }
576 } else {
577 Err(PyTypeError::new_err("Unsupported operand type for +"))
578 }
579 }
580 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
581 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
582 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
583 } else if let Ok(other_int) = other.extract::<usize>() {
584 if other_int == 0 {
585 Ok(self.clone())
586 } else {
587 Err(PyTypeError::new_err(
588 "Addition with an integer for this type is only defined for 0",
589 ))
590 }
591 } else {
592 Err(PyTypeError::new_err("Unsupported operand type for +"))
593 }
594 }
595 #[getter]
608 fn n_events(&self) -> usize {
609 self.0.n_events()
610 }
611 #[getter]
613 fn n_events_global(&self) -> usize {
614 self.0.n_events_global()
615 }
616 #[getter]
618 fn p4_names(&self) -> Vec<String> {
619 self.0.p4_names().to_vec()
620 }
621 #[getter]
623 fn aux_names(&self) -> Vec<String> {
624 self.0.aux_names().to_vec()
625 }
626
627 #[getter]
639 fn n_events_weighted(&self) -> f64 {
640 self.0.n_events_weighted()
641 }
642 #[getter]
644 fn n_events_weighted_global(&self) -> f64 {
645 self.0.n_events_weighted_global()
646 }
647 #[getter]
660 fn n_events_weighted_local(&self) -> f64 {
661 self.0.n_events_weighted_local()
662 }
663 #[getter]
675 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
676 PyArray1::from_slice(py, &self.0.weights())
677 }
678 #[getter]
680 fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
681 PyArray1::from_slice(py, &self.0.weights_global())
682 }
683 #[getter]
695 fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
696 PyArray1::from_slice(py, &self.0.weights_local())
697 }
698 #[getter]
713 fn events(&self) -> Vec<PyEvent> {
714 self.0
715 .shared_iter()
716 .map(|rust_event| PyEvent {
717 event: rust_event,
718 has_metadata: true,
719 })
720 .collect()
721 }
722 #[getter]
724 fn events_global(&self) -> Vec<PyEvent> {
725 self.events()
726 }
727 #[getter]
734 fn events_local(&self) -> Vec<PyEvent> {
735 self.0
736 .events_local()
737 .iter()
738 .map(|rust_event| PyEvent {
739 event: rust_event.clone(),
740 has_metadata: true,
741 })
742 .collect()
743 }
744 fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
746 self.0
747 .p4_by_name(index, name)
748 .map(PyVec4)
749 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
750 }
751 fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
753 self.0
754 .aux_by_name(index, name)
755 .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
756 }
757 fn event_global(&self, index: usize) -> PyResult<PyEvent> {
763 let event = self
764 .0
765 .get_event_global(index)
766 .ok_or_else(|| PyIndexError::new_err("index out of range"))?;
767 Ok(PyEvent {
768 event,
769 has_metadata: true,
770 })
771 }
772 fn __getitem__<'py>(
773 &self,
774 py: Python<'py>,
775 index: Bound<'py, PyAny>,
776 ) -> PyResult<Bound<'py, PyAny>> {
777 if let Ok(value) = self.evaluate(py, index.clone()) {
778 value.into_bound_py_any(py)
779 } else if let Ok(index) = index.extract::<usize>() {
780 let event = self
781 .0
782 .get_event(index)
783 .ok_or_else(|| PyIndexError::new_err("index out of range"))?;
784 PyEvent {
785 event,
786 has_metadata: true,
787 }
788 .into_bound_py_any(py)
789 } else {
790 Err(PyTypeError::new_err(
791 "Unsupported index type (int or Variable)",
792 ))
793 }
794 }
795 #[pyo3(signature = (variable, bins, range))]
833 fn bin_by(
834 &self,
835 variable: Bound<'_, PyAny>,
836 bins: usize,
837 range: (f64, f64),
838 ) -> PyResult<PyBinnedDataset> {
839 let py_variable = variable.extract::<PyVariable>()?;
840 let bound_variable = py_variable.bound(self.0.metadata())?;
841 Ok(PyBinnedDataset(self.0.bin_by(
842 bound_variable,
843 bins,
844 range,
845 )?))
846 }
847 pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
865 Ok(PyDataset(
866 self.0.filter(&expression.0).map_err(PyErr::from)?,
867 ))
868 }
869 fn bootstrap(&self, seed: usize) -> PyDataset {
890 PyDataset(self.0.bootstrap(seed))
891 }
892 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
910 PyDataset(self.0.boost_to_rest_frame_of(&names))
911 }
912 fn evaluate<'py>(
930 &self,
931 py: Python<'py>,
932 variable: Bound<'py, PyAny>,
933 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
934 let variable = variable.extract::<PyVariable>()?;
935 let bound_variable = variable.bound(self.0.metadata())?;
936 let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
937 Ok(PyArray1::from_vec(py, values))
938 }
939}
940
941#[pyfunction]
955#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
956pub fn read_parquet(
957 path: Bound<PyAny>,
958 p4s: Option<Vec<String>>,
959 aux: Option<Vec<String>>,
960 aliases: Option<Bound<PyDict>>,
961) -> PyResult<PyDataset> {
962 let path_str = parse_dataset_path(path)?;
963 let mut read_options = DatasetReadOptions::default();
964 if let Some(p4s) = p4s {
965 read_options = read_options.p4_names(p4s);
966 }
967 if let Some(aux) = aux {
968 read_options = read_options.aux_names(aux);
969 }
970 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
971 read_options = read_options.alias(alias_name, selection);
972 }
973 let dataset = core_read_parquet(&path_str, &read_options)?;
974 Ok(PyDataset(dataset))
975}
976
977#[pyfunction]
979#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
980pub fn read_parquet_chunked(
981 path: Bound<PyAny>,
982 p4s: Option<Vec<String>>,
983 aux: Option<Vec<String>>,
984 aliases: Option<Bound<PyDict>>,
985 chunk_size: Option<usize>,
986) -> PyResult<PyParquetChunkIter> {
987 let path_str = parse_dataset_path(path)?;
988 let mut read_options = DatasetReadOptions::default();
989 if let Some(p4s) = p4s {
990 read_options = read_options.p4_names(p4s);
991 }
992 if let Some(aux) = aux {
993 read_options = read_options.aux_names(aux);
994 }
995 if let Some(chunk_size) = chunk_size {
996 read_options = read_options.chunk_size(chunk_size);
997 }
998 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
999 read_options = read_options.alias(alias_name, selection);
1000 }
1001
1002 let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
1003 Ok(PyParquetChunkIter {
1004 chunks: Box::new(chunks),
1005 })
1006}
1007
1008#[pyfunction]
1022#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
1023pub fn read_root(
1024 path: Bound<PyAny>,
1025 tree: Option<String>,
1026 p4s: Option<Vec<String>>,
1027 aux: Option<Vec<String>>,
1028 aliases: Option<Bound<PyDict>>,
1029) -> PyResult<PyDataset> {
1030 let path_str = parse_dataset_path(path)?;
1031 let mut read_options = DatasetReadOptions::default();
1032 if let Some(p4s) = p4s {
1033 read_options = read_options.p4_names(p4s);
1034 }
1035 if let Some(aux) = aux {
1036 read_options = read_options.aux_names(aux);
1037 }
1038 if let Some(tree) = tree {
1039 read_options = read_options.tree(tree);
1040 }
1041 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1042 read_options = read_options.alias(alias_name, selection);
1043 }
1044 let dataset = core_read_root(&path_str, &read_options)?;
1045 Ok(PyDataset(dataset))
1046}
1047
1048#[pyfunction]
1050#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
1051pub fn write_parquet(
1052 dataset: &PyDataset,
1053 path: Bound<PyAny>,
1054 chunk_size: Option<usize>,
1055 precision: &str,
1056) -> PyResult<()> {
1057 let path_str = parse_dataset_path(path)?;
1058 let mut write_options = DatasetWriteOptions::default();
1059 if let Some(size) = chunk_size {
1060 write_options.batch_size = size.max(1);
1061 }
1062 write_options.precision = parse_precision_arg(Some(precision))?;
1063 core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1064}
1065
1066#[pyfunction]
1068#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
1069pub fn write_root(
1070 dataset: &PyDataset,
1071 path: Bound<PyAny>,
1072 tree: Option<String>,
1073 chunk_size: Option<usize>,
1074 precision: &str,
1075) -> PyResult<()> {
1076 let path_str = parse_dataset_path(path)?;
1077 let mut write_options = DatasetWriteOptions::default();
1078 if let Some(name) = tree {
1079 write_options.tree = Some(name);
1080 }
1081 if let Some(size) = chunk_size {
1082 write_options.batch_size = size.max(1);
1083 }
1084 write_options.precision = parse_precision_arg(Some(precision))?;
1085 core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1086}
1087
1088#[doc(hidden)]
1089#[pyfunction]
1117#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
1118pub fn from_columns(
1119 columns: Bound<'_, PyDict>,
1120 p4s: Option<Vec<String>>,
1121 aux: Option<Vec<String>>,
1122 aliases: Option<Bound<'_, PyDict>>,
1123) -> PyResult<PyDataset> {
1124 let column_names = columns
1125 .iter()
1126 .map(|(key, _)| key.extract::<String>())
1127 .collect::<PyResult<Vec<_>>>()?;
1128
1129 let (detected_p4_names, detected_aux_names) =
1130 infer_p4_and_aux_names_from_columns(&column_names);
1131 let p4_names = p4s.unwrap_or(detected_p4_names);
1132 if p4_names.is_empty() {
1133 let mut partial_components: std::collections::BTreeMap<
1134 String,
1135 std::collections::BTreeSet<&str>,
1136 > = std::collections::BTreeMap::new();
1137 for column_name in &column_names {
1138 let lowered = column_name.to_ascii_lowercase();
1139 for suffix in P4_COMPONENT_SUFFIXES {
1140 if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
1141 let prefix = column_name[..column_name.len() - suffix.len()].to_string();
1142 partial_components.entry(prefix).or_default().insert(suffix);
1143 }
1144 }
1145 }
1146 if let Some((prefix, present)) = partial_components.iter().next() {
1147 if present.len() < P4_COMPONENT_SUFFIXES.len() {
1148 let missing = P4_COMPONENT_SUFFIXES
1149 .iter()
1150 .filter(|suffix| !present.contains(**suffix))
1151 .map(|suffix| format!("{prefix}{suffix}"))
1152 .collect::<Vec<_>>()
1153 .join(", ");
1154 return Err(PyKeyError::new_err(format!(
1155 "Missing components [{missing}] for four-momentum '{prefix}'"
1156 )));
1157 }
1158 }
1159 return Err(PyValueError::new_err(
1160 "No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
1161 ));
1162 }
1163
1164 let aux_names = aux.unwrap_or(detected_aux_names);
1165 let p4_component_columns =
1166 resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
1167 let resolved_aux_columns =
1168 resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
1169
1170 let n_events = {
1171 let first_name = p4_component_columns
1172 .first()
1173 .map(|components| components[0].clone())
1174 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
1175 let values = extract_numeric_column(
1176 columns
1177 .get_item(first_name.as_str())?
1178 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
1179 &first_name,
1180 )?;
1181 values.len()
1182 };
1183
1184 let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
1185 for component_names in &p4_component_columns {
1186 let px = extract_numeric_column(
1187 columns
1188 .get_item(component_names[0].as_str())?
1189 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
1190 component_names[0].as_str(),
1191 )?;
1192 let py = extract_numeric_column(
1193 columns
1194 .get_item(component_names[1].as_str())?
1195 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
1196 component_names[1].as_str(),
1197 )?;
1198 let pz = extract_numeric_column(
1199 columns
1200 .get_item(component_names[2].as_str())?
1201 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
1202 component_names[2].as_str(),
1203 )?;
1204 let e = extract_numeric_column(
1205 columns
1206 .get_item(component_names[3].as_str())?
1207 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
1208 component_names[3].as_str(),
1209 )?;
1210 if px.len() != n_events
1211 || py.len() != n_events
1212 || pz.len() != n_events
1213 || e.len() != n_events
1214 {
1215 return Err(PyValueError::new_err(
1216 "All p4 components must have the same length",
1217 ));
1218 }
1219 p4_columns.push([px, py, pz, e]);
1220 }
1221
1222 let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
1223 for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
1224 let values = extract_numeric_column(
1225 columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
1226 PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
1227 })?,
1228 aux_name,
1229 )?;
1230 if values.len() != n_events {
1231 return Err(PyValueError::new_err(format!(
1232 "Auxiliary column '{aux_name}' length does not match p4 columns"
1233 )));
1234 }
1235 aux_columns.push(values);
1236 }
1237
1238 let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
1239 let weight_values = columns
1240 .get_item(weight_column_name.as_str())?
1241 .ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
1242 let values = extract_numeric_column(weight_values, "weight")?;
1243 if values.len() != n_events {
1244 return Err(PyValueError::new_err(
1245 "Column 'weight' length does not match p4 columns",
1246 ));
1247 }
1248 values
1249 } else {
1250 vec![1.0; n_events]
1251 };
1252
1253 let parsed_aliases = parse_aliases(aliases)?;
1254 let mut metadata =
1255 DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
1256 if !parsed_aliases.is_empty() {
1257 metadata
1258 .add_p4_aliases(
1259 parsed_aliases
1260 .into_iter()
1261 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
1262 )
1263 .map_err(PyErr::from)?;
1264 }
1265
1266 let mut events = Vec::with_capacity(n_events);
1267 for event_idx in 0..n_events {
1268 let p4s = p4_columns
1269 .iter()
1270 .map(|components| {
1271 laddu_core::utils::vectors::Vec4::new(
1272 components[0][event_idx],
1273 components[1][event_idx],
1274 components[2][event_idx],
1275 components[3][event_idx],
1276 )
1277 })
1278 .collect::<Vec<_>>();
1279 let aux = aux_columns
1280 .iter()
1281 .map(|values| values[event_idx])
1282 .collect::<Vec<_>>();
1283 events.push(Arc::new(EventData {
1284 p4s,
1285 aux,
1286 weight: weights[event_idx],
1287 }));
1288 }
1289
1290 Ok(PyDataset(Arc::new(Dataset::new_with_metadata(
1291 events,
1292 Arc::new(metadata),
1293 ))))
1294}
1295
1296#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
1305pub struct PyBinnedDataset(BinnedDataset);
1306
1307#[pymethods]
1308impl PyBinnedDataset {
1309 fn __len__(&self) -> usize {
1310 self.0.n_bins()
1311 }
1312 #[getter]
1315 fn n_bins(&self) -> usize {
1316 self.0.n_bins()
1317 }
1318 #[getter]
1321 fn range(&self) -> (f64, f64) {
1322 self.0.range()
1323 }
1324 #[getter]
1327 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
1328 PyArray1::from_slice(py, &self.0.edges())
1329 }
1330 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
1331 self.0
1332 .get(index)
1333 .ok_or(PyIndexError::new_err("index out of range"))
1334 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
1335 }
1336}