1use crate::utils::variables::{PyVariable, PyVariableExpression};
2use laddu_core::{
3 data::{
4 BinnedDataset, Dataset, DatasetMetadata, DatasetWriteOptions, Event, EventData,
5 FloatPrecision,
6 },
7 utils::variables::IntoP4Selection,
8 DatasetReadOptions,
9};
10use numpy::PyArray1;
11use pyo3::{
12 exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
13 prelude::*,
14 types::PyDict,
15 IntoPyObjectExt,
16};
17use std::{path::PathBuf, sync::Arc};
18
19use crate::utils::vectors::PyVec4;
20
21fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
22 let Some(aliases) = aliases else {
23 return Ok(Vec::new());
24 };
25
26 let mut parsed = Vec::new();
27 for (key, value) in aliases.iter() {
28 let alias_name = key.extract::<String>()?;
29 let selection = if let Ok(single) = value.extract::<String>() {
30 vec![single]
31 } else {
32 let seq = value.extract::<Vec<String>>().map_err(|_| {
33 PyTypeError::new_err("Alias values must be a string or a sequence of strings")
34 })?;
35 if seq.is_empty() {
36 return Err(PyValueError::new_err(format!(
37 "Alias '{alias_name}' must reference at least one particle",
38 )));
39 }
40 seq
41 };
42 parsed.push((alias_name, selection));
43 }
44
45 Ok(parsed)
46}
47
48fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
49 if let Ok(s) = path.extract::<String>() {
50 Ok(s)
51 } else if let Ok(pathbuf) = path.extract::<PathBuf>() {
52 Ok(pathbuf.to_string_lossy().into_owned())
53 } else {
54 Err(PyTypeError::new_err("Expected str or Path"))
55 }
56}
57
58fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
59 match value.map(|v| v.to_ascii_lowercase()) {
60 None => Ok(FloatPrecision::F64),
61 Some(name) if name == "f64" || name == "float64" || name == "double" => {
62 Ok(FloatPrecision::F64)
63 }
64 Some(name) if name == "f32" || name == "float32" || name == "float" => {
65 Ok(FloatPrecision::F32)
66 }
67 Some(other) => Err(PyValueError::new_err(format!(
68 "Unsupported precision '{other}' (expected 'f64' or 'f32')"
69 ))),
70 }
71}
72
73#[pyclass(name = "Event", module = "laddu")]
109#[derive(Clone)]
110pub struct PyEvent {
111 pub event: Event,
112 has_metadata: bool,
113}
114
115#[pymethods]
116impl PyEvent {
117 #[new]
118 #[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
119 fn new(
120 p4s: Vec<PyVec4>,
121 aux: Vec<f64>,
122 weight: f64,
123 p4_names: Option<Vec<String>>,
124 aux_names: Option<Vec<String>>,
125 aliases: Option<Bound<PyDict>>,
126 ) -> PyResult<Self> {
127 let event = EventData {
128 p4s: p4s.into_iter().map(|arr| arr.0).collect(),
129 aux,
130 weight,
131 };
132 let aliases = parse_aliases(aliases)?;
133
134 let missing_p4_names = p4_names
135 .as_ref()
136 .map(|names| names.is_empty())
137 .unwrap_or(true);
138
139 if !aliases.is_empty() && missing_p4_names {
140 return Err(PyValueError::new_err(
141 "`aliases` requires `p4_names` so selections can be resolved",
142 ));
143 }
144
145 let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
146 let metadata = if metadata_provided {
147 let p4_names = p4_names.unwrap_or_default();
148 let aux_names = aux_names.unwrap_or_default();
149 let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
150 if !aliases.is_empty() {
151 metadata
152 .add_p4_aliases(
153 aliases.into_iter().map(|(alias_name, selection)| {
154 (alias_name, selection.into_selection())
155 }),
156 )
157 .map_err(PyErr::from)?;
158 }
159 Arc::new(metadata)
160 } else {
161 Arc::new(DatasetMetadata::empty())
162 };
163 let event = Event::new(Arc::new(event), metadata);
164 Ok(Self {
165 event,
166 has_metadata: metadata_provided,
167 })
168 }
169 fn __str__(&self) -> String {
170 self.event.data().to_string()
171 }
172 #[getter]
175 fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
176 self.ensure_metadata()?;
177 let mapping = PyDict::new(py);
178 for (name, vec4) in self.event.p4s() {
179 mapping.set_item(name, PyVec4(vec4))?;
180 }
181 Ok(mapping.into())
182 }
183 #[getter]
186 #[pyo3(name = "aux")]
187 fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
188 self.ensure_metadata()?;
189 let mapping = PyDict::new(py);
190 for (name, value) in self.event.aux() {
191 mapping.set_item(name, value)?;
192 }
193 Ok(mapping.into())
194 }
195 #[getter]
198 fn get_weight(&self) -> f64 {
199 self.event.weight()
200 }
201 fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
214 let indices = self.resolve_p4_indices(&names)?;
215 Ok(PyVec4(self.event.data().get_p4_sum(indices)))
216 }
217 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
231 let indices = self.resolve_p4_indices(&names)?;
232 let boosted = self.event.data().boost_to_rest_frame_of(indices);
233 Ok(Self {
234 event: Event::new(Arc::new(boosted), self.event.metadata_arc()),
235 has_metadata: self.has_metadata,
236 })
237 }
238 fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
255 let mut variable = variable.extract::<PyVariable>()?;
256 if !self.has_metadata {
257 return Err(PyValueError::new_err(
258 "Cannot evaluate variable on an Event without associated metadata. Construct the Event with `p4_names`/`aux_names` or evaluate through a Dataset.",
259 ));
260 }
261 variable.bind_in_place(self.event.metadata())?;
262 let event_arc = self.event.data_arc();
263 variable.evaluate_event(&event_arc)
264 }
265
266 fn p4(&self, name: &str) -> PyResult<Option<PyVec4>> {
268 self.ensure_metadata()?;
269 Ok(self.event.p4(name).map(PyVec4))
270 }
271}
272
273impl PyEvent {
274 fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
275 if !self.has_metadata {
276 Err(PyValueError::new_err(
277 "Event has no associated metadata for name-based operations",
278 ))
279 } else {
280 Ok(self.event.metadata())
281 }
282 }
283
284 fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
285 let metadata = self.ensure_metadata()?;
286 let mut resolved = Vec::new();
287 for name in names {
288 let selection = metadata
289 .p4_selection(name)
290 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
291 resolved.extend_from_slice(selection.indices());
292 }
293 Ok(resolved)
294 }
295
296 pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
297 self.has_metadata.then(|| self.event.metadata())
298 }
299}
300
301#[pyclass(name = "Dataset", module = "laddu", subclass)]
325#[derive(Clone)]
326pub struct PyDataset(pub Arc<Dataset>);
327
328#[pyclass(name = "DatasetIter", module = "laddu")]
329struct PyDatasetIter {
330 dataset: Arc<Dataset>,
331 index: usize,
332 total: usize,
333}
334
335#[pymethods]
336impl PyDatasetIter {
337 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetIter> {
338 slf.into()
339 }
340
341 fn __next__(&mut self) -> Option<PyEvent> {
342 if self.index >= self.total {
343 return None;
344 }
345 let event = self.dataset[self.index].clone();
346 self.index += 1;
347 Some(PyEvent {
348 event,
349 has_metadata: true,
350 })
351 }
352}
353
354#[pymethods]
355impl PyDataset {
356 #[new]
357 #[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
358 fn new(
359 events: Vec<PyEvent>,
360 p4_names: Option<Vec<String>>,
361 aux_names: Option<Vec<String>>,
362 aliases: Option<Bound<PyDict>>,
363 ) -> PyResult<Self> {
364 let inferred_metadata = events
365 .iter()
366 .find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
367
368 let aliases = parse_aliases(aliases)?;
369 let use_explicit_metadata =
370 p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
371
372 let metadata =
373 if use_explicit_metadata {
374 let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
375 (Some(names), _) => names,
376 (None, Some(metadata)) => metadata.p4_names().to_vec(),
377 (None, None) => Vec::new(),
378 };
379 let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
380 (Some(names), _) => names,
381 (None, Some(metadata)) => metadata.aux_names().to_vec(),
382 (None, None) => Vec::new(),
383 };
384
385 if !aliases.is_empty() && resolved_p4_names.is_empty() {
386 return Err(PyValueError::new_err(
387 "`aliases` requires `p4_names` or events with metadata for resolution",
388 ));
389 }
390
391 let mut metadata = DatasetMetadata::new(resolved_p4_names, resolved_aux_names)
392 .map_err(PyErr::from)?;
393 if !aliases.is_empty() {
394 metadata
395 .add_p4_aliases(aliases.into_iter().map(|(alias_name, selection)| {
396 (alias_name, selection.into_selection())
397 }))
398 .map_err(PyErr::from)?;
399 }
400 Some(Arc::new(metadata))
401 } else {
402 inferred_metadata
403 };
404
405 let events: Vec<Arc<EventData>> = events
406 .into_iter()
407 .map(|event| event.event.data_arc())
408 .collect();
409 let dataset = if let Some(metadata) = metadata {
410 Dataset::new_with_metadata(events, metadata)
411 } else {
412 Dataset::new(events)
413 };
414 Ok(Self(Arc::new(dataset)))
415 }
416
417 #[staticmethod]
435 #[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
436 fn from_parquet(
437 path: Bound<PyAny>,
438 p4s: Option<Vec<String>>,
439 aux: Option<Vec<String>>,
440 aliases: Option<Bound<PyDict>>,
441 ) -> PyResult<Self> {
442 let path_str = parse_dataset_path(path)?;
443
444 let mut read_options = DatasetReadOptions::default();
445 if let Some(p4s) = p4s {
446 read_options = read_options.p4_names(p4s);
447 }
448 if let Some(aux) = aux {
449 read_options = read_options.aux_names(aux);
450 }
451 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
452 read_options = read_options.alias(alias_name, selection);
453 }
454 let dataset = Dataset::from_parquet(&path_str, &read_options)?;
455
456 Ok(Self(dataset))
457 }
458
459 #[staticmethod]
479 #[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
480 fn from_root(
481 path: Bound<PyAny>,
482 tree: Option<String>,
483 p4s: Option<Vec<String>>,
484 aux: Option<Vec<String>>,
485 aliases: Option<Bound<PyDict>>,
486 ) -> PyResult<Self> {
487 let path_str = parse_dataset_path(path)?;
488
489 let mut read_options = DatasetReadOptions::default();
490 if let Some(p4s) = p4s {
491 read_options = read_options.p4_names(p4s);
492 }
493 if let Some(aux) = aux {
494 read_options = read_options.aux_names(aux);
495 }
496 if let Some(tree) = tree {
497 read_options = read_options.tree(tree);
498 }
499 for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
500 read_options = read_options.alias(alias_name, selection);
501 }
502 let dataset = Dataset::from_root(&path_str, &read_options)?;
503
504 Ok(Self(dataset))
505 }
506 fn __len__(&self) -> usize {
507 self.0.n_events()
508 }
509 fn __iter__(&self) -> PyDatasetIter {
510 PyDatasetIter {
511 dataset: self.0.clone(),
512 index: 0,
513 total: self.0.n_events(),
514 }
515 }
516 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
517 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
518 Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
519 } else if let Ok(other_int) = other.extract::<usize>() {
520 if other_int == 0 {
521 Ok(self.clone())
522 } else {
523 Err(PyTypeError::new_err(
524 "Addition with an integer for this type is only defined for 0",
525 ))
526 }
527 } else {
528 Err(PyTypeError::new_err("Unsupported operand type for +"))
529 }
530 }
531 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
532 if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
533 Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
534 } else if let Ok(other_int) = other.extract::<usize>() {
535 if other_int == 0 {
536 Ok(self.clone())
537 } else {
538 Err(PyTypeError::new_err(
539 "Addition with an integer for this type is only defined for 0",
540 ))
541 }
542 } else {
543 Err(PyTypeError::new_err("Unsupported operand type for +"))
544 }
545 }
546 #[getter]
554 fn n_events(&self) -> usize {
555 self.0.n_events()
556 }
557 #[getter]
559 fn p4_names(&self) -> Vec<String> {
560 self.0.p4_names().to_vec()
561 }
562 #[getter]
564 fn aux_names(&self) -> Vec<String> {
565 self.0.aux_names().to_vec()
566 }
567
568 #[pyo3(signature = (path, *, chunk_size=None, precision="f64"))]
570 fn to_parquet(
571 &self,
572 path: Bound<'_, PyAny>,
573 chunk_size: Option<usize>,
574 precision: &str,
575 ) -> PyResult<()> {
576 let path_str = parse_dataset_path(path)?;
577 let mut write_options = DatasetWriteOptions::default();
578 if let Some(size) = chunk_size {
579 write_options.batch_size = size.max(1);
580 }
581 write_options.precision = parse_precision_arg(Some(precision))?;
582
583 self.0
584 .to_parquet(&path_str, &write_options)
585 .map_err(PyErr::from)
586 }
587
588 #[pyo3(signature = (path, *, tree=None, chunk_size=None, precision="f64"))]
590 fn to_root(
591 &self,
592 path: Bound<'_, PyAny>,
593 tree: Option<String>,
594 chunk_size: Option<usize>,
595 precision: &str,
596 ) -> PyResult<()> {
597 let path_str = parse_dataset_path(path)?;
598 let mut write_options = DatasetWriteOptions::default();
599 if let Some(name) = tree {
600 write_options.tree = Some(name);
601 }
602 if let Some(size) = chunk_size {
603 write_options.batch_size = size.max(1);
604 }
605 write_options.precision = parse_precision_arg(Some(precision))?;
606
607 self.0
608 .to_root(&path_str, &write_options)
609 .map_err(PyErr::from)
610 }
611 #[getter]
619 fn n_events_weighted(&self) -> f64 {
620 self.0.n_events_weighted()
621 }
622 #[getter]
630 fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
631 PyArray1::from_slice(py, &self.0.weights())
632 }
633 #[getter]
647 fn events(&self) -> Vec<PyEvent> {
648 self.0
649 .events
650 .iter()
651 .map(|rust_event| PyEvent {
652 event: rust_event.clone(),
653 has_metadata: true,
654 })
655 .collect()
656 }
657 fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
659 self.0
660 .p4_by_name(index, name)
661 .map(PyVec4)
662 .ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
663 }
664 fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
666 self.0
667 .aux_by_name(index, name)
668 .ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
669 }
670 fn __getitem__<'py>(
671 &self,
672 py: Python<'py>,
673 index: Bound<'py, PyAny>,
674 ) -> PyResult<Bound<'py, PyAny>> {
675 if let Ok(value) = self.evaluate(py, index.clone()) {
676 value.into_bound_py_any(py)
677 } else if let Ok(index) = index.extract::<usize>() {
678 PyEvent {
679 event: self.0[index].clone(),
680 has_metadata: true,
681 }
682 .into_bound_py_any(py)
683 } else {
684 Err(PyTypeError::new_err(
685 "Unsupported index type (int or Variable)",
686 ))
687 }
688 }
689 #[pyo3(signature = (variable, bins, range))]
727 fn bin_by(
728 &self,
729 variable: Bound<'_, PyAny>,
730 bins: usize,
731 range: (f64, f64),
732 ) -> PyResult<PyBinnedDataset> {
733 let py_variable = variable.extract::<PyVariable>()?;
734 let bound_variable = py_variable.bound(self.0.metadata())?;
735 Ok(PyBinnedDataset(self.0.bin_by(
736 bound_variable,
737 bins,
738 range,
739 )?))
740 }
741 pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
759 Ok(PyDataset(
760 self.0.filter(&expression.0).map_err(PyErr::from)?,
761 ))
762 }
763 fn bootstrap(&self, seed: usize) -> PyDataset {
784 PyDataset(self.0.bootstrap(seed))
785 }
786 pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
804 PyDataset(self.0.boost_to_rest_frame_of(&names))
805 }
806 fn evaluate<'py>(
824 &self,
825 py: Python<'py>,
826 variable: Bound<'py, PyAny>,
827 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
828 let variable = variable.extract::<PyVariable>()?;
829 let bound_variable = variable.bound(self.0.metadata())?;
830 let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
831 Ok(PyArray1::from_vec(py, values))
832 }
833}
834
835#[pyclass(name = "BinnedDataset", module = "laddu")]
844pub struct PyBinnedDataset(BinnedDataset);
845
846#[pymethods]
847impl PyBinnedDataset {
848 fn __len__(&self) -> usize {
849 self.0.n_bins()
850 }
851 #[getter]
854 fn n_bins(&self) -> usize {
855 self.0.n_bins()
856 }
857 #[getter]
860 fn range(&self) -> (f64, f64) {
861 self.0.range()
862 }
863 #[getter]
866 fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
867 PyArray1::from_slice(py, &self.0.edges())
868 }
869 fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
870 self.0
871 .get(index)
872 .ok_or(PyIndexError::new_err("index out of range"))
873 .map(|rust_dataset| PyDataset(rust_dataset.clone()))
874 }
875}