1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3 data::{
4 read_parquet as core_read_parquet, read_root as core_read_root,
5 write_parquet as core_write_parquet, write_root as core_write_root, BinnedDataset, Dataset,
6 DatasetMetadata, DatasetWriteOptions, Event, EventData, FloatPrecision,
7 },
8 utils::variables::IntoP4Selection,
9 DatasetReadOptions,
10};
11use numpy::PyArray1;
12use pyo3::{
13 exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
14 prelude::*,
15 types::PyDict,
16 IntoPyObjectExt,
17};
18use std::{path::PathBuf, sync::Arc};
19
20use crate::utils::vectors::PyVec4;
21
22fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
23 let Some(aliases) = aliases else {
24 return Ok(Vec::new());
25 };
26
27 let mut parsed = Vec::new();
28 for (key, value) in aliases.iter() {
29 let alias_name = key.extract::<String>()?;
30 let selection = if let Ok(single) = value.extract::<String>() {
31 vec![single]
32 } else {
33 let seq = value.extract::<Vec<String>>().map_err(|_| {
34 PyTypeError::new_err("Alias values must be a string or a sequence of strings")
35 })?;
36 if seq.is_empty() {
37 return Err(PyValueError::new_err(format!(
38 "Alias '{alias_name}' must reference at least one particle",
39 )));
40 }
41 seq
42 };
43 parsed.push((alias_name, selection));
44 }
45
46 Ok(parsed)
47}
48
49fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
50 if let Ok(s) = path.extract::<String>() {
51 Ok(s)
52 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
53 Ok(pathbuf.to_string_lossy().into_owned())
54 } else {
55 Err(PyTypeError::new_err("Expected str or Path"))
56 }
57}
58
59fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
60 match value.map(|v| v.to_ascii_lowercase()) {
61 None => Ok(FloatPrecision::F64),
62 Some(name) if name == "f64" || name == "float64" || name == "double" => {
63 Ok(FloatPrecision::F64)
64 }
65 Some(name) if name == "f32" || name == "float32" || name == "float" => {
66 Ok(FloatPrecision::F32)
67 }
68 Some(other) => Err(PyValueError::new_err(format!(
69 "Unsupported precision '{other}' (expected 'f64' or 'f32')"
70 ))),
71 }
72}
73
74#[pyclass(name = "Event", module = "laddu")]
110#[derive(Clone)]
111pub struct PyEvent {
112 pub event: Event,
113 has_metadata: bool,
114}
115
116#[pymethods]
117impl PyEvent {
118 #[new]
119 #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
120 fn new(
121 p4s: Vec<PyVec4>,
122 aux: Vec<f64>,
123 weight: f64,
124 p4_names: Option<Vec<String>>,
125 aux_names: Option<Vec<String>>,
126 aliases: Option<Bound<PyDict>>,
127 ) -> PyResult<Self> {
128 let event = EventData {
129 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
130 aux,
131 weight,
132 };
133 let aliases = parse_aliases(aliases)?;
134
135 let missing_p4_names = p4_names
136 .as_ref()
137 .map(|names| names.is_empty())
138 .unwrap_or(true);
139
140 if !aliases.is_empty() && missing_p4_names {
141 return Err(PyValueError::new_err(
142 "`aliases` requires `p4_names` so selections can be resolved",
143 ));
144 }
145
146 let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
147 let metadata = if metadata_provided {
148 let p4_names = p4_names.unwrap_or_default();
149 let aux_names = aux_names.unwrap_or_default();
150 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
151 if !aliases.is_empty() {
152 metadata
153 .add_p4_aliases(
154 aliases.into_iter().map(|(alias_name, selection)| {
155 (alias_name, selection.into_selection())
156 }),
157 )
158 .map_err(PyErr::from)?;
159 }
160 Arc::new(metadata)
161 } else {
162 Arc::new(DatasetMetadata::empty())
163 };
164 let event = Event::new(Arc::new(event), metadata);
165 Ok(Self {
166 event,
167 has_metadata: metadata_provided,
168 })
169 }
170 fn __str__(&self) -> String {
171 self.event.data().to_string()
172 }
173 #[getter]
176 fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
177 self.ensure_metadata()?;
178 let mapping = PyDict::new(py);
179 for (name, vec4) in self.event.p4s() {
180 mapping.set_item(name, PyVec4(vec4))?;
181 }
182 Ok(mapping.into())
183 }
184 #[getter]
187 #[pyo3(name = "aux")]
188 fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
189 self.ensure_metadata()?;
190 let mapping = PyDict::new(py);
191 for (name, value) in self.event.aux() {
192 mapping.set_item(name, value)?;
193 }
194 Ok(mapping.into())
195 }
196 #[getter]
199 fn get_weight(&self) -> f64 {
200 self.event.weight()
201 }
202 fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
215 let indices = self.resolve_p4_indices(&names)?;
216 Ok(PyVec4(self.event.data().get_p4_sum(indices)))
217 }
218 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
232 let indices = self.resolve_p4_indices(&names)?;
233 let boosted = self.event.data().boost_to_rest_frame_of(indices);
234 Ok(Self {
235 event: Event::new(Arc::new(boosted), self.event.metadata_arc()),
236 has_metadata: self.has_metadata,
237 })
238 }
239 fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
256 let mut variable = variable.extract::<PyVariable>()?;
257 if !self.has_metadata {
258 return Err(PyValueError::new_err(
259 "Cannot evaluate variable on an Event without associated metadata. Construct the Event with `p4_names`/`aux_names` or evaluate through a Dataset.",
260 ));
261 }
262 variable.bind_in_place(self.event.metadata())?;
263 let event_arc = self.event.data_arc();
264 variable.evaluate_event(&event_arc)
265 }
266
267 fn p4(&self, name: &str) -> PyResult<Option<PyVec4>> {
269 self.ensure_metadata()?;
270 Ok(self.event.p4(name).map(PyVec4))
271 }
272}
273
274impl PyEvent {
275 fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
276 if !self.has_metadata {
277 Err(PyValueError::new_err(
278 "Event has no associated metadata for name-based operations",
279 ))
280 } else {
281 Ok(self.event.metadata())
282 }
283 }
284
285 fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
286 let metadata = self.ensure_metadata()?;
287 let mut resolved = Vec::new();
288 for name in names {
289 let selection = metadata
290 .p4_selection(name)
291 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
292 resolved.extend_from_slice(selection.indices());
293 }
294 Ok(resolved)
295 }
296
297 pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
298 self.has_metadata.then(|| self.event.metadata())
299 }
300}
301
302#[pyclass(name = "Dataset", module = "laddu", subclass)]
326#[derive(Clone)]
327pub struct PyDataset(pub Arc<Dataset>);
328
329#[pyclass(name = "DatasetIter", module = "laddu")]
330struct PyDatasetIter {
331 dataset: Arc<Dataset>,
332 index: usize,
333 total: usize,
334}
335
336#[pymethods]
337impl PyDatasetIter {
338 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetIter> {
339 slf.into()
340 }
341
342 fn __next__(&mut self) -> Option<PyEvent> {
343 if self.index >= self.total {
344 return None;
345 }
346 let event = self.dataset[self.index].clone();
347 self.index += 1;
348 Some(PyEvent {
349 event,
350 has_metadata: true,
351 })
352 }
353}
354
355#[pymethods]
356impl PyDataset {
357 #[new]
358 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
359 fn new(
360 events: Vec<PyEvent>,
361 p4_names: Option<Vec<String>>,
362 aux_names: Option<Vec<String>>,
363 aliases: Option<Bound<PyDict>>,
364 ) -> PyResult<Self> {
365 let inferred_metadata = events
366 .iter()
367 .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
368
369 let aliases = parse_aliases(aliases)?;
370 let use_explicit_metadata =
371 p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
372
373 let metadata =
374 if use_explicit_metadata {
375 let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
376 (Some(names), _) => names,
377 (None, Some(metadata)) => metadata.p4_names().to_vec(),
378 (None, None) => Vec::new(),
379 };
380 let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
381 (Some(names), _) => names,
382 (None, Some(metadata)) => metadata.aux_names().to_vec(),
383 (None, None) => Vec::new(),
384 };
385
386 if !aliases.is_empty() && resolved_p4_names.is_empty() {
387 return Err(PyValueError::new_err(
388 "`aliases` requires `p4_names` or events with metadata for resolution",
389 ));
390 }
391
392 let mut metadata = DatasetMetadata::new(resolved_p4_names, resolved_aux_names)
393 .map_err(PyErr::from)?;
394 if !aliases.is_empty() {
395 metadata
396 .add_p4_aliases(aliases.into_iter().map(|(alias_name, selection)| {
397 (alias_name, selection.into_selection())
398 }))
399 .map_err(PyErr::from)?;
400 }
401 Some(Arc::new(metadata))
402 } else {
403 inferred_metadata
404 };
405
406 let events: Vec<Arc<EventData>> = events
407 .into_iter()
408 .map(|event| event.event.data_arc())
409 .collect();
410 let dataset = if let Some(metadata) = metadata {
411 Dataset::new_with_metadata(events, metadata)
412 } else {
413 Dataset::new(events)
414 };
415 Ok(Self(Arc::new(dataset)))
416 }
417
418 fn __len__(&self) -> usize {
419 self.0.n_events()
420 }
421 fn __iter__(&self) -> PyDatasetIter {
422 PyDatasetIter {
423 dataset: self.0.clone(),
424 index: 0,
425 total: self.0.n_events(),
426 }
427 }
428 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
429 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
430 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
431 } else if let Ok(other_int) = other.extract::<usize>() {
432 if other_int == 0 {
433 Ok(self.clone())
434 } else {
435 Err(PyTypeError::new_err(
436 "Addition with an integer for this type is only defined for 0",
437 ))
438 }
439 } else {
440 Err(PyTypeError::new_err("Unsupported operand type for +"))
441 }
442 }
443 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
444 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
445 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
446 } else if let Ok(other_int) = other.extract::<usize>() {
447 if other_int == 0 {
448 Ok(self.clone())
449 } else {
450 Err(PyTypeError::new_err(
451 "Addition with an integer for this type is only defined for 0",
452 ))
453 }
454 } else {
455 Err(PyTypeError::new_err("Unsupported operand type for +"))
456 }
457 }
458 #[getter]
466 fn n_events(&self) -> usize {
467 self.0.n_events()
468 }
469 #[getter]
471 fn p4_names(&self) -> Vec<String> {
472 self.0.p4_names().to_vec()
473 }
474 #[getter]
476 fn aux_names(&self) -> Vec<String> {
477 self.0.aux_names().to_vec()
478 }
479
480 #[getter]
488 fn n_events_weighted(&self) -> f64 {
489 self.0.n_events_weighted()
490 }
491 #[getter]
499 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
500 PyArray1::from_slice(py, &self.0.weights())
501 }
502 #[getter]
516 fn events(&self) -> Vec<PyEvent> {
517 self.0
518 .events
519 .iter()
520 .map(|rust_event| PyEvent {
521 event: rust_event.clone(),
522 has_metadata: true,
523 })
524 .collect()
525 }
526 fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
528 self.0
529 .p4_by_name(index, name)
530 .map(PyVec4)
531 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
532 }
533 fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
535 self.0
536 .aux_by_name(index, name)
537 .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
538 }
539 fn __getitem__<'py>(
540 &self,
541 py: Python<'py>,
542 index: Bound<'py, PyAny>,
543 ) -> PyResult<Bound<'py, PyAny>> {
544 if let Ok(value) = self.evaluate(py, index.clone()) {
545 value.into_bound_py_any(py)
546 } else if let Ok(index) = index.extract::<usize>() {
547 PyEvent {
548 event: self.0[index].clone(),
549 has_metadata: true,
550 }
551 .into_bound_py_any(py)
552 } else {
553 Err(PyTypeError::new_err(
554 "Unsupported index type (int or Variable)",
555 ))
556 }
557 }
558 #[pyo3(signature = (variable, bins, range))]
596 fn bin_by(
597 &self,
598 variable: Bound<'_, PyAny>,
599 bins: usize,
600 range: (f64, f64),
601 ) -> PyResult<PyBinnedDataset> {
602 let py_variable = variable.extract::<PyVariable>()?;
603 let bound_variable = py_variable.bound(self.0.metadata())?;
604 Ok(PyBinnedDataset(self.0.bin_by(
605 bound_variable,
606 bins,
607 range,
608 )?))
609 }
610 pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
628 Ok(PyDataset(
629 self.0.filter(&expression.0).map_err(PyErr::from)?,
630 ))
631 }
632 fn bootstrap(&self, seed: usize) -> PyDataset {
653 PyDataset(self.0.bootstrap(seed))
654 }
655 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
673 PyDataset(self.0.boost_to_rest_frame_of(&names))
674 }
675 fn evaluate<'py>(
693 &self,
694 py: Python<'py>,
695 variable: Bound<'py, PyAny>,
696 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
697 let variable = variable.extract::<PyVariable>()?;
698 let bound_variable = variable.bound(self.0.metadata())?;
699 let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
700 Ok(PyArray1::from_vec(py, values))
701 }
702}
703
704#[pyfunction]
706#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
707pub fn read_parquet(
708 path: Bound<PyAny>,
709 p4s: Option<Vec<String>>,
710 aux: Option<Vec<String>>,
711 aliases: Option<Bound<PyDict>>,
712) -> PyResult<PyDataset> {
713 let path_str = parse_dataset_path(path)?;
714 let mut read_options = DatasetReadOptions::default();
715 if let Some(p4s) = p4s {
716 read_options = read_options.p4_names(p4s);
717 }
718 if let Some(aux) = aux {
719 read_options = read_options.aux_names(aux);
720 }
721 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
722 read_options = read_options.alias(alias_name, selection);
723 }
724 let dataset = core_read_parquet(&path_str, &read_options)?;
725 Ok(PyDataset(dataset))
726}
727
728#[pyfunction]
730#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
731pub fn read_root(
732 path: Bound<PyAny>,
733 tree: Option<String>,
734 p4s: Option<Vec<String>>,
735 aux: Option<Vec<String>>,
736 aliases: Option<Bound<PyDict>>,
737) -> PyResult<PyDataset> {
738 let path_str = parse_dataset_path(path)?;
739 let mut read_options = DatasetReadOptions::default();
740 if let Some(p4s) = p4s {
741 read_options = read_options.p4_names(p4s);
742 }
743 if let Some(aux) = aux {
744 read_options = read_options.aux_names(aux);
745 }
746 if let Some(tree) = tree {
747 read_options = read_options.tree(tree);
748 }
749 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
750 read_options = read_options.alias(alias_name, selection);
751 }
752 let dataset = core_read_root(&path_str, &read_options)?;
753 Ok(PyDataset(dataset))
754}
755
756#[pyfunction]
758#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
759pub fn write_parquet(
760 dataset: &PyDataset,
761 path: Bound<PyAny>,
762 chunk_size: Option<usize>,
763 precision: &str,
764) -> PyResult<()> {
765 let path_str = parse_dataset_path(path)?;
766 let mut write_options = DatasetWriteOptions::default();
767 if let Some(size) = chunk_size {
768 write_options.batch_size = size.max(1);
769 }
770 write_options.precision = parse_precision_arg(Some(precision))?;
771 core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
772}
773
774#[pyfunction]
776#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
777pub fn write_root(
778 dataset: &PyDataset,
779 path: Bound<PyAny>,
780 tree: Option<String>,
781 chunk_size: Option<usize>,
782 precision: &str,
783) -> PyResult<()> {
784 let path_str = parse_dataset_path(path)?;
785 let mut write_options = DatasetWriteOptions::default();
786 if let Some(name) = tree {
787 write_options.tree = Some(name);
788 }
789 if let Some(size) = chunk_size {
790 write_options.batch_size = size.max(1);
791 }
792 write_options.precision = parse_precision_arg(Some(precision))?;
793 core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
794}
795
796#[pyclass(name = "BinnedDataset", module = "laddu")]
805pub struct PyBinnedDataset(BinnedDataset);
806
807#[pymethods]
808impl PyBinnedDataset {
809 fn __len__(&self) -> usize {
810 self.0.n_bins()
811 }
812 #[getter]
815 fn n_bins(&self) -> usize {
816 self.0.n_bins()
817 }
818 #[getter]
821 fn range(&self) -> (f64, f64) {
822 self.0.range()
823 }
824 #[getter]
827 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
828 PyArray1::from_slice(py, &self.0.edges())
829 }
830 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
831 self.0
832 .get(index)
833 .ok_or(PyIndexError::new_err("index out of range"))
834 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
835 }
836}