Skip to main content

laddu_python/
generation.rs

1use std::{collections::HashMap, sync::Arc};
2
3use laddu_generation::{
4    CompositeGenerator, Distribution, EventGenerator, ExpressionIntensity, GeneratedBatch,
5    GeneratedEventLayout, GeneratedParticle, GeneratedParticleLayout, GeneratedReaction,
6    GeneratedStorage, GeneratedVertexKind, GeneratedVertexLayout, HistogramSampler,
7    InitialGenerator, MandelstamTDistribution, ParticleSpecies, Reconstruction, RejectionEnvelope,
8    RejectionSampleIter, RejectionSamplingDiagnostics, RejectionSamplingOptions, StableGenerator,
9};
10use pyo3::{exceptions::PyValueError, prelude::*, types::PyTuple};
11
12use crate::{
13    amplitudes::PyExpression, data::PyDataset, math::PyHistogram, variables::PyReaction,
14    vectors::PyVec4,
15};
16
17/// A scalar distribution used by generated auxiliary columns.
18#[pyclass(name = "Distribution", module = "laddu", from_py_object)]
19#[derive(Clone, Debug)]
20pub struct PyDistribution(pub Distribution);
21
22#[pymethods]
23impl PyDistribution {
24    /// Construct a fixed scalar distribution.
25    #[staticmethod]
26    fn fixed(value: f64) -> Self {
27        Self(Distribution::Fixed(value))
28    }
29
30    /// Construct a uniform scalar distribution.
31    #[staticmethod]
32    fn uniform(min: f64, max: f64) -> PyResult<Self> {
33        if max <= min {
34            return Err(PyValueError::new_err(
35                "`max` must be greater than `min` for a uniform distribution",
36            ));
37        }
38        Ok(Self(Distribution::Uniform { min, max }))
39    }
40
41    /// Construct a normal scalar distribution.
42    #[staticmethod]
43    fn normal(mu: f64, sigma: f64) -> PyResult<Self> {
44        if sigma <= 0.0 {
45            return Err(PyValueError::new_err(
46                "`sigma` must be positive for a normal distribution",
47            ));
48        }
49        Ok(Self(Distribution::Normal { mu, sigma }))
50    }
51
52    /// Construct an exponential scalar distribution.
53    #[staticmethod]
54    fn exponential(slope: f64) -> PyResult<Self> {
55        if slope <= 0.0 {
56            return Err(PyValueError::new_err(
57                "`slope` must be positive for an exponential distribution",
58            ));
59        }
60        Ok(Self(Distribution::Exponential { slope }))
61    }
62
63    /// Construct a histogram-sampled scalar distribution.
64    #[staticmethod]
65    fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
66        Ok(Self(Distribution::Histogram(HistogramSampler::new(
67            histogram.0.clone(),
68        )?)))
69    }
70
71    fn __repr__(&self) -> String {
72        format!("{:?}", self.0)
73    }
74}
75
76/// A Mandelstam-t distribution for generated two-to-two reactions.
77#[pyclass(name = "MandelstamTDistribution", module = "laddu", from_py_object)]
78#[derive(Clone, Debug)]
79pub struct PyMandelstamTDistribution(pub MandelstamTDistribution);
80
81#[pymethods]
82impl PyMandelstamTDistribution {
83    /// Construct an exponential Mandelstam-t distribution.
84    #[staticmethod]
85    fn exponential(slope: f64) -> PyResult<Self> {
86        if slope <= 0.0 {
87            return Err(PyValueError::new_err(
88                "`slope` must be positive for an exponential distribution",
89            ));
90        }
91        Ok(Self(MandelstamTDistribution::Exponential { slope }))
92    }
93
94    /// Construct a histogram-sampled Mandelstam-t distribution.
95    #[staticmethod]
96    fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
97        Ok(Self(MandelstamTDistribution::Histogram(
98            HistogramSampler::new(histogram.0.clone())?,
99        )))
100    }
101
102    fn __repr__(&self) -> String {
103        format!("{:?}", self.0)
104    }
105}
106
107/// Generator settings for an initial generated particle.
108#[pyclass(name = "InitialGenerator", module = "laddu", from_py_object)]
109#[derive(Clone, Debug)]
110pub struct PyInitialGenerator(pub InitialGenerator);
111
112#[pymethods]
113impl PyInitialGenerator {
114    /// Construct a beam with fixed energy.
115    #[staticmethod]
116    fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
117        Self(InitialGenerator::beam_with_fixed_energy(mass, energy))
118    }
119
120    /// Construct a beam with uniformly sampled energy.
121    #[staticmethod]
122    fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
123        Self(InitialGenerator::beam(mass, min_energy, max_energy))
124    }
125
126    /// Construct a beam with histogram-sampled energy.
127    #[staticmethod]
128    fn beam_with_energy_histogram(mass: f64, energy: &PyHistogram) -> PyResult<Self> {
129        Ok(Self(InitialGenerator::beam_with_energy_histogram(
130            mass,
131            energy.0.clone(),
132        )?))
133    }
134
135    /// Construct a target at rest.
136    #[staticmethod]
137    fn target(mass: f64) -> Self {
138        Self(InitialGenerator::target(mass))
139    }
140
141    fn __repr__(&self) -> String {
142        format!("{:?}", self.0)
143    }
144}
145
146/// Generator settings for a generated composite particle.
147#[pyclass(name = "CompositeGenerator", module = "laddu", from_py_object)]
148#[derive(Clone, Debug)]
149pub struct PyCompositeGenerator(pub CompositeGenerator);
150
151#[pymethods]
152impl PyCompositeGenerator {
153    /// Construct a composite mass generator with a uniform mass range.
154    #[new]
155    fn new(min_mass: f64, max_mass: f64) -> Self {
156        Self(CompositeGenerator::new(min_mass, max_mass))
157    }
158
159    fn __repr__(&self) -> String {
160        format!("{:?}", self.0)
161    }
162}
163
164/// Generator settings for a stable generated particle.
165#[pyclass(name = "StableGenerator", module = "laddu", from_py_object)]
166#[derive(Clone, Debug)]
167pub struct PyStableGenerator(pub StableGenerator);
168
169#[pymethods]
170impl PyStableGenerator {
171    /// Construct a fixed-mass stable-particle generator.
172    #[new]
173    fn new(mass: f64) -> Self {
174        Self(StableGenerator::new(mass))
175    }
176
177    fn __repr__(&self) -> String {
178        format!("{:?}", self.0)
179    }
180}
181
182/// Reconstruction metadata for a generated particle.
183#[pyclass(name = "Reconstruction", module = "laddu", from_py_object)]
184#[derive(Clone, Debug)]
185pub struct PyReconstruction(pub Reconstruction);
186
187#[pymethods]
188impl PyReconstruction {
189    /// Mark a generated particle as stored under its generated ID.
190    #[staticmethod]
191    fn stored() -> Self {
192        Self(Reconstruction::Stored)
193    }
194
195    /// Mark a generated particle as fixed in the reconstructed reaction.
196    #[staticmethod]
197    fn fixed(p4: &PyVec4) -> Self {
198        Self(Reconstruction::Fixed(p4.0))
199    }
200
201    /// Mark a generated particle as missing in the reconstructed reaction.
202    #[staticmethod]
203    fn missing() -> Self {
204        Self(Reconstruction::Missing)
205    }
206
207    /// Mark a generated particle as reconstructed from its generated daughters.
208    #[staticmethod]
209    fn composite() -> Self {
210        Self(Reconstruction::Composite)
211    }
212
213    fn __repr__(&self) -> String {
214        format!("{:?}", self.0)
215    }
216}
217
218/// Experiment-neutral metadata describing a generated particle species.
219#[pyclass(name = "ParticleSpecies", module = "laddu", from_py_object)]
220#[derive(Clone, Debug)]
221pub struct PyParticleSpecies(pub ParticleSpecies);
222
223#[pymethods]
224impl PyParticleSpecies {
225    /// Construct a species from a numeric code with no namespace.
226    #[staticmethod]
227    fn code(id: i64) -> Self {
228        Self(ParticleSpecies::code(id))
229    }
230
231    /// Construct a species from a numeric code in an explicit namespace.
232    #[staticmethod]
233    fn with_namespace(namespace: &str, id: i64) -> Self {
234        Self(ParticleSpecies::with_namespace(namespace, id))
235    }
236
237    /// Construct a species from a free-form label.
238    #[staticmethod]
239    fn label(label: &str) -> Self {
240        Self(ParticleSpecies::label(label))
241    }
242
243    /// The numeric species code, if this is a code-based species.
244    #[getter]
245    fn id(&self) -> Option<i64> {
246        match &self.0 {
247            ParticleSpecies::Code { id, .. } => Some(*id),
248            ParticleSpecies::Label(_) => None,
249        }
250    }
251
252    /// The numeric species namespace, if this is a namespaced code-based species.
253    #[getter]
254    fn namespace(&self) -> Option<String> {
255        match &self.0 {
256            ParticleSpecies::Code { namespace, .. } => namespace.clone(),
257            ParticleSpecies::Label(_) => None,
258        }
259    }
260
261    /// The species label, if this is a label-based species.
262    #[getter]
263    fn label_value(&self) -> Option<String> {
264        match &self.0 {
265            ParticleSpecies::Code { .. } => None,
266            ParticleSpecies::Label(label) => Some(label.clone()),
267        }
268    }
269
270    fn __repr__(&self) -> String {
271        format!("{:?}", self.0)
272    }
273}
274
275/// A generated particle with generation and reconstruction metadata.
276#[pyclass(name = "GeneratedParticle", module = "laddu", from_py_object)]
277#[derive(Clone, Debug)]
278pub struct PyGeneratedParticle(pub GeneratedParticle);
279
280#[pymethods]
281impl PyGeneratedParticle {
282    /// Construct an initial generated particle.
283    #[staticmethod]
284    fn initial(
285        id: &str,
286        generator: &PyInitialGenerator,
287        reconstruction: &PyReconstruction,
288    ) -> Self {
289        Self(GeneratedParticle::initial(
290            id,
291            generator.0.clone(),
292            reconstruction.0.clone(),
293        ))
294    }
295
296    /// Construct a stable generated particle.
297    #[staticmethod]
298    fn stable(id: &str, generator: &PyStableGenerator, reconstruction: &PyReconstruction) -> Self {
299        Self(GeneratedParticle::stable(
300            id,
301            generator.0.clone(),
302            reconstruction.0.clone(),
303        ))
304    }
305
306    /// Construct a generated composite from exactly two ordered daughters.
307    #[staticmethod]
308    fn composite(
309        id: &str,
310        generator: &PyCompositeGenerator,
311        daughters: &Bound<'_, PyTuple>,
312        reconstruction: &PyReconstruction,
313    ) -> PyResult<Self> {
314        if daughters.len() != 2 {
315            return Err(PyValueError::new_err(
316                "composite particles require exactly two ordered daughters",
317            ));
318        }
319        let daughter_1 = daughters.get_item(0)?.extract::<Self>()?;
320        let daughter_2 = daughters.get_item(1)?.extract::<Self>()?;
321        Ok(Self(GeneratedParticle::composite(
322            id,
323            generator.0.clone(),
324            (&daughter_1.0, &daughter_2.0),
325            reconstruction.0.clone(),
326        )))
327    }
328
329    /// Return a copy of this generated particle with species metadata attached.
330    fn with_species(&self, species: &PyParticleSpecies) -> Self {
331        Self(self.0.clone().with_species(species.0.clone()))
332    }
333
334    /// The generated particle ID.
335    #[getter]
336    fn id(&self) -> String {
337        self.0.id().to_string()
338    }
339
340    /// Optional species metadata for this generated particle.
341    #[getter]
342    fn species(&self) -> Option<PyParticleSpecies> {
343        self.0.species().cloned().map(PyParticleSpecies)
344    }
345
346    fn __repr__(&self) -> String {
347        format!("{:?}", self.0)
348    }
349}
350
351/// A generated reaction layout.
352#[pyclass(name = "GeneratedReaction", module = "laddu", from_py_object)]
353#[derive(Clone, Debug)]
354pub struct PyGeneratedReaction(pub GeneratedReaction);
355
356#[pymethods]
357impl PyGeneratedReaction {
358    /// Construct a generated two-to-two reaction.
359    #[staticmethod]
360    fn two_to_two(
361        p1: &PyGeneratedParticle,
362        p2: &PyGeneratedParticle,
363        p3: &PyGeneratedParticle,
364        p4: &PyGeneratedParticle,
365        tdist: &PyMandelstamTDistribution,
366    ) -> PyResult<Self> {
367        Ok(Self(GeneratedReaction::two_to_two(
368            p1.0.clone(),
369            p2.0.clone(),
370            p3.0.clone(),
371            p4.0.clone(),
372            tdist.0.clone(),
373        )?))
374    }
375
376    /// Return generated p4 labels.
377    fn p4_labels(&self) -> Vec<String> {
378        self.0.p4_labels()
379    }
380
381    /// Return generated particle layout entries in stable product-ID order.
382    fn particle_layouts(&self) -> Vec<PyGeneratedParticleLayout> {
383        self.0
384            .particle_layouts()
385            .into_iter()
386            .map(PyGeneratedParticleLayout)
387            .collect()
388    }
389
390    /// Build the reconstructed reaction corresponding to this generated layout.
391    fn reconstructed_reaction(&self) -> PyResult<PyReaction> {
392        Ok(PyReaction(self.0.reconstructed_reaction()?))
393    }
394
395    fn __repr__(&self) -> String {
396        format!("{:?}", self.0)
397    }
398}
399
400/// Selects which generated particle p4s are written into generated datasets.
401#[pyclass(name = "GeneratedStorage", module = "laddu", from_py_object)]
402#[derive(Clone, Debug)]
403pub struct PyGeneratedStorage(pub GeneratedStorage);
404
405#[pymethods]
406impl PyGeneratedStorage {
407    /// Store every generated particle p4.
408    #[staticmethod]
409    fn all() -> Self {
410        Self(GeneratedStorage::all())
411    }
412
413    /// Store only the listed generated particle IDs.
414    #[staticmethod]
415    fn only(ids: Vec<String>) -> Self {
416        Self(GeneratedStorage::only(ids))
417    }
418
419    fn __repr__(&self) -> String {
420        format!("{:?}", self.0)
421    }
422}
423
424/// Metadata for one generated particle in a generated event layout.
425#[pyclass(name = "GeneratedParticleLayout", module = "laddu", from_py_object)]
426#[derive(Clone, Debug)]
427pub struct PyGeneratedParticleLayout(pub GeneratedParticleLayout);
428
429#[pymethods]
430impl PyGeneratedParticleLayout {
431    /// The generated particle identifier.
432    #[getter]
433    fn id(&self) -> String {
434        self.0.id().to_string()
435    }
436
437    /// The zero-based stable product ID in generated-layout order.
438    #[getter]
439    fn product_id(&self) -> usize {
440        self.0.product_id()
441    }
442
443    /// The decay-parent product ID, or None if this particle has no decay parent.
444    #[getter]
445    fn parent_id(&self) -> Option<usize> {
446        self.0.parent_id()
447    }
448
449    /// Optional species metadata associated with this generated particle.
450    #[getter]
451    fn species(&self) -> Option<PyParticleSpecies> {
452        self.0.species().cloned().map(PyParticleSpecies)
453    }
454
455    /// The dataset p4 label associated with this particle, if stored in the batch.
456    #[getter]
457    fn p4_label(&self) -> Option<String> {
458        self.0.p4_label().map(str::to_string)
459    }
460
461    /// The vertex ID where this particle was produced, if any.
462    #[getter]
463    fn produced_vertex_id(&self) -> Option<usize> {
464        self.0.produced_vertex_id()
465    }
466
467    /// The vertex ID where this particle decays, if it is a generated parent.
468    #[getter]
469    fn decay_vertex_id(&self) -> Option<usize> {
470        self.0.decay_vertex_id()
471    }
472
473    fn __repr__(&self) -> String {
474        format!("{:?}", self.0)
475    }
476}
477
478/// Metadata for one generated vertex in a generated event layout.
479#[pyclass(name = "GeneratedVertexLayout", module = "laddu", from_py_object)]
480#[derive(Clone, Debug)]
481pub struct PyGeneratedVertexLayout(pub GeneratedVertexLayout);
482
483#[pymethods]
484impl PyGeneratedVertexLayout {
485    /// The zero-based stable vertex ID in generated-layout order.
486    #[getter]
487    fn vertex_id(&self) -> usize {
488        self.0.vertex_id()
489    }
490
491    /// The semantic vertex kind.
492    #[getter]
493    fn kind(&self) -> &'static str {
494        match self.0.kind() {
495            GeneratedVertexKind::Production => "Production",
496            GeneratedVertexKind::Decay => "Decay",
497        }
498    }
499
500    /// Product IDs entering this vertex.
501    #[getter]
502    fn incoming_product_ids(&self) -> Vec<usize> {
503        self.0.incoming_product_ids().to_vec()
504    }
505
506    /// Product IDs leaving this vertex.
507    #[getter]
508    fn outgoing_product_ids(&self) -> Vec<usize> {
509        self.0.outgoing_product_ids().to_vec()
510    }
511
512    fn __repr__(&self) -> String {
513        format!("{:?}", self.0)
514    }
515}
516
517/// Metadata describing the columns in a generated event batch.
518#[pyclass(name = "GeneratedEventLayout", module = "laddu", from_py_object)]
519#[derive(Clone, Debug)]
520pub struct PyGeneratedEventLayout(pub GeneratedEventLayout);
521
522#[pymethods]
523impl PyGeneratedEventLayout {
524    /// Generated p4 column labels in dataset order.
525    #[getter]
526    fn p4_labels(&self) -> Vec<String> {
527        self.0.p4_labels().to_vec()
528    }
529
530    /// Generated auxiliary column labels in dataset order.
531    #[getter]
532    fn aux_labels(&self) -> Vec<String> {
533        self.0.aux_labels().to_vec()
534    }
535
536    /// Generated particle layout entries in stable product-ID order.
537    #[getter]
538    fn particles(&self) -> Vec<PyGeneratedParticleLayout> {
539        self.0
540            .particles()
541            .iter()
542            .cloned()
543            .map(PyGeneratedParticleLayout)
544            .collect()
545    }
546
547    /// Return the generated particle layout for a generated particle ID.
548    fn particle(&self, id: &str) -> Option<PyGeneratedParticleLayout> {
549        self.0.particle(id).cloned().map(PyGeneratedParticleLayout)
550    }
551
552    /// Return the generated particle layout for a stable product ID.
553    fn product(&self, product_id: usize) -> Option<PyGeneratedParticleLayout> {
554        self.0
555            .product(product_id)
556            .cloned()
557            .map(PyGeneratedParticleLayout)
558    }
559
560    /// Generated vertex layout entries in stable vertex-ID order.
561    #[getter]
562    fn vertices(&self) -> Vec<PyGeneratedVertexLayout> {
563        self.0
564            .vertices()
565            .iter()
566            .cloned()
567            .map(PyGeneratedVertexLayout)
568            .collect()
569    }
570
571    /// Return the generated vertex layout for a stable vertex ID.
572    fn vertex(&self, vertex_id: usize) -> Option<PyGeneratedVertexLayout> {
573        self.0
574            .vertex(vertex_id)
575            .cloned()
576            .map(PyGeneratedVertexLayout)
577    }
578
579    /// Return the production vertex layout, if the generated layout has one.
580    fn production_vertex(&self) -> Option<PyGeneratedVertexLayout> {
581        self.0
582            .production_vertex()
583            .cloned()
584            .map(PyGeneratedVertexLayout)
585    }
586
587    /// Return the generated decay daughters of a parent product ID.
588    fn decay_products(&self, parent_product_id: usize) -> Vec<PyGeneratedParticleLayout> {
589        self.0
590            .decay_products(parent_product_id)
591            .into_iter()
592            .cloned()
593            .map(PyGeneratedParticleLayout)
594            .collect()
595    }
596
597    /// Return production-level incoming particle layouts.
598    fn production_incoming(&self) -> Vec<PyGeneratedParticleLayout> {
599        self.0
600            .production_incoming()
601            .into_iter()
602            .cloned()
603            .map(PyGeneratedParticleLayout)
604            .collect()
605    }
606
607    /// Return production-level outgoing particle layouts.
608    fn production_outgoing(&self) -> Vec<PyGeneratedParticleLayout> {
609        self.0
610            .production_outgoing()
611            .into_iter()
612            .cloned()
613            .map(PyGeneratedParticleLayout)
614            .collect()
615    }
616
617    fn __repr__(&self) -> String {
618        format!("{:?}", self.0)
619    }
620}
621
622/// A generated dataset batch plus generated reaction and layout metadata.
623#[pyclass(name = "GeneratedBatch", module = "laddu", from_py_object)]
624#[derive(Clone, Debug)]
625pub struct PyGeneratedBatch(pub GeneratedBatch);
626
627#[pymethods]
628impl PyGeneratedBatch {
629    /// The generated dataset for this batch.
630    #[getter]
631    fn dataset(&self) -> PyDataset {
632        PyDataset(Arc::new(self.0.dataset().clone()))
633    }
634
635    /// The generated reaction metadata for this batch.
636    #[getter]
637    fn reaction(&self) -> PyGeneratedReaction {
638        PyGeneratedReaction(self.0.reaction().clone())
639    }
640
641    /// The generated event layout metadata for this batch.
642    #[getter]
643    fn layout(&self) -> PyGeneratedEventLayout {
644        PyGeneratedEventLayout(self.0.layout().clone())
645    }
646
647    fn __repr__(&self) -> String {
648        format!("{:?}", self.0)
649    }
650}
651
652/// Finite iterator over generated dataset batches.
653#[pyclass(
654    name = "GeneratedBatchIter",
655    module = "laddu",
656    unsendable,
657    skip_from_py_object
658)]
659pub struct PyGeneratedBatchIter {
660    iter: Box<dyn Iterator<Item = laddu_core::LadduResult<GeneratedBatch>>>,
661}
662
663#[pymethods]
664impl PyGeneratedBatchIter {
665    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyGeneratedBatchIter> {
666        slf.into()
667    }
668
669    fn __next__(&mut self) -> PyResult<Option<PyGeneratedBatch>> {
670        match self.iter.next() {
671            Some(Ok(batch)) => Ok(Some(PyGeneratedBatch(batch))),
672            Some(Err(err)) => Err(PyErr::from(err)),
673            None => Ok(None),
674        }
675    }
676}
677
678/// Envelope strategy used by rejection sampling.
679#[pyclass(name = "RejectionEnvelope", module = "laddu", from_py_object)]
680#[derive(Clone, Debug)]
681pub struct PyRejectionEnvelope(pub RejectionEnvelope);
682
683#[pymethods]
684impl PyRejectionEnvelope {
685    /// Use a fixed maximum event weight.
686    #[staticmethod]
687    fn fixed(max_weight: f64) -> Self {
688        Self(RejectionEnvelope::Fixed { max_weight })
689    }
690
691    /// Estimate the maximum event weight from a pilot sample.
692    #[staticmethod]
693    #[pyo3(signature = (pilot_events, *, safety_factor=1.2, batch_size=None))]
694    fn pilot(pilot_events: usize, safety_factor: f64, batch_size: Option<usize>) -> Self {
695        Self(RejectionEnvelope::Pilot {
696            pilot_events,
697            batch_size,
698            safety_factor,
699        })
700    }
701
702    fn __repr__(&self) -> String {
703        format!("{:?}", self.0)
704    }
705}
706
707/// Rejection-sampling diagnostics.
708#[pyclass(
709    name = "RejectionSamplingDiagnostics",
710    module = "laddu",
711    from_py_object
712)]
713#[derive(Clone, Debug)]
714pub struct PyRejectionSamplingDiagnostics(pub RejectionSamplingDiagnostics);
715
716#[pymethods]
717impl PyRejectionSamplingDiagnostics {
718    #[getter]
719    fn generated_events(&self) -> usize {
720        self.0.generated_events
721    }
722
723    #[getter]
724    fn accepted_events(&self) -> usize {
725        self.0.accepted_events
726    }
727
728    #[getter]
729    fn rejected_events(&self) -> usize {
730        self.0.rejected_events
731    }
732
733    #[getter]
734    fn max_observed_weight(&self) -> f64 {
735        self.0.max_observed_weight
736    }
737
738    #[getter]
739    fn envelope_max_weight(&self) -> f64 {
740        self.0.envelope_max_weight
741    }
742
743    #[getter]
744    fn envelope_violations(&self) -> usize {
745        self.0.envelope_violations
746    }
747
748    fn acceptance_efficiency(&self) -> f64 {
749        self.0.acceptance_efficiency()
750    }
751
752    fn __repr__(&self) -> String {
753        format!("{:?}", self.0)
754    }
755}
756
757/// Iterator over expression rejection-sampled generated batches.
758#[pyclass(
759    name = "RejectionSampleIter",
760    module = "laddu",
761    unsendable,
762    skip_from_py_object
763)]
764pub struct PyRejectionSampleIter {
765    iter: RejectionSampleIter<ExpressionIntensity>,
766}
767
768#[pymethods]
769impl PyRejectionSampleIter {
770    fn __iter__(slf: PyRef<'_, Self>) -> Py<PyRejectionSampleIter> {
771        slf.into()
772    }
773
774    fn __next__(&mut self) -> PyResult<Option<PyGeneratedBatch>> {
775        match self.iter.next() {
776            Some(Ok(batch)) => Ok(Some(PyGeneratedBatch(batch))),
777            Some(Err(err)) => Err(PyErr::from(err)),
778            None => Ok(None),
779        }
780    }
781
782    #[getter]
783    fn diagnostics(&self) -> PyRejectionSamplingDiagnostics {
784        PyRejectionSamplingDiagnostics(self.iter.diagnostics().clone())
785    }
786}
787
788/// Event generator for generated reaction layouts.
789#[pyclass(name = "EventGenerator", module = "laddu", from_py_object)]
790#[derive(Clone, Debug)]
791pub struct PyEventGenerator(pub EventGenerator);
792
793#[pymethods]
794impl PyEventGenerator {
795    /// Construct an event generator.
796    #[new]
797    #[pyo3(signature = (reaction, aux_generators=None, seed=None, storage=None))]
798    fn new(
799        reaction: &PyGeneratedReaction,
800        aux_generators: Option<HashMap<String, PyDistribution>>,
801        seed: Option<u64>,
802        storage: Option<&PyGeneratedStorage>,
803    ) -> PyResult<Self> {
804        let generator = EventGenerator::new(
805            reaction.0.clone(),
806            aux_generators
807                .unwrap_or_default()
808                .into_iter()
809                .map(|(name, distribution)| (name, distribution.0))
810                .collect(),
811            seed,
812        );
813        let generator = if let Some(storage) = storage {
814            generator.with_storage(storage.0.clone())?
815        } else {
816            generator
817        };
818        Ok(Self(generator))
819    }
820
821    /// Generate one dataset batch with generated layout metadata.
822    fn generate_batch(&self, n_events: usize) -> PyResult<PyGeneratedBatch> {
823        Ok(PyGeneratedBatch(self.0.generate_batch(n_events)?))
824    }
825
826    /// Generate a finite iterator over generated dataset batches.
827    fn generate_batches(
828        &self,
829        total_events: usize,
830        batch_size: usize,
831    ) -> PyResult<PyGeneratedBatchIter> {
832        Ok(PyGeneratedBatchIter {
833            iter: Box::new(self.0.generate_batches(total_events, batch_size)?),
834        })
835    }
836
837    /// Generate accepted batches using a real-valued expression as the rejection intensity.
838    #[allow(clippy::too_many_arguments)]
839    #[pyo3(signature = (
840        expression,
841        parameters,
842        *,
843        n_events,
844        generation_batch_size,
845        output_batch_size,
846        envelope,
847        seed=None
848    ))]
849    fn generate_batches_rejection(
850        &self,
851        expression: &PyExpression,
852        parameters: Vec<f64>,
853        n_events: usize,
854        generation_batch_size: usize,
855        output_batch_size: usize,
856        envelope: &PyRejectionEnvelope,
857        seed: Option<u64>,
858    ) -> PyResult<PyRejectionSampleIter> {
859        let sampler = self.0.rejection_sampler_with_expression(
860            expression.0.clone(),
861            parameters,
862            RejectionSamplingOptions {
863                target_accepted: n_events,
864                generation_batch_size,
865                output_batch_size,
866                envelope: envelope.0.clone(),
867                seed: seed.unwrap_or_else(|| fastrand::u64(..)),
868            },
869        )?;
870        Ok(PyRejectionSampleIter {
871            iter: sampler.accepted_batches(),
872        })
873    }
874
875    /// Generate a Dataset using a real-valued expression as the rejection intensity.
876    #[allow(clippy::too_many_arguments)]
877    #[pyo3(signature = (
878        expression,
879        parameters,
880        *,
881        n_events,
882        generation_batch_size,
883        output_batch_size,
884        envelope,
885        seed=None
886    ))]
887    fn generate_dataset_rejection(
888        &self,
889        expression: &PyExpression,
890        parameters: Vec<f64>,
891        n_events: usize,
892        generation_batch_size: usize,
893        output_batch_size: usize,
894        envelope: &PyRejectionEnvelope,
895        seed: Option<u64>,
896    ) -> PyResult<PyDataset> {
897        let mut iter = self.generate_batches_rejection(
898            expression,
899            parameters,
900            n_events,
901            generation_batch_size,
902            output_batch_size,
903            envelope,
904            seed,
905        )?;
906        let mut output = None;
907        while let Some(batch) = iter.__next__()? {
908            let dataset = batch.0.into_dataset();
909            if output.is_none() {
910                output = Some(laddu_core::Dataset::empty_local(dataset.metadata().clone()));
911            }
912            let output_dataset = output.as_mut().expect("output dataset should exist");
913            for index in 0..dataset.n_events() {
914                let event = dataset.event_global(index)?;
915                output_dataset.push_event_local(
916                    event.p4s.clone(),
917                    event.aux.clone(),
918                    event.weight,
919                )?;
920            }
921        }
922        let output = match output {
923            Some(output) => output,
924            None => self.0.generate_batch(0)?.into_dataset(),
925        };
926        Ok(PyDataset(Arc::new(output)))
927    }
928
929    /// Generate a dataset.
930    fn generate_dataset(&self, n_events: usize) -> PyResult<PyDataset> {
931        Ok(PyDataset(Arc::new(self.0.generate_dataset(n_events)?)))
932    }
933
934    fn __repr__(&self) -> String {
935        format!("{:?}", self.0)
936    }
937}