1use std::{collections::HashMap, sync::Arc};
2
3use laddu_generation::{
4 CompositeGenerator, Distribution, EventGenerator, GeneratedBatch, GeneratedEventLayout,
5 GeneratedParticle, GeneratedParticleLayout, GeneratedReaction, GeneratedStorage,
6 GeneratedVertexKind, GeneratedVertexLayout, HistogramSampler, InitialGenerator,
7 MandelstamTDistribution, ParticleSpecies, Reconstruction, StableGenerator,
8};
9use pyo3::{exceptions::PyValueError, prelude::*, types::PyTuple};
10
11use crate::{data::PyDataset, math::PyHistogram, variables::PyReaction, vectors::PyVec4};
12
13#[pyclass(name = "Distribution", module = "laddu", from_py_object)]
15#[derive(Clone, Debug)]
16pub struct PyDistribution(pub Distribution);
17
18#[pymethods]
19impl PyDistribution {
20 #[staticmethod]
22 fn fixed(value: f64) -> Self {
23 Self(Distribution::Fixed(value))
24 }
25
26 #[staticmethod]
28 fn uniform(min: f64, max: f64) -> PyResult<Self> {
29 if max <= min {
30 return Err(PyValueError::new_err(
31 "`max` must be greater than `min` for a uniform distribution",
32 ));
33 }
34 Ok(Self(Distribution::Uniform { min, max }))
35 }
36
37 #[staticmethod]
39 fn normal(mu: f64, sigma: f64) -> PyResult<Self> {
40 if sigma <= 0.0 {
41 return Err(PyValueError::new_err(
42 "`sigma` must be positive for a normal distribution",
43 ));
44 }
45 Ok(Self(Distribution::Normal { mu, sigma }))
46 }
47
48 #[staticmethod]
50 fn exponential(slope: f64) -> PyResult<Self> {
51 if slope <= 0.0 {
52 return Err(PyValueError::new_err(
53 "`slope` must be positive for an exponential distribution",
54 ));
55 }
56 Ok(Self(Distribution::Exponential { slope }))
57 }
58
59 #[staticmethod]
61 fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
62 Ok(Self(Distribution::Histogram(HistogramSampler::new(
63 histogram.0.clone(),
64 )?)))
65 }
66
67 fn __repr__(&self) -> String {
68 format!("{:?}", self.0)
69 }
70}
71
72#[pyclass(name = "MandelstamTDistribution", module = "laddu", from_py_object)]
74#[derive(Clone, Debug)]
75pub struct PyMandelstamTDistribution(pub MandelstamTDistribution);
76
77#[pymethods]
78impl PyMandelstamTDistribution {
79 #[staticmethod]
81 fn exponential(slope: f64) -> PyResult<Self> {
82 if slope <= 0.0 {
83 return Err(PyValueError::new_err(
84 "`slope` must be positive for an exponential distribution",
85 ));
86 }
87 Ok(Self(MandelstamTDistribution::Exponential { slope }))
88 }
89
90 #[staticmethod]
92 fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
93 Ok(Self(MandelstamTDistribution::Histogram(
94 HistogramSampler::new(histogram.0.clone())?,
95 )))
96 }
97
98 fn __repr__(&self) -> String {
99 format!("{:?}", self.0)
100 }
101}
102
103#[pyclass(name = "InitialGenerator", module = "laddu", from_py_object)]
105#[derive(Clone, Debug)]
106pub struct PyInitialGenerator(pub InitialGenerator);
107
108#[pymethods]
109impl PyInitialGenerator {
110 #[staticmethod]
112 fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
113 Self(InitialGenerator::beam_with_fixed_energy(mass, energy))
114 }
115
116 #[staticmethod]
118 fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
119 Self(InitialGenerator::beam(mass, min_energy, max_energy))
120 }
121
122 #[staticmethod]
124 fn beam_with_energy_histogram(mass: f64, energy: &PyHistogram) -> PyResult<Self> {
125 Ok(Self(InitialGenerator::beam_with_energy_histogram(
126 mass,
127 energy.0.clone(),
128 )?))
129 }
130
131 #[staticmethod]
133 fn target(mass: f64) -> Self {
134 Self(InitialGenerator::target(mass))
135 }
136
137 fn __repr__(&self) -> String {
138 format!("{:?}", self.0)
139 }
140}
141
142#[pyclass(name = "CompositeGenerator", module = "laddu", from_py_object)]
144#[derive(Clone, Debug)]
145pub struct PyCompositeGenerator(pub CompositeGenerator);
146
147#[pymethods]
148impl PyCompositeGenerator {
149 #[new]
151 fn new(min_mass: f64, max_mass: f64) -> Self {
152 Self(CompositeGenerator::new(min_mass, max_mass))
153 }
154
155 fn __repr__(&self) -> String {
156 format!("{:?}", self.0)
157 }
158}
159
160#[pyclass(name = "StableGenerator", module = "laddu", from_py_object)]
162#[derive(Clone, Debug)]
163pub struct PyStableGenerator(pub StableGenerator);
164
165#[pymethods]
166impl PyStableGenerator {
167 #[new]
169 fn new(mass: f64) -> Self {
170 Self(StableGenerator::new(mass))
171 }
172
173 fn __repr__(&self) -> String {
174 format!("{:?}", self.0)
175 }
176}
177
178#[pyclass(name = "Reconstruction", module = "laddu", from_py_object)]
180#[derive(Clone, Debug)]
181pub struct PyReconstruction(pub Reconstruction);
182
183#[pymethods]
184impl PyReconstruction {
185 #[staticmethod]
187 fn stored() -> Self {
188 Self(Reconstruction::Stored)
189 }
190
191 #[staticmethod]
193 fn fixed(p4: &PyVec4) -> Self {
194 Self(Reconstruction::Fixed(p4.0))
195 }
196
197 #[staticmethod]
199 fn missing() -> Self {
200 Self(Reconstruction::Missing)
201 }
202
203 #[staticmethod]
205 fn composite() -> Self {
206 Self(Reconstruction::Composite)
207 }
208
209 fn __repr__(&self) -> String {
210 format!("{:?}", self.0)
211 }
212}
213
214#[pyclass(name = "ParticleSpecies", module = "laddu", from_py_object)]
216#[derive(Clone, Debug)]
217pub struct PyParticleSpecies(pub ParticleSpecies);
218
219#[pymethods]
220impl PyParticleSpecies {
221 #[staticmethod]
223 fn code(id: i64) -> Self {
224 Self(ParticleSpecies::code(id))
225 }
226
227 #[staticmethod]
229 fn with_namespace(namespace: &str, id: i64) -> Self {
230 Self(ParticleSpecies::with_namespace(namespace, id))
231 }
232
233 #[staticmethod]
235 fn label(label: &str) -> Self {
236 Self(ParticleSpecies::label(label))
237 }
238
239 #[getter]
241 fn id(&self) -> Option<i64> {
242 match &self.0 {
243 ParticleSpecies::Code { id, .. } => Some(*id),
244 ParticleSpecies::Label(_) => None,
245 }
246 }
247
248 #[getter]
250 fn namespace(&self) -> Option<String> {
251 match &self.0 {
252 ParticleSpecies::Code { namespace, .. } => namespace.clone(),
253 ParticleSpecies::Label(_) => None,
254 }
255 }
256
257 #[getter]
259 fn label_value(&self) -> Option<String> {
260 match &self.0 {
261 ParticleSpecies::Code { .. } => None,
262 ParticleSpecies::Label(label) => Some(label.clone()),
263 }
264 }
265
266 fn __repr__(&self) -> String {
267 format!("{:?}", self.0)
268 }
269}
270
271#[pyclass(name = "GeneratedParticle", module = "laddu", from_py_object)]
273#[derive(Clone, Debug)]
274pub struct PyGeneratedParticle(pub GeneratedParticle);
275
276#[pymethods]
277impl PyGeneratedParticle {
278 #[staticmethod]
280 fn initial(
281 id: &str,
282 generator: &PyInitialGenerator,
283 reconstruction: &PyReconstruction,
284 ) -> Self {
285 Self(GeneratedParticle::initial(
286 id,
287 generator.0.clone(),
288 reconstruction.0.clone(),
289 ))
290 }
291
292 #[staticmethod]
294 fn stable(id: &str, generator: &PyStableGenerator, reconstruction: &PyReconstruction) -> Self {
295 Self(GeneratedParticle::stable(
296 id,
297 generator.0.clone(),
298 reconstruction.0.clone(),
299 ))
300 }
301
302 #[staticmethod]
304 fn composite(
305 id: &str,
306 generator: &PyCompositeGenerator,
307 daughters: &Bound<'_, PyTuple>,
308 reconstruction: &PyReconstruction,
309 ) -> PyResult<Self> {
310 if daughters.len() != 2 {
311 return Err(PyValueError::new_err(
312 "composite particles require exactly two ordered daughters",
313 ));
314 }
315 let daughter_1 = daughters.get_item(0)?.extract::<Self>()?;
316 let daughter_2 = daughters.get_item(1)?.extract::<Self>()?;
317 Ok(Self(GeneratedParticle::composite(
318 id,
319 generator.0.clone(),
320 (&daughter_1.0, &daughter_2.0),
321 reconstruction.0.clone(),
322 )))
323 }
324
325 fn with_species(&self, species: &PyParticleSpecies) -> Self {
327 Self(self.0.clone().with_species(species.0.clone()))
328 }
329
330 #[getter]
332 fn id(&self) -> String {
333 self.0.id().to_string()
334 }
335
336 #[getter]
338 fn species(&self) -> Option<PyParticleSpecies> {
339 self.0.species().cloned().map(PyParticleSpecies)
340 }
341
342 fn __repr__(&self) -> String {
343 format!("{:?}", self.0)
344 }
345}
346
347#[pyclass(name = "GeneratedReaction", module = "laddu", from_py_object)]
349#[derive(Clone, Debug)]
350pub struct PyGeneratedReaction(pub GeneratedReaction);
351
352#[pymethods]
353impl PyGeneratedReaction {
354 #[staticmethod]
356 fn two_to_two(
357 p1: &PyGeneratedParticle,
358 p2: &PyGeneratedParticle,
359 p3: &PyGeneratedParticle,
360 p4: &PyGeneratedParticle,
361 tdist: &PyMandelstamTDistribution,
362 ) -> PyResult<Self> {
363 Ok(Self(GeneratedReaction::two_to_two(
364 p1.0.clone(),
365 p2.0.clone(),
366 p3.0.clone(),
367 p4.0.clone(),
368 tdist.0.clone(),
369 )?))
370 }
371
372 fn p4_labels(&self) -> Vec<String> {
374 self.0.p4_labels()
375 }
376
377 fn particle_layouts(&self) -> Vec<PyGeneratedParticleLayout> {
379 self.0
380 .particle_layouts()
381 .into_iter()
382 .map(PyGeneratedParticleLayout)
383 .collect()
384 }
385
386 fn reconstructed_reaction(&self) -> PyResult<PyReaction> {
388 Ok(PyReaction(self.0.reconstructed_reaction()?))
389 }
390
391 fn __repr__(&self) -> String {
392 format!("{:?}", self.0)
393 }
394}
395
396#[pyclass(name = "GeneratedStorage", module = "laddu", from_py_object)]
398#[derive(Clone, Debug)]
399pub struct PyGeneratedStorage(pub GeneratedStorage);
400
401#[pymethods]
402impl PyGeneratedStorage {
403 #[staticmethod]
405 fn all() -> Self {
406 Self(GeneratedStorage::all())
407 }
408
409 #[staticmethod]
411 fn only(ids: Vec<String>) -> Self {
412 Self(GeneratedStorage::only(ids))
413 }
414
415 fn __repr__(&self) -> String {
416 format!("{:?}", self.0)
417 }
418}
419
420#[pyclass(name = "GeneratedParticleLayout", module = "laddu", from_py_object)]
422#[derive(Clone, Debug)]
423pub struct PyGeneratedParticleLayout(pub GeneratedParticleLayout);
424
425#[pymethods]
426impl PyGeneratedParticleLayout {
427 #[getter]
429 fn id(&self) -> String {
430 self.0.id().to_string()
431 }
432
433 #[getter]
435 fn product_id(&self) -> usize {
436 self.0.product_id()
437 }
438
439 #[getter]
441 fn parent_id(&self) -> Option<usize> {
442 self.0.parent_id()
443 }
444
445 #[getter]
447 fn species(&self) -> Option<PyParticleSpecies> {
448 self.0.species().cloned().map(PyParticleSpecies)
449 }
450
451 #[getter]
453 fn p4_label(&self) -> Option<String> {
454 self.0.p4_label().map(str::to_string)
455 }
456
457 #[getter]
459 fn produced_vertex_id(&self) -> Option<usize> {
460 self.0.produced_vertex_id()
461 }
462
463 #[getter]
465 fn decay_vertex_id(&self) -> Option<usize> {
466 self.0.decay_vertex_id()
467 }
468
469 fn __repr__(&self) -> String {
470 format!("{:?}", self.0)
471 }
472}
473
474#[pyclass(name = "GeneratedVertexLayout", module = "laddu", from_py_object)]
476#[derive(Clone, Debug)]
477pub struct PyGeneratedVertexLayout(pub GeneratedVertexLayout);
478
479#[pymethods]
480impl PyGeneratedVertexLayout {
481 #[getter]
483 fn vertex_id(&self) -> usize {
484 self.0.vertex_id()
485 }
486
487 #[getter]
489 fn kind(&self) -> &'static str {
490 match self.0.kind() {
491 GeneratedVertexKind::Production => "Production",
492 GeneratedVertexKind::Decay => "Decay",
493 }
494 }
495
496 #[getter]
498 fn incoming_product_ids(&self) -> Vec<usize> {
499 self.0.incoming_product_ids().to_vec()
500 }
501
502 #[getter]
504 fn outgoing_product_ids(&self) -> Vec<usize> {
505 self.0.outgoing_product_ids().to_vec()
506 }
507
508 fn __repr__(&self) -> String {
509 format!("{:?}", self.0)
510 }
511}
512
513#[pyclass(name = "GeneratedEventLayout", module = "laddu", from_py_object)]
515#[derive(Clone, Debug)]
516pub struct PyGeneratedEventLayout(pub GeneratedEventLayout);
517
518#[pymethods]
519impl PyGeneratedEventLayout {
520 #[getter]
522 fn p4_labels(&self) -> Vec<String> {
523 self.0.p4_labels().to_vec()
524 }
525
526 #[getter]
528 fn aux_labels(&self) -> Vec<String> {
529 self.0.aux_labels().to_vec()
530 }
531
532 #[getter]
534 fn particles(&self) -> Vec<PyGeneratedParticleLayout> {
535 self.0
536 .particles()
537 .iter()
538 .cloned()
539 .map(PyGeneratedParticleLayout)
540 .collect()
541 }
542
543 fn particle(&self, id: &str) -> Option<PyGeneratedParticleLayout> {
545 self.0.particle(id).cloned().map(PyGeneratedParticleLayout)
546 }
547
548 fn product(&self, product_id: usize) -> Option<PyGeneratedParticleLayout> {
550 self.0
551 .product(product_id)
552 .cloned()
553 .map(PyGeneratedParticleLayout)
554 }
555
556 #[getter]
558 fn vertices(&self) -> Vec<PyGeneratedVertexLayout> {
559 self.0
560 .vertices()
561 .iter()
562 .cloned()
563 .map(PyGeneratedVertexLayout)
564 .collect()
565 }
566
567 fn vertex(&self, vertex_id: usize) -> Option<PyGeneratedVertexLayout> {
569 self.0
570 .vertex(vertex_id)
571 .cloned()
572 .map(PyGeneratedVertexLayout)
573 }
574
575 fn production_vertex(&self) -> Option<PyGeneratedVertexLayout> {
577 self.0
578 .production_vertex()
579 .cloned()
580 .map(PyGeneratedVertexLayout)
581 }
582
583 fn decay_products(&self, parent_product_id: usize) -> Vec<PyGeneratedParticleLayout> {
585 self.0
586 .decay_products(parent_product_id)
587 .into_iter()
588 .cloned()
589 .map(PyGeneratedParticleLayout)
590 .collect()
591 }
592
593 fn production_incoming(&self) -> Vec<PyGeneratedParticleLayout> {
595 self.0
596 .production_incoming()
597 .into_iter()
598 .cloned()
599 .map(PyGeneratedParticleLayout)
600 .collect()
601 }
602
603 fn production_outgoing(&self) -> Vec<PyGeneratedParticleLayout> {
605 self.0
606 .production_outgoing()
607 .into_iter()
608 .cloned()
609 .map(PyGeneratedParticleLayout)
610 .collect()
611 }
612
613 fn __repr__(&self) -> String {
614 format!("{:?}", self.0)
615 }
616}
617
618#[pyclass(name = "GeneratedBatch", module = "laddu", from_py_object)]
620#[derive(Clone, Debug)]
621pub struct PyGeneratedBatch(pub GeneratedBatch);
622
623#[pymethods]
624impl PyGeneratedBatch {
625 #[getter]
627 fn dataset(&self) -> PyDataset {
628 PyDataset(Arc::new(self.0.dataset().clone()))
629 }
630
631 #[getter]
633 fn reaction(&self) -> PyGeneratedReaction {
634 PyGeneratedReaction(self.0.reaction().clone())
635 }
636
637 #[getter]
639 fn layout(&self) -> PyGeneratedEventLayout {
640 PyGeneratedEventLayout(self.0.layout().clone())
641 }
642
643 fn __repr__(&self) -> String {
644 format!("{:?}", self.0)
645 }
646}
647
648#[pyclass(
650 name = "GeneratedBatchIter",
651 module = "laddu",
652 unsendable,
653 skip_from_py_object
654)]
655pub struct PyGeneratedBatchIter {
656 iter: Box<dyn Iterator<Item = laddu_core::LadduResult<GeneratedBatch>>>,
657}
658
659#[pymethods]
660impl PyGeneratedBatchIter {
661 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyGeneratedBatchIter> {
662 slf.into()
663 }
664
665 fn __next__(&mut self) -> PyResult<Option<PyGeneratedBatch>> {
666 match self.iter.next() {
667 Some(Ok(batch)) => Ok(Some(PyGeneratedBatch(batch))),
668 Some(Err(err)) => Err(PyErr::from(err)),
669 None => Ok(None),
670 }
671 }
672}
673
674#[pyclass(name = "EventGenerator", module = "laddu", from_py_object)]
676#[derive(Clone, Debug)]
677pub struct PyEventGenerator(pub EventGenerator);
678
679#[pymethods]
680impl PyEventGenerator {
681 #[new]
683 #[pyo3(signature = (reaction, aux_generators=None, seed=None, storage=None))]
684 fn new(
685 reaction: &PyGeneratedReaction,
686 aux_generators: Option<HashMap<String, PyDistribution>>,
687 seed: Option<u64>,
688 storage: Option<&PyGeneratedStorage>,
689 ) -> PyResult<Self> {
690 let generator = EventGenerator::new(
691 reaction.0.clone(),
692 aux_generators
693 .unwrap_or_default()
694 .into_iter()
695 .map(|(name, distribution)| (name, distribution.0))
696 .collect(),
697 seed,
698 );
699 let generator = if let Some(storage) = storage {
700 generator.with_storage(storage.0.clone())?
701 } else {
702 generator
703 };
704 Ok(Self(generator))
705 }
706
707 fn generate_batch(&self, n_events: usize) -> PyResult<PyGeneratedBatch> {
709 Ok(PyGeneratedBatch(self.0.generate_batch(n_events)?))
710 }
711
712 fn generate_batches(
714 &self,
715 total_events: usize,
716 batch_size: usize,
717 ) -> PyResult<PyGeneratedBatchIter> {
718 Ok(PyGeneratedBatchIter {
719 iter: Box::new(self.0.generate_batches(total_events, batch_size)?),
720 })
721 }
722
723 fn generate_dataset(&self, n_events: usize) -> PyResult<PyDataset> {
725 Ok(PyDataset(Arc::new(self.0.generate_dataset(n_events)?)))
726 }
727
728 fn __repr__(&self) -> String {
729 format!("{:?}", self.0)
730 }
731}