Skip to main content

laddu_generation/
topology.rs

1use std::collections::{HashMap, HashSet};
2
3use fastrand::Rng;
4use laddu_core::{
5    math::{q_m, Histogram, Sheet},
6    Dataset, DatasetMetadata, LadduError, LadduResult, Particle, Reaction, Vec3, Vec4, PI,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::distributions::{
11    Distribution, HistogramSampler, LadduGenRngExt, MandelstamTDistribution, SimpleDistribution,
12};
13
14/// Selects which generated particle four-momenta are written into generated datasets.
15///
16/// The generated reaction layout always retains the full generated graph. This policy only controls
17/// which generated particle IDs become p4 columns in generated [`Dataset`] values and which
18/// particles have a p4 label in [`GeneratedEventLayout`].
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub enum GeneratedStorage {
21    /// Store every generated particle p4.
22    All,
23    /// Store only the listed generated particle IDs, preserving reaction p4-label order.
24    Only(Vec<String>),
25}
26
27impl GeneratedStorage {
28    /// Store every generated particle p4.
29    pub fn all() -> Self {
30        Self::All
31    }
32
33    /// Store only the listed generated particle IDs.
34    pub fn only<I, S>(ids: I) -> Self
35    where
36        I: IntoIterator<Item = S>,
37        S: Into<String>,
38    {
39        Self::Only(ids.into_iter().map(Into::into).collect())
40    }
41
42    /// Return true if `id` is selected for dataset storage.
43    pub fn stores(&self, id: &str) -> bool {
44        match self {
45            Self::All => true,
46            Self::Only(ids) => ids.iter().any(|stored_id| stored_id == id),
47        }
48    }
49
50    fn validate(&self, available_ids: &[String]) -> LadduResult<()> {
51        let available = available_ids
52            .iter()
53            .map(String::as_str)
54            .collect::<HashSet<_>>();
55        let Self::Only(ids) = self else {
56            return Ok(());
57        };
58        let mut seen = HashSet::new();
59        for id in ids {
60            if !seen.insert(id.as_str()) {
61                return Err(LadduError::Custom(format!(
62                    "generated storage contains duplicate particle ID '{id}'"
63                )));
64            }
65            if !available.contains(id.as_str()) {
66                return Err(LadduError::Custom(format!(
67                    "generated storage references unknown particle ID '{id}'"
68                )));
69            }
70        }
71        Ok(())
72    }
73
74    fn stored_labels(&self, all_labels: &[String]) -> Vec<String> {
75        all_labels
76            .iter()
77            .filter(|label| self.stores(label))
78            .cloned()
79            .collect()
80    }
81}
82
83/// Experiment-neutral metadata describing a generated particle species.
84///
85/// Species metadata is intentionally separate from generated particle IDs and reconstructed
86/// reaction particles. It is meant for generator/export layers that need an external particle code
87/// or label without forcing laddu to adopt an experiment-specific particle table.
88#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
89pub enum ParticleSpecies {
90    /// A numeric species code with an optional namespace.
91    Code {
92        /// Numeric species identifier.
93        id: i64,
94        /// Optional namespace, such as `"pdg"`.
95        namespace: Option<String>,
96    },
97    /// A free-form species label.
98    Label(String),
99}
100
101impl ParticleSpecies {
102    /// Construct a species from a numeric code with no namespace.
103    pub fn code(id: i64) -> Self {
104        Self::Code {
105            id,
106            namespace: None,
107        }
108    }
109
110    /// Construct a species from a numeric code in an explicit namespace.
111    pub fn with_namespace(namespace: impl Into<String>, id: i64) -> Self {
112        Self::Code {
113            id,
114            namespace: Some(namespace.into()),
115        }
116    }
117
118    /// Construct a species from a free-form label.
119    pub fn label(label: impl Into<String>) -> Self {
120        Self::Label(label.into())
121    }
122}
123
124fn basis(z: Vec3) -> (Vec3, Vec3, Vec3) {
125    let z = z.unit();
126    let ref_axis = if z.z.abs() < 0.9 {
127        Vec3::z()
128    } else {
129        Vec3::y()
130    };
131    let x = ref_axis.cross(&z).unit();
132    let y = z.cross(&x);
133    (x, y, z)
134}
135
136/// Generator settings for an initial-state particle.
137#[derive(Clone, Debug)]
138pub struct InitialGenerator {
139    mass: f64,
140    energy_distribution: SimpleDistribution,
141}
142
143impl InitialGenerator {
144    /// Construct a beam with fixed energy.
145    pub fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
146        debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
147        debug_assert!(energy > 0.0, "Energy must be positive!\nEnergy: {}", energy);
148        Self {
149            mass,
150            energy_distribution: SimpleDistribution::Fixed(energy),
151        }
152    }
153
154    /// Construct a beam with uniformly sampled energy.
155    pub fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
156        debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
157        debug_assert!(
158            min_energy > 0.0,
159            "Minimum energy must be positive!\nMinimum Energy: {}",
160            min_energy
161        );
162        debug_assert!(
163            max_energy > min_energy,
164            "Maximum energy must be greater than minimum energy!"
165        );
166        Self {
167            mass,
168            energy_distribution: SimpleDistribution::Uniform {
169                min: min_energy,
170                max: max_energy,
171            },
172        }
173    }
174
175    /// Construct a beam with histogram-sampled energy.
176    pub fn beam_with_energy_histogram(mass: f64, energy: Histogram) -> LadduResult<Self> {
177        debug_assert!(
178            mass >= 0.0,
179            "Mass must be positive and greater than zero!\nMass: {}",
180            mass
181        );
182        let sampler = HistogramSampler::new(energy)?;
183        debug_assert!(
184            sampler.hist.bin_edges()[0] >= mass,
185            "Mass cannot be greater than the minimum allowed energy!\nMass: {}\nMinimum Energy: {}",
186            mass,
187            sampler.hist.bin_edges()[0]
188        );
189        Ok(Self {
190            mass,
191            energy_distribution: SimpleDistribution::Histogram(sampler),
192        })
193    }
194
195    /// Construct a target at rest.
196    pub fn target(mass: f64) -> Self {
197        Self {
198            mass,
199            energy_distribution: SimpleDistribution::Fixed(mass),
200        }
201    }
202}
203
204/// Generator settings for a generated composite particle.
205#[derive(Clone, Debug)]
206pub struct CompositeGenerator {
207    mass_distribution: SimpleDistribution,
208}
209
210impl CompositeGenerator {
211    /// Construct a composite mass generator with a uniform mass range.
212    pub fn new(min_mass: f64, max_mass: f64) -> Self {
213        Self {
214            mass_distribution: SimpleDistribution::Uniform {
215                min: min_mass,
216                max: max_mass,
217            },
218        }
219    }
220
221    fn sample_mass(&self, rng: &mut Rng) -> f64 {
222        self.mass_distribution.sample(rng)
223    }
224}
225
226/// Generator settings for a stable generated particle.
227#[derive(Clone, Debug)]
228pub struct StableGenerator {
229    mass_distribution: SimpleDistribution,
230}
231
232impl StableGenerator {
233    /// Construct a fixed-mass stable-particle generator.
234    pub fn new(mass: f64) -> Self {
235        debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
236        Self {
237            mass_distribution: SimpleDistribution::Fixed(mass),
238        }
239    }
240
241    fn sample_mass(&self, rng: &mut Rng) -> f64 {
242        self.mass_distribution.sample(rng)
243    }
244}
245
246/// Reconstruction interpretation for a generated particle.
247#[derive(Clone, Debug, PartialEq)]
248pub enum Reconstruction {
249    /// The particle p4 is stored in the analysis dataset under the generated particle ID.
250    Stored,
251    /// The particle p4 is fixed in the reconstructed reaction.
252    Fixed(Vec4),
253    /// The particle p4 is inferred from reaction-level constraints.
254    Missing,
255    /// The particle p4 is reconstructed as a composite of its two generated daughters.
256    Composite,
257}
258
259/// A generated particle with generation and reconstruction metadata.
260#[derive(Clone, Debug)]
261pub enum GeneratedParticle {
262    /// An initial-state generated particle.
263    Initial {
264        id: String,
265        generator: InitialGenerator,
266        reconstruction: Reconstruction,
267        species: Option<ParticleSpecies>,
268    },
269    /// A stable generated particle.
270    Stable {
271        id: String,
272        generator: StableGenerator,
273        reconstruction: Reconstruction,
274        species: Option<ParticleSpecies>,
275    },
276    /// A generated composite particle with exactly two generated daughters.
277    Composite {
278        id: String,
279        generator: CompositeGenerator,
280        daughters: (Box<GeneratedParticle>, Box<GeneratedParticle>),
281        reconstruction: Reconstruction,
282        species: Option<ParticleSpecies>,
283    },
284}
285
286impl GeneratedParticle {
287    /// Construct a generated initial-state particle.
288    pub fn initial(
289        id: impl Into<String>,
290        generator: InitialGenerator,
291        reconstruction: Reconstruction,
292    ) -> Self {
293        Self::Initial {
294            id: id.into(),
295            generator,
296            reconstruction,
297            species: None,
298        }
299    }
300
301    /// Construct a generated stable particle.
302    pub fn stable(
303        id: impl Into<String>,
304        generator: StableGenerator,
305        reconstruction: Reconstruction,
306    ) -> Self {
307        Self::Stable {
308            id: id.into(),
309            generator,
310            reconstruction,
311            species: None,
312        }
313    }
314
315    /// Construct a generated composite particle from exactly two ordered daughters.
316    pub fn composite(
317        id: impl Into<String>,
318        generator: CompositeGenerator,
319        daughters: (&GeneratedParticle, &GeneratedParticle),
320        reconstruction: Reconstruction,
321    ) -> Self {
322        Self::Composite {
323            id: id.into(),
324            generator,
325            daughters: (Box::new(daughters.0.clone()), Box::new(daughters.1.clone())),
326            reconstruction,
327            species: None,
328        }
329    }
330
331    /// Return a copy of this generated particle with species metadata attached.
332    pub fn with_species(mut self, species: ParticleSpecies) -> Self {
333        match &mut self {
334            Self::Initial {
335                species: particle_species,
336                ..
337            }
338            | Self::Stable {
339                species: particle_species,
340                ..
341            }
342            | Self::Composite {
343                species: particle_species,
344                ..
345            } => *particle_species = Some(species),
346        }
347        self
348    }
349
350    /// Return the generated particle ID.
351    pub fn id(&self) -> &str {
352        match self {
353            Self::Initial { id, .. } | Self::Stable { id, .. } | Self::Composite { id, .. } => id,
354        }
355    }
356
357    /// Return optional species metadata for this generated particle.
358    pub fn species(&self) -> Option<&ParticleSpecies> {
359        match self {
360            Self::Initial { species, .. }
361            | Self::Stable { species, .. }
362            | Self::Composite { species, .. } => species.as_ref(),
363        }
364    }
365
366    /// Return this particle's reconstruction interpretation.
367    pub fn reconstruction(&self) -> &Reconstruction {
368        match self {
369            Self::Initial { reconstruction, .. }
370            | Self::Stable { reconstruction, .. }
371            | Self::Composite { reconstruction, .. } => reconstruction,
372        }
373    }
374
375    fn p4_labels(&self) -> Vec<String> {
376        let mut labels = vec![self.id().to_string()];
377        if let Self::Composite { daughters, .. } = self {
378            labels.append(&mut daughters.0.p4_labels());
379            labels.append(&mut daughters.1.p4_labels());
380        }
381        labels
382    }
383
384    fn append_decay_layout(
385        &self,
386        parent_id: Option<usize>,
387        produced_vertex_id: Option<usize>,
388        storage: &GeneratedStorage,
389        particles: &mut Vec<GeneratedParticleLayout>,
390        vertices: &mut Vec<GeneratedVertexLayout>,
391    ) -> usize {
392        let product_id = particles.len();
393        particles.push(GeneratedParticleLayout {
394            id: self.id().to_string(),
395            product_id,
396            parent_id,
397            species: self.species().cloned(),
398            p4_label: storage.stores(self.id()).then(|| self.id().to_string()),
399            produced_vertex_id,
400            decay_vertex_id: None,
401        });
402        if let Self::Composite { daughters, .. } = self {
403            let vertex_id = vertices.len();
404            particles[product_id].decay_vertex_id = Some(vertex_id);
405            vertices.push(GeneratedVertexLayout {
406                vertex_id,
407                kind: GeneratedVertexKind::Decay,
408                incoming_product_ids: vec![product_id],
409                outgoing_product_ids: Vec::new(),
410            });
411            let daughter_1_id = daughters.0.append_decay_layout(
412                Some(product_id),
413                Some(vertex_id),
414                storage,
415                particles,
416                vertices,
417            );
418            let daughter_2_id = daughters.1.append_decay_layout(
419                Some(product_id),
420                Some(vertex_id),
421                storage,
422                particles,
423                vertices,
424            );
425            vertices[vertex_id].outgoing_product_ids = vec![daughter_1_id, daughter_2_id];
426        }
427        product_id
428    }
429
430    fn sample_mass(&self, rng: &mut Rng) -> f64 {
431        match self {
432            Self::Initial { generator, .. } => generator.mass,
433            Self::Stable { generator, .. } => generator.sample_mass(rng),
434            Self::Composite { generator, .. } => generator.sample_mass(rng),
435        }
436    }
437
438    fn generated_particle(&self) -> LadduResult<Particle> {
439        match self.reconstruction() {
440            Reconstruction::Stored => Ok(Particle::stored(self.id())),
441            Reconstruction::Fixed(p4) => Ok(Particle::fixed(self.id(), *p4)),
442            Reconstruction::Missing => Ok(Particle::missing(self.id())),
443            Reconstruction::Composite => {
444                let Self::Composite { daughters, .. } = self else {
445                    return Err(LadduError::Custom(format!(
446                        "particle '{}' cannot use composite reconstruction without daughters",
447                        self.id()
448                    )));
449                };
450                let daughter_1 = daughters.0.generated_particle()?;
451                let daughter_2 = daughters.1.generated_particle()?;
452                Particle::composite(self.id(), (&daughter_1, &daughter_2))
453            }
454        }
455    }
456
457    fn validate_reconstruction(&self) -> LadduResult<()> {
458        match (self, self.reconstruction()) {
459            (Self::Composite { daughters, .. }, Reconstruction::Composite) => {
460                daughters.0.validate_reconstruction()?;
461                daughters.1.validate_reconstruction()?;
462                Ok(())
463            }
464            (Self::Composite { .. }, _) => Ok(()),
465            (_, Reconstruction::Composite) => Err(LadduError::Custom(format!(
466                "particle '{}' cannot use composite reconstruction without daughters",
467                self.id()
468            ))),
469            _ => Ok(()),
470        }
471    }
472
473    fn collect_ids<'a>(&'a self, seen: &mut HashSet<&'a str>) -> LadduResult<()> {
474        if !seen.insert(self.id()) {
475            return Err(LadduError::Custom(format!(
476                "duplicate generated particle identifier '{}'",
477                self.id()
478            )));
479        }
480        if let Self::Composite { daughters, .. } = self {
481            daughters.0.collect_ids(seen)?;
482            daughters.1.collect_ids(seen)?;
483        }
484        Ok(())
485    }
486
487    fn generate_decay(
488        &self,
489        rng: &mut Rng,
490        p4_cm: Vec4,
491        cm_to_lab_boost: &Vec3,
492        p4_storage: &mut HashMap<String, Vec<Vec4>>,
493    ) {
494        let p4_lab = p4_cm.boost(cm_to_lab_boost);
495        if let Some(storage) = p4_storage.get_mut(self.id()) {
496            storage.push(p4_lab);
497        }
498
499        let Self::Composite { daughters, .. } = self else {
500            return;
501        };
502        let d1 = &daughters.0;
503        let d2 = &daughters.1;
504        let parent_mass = p4_cm.m();
505        let m1 = d1.sample_mass(rng);
506        let m2 = d2.sample_mass(rng);
507        let q = q_m(parent_mass, m1, m2, Sheet::Physical).re;
508        let parent_msq = parent_mass * parent_mass;
509        let msq1 = m1 * m1;
510        let msq2 = m2 * m2;
511        let e1 = (parent_msq + msq1 - msq2) / (2.0 * parent_mass);
512        let e2 = (parent_msq + msq2 - msq1) / (2.0 * parent_mass);
513
514        let cos_theta = rng.uniform(-1.0, 1.0);
515        let sin_theta = (1.0 - cos_theta * cos_theta).sqrt();
516        let phi = rng.uniform(0.0, 2.0 * PI);
517        let (sin_phi, cos_phi) = phi.sin_cos();
518
519        let dir = Vec3::new(sin_theta * cos_phi, sin_theta * sin_phi, cos_theta);
520        let p1_p4_rest = (dir * q).with_energy(e1);
521        let p2_p4_rest = (-dir * q).with_energy(e2);
522        let parent_to_cm_boost = p4_cm.beta();
523        let p1_p4_cm = p1_p4_rest.boost(&parent_to_cm_boost);
524        let p2_p4_cm = p2_p4_rest.boost(&parent_to_cm_boost);
525        d1.generate_decay(rng, p1_p4_cm, cm_to_lab_boost, p4_storage);
526        d2.generate_decay(rng, p2_p4_cm, cm_to_lab_boost, p4_storage);
527    }
528}
529
530/// A generated two-to-two reaction preserving `p1 + p2 -> p3 + p4` role semantics.
531#[derive(Clone, Debug)]
532pub struct GeneratedTwoToTwoReaction {
533    p1: GeneratedParticle,
534    p2: GeneratedParticle,
535    p3: GeneratedParticle,
536    p4: GeneratedParticle,
537    tdist: MandelstamTDistribution,
538    p1_p3_lab_dir: Vec3,
539    p2_p3_lab_dir: Vec3,
540}
541
542impl GeneratedTwoToTwoReaction {
543    /// Construct a generated two-to-two reaction.
544    pub fn new(
545        p1: GeneratedParticle,
546        p2: GeneratedParticle,
547        p3: GeneratedParticle,
548        p4: GeneratedParticle,
549        tdist: MandelstamTDistribution,
550    ) -> LadduResult<Self> {
551        validate_initial_role(&p1, "p1")?;
552        validate_initial_role(&p2, "p2")?;
553        validate_final_role(&p3, "p3")?;
554        validate_final_role(&p4, "p4")?;
555        let reaction = Self {
556            p1,
557            p2,
558            p3,
559            p4,
560            tdist,
561            p1_p3_lab_dir: Vec3::z(),
562            p2_p3_lab_dir: -Vec3::z(),
563        };
564        reaction.validate()?;
565        Ok(reaction)
566    }
567
568    fn validate(&self) -> LadduResult<()> {
569        let mut seen = HashSet::new();
570        for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
571            particle.collect_ids(&mut seen)?;
572            particle.validate_reconstruction()?;
573        }
574        Ok(())
575    }
576
577    fn p4_labels(&self) -> Vec<String> {
578        let mut labels = Vec::new();
579        for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
580            labels.append(&mut particle.p4_labels());
581        }
582        labels
583    }
584
585    fn layout_components(
586        &self,
587        storage: &GeneratedStorage,
588    ) -> (Vec<GeneratedParticleLayout>, Vec<GeneratedVertexLayout>) {
589        let mut particles = Vec::new();
590        let mut vertices = vec![GeneratedVertexLayout {
591            vertex_id: 0,
592            kind: GeneratedVertexKind::Production,
593            incoming_product_ids: Vec::new(),
594            outgoing_product_ids: Vec::new(),
595        }];
596        let p1_id = self
597            .p1
598            .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
599        let p2_id = self
600            .p2
601            .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
602        let p3_id =
603            self.p3
604                .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
605        let p4_id =
606            self.p4
607                .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
608        vertices[0].incoming_product_ids = vec![p1_id, p2_id];
609        vertices[0].outgoing_product_ids = vec![p3_id, p4_id];
610        (particles, vertices)
611    }
612
613    fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
614        self.particle_layouts_with_storage(&GeneratedStorage::All)
615    }
616
617    fn particle_layouts_with_storage(
618        &self,
619        storage: &GeneratedStorage,
620    ) -> Vec<GeneratedParticleLayout> {
621        self.layout_components(storage).0
622    }
623
624    fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
625        self.layout_components(&GeneratedStorage::All).1
626    }
627
628    fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
629        Reaction::two_to_two(
630            &self.p1.generated_particle()?,
631            &self.p2.generated_particle()?,
632            &self.p3.generated_particle()?,
633            &self.p4.generated_particle()?,
634        )
635    }
636
637    fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
638        let GeneratedParticle::Initial {
639            id: p1_id,
640            generator: p1_generator,
641            ..
642        } = &self.p1
643        else {
644            unreachable!("validated generated two-to-two p1 role")
645        };
646        let GeneratedParticle::Initial {
647            id: p2_id,
648            generator: p2_generator,
649            ..
650        } = &self.p2
651        else {
652            unreachable!("validated generated two-to-two p2 role")
653        };
654
655        let p1_e = p1_generator.energy_distribution.sample(rng);
656        let p1_m = p1_generator.mass;
657        let p1_msq = p1_m * p1_m;
658        let p1_p4_lab = rng.p4(p1_m, p1_e, self.p1_p3_lab_dir);
659        if let Some(storage) = p4_storage.get_mut(p1_id) {
660            storage.push(p1_p4_lab)
661        }
662
663        let p2_e = p2_generator.energy_distribution.sample(rng);
664        let p2_m = p2_generator.mass;
665        let p2_msq = p2_m * p2_m;
666        let p2_p4_lab = rng.p4(p2_m, p2_e, self.p2_p3_lab_dir);
667        if let Some(storage) = p4_storage.get_mut(p2_id) {
668            storage.push(p2_p4_lab)
669        }
670
671        let cm = p1_p4_lab + p2_p4_lab;
672        let cm_boost = -cm.beta();
673        let s = cm.mag2();
674        let sqrt_s = s.sqrt();
675        let t = self.tdist.sample(rng);
676
677        let p1_p4_cm = p1_p4_lab.boost(&cm_boost);
678        let p3_m = self.p3.sample_mass(rng);
679        let p3_msq = p3_m * p3_m;
680        let p4_m = self.p4.sample_mass(rng);
681        let p4_msq = p4_m * p4_m;
682        let p_in_mag = q_m(sqrt_s, p1_m, p2_m, Sheet::Physical).re;
683        let p_out_mag = q_m(sqrt_s, p3_m, p4_m, Sheet::Physical).re;
684        let p1_e_cm = (s + p1_msq - p2_msq) / (2.0 * sqrt_s);
685        let p3_e_cm = (s + p3_msq - p4_msq) / (2.0 * sqrt_s);
686        let p4_e_cm = (s + p4_msq - p3_msq) / (2.0 * sqrt_s);
687        let costheta =
688            (t - p1_msq - p3_msq + 2.0 * p1_e_cm * p3_e_cm) / (2.0 * p_in_mag * p_out_mag);
689        let costheta = costheta.clamp(-1.0, 1.0);
690        let sintheta = (1.0 - costheta * costheta).sqrt();
691        let phi = rng.uniform(0.0, 2.0 * PI);
692        let (sin_phi, cos_phi) = phi.sin_cos();
693        let (x, y, z) = basis(p1_p4_cm.vec3());
694        let p3_dir_cm = x * (sintheta * cos_phi) + y * (sintheta * sin_phi) + z * costheta;
695
696        let p3_p4_cm = (p3_dir_cm * p_out_mag).with_energy(p3_e_cm);
697        self.p3
698            .generate_decay(rng, p3_p4_cm, &-cm_boost, p4_storage);
699        let p4_p4_cm = (-p3_dir_cm * p_out_mag).with_energy(p4_e_cm);
700        self.p4
701            .generate_decay(rng, p4_p4_cm, &-cm_boost, p4_storage);
702    }
703}
704
705fn validate_initial_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
706    if matches!(particle, GeneratedParticle::Initial { .. }) {
707        Ok(())
708    } else {
709        Err(LadduError::Custom(format!(
710            "generated two-to-two role '{role}' requires an initial particle"
711        )))
712    }
713}
714
715fn validate_final_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
716    if matches!(
717        particle,
718        GeneratedParticle::Stable { .. } | GeneratedParticle::Composite { .. }
719    ) {
720        Ok(())
721    } else {
722        Err(LadduError::Custom(format!(
723            "generated two-to-two role '{role}' requires an outgoing particle"
724        )))
725    }
726}
727
728/// A generated reaction topology.
729#[derive(Clone, Debug)]
730pub enum GeneratedReactionTopology {
731    /// A generated two-to-two topology.
732    TwoToTwo(GeneratedTwoToTwoReaction),
733}
734
735impl GeneratedReactionTopology {
736    fn p4_labels(&self) -> Vec<String> {
737        match self {
738            Self::TwoToTwo(reaction) => reaction.p4_labels(),
739        }
740    }
741
742    fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
743        match self {
744            Self::TwoToTwo(reaction) => reaction.particle_layouts(),
745        }
746    }
747
748    fn particle_layouts_with_storage(
749        &self,
750        storage: &GeneratedStorage,
751    ) -> Vec<GeneratedParticleLayout> {
752        match self {
753            Self::TwoToTwo(reaction) => reaction.particle_layouts_with_storage(storage),
754        }
755    }
756
757    fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
758        match self {
759            Self::TwoToTwo(reaction) => reaction.vertex_layouts(),
760        }
761    }
762
763    fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
764        match self {
765            Self::TwoToTwo(reaction) => reaction.reconstructed_reaction(),
766        }
767    }
768
769    fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
770        match self {
771            Self::TwoToTwo(reaction) => reaction.generate_event(rng, p4_storage),
772        }
773    }
774}
775
776/// A generated reaction layout.
777#[derive(Clone, Debug)]
778pub struct GeneratedReaction {
779    topology: GeneratedReactionTopology,
780}
781
782impl GeneratedReaction {
783    /// Construct a generated two-to-two reaction.
784    pub fn two_to_two(
785        p1: GeneratedParticle,
786        p2: GeneratedParticle,
787        p3: GeneratedParticle,
788        p4: GeneratedParticle,
789        tdist: MandelstamTDistribution,
790    ) -> LadduResult<Self> {
791        Ok(Self {
792            topology: GeneratedReactionTopology::TwoToTwo(GeneratedTwoToTwoReaction::new(
793                p1, p2, p3, p4, tdist,
794            )?),
795        })
796    }
797
798    /// Return generated p4 labels.
799    pub fn p4_labels(&self) -> Vec<String> {
800        self.topology.p4_labels()
801    }
802
803    /// Return generated particle layout entries in stable product-ID order.
804    pub fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
805        self.topology.particle_layouts()
806    }
807
808    /// Return generated particle layout entries for a dataset storage policy.
809    pub fn particle_layouts_with_storage(
810        &self,
811        storage: &GeneratedStorage,
812    ) -> Vec<GeneratedParticleLayout> {
813        self.topology.particle_layouts_with_storage(storage)
814    }
815
816    /// Return generated vertex layout entries in stable vertex-ID order.
817    pub fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
818        self.topology.vertex_layouts()
819    }
820
821    /// Build the reconstructed reaction corresponding to this generated layout.
822    pub fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
823        self.topology.reconstructed_reaction()
824    }
825
826    fn generate(
827        &self,
828        rng: &mut Rng,
829        p4_storage: &mut HashMap<String, Vec<Vec4>>,
830        n_events: usize,
831    ) {
832        for _ in 0..n_events {
833            self.topology.generate_event(rng, p4_storage);
834        }
835    }
836}
837
838/// Metadata for one generated particle in a generated event layout.
839#[derive(Clone, Debug, PartialEq, Eq)]
840pub struct GeneratedParticleLayout {
841    id: String,
842    product_id: usize,
843    parent_id: Option<usize>,
844    species: Option<ParticleSpecies>,
845    p4_label: Option<String>,
846    produced_vertex_id: Option<usize>,
847    decay_vertex_id: Option<usize>,
848}
849
850impl GeneratedParticleLayout {
851    /// Return the generated particle identifier.
852    pub fn id(&self) -> &str {
853        &self.id
854    }
855
856    /// Return the zero-based stable product ID in generated-layout order.
857    pub fn product_id(&self) -> usize {
858        self.product_id
859    }
860
861    /// Return the decay-parent product ID, if this particle is a decay daughter.
862    pub fn parent_id(&self) -> Option<usize> {
863        self.parent_id
864    }
865
866    /// Return optional species metadata associated with this generated particle.
867    pub fn species(&self) -> Option<&ParticleSpecies> {
868        self.species.as_ref()
869    }
870
871    /// Return the dataset p4 label associated with this particle, if stored in the batch.
872    pub fn p4_label(&self) -> Option<&str> {
873        self.p4_label.as_deref()
874    }
875
876    /// Return the vertex ID where this particle was produced, if any.
877    pub fn produced_vertex_id(&self) -> Option<usize> {
878        self.produced_vertex_id
879    }
880
881    /// Return the vertex ID where this particle decays, if it is a generated parent.
882    pub fn decay_vertex_id(&self) -> Option<usize> {
883        self.decay_vertex_id
884    }
885}
886
887/// The semantic kind of a generated vertex.
888#[derive(Clone, Copy, Debug, PartialEq, Eq)]
889pub enum GeneratedVertexKind {
890    /// A production vertex connecting initial-state particles to outgoing products.
891    Production,
892    /// A decay vertex connecting one generated parent to generated daughters.
893    Decay,
894}
895
896/// Metadata for one generated vertex in a generated event layout.
897#[derive(Clone, Debug, PartialEq, Eq)]
898pub struct GeneratedVertexLayout {
899    vertex_id: usize,
900    kind: GeneratedVertexKind,
901    incoming_product_ids: Vec<usize>,
902    outgoing_product_ids: Vec<usize>,
903}
904
905impl GeneratedVertexLayout {
906    /// Return the zero-based stable vertex ID in generated-layout order.
907    pub fn vertex_id(&self) -> usize {
908        self.vertex_id
909    }
910
911    /// Return the semantic vertex kind.
912    pub fn kind(&self) -> GeneratedVertexKind {
913        self.kind
914    }
915
916    /// Return product IDs entering this vertex.
917    pub fn incoming_product_ids(&self) -> &[usize] {
918        &self.incoming_product_ids
919    }
920
921    /// Return product IDs leaving this vertex.
922    pub fn outgoing_product_ids(&self) -> &[usize] {
923        &self.outgoing_product_ids
924    }
925}
926
927/// Metadata describing the columns and generated particles in a generated event batch.
928#[derive(Clone, Debug, PartialEq, Eq)]
929pub struct GeneratedEventLayout {
930    p4_labels: Vec<String>,
931    aux_labels: Vec<String>,
932    particles: Vec<GeneratedParticleLayout>,
933    vertices: Vec<GeneratedVertexLayout>,
934}
935
936impl GeneratedEventLayout {
937    /// Construct generated event layout metadata from p4 and auxiliary labels.
938    pub fn new(
939        p4_labels: Vec<String>,
940        aux_labels: Vec<String>,
941        particles: Vec<GeneratedParticleLayout>,
942        vertices: Vec<GeneratedVertexLayout>,
943    ) -> Self {
944        Self {
945            p4_labels,
946            aux_labels,
947            particles,
948            vertices,
949        }
950    }
951
952    /// Return generated p4 column labels in dataset order.
953    pub fn p4_labels(&self) -> &[String] {
954        &self.p4_labels
955    }
956
957    /// Return generated auxiliary column labels in dataset order.
958    pub fn aux_labels(&self) -> &[String] {
959        &self.aux_labels
960    }
961
962    /// Return generated particle layout entries in stable product-ID order.
963    pub fn particles(&self) -> &[GeneratedParticleLayout] {
964        &self.particles
965    }
966
967    /// Return the generated particle layout for a generated particle ID.
968    pub fn particle(&self, id: &str) -> Option<&GeneratedParticleLayout> {
969        self.particles.iter().find(|particle| particle.id() == id)
970    }
971
972    /// Return the generated particle layout for a stable product ID.
973    pub fn product(&self, product_id: usize) -> Option<&GeneratedParticleLayout> {
974        self.particles
975            .iter()
976            .find(|particle| particle.product_id() == product_id)
977    }
978
979    /// Return generated vertex layout entries in stable vertex-ID order.
980    pub fn vertices(&self) -> &[GeneratedVertexLayout] {
981        &self.vertices
982    }
983
984    /// Return the generated vertex layout for a stable vertex ID.
985    pub fn vertex(&self, vertex_id: usize) -> Option<&GeneratedVertexLayout> {
986        self.vertices
987            .iter()
988            .find(|vertex| vertex.vertex_id() == vertex_id)
989    }
990
991    /// Return the production vertex layout, if the generated layout has one.
992    pub fn production_vertex(&self) -> Option<&GeneratedVertexLayout> {
993        self.vertices
994            .iter()
995            .find(|vertex| vertex.kind() == GeneratedVertexKind::Production)
996    }
997
998    /// Return the generated decay daughters of a parent product ID.
999    pub fn decay_products(&self, parent_product_id: usize) -> Vec<&GeneratedParticleLayout> {
1000        self.particles
1001            .iter()
1002            .filter(|particle| particle.parent_id() == Some(parent_product_id))
1003            .collect()
1004    }
1005
1006    /// Return production-level incoming particle layouts.
1007    pub fn production_incoming(&self) -> Vec<&GeneratedParticleLayout> {
1008        self.production_vertex_products(GeneratedVertexLayout::incoming_product_ids)
1009    }
1010
1011    /// Return production-level outgoing particle layouts.
1012    pub fn production_outgoing(&self) -> Vec<&GeneratedParticleLayout> {
1013        self.production_vertex_products(GeneratedVertexLayout::outgoing_product_ids)
1014    }
1015
1016    fn production_vertex_products(
1017        &self,
1018        ids: impl FnOnce(&GeneratedVertexLayout) -> &[usize],
1019    ) -> Vec<&GeneratedParticleLayout> {
1020        self.production_vertex()
1021            .map(|vertex| {
1022                ids(vertex)
1023                    .iter()
1024                    .filter_map(|product_id| self.product(*product_id))
1025                    .collect()
1026            })
1027            .unwrap_or_default()
1028    }
1029}
1030
1031/// A generated dataset batch plus the metadata needed to interpret it.
1032#[derive(Clone, Debug)]
1033pub struct GeneratedBatch {
1034    dataset: Dataset,
1035    reaction: GeneratedReaction,
1036    layout: GeneratedEventLayout,
1037}
1038
1039impl GeneratedBatch {
1040    /// Construct a generated batch.
1041    pub fn new(
1042        dataset: Dataset,
1043        reaction: GeneratedReaction,
1044        layout: GeneratedEventLayout,
1045    ) -> Self {
1046        Self {
1047            dataset,
1048            reaction,
1049            layout,
1050        }
1051    }
1052
1053    /// Borrow the generated dataset.
1054    pub fn dataset(&self) -> &Dataset {
1055        &self.dataset
1056    }
1057
1058    /// Consume this batch and return the generated dataset.
1059    pub fn into_dataset(self) -> Dataset {
1060        self.dataset
1061    }
1062
1063    /// Borrow the generated reaction metadata.
1064    pub fn reaction(&self) -> &GeneratedReaction {
1065        &self.reaction
1066    }
1067
1068    /// Borrow the generated event layout metadata.
1069    pub fn layout(&self) -> &GeneratedEventLayout {
1070        &self.layout
1071    }
1072}
1073
1074/// Event generator for generated reactions.
1075#[derive(Clone, Debug)]
1076pub struct EventGenerator {
1077    reaction: GeneratedReaction,
1078    aux_generators: HashMap<String, Distribution>,
1079    storage: GeneratedStorage,
1080    seed: u64,
1081}
1082
1083/// Finite iterator over generated dataset batches.
1084#[derive(Clone, Debug)]
1085pub struct GeneratedBatchIter {
1086    generator: EventGenerator,
1087    remaining_events: usize,
1088    batch_size: usize,
1089    rng: Rng,
1090}
1091
1092impl Iterator for GeneratedBatchIter {
1093    type Item = LadduResult<GeneratedBatch>;
1094
1095    fn next(&mut self) -> Option<Self::Item> {
1096        if self.remaining_events == 0 {
1097            return None;
1098        }
1099        let n_events = self.batch_size.min(self.remaining_events);
1100        self.remaining_events -= n_events;
1101        Some(
1102            self.generator
1103                .generate_batch_with_rng(n_events, &mut self.rng),
1104        )
1105    }
1106}
1107
1108/// Evaluates unnormalized intensities for generated batches.
1109pub trait BatchIntensity {
1110    /// Return one nonnegative finite intensity for each event in `batch`.
1111    fn evaluate(&mut self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>>;
1112}
1113
1114impl<F> BatchIntensity for F
1115where
1116    F: FnMut(&GeneratedBatch) -> LadduResult<Vec<f64>>,
1117{
1118    fn evaluate(&mut self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1119        self(batch)
1120    }
1121}
1122
1123/// Envelope strategy used by rejection sampling.
1124#[derive(Clone, Debug)]
1125pub enum RejectionEnvelope {
1126    /// Use a fixed maximum event weight.
1127    Fixed {
1128        /// Maximum event weight used as the rejection envelope.
1129        max_weight: f64,
1130    },
1131}
1132
1133impl RejectionEnvelope {
1134    fn max_weight(&self) -> f64 {
1135        match self {
1136            Self::Fixed { max_weight } => *max_weight,
1137        }
1138    }
1139}
1140
1141/// Options for rejection sampling generated events.
1142#[derive(Clone, Debug)]
1143pub struct RejectionSamplingOptions {
1144    /// Number of accepted events to produce.
1145    pub target_accepted: usize,
1146    /// Number of raw events to generate per source batch.
1147    pub generation_batch_size: usize,
1148    /// Target number of accepted events emitted per output batch.
1149    pub output_batch_size: usize,
1150    /// Envelope used by the rejection sampler.
1151    pub envelope: RejectionEnvelope,
1152    /// Random seed used for accept/reject decisions.
1153    pub seed: u64,
1154}
1155
1156/// Rejection-sampling diagnostics accumulated while sampling.
1157#[derive(Clone, Debug, Default)]
1158pub struct RejectionSamplingDiagnostics {
1159    /// Number of generated events inspected.
1160    pub generated_events: usize,
1161    /// Number of events accepted.
1162    pub accepted_events: usize,
1163    /// Number of events rejected.
1164    pub rejected_events: usize,
1165    /// Maximum observed event intensity.
1166    pub max_observed_weight: f64,
1167    /// Envelope maximum used for rejection sampling.
1168    pub envelope_max_weight: f64,
1169    /// Number of fixed-envelope violations observed.
1170    pub envelope_violations: usize,
1171}
1172
1173impl RejectionSamplingDiagnostics {
1174    /// Fraction of generated events accepted.
1175    pub fn acceptance_efficiency(&self) -> f64 {
1176        if self.generated_events == 0 {
1177            0.0
1178        } else {
1179            self.accepted_events as f64 / self.generated_events as f64
1180        }
1181    }
1182}
1183
1184/// Rejection sampler over generated batches.
1185#[derive(Clone, Debug)]
1186pub struct RejectionSampler<I> {
1187    generator: EventGenerator,
1188    intensity: I,
1189    options: RejectionSamplingOptions,
1190}
1191
1192impl<I> RejectionSampler<I>
1193where
1194    I: BatchIntensity,
1195{
1196    /// Construct a rejection sampler.
1197    pub fn new(
1198        generator: EventGenerator,
1199        intensity: I,
1200        options: RejectionSamplingOptions,
1201    ) -> LadduResult<Self> {
1202        if options.generation_batch_size == 0 {
1203            return Err(LadduError::Custom(
1204                "generation_batch_size must be greater than zero".to_string(),
1205            ));
1206        }
1207        if options.output_batch_size == 0 {
1208            return Err(LadduError::Custom(
1209                "output_batch_size must be greater than zero".to_string(),
1210            ));
1211        }
1212        let max_weight = options.envelope.max_weight();
1213        if !max_weight.is_finite() || max_weight <= 0.0 {
1214            return Err(LadduError::Custom(
1215                "rejection envelope max_weight must be finite and positive".to_string(),
1216            ));
1217        }
1218        Ok(Self {
1219            generator,
1220            intensity,
1221            options,
1222        })
1223    }
1224
1225    /// Consume this sampler and return an iterator over accepted generated batches.
1226    pub fn accepted_batches(self) -> RejectionSampleIter<I> {
1227        let envelope_max_weight = self.options.envelope.max_weight();
1228        RejectionSampleIter {
1229            generation_rng: Rng::with_seed(self.generator.seed),
1230            rejection_rng: Rng::with_seed(self.options.seed),
1231            diagnostics: RejectionSamplingDiagnostics {
1232                envelope_max_weight,
1233                ..Default::default()
1234            },
1235            sampler: self,
1236            current_batch: None,
1237            current_intensities: Vec::new(),
1238            current_index: 0,
1239        }
1240    }
1241}
1242
1243/// Iterator over accepted generated batches.
1244#[derive(Clone, Debug)]
1245pub struct RejectionSampleIter<I> {
1246    sampler: RejectionSampler<I>,
1247    generation_rng: Rng,
1248    rejection_rng: Rng,
1249    diagnostics: RejectionSamplingDiagnostics,
1250    current_batch: Option<GeneratedBatch>,
1251    current_intensities: Vec<f64>,
1252    current_index: usize,
1253}
1254
1255impl<I> RejectionSampleIter<I> {
1256    /// Borrow rejection-sampling diagnostics accumulated so far.
1257    pub fn diagnostics(&self) -> &RejectionSamplingDiagnostics {
1258        &self.diagnostics
1259    }
1260}
1261
1262impl<I> RejectionSampleIter<I>
1263where
1264    I: BatchIntensity,
1265{
1266    fn load_next_source_batch(&mut self) -> LadduResult<()> {
1267        let batch = self.sampler.generator.generate_batch_with_rng(
1268            self.sampler.options.generation_batch_size,
1269            &mut self.generation_rng,
1270        )?;
1271        let intensities = self.sampler.intensity.evaluate(&batch)?;
1272        if intensities.len() != batch.dataset().n_events() {
1273            return Err(LadduError::Custom(format!(
1274                "intensity length mismatch: expected {}, got {}",
1275                batch.dataset().n_events(),
1276                intensities.len()
1277            )));
1278        }
1279        self.diagnostics.generated_events += batch.dataset().n_events();
1280        self.current_batch = Some(batch);
1281        self.current_intensities = intensities;
1282        self.current_index = 0;
1283        Ok(())
1284    }
1285
1286    fn empty_output_batch(source: &GeneratedBatch) -> GeneratedBatch {
1287        GeneratedBatch::new(
1288            Dataset::empty_local(source.dataset().metadata().clone()),
1289            source.reaction().clone(),
1290            source.layout().clone(),
1291        )
1292    }
1293}
1294
1295impl<I> Iterator for RejectionSampleIter<I>
1296where
1297    I: BatchIntensity,
1298{
1299    type Item = LadduResult<GeneratedBatch>;
1300
1301    fn next(&mut self) -> Option<Self::Item> {
1302        if self.diagnostics.accepted_events >= self.sampler.options.target_accepted {
1303            return None;
1304        }
1305
1306        let mut output: Option<GeneratedBatch> = None;
1307        while self.diagnostics.accepted_events < self.sampler.options.target_accepted {
1308            let needs_batch = self
1309                .current_batch
1310                .as_ref()
1311                .map(|batch| self.current_index >= batch.dataset().n_events())
1312                .unwrap_or(true);
1313            if needs_batch {
1314                if let Err(err) = self.load_next_source_batch() {
1315                    return Some(Err(err));
1316                }
1317            }
1318
1319            let source = self
1320                .current_batch
1321                .as_ref()
1322                .expect("source batch should be loaded");
1323            if output.is_none() {
1324                output = Some(Self::empty_output_batch(source));
1325            }
1326
1327            let weight = self.current_intensities[self.current_index];
1328            if !weight.is_finite() || weight < 0.0 {
1329                return Some(Err(LadduError::Custom(format!(
1330                    "intensity at event {} must be finite and nonnegative, got {weight}",
1331                    self.current_index
1332                ))));
1333            }
1334            self.diagnostics.max_observed_weight = self.diagnostics.max_observed_weight.max(weight);
1335            let envelope_max = self.sampler.options.envelope.max_weight();
1336            if weight > envelope_max {
1337                self.diagnostics.envelope_violations += 1;
1338                return Some(Err(LadduError::Custom(format!(
1339                    "rejection envelope violation: observed weight {weight} exceeds max_weight {envelope_max}"
1340                ))));
1341            }
1342
1343            let accepted = self.rejection_rng.f64() * envelope_max < weight;
1344            if accepted {
1345                let event = match source.dataset().event_global(self.current_index) {
1346                    Ok(event) => event,
1347                    Err(err) => return Some(Err(err)),
1348                };
1349                if let Err(err) = output.as_mut().unwrap().dataset.push_event_local(
1350                    event.p4s.clone(),
1351                    event.aux.clone(),
1352                    event.weight,
1353                ) {
1354                    return Some(Err(err));
1355                }
1356                self.diagnostics.accepted_events += 1;
1357            } else {
1358                self.diagnostics.rejected_events += 1;
1359            }
1360            self.current_index += 1;
1361
1362            if output.as_ref().unwrap().dataset().n_events()
1363                >= self.sampler.options.output_batch_size
1364                || self.diagnostics.accepted_events >= self.sampler.options.target_accepted
1365            {
1366                break;
1367            }
1368        }
1369
1370        output
1371            .filter(|batch| batch.dataset().n_events() > 0)
1372            .map(Ok)
1373    }
1374}
1375
1376impl EventGenerator {
1377    /// Construct an event generator.
1378    pub fn new(
1379        reaction: GeneratedReaction,
1380        aux_generators: HashMap<String, Distribution>,
1381        seed: Option<u64>,
1382    ) -> Self {
1383        Self {
1384            reaction,
1385            aux_generators,
1386            storage: GeneratedStorage::All,
1387            seed: seed.unwrap_or_else(|| fastrand::u64(..)),
1388        }
1389    }
1390
1391    /// Return the generated p4 storage policy.
1392    pub fn storage(&self) -> &GeneratedStorage {
1393        &self.storage
1394    }
1395
1396    /// Return a copy of this generator with a generated p4 storage policy.
1397    pub fn with_storage(mut self, storage: GeneratedStorage) -> LadduResult<Self> {
1398        storage.validate(&self.reaction.p4_labels())?;
1399        self.storage = storage;
1400        Ok(self)
1401    }
1402
1403    fn aux_entries(&self) -> Vec<(&String, &Distribution)> {
1404        let mut aux_entries = self.aux_generators.iter().collect::<Vec<_>>();
1405        aux_entries.sort_by_key(|(label, _)| *label);
1406        aux_entries
1407    }
1408
1409    fn generate_batch_with_rng(
1410        &self,
1411        n_events: usize,
1412        rng: &mut Rng,
1413    ) -> LadduResult<GeneratedBatch> {
1414        let all_p4_labels = self.reaction.p4_labels();
1415        self.storage.validate(&all_p4_labels)?;
1416        let p4_labels = self.storage.stored_labels(&all_p4_labels);
1417        let aux_entries = self.aux_entries();
1418        let aux_labels = aux_entries
1419            .iter()
1420            .map(|(label, _)| (*label).clone())
1421            .collect::<Vec<_>>();
1422        let mut p4_data: HashMap<String, Vec<Vec4>> = p4_labels
1423            .iter()
1424            .map(|label| (label.clone(), Vec::with_capacity(n_events)))
1425            .collect();
1426        let metadata = DatasetMetadata::new(p4_labels.clone(), aux_labels.clone())?;
1427        let mut aux: Vec<Vec<f64>> = aux_entries
1428            .iter()
1429            .map(|_| Vec::with_capacity(n_events))
1430            .collect();
1431        let weights = vec![1.0; n_events];
1432        for _ in 0..n_events {
1433            for ((_, distribution), column) in aux_entries.iter().zip(aux.iter_mut()) {
1434                column.push(distribution.sample(rng));
1435            }
1436            self.reaction.generate(rng, &mut p4_data, 1);
1437        }
1438        let p4 = p4_labels
1439            .iter()
1440            .filter_map(|label| p4_data.remove(label))
1441            .collect();
1442        let dataset = Dataset::from_columns_local(metadata, p4, aux, weights)?;
1443        Ok(GeneratedBatch::new(
1444            dataset,
1445            self.reaction.clone(),
1446            GeneratedEventLayout::new(
1447                p4_labels,
1448                aux_labels,
1449                self.reaction.particle_layouts_with_storage(&self.storage),
1450                self.reaction.vertex_layouts(),
1451            ),
1452        ))
1453    }
1454
1455    /// Generate one dataset batch with generated layout metadata.
1456    pub fn generate_batch(&self, n_events: usize) -> LadduResult<GeneratedBatch> {
1457        let mut rng = Rng::with_seed(self.seed);
1458        self.generate_batch_with_rng(n_events, &mut rng)
1459    }
1460
1461    /// Generate a finite iterator over batches.
1462    ///
1463    /// The iterator advances one RNG stream, so concatenating all yielded batches is
1464    /// deterministic and matches [`EventGenerator::generate_dataset`] for the same total count.
1465    pub fn generate_batches(
1466        &self,
1467        total_events: usize,
1468        batch_size: usize,
1469    ) -> LadduResult<GeneratedBatchIter> {
1470        if batch_size == 0 {
1471            return Err(LadduError::Custom(
1472                "batch_size must be greater than zero".to_string(),
1473            ));
1474        }
1475        Ok(GeneratedBatchIter {
1476            generator: self.clone(),
1477            remaining_events: total_events,
1478            batch_size,
1479            rng: Rng::with_seed(self.seed),
1480        })
1481    }
1482
1483    /// Generate a dataset.
1484    pub fn generate_dataset(&self, n_events: usize) -> LadduResult<Dataset> {
1485        Ok(self.generate_batch(n_events)?.into_dataset())
1486    }
1487}
1488
1489#[cfg(test)]
1490mod tests {
1491    use approx::assert_relative_eq;
1492    use laddu_core::{traits::Variable, Channel, Frame};
1493
1494    use super::*;
1495
1496    fn demo_reaction() -> GeneratedReaction {
1497        let beam = GeneratedParticle::initial(
1498            "beam",
1499            InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1500            Reconstruction::Stored,
1501        );
1502        let target = GeneratedParticle::initial(
1503            "target",
1504            InitialGenerator::target(0.938272),
1505            Reconstruction::Missing,
1506        );
1507        let ks1 = GeneratedParticle::stable(
1508            "kshort1",
1509            StableGenerator::new(0.497611),
1510            Reconstruction::Stored,
1511        );
1512        let ks2 = GeneratedParticle::stable(
1513            "kshort2",
1514            StableGenerator::new(0.497611),
1515            Reconstruction::Stored,
1516        );
1517        let kk = GeneratedParticle::composite(
1518            "kk",
1519            CompositeGenerator::new(1.1, 1.6),
1520            (&ks1, &ks2),
1521            Reconstruction::Composite,
1522        );
1523        let recoil = GeneratedParticle::stable(
1524            "recoil",
1525            StableGenerator::new(0.938272),
1526            Reconstruction::Stored,
1527        );
1528        let tdist = MandelstamTDistribution::Exponential { slope: 0.1 };
1529        GeneratedReaction::two_to_two(beam, target, kk, recoil, tdist).unwrap()
1530    }
1531
1532    #[test]
1533    fn test_generation() {
1534        let reaction = demo_reaction();
1535        let generator = EventGenerator::new(reaction, HashMap::new(), Some(12345));
1536        let n_events = 1_000;
1537        let dataset = generator.generate_dataset(n_events).unwrap();
1538        assert_eq!(dataset.n_events(), n_events);
1539        let metadata = dataset.metadata();
1540        assert!(metadata.p4_index("beam").is_some());
1541        assert!(metadata.p4_index("target").is_some());
1542        assert!(metadata.p4_index("kk").is_some());
1543        assert!(metadata.p4_index("kshort1").is_some());
1544        assert!(metadata.p4_index("kshort2").is_some());
1545        assert!(metadata.p4_index("recoil").is_some());
1546
1547        for event in dataset.events_global() {
1548            let beam_p4 = event.p4("beam").unwrap();
1549            let target_p4 = event.p4("target").unwrap();
1550            let kk_p4 = event.p4("kk").unwrap();
1551            let kshort1_p4 = event.p4("kshort1").unwrap();
1552            let kshort2_p4 = event.p4("kshort2").unwrap();
1553            let recoil_p4 = event.p4("recoil").unwrap();
1554
1555            assert!(beam_p4.e().is_finite());
1556            assert!(target_p4.e().is_finite());
1557            assert!(kk_p4.e().is_finite());
1558            assert!(kshort1_p4.e().is_finite());
1559            assert!(kshort2_p4.e().is_finite());
1560            assert!(recoil_p4.e().is_finite());
1561
1562            assert_relative_eq!(kk_p4, kshort1_p4 + kshort2_p4, epsilon = 1e-10);
1563            assert_relative_eq!(beam_p4 + target_p4, kk_p4 + recoil_p4, epsilon = 1e-10);
1564            assert_relative_eq!(kshort1_p4.m(), 0.497611, epsilon = 1e-10);
1565            assert_relative_eq!(kshort2_p4.m(), 0.497611, epsilon = 1e-10);
1566            assert_relative_eq!(recoil_p4.m(), 0.938272, epsilon = 1e-10);
1567        }
1568    }
1569
1570    #[test]
1571    fn test_reconstructed_reaction() {
1572        let generated = demo_reaction();
1573        let reaction = generated.reconstructed_reaction().unwrap();
1574        let dataset = EventGenerator::new(generated, HashMap::new(), Some(12345))
1575            .generate_dataset(4)
1576            .unwrap();
1577        let mass = reaction.mass("kk").value_on(&dataset).unwrap();
1578        let angles = reaction
1579            .decay("kk")
1580            .unwrap()
1581            .angles("kshort1", Frame::Helicity)
1582            .unwrap();
1583        let mandelstam = reaction
1584            .mandelstam(Channel::S)
1585            .unwrap()
1586            .value_on(&dataset)
1587            .unwrap();
1588
1589        assert_eq!(mass.len(), 4);
1590        assert_eq!(
1591            angles.costheta.to_string(),
1592            "CosTheta(parent=kk, daughter=kshort1, frame=Helicity)"
1593        );
1594        assert_eq!(mandelstam.len(), 4);
1595    }
1596
1597    #[test]
1598    fn test_generated_batch_metadata() {
1599        let generated = demo_reaction();
1600        let generator = EventGenerator::new(
1601            generated,
1602            HashMap::from([("pol_angle".to_string(), Distribution::Fixed(0.25))]),
1603            Some(12345),
1604        );
1605        let batch = generator.generate_batch(4).unwrap();
1606
1607        assert_eq!(batch.dataset().n_events(), 4);
1608        assert_eq!(
1609            batch.layout().p4_labels(),
1610            &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1611        );
1612        assert_eq!(batch.layout().aux_labels(), &["pol_angle"]);
1613        assert_eq!(
1614            batch.reaction().p4_labels(),
1615            vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1616        );
1617        assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1618        assert_eq!(batch.dataset().aux_names(), batch.layout().aux_labels());
1619        let particles = batch.layout().particles();
1620        assert_eq!(particles.len(), 6);
1621        assert_eq!(particles[0].id(), "beam");
1622        assert_eq!(particles[0].product_id(), 0);
1623        assert_eq!(particles[0].parent_id(), None);
1624        assert_eq!(particles[0].produced_vertex_id(), None);
1625        assert_eq!(particles[0].decay_vertex_id(), None);
1626        assert_eq!(particles[1].id(), "target");
1627        assert_eq!(particles[1].parent_id(), None);
1628        assert_eq!(particles[1].produced_vertex_id(), None);
1629        assert_eq!(particles[1].decay_vertex_id(), None);
1630        assert_eq!(particles[2].id(), "kk");
1631        assert_eq!(particles[2].product_id(), 2);
1632        assert_eq!(particles[2].parent_id(), None);
1633        assert_eq!(particles[2].produced_vertex_id(), Some(0));
1634        assert_eq!(particles[2].decay_vertex_id(), Some(1));
1635        assert_eq!(particles[3].id(), "kshort1");
1636        assert_eq!(particles[3].parent_id(), Some(2));
1637        assert_eq!(particles[3].produced_vertex_id(), Some(1));
1638        assert_eq!(particles[3].decay_vertex_id(), None);
1639        assert_eq!(particles[4].id(), "kshort2");
1640        assert_eq!(particles[4].parent_id(), Some(2));
1641        assert_eq!(particles[4].produced_vertex_id(), Some(1));
1642        assert_eq!(particles[4].decay_vertex_id(), None);
1643        assert_eq!(particles[5].id(), "recoil");
1644        assert_eq!(particles[5].parent_id(), None);
1645        assert_eq!(particles[5].produced_vertex_id(), Some(0));
1646        assert_eq!(particles[5].decay_vertex_id(), None);
1647        for particle in particles {
1648            assert_eq!(particle.p4_label(), Some(particle.id()));
1649        }
1650        let vertices = batch.layout().vertices();
1651        assert_eq!(vertices.len(), 2);
1652        assert_eq!(vertices[0].vertex_id(), 0);
1653        assert_eq!(vertices[0].kind(), GeneratedVertexKind::Production);
1654        assert_eq!(vertices[0].incoming_product_ids(), &[0, 1]);
1655        assert_eq!(vertices[0].outgoing_product_ids(), &[2, 5]);
1656        assert_eq!(vertices[1].vertex_id(), 1);
1657        assert_eq!(vertices[1].kind(), GeneratedVertexKind::Decay);
1658        assert_eq!(vertices[1].incoming_product_ids(), &[2]);
1659        assert_eq!(vertices[1].outgoing_product_ids(), &[3, 4]);
1660
1661        assert_eq!(batch.layout().particle("kk"), Some(&particles[2]));
1662        assert_eq!(batch.layout().particle("missing_id"), None);
1663        assert_eq!(batch.layout().product(5), Some(&particles[5]));
1664        assert_eq!(batch.layout().product(6), None);
1665        assert_eq!(batch.layout().vertex(1), Some(&vertices[1]));
1666        assert_eq!(batch.layout().vertex(2), None);
1667        assert_eq!(batch.layout().production_vertex(), Some(&vertices[0]));
1668        assert_eq!(
1669            batch
1670                .layout()
1671                .production_incoming()
1672                .iter()
1673                .map(|particle| particle.id())
1674                .collect::<Vec<_>>(),
1675            vec!["beam", "target"]
1676        );
1677        assert_eq!(
1678            batch
1679                .layout()
1680                .production_outgoing()
1681                .iter()
1682                .map(|particle| particle.id())
1683                .collect::<Vec<_>>(),
1684            vec!["kk", "recoil"]
1685        );
1686        assert_eq!(
1687            batch
1688                .layout()
1689                .decay_products(2)
1690                .iter()
1691                .map(|particle| particle.id())
1692                .collect::<Vec<_>>(),
1693            vec!["kshort1", "kshort2"]
1694        );
1695        assert!(batch.layout().decay_products(5).is_empty());
1696    }
1697
1698    #[test]
1699    fn generated_storage_only_projects_dataset_columns() {
1700        let generated = demo_reaction();
1701        let generator = EventGenerator::new(generated, HashMap::new(), Some(12345))
1702            .with_storage(GeneratedStorage::only([
1703                "beam", "target", "kshort1", "kshort2", "recoil",
1704            ]))
1705            .unwrap();
1706        let batch = generator.generate_batch(4).unwrap();
1707
1708        assert_eq!(
1709            batch.reaction().p4_labels(),
1710            vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1711        );
1712        assert_eq!(
1713            batch.layout().p4_labels(),
1714            &["beam", "target", "kshort1", "kshort2", "recoil"]
1715        );
1716        assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1717        assert!(batch.dataset().metadata().p4_index("kk").is_none());
1718
1719        let particles = batch.layout().particles();
1720        assert_eq!(particles.len(), 6);
1721        assert_eq!(particles[2].id(), "kk");
1722        assert_eq!(particles[2].p4_label(), None);
1723        assert_eq!(particles[3].p4_label(), Some("kshort1"));
1724        assert_eq!(particles[4].p4_label(), Some("kshort2"));
1725
1726        for index in 0..batch.dataset().n_events() {
1727            let event = batch.dataset().event_global(index).unwrap();
1728            assert_relative_eq!(
1729                event.p4("beam").unwrap() + event.p4("target").unwrap(),
1730                event.p4("kshort1").unwrap()
1731                    + event.p4("kshort2").unwrap()
1732                    + event.p4("recoil").unwrap(),
1733                epsilon = 1e-10
1734            );
1735        }
1736    }
1737
1738    #[test]
1739    fn generated_storage_rejects_unknown_and_duplicate_ids() {
1740        assert!(
1741            EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1742                .with_storage(GeneratedStorage::only(["beam", "does_not_exist"]))
1743                .is_err()
1744        );
1745        assert!(
1746            EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1747                .with_storage(GeneratedStorage::only(["beam", "beam"]))
1748                .is_err()
1749        );
1750    }
1751
1752    #[test]
1753    fn generated_species_metadata_propagates_to_layout() {
1754        let beam = GeneratedParticle::initial(
1755            "beam",
1756            InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1757            Reconstruction::Stored,
1758        )
1759        .with_species(ParticleSpecies::code(22));
1760        let target = GeneratedParticle::initial(
1761            "target",
1762            InitialGenerator::target(0.938272),
1763            Reconstruction::Missing,
1764        )
1765        .with_species(ParticleSpecies::with_namespace("pdg", 2212));
1766        let kshort1 = GeneratedParticle::stable(
1767            "kshort1",
1768            StableGenerator::new(0.497611),
1769            Reconstruction::Stored,
1770        )
1771        .with_species(ParticleSpecies::label("KShort"));
1772        let kshort2 = GeneratedParticle::stable(
1773            "kshort2",
1774            StableGenerator::new(0.497611),
1775            Reconstruction::Stored,
1776        )
1777        .with_species(ParticleSpecies::label("KShort"));
1778        let kk = GeneratedParticle::composite(
1779            "kk",
1780            CompositeGenerator::new(1.1, 1.6),
1781            (&kshort1, &kshort2),
1782            Reconstruction::Composite,
1783        )
1784        .with_species(ParticleSpecies::label("KK"));
1785        let recoil = GeneratedParticle::stable(
1786            "recoil",
1787            StableGenerator::new(0.938272),
1788            Reconstruction::Stored,
1789        )
1790        .with_species(ParticleSpecies::code(2212));
1791        let reaction = GeneratedReaction::two_to_two(
1792            beam,
1793            target,
1794            kk,
1795            recoil,
1796            MandelstamTDistribution::Exponential { slope: 0.1 },
1797        )
1798        .unwrap();
1799        let particles = reaction.particle_layouts();
1800
1801        assert_eq!(particles[0].species(), Some(&ParticleSpecies::code(22)));
1802        assert_eq!(
1803            particles[1].species(),
1804            Some(&ParticleSpecies::with_namespace("pdg", 2212))
1805        );
1806        assert_eq!(particles[2].species(), Some(&ParticleSpecies::label("KK")));
1807        assert_eq!(
1808            particles[3].species(),
1809            Some(&ParticleSpecies::label("KShort"))
1810        );
1811        assert_eq!(
1812            particles[4].species(),
1813            Some(&ParticleSpecies::label("KShort"))
1814        );
1815        assert_eq!(particles[5].species(), Some(&ParticleSpecies::code(2212)));
1816    }
1817
1818    #[test]
1819    fn generated_batches_match_one_shot_generation() {
1820        let generated = demo_reaction();
1821        let generator = EventGenerator::new(
1822            generated,
1823            HashMap::from([(
1824                "pol_angle".to_string(),
1825                Distribution::Uniform { min: 0.0, max: 1.0 },
1826            )]),
1827            Some(12345),
1828        );
1829        let one_shot = generator.generate_dataset(7).unwrap();
1830        let batches = generator
1831            .generate_batches(7, 3)
1832            .unwrap()
1833            .collect::<LadduResult<Vec<_>>>()
1834            .unwrap();
1835        let batch_sizes = batches
1836            .iter()
1837            .map(|batch| batch.dataset().n_events())
1838            .collect::<Vec<_>>();
1839        assert_eq!(batch_sizes, vec![3, 3, 1]);
1840
1841        let mut offset = 0;
1842        for batch in batches {
1843            for local_index in 0..batch.dataset().n_events() {
1844                let expected = one_shot.event_global(offset + local_index).unwrap();
1845                let actual = batch.dataset().event_global(local_index).unwrap();
1846                for name in one_shot.p4_names() {
1847                    assert_relative_eq!(
1848                        actual.p4(name).unwrap(),
1849                        expected.p4(name).unwrap(),
1850                        epsilon = 1e-10
1851                    );
1852                }
1853                for aux_index in 0..one_shot.aux_names().len() {
1854                    assert_relative_eq!(actual.aux[aux_index], expected.aux[aux_index]);
1855                }
1856                assert_relative_eq!(actual.weight(), expected.weight());
1857            }
1858            offset += batch.dataset().n_events();
1859        }
1860        assert_eq!(offset, one_shot.n_events());
1861        assert!(generator.generate_batches(1, 0).is_err());
1862    }
1863
1864    #[test]
1865    fn fixed_envelope_rejection_sampler_streams_accepted_batches() {
1866        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
1867        let sampler = RejectionSampler::new(
1868            generator,
1869            |batch: &GeneratedBatch| Ok(vec![1.0; batch.dataset().n_events()]),
1870            RejectionSamplingOptions {
1871                target_accepted: 5,
1872                generation_batch_size: 4,
1873                output_batch_size: 2,
1874                envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
1875                seed: 67890,
1876            },
1877        )
1878        .unwrap();
1879
1880        let mut iter = sampler.accepted_batches();
1881        let mut accepted_batches = Vec::new();
1882        for batch in iter.by_ref() {
1883            accepted_batches.push(batch.unwrap());
1884        }
1885        assert_eq!(
1886            accepted_batches
1887                .iter()
1888                .map(|batch| batch.dataset().n_events())
1889                .collect::<Vec<_>>(),
1890            vec![2, 2, 1]
1891        );
1892        assert_eq!(iter.diagnostics().generated_events, 8);
1893        assert_eq!(iter.diagnostics().accepted_events, 5);
1894        assert_eq!(iter.diagnostics().rejected_events, 0);
1895        assert_relative_eq!(iter.diagnostics().acceptance_efficiency(), 5.0 / 8.0);
1896        for batch in accepted_batches {
1897            assert_eq!(
1898                batch.layout().p4_labels(),
1899                &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1900            );
1901        }
1902    }
1903
1904    #[test]
1905    fn fixed_envelope_rejection_sampler_rejects_violations() {
1906        let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
1907        let sampler = RejectionSampler::new(
1908            generator,
1909            |batch: &GeneratedBatch| Ok(vec![2.0; batch.dataset().n_events()]),
1910            RejectionSamplingOptions {
1911                target_accepted: 1,
1912                generation_batch_size: 1,
1913                output_batch_size: 1,
1914                envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
1915                seed: 67890,
1916            },
1917        )
1918        .unwrap();
1919
1920        let mut iter = sampler.accepted_batches();
1921        let err = iter.next().expect("sampler should produce an error");
1922        assert!(err.is_err());
1923        assert_eq!(iter.diagnostics().envelope_violations, 1);
1924        assert_relative_eq!(iter.diagnostics().max_observed_weight, 2.0);
1925    }
1926
1927    #[test]
1928    fn duplicate_generated_particle_ids_are_rejected() {
1929        let beam = GeneratedParticle::initial(
1930            "beam",
1931            InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1932            Reconstruction::Stored,
1933        );
1934        let target = GeneratedParticle::initial(
1935            "target",
1936            InitialGenerator::target(0.938272),
1937            Reconstruction::Missing,
1938        );
1939        let duplicate = GeneratedParticle::stable(
1940            "beam",
1941            StableGenerator::new(0.497611),
1942            Reconstruction::Stored,
1943        );
1944        let recoil = GeneratedParticle::stable(
1945            "recoil",
1946            StableGenerator::new(0.938272),
1947            Reconstruction::Stored,
1948        );
1949
1950        assert!(GeneratedReaction::two_to_two(
1951            beam,
1952            target,
1953            duplicate,
1954            recoil,
1955            MandelstamTDistribution::Exponential { slope: 0.1 },
1956        )
1957        .is_err());
1958    }
1959}