1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3 data::{
4 io::{
5 infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
6 resolve_optional_weight_column, resolve_p4_component_columns, P4_COMPONENT_SUFFIXES,
7 },
8 read_parquet as core_read_parquet,
9 read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
10 read_root as core_read_root, write_parquet as core_write_parquet,
11 write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
12 DatasetWriteOptions, Event, EventData, FloatPrecision, SharedDatasetIterExt,
13 },
14 utils::variables::IntoP4Selection,
15 DatasetReadOptions,
16};
17use numpy::{PyArray1, PyReadonlyArray1};
18use pyo3::{
19 exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
20 prelude::*,
21 types::{PyDict, PyList},
22 IntoPyObjectExt,
23};
24use std::{path::PathBuf, sync::Arc};
25
26use crate::utils::vectors::PyVec4;
27
28fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
29 let Some(aliases) = aliases else {
30 return Ok(Vec::new());
31 };
32
33 let mut parsed = Vec::new();
34 for (key, value) in aliases.iter() {
35 let alias_name = key.extract::<String>()?;
36 let selection = if let Ok(single) = value.extract::<String>() {
37 vec![single]
38 } else {
39 let seq = value.extract::<Vec<String>>().map_err(|_| {
40 PyTypeError::new_err("Alias values must be a string or a sequence of strings")
41 })?;
42 if seq.is_empty() {
43 return Err(PyValueError::new_err(format!(
44 "Alias '{alias_name}' must reference at least one particle",
45 )));
46 }
47 seq
48 };
49 parsed.push((alias_name, selection));
50 }
51
52 Ok(parsed)
53}
54
55fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
56 if let Ok(s) = path.extract::<String>() {
57 Ok(s)
58 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
59 Ok(pathbuf.to_string_lossy().into_owned())
60 } else {
61 Err(PyTypeError::new_err("Expected str or Path"))
62 }
63}
64
65fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
66 match value.map(|v| v.to_ascii_lowercase()) {
67 None => Ok(FloatPrecision::F64),
68 Some(name) if name == "f64" || name == "float64" || name == "double" => {
69 Ok(FloatPrecision::F64)
70 }
71 Some(name) if name == "f32" || name == "float32" || name == "float" => {
72 Ok(FloatPrecision::F32)
73 }
74 Some(other) => Err(PyValueError::new_err(format!(
75 "Unsupported precision '{other}' (expected 'f64' or 'f32')"
76 ))),
77 }
78}
79
80fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
81 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
82 return Ok(array.as_slice()?.to_vec());
83 }
84 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
85 return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
86 }
87 if let Ok(values) = value.extract::<Vec<f64>>() {
88 return Ok(values);
89 }
90 if let Ok(values) = value.extract::<Vec<f32>>() {
91 return Ok(values.into_iter().map(|v| v as f64).collect());
92 }
93 if let Ok(list) = value.cast::<PyList>() {
94 let mut converted = Vec::with_capacity(list.len());
95 for item in list.iter() {
96 converted.push(item.extract::<f64>().map_err(|_| {
97 PyTypeError::new_err(format!(
98 "Column '{name}' must be numeric (float32/float64/list of floats)"
99 ))
100 })?);
101 }
102 return Ok(converted);
103 }
104 Err(PyTypeError::new_err(format!(
105 "Column '{name}' must be numeric (float32/float64/list of floats)"
106 )))
107}
108
109#[pyclass(name = "Event", module = "laddu", from_py_object)]
147#[derive(Clone)]
148pub struct PyEvent {
149 pub event: Event,
150 has_metadata: bool,
151}
152
153#[pymethods]
154impl PyEvent {
155 #[new]
156 #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
157 fn new(
158 p4s: Vec<PyVec4>,
159 aux: Vec<f64>,
160 weight: f64,
161 p4_names: Option<Vec<String>>,
162 aux_names: Option<Vec<String>>,
163 aliases: Option<Bound<PyDict>>,
164 ) -> PyResult<Self> {
165 let event = EventData {
166 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
167 aux,
168 weight,
169 };
170 let aliases = parse_aliases(aliases)?;
171
172 let missing_p4_names = p4_names
173 .as_ref()
174 .map(|names| names.is_empty())
175 .unwrap_or(true);
176
177 if !aliases.is_empty() && missing_p4_names {
178 return Err(PyValueError::new_err(
179 "`aliases` requires `p4_names` so selections can be resolved",
180 ));
181 }
182
183 let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
184 let metadata = if metadata_provided {
185 let p4_names = p4_names.unwrap_or_default();
186 let aux_names = aux_names.unwrap_or_default();
187 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
188 if !aliases.is_empty() {
189 metadata
190 .add_p4_aliases(
191 aliases.into_iter().map(|(alias_name, selection)| {
192 (alias_name, selection.into_selection())
193 }),
194 )
195 .map_err(PyErr::from)?;
196 }
197 Arc::new(metadata)
198 } else {
199 Arc::new(DatasetMetadata::empty())
200 };
201 let event = Event::new(Arc::new(event), metadata);
202 Ok(Self {
203 event,
204 has_metadata: metadata_provided,
205 })
206 }
207 fn __str__(&self) -> String {
208 self.event.data().to_string()
209 }
210 #[getter]
213 fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
214 self.ensure_metadata()?;
215 let mapping = PyDict::new(py);
216 for (name, vec4) in self.event.p4s() {
217 mapping.set_item(name, PyVec4(vec4))?;
218 }
219 Ok(mapping.into())
220 }
221 #[getter]
224 #[pyo3(name = "aux")]
225 fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
226 self.ensure_metadata()?;
227 let mapping = PyDict::new(py);
228 for (name, value) in self.event.aux() {
229 mapping.set_item(name, value)?;
230 }
231 Ok(mapping.into())
232 }
233 #[getter]
236 fn get_weight(&self) -> f64 {
237 self.event.weight()
238 }
239 fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
252 let indices = self.resolve_p4_indices(&names)?;
253 Ok(PyVec4(self.event.data().get_p4_sum(indices)))
254 }
255 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
269 let indices = self.resolve_p4_indices(&names)?;
270 let boosted = self.event.data().boost_to_rest_frame_of(indices);
271 Ok(Self {
272 event: Event::new(Arc::new(boosted), self.event.metadata_arc()),
273 has_metadata: self.has_metadata,
274 })
275 }
276 fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
306 let mut variable = variable.extract::<PyVariable>()?;
307 let metadata = self.ensure_metadata()?;
308 variable.bind_in_place(metadata)?;
309 variable.evaluate_event(&self.event)
310 }
311
312 fn p4(&self, name: &str) -> PyResult<PyVec4> {
314 self.ensure_metadata()?;
315 self.event
316 .p4(name)
317 .map(PyVec4)
318 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
319 }
320}
321
322impl PyEvent {
323 fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
324 if !self.has_metadata {
325 Err(PyValueError::new_err(
326 "Event has no associated metadata for name-based operations",
327 ))
328 } else {
329 Ok(self.event.metadata())
330 }
331 }
332
333 fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
334 let metadata = self.ensure_metadata()?;
335 let mut resolved = Vec::new();
336 for name in names {
337 let selection = metadata
338 .p4_selection(name)
339 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
340 resolved.extend_from_slice(selection.indices());
341 }
342 Ok(resolved)
343 }
344
345 pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
346 self.has_metadata.then(|| self.event.metadata())
347 }
348}
349
350#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
391#[derive(Clone)]
392pub struct PyDataset(pub Arc<Dataset>);
393
394#[pyclass(
395 name = "ParquetChunkIter",
396 module = "laddu",
397 unsendable,
398 skip_from_py_object
399)]
400pub struct PyParquetChunkIter {
401 chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
402}
403
404#[pymethods]
405impl PyParquetChunkIter {
406 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
407 slf.into()
408 }
409
410 fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
411 match self.chunks.next() {
412 Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
413 Some(Err(err)) => Err(PyErr::from(err)),
414 None => Ok(None),
415 }
416 }
417}
418
419#[pyclass(
420 name = "DatasetIter",
421 module = "laddu",
422 unsendable,
423 skip_from_py_object
424)]
425struct PyDatasetIter {
426 kind: PyDatasetIterKind,
427}
428
429enum PyDatasetIterKind {
430 Local { dataset: Arc<Dataset>, index: usize },
431 Global(DatasetArcIter),
432}
433
434#[pymethods]
435impl PyDatasetIter {
436 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetIter> {
437 slf.into()
438 }
439
440 fn __next__(&mut self) -> Option<PyEvent> {
441 let event = match &mut self.kind {
442 PyDatasetIterKind::Local { dataset, index } => {
443 let event = dataset.events_local().get(*index)?.clone();
444 *index += 1;
445 event
446 }
447 PyDatasetIterKind::Global(iterator) => iterator.next()?,
448 };
449 Some(PyEvent {
450 event,
451 has_metadata: true,
452 })
453 }
454}
455
456#[pymethods]
457impl PyDataset {
458 #[new]
459 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
460 fn new(
461 events: Vec<PyEvent>,
462 p4_names: Option<Vec<String>>,
463 aux_names: Option<Vec<String>>,
464 aliases: Option<Bound<PyDict>>,
465 ) -> PyResult<Self> {
466 let inferred_metadata = events
467 .iter()
468 .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
469
470 let aliases = parse_aliases(aliases)?;
471 let use_explicit_metadata =
472 p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
473
474 let metadata =
475 if use_explicit_metadata {
476 let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
477 (Some(names), _) => names,
478 (None, Some(metadata)) => metadata.p4_names().to_vec(),
479 (None, None) => Vec::new(),
480 };
481 let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
482 (Some(names), _) => names,
483 (None, Some(metadata)) => metadata.aux_names().to_vec(),
484 (None, None) => Vec::new(),
485 };
486
487 if !aliases.is_empty() && resolved_p4_names.is_empty() {
488 return Err(PyValueError::new_err(
489 "`aliases` requires `p4_names` or events with metadata for resolution",
490 ));
491 }
492
493 let mut metadata = DatasetMetadata::new(resolved_p4_names, resolved_aux_names)
494 .map_err(PyErr::from)?;
495 if !aliases.is_empty() {
496 metadata
497 .add_p4_aliases(aliases.into_iter().map(|(alias_name, selection)| {
498 (alias_name, selection.into_selection())
499 }))
500 .map_err(PyErr::from)?;
501 }
502 Some(Arc::new(metadata))
503 } else {
504 inferred_metadata
505 };
506
507 let events: Vec<Arc<EventData>> = events
508 .into_iter()
509 .map(|event| event.event.data_arc())
510 .collect();
511 let dataset = if let Some(metadata) = metadata {
512 Dataset::new_with_metadata(events, metadata)
513 } else {
514 Dataset::new(events)
515 };
516 Ok(Self(Arc::new(dataset)))
517 }
518
519 fn __len__(&self) -> usize {
520 self.0.n_events()
521 }
522 fn __iter__(&self) -> PyDatasetIter {
530 self.iter_global()
531 }
532 #[getter]
534 fn n_events_local(&self) -> usize {
535 self.0.n_events_local()
536 }
537 fn iter_local(&self) -> PyDatasetIter {
545 PyDatasetIter {
546 kind: PyDatasetIterKind::Local {
547 dataset: self.0.clone(),
548 index: 0,
549 },
550 }
551 }
552 fn iter_global(&self) -> PyDatasetIter {
560 PyDatasetIter {
561 kind: PyDatasetIterKind::Global(self.0.shared_iter_global()),
562 }
563 }
564 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
565 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
566 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
567 } else if let Ok(other_int) = other.extract::<usize>() {
568 if other_int == 0 {
569 Ok(self.clone())
570 } else {
571 Err(PyTypeError::new_err(
572 "Addition with an integer for this type is only defined for 0",
573 ))
574 }
575 } else {
576 Err(PyTypeError::new_err("Unsupported operand type for +"))
577 }
578 }
579 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
580 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
581 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
582 } else if let Ok(other_int) = other.extract::<usize>() {
583 if other_int == 0 {
584 Ok(self.clone())
585 } else {
586 Err(PyTypeError::new_err(
587 "Addition with an integer for this type is only defined for 0",
588 ))
589 }
590 } else {
591 Err(PyTypeError::new_err("Unsupported operand type for +"))
592 }
593 }
594 #[getter]
607 fn n_events(&self) -> usize {
608 self.0.n_events()
609 }
610 #[getter]
612 fn n_events_global(&self) -> usize {
613 self.0.n_events_global()
614 }
615 #[getter]
617 fn p4_names(&self) -> Vec<String> {
618 self.0.p4_names().to_vec()
619 }
620 #[getter]
622 fn aux_names(&self) -> Vec<String> {
623 self.0.aux_names().to_vec()
624 }
625
626 #[getter]
638 fn n_events_weighted(&self) -> f64 {
639 self.0.n_events_weighted()
640 }
641 #[getter]
643 fn n_events_weighted_global(&self) -> f64 {
644 self.0.n_events_weighted_global()
645 }
646 #[getter]
659 fn n_events_weighted_local(&self) -> f64 {
660 self.0.n_events_weighted_local()
661 }
662 #[getter]
674 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
675 PyArray1::from_slice(py, &self.0.weights())
676 }
677 #[getter]
679 fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
680 PyArray1::from_slice(py, &self.0.weights_global())
681 }
682 #[getter]
694 fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
695 PyArray1::from_slice(py, &self.0.weights_local())
696 }
697 #[getter]
712 fn events(&self) -> Vec<PyEvent> {
713 self.0
714 .shared_iter()
715 .map(|rust_event| PyEvent {
716 event: rust_event,
717 has_metadata: true,
718 })
719 .collect()
720 }
721 #[getter]
723 fn events_global(&self) -> Vec<PyEvent> {
724 self.events()
725 }
726 #[getter]
733 fn events_local(&self) -> Vec<PyEvent> {
734 self.0
735 .events_local()
736 .iter()
737 .map(|rust_event| PyEvent {
738 event: rust_event.clone(),
739 has_metadata: true,
740 })
741 .collect()
742 }
743 fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
745 self.0
746 .p4_by_name(index, name)
747 .map(PyVec4)
748 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
749 }
750 fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
752 self.0
753 .aux_by_name(index, name)
754 .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
755 }
756 fn event_global(&self, index: usize) -> PyResult<PyEvent> {
762 let event = self
763 .0
764 .get_event_global(index)
765 .ok_or_else(|| PyIndexError::new_err("index out of range"))?;
766 Ok(PyEvent {
767 event,
768 has_metadata: true,
769 })
770 }
771 fn __getitem__<'py>(
772 &self,
773 py: Python<'py>,
774 index: Bound<'py, PyAny>,
775 ) -> PyResult<Bound<'py, PyAny>> {
776 if let Ok(value) = self.evaluate(py, index.clone()) {
777 value.into_bound_py_any(py)
778 } else if let Ok(index) = index.extract::<usize>() {
779 let event = self
780 .0
781 .get_event(index)
782 .ok_or_else(|| PyIndexError::new_err("index out of range"))?;
783 PyEvent {
784 event,
785 has_metadata: true,
786 }
787 .into_bound_py_any(py)
788 } else {
789 Err(PyTypeError::new_err(
790 "Unsupported index type (int or Variable)",
791 ))
792 }
793 }
794 #[pyo3(signature = (variable, bins, range))]
832 fn bin_by(
833 &self,
834 variable: Bound<'_, PyAny>,
835 bins: usize,
836 range: (f64, f64),
837 ) -> PyResult<PyBinnedDataset> {
838 let py_variable = variable.extract::<PyVariable>()?;
839 let bound_variable = py_variable.bound(self.0.metadata())?;
840 Ok(PyBinnedDataset(self.0.bin_by(
841 bound_variable,
842 bins,
843 range,
844 )?))
845 }
846 pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
864 Ok(PyDataset(
865 self.0.filter(&expression.0).map_err(PyErr::from)?,
866 ))
867 }
868 fn bootstrap(&self, seed: usize) -> PyDataset {
889 PyDataset(self.0.bootstrap(seed))
890 }
891 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
909 PyDataset(self.0.boost_to_rest_frame_of(&names))
910 }
911 fn evaluate<'py>(
929 &self,
930 py: Python<'py>,
931 variable: Bound<'py, PyAny>,
932 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
933 let variable = variable.extract::<PyVariable>()?;
934 let bound_variable = variable.bound(self.0.metadata())?;
935 let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
936 Ok(PyArray1::from_vec(py, values))
937 }
938}
939
940#[pyfunction]
954#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
955pub fn read_parquet(
956 path: Bound<PyAny>,
957 p4s: Option<Vec<String>>,
958 aux: Option<Vec<String>>,
959 aliases: Option<Bound<PyDict>>,
960) -> PyResult<PyDataset> {
961 let path_str = parse_dataset_path(path)?;
962 let mut read_options = DatasetReadOptions::default();
963 if let Some(p4s) = p4s {
964 read_options = read_options.p4_names(p4s);
965 }
966 if let Some(aux) = aux {
967 read_options = read_options.aux_names(aux);
968 }
969 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
970 read_options = read_options.alias(alias_name, selection);
971 }
972 let dataset = core_read_parquet(&path_str, &read_options)?;
973 Ok(PyDataset(dataset))
974}
975
976#[pyfunction]
978#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
979pub fn read_parquet_chunked(
980 path: Bound<PyAny>,
981 p4s: Option<Vec<String>>,
982 aux: Option<Vec<String>>,
983 aliases: Option<Bound<PyDict>>,
984 chunk_size: Option<usize>,
985) -> PyResult<PyParquetChunkIter> {
986 let path_str = parse_dataset_path(path)?;
987 let mut read_options = DatasetReadOptions::default();
988 if let Some(p4s) = p4s {
989 read_options = read_options.p4_names(p4s);
990 }
991 if let Some(aux) = aux {
992 read_options = read_options.aux_names(aux);
993 }
994 if let Some(chunk_size) = chunk_size {
995 read_options = read_options.chunk_size(chunk_size);
996 }
997 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
998 read_options = read_options.alias(alias_name, selection);
999 }
1000
1001 let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
1002 Ok(PyParquetChunkIter {
1003 chunks: Box::new(chunks),
1004 })
1005}
1006
1007#[pyfunction]
1021#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
1022pub fn read_root(
1023 path: Bound<PyAny>,
1024 tree: Option<String>,
1025 p4s: Option<Vec<String>>,
1026 aux: Option<Vec<String>>,
1027 aliases: Option<Bound<PyDict>>,
1028) -> PyResult<PyDataset> {
1029 let path_str = parse_dataset_path(path)?;
1030 let mut read_options = DatasetReadOptions::default();
1031 if let Some(p4s) = p4s {
1032 read_options = read_options.p4_names(p4s);
1033 }
1034 if let Some(aux) = aux {
1035 read_options = read_options.aux_names(aux);
1036 }
1037 if let Some(tree) = tree {
1038 read_options = read_options.tree(tree);
1039 }
1040 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
1041 read_options = read_options.alias(alias_name, selection);
1042 }
1043 let dataset = core_read_root(&path_str, &read_options)?;
1044 Ok(PyDataset(dataset))
1045}
1046
1047#[pyfunction]
1049#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
1050pub fn write_parquet(
1051 dataset: &PyDataset,
1052 path: Bound<PyAny>,
1053 chunk_size: Option<usize>,
1054 precision: &str,
1055) -> PyResult<()> {
1056 let path_str = parse_dataset_path(path)?;
1057 let mut write_options = DatasetWriteOptions::default();
1058 if let Some(size) = chunk_size {
1059 write_options.batch_size = size.max(1);
1060 }
1061 write_options.precision = parse_precision_arg(Some(precision))?;
1062 core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1063}
1064
1065#[pyfunction]
1067#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
1068pub fn write_root(
1069 dataset: &PyDataset,
1070 path: Bound<PyAny>,
1071 tree: Option<String>,
1072 chunk_size: Option<usize>,
1073 precision: &str,
1074) -> PyResult<()> {
1075 let path_str = parse_dataset_path(path)?;
1076 let mut write_options = DatasetWriteOptions::default();
1077 if let Some(name) = tree {
1078 write_options.tree = Some(name);
1079 }
1080 if let Some(size) = chunk_size {
1081 write_options.batch_size = size.max(1);
1082 }
1083 write_options.precision = parse_precision_arg(Some(precision))?;
1084 core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
1085}
1086
1087#[pyfunction]
1115#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
1116pub fn from_columns(
1117 columns: Bound<'_, PyDict>,
1118 p4s: Option<Vec<String>>,
1119 aux: Option<Vec<String>>,
1120 aliases: Option<Bound<'_, PyDict>>,
1121) -> PyResult<PyDataset> {
1122 let column_names = columns
1123 .iter()
1124 .map(|(key, _)| key.extract::<String>())
1125 .collect::<PyResult<Vec<_>>>()?;
1126
1127 let (detected_p4_names, detected_aux_names) =
1128 infer_p4_and_aux_names_from_columns(&column_names);
1129 let p4_names = p4s.unwrap_or(detected_p4_names);
1130 if p4_names.is_empty() {
1131 let mut partial_components: std::collections::BTreeMap<
1132 String,
1133 std::collections::BTreeSet<&str>,
1134 > = std::collections::BTreeMap::new();
1135 for column_name in &column_names {
1136 let lowered = column_name.to_ascii_lowercase();
1137 for suffix in P4_COMPONENT_SUFFIXES {
1138 if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
1139 let prefix = column_name[..column_name.len() - suffix.len()].to_string();
1140 partial_components.entry(prefix).or_default().insert(suffix);
1141 }
1142 }
1143 }
1144 if let Some((prefix, present)) = partial_components.iter().next() {
1145 if present.len() < P4_COMPONENT_SUFFIXES.len() {
1146 let missing = P4_COMPONENT_SUFFIXES
1147 .iter()
1148 .filter(|suffix| !present.contains(**suffix))
1149 .map(|suffix| format!("{prefix}{suffix}"))
1150 .collect::<Vec<_>>()
1151 .join(", ");
1152 return Err(PyKeyError::new_err(format!(
1153 "Missing components [{missing}] for four-momentum '{prefix}'"
1154 )));
1155 }
1156 }
1157 return Err(PyValueError::new_err(
1158 "No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
1159 ));
1160 }
1161
1162 let aux_names = aux.unwrap_or(detected_aux_names);
1163 let p4_component_columns =
1164 resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
1165 let resolved_aux_columns =
1166 resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
1167
1168 let n_events = {
1169 let first_name = p4_component_columns
1170 .first()
1171 .map(|components| components[0].clone())
1172 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
1173 let values = extract_numeric_column(
1174 columns
1175 .get_item(first_name.as_str())?
1176 .ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
1177 &first_name,
1178 )?;
1179 values.len()
1180 };
1181
1182 let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
1183 for component_names in &p4_component_columns {
1184 let px = extract_numeric_column(
1185 columns
1186 .get_item(component_names[0].as_str())?
1187 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
1188 component_names[0].as_str(),
1189 )?;
1190 let py = extract_numeric_column(
1191 columns
1192 .get_item(component_names[1].as_str())?
1193 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
1194 component_names[1].as_str(),
1195 )?;
1196 let pz = extract_numeric_column(
1197 columns
1198 .get_item(component_names[2].as_str())?
1199 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
1200 component_names[2].as_str(),
1201 )?;
1202 let e = extract_numeric_column(
1203 columns
1204 .get_item(component_names[3].as_str())?
1205 .ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
1206 component_names[3].as_str(),
1207 )?;
1208 if px.len() != n_events
1209 || py.len() != n_events
1210 || pz.len() != n_events
1211 || e.len() != n_events
1212 {
1213 return Err(PyValueError::new_err(
1214 "All p4 components must have the same length",
1215 ));
1216 }
1217 p4_columns.push([px, py, pz, e]);
1218 }
1219
1220 let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
1221 for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
1222 let values = extract_numeric_column(
1223 columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
1224 PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
1225 })?,
1226 aux_name,
1227 )?;
1228 if values.len() != n_events {
1229 return Err(PyValueError::new_err(format!(
1230 "Auxiliary column '{aux_name}' length does not match p4 columns"
1231 )));
1232 }
1233 aux_columns.push(values);
1234 }
1235
1236 let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
1237 let weight_values = columns
1238 .get_item(weight_column_name.as_str())?
1239 .ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
1240 let values = extract_numeric_column(weight_values, "weight")?;
1241 if values.len() != n_events {
1242 return Err(PyValueError::new_err(
1243 "Column 'weight' length does not match p4 columns",
1244 ));
1245 }
1246 values
1247 } else {
1248 vec![1.0; n_events]
1249 };
1250
1251 let parsed_aliases = parse_aliases(aliases)?;
1252 let mut metadata =
1253 DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
1254 if !parsed_aliases.is_empty() {
1255 metadata
1256 .add_p4_aliases(
1257 parsed_aliases
1258 .into_iter()
1259 .map(|(alias_name, selection)| (alias_name, selection.into_selection())),
1260 )
1261 .map_err(PyErr::from)?;
1262 }
1263
1264 let mut events = Vec::with_capacity(n_events);
1265 for event_idx in 0..n_events {
1266 let p4s = p4_columns
1267 .iter()
1268 .map(|components| {
1269 laddu_core::utils::vectors::Vec4::new(
1270 components[0][event_idx],
1271 components[1][event_idx],
1272 components[2][event_idx],
1273 components[3][event_idx],
1274 )
1275 })
1276 .collect::<Vec<_>>();
1277 let aux = aux_columns
1278 .iter()
1279 .map(|values| values[event_idx])
1280 .collect::<Vec<_>>();
1281 events.push(Arc::new(EventData {
1282 p4s,
1283 aux,
1284 weight: weights[event_idx],
1285 }));
1286 }
1287
1288 Ok(PyDataset(Arc::new(Dataset::new_with_metadata(
1289 events,
1290 Arc::new(metadata),
1291 ))))
1292}
1293
1294#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
1303pub struct PyBinnedDataset(BinnedDataset);
1304
1305#[pymethods]
1306impl PyBinnedDataset {
1307 fn __len__(&self) -> usize {
1308 self.0.n_bins()
1309 }
1310 #[getter]
1313 fn n_bins(&self) -> usize {
1314 self.0.n_bins()
1315 }
1316 #[getter]
1319 fn range(&self) -> (f64, f64) {
1320 self.0.range()
1321 }
1322 #[getter]
1325 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
1326 PyArray1::from_slice(py, &self.0.edges())
1327 }
1328 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
1329 self.0
1330 .get(index)
1331 .ok_or(PyIndexError::new_err("index out of range"))
1332 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
1333 }
1334}