use std::{collections::HashMap, sync::Arc};
use laddu_generation::{
CompositeGenerator, Distribution, EventGenerator, GeneratedBatch, GeneratedEventLayout,
GeneratedParticle, GeneratedParticleLayout, GeneratedReaction, GeneratedStorage,
GeneratedVertexKind, GeneratedVertexLayout, HistogramSampler, InitialGenerator,
MandelstamTDistribution, ParticleSpecies, Reconstruction, StableGenerator,
};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyTuple};
use crate::{data::PyDataset, math::PyHistogram, variables::PyReaction, vectors::PyVec4};
#[pyclass(name = "Distribution", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyDistribution(pub Distribution);
#[pymethods]
impl PyDistribution {
#[staticmethod]
fn fixed(value: f64) -> Self {
Self(Distribution::Fixed(value))
}
#[staticmethod]
fn uniform(min: f64, max: f64) -> PyResult<Self> {
if max <= min {
return Err(PyValueError::new_err(
"`max` must be greater than `min` for a uniform distribution",
));
}
Ok(Self(Distribution::Uniform { min, max }))
}
#[staticmethod]
fn normal(mu: f64, sigma: f64) -> PyResult<Self> {
if sigma <= 0.0 {
return Err(PyValueError::new_err(
"`sigma` must be positive for a normal distribution",
));
}
Ok(Self(Distribution::Normal { mu, sigma }))
}
#[staticmethod]
fn exponential(slope: f64) -> PyResult<Self> {
if slope <= 0.0 {
return Err(PyValueError::new_err(
"`slope` must be positive for an exponential distribution",
));
}
Ok(Self(Distribution::Exponential { slope }))
}
#[staticmethod]
fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
Ok(Self(Distribution::Histogram(HistogramSampler::new(
histogram.0.clone(),
)?)))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "MandelstamTDistribution", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyMandelstamTDistribution(pub MandelstamTDistribution);
#[pymethods]
impl PyMandelstamTDistribution {
#[staticmethod]
fn exponential(slope: f64) -> PyResult<Self> {
if slope <= 0.0 {
return Err(PyValueError::new_err(
"`slope` must be positive for an exponential distribution",
));
}
Ok(Self(MandelstamTDistribution::Exponential { slope }))
}
#[staticmethod]
fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
Ok(Self(MandelstamTDistribution::Histogram(
HistogramSampler::new(histogram.0.clone())?,
)))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "InitialGenerator", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyInitialGenerator(pub InitialGenerator);
#[pymethods]
impl PyInitialGenerator {
#[staticmethod]
fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
Self(InitialGenerator::beam_with_fixed_energy(mass, energy))
}
#[staticmethod]
fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
Self(InitialGenerator::beam(mass, min_energy, max_energy))
}
#[staticmethod]
fn beam_with_energy_histogram(mass: f64, energy: &PyHistogram) -> PyResult<Self> {
Ok(Self(InitialGenerator::beam_with_energy_histogram(
mass,
energy.0.clone(),
)?))
}
#[staticmethod]
fn target(mass: f64) -> Self {
Self(InitialGenerator::target(mass))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "CompositeGenerator", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyCompositeGenerator(pub CompositeGenerator);
#[pymethods]
impl PyCompositeGenerator {
#[new]
fn new(min_mass: f64, max_mass: f64) -> Self {
Self(CompositeGenerator::new(min_mass, max_mass))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "StableGenerator", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyStableGenerator(pub StableGenerator);
#[pymethods]
impl PyStableGenerator {
#[new]
fn new(mass: f64) -> Self {
Self(StableGenerator::new(mass))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "Reconstruction", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyReconstruction(pub Reconstruction);
#[pymethods]
impl PyReconstruction {
#[staticmethod]
fn stored() -> Self {
Self(Reconstruction::Stored)
}
#[staticmethod]
fn fixed(p4: &PyVec4) -> Self {
Self(Reconstruction::Fixed(p4.0))
}
#[staticmethod]
fn missing() -> Self {
Self(Reconstruction::Missing)
}
#[staticmethod]
fn composite() -> Self {
Self(Reconstruction::Composite)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "ParticleSpecies", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyParticleSpecies(pub ParticleSpecies);
#[pymethods]
impl PyParticleSpecies {
#[staticmethod]
fn code(id: i64) -> Self {
Self(ParticleSpecies::code(id))
}
#[staticmethod]
fn with_namespace(namespace: &str, id: i64) -> Self {
Self(ParticleSpecies::with_namespace(namespace, id))
}
#[staticmethod]
fn label(label: &str) -> Self {
Self(ParticleSpecies::label(label))
}
#[getter]
fn id(&self) -> Option<i64> {
match &self.0 {
ParticleSpecies::Code { id, .. } => Some(*id),
ParticleSpecies::Label(_) => None,
}
}
#[getter]
fn namespace(&self) -> Option<String> {
match &self.0 {
ParticleSpecies::Code { namespace, .. } => namespace.clone(),
ParticleSpecies::Label(_) => None,
}
}
#[getter]
fn label_value(&self) -> Option<String> {
match &self.0 {
ParticleSpecies::Code { .. } => None,
ParticleSpecies::Label(label) => Some(label.clone()),
}
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "GeneratedParticle", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyGeneratedParticle(pub GeneratedParticle);
#[pymethods]
impl PyGeneratedParticle {
#[staticmethod]
fn initial(
id: &str,
generator: &PyInitialGenerator,
reconstruction: &PyReconstruction,
) -> Self {
Self(GeneratedParticle::initial(
id,
generator.0.clone(),
reconstruction.0.clone(),
))
}
#[staticmethod]
fn stable(id: &str, generator: &PyStableGenerator, reconstruction: &PyReconstruction) -> Self {
Self(GeneratedParticle::stable(
id,
generator.0.clone(),
reconstruction.0.clone(),
))
}
#[staticmethod]
fn composite(
id: &str,
generator: &PyCompositeGenerator,
daughters: &Bound<'_, PyTuple>,
reconstruction: &PyReconstruction,
) -> PyResult<Self> {
if daughters.len() != 2 {
return Err(PyValueError::new_err(
"composite particles require exactly two ordered daughters",
));
}
let daughter_1 = daughters.get_item(0)?.extract::<Self>()?;
let daughter_2 = daughters.get_item(1)?.extract::<Self>()?;
Ok(Self(GeneratedParticle::composite(
id,
generator.0.clone(),
(&daughter_1.0, &daughter_2.0),
reconstruction.0.clone(),
)))
}
fn with_species(&self, species: &PyParticleSpecies) -> Self {
Self(self.0.clone().with_species(species.0.clone()))
}
#[getter]
fn id(&self) -> String {
self.0.id().to_string()
}
#[getter]
fn species(&self) -> Option<PyParticleSpecies> {
self.0.species().cloned().map(PyParticleSpecies)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "GeneratedReaction", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyGeneratedReaction(pub GeneratedReaction);
#[pymethods]
impl PyGeneratedReaction {
#[staticmethod]
fn two_to_two(
p1: &PyGeneratedParticle,
p2: &PyGeneratedParticle,
p3: &PyGeneratedParticle,
p4: &PyGeneratedParticle,
tdist: &PyMandelstamTDistribution,
) -> PyResult<Self> {
Ok(Self(GeneratedReaction::two_to_two(
p1.0.clone(),
p2.0.clone(),
p3.0.clone(),
p4.0.clone(),
tdist.0.clone(),
)?))
}
fn p4_labels(&self) -> Vec<String> {
self.0.p4_labels()
}
fn particle_layouts(&self) -> Vec<PyGeneratedParticleLayout> {
self.0
.particle_layouts()
.into_iter()
.map(PyGeneratedParticleLayout)
.collect()
}
fn reconstructed_reaction(&self) -> PyResult<PyReaction> {
Ok(PyReaction(self.0.reconstructed_reaction()?))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "GeneratedStorage", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyGeneratedStorage(pub GeneratedStorage);
#[pymethods]
impl PyGeneratedStorage {
#[staticmethod]
fn all() -> Self {
Self(GeneratedStorage::all())
}
#[staticmethod]
fn only(ids: Vec<String>) -> Self {
Self(GeneratedStorage::only(ids))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "GeneratedParticleLayout", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyGeneratedParticleLayout(pub GeneratedParticleLayout);
#[pymethods]
impl PyGeneratedParticleLayout {
#[getter]
fn id(&self) -> String {
self.0.id().to_string()
}
#[getter]
fn product_id(&self) -> usize {
self.0.product_id()
}
#[getter]
fn parent_id(&self) -> Option<usize> {
self.0.parent_id()
}
#[getter]
fn species(&self) -> Option<PyParticleSpecies> {
self.0.species().cloned().map(PyParticleSpecies)
}
#[getter]
fn p4_label(&self) -> Option<String> {
self.0.p4_label().map(str::to_string)
}
#[getter]
fn produced_vertex_id(&self) -> Option<usize> {
self.0.produced_vertex_id()
}
#[getter]
fn decay_vertex_id(&self) -> Option<usize> {
self.0.decay_vertex_id()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "GeneratedVertexLayout", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyGeneratedVertexLayout(pub GeneratedVertexLayout);
#[pymethods]
impl PyGeneratedVertexLayout {
#[getter]
fn vertex_id(&self) -> usize {
self.0.vertex_id()
}
#[getter]
fn kind(&self) -> &'static str {
match self.0.kind() {
GeneratedVertexKind::Production => "Production",
GeneratedVertexKind::Decay => "Decay",
}
}
#[getter]
fn incoming_product_ids(&self) -> Vec<usize> {
self.0.incoming_product_ids().to_vec()
}
#[getter]
fn outgoing_product_ids(&self) -> Vec<usize> {
self.0.outgoing_product_ids().to_vec()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "GeneratedEventLayout", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyGeneratedEventLayout(pub GeneratedEventLayout);
#[pymethods]
impl PyGeneratedEventLayout {
#[getter]
fn p4_labels(&self) -> Vec<String> {
self.0.p4_labels().to_vec()
}
#[getter]
fn aux_labels(&self) -> Vec<String> {
self.0.aux_labels().to_vec()
}
#[getter]
fn particles(&self) -> Vec<PyGeneratedParticleLayout> {
self.0
.particles()
.iter()
.cloned()
.map(PyGeneratedParticleLayout)
.collect()
}
fn particle(&self, id: &str) -> Option<PyGeneratedParticleLayout> {
self.0.particle(id).cloned().map(PyGeneratedParticleLayout)
}
fn product(&self, product_id: usize) -> Option<PyGeneratedParticleLayout> {
self.0
.product(product_id)
.cloned()
.map(PyGeneratedParticleLayout)
}
#[getter]
fn vertices(&self) -> Vec<PyGeneratedVertexLayout> {
self.0
.vertices()
.iter()
.cloned()
.map(PyGeneratedVertexLayout)
.collect()
}
fn vertex(&self, vertex_id: usize) -> Option<PyGeneratedVertexLayout> {
self.0
.vertex(vertex_id)
.cloned()
.map(PyGeneratedVertexLayout)
}
fn production_vertex(&self) -> Option<PyGeneratedVertexLayout> {
self.0
.production_vertex()
.cloned()
.map(PyGeneratedVertexLayout)
}
fn decay_products(&self, parent_product_id: usize) -> Vec<PyGeneratedParticleLayout> {
self.0
.decay_products(parent_product_id)
.into_iter()
.cloned()
.map(PyGeneratedParticleLayout)
.collect()
}
fn production_incoming(&self) -> Vec<PyGeneratedParticleLayout> {
self.0
.production_incoming()
.into_iter()
.cloned()
.map(PyGeneratedParticleLayout)
.collect()
}
fn production_outgoing(&self) -> Vec<PyGeneratedParticleLayout> {
self.0
.production_outgoing()
.into_iter()
.cloned()
.map(PyGeneratedParticleLayout)
.collect()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "GeneratedBatch", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyGeneratedBatch(pub GeneratedBatch);
#[pymethods]
impl PyGeneratedBatch {
#[getter]
fn dataset(&self) -> PyDataset {
PyDataset(Arc::new(self.0.dataset().clone()))
}
#[getter]
fn reaction(&self) -> PyGeneratedReaction {
PyGeneratedReaction(self.0.reaction().clone())
}
#[getter]
fn layout(&self) -> PyGeneratedEventLayout {
PyGeneratedEventLayout(self.0.layout().clone())
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(
name = "GeneratedBatchIter",
module = "laddu",
unsendable,
skip_from_py_object
)]
pub struct PyGeneratedBatchIter {
iter: Box<dyn Iterator<Item = laddu_core::LadduResult<GeneratedBatch>>>,
}
#[pymethods]
impl PyGeneratedBatchIter {
fn __iter__(slf: PyRef<'_, Self>) -> Py<PyGeneratedBatchIter> {
slf.into()
}
fn __next__(&mut self) -> PyResult<Option<PyGeneratedBatch>> {
match self.iter.next() {
Some(Ok(batch)) => Ok(Some(PyGeneratedBatch(batch))),
Some(Err(err)) => Err(PyErr::from(err)),
None => Ok(None),
}
}
}
#[pyclass(name = "EventGenerator", module = "laddu", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyEventGenerator(pub EventGenerator);
#[pymethods]
impl PyEventGenerator {
#[new]
#[pyo3(signature = (reaction, aux_generators=None, seed=None, storage=None))]
fn new(
reaction: &PyGeneratedReaction,
aux_generators: Option<HashMap<String, PyDistribution>>,
seed: Option<u64>,
storage: Option<&PyGeneratedStorage>,
) -> PyResult<Self> {
let generator = EventGenerator::new(
reaction.0.clone(),
aux_generators
.unwrap_or_default()
.into_iter()
.map(|(name, distribution)| (name, distribution.0))
.collect(),
seed,
);
let generator = if let Some(storage) = storage {
generator.with_storage(storage.0.clone())?
} else {
generator
};
Ok(Self(generator))
}
fn generate_batch(&self, n_events: usize) -> PyResult<PyGeneratedBatch> {
Ok(PyGeneratedBatch(self.0.generate_batch(n_events)?))
}
fn generate_batches(
&self,
total_events: usize,
batch_size: usize,
) -> PyResult<PyGeneratedBatchIter> {
Ok(PyGeneratedBatchIter {
iter: Box::new(self.0.generate_batches(total_events, batch_size)?),
})
}
fn generate_dataset(&self, n_events: usize) -> PyResult<PyDataset> {
Ok(PyDataset(Arc::new(self.0.generate_dataset(n_events)?)))
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}