Skip to main content

laddu_python/
generation.rs

1use std::{collections::HashMap, sync::Arc};
2
3use laddu_generation::{
4    CompositeGenerator, Distribution, EventGenerator, GeneratedBatch, GeneratedEventLayout,
5    GeneratedParticle, GeneratedParticleLayout, GeneratedReaction, GeneratedStorage,
6    GeneratedVertexKind, GeneratedVertexLayout, HistogramSampler, InitialGenerator,
7    MandelstamTDistribution, ParticleSpecies, Reconstruction, StableGenerator,
8};
9use pyo3::{exceptions::PyValueError, prelude::*, types::PyTuple};
10
11use crate::{data::PyDataset, math::PyHistogram, variables::PyReaction, vectors::PyVec4};
12
13/// A scalar distribution used by generated auxiliary columns.
14#[pyclass(name = "Distribution", module = "laddu", from_py_object)]
15#[derive(Clone, Debug)]
16pub struct PyDistribution(pub Distribution);
17
18#[pymethods]
19impl PyDistribution {
20    /// Construct a fixed scalar distribution.
21    #[staticmethod]
22    fn fixed(value: f64) -> Self {
23        Self(Distribution::Fixed(value))
24    }
25
26    /// Construct a uniform scalar distribution.
27    #[staticmethod]
28    fn uniform(min: f64, max: f64) -> PyResult<Self> {
29        if max <= min {
30            return Err(PyValueError::new_err(
31                "`max` must be greater than `min` for a uniform distribution",
32            ));
33        }
34        Ok(Self(Distribution::Uniform { min, max }))
35    }
36
37    /// Construct a normal scalar distribution.
38    #[staticmethod]
39    fn normal(mu: f64, sigma: f64) -> PyResult<Self> {
40        if sigma <= 0.0 {
41            return Err(PyValueError::new_err(
42                "`sigma` must be positive for a normal distribution",
43            ));
44        }
45        Ok(Self(Distribution::Normal { mu, sigma }))
46    }
47
48    /// Construct an exponential scalar distribution.
49    #[staticmethod]
50    fn exponential(slope: f64) -> PyResult<Self> {
51        if slope <= 0.0 {
52            return Err(PyValueError::new_err(
53                "`slope` must be positive for an exponential distribution",
54            ));
55        }
56        Ok(Self(Distribution::Exponential { slope }))
57    }
58
59    /// Construct a histogram-sampled scalar distribution.
60    #[staticmethod]
61    fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
62        Ok(Self(Distribution::Histogram(HistogramSampler::new(
63            histogram.0.clone(),
64        )?)))
65    }
66
67    fn __repr__(&self) -> String {
68        format!("{:?}", self.0)
69    }
70}
71
72/// A Mandelstam-t distribution for generated two-to-two reactions.
73#[pyclass(name = "MandelstamTDistribution", module = "laddu", from_py_object)]
74#[derive(Clone, Debug)]
75pub struct PyMandelstamTDistribution(pub MandelstamTDistribution);
76
77#[pymethods]
78impl PyMandelstamTDistribution {
79    /// Construct an exponential Mandelstam-t distribution.
80    #[staticmethod]
81    fn exponential(slope: f64) -> PyResult<Self> {
82        if slope <= 0.0 {
83            return Err(PyValueError::new_err(
84                "`slope` must be positive for an exponential distribution",
85            ));
86        }
87        Ok(Self(MandelstamTDistribution::Exponential { slope }))
88    }
89
90    /// Construct a histogram-sampled Mandelstam-t distribution.
91    #[staticmethod]
92    fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
93        Ok(Self(MandelstamTDistribution::Histogram(
94            HistogramSampler::new(histogram.0.clone())?,
95        )))
96    }
97
98    fn __repr__(&self) -> String {
99        format!("{:?}", self.0)
100    }
101}
102
103/// Generator settings for an initial generated particle.
104#[pyclass(name = "InitialGenerator", module = "laddu", from_py_object)]
105#[derive(Clone, Debug)]
106pub struct PyInitialGenerator(pub InitialGenerator);
107
108#[pymethods]
109impl PyInitialGenerator {
110    /// Construct a beam with fixed energy.
111    #[staticmethod]
112    fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
113        Self(InitialGenerator::beam_with_fixed_energy(mass, energy))
114    }
115
116    /// Construct a beam with uniformly sampled energy.
117    #[staticmethod]
118    fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
119        Self(InitialGenerator::beam(mass, min_energy, max_energy))
120    }
121
122    /// Construct a beam with histogram-sampled energy.
123    #[staticmethod]
124    fn beam_with_energy_histogram(mass: f64, energy: &PyHistogram) -> PyResult<Self> {
125        Ok(Self(InitialGenerator::beam_with_energy_histogram(
126            mass,
127            energy.0.clone(),
128        )?))
129    }
130
131    /// Construct a target at rest.
132    #[staticmethod]
133    fn target(mass: f64) -> Self {
134        Self(InitialGenerator::target(mass))
135    }
136
137    fn __repr__(&self) -> String {
138        format!("{:?}", self.0)
139    }
140}
141
142/// Generator settings for a generated composite particle.
143#[pyclass(name = "CompositeGenerator", module = "laddu", from_py_object)]
144#[derive(Clone, Debug)]
145pub struct PyCompositeGenerator(pub CompositeGenerator);
146
147#[pymethods]
148impl PyCompositeGenerator {
149    /// Construct a composite mass generator with a uniform mass range.
150    #[new]
151    fn new(min_mass: f64, max_mass: f64) -> Self {
152        Self(CompositeGenerator::new(min_mass, max_mass))
153    }
154
155    fn __repr__(&self) -> String {
156        format!("{:?}", self.0)
157    }
158}
159
160/// Generator settings for a stable generated particle.
161#[pyclass(name = "StableGenerator", module = "laddu", from_py_object)]
162#[derive(Clone, Debug)]
163pub struct PyStableGenerator(pub StableGenerator);
164
165#[pymethods]
166impl PyStableGenerator {
167    /// Construct a fixed-mass stable-particle generator.
168    #[new]
169    fn new(mass: f64) -> Self {
170        Self(StableGenerator::new(mass))
171    }
172
173    fn __repr__(&self) -> String {
174        format!("{:?}", self.0)
175    }
176}
177
178/// Reconstruction metadata for a generated particle.
179#[pyclass(name = "Reconstruction", module = "laddu", from_py_object)]
180#[derive(Clone, Debug)]
181pub struct PyReconstruction(pub Reconstruction);
182
183#[pymethods]
184impl PyReconstruction {
185    /// Mark a generated particle as stored under its generated ID.
186    #[staticmethod]
187    fn stored() -> Self {
188        Self(Reconstruction::Stored)
189    }
190
191    /// Mark a generated particle as fixed in the reconstructed reaction.
192    #[staticmethod]
193    fn fixed(p4: &PyVec4) -> Self {
194        Self(Reconstruction::Fixed(p4.0))
195    }
196
197    /// Mark a generated particle as missing in the reconstructed reaction.
198    #[staticmethod]
199    fn missing() -> Self {
200        Self(Reconstruction::Missing)
201    }
202
203    /// Mark a generated particle as reconstructed from its generated daughters.
204    #[staticmethod]
205    fn composite() -> Self {
206        Self(Reconstruction::Composite)
207    }
208
209    fn __repr__(&self) -> String {
210        format!("{:?}", self.0)
211    }
212}
213
214/// Experiment-neutral metadata describing a generated particle species.
215#[pyclass(name = "ParticleSpecies", module = "laddu", from_py_object)]
216#[derive(Clone, Debug)]
217pub struct PyParticleSpecies(pub ParticleSpecies);
218
219#[pymethods]
220impl PyParticleSpecies {
221    /// Construct a species from a numeric code with no namespace.
222    #[staticmethod]
223    fn code(id: i64) -> Self {
224        Self(ParticleSpecies::code(id))
225    }
226
227    /// Construct a species from a numeric code in an explicit namespace.
228    #[staticmethod]
229    fn with_namespace(namespace: &str, id: i64) -> Self {
230        Self(ParticleSpecies::with_namespace(namespace, id))
231    }
232
233    /// Construct a species from a free-form label.
234    #[staticmethod]
235    fn label(label: &str) -> Self {
236        Self(ParticleSpecies::label(label))
237    }
238
239    /// The numeric species code, if this is a code-based species.
240    #[getter]
241    fn id(&self) -> Option<i64> {
242        match &self.0 {
243            ParticleSpecies::Code { id, .. } => Some(*id),
244            ParticleSpecies::Label(_) => None,
245        }
246    }
247
248    /// The numeric species namespace, if this is a namespaced code-based species.
249    #[getter]
250    fn namespace(&self) -> Option<String> {
251        match &self.0 {
252            ParticleSpecies::Code { namespace, .. } => namespace.clone(),
253            ParticleSpecies::Label(_) => None,
254        }
255    }
256
257    /// The species label, if this is a label-based species.
258    #[getter]
259    fn label_value(&self) -> Option<String> {
260        match &self.0 {
261            ParticleSpecies::Code { .. } => None,
262            ParticleSpecies::Label(label) => Some(label.clone()),
263        }
264    }
265
266    fn __repr__(&self) -> String {
267        format!("{:?}", self.0)
268    }
269}
270
271/// A generated particle with generation and reconstruction metadata.
272#[pyclass(name = "GeneratedParticle", module = "laddu", from_py_object)]
273#[derive(Clone, Debug)]
274pub struct PyGeneratedParticle(pub GeneratedParticle);
275
276#[pymethods]
277impl PyGeneratedParticle {
278    /// Construct an initial generated particle.
279    #[staticmethod]
280    fn initial(
281        id: &str,
282        generator: &PyInitialGenerator,
283        reconstruction: &PyReconstruction,
284    ) -> Self {
285        Self(GeneratedParticle::initial(
286            id,
287            generator.0.clone(),
288            reconstruction.0.clone(),
289        ))
290    }
291
292    /// Construct a stable generated particle.
293    #[staticmethod]
294    fn stable(id: &str, generator: &PyStableGenerator, reconstruction: &PyReconstruction) -> Self {
295        Self(GeneratedParticle::stable(
296            id,
297            generator.0.clone(),
298            reconstruction.0.clone(),
299        ))
300    }
301
302    /// Construct a generated composite from exactly two ordered daughters.
303    #[staticmethod]
304    fn composite(
305        id: &str,
306        generator: &PyCompositeGenerator,
307        daughters: &Bound<'_, PyTuple>,
308        reconstruction: &PyReconstruction,
309    ) -> PyResult<Self> {
310        if daughters.len() != 2 {
311            return Err(PyValueError::new_err(
312                "composite particles require exactly two ordered daughters",
313            ));
314        }
315        let daughter_1 = daughters.get_item(0)?.extract::<Self>()?;
316        let daughter_2 = daughters.get_item(1)?.extract::<Self>()?;
317        Ok(Self(GeneratedParticle::composite(
318            id,
319            generator.0.clone(),
320            (&daughter_1.0, &daughter_2.0),
321            reconstruction.0.clone(),
322        )))
323    }
324
325    /// Return a copy of this generated particle with species metadata attached.
326    fn with_species(&self, species: &PyParticleSpecies) -> Self {
327        Self(self.0.clone().with_species(species.0.clone()))
328    }
329
330    /// The generated particle ID.
331    #[getter]
332    fn id(&self) -> String {
333        self.0.id().to_string()
334    }
335
336    /// Optional species metadata for this generated particle.
337    #[getter]
338    fn species(&self) -> Option<PyParticleSpecies> {
339        self.0.species().cloned().map(PyParticleSpecies)
340    }
341
342    fn __repr__(&self) -> String {
343        format!("{:?}", self.0)
344    }
345}
346
347/// A generated reaction layout.
348#[pyclass(name = "GeneratedReaction", module = "laddu", from_py_object)]
349#[derive(Clone, Debug)]
350pub struct PyGeneratedReaction(pub GeneratedReaction);
351
352#[pymethods]
353impl PyGeneratedReaction {
354    /// Construct a generated two-to-two reaction.
355    #[staticmethod]
356    fn two_to_two(
357        p1: &PyGeneratedParticle,
358        p2: &PyGeneratedParticle,
359        p3: &PyGeneratedParticle,
360        p4: &PyGeneratedParticle,
361        tdist: &PyMandelstamTDistribution,
362    ) -> PyResult<Self> {
363        Ok(Self(GeneratedReaction::two_to_two(
364            p1.0.clone(),
365            p2.0.clone(),
366            p3.0.clone(),
367            p4.0.clone(),
368            tdist.0.clone(),
369        )?))
370    }
371
372    /// Return generated p4 labels.
373    fn p4_labels(&self) -> Vec<String> {
374        self.0.p4_labels()
375    }
376
377    /// Return generated particle layout entries in stable product-ID order.
378    fn particle_layouts(&self) -> Vec<PyGeneratedParticleLayout> {
379        self.0
380            .particle_layouts()
381            .into_iter()
382            .map(PyGeneratedParticleLayout)
383            .collect()
384    }
385
386    /// Build the reconstructed reaction corresponding to this generated layout.
387    fn reconstructed_reaction(&self) -> PyResult<PyReaction> {
388        Ok(PyReaction(self.0.reconstructed_reaction()?))
389    }
390
391    fn __repr__(&self) -> String {
392        format!("{:?}", self.0)
393    }
394}
395
396/// Selects which generated particle p4s are written into generated datasets.
397#[pyclass(name = "GeneratedStorage", module = "laddu", from_py_object)]
398#[derive(Clone, Debug)]
399pub struct PyGeneratedStorage(pub GeneratedStorage);
400
401#[pymethods]
402impl PyGeneratedStorage {
403    /// Store every generated particle p4.
404    #[staticmethod]
405    fn all() -> Self {
406        Self(GeneratedStorage::all())
407    }
408
409    /// Store only the listed generated particle IDs.
410    #[staticmethod]
411    fn only(ids: Vec<String>) -> Self {
412        Self(GeneratedStorage::only(ids))
413    }
414
415    fn __repr__(&self) -> String {
416        format!("{:?}", self.0)
417    }
418}
419
420/// Metadata for one generated particle in a generated event layout.
421#[pyclass(name = "GeneratedParticleLayout", module = "laddu", from_py_object)]
422#[derive(Clone, Debug)]
423pub struct PyGeneratedParticleLayout(pub GeneratedParticleLayout);
424
425#[pymethods]
426impl PyGeneratedParticleLayout {
427    /// The generated particle identifier.
428    #[getter]
429    fn id(&self) -> String {
430        self.0.id().to_string()
431    }
432
433    /// The zero-based stable product ID in generated-layout order.
434    #[getter]
435    fn product_id(&self) -> usize {
436        self.0.product_id()
437    }
438
439    /// The decay-parent product ID, or None if this particle has no decay parent.
440    #[getter]
441    fn parent_id(&self) -> Option<usize> {
442        self.0.parent_id()
443    }
444
445    /// Optional species metadata associated with this generated particle.
446    #[getter]
447    fn species(&self) -> Option<PyParticleSpecies> {
448        self.0.species().cloned().map(PyParticleSpecies)
449    }
450
451    /// The dataset p4 label associated with this particle, if stored in the batch.
452    #[getter]
453    fn p4_label(&self) -> Option<String> {
454        self.0.p4_label().map(str::to_string)
455    }
456
457    /// The vertex ID where this particle was produced, if any.
458    #[getter]
459    fn produced_vertex_id(&self) -> Option<usize> {
460        self.0.produced_vertex_id()
461    }
462
463    /// The vertex ID where this particle decays, if it is a generated parent.
464    #[getter]
465    fn decay_vertex_id(&self) -> Option<usize> {
466        self.0.decay_vertex_id()
467    }
468
469    fn __repr__(&self) -> String {
470        format!("{:?}", self.0)
471    }
472}
473
474/// Metadata for one generated vertex in a generated event layout.
475#[pyclass(name = "GeneratedVertexLayout", module = "laddu", from_py_object)]
476#[derive(Clone, Debug)]
477pub struct PyGeneratedVertexLayout(pub GeneratedVertexLayout);
478
479#[pymethods]
480impl PyGeneratedVertexLayout {
481    /// The zero-based stable vertex ID in generated-layout order.
482    #[getter]
483    fn vertex_id(&self) -> usize {
484        self.0.vertex_id()
485    }
486
487    /// The semantic vertex kind.
488    #[getter]
489    fn kind(&self) -> &'static str {
490        match self.0.kind() {
491            GeneratedVertexKind::Production => "Production",
492            GeneratedVertexKind::Decay => "Decay",
493        }
494    }
495
496    /// Product IDs entering this vertex.
497    #[getter]
498    fn incoming_product_ids(&self) -> Vec<usize> {
499        self.0.incoming_product_ids().to_vec()
500    }
501
502    /// Product IDs leaving this vertex.
503    #[getter]
504    fn outgoing_product_ids(&self) -> Vec<usize> {
505        self.0.outgoing_product_ids().to_vec()
506    }
507
508    fn __repr__(&self) -> String {
509        format!("{:?}", self.0)
510    }
511}
512
513/// Metadata describing the columns in a generated event batch.
514#[pyclass(name = "GeneratedEventLayout", module = "laddu", from_py_object)]
515#[derive(Clone, Debug)]
516pub struct PyGeneratedEventLayout(pub GeneratedEventLayout);
517
518#[pymethods]
519impl PyGeneratedEventLayout {
520    /// Generated p4 column labels in dataset order.
521    #[getter]
522    fn p4_labels(&self) -> Vec<String> {
523        self.0.p4_labels().to_vec()
524    }
525
526    /// Generated auxiliary column labels in dataset order.
527    #[getter]
528    fn aux_labels(&self) -> Vec<String> {
529        self.0.aux_labels().to_vec()
530    }
531
532    /// Generated particle layout entries in stable product-ID order.
533    #[getter]
534    fn particles(&self) -> Vec<PyGeneratedParticleLayout> {
535        self.0
536            .particles()
537            .iter()
538            .cloned()
539            .map(PyGeneratedParticleLayout)
540            .collect()
541    }
542
543    /// Return the generated particle layout for a generated particle ID.
544    fn particle(&self, id: &str) -> Option<PyGeneratedParticleLayout> {
545        self.0.particle(id).cloned().map(PyGeneratedParticleLayout)
546    }
547
548    /// Return the generated particle layout for a stable product ID.
549    fn product(&self, product_id: usize) -> Option<PyGeneratedParticleLayout> {
550        self.0
551            .product(product_id)
552            .cloned()
553            .map(PyGeneratedParticleLayout)
554    }
555
556    /// Generated vertex layout entries in stable vertex-ID order.
557    #[getter]
558    fn vertices(&self) -> Vec<PyGeneratedVertexLayout> {
559        self.0
560            .vertices()
561            .iter()
562            .cloned()
563            .map(PyGeneratedVertexLayout)
564            .collect()
565    }
566
567    /// Return the generated vertex layout for a stable vertex ID.
568    fn vertex(&self, vertex_id: usize) -> Option<PyGeneratedVertexLayout> {
569        self.0
570            .vertex(vertex_id)
571            .cloned()
572            .map(PyGeneratedVertexLayout)
573    }
574
575    /// Return the production vertex layout, if the generated layout has one.
576    fn production_vertex(&self) -> Option<PyGeneratedVertexLayout> {
577        self.0
578            .production_vertex()
579            .cloned()
580            .map(PyGeneratedVertexLayout)
581    }
582
583    /// Return the generated decay daughters of a parent product ID.
584    fn decay_products(&self, parent_product_id: usize) -> Vec<PyGeneratedParticleLayout> {
585        self.0
586            .decay_products(parent_product_id)
587            .into_iter()
588            .cloned()
589            .map(PyGeneratedParticleLayout)
590            .collect()
591    }
592
593    /// Return production-level incoming particle layouts.
594    fn production_incoming(&self) -> Vec<PyGeneratedParticleLayout> {
595        self.0
596            .production_incoming()
597            .into_iter()
598            .cloned()
599            .map(PyGeneratedParticleLayout)
600            .collect()
601    }
602
603    /// Return production-level outgoing particle layouts.
604    fn production_outgoing(&self) -> Vec<PyGeneratedParticleLayout> {
605        self.0
606            .production_outgoing()
607            .into_iter()
608            .cloned()
609            .map(PyGeneratedParticleLayout)
610            .collect()
611    }
612
613    fn __repr__(&self) -> String {
614        format!("{:?}", self.0)
615    }
616}
617
618/// A generated dataset batch plus generated reaction and layout metadata.
619#[pyclass(name = "GeneratedBatch", module = "laddu", from_py_object)]
620#[derive(Clone, Debug)]
621pub struct PyGeneratedBatch(pub GeneratedBatch);
622
623#[pymethods]
624impl PyGeneratedBatch {
625    /// The generated dataset for this batch.
626    #[getter]
627    fn dataset(&self) -> PyDataset {
628        PyDataset(Arc::new(self.0.dataset().clone()))
629    }
630
631    /// The generated reaction metadata for this batch.
632    #[getter]
633    fn reaction(&self) -> PyGeneratedReaction {
634        PyGeneratedReaction(self.0.reaction().clone())
635    }
636
637    /// The generated event layout metadata for this batch.
638    #[getter]
639    fn layout(&self) -> PyGeneratedEventLayout {
640        PyGeneratedEventLayout(self.0.layout().clone())
641    }
642
643    fn __repr__(&self) -> String {
644        format!("{:?}", self.0)
645    }
646}
647
648/// Finite iterator over generated dataset batches.
649#[pyclass(
650    name = "GeneratedBatchIter",
651    module = "laddu",
652    unsendable,
653    skip_from_py_object
654)]
655pub struct PyGeneratedBatchIter {
656    iter: Box<dyn Iterator<Item = laddu_core::LadduResult<GeneratedBatch>>>,
657}
658
659#[pymethods]
660impl PyGeneratedBatchIter {
661    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyGeneratedBatchIter> {
662        slf.into()
663    }
664
665    fn __next__(&mut self) -> PyResult<Option<PyGeneratedBatch>> {
666        match self.iter.next() {
667            Some(Ok(batch)) => Ok(Some(PyGeneratedBatch(batch))),
668            Some(Err(err)) => Err(PyErr::from(err)),
669            None => Ok(None),
670        }
671    }
672}
673
674/// Event generator for generated reaction layouts.
675#[pyclass(name = "EventGenerator", module = "laddu", from_py_object)]
676#[derive(Clone, Debug)]
677pub struct PyEventGenerator(pub EventGenerator);
678
679#[pymethods]
680impl PyEventGenerator {
681    /// Construct an event generator.
682    #[new]
683    #[pyo3(signature = (reaction, aux_generators=None, seed=None, storage=None))]
684    fn new(
685        reaction: &PyGeneratedReaction,
686        aux_generators: Option<HashMap<String, PyDistribution>>,
687        seed: Option<u64>,
688        storage: Option<&PyGeneratedStorage>,
689    ) -> PyResult<Self> {
690        let generator = EventGenerator::new(
691            reaction.0.clone(),
692            aux_generators
693                .unwrap_or_default()
694                .into_iter()
695                .map(|(name, distribution)| (name, distribution.0))
696                .collect(),
697            seed,
698        );
699        let generator = if let Some(storage) = storage {
700            generator.with_storage(storage.0.clone())?
701        } else {
702            generator
703        };
704        Ok(Self(generator))
705    }
706
707    /// Generate one dataset batch with generated layout metadata.
708    fn generate_batch(&self, n_events: usize) -> PyResult<PyGeneratedBatch> {
709        Ok(PyGeneratedBatch(self.0.generate_batch(n_events)?))
710    }
711
712    /// Generate a finite iterator over generated dataset batches.
713    fn generate_batches(
714        &self,
715        total_events: usize,
716        batch_size: usize,
717    ) -> PyResult<PyGeneratedBatchIter> {
718        Ok(PyGeneratedBatchIter {
719            iter: Box::new(self.0.generate_batches(total_events, batch_size)?),
720        })
721    }
722
723    /// Generate a dataset.
724    fn generate_dataset(&self, n_events: usize) -> PyResult<PyDataset> {
725        Ok(PyDataset(Arc::new(self.0.generate_dataset(n_events)?)))
726    }
727
728    fn __repr__(&self) -> String {
729        format!("{:?}", self.0)
730    }
731}