1use std::{collections::HashMap, sync::Arc};
2
3use laddu_generation::{
4 CompositeGenerator, Distribution, EventGenerator, ExpressionIntensity, GeneratedBatch,
5 GeneratedEventLayout, GeneratedParticle, GeneratedParticleLayout, GeneratedReaction,
6 GeneratedStorage, GeneratedVertexKind, GeneratedVertexLayout, HistogramSampler,
7 InitialGenerator, MandelstamTDistribution, ParticleSpecies, Reconstruction, RejectionEnvelope,
8 RejectionSampleIter, RejectionSamplingDiagnostics, RejectionSamplingOptions, StableGenerator,
9};
10use pyo3::{exceptions::PyValueError, prelude::*, types::PyTuple};
11
12use crate::{
13 amplitudes::PyExpression, data::PyDataset, math::PyHistogram, variables::PyReaction,
14 vectors::PyVec4,
15};
16
17#[pyclass(name = "Distribution", module = "laddu", from_py_object)]
19#[derive(Clone, Debug)]
20pub struct PyDistribution(pub Distribution);
21
22#[pymethods]
23impl PyDistribution {
24 #[staticmethod]
26 fn fixed(value: f64) -> Self {
27 Self(Distribution::Fixed(value))
28 }
29
30 #[staticmethod]
32 fn uniform(min: f64, max: f64) -> PyResult<Self> {
33 if max <= min {
34 return Err(PyValueError::new_err(
35 "`max` must be greater than `min` for a uniform distribution",
36 ));
37 }
38 Ok(Self(Distribution::Uniform { min, max }))
39 }
40
41 #[staticmethod]
43 fn normal(mu: f64, sigma: f64) -> PyResult<Self> {
44 if sigma <= 0.0 {
45 return Err(PyValueError::new_err(
46 "`sigma` must be positive for a normal distribution",
47 ));
48 }
49 Ok(Self(Distribution::Normal { mu, sigma }))
50 }
51
52 #[staticmethod]
54 fn exponential(slope: f64) -> PyResult<Self> {
55 if slope <= 0.0 {
56 return Err(PyValueError::new_err(
57 "`slope` must be positive for an exponential distribution",
58 ));
59 }
60 Ok(Self(Distribution::Exponential { slope }))
61 }
62
63 #[staticmethod]
65 fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
66 Ok(Self(Distribution::Histogram(HistogramSampler::new(
67 histogram.0.clone(),
68 )?)))
69 }
70
71 fn __repr__(&self) -> String {
72 format!("{:?}", self.0)
73 }
74}
75
76#[pyclass(name = "MandelstamTDistribution", module = "laddu", from_py_object)]
78#[derive(Clone, Debug)]
79pub struct PyMandelstamTDistribution(pub MandelstamTDistribution);
80
81#[pymethods]
82impl PyMandelstamTDistribution {
83 #[staticmethod]
85 fn exponential(slope: f64) -> PyResult<Self> {
86 if slope <= 0.0 {
87 return Err(PyValueError::new_err(
88 "`slope` must be positive for an exponential distribution",
89 ));
90 }
91 Ok(Self(MandelstamTDistribution::Exponential { slope }))
92 }
93
94 #[staticmethod]
96 fn histogram(histogram: &PyHistogram) -> PyResult<Self> {
97 Ok(Self(MandelstamTDistribution::Histogram(
98 HistogramSampler::new(histogram.0.clone())?,
99 )))
100 }
101
102 fn __repr__(&self) -> String {
103 format!("{:?}", self.0)
104 }
105}
106
107#[pyclass(name = "InitialGenerator", module = "laddu", from_py_object)]
109#[derive(Clone, Debug)]
110pub struct PyInitialGenerator(pub InitialGenerator);
111
112#[pymethods]
113impl PyInitialGenerator {
114 #[staticmethod]
116 fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
117 Self(InitialGenerator::beam_with_fixed_energy(mass, energy))
118 }
119
120 #[staticmethod]
122 fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
123 Self(InitialGenerator::beam(mass, min_energy, max_energy))
124 }
125
126 #[staticmethod]
128 fn beam_with_energy_histogram(mass: f64, energy: &PyHistogram) -> PyResult<Self> {
129 Ok(Self(InitialGenerator::beam_with_energy_histogram(
130 mass,
131 energy.0.clone(),
132 )?))
133 }
134
135 #[staticmethod]
137 fn target(mass: f64) -> Self {
138 Self(InitialGenerator::target(mass))
139 }
140
141 fn __repr__(&self) -> String {
142 format!("{:?}", self.0)
143 }
144}
145
146#[pyclass(name = "CompositeGenerator", module = "laddu", from_py_object)]
148#[derive(Clone, Debug)]
149pub struct PyCompositeGenerator(pub CompositeGenerator);
150
151#[pymethods]
152impl PyCompositeGenerator {
153 #[new]
155 fn new(min_mass: f64, max_mass: f64) -> Self {
156 Self(CompositeGenerator::new(min_mass, max_mass))
157 }
158
159 fn __repr__(&self) -> String {
160 format!("{:?}", self.0)
161 }
162}
163
164#[pyclass(name = "StableGenerator", module = "laddu", from_py_object)]
166#[derive(Clone, Debug)]
167pub struct PyStableGenerator(pub StableGenerator);
168
169#[pymethods]
170impl PyStableGenerator {
171 #[new]
173 fn new(mass: f64) -> Self {
174 Self(StableGenerator::new(mass))
175 }
176
177 fn __repr__(&self) -> String {
178 format!("{:?}", self.0)
179 }
180}
181
182#[pyclass(name = "Reconstruction", module = "laddu", from_py_object)]
184#[derive(Clone, Debug)]
185pub struct PyReconstruction(pub Reconstruction);
186
187#[pymethods]
188impl PyReconstruction {
189 #[staticmethod]
191 fn stored() -> Self {
192 Self(Reconstruction::Stored)
193 }
194
195 #[staticmethod]
197 fn fixed(p4: &PyVec4) -> Self {
198 Self(Reconstruction::Fixed(p4.0))
199 }
200
201 #[staticmethod]
203 fn missing() -> Self {
204 Self(Reconstruction::Missing)
205 }
206
207 #[staticmethod]
209 fn composite() -> Self {
210 Self(Reconstruction::Composite)
211 }
212
213 fn __repr__(&self) -> String {
214 format!("{:?}", self.0)
215 }
216}
217
218#[pyclass(name = "ParticleSpecies", module = "laddu", from_py_object)]
220#[derive(Clone, Debug)]
221pub struct PyParticleSpecies(pub ParticleSpecies);
222
223#[pymethods]
224impl PyParticleSpecies {
225 #[staticmethod]
227 fn code(id: i64) -> Self {
228 Self(ParticleSpecies::code(id))
229 }
230
231 #[staticmethod]
233 fn with_namespace(namespace: &str, id: i64) -> Self {
234 Self(ParticleSpecies::with_namespace(namespace, id))
235 }
236
237 #[staticmethod]
239 fn label(label: &str) -> Self {
240 Self(ParticleSpecies::label(label))
241 }
242
243 #[getter]
245 fn id(&self) -> Option<i64> {
246 match &self.0 {
247 ParticleSpecies::Code { id, .. } => Some(*id),
248 ParticleSpecies::Label(_) => None,
249 }
250 }
251
252 #[getter]
254 fn namespace(&self) -> Option<String> {
255 match &self.0 {
256 ParticleSpecies::Code { namespace, .. } => namespace.clone(),
257 ParticleSpecies::Label(_) => None,
258 }
259 }
260
261 #[getter]
263 fn label_value(&self) -> Option<String> {
264 match &self.0 {
265 ParticleSpecies::Code { .. } => None,
266 ParticleSpecies::Label(label) => Some(label.clone()),
267 }
268 }
269
270 fn __repr__(&self) -> String {
271 format!("{:?}", self.0)
272 }
273}
274
275#[pyclass(name = "GeneratedParticle", module = "laddu", from_py_object)]
277#[derive(Clone, Debug)]
278pub struct PyGeneratedParticle(pub GeneratedParticle);
279
280#[pymethods]
281impl PyGeneratedParticle {
282 #[staticmethod]
284 fn initial(
285 id: &str,
286 generator: &PyInitialGenerator,
287 reconstruction: &PyReconstruction,
288 ) -> Self {
289 Self(GeneratedParticle::initial(
290 id,
291 generator.0.clone(),
292 reconstruction.0.clone(),
293 ))
294 }
295
296 #[staticmethod]
298 fn stable(id: &str, generator: &PyStableGenerator, reconstruction: &PyReconstruction) -> Self {
299 Self(GeneratedParticle::stable(
300 id,
301 generator.0.clone(),
302 reconstruction.0.clone(),
303 ))
304 }
305
306 #[staticmethod]
308 fn composite(
309 id: &str,
310 generator: &PyCompositeGenerator,
311 daughters: &Bound<'_, PyTuple>,
312 reconstruction: &PyReconstruction,
313 ) -> PyResult<Self> {
314 if daughters.len() != 2 {
315 return Err(PyValueError::new_err(
316 "composite particles require exactly two ordered daughters",
317 ));
318 }
319 let daughter_1 = daughters.get_item(0)?.extract::<Self>()?;
320 let daughter_2 = daughters.get_item(1)?.extract::<Self>()?;
321 Ok(Self(GeneratedParticle::composite(
322 id,
323 generator.0.clone(),
324 (&daughter_1.0, &daughter_2.0),
325 reconstruction.0.clone(),
326 )))
327 }
328
329 fn with_species(&self, species: &PyParticleSpecies) -> Self {
331 Self(self.0.clone().with_species(species.0.clone()))
332 }
333
334 #[getter]
336 fn id(&self) -> String {
337 self.0.id().to_string()
338 }
339
340 #[getter]
342 fn species(&self) -> Option<PyParticleSpecies> {
343 self.0.species().cloned().map(PyParticleSpecies)
344 }
345
346 fn __repr__(&self) -> String {
347 format!("{:?}", self.0)
348 }
349}
350
351#[pyclass(name = "GeneratedReaction", module = "laddu", from_py_object)]
353#[derive(Clone, Debug)]
354pub struct PyGeneratedReaction(pub GeneratedReaction);
355
356#[pymethods]
357impl PyGeneratedReaction {
358 #[staticmethod]
360 fn two_to_two(
361 p1: &PyGeneratedParticle,
362 p2: &PyGeneratedParticle,
363 p3: &PyGeneratedParticle,
364 p4: &PyGeneratedParticle,
365 tdist: &PyMandelstamTDistribution,
366 ) -> PyResult<Self> {
367 Ok(Self(GeneratedReaction::two_to_two(
368 p1.0.clone(),
369 p2.0.clone(),
370 p3.0.clone(),
371 p4.0.clone(),
372 tdist.0.clone(),
373 )?))
374 }
375
376 fn p4_labels(&self) -> Vec<String> {
378 self.0.p4_labels()
379 }
380
381 fn particle_layouts(&self) -> Vec<PyGeneratedParticleLayout> {
383 self.0
384 .particle_layouts()
385 .into_iter()
386 .map(PyGeneratedParticleLayout)
387 .collect()
388 }
389
390 fn reconstructed_reaction(&self) -> PyResult<PyReaction> {
392 Ok(PyReaction(self.0.reconstructed_reaction()?))
393 }
394
395 fn __repr__(&self) -> String {
396 format!("{:?}", self.0)
397 }
398}
399
400#[pyclass(name = "GeneratedStorage", module = "laddu", from_py_object)]
402#[derive(Clone, Debug)]
403pub struct PyGeneratedStorage(pub GeneratedStorage);
404
405#[pymethods]
406impl PyGeneratedStorage {
407 #[staticmethod]
409 fn all() -> Self {
410 Self(GeneratedStorage::all())
411 }
412
413 #[staticmethod]
415 fn only(ids: Vec<String>) -> Self {
416 Self(GeneratedStorage::only(ids))
417 }
418
419 fn __repr__(&self) -> String {
420 format!("{:?}", self.0)
421 }
422}
423
424#[pyclass(name = "GeneratedParticleLayout", module = "laddu", from_py_object)]
426#[derive(Clone, Debug)]
427pub struct PyGeneratedParticleLayout(pub GeneratedParticleLayout);
428
429#[pymethods]
430impl PyGeneratedParticleLayout {
431 #[getter]
433 fn id(&self) -> String {
434 self.0.id().to_string()
435 }
436
437 #[getter]
439 fn product_id(&self) -> usize {
440 self.0.product_id()
441 }
442
443 #[getter]
445 fn parent_id(&self) -> Option<usize> {
446 self.0.parent_id()
447 }
448
449 #[getter]
451 fn species(&self) -> Option<PyParticleSpecies> {
452 self.0.species().cloned().map(PyParticleSpecies)
453 }
454
455 #[getter]
457 fn p4_label(&self) -> Option<String> {
458 self.0.p4_label().map(str::to_string)
459 }
460
461 #[getter]
463 fn produced_vertex_id(&self) -> Option<usize> {
464 self.0.produced_vertex_id()
465 }
466
467 #[getter]
469 fn decay_vertex_id(&self) -> Option<usize> {
470 self.0.decay_vertex_id()
471 }
472
473 fn __repr__(&self) -> String {
474 format!("{:?}", self.0)
475 }
476}
477
478#[pyclass(name = "GeneratedVertexLayout", module = "laddu", from_py_object)]
480#[derive(Clone, Debug)]
481pub struct PyGeneratedVertexLayout(pub GeneratedVertexLayout);
482
483#[pymethods]
484impl PyGeneratedVertexLayout {
485 #[getter]
487 fn vertex_id(&self) -> usize {
488 self.0.vertex_id()
489 }
490
491 #[getter]
493 fn kind(&self) -> &'static str {
494 match self.0.kind() {
495 GeneratedVertexKind::Production => "Production",
496 GeneratedVertexKind::Decay => "Decay",
497 }
498 }
499
500 #[getter]
502 fn incoming_product_ids(&self) -> Vec<usize> {
503 self.0.incoming_product_ids().to_vec()
504 }
505
506 #[getter]
508 fn outgoing_product_ids(&self) -> Vec<usize> {
509 self.0.outgoing_product_ids().to_vec()
510 }
511
512 fn __repr__(&self) -> String {
513 format!("{:?}", self.0)
514 }
515}
516
517#[pyclass(name = "GeneratedEventLayout", module = "laddu", from_py_object)]
519#[derive(Clone, Debug)]
520pub struct PyGeneratedEventLayout(pub GeneratedEventLayout);
521
522#[pymethods]
523impl PyGeneratedEventLayout {
524 #[getter]
526 fn p4_labels(&self) -> Vec<String> {
527 self.0.p4_labels().to_vec()
528 }
529
530 #[getter]
532 fn aux_labels(&self) -> Vec<String> {
533 self.0.aux_labels().to_vec()
534 }
535
536 #[getter]
538 fn particles(&self) -> Vec<PyGeneratedParticleLayout> {
539 self.0
540 .particles()
541 .iter()
542 .cloned()
543 .map(PyGeneratedParticleLayout)
544 .collect()
545 }
546
547 fn particle(&self, id: &str) -> Option<PyGeneratedParticleLayout> {
549 self.0.particle(id).cloned().map(PyGeneratedParticleLayout)
550 }
551
552 fn product(&self, product_id: usize) -> Option<PyGeneratedParticleLayout> {
554 self.0
555 .product(product_id)
556 .cloned()
557 .map(PyGeneratedParticleLayout)
558 }
559
560 #[getter]
562 fn vertices(&self) -> Vec<PyGeneratedVertexLayout> {
563 self.0
564 .vertices()
565 .iter()
566 .cloned()
567 .map(PyGeneratedVertexLayout)
568 .collect()
569 }
570
571 fn vertex(&self, vertex_id: usize) -> Option<PyGeneratedVertexLayout> {
573 self.0
574 .vertex(vertex_id)
575 .cloned()
576 .map(PyGeneratedVertexLayout)
577 }
578
579 fn production_vertex(&self) -> Option<PyGeneratedVertexLayout> {
581 self.0
582 .production_vertex()
583 .cloned()
584 .map(PyGeneratedVertexLayout)
585 }
586
587 fn decay_products(&self, parent_product_id: usize) -> Vec<PyGeneratedParticleLayout> {
589 self.0
590 .decay_products(parent_product_id)
591 .into_iter()
592 .cloned()
593 .map(PyGeneratedParticleLayout)
594 .collect()
595 }
596
597 fn production_incoming(&self) -> Vec<PyGeneratedParticleLayout> {
599 self.0
600 .production_incoming()
601 .into_iter()
602 .cloned()
603 .map(PyGeneratedParticleLayout)
604 .collect()
605 }
606
607 fn production_outgoing(&self) -> Vec<PyGeneratedParticleLayout> {
609 self.0
610 .production_outgoing()
611 .into_iter()
612 .cloned()
613 .map(PyGeneratedParticleLayout)
614 .collect()
615 }
616
617 fn __repr__(&self) -> String {
618 format!("{:?}", self.0)
619 }
620}
621
622#[pyclass(name = "GeneratedBatch", module = "laddu", from_py_object)]
624#[derive(Clone, Debug)]
625pub struct PyGeneratedBatch(pub GeneratedBatch);
626
627#[pymethods]
628impl PyGeneratedBatch {
629 #[getter]
631 fn dataset(&self) -> PyDataset {
632 PyDataset(Arc::new(self.0.dataset().clone()))
633 }
634
635 #[getter]
637 fn reaction(&self) -> PyGeneratedReaction {
638 PyGeneratedReaction(self.0.reaction().clone())
639 }
640
641 #[getter]
643 fn layout(&self) -> PyGeneratedEventLayout {
644 PyGeneratedEventLayout(self.0.layout().clone())
645 }
646
647 fn __repr__(&self) -> String {
648 format!("{:?}", self.0)
649 }
650}
651
652#[pyclass(
654 name = "GeneratedBatchIter",
655 module = "laddu",
656 unsendable,
657 skip_from_py_object
658)]
659pub struct PyGeneratedBatchIter {
660 iter: Box<dyn Iterator<Item = laddu_core::LadduResult<GeneratedBatch>>>,
661}
662
663#[pymethods]
664impl PyGeneratedBatchIter {
665 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyGeneratedBatchIter> {
666 slf.into()
667 }
668
669 fn __next__(&mut self) -> PyResult<Option<PyGeneratedBatch>> {
670 match self.iter.next() {
671 Some(Ok(batch)) => Ok(Some(PyGeneratedBatch(batch))),
672 Some(Err(err)) => Err(PyErr::from(err)),
673 None => Ok(None),
674 }
675 }
676}
677
678#[pyclass(name = "RejectionEnvelope", module = "laddu", from_py_object)]
680#[derive(Clone, Debug)]
681pub struct PyRejectionEnvelope(pub RejectionEnvelope);
682
683#[pymethods]
684impl PyRejectionEnvelope {
685 #[staticmethod]
687 fn fixed(max_weight: f64) -> Self {
688 Self(RejectionEnvelope::Fixed { max_weight })
689 }
690
691 #[staticmethod]
693 #[pyo3(signature = (pilot_events, *, safety_factor=1.2, batch_size=None))]
694 fn pilot(pilot_events: usize, safety_factor: f64, batch_size: Option<usize>) -> Self {
695 Self(RejectionEnvelope::Pilot {
696 pilot_events,
697 batch_size,
698 safety_factor,
699 })
700 }
701
702 fn __repr__(&self) -> String {
703 format!("{:?}", self.0)
704 }
705}
706
707#[pyclass(
709 name = "RejectionSamplingDiagnostics",
710 module = "laddu",
711 from_py_object
712)]
713#[derive(Clone, Debug)]
714pub struct PyRejectionSamplingDiagnostics(pub RejectionSamplingDiagnostics);
715
716#[pymethods]
717impl PyRejectionSamplingDiagnostics {
718 #[getter]
719 fn generated_events(&self) -> usize {
720 self.0.generated_events
721 }
722
723 #[getter]
724 fn accepted_events(&self) -> usize {
725 self.0.accepted_events
726 }
727
728 #[getter]
729 fn rejected_events(&self) -> usize {
730 self.0.rejected_events
731 }
732
733 #[getter]
734 fn max_observed_weight(&self) -> f64 {
735 self.0.max_observed_weight
736 }
737
738 #[getter]
739 fn envelope_max_weight(&self) -> f64 {
740 self.0.envelope_max_weight
741 }
742
743 #[getter]
744 fn envelope_violations(&self) -> usize {
745 self.0.envelope_violations
746 }
747
748 fn acceptance_efficiency(&self) -> f64 {
749 self.0.acceptance_efficiency()
750 }
751
752 fn __repr__(&self) -> String {
753 format!("{:?}", self.0)
754 }
755}
756
757#[pyclass(
759 name = "RejectionSampleIter",
760 module = "laddu",
761 unsendable,
762 skip_from_py_object
763)]
764pub struct PyRejectionSampleIter {
765 iter: RejectionSampleIter<ExpressionIntensity>,
766}
767
768#[pymethods]
769impl PyRejectionSampleIter {
770 fn __iter__(slf: PyRef<'_, Self>) -> Py<PyRejectionSampleIter> {
771 slf.into()
772 }
773
774 fn __next__(&mut self) -> PyResult<Option<PyGeneratedBatch>> {
775 match self.iter.next() {
776 Some(Ok(batch)) => Ok(Some(PyGeneratedBatch(batch))),
777 Some(Err(err)) => Err(PyErr::from(err)),
778 None => Ok(None),
779 }
780 }
781
782 #[getter]
783 fn diagnostics(&self) -> PyRejectionSamplingDiagnostics {
784 PyRejectionSamplingDiagnostics(self.iter.diagnostics().clone())
785 }
786}
787
788#[pyclass(name = "EventGenerator", module = "laddu", from_py_object)]
790#[derive(Clone, Debug)]
791pub struct PyEventGenerator(pub EventGenerator);
792
793#[pymethods]
794impl PyEventGenerator {
795 #[new]
797 #[pyo3(signature = (reaction, aux_generators=None, seed=None, storage=None))]
798 fn new(
799 reaction: &PyGeneratedReaction,
800 aux_generators: Option<HashMap<String, PyDistribution>>,
801 seed: Option<u64>,
802 storage: Option<&PyGeneratedStorage>,
803 ) -> PyResult<Self> {
804 let generator = EventGenerator::new(
805 reaction.0.clone(),
806 aux_generators
807 .unwrap_or_default()
808 .into_iter()
809 .map(|(name, distribution)| (name, distribution.0))
810 .collect(),
811 seed,
812 );
813 let generator = if let Some(storage) = storage {
814 generator.with_storage(storage.0.clone())?
815 } else {
816 generator
817 };
818 Ok(Self(generator))
819 }
820
821 fn generate_batch(&self, n_events: usize) -> PyResult<PyGeneratedBatch> {
823 Ok(PyGeneratedBatch(self.0.generate_batch(n_events)?))
824 }
825
826 fn generate_batches(
828 &self,
829 total_events: usize,
830 batch_size: usize,
831 ) -> PyResult<PyGeneratedBatchIter> {
832 Ok(PyGeneratedBatchIter {
833 iter: Box::new(self.0.generate_batches(total_events, batch_size)?),
834 })
835 }
836
837 #[allow(clippy::too_many_arguments)]
839 #[pyo3(signature = (
840 expression,
841 parameters,
842 *,
843 n_events,
844 generation_batch_size,
845 output_batch_size,
846 envelope,
847 seed=None
848 ))]
849 fn generate_batches_rejection(
850 &self,
851 expression: &PyExpression,
852 parameters: Vec<f64>,
853 n_events: usize,
854 generation_batch_size: usize,
855 output_batch_size: usize,
856 envelope: &PyRejectionEnvelope,
857 seed: Option<u64>,
858 ) -> PyResult<PyRejectionSampleIter> {
859 let sampler = self.0.rejection_sampler_with_expression(
860 expression.0.clone(),
861 parameters,
862 RejectionSamplingOptions {
863 target_accepted: n_events,
864 generation_batch_size,
865 output_batch_size,
866 envelope: envelope.0.clone(),
867 seed: seed.unwrap_or_else(|| fastrand::u64(..)),
868 },
869 )?;
870 Ok(PyRejectionSampleIter {
871 iter: sampler.accepted_batches(),
872 })
873 }
874
875 #[allow(clippy::too_many_arguments)]
877 #[pyo3(signature = (
878 expression,
879 parameters,
880 *,
881 n_events,
882 generation_batch_size,
883 output_batch_size,
884 envelope,
885 seed=None
886 ))]
887 fn generate_dataset_rejection(
888 &self,
889 expression: &PyExpression,
890 parameters: Vec<f64>,
891 n_events: usize,
892 generation_batch_size: usize,
893 output_batch_size: usize,
894 envelope: &PyRejectionEnvelope,
895 seed: Option<u64>,
896 ) -> PyResult<PyDataset> {
897 let mut iter = self.generate_batches_rejection(
898 expression,
899 parameters,
900 n_events,
901 generation_batch_size,
902 output_batch_size,
903 envelope,
904 seed,
905 )?;
906 let mut output = None;
907 while let Some(batch) = iter.__next__()? {
908 let dataset = batch.0.into_dataset();
909 if output.is_none() {
910 output = Some(laddu_core::Dataset::empty_local(dataset.metadata().clone()));
911 }
912 let output_dataset = output.as_mut().expect("output dataset should exist");
913 for index in 0..dataset.n_events() {
914 let event = dataset.event_global(index)?;
915 output_dataset.push_event_local(
916 event.p4s.clone(),
917 event.aux.clone(),
918 event.weight,
919 )?;
920 }
921 }
922 let output = match output {
923 Some(output) => output,
924 None => self.0.generate_batch(0)?.into_dataset(),
925 };
926 Ok(PyDataset(Arc::new(output)))
927 }
928
929 fn generate_dataset(&self, n_events: usize) -> PyResult<PyDataset> {
931 Ok(PyDataset(Arc::new(self.0.generate_dataset(n_events)?)))
932 }
933
934 fn __repr__(&self) -> String {
935 format!("{:?}", self.0)
936 }
937}