Skip to main content

laddu_generation/
topology.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use fastrand::Rng;
7use laddu_core::{
8    math::{q_m, Histogram, Sheet},
9    Dataset, DatasetMetadata, Expression, LadduError, LadduResult, Particle, Reaction, Vec3, Vec4,
10    PI,
11};
12use serde::{Deserialize, Serialize};
13
14use crate::distributions::{
15    Distribution, HistogramSampler, LadduGenRngExt, MandelstamTDistribution, SimpleDistribution,
16};
17
18/// Selects which generated particle four-momenta are written into generated datasets.
19///
20/// The generated reaction layout always retains the full generated graph. This policy only controls
21/// which generated particle IDs become p4 columns in generated [`Dataset`] values and which
22/// particles have a p4 label in [`GeneratedEventLayout`].
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum GeneratedStorage {
25    /// Store every generated particle p4.
26    All,
27    /// Store only the listed generated particle IDs, preserving reaction p4-label order.
28    Only(Vec<String>),
29}
30
31impl GeneratedStorage {
32    /// Store every generated particle p4.
33    pub fn all() -> Self {
34        Self::All
35    }
36
37    /// Store only the listed generated particle IDs.
38    pub fn only<I, S>(ids: I) -> Self
39    where
40        I: IntoIterator<Item = S>,
41        S: Into<String>,
42    {
43        Self::Only(ids.into_iter().map(Into::into).collect())
44    }
45
46    /// Return true if `id` is selected for dataset storage.
47    pub fn stores(&self, id: &str) -> bool {
48        match self {
49            Self::All => true,
50            Self::Only(ids) => ids.iter().any(|stored_id| stored_id == id),
51        }
52    }
53
54    fn validate(&self, available_ids: &[String]) -> LadduResult<()> {
55        let available = available_ids
56            .iter()
57            .map(String::as_str)
58            .collect::<HashSet<_>>();
59        let Self::Only(ids) = self else {
60            return Ok(());
61        };
62        let mut seen = HashSet::new();
63        for id in ids {
64            if !seen.insert(id.as_str()) {
65                return Err(LadduError::Custom(format!(
66                    "generated storage contains duplicate particle ID '{id}'"
67                )));
68            }
69            if !available.contains(id.as_str()) {
70                return Err(LadduError::Custom(format!(
71                    "generated storage references unknown particle ID '{id}'"
72                )));
73            }
74        }
75        Ok(())
76    }
77
78    fn stored_labels(&self, all_labels: &[String]) -> Vec<String> {
79        all_labels
80            .iter()
81            .filter(|label| self.stores(label))
82            .cloned()
83            .collect()
84    }
85}
86
87/// Experiment-neutral metadata describing a generated particle species.
88///
89/// Species metadata is intentionally separate from generated particle IDs and reconstructed
90/// reaction particles. It is meant for generator/export layers that need an external particle code
91/// or label without forcing laddu to adopt an experiment-specific particle table.
92#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
93pub enum ParticleSpecies {
94    /// A numeric species code with an optional namespace.
95    Code {
96        /// Numeric species identifier.
97        id: i64,
98        /// Optional namespace, such as `"pdg"`.
99        namespace: Option<String>,
100    },
101    /// A free-form species label.
102    Label(String),
103}
104
105impl ParticleSpecies {
106    /// Construct a species from a numeric code with no namespace.
107    pub fn code(id: i64) -> Self {
108        Self::Code {
109            id,
110            namespace: None,
111        }
112    }
113
114    /// Construct a species from a numeric code in an explicit namespace.
115    pub fn with_namespace(namespace: impl Into<String>, id: i64) -> Self {
116        Self::Code {
117            id,
118            namespace: Some(namespace.into()),
119        }
120    }
121
122    /// Construct a species from a free-form label.
123    pub fn label(label: impl Into<String>) -> Self {
124        Self::Label(label.into())
125    }
126}
127
128fn basis(z: Vec3) -> (Vec3, Vec3, Vec3) {
129    let z = z.unit();
130    let ref_axis = if z.z.abs() < 0.9 {
131        Vec3::z()
132    } else {
133        Vec3::y()
134    };
135    let x = ref_axis.cross(&z).unit();
136    let y = z.cross(&x);
137    (x, y, z)
138}
139
140/// Generator settings for an initial-state particle.
141#[derive(Clone, Debug)]
142pub struct InitialGenerator {
143    mass: f64,
144    energy_distribution: SimpleDistribution,
145}
146
147impl InitialGenerator {
148    /// Construct a beam with fixed energy.
149    pub fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
150        debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
151        debug_assert!(energy > 0.0, "Energy must be positive!\nEnergy: {}", energy);
152        Self {
153            mass,
154            energy_distribution: SimpleDistribution::Fixed(energy),
155        }
156    }
157
158    /// Construct a beam with uniformly sampled energy.
159    pub fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
160        debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
161        debug_assert!(
162            min_energy > 0.0,
163            "Minimum energy must be positive!\nMinimum Energy: {}",
164            min_energy
165        );
166        debug_assert!(
167            max_energy > min_energy,
168            "Maximum energy must be greater than minimum energy!"
169        );
170        Self {
171            mass,
172            energy_distribution: SimpleDistribution::Uniform {
173                min: min_energy,
174                max: max_energy,
175            },
176        }
177    }
178
179    /// Construct a beam with histogram-sampled energy.
180    pub fn beam_with_energy_histogram(mass: f64, energy: Histogram) -> LadduResult<Self> {
181        debug_assert!(
182            mass >= 0.0,
183            "Mass must be positive and greater than zero!\nMass: {}",
184            mass
185        );
186        let sampler = HistogramSampler::new(energy)?;
187        debug_assert!(
188            sampler.hist.bin_edges()[0] >= mass,
189            "Mass cannot be greater than the minimum allowed energy!\nMass: {}\nMinimum Energy: {}",
190            mass,
191            sampler.hist.bin_edges()[0]
192        );
193        Ok(Self {
194            mass,
195            energy_distribution: SimpleDistribution::Histogram(sampler),
196        })
197    }
198
199    /// Construct a target at rest.
200    pub fn target(mass: f64) -> Self {
201        Self {
202            mass,
203            energy_distribution: SimpleDistribution::Fixed(mass),
204        }
205    }
206}
207
208/// Generator settings for a generated composite particle.
209#[derive(Clone, Debug)]
210pub struct CompositeGenerator {
211    mass_distribution: SimpleDistribution,
212}
213
214impl CompositeGenerator {
215    /// Construct a composite mass generator with a uniform mass range.
216    pub fn new(min_mass: f64, max_mass: f64) -> Self {
217        Self {
218            mass_distribution: SimpleDistribution::Uniform {
219                min: min_mass,
220                max: max_mass,
221            },
222        }
223    }
224
225    fn sample_mass(&self, rng: &mut Rng) -> f64 {
226        self.mass_distribution.sample(rng)
227    }
228}
229
230/// Generator settings for a stable generated particle.
231#[derive(Clone, Debug)]
232pub struct StableGenerator {
233    mass_distribution: SimpleDistribution,
234}
235
236impl StableGenerator {
237    /// Construct a fixed-mass stable-particle generator.
238    pub fn new(mass: f64) -> Self {
239        debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
240        Self {
241            mass_distribution: SimpleDistribution::Fixed(mass),
242        }
243    }
244
245    fn sample_mass(&self, rng: &mut Rng) -> f64 {
246        self.mass_distribution.sample(rng)
247    }
248}
249
250/// Reconstruction interpretation for a generated particle.
251#[derive(Clone, Debug, PartialEq)]
252pub enum Reconstruction {
253    /// The particle p4 is stored in the analysis dataset under the generated particle ID.
254    Stored,
255    /// The particle p4 is fixed in the reconstructed reaction.
256    Fixed(Vec4),
257    /// The particle p4 is inferred from reaction-level constraints.
258    Missing,
259    /// The particle p4 is reconstructed as a composite of its two generated daughters.
260    Composite,
261}
262
263/// A generated particle with generation and reconstruction metadata.
264#[derive(Clone, Debug)]
265pub enum GeneratedParticle {
266    /// An initial-state generated particle.
267    Initial {
268        id: String,
269        generator: InitialGenerator,
270        reconstruction: Reconstruction,
271        species: Option<ParticleSpecies>,
272    },
273    /// A stable generated particle.
274    Stable {
275        id: String,
276        generator: StableGenerator,
277        reconstruction: Reconstruction,
278        species: Option<ParticleSpecies>,
279    },
280    /// A generated composite particle with exactly two generated daughters.
281    Composite {
282        id: String,
283        generator: CompositeGenerator,
284        daughters: (Box<GeneratedParticle>, Box<GeneratedParticle>),
285        reconstruction: Reconstruction,
286        species: Option<ParticleSpecies>,
287    },
288}
289
290impl GeneratedParticle {
291    /// Construct a generated initial-state particle.
292    pub fn initial(
293        id: impl Into<String>,
294        generator: InitialGenerator,
295        reconstruction: Reconstruction,
296    ) -> Self {
297        Self::Initial {
298            id: id.into(),
299            generator,
300            reconstruction,
301            species: None,
302        }
303    }
304
305    /// Construct a generated stable particle.
306    pub fn stable(
307        id: impl Into<String>,
308        generator: StableGenerator,
309        reconstruction: Reconstruction,
310    ) -> Self {
311        Self::Stable {
312            id: id.into(),
313            generator,
314            reconstruction,
315            species: None,
316        }
317    }
318
319    /// Construct a generated composite particle from exactly two ordered daughters.
320    pub fn composite(
321        id: impl Into<String>,
322        generator: CompositeGenerator,
323        daughters: (&GeneratedParticle, &GeneratedParticle),
324        reconstruction: Reconstruction,
325    ) -> Self {
326        Self::Composite {
327            id: id.into(),
328            generator,
329            daughters: (Box::new(daughters.0.clone()), Box::new(daughters.1.clone())),
330            reconstruction,
331            species: None,
332        }
333    }
334
335    /// Return a copy of this generated particle with species metadata attached.
336    pub fn with_species(mut self, species: ParticleSpecies) -> Self {
337        match &mut self {
338            Self::Initial {
339                species: particle_species,
340                ..
341            }
342            | Self::Stable {
343                species: particle_species,
344                ..
345            }
346            | Self::Composite {
347                species: particle_species,
348                ..
349            } => *particle_species = Some(species),
350        }
351        self
352    }
353
354    /// Return the generated particle ID.
355    pub fn id(&self) -> &str {
356        match self {
357            Self::Initial { id, .. } | Self::Stable { id, .. } | Self::Composite { id, .. } => id,
358        }
359    }
360
361    /// Return optional species metadata for this generated particle.
362    pub fn species(&self) -> Option<&ParticleSpecies> {
363        match self {
364            Self::Initial { species, .. }
365            | Self::Stable { species, .. }
366            | Self::Composite { species, .. } => species.as_ref(),
367        }
368    }
369
370    /// Return this particle's reconstruction interpretation.
371    pub fn reconstruction(&self) -> &Reconstruction {
372        match self {
373            Self::Initial { reconstruction, .. }
374            | Self::Stable { reconstruction, .. }
375            | Self::Composite { reconstruction, .. } => reconstruction,
376        }
377    }
378
379    fn p4_labels(&self) -> Vec<String> {
380        let mut labels = vec![self.id().to_string()];
381        if let Self::Composite { daughters, .. } = self {
382            labels.append(&mut daughters.0.p4_labels());
383            labels.append(&mut daughters.1.p4_labels());
384        }
385        labels
386    }
387
388    fn append_decay_layout(
389        &self,
390        parent_id: Option<usize>,
391        produced_vertex_id: Option<usize>,
392        storage: &GeneratedStorage,
393        particles: &mut Vec<GeneratedParticleLayout>,
394        vertices: &mut Vec<GeneratedVertexLayout>,
395    ) -> usize {
396        let product_id = particles.len();
397        particles.push(GeneratedParticleLayout {
398            id: self.id().to_string(),
399            product_id,
400            parent_id,
401            species: self.species().cloned(),
402            p4_label: storage.stores(self.id()).then(|| self.id().to_string()),
403            produced_vertex_id,
404            decay_vertex_id: None,
405        });
406        if let Self::Composite { daughters, .. } = self {
407            let vertex_id = vertices.len();
408            particles[product_id].decay_vertex_id = Some(vertex_id);
409            vertices.push(GeneratedVertexLayout {
410                vertex_id,
411                kind: GeneratedVertexKind::Decay,
412                incoming_product_ids: vec![product_id],
413                outgoing_product_ids: Vec::new(),
414            });
415            let daughter_1_id = daughters.0.append_decay_layout(
416                Some(product_id),
417                Some(vertex_id),
418                storage,
419                particles,
420                vertices,
421            );
422            let daughter_2_id = daughters.1.append_decay_layout(
423                Some(product_id),
424                Some(vertex_id),
425                storage,
426                particles,
427                vertices,
428            );
429            vertices[vertex_id].outgoing_product_ids = vec![daughter_1_id, daughter_2_id];
430        }
431        product_id
432    }
433
434    fn sample_mass(&self, rng: &mut Rng) -> f64 {
435        match self {
436            Self::Initial { generator, .. } => generator.mass,
437            Self::Stable { generator, .. } => generator.sample_mass(rng),
438            Self::Composite { generator, .. } => generator.sample_mass(rng),
439        }
440    }
441
442    fn generated_particle(&self) -> LadduResult<Particle> {
443        match self.reconstruction() {
444            Reconstruction::Stored => Ok(Particle::stored(self.id())),
445            Reconstruction::Fixed(p4) => Ok(Particle::fixed(self.id(), *p4)),
446            Reconstruction::Missing => Ok(Particle::missing(self.id())),
447            Reconstruction::Composite => {
448                let Self::Composite { daughters, .. } = self else {
449                    return Err(LadduError::Custom(format!(
450                        "particle '{}' cannot use composite reconstruction without daughters",
451                        self.id()
452                    )));
453                };
454                let daughter_1 = daughters.0.generated_particle()?;
455                let daughter_2 = daughters.1.generated_particle()?;
456                Particle::composite(self.id(), (&daughter_1, &daughter_2))
457            }
458        }
459    }
460
461    fn validate_reconstruction(&self) -> LadduResult<()> {
462        match (self, self.reconstruction()) {
463            (Self::Composite { daughters, .. }, Reconstruction::Composite) => {
464                daughters.0.validate_reconstruction()?;
465                daughters.1.validate_reconstruction()?;
466                Ok(())
467            }
468            (Self::Composite { .. }, _) => Ok(()),
469            (_, Reconstruction::Composite) => Err(LadduError::Custom(format!(
470                "particle '{}' cannot use composite reconstruction without daughters",
471                self.id()
472            ))),
473            _ => Ok(()),
474        }
475    }
476
477    fn collect_ids<'a>(&'a self, seen: &mut HashSet<&'a str>) -> LadduResult<()> {
478        if !seen.insert(self.id()) {
479            return Err(LadduError::Custom(format!(
480                "duplicate generated particle identifier '{}'",
481                self.id()
482            )));
483        }
484        if let Self::Composite { daughters, .. } = self {
485            daughters.0.collect_ids(seen)?;
486            daughters.1.collect_ids(seen)?;
487        }
488        Ok(())
489    }
490
491    fn generate_decay(
492        &self,
493        rng: &mut Rng,
494        p4_cm: Vec4,
495        cm_to_lab_boost: &Vec3,
496        p4_storage: &mut HashMap<String, Vec<Vec4>>,
497    ) {
498        let p4_lab = p4_cm.boost(cm_to_lab_boost);
499        if let Some(storage) = p4_storage.get_mut(self.id()) {
500            storage.push(p4_lab);
501        }
502
503        let Self::Composite { daughters, .. } = self else {
504            return;
505        };
506        let d1 = &daughters.0;
507        let d2 = &daughters.1;
508        let parent_mass = p4_cm.m();
509        let m1 = d1.sample_mass(rng);
510        let m2 = d2.sample_mass(rng);
511        let q = q_m(parent_mass, m1, m2, Sheet::Physical).re;
512        let parent_msq = parent_mass * parent_mass;
513        let msq1 = m1 * m1;
514        let msq2 = m2 * m2;
515        let e1 = (parent_msq + msq1 - msq2) / (2.0 * parent_mass);
516        let e2 = (parent_msq + msq2 - msq1) / (2.0 * parent_mass);
517
518        let cos_theta = rng.uniform(-1.0, 1.0);
519        let sin_theta = (1.0 - cos_theta * cos_theta).sqrt();
520        let phi = rng.uniform(0.0, 2.0 * PI);
521        let (sin_phi, cos_phi) = phi.sin_cos();
522
523        let dir = Vec3::new(sin_theta * cos_phi, sin_theta * sin_phi, cos_theta);
524        let p1_p4_rest = (dir * q).with_energy(e1);
525        let p2_p4_rest = (-dir * q).with_energy(e2);
526        let parent_to_cm_boost = p4_cm.beta();
527        let p1_p4_cm = p1_p4_rest.boost(&parent_to_cm_boost);
528        let p2_p4_cm = p2_p4_rest.boost(&parent_to_cm_boost);
529        d1.generate_decay(rng, p1_p4_cm, cm_to_lab_boost, p4_storage);
530        d2.generate_decay(rng, p2_p4_cm, cm_to_lab_boost, p4_storage);
531    }
532}
533
534/// A generated two-to-two reaction preserving `p1 + p2 -> p3 + p4` role semantics.
535#[derive(Clone, Debug)]
536pub struct GeneratedTwoToTwoReaction {
537    p1: GeneratedParticle,
538    p2: GeneratedParticle,
539    p3: GeneratedParticle,
540    p4: GeneratedParticle,
541    tdist: MandelstamTDistribution,
542    p1_p3_lab_dir: Vec3,
543    p2_p3_lab_dir: Vec3,
544}
545
546impl GeneratedTwoToTwoReaction {
547    /// Construct a generated two-to-two reaction.
548    pub fn new(
549        p1: GeneratedParticle,
550        p2: GeneratedParticle,
551        p3: GeneratedParticle,
552        p4: GeneratedParticle,
553        tdist: MandelstamTDistribution,
554    ) -> LadduResult<Self> {
555        validate_initial_role(&p1, "p1")?;
556        validate_initial_role(&p2, "p2")?;
557        validate_final_role(&p3, "p3")?;
558        validate_final_role(&p4, "p4")?;
559        let reaction = Self {
560            p1,
561            p2,
562            p3,
563            p4,
564            tdist,
565            p1_p3_lab_dir: Vec3::z(),
566            p2_p3_lab_dir: -Vec3::z(),
567        };
568        reaction.validate()?;
569        Ok(reaction)
570    }
571
572    fn validate(&self) -> LadduResult<()> {
573        let mut seen = HashSet::new();
574        for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
575            particle.collect_ids(&mut seen)?;
576            particle.validate_reconstruction()?;
577        }
578        Ok(())
579    }
580
581    fn p4_labels(&self) -> Vec<String> {
582        let mut labels = Vec::new();
583        for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
584            labels.append(&mut particle.p4_labels());
585        }
586        labels
587    }
588
589    fn layout_components(
590        &self,
591        storage: &GeneratedStorage,
592    ) -> (Vec<GeneratedParticleLayout>, Vec<GeneratedVertexLayout>) {
593        let mut particles = Vec::new();
594        let mut vertices = vec![GeneratedVertexLayout {
595            vertex_id: 0,
596            kind: GeneratedVertexKind::Production,
597            incoming_product_ids: Vec::new(),
598            outgoing_product_ids: Vec::new(),
599        }];
600        let p1_id = self
601            .p1
602            .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
603        let p2_id = self
604            .p2
605            .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
606        let p3_id =
607            self.p3
608                .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
609        let p4_id =
610            self.p4
611                .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
612        vertices[0].incoming_product_ids = vec![p1_id, p2_id];
613        vertices[0].outgoing_product_ids = vec![p3_id, p4_id];
614        (particles, vertices)
615    }
616
617    fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
618        self.particle_layouts_with_storage(&GeneratedStorage::All)
619    }
620
621    fn particle_layouts_with_storage(
622        &self,
623        storage: &GeneratedStorage,
624    ) -> Vec<GeneratedParticleLayout> {
625        self.layout_components(storage).0
626    }
627
628    fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
629        self.layout_components(&GeneratedStorage::All).1
630    }
631
632    fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
633        Reaction::two_to_two(
634            &self.p1.generated_particle()?,
635            &self.p2.generated_particle()?,
636            &self.p3.generated_particle()?,
637            &self.p4.generated_particle()?,
638        )
639    }
640
641    fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
642        let GeneratedParticle::Initial {
643            id: p1_id,
644            generator: p1_generator,
645            ..
646        } = &self.p1
647        else {
648            unreachable!("validated generated two-to-two p1 role")
649        };
650        let GeneratedParticle::Initial {
651            id: p2_id,
652            generator: p2_generator,
653            ..
654        } = &self.p2
655        else {
656            unreachable!("validated generated two-to-two p2 role")
657        };
658
659        let p1_e = p1_generator.energy_distribution.sample(rng);
660        let p1_m = p1_generator.mass;
661        let p1_msq = p1_m * p1_m;
662        let p1_p4_lab = rng.p4(p1_m, p1_e, self.p1_p3_lab_dir);
663        if let Some(storage) = p4_storage.get_mut(p1_id) {
664            storage.push(p1_p4_lab)
665        }
666
667        let p2_e = p2_generator.energy_distribution.sample(rng);
668        let p2_m = p2_generator.mass;
669        let p2_msq = p2_m * p2_m;
670        let p2_p4_lab = rng.p4(p2_m, p2_e, self.p2_p3_lab_dir);
671        if let Some(storage) = p4_storage.get_mut(p2_id) {
672            storage.push(p2_p4_lab)
673        }
674
675        let cm = p1_p4_lab + p2_p4_lab;
676        let cm_boost = -cm.beta();
677        let s = cm.mag2();
678        let sqrt_s = s.sqrt();
679        let t = self.tdist.sample(rng);
680
681        let p1_p4_cm = p1_p4_lab.boost(&cm_boost);
682        let p3_m = self.p3.sample_mass(rng);
683        let p3_msq = p3_m * p3_m;
684        let p4_m = self.p4.sample_mass(rng);
685        let p4_msq = p4_m * p4_m;
686        let p_in_mag = q_m(sqrt_s, p1_m, p2_m, Sheet::Physical).re;
687        let p_out_mag = q_m(sqrt_s, p3_m, p4_m, Sheet::Physical).re;
688        let p1_e_cm = (s + p1_msq - p2_msq) / (2.0 * sqrt_s);
689        let p3_e_cm = (s + p3_msq - p4_msq) / (2.0 * sqrt_s);
690        let p4_e_cm = (s + p4_msq - p3_msq) / (2.0 * sqrt_s);
691        let costheta =
692            (t - p1_msq - p3_msq + 2.0 * p1_e_cm * p3_e_cm) / (2.0 * p_in_mag * p_out_mag);
693        let costheta = costheta.clamp(-1.0, 1.0);
694        let sintheta = (1.0 - costheta * costheta).sqrt();
695        let phi = rng.uniform(0.0, 2.0 * PI);
696        let (sin_phi, cos_phi) = phi.sin_cos();
697        let (x, y, z) = basis(p1_p4_cm.vec3());
698        let p3_dir_cm = x * (sintheta * cos_phi) + y * (sintheta * sin_phi) + z * costheta;
699
700        let p3_p4_cm = (p3_dir_cm * p_out_mag).with_energy(p3_e_cm);
701        self.p3
702            .generate_decay(rng, p3_p4_cm, &-cm_boost, p4_storage);
703        let p4_p4_cm = (-p3_dir_cm * p_out_mag).with_energy(p4_e_cm);
704        self.p4
705            .generate_decay(rng, p4_p4_cm, &-cm_boost, p4_storage);
706    }
707}
708
709fn validate_initial_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
710    if matches!(particle, GeneratedParticle::Initial { .. }) {
711        Ok(())
712    } else {
713        Err(LadduError::Custom(format!(
714            "generated two-to-two role '{role}' requires an initial particle"
715        )))
716    }
717}
718
719fn validate_final_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
720    if matches!(
721        particle,
722        GeneratedParticle::Stable { .. } | GeneratedParticle::Composite { .. }
723    ) {
724        Ok(())
725    } else {
726        Err(LadduError::Custom(format!(
727            "generated two-to-two role '{role}' requires an outgoing particle"
728        )))
729    }
730}
731
732/// A generated reaction topology.
733#[derive(Clone, Debug)]
734pub enum GeneratedReactionTopology {
735    /// A generated two-to-two topology.
736    TwoToTwo(GeneratedTwoToTwoReaction),
737}
738
739impl GeneratedReactionTopology {
740    fn p4_labels(&self) -> Vec<String> {
741        match self {
742            Self::TwoToTwo(reaction) => reaction.p4_labels(),
743        }
744    }
745
746    fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
747        match self {
748            Self::TwoToTwo(reaction) => reaction.particle_layouts(),
749        }
750    }
751
752    fn particle_layouts_with_storage(
753        &self,
754        storage: &GeneratedStorage,
755    ) -> Vec<GeneratedParticleLayout> {
756        match self {
757            Self::TwoToTwo(reaction) => reaction.particle_layouts_with_storage(storage),
758        }
759    }
760
761    fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
762        match self {
763            Self::TwoToTwo(reaction) => reaction.vertex_layouts(),
764        }
765    }
766
767    fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
768        match self {
769            Self::TwoToTwo(reaction) => reaction.reconstructed_reaction(),
770        }
771    }
772
773    fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
774        match self {
775            Self::TwoToTwo(reaction) => reaction.generate_event(rng, p4_storage),
776        }
777    }
778}
779
780/// A generated reaction layout.
781#[derive(Clone, Debug)]
782pub struct GeneratedReaction {
783    topology: GeneratedReactionTopology,
784}
785
786impl GeneratedReaction {
787    /// Construct a generated two-to-two reaction.
788    pub fn two_to_two(
789        p1: GeneratedParticle,
790        p2: GeneratedParticle,
791        p3: GeneratedParticle,
792        p4: GeneratedParticle,
793        tdist: MandelstamTDistribution,
794    ) -> LadduResult<Self> {
795        Ok(Self {
796            topology: GeneratedReactionTopology::TwoToTwo(GeneratedTwoToTwoReaction::new(
797                p1, p2, p3, p4, tdist,
798            )?),
799        })
800    }
801
802    /// Return generated p4 labels.
803    pub fn p4_labels(&self) -> Vec<String> {
804        self.topology.p4_labels()
805    }
806
807    /// Return generated particle layout entries in stable product-ID order.
808    pub fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
809        self.topology.particle_layouts()
810    }
811
812    /// Return generated particle layout entries for a dataset storage policy.
813    pub fn particle_layouts_with_storage(
814        &self,
815        storage: &GeneratedStorage,
816    ) -> Vec<GeneratedParticleLayout> {
817        self.topology.particle_layouts_with_storage(storage)
818    }
819
820    /// Return generated vertex layout entries in stable vertex-ID order.
821    pub fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
822        self.topology.vertex_layouts()
823    }
824
825    /// Build the reconstructed reaction corresponding to this generated layout.
826    pub fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
827        self.topology.reconstructed_reaction()
828    }
829
830    fn generate(
831        &self,
832        rng: &mut Rng,
833        p4_storage: &mut HashMap<String, Vec<Vec4>>,
834        n_events: usize,
835    ) {
836        for _ in 0..n_events {
837            self.topology.generate_event(rng, p4_storage);
838        }
839    }
840}
841
842/// Metadata for one generated particle in a generated event layout.
843#[derive(Clone, Debug, PartialEq, Eq)]
844pub struct GeneratedParticleLayout {
845    id: String,
846    product_id: usize,
847    parent_id: Option<usize>,
848    species: Option<ParticleSpecies>,
849    p4_label: Option<String>,
850    produced_vertex_id: Option<usize>,
851    decay_vertex_id: Option<usize>,
852}
853
854impl GeneratedParticleLayout {
855    /// Return the generated particle identifier.
856    pub fn id(&self) -> &str {
857        &self.id
858    }
859
860    /// Return the zero-based stable product ID in generated-layout order.
861    pub fn product_id(&self) -> usize {
862        self.product_id
863    }
864
865    /// Return the decay-parent product ID, if this particle is a decay daughter.
866    pub fn parent_id(&self) -> Option<usize> {
867        self.parent_id
868    }
869
870    /// Return optional species metadata associated with this generated particle.
871    pub fn species(&self) -> Option<&ParticleSpecies> {
872        self.species.as_ref()
873    }
874
875    /// Return the dataset p4 label associated with this particle, if stored in the batch.
876    pub fn p4_label(&self) -> Option<&str> {
877        self.p4_label.as_deref()
878    }
879
880    /// Return the vertex ID where this particle was produced, if any.
881    pub fn produced_vertex_id(&self) -> Option<usize> {
882        self.produced_vertex_id
883    }
884
885    /// Return the vertex ID where this particle decays, if it is a generated parent.
886    pub fn decay_vertex_id(&self) -> Option<usize> {
887        self.decay_vertex_id
888    }
889}
890
891/// The semantic kind of a generated vertex.
892#[derive(Clone, Copy, Debug, PartialEq, Eq)]
893pub enum GeneratedVertexKind {
894    /// A production vertex connecting initial-state particles to outgoing products.
895    Production,
896    /// A decay vertex connecting one generated parent to generated daughters.
897    Decay,
898}
899
900/// Metadata for one generated vertex in a generated event layout.
901#[derive(Clone, Debug, PartialEq, Eq)]
902pub struct GeneratedVertexLayout {
903    vertex_id: usize,
904    kind: GeneratedVertexKind,
905    incoming_product_ids: Vec<usize>,
906    outgoing_product_ids: Vec<usize>,
907}
908
909impl GeneratedVertexLayout {
910    /// Return the zero-based stable vertex ID in generated-layout order.
911    pub fn vertex_id(&self) -> usize {
912        self.vertex_id
913    }
914
915    /// Return the semantic vertex kind.
916    pub fn kind(&self) -> GeneratedVertexKind {
917        self.kind
918    }
919
920    /// Return product IDs entering this vertex.
921    pub fn incoming_product_ids(&self) -> &[usize] {
922        &self.incoming_product_ids
923    }
924
925    /// Return product IDs leaving this vertex.
926    pub fn outgoing_product_ids(&self) -> &[usize] {
927        &self.outgoing_product_ids
928    }
929}
930
931/// Metadata describing the columns and generated particles in a generated event batch.
932#[derive(Clone, Debug, PartialEq, Eq)]
933pub struct GeneratedEventLayout {
934    p4_labels: Vec<String>,
935    aux_labels: Vec<String>,
936    particles: Vec<GeneratedParticleLayout>,
937    vertices: Vec<GeneratedVertexLayout>,
938}
939
940impl GeneratedEventLayout {
941    /// Construct generated event layout metadata from p4 and auxiliary labels.
942    pub fn new(
943        p4_labels: Vec<String>,
944        aux_labels: Vec<String>,
945        particles: Vec<GeneratedParticleLayout>,
946        vertices: Vec<GeneratedVertexLayout>,
947    ) -> Self {
948        Self {
949            p4_labels,
950            aux_labels,
951            particles,
952            vertices,
953        }
954    }
955
956    /// Return generated p4 column labels in dataset order.
957    pub fn p4_labels(&self) -> &[String] {
958        &self.p4_labels
959    }
960
961    /// Return generated auxiliary column labels in dataset order.
962    pub fn aux_labels(&self) -> &[String] {
963        &self.aux_labels
964    }
965
966    /// Return generated particle layout entries in stable product-ID order.
967    pub fn particles(&self) -> &[GeneratedParticleLayout] {
968        &self.particles
969    }
970
971    /// Return the generated particle layout for a generated particle ID.
972    pub fn particle(&self, id: &str) -> Option<&GeneratedParticleLayout> {
973        self.particles.iter().find(|particle| particle.id() == id)
974    }
975
976    /// Return the generated particle layout for a stable product ID.
977    pub fn product(&self, product_id: usize) -> Option<&GeneratedParticleLayout> {
978        self.particles
979            .iter()
980            .find(|particle| particle.product_id() == product_id)
981    }
982
983    /// Return generated vertex layout entries in stable vertex-ID order.
984    pub fn vertices(&self) -> &[GeneratedVertexLayout] {
985        &self.vertices
986    }
987
988    /// Return the generated vertex layout for a stable vertex ID.
989    pub fn vertex(&self, vertex_id: usize) -> Option<&GeneratedVertexLayout> {
990        self.vertices
991            .iter()
992            .find(|vertex| vertex.vertex_id() == vertex_id)
993    }
994
995    /// Return the production vertex layout, if the generated layout has one.
996    pub fn production_vertex(&self) -> Option<&GeneratedVertexLayout> {
997        self.vertices
998            .iter()
999            .find(|vertex| vertex.kind() == GeneratedVertexKind::Production)
1000    }
1001
1002    /// Return the generated decay daughters of a parent product ID.
1003    pub fn decay_products(&self, parent_product_id: usize) -> Vec<&GeneratedParticleLayout> {
1004        self.particles
1005            .iter()
1006            .filter(|particle| particle.parent_id() == Some(parent_product_id))
1007            .collect()
1008    }
1009
1010    /// Return production-level incoming particle layouts.
1011    pub fn production_incoming(&self) -> Vec<&GeneratedParticleLayout> {
1012        self.production_vertex_products(GeneratedVertexLayout::incoming_product_ids)
1013    }
1014
1015    /// Return production-level outgoing particle layouts.
1016    pub fn production_outgoing(&self) -> Vec<&GeneratedParticleLayout> {
1017        self.production_vertex_products(GeneratedVertexLayout::outgoing_product_ids)
1018    }
1019
1020    fn production_vertex_products(
1021        &self,
1022        ids: impl FnOnce(&GeneratedVertexLayout) -> &[usize],
1023    ) -> Vec<&GeneratedParticleLayout> {
1024        self.production_vertex()
1025            .map(|vertex| {
1026                ids(vertex)
1027                    .iter()
1028                    .filter_map(|product_id| self.product(*product_id))
1029                    .collect()
1030            })
1031            .unwrap_or_default()
1032    }
1033}
1034
1035/// A generated dataset batch plus the metadata needed to interpret it.
1036#[derive(Clone, Debug)]
1037pub struct GeneratedBatch {
1038    dataset: Dataset,
1039    reaction: GeneratedReaction,
1040    layout: GeneratedEventLayout,
1041}
1042
1043impl GeneratedBatch {
1044    /// Construct a generated batch.
1045    pub fn new(
1046        dataset: Dataset,
1047        reaction: GeneratedReaction,
1048        layout: GeneratedEventLayout,
1049    ) -> Self {
1050        Self {
1051            dataset,
1052            reaction,
1053            layout,
1054        }
1055    }
1056
1057    /// Borrow the generated dataset.
1058    pub fn dataset(&self) -> &Dataset {
1059        &self.dataset
1060    }
1061
1062    /// Consume this batch and return the generated dataset.
1063    pub fn into_dataset(self) -> Dataset {
1064        self.dataset
1065    }
1066
1067    /// Borrow the generated reaction metadata.
1068    pub fn reaction(&self) -> &GeneratedReaction {
1069        &self.reaction
1070    }
1071
1072    /// Borrow the generated event layout metadata.
1073    pub fn layout(&self) -> &GeneratedEventLayout {
1074        &self.layout
1075    }
1076}
1077
1078/// Event generator for generated reactions.
1079#[derive(Clone, Debug)]
1080pub struct EventGenerator {
1081    reaction: GeneratedReaction,
1082    aux_generators: HashMap<String, Distribution>,
1083    storage: GeneratedStorage,
1084    seed: u64,
1085}
1086
1087/// Finite iterator over generated dataset batches.
1088#[derive(Clone, Debug)]
1089pub struct GeneratedBatchIter {
1090    generator: EventGenerator,
1091    remaining_events: usize,
1092    batch_size: usize,
1093    rng: Rng,
1094}
1095
1096impl Iterator for GeneratedBatchIter {
1097    type Item = LadduResult<GeneratedBatch>;
1098
1099    fn next(&mut self) -> Option<Self::Item> {
1100        if self.remaining_events == 0 {
1101            return None;
1102        }
1103        let n_events = self.batch_size.min(self.remaining_events);
1104        self.remaining_events -= n_events;
1105        Some(
1106            self.generator
1107                .generate_batch_with_rng(n_events, &mut self.rng),
1108        )
1109    }
1110}
1111
1112/// Evaluates unnormalized intensities for generated batches.
1113pub trait BatchIntensity {
1114    /// Return one real-valued intensity for each event in `batch`.
1115    fn evaluate(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>>;
1116
1117    /// Evaluate and validate one finite nonnegative intensity for each event in `batch`.
1118    fn evaluate_checked(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1119        let intensities = self.evaluate(batch)?;
1120        if intensities.len() != batch.dataset().n_events() {
1121            return Err(LadduError::Custom(format!(
1122                "intensity length mismatch: expected {}, got {}",
1123                batch.dataset().n_events(),
1124                intensities.len()
1125            )));
1126        }
1127        for (index, weight) in intensities.iter().enumerate() {
1128            if !weight.is_finite() || *weight < 0.0 {
1129                return Err(LadduError::Custom(format!(
1130                    "intensity at event {index} must be finite and nonnegative, got {weight}"
1131                )));
1132            }
1133        }
1134        Ok(intensities)
1135    }
1136}
1137
1138impl<F> BatchIntensity for F
1139where
1140    F: Fn(&GeneratedBatch) -> LadduResult<Vec<f64>>,
1141{
1142    fn evaluate(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1143        self(batch)
1144    }
1145}
1146
1147/// Evaluates a real-valued [`Expression`] as a generated-batch intensity.
1148#[derive(Clone, Debug)]
1149pub struct ExpressionIntensity {
1150    expression: Expression,
1151    parameters: Vec<f64>,
1152}
1153
1154impl ExpressionIntensity {
1155    /// Construct an expression-backed generated-batch intensity.
1156    pub fn new(expression: Expression, parameters: Vec<f64>) -> Self {
1157        Self {
1158            expression,
1159            parameters,
1160        }
1161    }
1162}
1163
1164impl BatchIntensity for ExpressionIntensity {
1165    fn evaluate(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1166        let dataset = Arc::new(batch.dataset().clone());
1167        let evaluator = self.expression.load(&dataset)?;
1168        evaluator
1169            .evaluate(&self.parameters)?
1170            .into_iter()
1171            .enumerate()
1172            .map(|(index, value)| {
1173                if !value.im.is_finite() || value.im.abs() > f64::EPSILON {
1174                    return Err(LadduError::Custom(format!(
1175                        "expression intensity at event {index} must be real-valued, got {value}"
1176                    )));
1177                }
1178                Ok(value.re)
1179            })
1180            .collect()
1181    }
1182}
1183
1184/// Envelope strategy used by rejection sampling.
1185#[derive(Clone, Debug)]
1186pub enum RejectionEnvelope {
1187    /// Use a fixed maximum event weight.
1188    Fixed {
1189        /// Maximum event weight used as the rejection envelope.
1190        max_weight: f64,
1191    },
1192    /// Estimate the maximum event weight from a pilot sample and inflate it.
1193    Pilot {
1194        /// Number of pilot events used to estimate the envelope.
1195        pilot_events: usize,
1196        /// Optional raw generation batch size used for pilot evaluation.
1197        batch_size: Option<usize>,
1198        /// Multiplicative safety factor applied to the observed maximum.
1199        safety_factor: f64,
1200    },
1201}
1202
1203/// Options for rejection sampling generated events.
1204#[derive(Clone, Debug)]
1205pub struct RejectionSamplingOptions {
1206    /// Number of accepted events to produce.
1207    pub target_accepted: usize,
1208    /// Number of raw events to generate per source batch.
1209    pub generation_batch_size: usize,
1210    /// Target number of accepted events emitted per output batch.
1211    pub output_batch_size: usize,
1212    /// Envelope used by the rejection sampler.
1213    pub envelope: RejectionEnvelope,
1214    /// Random seed used for accept/reject decisions.
1215    pub seed: u64,
1216}
1217
1218/// Rejection-sampling diagnostics accumulated while sampling.
1219#[derive(Clone, Debug, Default)]
1220pub struct RejectionSamplingDiagnostics {
1221    /// Number of generated events inspected.
1222    pub generated_events: usize,
1223    /// Number of events accepted.
1224    pub accepted_events: usize,
1225    /// Number of events rejected.
1226    pub rejected_events: usize,
1227    /// Maximum observed event intensity.
1228    pub max_observed_weight: f64,
1229    /// Envelope maximum used for rejection sampling.
1230    pub envelope_max_weight: f64,
1231    /// Number of fixed-envelope violations observed.
1232    pub envelope_violations: usize,
1233}
1234
1235impl RejectionSamplingDiagnostics {
1236    /// Fraction of generated events accepted.
1237    pub fn acceptance_efficiency(&self) -> f64 {
1238        if self.generated_events == 0 {
1239            0.0
1240        } else {
1241            self.accepted_events as f64 / self.generated_events as f64
1242        }
1243    }
1244}
1245
1246/// Rejection sampler over generated batches.
1247#[derive(Clone, Debug)]
1248pub struct RejectionSampler<I> {
1249    generator: EventGenerator,
1250    intensity: I,
1251    max_weight: f64,
1252    options: RejectionSamplingOptions,
1253}
1254
1255impl<I> RejectionSampler<I>
1256where
1257    I: BatchIntensity,
1258{
1259    /// Construct a rejection sampler.
1260    pub fn new(
1261        generator: EventGenerator,
1262        intensity: I,
1263        options: RejectionSamplingOptions,
1264    ) -> LadduResult<Self> {
1265        if options.generation_batch_size == 0 {
1266            return Err(LadduError::Custom(
1267                "generation_batch_size must be greater than zero".to_string(),
1268            ));
1269        }
1270        if options.output_batch_size == 0 {
1271            return Err(LadduError::Custom(
1272                "output_batch_size must be greater than zero".to_string(),
1273            ));
1274        }
1275        let max_weight = match options.envelope {
1276            RejectionEnvelope::Fixed { max_weight } => max_weight,
1277            RejectionEnvelope::Pilot {
1278                pilot_events,
1279                batch_size,
1280                safety_factor,
1281            } => estimate_rejection_envelope(
1282                &generator,
1283                &intensity,
1284                pilot_events,
1285                batch_size.unwrap_or(options.generation_batch_size),
1286                safety_factor,
1287            )?,
1288        };
1289        if !max_weight.is_finite() || max_weight <= 0.0 {
1290            return Err(LadduError::Custom(
1291                "rejection envelope max_weight must be finite and positive".to_string(),
1292            ));
1293        }
1294        Ok(Self {
1295            generator,
1296            intensity,
1297            max_weight,
1298            options,
1299        })
1300    }
1301
1302    /// Consume this sampler and return an iterator over accepted generated batches.
1303    pub fn accepted_batches(self) -> RejectionSampleIter<I> {
1304        RejectionSampleIter {
1305            generation_rng: Rng::with_seed(self.generator.seed),
1306            rejection_rng: Rng::with_seed(self.options.seed),
1307            diagnostics: RejectionSamplingDiagnostics {
1308                envelope_max_weight: self.max_weight,
1309                ..Default::default()
1310            },
1311            sampler: self,
1312            current_batch: None,
1313            current_intensities: Vec::new(),
1314            current_index: 0,
1315        }
1316    }
1317}
1318
1319fn estimate_rejection_envelope<I>(
1320    generator: &EventGenerator,
1321    intensity: &I,
1322    pilot_events: usize,
1323    batch_size: usize,
1324    safety_factor: f64,
1325) -> LadduResult<f64>
1326where
1327    I: BatchIntensity,
1328{
1329    if pilot_events == 0 {
1330        return Err(LadduError::Custom(
1331            "pilot_events must be greater than zero".to_string(),
1332        ));
1333    }
1334    if batch_size == 0 {
1335        return Err(LadduError::Custom(
1336            "pilot batch_size must be greater than zero".to_string(),
1337        ));
1338    }
1339    if !safety_factor.is_finite() || safety_factor <= 0.0 {
1340        return Err(LadduError::Custom(
1341            "pilot safety_factor must be finite and positive".to_string(),
1342        ));
1343    }
1344
1345    let mut max_observed = 0.0_f64;
1346    let mut rng = Rng::with_seed(generator.seed);
1347    let mut remaining = pilot_events;
1348    while remaining > 0 {
1349        let n_events = remaining.min(batch_size);
1350        let batch = generator.generate_batch_with_rng(n_events, &mut rng)?;
1351        let weights = intensity.evaluate_checked(&batch)?;
1352        for weight in weights {
1353            max_observed = max_observed.max(weight);
1354        }
1355        remaining -= n_events;
1356    }
1357
1358    let max_weight = max_observed * safety_factor;
1359    if !max_weight.is_finite() || max_weight <= 0.0 {
1360        return Err(LadduError::Custom(format!(
1361            "pilot envelope produced invalid max_weight {max_weight}; observed maximum was {max_observed}"
1362        )));
1363    }
1364    Ok(max_weight)
1365}
1366
1367/// Iterator over accepted generated batches.
1368#[derive(Clone, Debug)]
1369pub struct RejectionSampleIter<I> {
1370    sampler: RejectionSampler<I>,
1371    generation_rng: Rng,
1372    rejection_rng: Rng,
1373    diagnostics: RejectionSamplingDiagnostics,
1374    current_batch: Option<GeneratedBatch>,
1375    current_intensities: Vec<f64>,
1376    current_index: usize,
1377}
1378
1379impl<I> RejectionSampleIter<I> {
1380    /// Borrow rejection-sampling diagnostics accumulated so far.
1381    pub fn diagnostics(&self) -> &RejectionSamplingDiagnostics {
1382        &self.diagnostics
1383    }
1384}
1385
1386impl<I> RejectionSampleIter<I>
1387where
1388    I: BatchIntensity,
1389{
1390    fn load_next_source_batch(&mut self) -> LadduResult<()> {
1391        let batch = self.sampler.generator.generate_batch_with_rng(
1392            self.sampler.options.generation_batch_size,
1393            &mut self.generation_rng,
1394        )?;
1395        let intensities = self.sampler.intensity.evaluate_checked(&batch)?;
1396        self.diagnostics.generated_events += batch.dataset().n_events();
1397        self.current_batch = Some(batch);
1398        self.current_intensities = intensities;
1399        self.current_index = 0;
1400        Ok(())
1401    }
1402
1403    fn empty_output_batch(source: &GeneratedBatch) -> GeneratedBatch {
1404        GeneratedBatch::new(
1405            Dataset::empty_local(source.dataset().metadata().clone()),
1406            source.reaction().clone(),
1407            source.layout().clone(),
1408        )
1409    }
1410}
1411
1412impl<I> Iterator for RejectionSampleIter<I>
1413where
1414    I: BatchIntensity,
1415{
1416    type Item = LadduResult<GeneratedBatch>;
1417
1418    fn next(&mut self) -> Option<Self::Item> {
1419        if self.diagnostics.accepted_events >= self.sampler.options.target_accepted {
1420            return None;
1421        }
1422
1423        let mut output: Option<GeneratedBatch> = None;
1424        while self.diagnostics.accepted_events < self.sampler.options.target_accepted {
1425            let needs_batch = self
1426                .current_batch
1427                .as_ref()
1428                .map(|batch| self.current_index >= batch.dataset().n_events())
1429                .unwrap_or(true);
1430            if needs_batch {
1431                if let Err(err) = self.load_next_source_batch() {
1432                    return Some(Err(err));
1433                }
1434            }
1435
1436            let source = self
1437                .current_batch
1438                .as_ref()
1439                .expect("source batch should be loaded");
1440            if output.is_none() {
1441                output = Some(Self::empty_output_batch(source));
1442            }
1443
1444            let weight = self.current_intensities[self.current_index];
1445            self.diagnostics.max_observed_weight = self.diagnostics.max_observed_weight.max(weight);
1446            let envelope_max = self.sampler.max_weight;
1447            if weight > envelope_max {
1448                self.diagnostics.envelope_violations += 1;
1449                return Some(Err(LadduError::Custom(format!(
1450                    "rejection envelope violation: observed weight {weight} exceeds max_weight {envelope_max}"
1451                ))));
1452            }
1453
1454            let accepted = self.rejection_rng.f64() * envelope_max < weight;
1455            if accepted {
1456                let event = match source.dataset().event_global(self.current_index) {
1457                    Ok(event) => event,
1458                    Err(err) => return Some(Err(err)),
1459                };
1460                if let Err(err) = output.as_mut().unwrap().dataset.push_event_local(
1461                    event.p4s.clone(),
1462                    event.aux.clone(),
1463                    event.weight,
1464                ) {
1465                    return Some(Err(err));
1466                }
1467                self.diagnostics.accepted_events += 1;
1468            } else {
1469                self.diagnostics.rejected_events += 1;
1470            }
1471            self.current_index += 1;
1472
1473            if output.as_ref().unwrap().dataset().n_events()
1474                >= self.sampler.options.output_batch_size
1475                || self.diagnostics.accepted_events >= self.sampler.options.target_accepted
1476            {
1477                break;
1478            }
1479        }
1480
1481        output
1482            .filter(|batch| batch.dataset().n_events() > 0)
1483            .map(Ok)
1484    }
1485}
1486
1487impl EventGenerator {
1488    /// Construct an event generator.
1489    pub fn new(
1490        reaction: GeneratedReaction,
1491        aux_generators: HashMap<String, Distribution>,
1492        seed: Option<u64>,
1493    ) -> Self {
1494        Self {
1495            reaction,
1496            aux_generators,
1497            storage: GeneratedStorage::All,
1498            seed: seed.unwrap_or_else(|| fastrand::u64(..)),
1499        }
1500    }
1501
1502    /// Return the generated p4 storage policy.
1503    pub fn storage(&self) -> &GeneratedStorage {
1504        &self.storage
1505    }
1506
1507    /// Return a copy of this generator with a generated p4 storage policy.
1508    pub fn with_storage(mut self, storage: GeneratedStorage) -> LadduResult<Self> {
1509        storage.validate(&self.reaction.p4_labels())?;
1510        self.storage = storage;
1511        Ok(self)
1512    }
1513
1514    fn aux_entries(&self) -> Vec<(&String, &Distribution)> {
1515        let mut aux_entries = self.aux_generators.iter().collect::<Vec<_>>();
1516        aux_entries.sort_by_key(|(label, _)| *label);
1517        aux_entries
1518    }
1519
1520    fn generate_batch_with_rng(
1521        &self,
1522        n_events: usize,
1523        rng: &mut Rng,
1524    ) -> LadduResult<GeneratedBatch> {
1525        let all_p4_labels = self.reaction.p4_labels();
1526        self.storage.validate(&all_p4_labels)?;
1527        let p4_labels = self.storage.stored_labels(&all_p4_labels);
1528        let aux_entries = self.aux_entries();
1529        let aux_labels = aux_entries
1530            .iter()
1531            .map(|(label, _)| (*label).clone())
1532            .collect::<Vec<_>>();
1533        let mut p4_data: HashMap<String, Vec<Vec4>> = p4_labels
1534            .iter()
1535            .map(|label| (label.clone(), Vec::with_capacity(n_events)))
1536            .collect();
1537        let metadata = DatasetMetadata::new(p4_labels.clone(), aux_labels.clone())?;
1538        let mut aux: Vec<Vec<f64>> = aux_entries
1539            .iter()
1540            .map(|_| Vec::with_capacity(n_events))
1541            .collect();
1542        let weights = vec![1.0; n_events];
1543        for _ in 0..n_events {
1544            for ((_, distribution), column) in aux_entries.iter().zip(aux.iter_mut()) {
1545                column.push(distribution.sample(rng));
1546            }
1547            self.reaction.generate(rng, &mut p4_data, 1);
1548        }
1549        let p4 = p4_labels
1550            .iter()
1551            .filter_map(|label| p4_data.remove(label))
1552            .collect();
1553        let dataset = Dataset::from_columns_local(metadata, p4, aux, weights)?;
1554        Ok(GeneratedBatch::new(
1555            dataset,
1556            self.reaction.clone(),
1557            GeneratedEventLayout::new(
1558                p4_labels,
1559                aux_labels,
1560                self.reaction.particle_layouts_with_storage(&self.storage),
1561                self.reaction.vertex_layouts(),
1562            ),
1563        ))
1564    }
1565
1566    /// Generate one dataset batch with generated layout metadata.
1567    pub fn generate_batch(&self, n_events: usize) -> LadduResult<GeneratedBatch> {
1568        let mut rng = Rng::with_seed(self.seed);
1569        self.generate_batch_with_rng(n_events, &mut rng)
1570    }
1571
1572    /// Generate a finite iterator over batches.
1573    ///
1574    /// The iterator advances one RNG stream, so concatenating all yielded batches is
1575    /// deterministic and matches [`EventGenerator::generate_dataset`] for the same total count.
1576    pub fn generate_batches(
1577        &self,
1578        total_events: usize,
1579        batch_size: usize,
1580    ) -> LadduResult<GeneratedBatchIter> {
1581        if batch_size == 0 {
1582            return Err(LadduError::Custom(
1583                "batch_size must be greater than zero".to_string(),
1584            ));
1585        }
1586        Ok(GeneratedBatchIter {
1587            generator: self.clone(),
1588            remaining_events: total_events,
1589            batch_size,
1590            rng: Rng::with_seed(self.seed),
1591        })
1592    }
1593
1594    /// Generate a dataset.
1595    pub fn generate_dataset(&self, n_events: usize) -> LadduResult<Dataset> {
1596        Ok(self.generate_batch(n_events)?.into_dataset())
1597    }
1598
1599    /// Construct a rejection sampler using a real-valued expression intensity.
1600    pub fn rejection_sampler_with_expression(
1601        &self,
1602        expression: Expression,
1603        parameters: Vec<f64>,
1604        options: RejectionSamplingOptions,
1605    ) -> LadduResult<RejectionSampler<ExpressionIntensity>> {
1606        RejectionSampler::new(
1607            self.clone(),
1608            ExpressionIntensity::new(expression, parameters),
1609            options,
1610        )
1611    }
1612}
1613
1614#[cfg(test)]
1615mod tests {
1616    use approx::assert_relative_eq;
1617    use laddu_core::{traits::Variable, Channel, Expression, Frame};
1618
1619    use super::*;
1620
1621    fn demo_reaction() -> GeneratedReaction {
1622        let beam = GeneratedParticle::initial(
1623            "beam",
1624            InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1625            Reconstruction::Stored,
1626        );
1627        let target = GeneratedParticle::initial(
1628            "target",
1629            InitialGenerator::target(0.938272),
1630            Reconstruction::Missing,
1631        );
1632        let ks1 = GeneratedParticle::stable(
1633            "kshort1",
1634            StableGenerator::new(0.497611),
1635            Reconstruction::Stored,
1636        );
1637        let ks2 = GeneratedParticle::stable(
1638            "kshort2",
1639            StableGenerator::new(0.497611),
1640            Reconstruction::Stored,
1641        );
1642        let kk = GeneratedParticle::composite(
1643            "kk",
1644            CompositeGenerator::new(1.1, 1.6),
1645            (&ks1, &ks2),
1646            Reconstruction::Composite,
1647        );
1648        let recoil = GeneratedParticle::stable(
1649            "recoil",
1650            StableGenerator::new(0.938272),
1651            Reconstruction::Stored,
1652        );
1653        let tdist = MandelstamTDistribution::Exponential { slope: 0.1 };
1654        GeneratedReaction::two_to_two(beam, target, kk, recoil, tdist).unwrap()
1655    }
1656
1657    #[test]
1658    fn test_generation() {
1659        let reaction = demo_reaction();
1660        let generator = EventGenerator::new(reaction, HashMap::new(), Some(12345));
1661        let n_events = 1_000;
1662        let dataset = generator.generate_dataset(n_events).unwrap();
1663        assert_eq!(dataset.n_events(), n_events);
1664        let metadata = dataset.metadata();
1665        assert!(metadata.p4_index("beam").is_some());
1666        assert!(metadata.p4_index("target").is_some());
1667        assert!(metadata.p4_index("kk").is_some());
1668        assert!(metadata.p4_index("kshort1").is_some());
1669        assert!(metadata.p4_index("kshort2").is_some());
1670        assert!(metadata.p4_index("recoil").is_some());
1671
1672        for event in dataset.events_global() {
1673            let beam_p4 = event.p4("beam").unwrap();
1674            let target_p4 = event.p4("target").unwrap();
1675            let kk_p4 = event.p4("kk").unwrap();
1676            let kshort1_p4 = event.p4("kshort1").unwrap();
1677            let kshort2_p4 = event.p4("kshort2").unwrap();
1678            let recoil_p4 = event.p4("recoil").unwrap();
1679
1680            assert!(beam_p4.e().is_finite());
1681            assert!(target_p4.e().is_finite());
1682            assert!(kk_p4.e().is_finite());
1683            assert!(kshort1_p4.e().is_finite());
1684            assert!(kshort2_p4.e().is_finite());
1685            assert!(recoil_p4.e().is_finite());
1686
1687            assert_relative_eq!(kk_p4, kshort1_p4 + kshort2_p4, epsilon = 1e-10);
1688            assert_relative_eq!(beam_p4 + target_p4, kk_p4 + recoil_p4, epsilon = 1e-10);
1689            assert_relative_eq!(kshort1_p4.m(), 0.497611, epsilon = 1e-10);
1690            assert_relative_eq!(kshort2_p4.m(), 0.497611, epsilon = 1e-10);
1691            assert_relative_eq!(recoil_p4.m(), 0.938272, epsilon = 1e-10);
1692        }
1693    }
1694
1695    #[test]
1696    fn test_reconstructed_reaction() {
1697        let generated = demo_reaction();
1698        let reaction = generated.reconstructed_reaction().unwrap();
1699        let dataset = EventGenerator::new(generated, HashMap::new(), Some(12345))
1700            .generate_dataset(4)
1701            .unwrap();
1702        let mass = reaction.mass("kk").value_on(&dataset).unwrap();
1703        let angles = reaction
1704            .decay("kk")
1705            .unwrap()
1706            .angles("kshort1", Frame::Helicity)
1707            .unwrap();
1708        let mandelstam = reaction
1709            .mandelstam(Channel::S)
1710            .unwrap()
1711            .value_on(&dataset)
1712            .unwrap();
1713
1714        assert_eq!(mass.len(), 4);
1715        assert_eq!(
1716            angles.costheta.to_string(),
1717            "CosTheta(parent=kk, daughter=kshort1, frame=Helicity)"
1718        );
1719        assert_eq!(mandelstam.len(), 4);
1720    }
1721
1722    #[test]
1723    fn test_generated_batch_metadata() {
1724        let generated = demo_reaction();
1725        let generator = EventGenerator::new(
1726            generated,
1727            HashMap::from([("pol_angle".to_string(), Distribution::Fixed(0.25))]),
1728            Some(12345),
1729        );
1730        let batch = generator.generate_batch(4).unwrap();
1731
1732        assert_eq!(batch.dataset().n_events(), 4);
1733        assert_eq!(
1734            batch.layout().p4_labels(),
1735            &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1736        );
1737        assert_eq!(batch.layout().aux_labels(), &["pol_angle"]);
1738        assert_eq!(
1739            batch.reaction().p4_labels(),
1740            vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1741        );
1742        assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1743        assert_eq!(batch.dataset().aux_names(), batch.layout().aux_labels());
1744        let particles = batch.layout().particles();
1745        assert_eq!(particles.len(), 6);
1746        assert_eq!(particles[0].id(), "beam");
1747        assert_eq!(particles[0].product_id(), 0);
1748        assert_eq!(particles[0].parent_id(), None);
1749        assert_eq!(particles[0].produced_vertex_id(), None);
1750        assert_eq!(particles[0].decay_vertex_id(), None);
1751        assert_eq!(particles[1].id(), "target");
1752        assert_eq!(particles[1].parent_id(), None);
1753        assert_eq!(particles[1].produced_vertex_id(), None);
1754        assert_eq!(particles[1].decay_vertex_id(), None);
1755        assert_eq!(particles[2].id(), "kk");
1756        assert_eq!(particles[2].product_id(), 2);
1757        assert_eq!(particles[2].parent_id(), None);
1758        assert_eq!(particles[2].produced_vertex_id(), Some(0));
1759        assert_eq!(particles[2].decay_vertex_id(), Some(1));
1760        assert_eq!(particles[3].id(), "kshort1");
1761        assert_eq!(particles[3].parent_id(), Some(2));
1762        assert_eq!(particles[3].produced_vertex_id(), Some(1));
1763        assert_eq!(particles[3].decay_vertex_id(), None);
1764        assert_eq!(particles[4].id(), "kshort2");
1765        assert_eq!(particles[4].parent_id(), Some(2));
1766        assert_eq!(particles[4].produced_vertex_id(), Some(1));
1767        assert_eq!(particles[4].decay_vertex_id(), None);
1768        assert_eq!(particles[5].id(), "recoil");
1769        assert_eq!(particles[5].parent_id(), None);
1770        assert_eq!(particles[5].produced_vertex_id(), Some(0));
1771        assert_eq!(particles[5].decay_vertex_id(), None);
1772        for particle in particles {
1773            assert_eq!(particle.p4_label(), Some(particle.id()));
1774        }
1775        let vertices = batch.layout().vertices();
1776        assert_eq!(vertices.len(), 2);
1777        assert_eq!(vertices[0].vertex_id(), 0);
1778        assert_eq!(vertices[0].kind(), GeneratedVertexKind::Production);
1779        assert_eq!(vertices[0].incoming_product_ids(), &[0, 1]);
1780        assert_eq!(vertices[0].outgoing_product_ids(), &[2, 5]);
1781        assert_eq!(vertices[1].vertex_id(), 1);
1782        assert_eq!(vertices[1].kind(), GeneratedVertexKind::Decay);
1783        assert_eq!(vertices[1].incoming_product_ids(), &[2]);
1784        assert_eq!(vertices[1].outgoing_product_ids(), &[3, 4]);
1785
1786        assert_eq!(batch.layout().particle("kk"), Some(&particles[2]));
1787        assert_eq!(batch.layout().particle("missing_id"), None);
1788        assert_eq!(batch.layout().product(5), Some(&particles[5]));
1789        assert_eq!(batch.layout().product(6), None);
1790        assert_eq!(batch.layout().vertex(1), Some(&vertices[1]));
1791        assert_eq!(batch.layout().vertex(2), None);
1792        assert_eq!(batch.layout().production_vertex(), Some(&vertices[0]));
1793        assert_eq!(
1794            batch
1795                .layout()
1796                .production_incoming()
1797                .iter()
1798                .map(|particle| particle.id())
1799                .collect::<Vec<_>>(),
1800            vec!["beam", "target"]
1801        );
1802        assert_eq!(
1803            batch
1804                .layout()
1805                .production_outgoing()
1806                .iter()
1807                .map(|particle| particle.id())
1808                .collect::<Vec<_>>(),
1809            vec!["kk", "recoil"]
1810        );
1811        assert_eq!(
1812            batch
1813                .layout()
1814                .decay_products(2)
1815                .iter()
1816                .map(|particle| particle.id())
1817                .collect::<Vec<_>>(),
1818            vec!["kshort1", "kshort2"]
1819        );
1820        assert!(batch.layout().decay_products(5).is_empty());
1821    }
1822
1823    #[test]
1824    fn generated_storage_only_projects_dataset_columns() {
1825        let generated = demo_reaction();
1826        let generator = EventGenerator::new(generated, HashMap::new(), Some(12345))
1827            .with_storage(GeneratedStorage::only([
1828                "beam", "target", "kshort1", "kshort2", "recoil",
1829            ]))
1830            .unwrap();
1831        let batch = generator.generate_batch(4).unwrap();
1832
1833        assert_eq!(
1834            batch.reaction().p4_labels(),
1835            vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1836        );
1837        assert_eq!(
1838            batch.layout().p4_labels(),
1839            &["beam", "target", "kshort1", "kshort2", "recoil"]
1840        );
1841        assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1842        assert!(batch.dataset().metadata().p4_index("kk").is_none());
1843
1844        let particles = batch.layout().particles();
1845        assert_eq!(particles.len(), 6);
1846        assert_eq!(particles[2].id(), "kk");
1847        assert_eq!(particles[2].p4_label(), None);
1848        assert_eq!(particles[3].p4_label(), Some("kshort1"));
1849        assert_eq!(particles[4].p4_label(), Some("kshort2"));
1850
1851        for index in 0..batch.dataset().n_events() {
1852            let event = batch.dataset().event_global(index).unwrap();
1853            assert_relative_eq!(
1854                event.p4("beam").unwrap() + event.p4("target").unwrap(),
1855                event.p4("kshort1").unwrap()
1856                    + event.p4("kshort2").unwrap()
1857                    + event.p4("recoil").unwrap(),
1858                epsilon = 1e-10
1859            );
1860        }
1861    }
1862
1863    #[test]
1864    fn generated_storage_rejects_unknown_and_duplicate_ids() {
1865        assert!(
1866            EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1867                .with_storage(GeneratedStorage::only(["beam", "does_not_exist"]))
1868                .is_err()
1869        );
1870        assert!(
1871            EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1872                .with_storage(GeneratedStorage::only(["beam", "beam"]))
1873                .is_err()
1874        );
1875    }
1876
1877    #[test]
1878    fn generated_species_metadata_propagates_to_layout() {
1879        let beam = GeneratedParticle::initial(
1880            "beam",
1881            InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1882            Reconstruction::Stored,
1883        )
1884        .with_species(ParticleSpecies::code(22));
1885        let target = GeneratedParticle::initial(
1886            "target",
1887            InitialGenerator::target(0.938272),
1888            Reconstruction::Missing,
1889        )
1890        .with_species(ParticleSpecies::with_namespace("pdg", 2212));
1891        let kshort1 = GeneratedParticle::stable(
1892            "kshort1",
1893            StableGenerator::new(0.497611),
1894            Reconstruction::Stored,
1895        )
1896        .with_species(ParticleSpecies::label("KShort"));
1897        let kshort2 = GeneratedParticle::stable(
1898            "kshort2",
1899            StableGenerator::new(0.497611),
1900            Reconstruction::Stored,
1901        )
1902        .with_species(ParticleSpecies::label("KShort"));
1903        let kk = GeneratedParticle::composite(
1904            "kk",
1905            CompositeGenerator::new(1.1, 1.6),
1906            (&kshort1, &kshort2),
1907            Reconstruction::Composite,
1908        )
1909        .with_species(ParticleSpecies::label("KK"));
1910        let recoil = GeneratedParticle::stable(
1911            "recoil",
1912            StableGenerator::new(0.938272),
1913            Reconstruction::Stored,
1914        )
1915        .with_species(ParticleSpecies::code(2212));
1916        let reaction = GeneratedReaction::two_to_two(
1917            beam,
1918            target,
1919            kk,
1920            recoil,
1921            MandelstamTDistribution::Exponential { slope: 0.1 },
1922        )
1923        .unwrap();
1924        let particles = reaction.particle_layouts();
1925
1926        assert_eq!(particles[0].species(), Some(&ParticleSpecies::code(22)));
1927        assert_eq!(
1928            particles[1].species(),
1929            Some(&ParticleSpecies::with_namespace("pdg", 2212))
1930        );
1931        assert_eq!(particles[2].species(), Some(&ParticleSpecies::label("KK")));
1932        assert_eq!(
1933            particles[3].species(),
1934            Some(&ParticleSpecies::label("KShort"))
1935        );
1936        assert_eq!(
1937            particles[4].species(),
1938            Some(&ParticleSpecies::label("KShort"))
1939        );
1940        assert_eq!(particles[5].species(), Some(&ParticleSpecies::code(2212)));
1941    }
1942
1943    #[test]
1944    fn generated_batches_match_one_shot_generation() {
1945        let generated = demo_reaction();
1946        let generator = EventGenerator::new(
1947            generated,
1948            HashMap::from([(
1949                "pol_angle".to_string(),
1950                Distribution::Uniform { min: 0.0, max: 1.0 },
1951            )]),
1952            Some(12345),
1953        );
1954        let one_shot = generator.generate_dataset(7).unwrap();
1955        let batches = generator
1956            .generate_batches(7, 3)
1957            .unwrap()
1958            .collect::<LadduResult<Vec<_>>>()
1959            .unwrap();
1960        let batch_sizes = batches
1961            .iter()
1962            .map(|batch| batch.dataset().n_events())
1963            .collect::<Vec<_>>();
1964        assert_eq!(batch_sizes, vec![3, 3, 1]);
1965
1966        let mut offset = 0;
1967        for batch in batches {
1968            for local_index in 0..batch.dataset().n_events() {
1969                let expected = one_shot.event_global(offset + local_index).unwrap();
1970                let actual = batch.dataset().event_global(local_index).unwrap();
1971                for name in one_shot.p4_names() {
1972                    assert_relative_eq!(
1973                        actual.p4(name).unwrap(),
1974                        expected.p4(name).unwrap(),
1975                        epsilon = 1e-10
1976                    );
1977                }
1978                for aux_index in 0..one_shot.aux_names().len() {
1979                    assert_relative_eq!(actual.aux[aux_index], expected.aux[aux_index]);
1980                }
1981                assert_relative_eq!(actual.weight(), expected.weight());
1982            }
1983            offset += batch.dataset().n_events();
1984        }
1985        assert_eq!(offset, one_shot.n_events());
1986        assert!(generator.generate_batches(1, 0).is_err());
1987    }
1988
1989    #[test]
1990    fn fixed_envelope_rejection_sampler_streams_accepted_batches() {
1991        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
1992        let sampler = RejectionSampler::new(
1993            generator,
1994            |batch: &GeneratedBatch| Ok(vec![1.0; batch.dataset().n_events()]),
1995            RejectionSamplingOptions {
1996                target_accepted: 5,
1997                generation_batch_size: 4,
1998                output_batch_size: 2,
1999                envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2000                seed: 67890,
2001            },
2002        )
2003        .unwrap();
2004
2005        let mut iter = sampler.accepted_batches();
2006        let mut accepted_batches = Vec::new();
2007        for batch in iter.by_ref() {
2008            accepted_batches.push(batch.unwrap());
2009        }
2010        assert_eq!(
2011            accepted_batches
2012                .iter()
2013                .map(|batch| batch.dataset().n_events())
2014                .collect::<Vec<_>>(),
2015            vec![2, 2, 1]
2016        );
2017        assert_eq!(iter.diagnostics().generated_events, 8);
2018        assert_eq!(iter.diagnostics().accepted_events, 5);
2019        assert_eq!(iter.diagnostics().rejected_events, 0);
2020        assert_relative_eq!(iter.diagnostics().acceptance_efficiency(), 5.0 / 8.0);
2021        for batch in accepted_batches {
2022            assert_eq!(
2023                batch.layout().p4_labels(),
2024                &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
2025            );
2026        }
2027    }
2028
2029    #[test]
2030    fn fixed_envelope_rejection_sampler_rejects_violations() {
2031        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2032        let sampler = RejectionSampler::new(
2033            generator,
2034            |batch: &GeneratedBatch| Ok(vec![2.0; batch.dataset().n_events()]),
2035            RejectionSamplingOptions {
2036                target_accepted: 1,
2037                generation_batch_size: 1,
2038                output_batch_size: 1,
2039                envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2040                seed: 67890,
2041            },
2042        )
2043        .unwrap();
2044
2045        let mut iter = sampler.accepted_batches();
2046        let err = iter.next().expect("sampler should produce an error");
2047        assert!(err.is_err());
2048        assert_eq!(iter.diagnostics().envelope_violations, 1);
2049        assert_relative_eq!(iter.diagnostics().max_observed_weight, 2.0);
2050    }
2051
2052    #[test]
2053    fn rejection_sampler_validates_custom_batch_intensities() {
2054        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2055        let sampler = RejectionSampler::new(
2056            generator,
2057            |batch: &GeneratedBatch| Ok(vec![f64::NAN; batch.dataset().n_events()]),
2058            RejectionSamplingOptions {
2059                target_accepted: 1,
2060                generation_batch_size: 1,
2061                output_batch_size: 1,
2062                envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2063                seed: 67890,
2064            },
2065        )
2066        .unwrap();
2067
2068        let err = sampler
2069            .accepted_batches()
2070            .next()
2071            .expect("sampler should produce an error");
2072        assert!(err.is_err());
2073    }
2074
2075    #[test]
2076    fn expression_rejection_sampler_streams_exact_target_count() {
2077        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2078        let sampler = generator
2079            .rejection_sampler_with_expression(
2080                Expression::from(1.0),
2081                vec![],
2082                RejectionSamplingOptions {
2083                    target_accepted: 5,
2084                    generation_batch_size: 4,
2085                    output_batch_size: 2,
2086                    envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2087                    seed: 67890,
2088                },
2089            )
2090            .unwrap();
2091
2092        let batches = sampler
2093            .accepted_batches()
2094            .collect::<LadduResult<Vec<_>>>()
2095            .unwrap();
2096        assert_eq!(
2097            batches
2098                .iter()
2099                .map(|batch| batch.dataset().n_events())
2100                .collect::<Vec<_>>(),
2101            vec![2, 2, 1]
2102        );
2103        assert_eq!(
2104            batches
2105                .iter()
2106                .map(|batch| batch.dataset().n_events())
2107                .sum::<usize>(),
2108            5
2109        );
2110    }
2111
2112    #[test]
2113    fn expression_rejection_sampler_supports_pilot_envelope() {
2114        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2115        let sampler = generator
2116            .rejection_sampler_with_expression(
2117                Expression::from(1.0),
2118                vec![],
2119                RejectionSamplingOptions {
2120                    target_accepted: 3,
2121                    generation_batch_size: 2,
2122                    output_batch_size: 2,
2123                    envelope: RejectionEnvelope::Pilot {
2124                        pilot_events: 4,
2125                        batch_size: Some(2),
2126                        safety_factor: 1.25,
2127                    },
2128                    seed: 67890,
2129                },
2130            )
2131            .unwrap();
2132        let mut iter = sampler.accepted_batches();
2133        let batches = iter.by_ref().collect::<LadduResult<Vec<_>>>().unwrap();
2134        assert_eq!(
2135            batches
2136                .iter()
2137                .map(|batch| batch.dataset().n_events())
2138                .sum::<usize>(),
2139            3
2140        );
2141        assert_relative_eq!(iter.diagnostics().envelope_max_weight, 1.25);
2142    }
2143
2144    #[test]
2145    fn expression_rejection_sampler_rejects_invalid_intensities() {
2146        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2147        let err = generator
2148            .rejection_sampler_with_expression(
2149                Expression::from(-1.0),
2150                vec![],
2151                RejectionSamplingOptions {
2152                    target_accepted: 1,
2153                    generation_batch_size: 1,
2154                    output_batch_size: 1,
2155                    envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2156                    seed: 67890,
2157                },
2158            )
2159            .unwrap()
2160            .accepted_batches()
2161            .next()
2162            .expect("sampler should emit an error");
2163        assert!(err.is_err());
2164    }
2165
2166    #[test]
2167    fn duplicate_generated_particle_ids_are_rejected() {
2168        let beam = GeneratedParticle::initial(
2169            "beam",
2170            InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
2171            Reconstruction::Stored,
2172        );
2173        let target = GeneratedParticle::initial(
2174            "target",
2175            InitialGenerator::target(0.938272),
2176            Reconstruction::Missing,
2177        );
2178        let duplicate = GeneratedParticle::stable(
2179            "beam",
2180            StableGenerator::new(0.497611),
2181            Reconstruction::Stored,
2182        );
2183        let recoil = GeneratedParticle::stable(
2184            "recoil",
2185            StableGenerator::new(0.938272),
2186            Reconstruction::Stored,
2187        );
2188
2189        assert!(GeneratedReaction::two_to_two(
2190            beam,
2191            target,
2192            duplicate,
2193            recoil,
2194            MandelstamTDistribution::Exponential { slope: 0.1 },
2195        )
2196        .is_err());
2197    }
2198}