use std::{path::PathBuf, sync::Arc};
use laddu_core::{
data::{
io::{
infer_p4_and_aux_names_from_columns, resolve_columns_case_insensitive,
resolve_optional_weight_column, resolve_p4_component_columns, P4_COMPONENT_SUFFIXES,
},
read_parquet as core_read_parquet,
read_parquet_chunks_with_options as core_read_parquet_chunks_with_options,
read_root as core_read_root, write_parquet as core_write_parquet,
write_root as core_write_root, BinnedDataset, Dataset, DatasetArcIter, DatasetMetadata,
DatasetWriteOptions, EventData, FloatPrecision, OwnedEvent, SharedDatasetIterExt,
},
variables::IntoP4Selection,
DatasetReadOptions, Vec4,
};
use numpy::{PyArray1, PyReadonlyArray1};
use pyo3::{
exceptions::{PyIndexError, PyKeyError, PyTypeError, PyValueError},
prelude::*,
types::{PyDict, PyList},
IntoPyObjectExt,
};
use crate::{
variables::{PyVariable, PyVariableExpression},
vectors::PyVec4,
};
fn parse_aliases(aliases: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, Vec<String>)>> {
let Some(aliases) = aliases else {
return Ok(Vec::new());
};
let mut parsed = Vec::new();
for (key, value) in aliases.iter() {
let alias_name = key.extract::<String>()?;
let selection = if let Ok(single) = value.extract::<String>() {
vec![single]
} else {
let seq = value.extract::<Vec<String>>().map_err(|_| {
PyTypeError::new_err("Alias values must be a string or a sequence of strings")
})?;
if seq.is_empty() {
return Err(PyValueError::new_err(format!(
"Alias '{alias_name}' must reference at least one particle",
)));
}
seq
};
parsed.push((alias_name, selection));
}
Ok(parsed)
}
fn parse_dataset_path(path: Bound<'_, PyAny>) -> PyResult<String> {
if let Ok(s) = path.extract::<String>() {
Ok(s)
} else if let Ok(pathbuf) = path.extract::<PathBuf>() {
Ok(pathbuf.to_string_lossy().into_owned())
} else {
Err(PyTypeError::new_err("Expected str or Path"))
}
}
fn parse_precision_arg(value: Option<&str>) -> PyResult<FloatPrecision> {
match value.map(|v| v.to_ascii_lowercase()) {
None => Ok(FloatPrecision::F64),
Some(name) if name == "f64" || name == "float64" || name == "double" => {
Ok(FloatPrecision::F64)
}
Some(name) if name == "f32" || name == "float32" || name == "float" => {
Ok(FloatPrecision::F32)
}
Some(other) => Err(PyValueError::new_err(format!(
"Unsupported precision '{other}' (expected 'f64' or 'f32')"
))),
}
}
fn extract_numeric_column(value: Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
return Ok(array.as_slice()?.to_vec());
}
if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f32>>() {
return Ok(array.as_slice()?.iter().map(|v| *v as f64).collect());
}
if let Ok(values) = value.extract::<Vec<f64>>() {
return Ok(values);
}
if let Ok(values) = value.extract::<Vec<f32>>() {
return Ok(values.into_iter().map(|v| v as f64).collect());
}
if let Ok(list) = value.cast::<PyList>() {
let mut converted = Vec::with_capacity(list.len());
for item in list.iter() {
converted.push(item.extract::<f64>().map_err(|_| {
PyTypeError::new_err(format!(
"Column '{name}' must be numeric (float32/float64/list of floats)"
))
})?);
}
return Ok(converted);
}
Err(PyTypeError::new_err(format!(
"Column '{name}' must be numeric (float32/float64/list of floats)"
)))
}
fn metadata_from_names_and_aliases(
p4_names: Vec<String>,
aux_names: Vec<String>,
aliases: Option<Bound<'_, PyDict>>,
) -> PyResult<DatasetMetadata> {
let parsed_aliases = parse_aliases(aliases)?;
let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
if !parsed_aliases.is_empty() {
metadata
.add_p4_aliases(
parsed_aliases
.into_iter()
.map(|(alias_name, selection)| (alias_name, selection.into_selection())),
)
.map_err(PyErr::from)?;
}
Ok(metadata)
}
fn parse_p4_mapping(p4: Bound<'_, PyDict>) -> PyResult<Vec<(String, Vec4)>> {
p4.iter()
.map(|(key, value)| {
Ok((
key.extract::<String>()?,
value
.extract::<PyVec4>()
.map_err(|_| PyTypeError::new_err("p4 values must be laddu.Vec4 instances"))?
.0,
))
})
.collect()
}
fn parse_aux_mapping(aux: Option<Bound<'_, PyDict>>) -> PyResult<Vec<(String, f64)>> {
let Some(aux) = aux else {
return Ok(Vec::new());
};
aux.iter()
.map(|(key, value)| Ok((key.extract::<String>()?, value.extract::<f64>()?)))
.collect()
}
fn parse_p4_column(values: Vec<PyVec4>) -> Vec<Vec4> {
values.into_iter().map(|value| value.0).collect()
}
fn dataset_from_py_events(
events: Vec<PyEvent>,
p4_names: Option<Vec<String>>,
aux_names: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
global: bool,
) -> PyResult<PyDataset> {
let inferred_metadata = events
.iter()
.find_map(|event| event.has_metadata.then(|| event.event.metadata_arc()));
let aliases = parse_aliases(aliases)?;
let use_explicit_metadata = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
let metadata = if use_explicit_metadata {
let resolved_p4_names = match (p4_names, inferred_metadata.as_ref()) {
(Some(names), _) => names,
(None, Some(metadata)) => metadata.p4_names().to_vec(),
(None, None) => Vec::new(),
};
let resolved_aux_names = match (aux_names, inferred_metadata.as_ref()) {
(Some(names), _) => names,
(None, Some(metadata)) => metadata.aux_names().to_vec(),
(None, None) => Vec::new(),
};
if !aliases.is_empty() && resolved_p4_names.is_empty() {
return Err(PyValueError::new_err(
"`aliases` requires `p4_names` or events with metadata for resolution",
));
}
let mut metadata =
DatasetMetadata::new(resolved_p4_names, resolved_aux_names).map_err(PyErr::from)?;
if !aliases.is_empty() {
metadata
.add_p4_aliases(
aliases
.into_iter()
.map(|(alias_name, selection)| (alias_name, selection.into_selection())),
)
.map_err(PyErr::from)?;
}
Some(Arc::new(metadata))
} else {
inferred_metadata
};
let events: Vec<Arc<EventData>> = events
.into_iter()
.map(|event| event.event.data_arc())
.collect();
let dataset = match (metadata, global) {
(Some(metadata), true) => Dataset::new_with_metadata(events, metadata),
(Some(metadata), false) => Dataset::new_local(events, metadata),
(None, true) => Dataset::new(events),
(None, false) => Dataset::new_local(events, Arc::new(DatasetMetadata::default())),
};
Ok(PyDataset(Arc::new(dataset)))
}
#[pyclass(name = "Event", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyEvent {
pub event: OwnedEvent,
has_metadata: bool,
}
#[pymethods]
impl PyEvent {
#[new]
#[pyo3(signature = (p4s, aux, weight, *, p4_names=None, aux_names=None, aliases=None))]
fn new(
p4s: Vec<PyVec4>,
aux: Vec<f64>,
weight: f64,
p4_names: Option<Vec<String>>,
aux_names: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
) -> PyResult<Self> {
let event = EventData {
p4s: p4s.into_iter().map(|arr| arr.0).collect(),
aux,
weight,
};
let aliases = parse_aliases(aliases)?;
let missing_p4_names = p4_names
.as_ref()
.map(|names| names.is_empty())
.unwrap_or(true);
if !aliases.is_empty() && missing_p4_names {
return Err(PyValueError::new_err(
"`aliases` requires `p4_names` so selections can be resolved",
));
}
let metadata_provided = p4_names.is_some() || aux_names.is_some() || !aliases.is_empty();
let metadata = if metadata_provided {
let p4_names = p4_names.unwrap_or_default();
let aux_names = aux_names.unwrap_or_default();
let mut metadata = DatasetMetadata::new(p4_names, aux_names).map_err(PyErr::from)?;
if !aliases.is_empty() {
metadata
.add_p4_aliases(
aliases.into_iter().map(|(alias_name, selection)| {
(alias_name, selection.into_selection())
}),
)
.map_err(PyErr::from)?;
}
Arc::new(metadata)
} else {
Arc::new(DatasetMetadata::empty())
};
let event = OwnedEvent::new(Arc::new(event), metadata);
Ok(Self {
event,
has_metadata: metadata_provided,
})
}
fn __str__(&self) -> String {
self.event.to_string()
}
fn __repr__(&self) -> String {
self.__str__()
}
#[getter]
fn p4s<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
self.ensure_metadata()?;
let mapping = PyDict::new(py);
for (name, vec4) in self.event.p4s() {
mapping.set_item(name, PyVec4(vec4))?;
}
Ok(mapping.into())
}
#[getter]
#[pyo3(name = "aux")]
fn aux_mapping<'py>(&self, py: Python<'py>) -> PyResult<Py<PyDict>> {
self.ensure_metadata()?;
let mapping = PyDict::new(py);
for (name, value) in self.event.aux() {
mapping.set_item(name, value)?;
}
Ok(mapping.into())
}
#[getter]
fn get_weight(&self) -> f64 {
self.event.weight()
}
fn get_p4_sum(&self, names: Vec<String>) -> PyResult<PyVec4> {
let indices = self.resolve_p4_indices(&names)?;
Ok(PyVec4(self.event.data().get_p4_sum(indices)))
}
pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyResult<Self> {
let indices = self.resolve_p4_indices(&names)?;
let boosted = self.event.data().boost_to_rest_frame_of(indices);
Ok(Self {
event: OwnedEvent::new(Arc::new(boosted), self.event.metadata_arc()),
has_metadata: self.has_metadata,
})
}
fn evaluate(&self, variable: Bound<'_, PyAny>) -> PyResult<f64> {
let mut variable = variable.extract::<PyVariable>()?;
let metadata = self.ensure_metadata()?;
variable.bind_in_place(metadata)?;
variable.evaluate_event(&self.event)
}
fn p4(&self, name: &str) -> PyResult<PyVec4> {
self.ensure_metadata()?;
self.event
.p4(name)
.map(PyVec4)
.ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
}
}
impl PyEvent {
fn ensure_metadata(&self) -> PyResult<&DatasetMetadata> {
if !self.has_metadata {
Err(PyValueError::new_err(
"Event has no associated metadata for name-based operations",
))
} else {
Ok(self.event.metadata())
}
}
fn resolve_p4_indices(&self, names: &[String]) -> PyResult<Vec<usize>> {
let metadata = self.ensure_metadata()?;
let mut resolved = Vec::new();
for name in names {
let selection = metadata
.p4_selection(name)
.ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))?;
resolved.extend_from_slice(selection.indices());
}
Ok(resolved)
}
pub(crate) fn metadata_opt(&self) -> Option<&DatasetMetadata> {
self.has_metadata.then(|| self.event.metadata())
}
}
#[doc(hidden)]
#[pyclass(name = "Dataset", module = "laddu", subclass, skip_from_py_object)]
#[derive(Clone)]
pub struct PyDataset(pub Arc<Dataset>);
#[pyclass(
name = "ParquetChunkIter",
module = "laddu",
unsendable,
skip_from_py_object
)]
pub struct PyParquetChunkIter {
chunks: Box<dyn Iterator<Item = laddu_core::LadduResult<Arc<Dataset>>> + Send>,
}
#[pymethods]
impl PyParquetChunkIter {
fn __iter__(slf: PyRef<'_, Self>) -> Py<PyParquetChunkIter> {
slf.into()
}
fn __next__(&mut self) -> PyResult<Option<PyDataset>> {
match self.chunks.next() {
Some(Ok(dataset)) => Ok(Some(PyDataset(dataset))),
Some(Err(err)) => Err(PyErr::from(err)),
None => Ok(None),
}
}
}
#[pyclass(
name = "DatasetEventsGlobal",
module = "laddu",
unsendable,
skip_from_py_object
)]
pub struct PyDatasetEventsGlobalIter {
iter: DatasetArcIter,
}
#[pymethods]
impl PyDatasetEventsGlobalIter {
fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsGlobalIter> {
slf.into()
}
fn __next__(&mut self) -> Option<PyEvent> {
self.iter.next().map(|rust_event| PyEvent {
event: rust_event,
has_metadata: true,
})
}
}
#[pyclass(
name = "DatasetEventsLocal",
module = "laddu",
unsendable,
skip_from_py_object
)]
pub struct PyDatasetEventsLocalIter {
dataset: Arc<Dataset>,
index: usize,
}
#[pymethods]
impl PyDatasetEventsLocalIter {
fn __iter__(slf: PyRef<'_, Self>) -> Py<PyDatasetEventsLocalIter> {
slf.into()
}
fn __next__(&mut self) -> Option<PyEvent> {
if self.index >= self.dataset.n_events_local() {
return None;
}
let event = self
.dataset
.event_local(self.index)
.expect("local event index should exist")
.to_event_data();
self.index += 1;
Some(PyEvent {
event: OwnedEvent::new(Arc::new(event), self.dataset.metadata_arc()),
has_metadata: true,
})
}
}
#[pymethods]
impl PyDataset {
#[new]
#[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
fn new(
events: Vec<PyEvent>,
p4_names: Option<Vec<String>>,
aux_names: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
) -> PyResult<Self> {
dataset_from_py_events(events, p4_names, aux_names, aliases, true)
}
#[staticmethod]
#[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
fn from_events_local(
events: Vec<PyEvent>,
p4_names: Option<Vec<String>>,
aux_names: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
) -> PyResult<Self> {
dataset_from_py_events(events, p4_names, aux_names, aliases, false)
}
#[staticmethod]
#[pyo3(signature = (events, *, p4_names=None, aux_names=None, aliases=None))]
fn from_events_global(
events: Vec<PyEvent>,
p4_names: Option<Vec<String>>,
aux_names: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
) -> PyResult<Self> {
dataset_from_py_events(events, p4_names, aux_names, aliases, true)
}
#[staticmethod]
#[pyo3(signature = (*, p4_names, aux_names=None, aliases=None))]
fn empty_local(
p4_names: Vec<String>,
aux_names: Option<Vec<String>>,
aliases: Option<Bound<'_, PyDict>>,
) -> PyResult<Self> {
let metadata =
metadata_from_names_and_aliases(p4_names, aux_names.unwrap_or_default(), aliases)?;
Ok(Self(Arc::new(Dataset::empty_local(metadata))))
}
#[pyo3(signature = (*, p4, aux=None, weight=1.0))]
fn push_event_local(
&mut self,
p4: Bound<'_, PyDict>,
aux: Option<Bound<'_, PyDict>>,
weight: f64,
) -> PyResult<()> {
let p4 = parse_p4_mapping(p4)?;
let aux = parse_aux_mapping(aux)?;
Arc::make_mut(&mut self.0)
.push_event_named_local(p4, aux, weight)
.map_err(PyErr::from)
}
#[pyo3(signature = (*, p4, aux=None, weight=1.0))]
fn push_event_global(
&mut self,
p4: Bound<'_, PyDict>,
aux: Option<Bound<'_, PyDict>>,
weight: f64,
) -> PyResult<()> {
let p4 = parse_p4_mapping(p4)?;
let aux = parse_aux_mapping(aux)?;
Arc::make_mut(&mut self.0)
.push_event_named_global(p4, aux, weight)
.map_err(PyErr::from)
}
#[pyo3(signature = (name, values))]
fn add_p4_column_local(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
Arc::make_mut(&mut self.0)
.add_p4_column_local(name, parse_p4_column(values))
.map_err(PyErr::from)
}
#[pyo3(signature = (name, values))]
fn add_aux_column_local(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
let values = extract_numeric_column(values, &name)?;
Arc::make_mut(&mut self.0)
.add_aux_column_local(name, values)
.map_err(PyErr::from)
}
#[pyo3(signature = (name, values))]
fn add_p4_column_global(&mut self, name: String, values: Vec<PyVec4>) -> PyResult<()> {
Arc::make_mut(&mut self.0)
.add_p4_column_global(name, parse_p4_column(values))
.map_err(PyErr::from)
}
#[pyo3(signature = (name, values))]
fn add_aux_column_global(&mut self, name: String, values: Bound<'_, PyAny>) -> PyResult<()> {
let values = extract_numeric_column(values, &name)?;
Arc::make_mut(&mut self.0)
.add_aux_column_global(name, values)
.map_err(PyErr::from)
}
fn __len__(&self) -> usize {
self.0.n_events()
}
fn __iter__(&self) -> PyResult<()> {
Err(PyTypeError::new_err(
"Dataset iteration is explicit; use dataset.events_local or dataset.events_global",
))
}
#[getter]
fn n_events_local(&self) -> usize {
self.0.n_events_local()
}
fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
Ok(PyDataset(Arc::new(self.0.as_ref() + other_ds.0.as_ref())))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(self.clone())
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyDataset> {
if let Ok(other_ds) = other.extract::<PyRef<PyDataset>>() {
Ok(PyDataset(Arc::new(other_ds.0.as_ref() + self.0.as_ref())))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(self.clone())
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __repr__(&self) -> String {
format!(
"Dataset(n_events={}, n_events_local={}, p4_names={:?}, aux_names={:?})",
self.0.n_events_global(),
self.0.n_events_local(),
self.0.p4_names(),
self.0.aux_names()
)
}
fn __str__(&self) -> String {
self.__repr__()
}
#[getter]
fn n_events(&self) -> usize {
self.0.n_events()
}
#[getter]
fn n_events_global(&self) -> usize {
self.0.n_events_global()
}
#[getter]
fn p4_names(&self) -> Vec<String> {
self.0.p4_names().to_vec()
}
#[getter]
fn aux_names(&self) -> Vec<String> {
self.0.aux_names().to_vec()
}
#[getter]
fn n_events_weighted(&self) -> f64 {
self.0.n_events_weighted()
}
#[getter]
fn n_events_weighted_global(&self) -> f64 {
self.0.n_events_weighted_global()
}
#[getter]
fn n_events_weighted_local(&self) -> f64 {
self.0.n_events_weighted_local()
}
#[getter]
fn weights<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
PyArray1::from_slice(py, &self.0.weights())
}
#[getter]
fn weights_global<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
PyArray1::from_slice(py, &self.0.weights_global())
}
#[getter]
fn weights_local<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
PyArray1::from_slice(py, &self.0.weights_local())
}
#[getter]
fn events_global(&self) -> PyDatasetEventsGlobalIter {
PyDatasetEventsGlobalIter {
iter: self.0.shared_iter_global(),
}
}
#[getter]
fn events_local(&self) -> PyDatasetEventsLocalIter {
PyDatasetEventsLocalIter {
dataset: self.0.clone(),
index: 0,
}
}
fn p4_by_name(&self, index: usize, name: &str) -> PyResult<PyVec4> {
self.0
.p4_by_name(index, name)
.map(PyVec4)
.ok_or_else(|| PyKeyError::new_err(format!("Unknown particle name '{name}'")))
}
fn aux_by_name(&self, index: usize, name: &str) -> PyResult<f64> {
self.0
.aux_by_name(index, name)
.ok_or_else(|| PyKeyError::new_err(format!("Unknown auxiliary name '{name}'")))
}
fn event_global(&self, index: usize) -> PyResult<PyEvent> {
let event = self
.0
.event_global(index)
.map_err(|_| PyIndexError::new_err("index out of range"))?;
Ok(PyEvent {
event,
has_metadata: true,
})
}
fn __getitem__<'py>(
&self,
py: Python<'py>,
index: Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
if let Ok(value) = self.evaluate(py, index.clone()) {
value.into_bound_py_any(py)
} else if let Ok(index) = index.extract::<usize>() {
let event = self
.0
.event_global(index)
.map_err(|_| PyIndexError::new_err("index out of range"))?;
PyEvent {
event,
has_metadata: true,
}
.into_bound_py_any(py)
} else {
Err(PyTypeError::new_err(
"Unsupported index type (int or Variable)",
))
}
}
#[pyo3(signature = (variable, bins, range))]
fn bin_by(
&self,
variable: Bound<'_, PyAny>,
bins: usize,
range: (f64, f64),
) -> PyResult<PyBinnedDataset> {
let py_variable = variable.extract::<PyVariable>()?;
let bound_variable = py_variable.bound(self.0.metadata())?;
Ok(PyBinnedDataset(self.0.bin_by(
bound_variable,
bins,
range,
)?))
}
pub fn filter(&self, expression: &PyVariableExpression) -> PyResult<PyDataset> {
Ok(PyDataset(
self.0.filter(&expression.0).map_err(PyErr::from)?,
))
}
fn bootstrap(&self, seed: usize) -> PyDataset {
PyDataset(self.0.bootstrap(seed))
}
pub fn boost_to_rest_frame_of(&self, names: Vec<String>) -> PyDataset {
PyDataset(self.0.boost_to_rest_frame_of(&names))
}
fn evaluate<'py>(
&self,
py: Python<'py>,
variable: Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyArray1<f64>>> {
let variable = variable.extract::<PyVariable>()?;
let bound_variable = variable.bound(self.0.metadata())?;
let values = self.0.evaluate(&bound_variable).map_err(PyErr::from)?;
Ok(PyArray1::from_vec(py, values))
}
}
#[pyfunction]
#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None))]
pub fn read_parquet(
path: Bound<PyAny>,
p4s: Option<Vec<String>>,
aux: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
) -> PyResult<PyDataset> {
let path_str = parse_dataset_path(path)?;
let mut read_options = DatasetReadOptions::default();
if let Some(p4s) = p4s {
read_options = read_options.p4_names(p4s);
}
if let Some(aux) = aux {
read_options = read_options.aux_names(aux);
}
for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
read_options = read_options.alias(alias_name, selection);
}
let dataset = core_read_parquet(&path_str, &read_options)?;
Ok(PyDataset(dataset))
}
#[pyfunction]
#[pyo3(signature = (path, *, p4s=None, aux=None, aliases=None, chunk_size=None))]
pub fn read_parquet_chunked(
path: Bound<PyAny>,
p4s: Option<Vec<String>>,
aux: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
chunk_size: Option<usize>,
) -> PyResult<PyParquetChunkIter> {
let path_str = parse_dataset_path(path)?;
let mut read_options = DatasetReadOptions::default();
if let Some(p4s) = p4s {
read_options = read_options.p4_names(p4s);
}
if let Some(aux) = aux {
read_options = read_options.aux_names(aux);
}
if let Some(chunk_size) = chunk_size {
read_options = read_options.chunk_size(chunk_size);
}
for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
read_options = read_options.alias(alias_name, selection);
}
let chunks = core_read_parquet_chunks_with_options(&path_str, &read_options)?;
Ok(PyParquetChunkIter {
chunks: Box::new(chunks),
})
}
#[pyfunction]
#[pyo3(signature = (path, *, tree=None, p4s=None, aux=None, aliases=None))]
pub fn read_root(
path: Bound<PyAny>,
tree: Option<String>,
p4s: Option<Vec<String>>,
aux: Option<Vec<String>>,
aliases: Option<Bound<PyDict>>,
) -> PyResult<PyDataset> {
let path_str = parse_dataset_path(path)?;
let mut read_options = DatasetReadOptions::default();
if let Some(p4s) = p4s {
read_options = read_options.p4_names(p4s);
}
if let Some(aux) = aux {
read_options = read_options.aux_names(aux);
}
if let Some(tree) = tree {
read_options = read_options.tree(tree);
}
for (alias_name, selection) in parse_aliases(aliases)?.into_iter() {
read_options = read_options.alias(alias_name, selection);
}
let dataset = core_read_root(&path_str, &read_options)?;
Ok(PyDataset(dataset))
}
#[pyfunction]
#[pyo3(signature = (dataset, path, *, chunk_size=None, precision="f64"))]
pub fn write_parquet(
dataset: &PyDataset,
path: Bound<PyAny>,
chunk_size: Option<usize>,
precision: &str,
) -> PyResult<()> {
let path_str = parse_dataset_path(path)?;
let mut write_options = DatasetWriteOptions::default();
if let Some(size) = chunk_size {
write_options.batch_size = size.max(1);
}
write_options.precision = parse_precision_arg(Some(precision))?;
core_write_parquet(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
}
#[pyfunction]
#[pyo3(signature = (dataset, path, *, tree=None, chunk_size=None, precision="f64"))]
pub fn write_root(
dataset: &PyDataset,
path: Bound<PyAny>,
tree: Option<String>,
chunk_size: Option<usize>,
precision: &str,
) -> PyResult<()> {
let path_str = parse_dataset_path(path)?;
let mut write_options = DatasetWriteOptions::default();
if let Some(name) = tree {
write_options.tree = Some(name);
}
if let Some(size) = chunk_size {
write_options.batch_size = size.max(1);
}
write_options.precision = parse_precision_arg(Some(precision))?;
core_write_root(dataset.0.as_ref(), &path_str, &write_options).map_err(PyErr::from)
}
#[doc(hidden)]
#[pyfunction]
#[pyo3(signature = (columns, *, p4s=None, aux=None, aliases=None))]
pub fn from_columns(
columns: Bound<'_, PyDict>,
p4s: Option<Vec<String>>,
aux: Option<Vec<String>>,
aliases: Option<Bound<'_, PyDict>>,
) -> PyResult<PyDataset> {
let column_names = columns
.iter()
.map(|(key, _)| key.extract::<String>())
.collect::<PyResult<Vec<_>>>()?;
let (detected_p4_names, detected_aux_names) =
infer_p4_and_aux_names_from_columns(&column_names);
let p4_names = p4s.unwrap_or(detected_p4_names);
if p4_names.is_empty() {
let mut partial_components: std::collections::BTreeMap<
String,
std::collections::BTreeSet<&str>,
> = std::collections::BTreeMap::new();
for column_name in &column_names {
let lowered = column_name.to_ascii_lowercase();
for suffix in P4_COMPONENT_SUFFIXES {
if lowered.ends_with(suffix) && column_name.len() > suffix.len() {
let prefix = column_name[..column_name.len() - suffix.len()].to_string();
partial_components.entry(prefix).or_default().insert(suffix);
}
}
}
if let Some((prefix, present)) = partial_components.iter().next() {
if present.len() < P4_COMPONENT_SUFFIXES.len() {
let missing = P4_COMPONENT_SUFFIXES
.iter()
.filter(|suffix| !present.contains(**suffix))
.map(|suffix| format!("{prefix}{suffix}"))
.collect::<Vec<_>>()
.join(", ");
return Err(PyKeyError::new_err(format!(
"Missing components [{missing}] for four-momentum '{prefix}'"
)));
}
}
return Err(PyValueError::new_err(
"No four-momentum columns found (expected *_px, *_py, *_pz, *_e)",
));
}
let aux_names = aux.unwrap_or(detected_aux_names);
let p4_component_columns =
resolve_p4_component_columns(&column_names, &p4_names).map_err(PyErr::from)?;
let resolved_aux_columns =
resolve_columns_case_insensitive(&column_names, &aux_names).map_err(PyErr::from)?;
let n_events = {
let first_name = p4_component_columns
.first()
.map(|components| components[0].clone())
.ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?;
let values = extract_numeric_column(
columns
.get_item(first_name.as_str())?
.ok_or_else(|| PyKeyError::new_err("Missing required p4 column"))?,
&first_name,
)?;
values.len()
};
let mut p4_columns: Vec<[Vec<f64>; 4]> = Vec::with_capacity(p4_names.len());
for component_names in &p4_component_columns {
let px = extract_numeric_column(
columns
.get_item(component_names[0].as_str())?
.ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[0])))?,
component_names[0].as_str(),
)?;
let py = extract_numeric_column(
columns
.get_item(component_names[1].as_str())?
.ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[1])))?,
component_names[1].as_str(),
)?;
let pz = extract_numeric_column(
columns
.get_item(component_names[2].as_str())?
.ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[2])))?,
component_names[2].as_str(),
)?;
let e = extract_numeric_column(
columns
.get_item(component_names[3].as_str())?
.ok_or_else(|| PyKeyError::new_err(format!("Missing {}", component_names[3])))?,
component_names[3].as_str(),
)?;
if px.len() != n_events
|| py.len() != n_events
|| pz.len() != n_events
|| e.len() != n_events
{
return Err(PyValueError::new_err(
"All p4 components must have the same length",
));
}
p4_columns.push([px, py, pz, e]);
}
let mut aux_columns: Vec<Vec<f64>> = Vec::with_capacity(resolved_aux_columns.len());
for (aux_name, aux_column_name) in aux_names.iter().zip(&resolved_aux_columns) {
let values = extract_numeric_column(
columns.get_item(aux_column_name.as_str())?.ok_or_else(|| {
PyKeyError::new_err(format!("Missing auxiliary column '{aux_name}'"))
})?,
aux_name,
)?;
if values.len() != n_events {
return Err(PyValueError::new_err(format!(
"Auxiliary column '{aux_name}' length does not match p4 columns"
)));
}
aux_columns.push(values);
}
let weights = if let Some(weight_column_name) = resolve_optional_weight_column(&column_names) {
let weight_values = columns
.get_item(weight_column_name.as_str())?
.ok_or_else(|| PyKeyError::new_err("Missing weight column"))?;
let values = extract_numeric_column(weight_values, "weight")?;
if values.len() != n_events {
return Err(PyValueError::new_err(
"Column 'weight' length does not match p4 columns",
));
}
values
} else {
vec![1.0; n_events]
};
let parsed_aliases = parse_aliases(aliases)?;
let mut metadata =
DatasetMetadata::new(p4_names.clone(), aux_names.clone()).map_err(PyErr::from)?;
if !parsed_aliases.is_empty() {
metadata
.add_p4_aliases(
parsed_aliases
.into_iter()
.map(|(alias_name, selection)| (alias_name, selection.into_selection())),
)
.map_err(PyErr::from)?;
}
let p4_columns = p4_columns
.into_iter()
.map(|components| {
(0..n_events)
.map(|event_idx| {
laddu_core::vectors::Vec4::new(
components[0][event_idx],
components[1][event_idx],
components[2][event_idx],
components[3][event_idx],
)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
Ok(PyDataset(Arc::new(Dataset::from_columns_global(
metadata,
p4_columns,
aux_columns,
weights,
)?)))
}
#[pyclass(name = "BinnedDataset", module = "laddu", skip_from_py_object)]
pub struct PyBinnedDataset(BinnedDataset);
#[pymethods]
impl PyBinnedDataset {
fn __len__(&self) -> usize {
self.0.n_bins()
}
#[getter]
fn n_bins(&self) -> usize {
self.0.n_bins()
}
#[getter]
fn range(&self) -> (f64, f64) {
self.0.range()
}
#[getter]
fn edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
PyArray1::from_slice(py, &self.0.edges())
}
fn __getitem__(&self, index: usize) -> PyResult<PyDataset> {
self.0
.get(index)
.ok_or(PyIndexError::new_err("index out of range"))
.map(|rust_dataset| PyDataset(rust_dataset.clone()))
}
fn __repr__(&self) -> String {
format!(
"BinnedDataset(n_bins={}, range={:?})",
self.0.n_bins(),
self.0.range()
)
}
fn __str__(&self) -> String {
self.__repr__()
}
}