1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use fastrand::Rng;
7use laddu_core::{
8 math::{q_m, Histogram, Sheet},
9 Dataset, DatasetMetadata, Expression, LadduError, LadduResult, Particle, Reaction, Vec3, Vec4,
10 PI,
11};
12use serde::{Deserialize, Serialize};
13
14use crate::distributions::{
15 Distribution, HistogramSampler, LadduGenRngExt, MandelstamTDistribution, SimpleDistribution,
16};
17
18#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum GeneratedStorage {
25 All,
27 Only(Vec<String>),
29}
30
31impl GeneratedStorage {
32 pub fn all() -> Self {
34 Self::All
35 }
36
37 pub fn only<I, S>(ids: I) -> Self
39 where
40 I: IntoIterator<Item = S>,
41 S: Into<String>,
42 {
43 Self::Only(ids.into_iter().map(Into::into).collect())
44 }
45
46 pub fn stores(&self, id: &str) -> bool {
48 match self {
49 Self::All => true,
50 Self::Only(ids) => ids.iter().any(|stored_id| stored_id == id),
51 }
52 }
53
54 fn validate(&self, available_ids: &[String]) -> LadduResult<()> {
55 let available = available_ids
56 .iter()
57 .map(String::as_str)
58 .collect::<HashSet<_>>();
59 let Self::Only(ids) = self else {
60 return Ok(());
61 };
62 let mut seen = HashSet::new();
63 for id in ids {
64 if !seen.insert(id.as_str()) {
65 return Err(LadduError::Custom(format!(
66 "generated storage contains duplicate particle ID '{id}'"
67 )));
68 }
69 if !available.contains(id.as_str()) {
70 return Err(LadduError::Custom(format!(
71 "generated storage references unknown particle ID '{id}'"
72 )));
73 }
74 }
75 Ok(())
76 }
77
78 fn stored_labels(&self, all_labels: &[String]) -> Vec<String> {
79 all_labels
80 .iter()
81 .filter(|label| self.stores(label))
82 .cloned()
83 .collect()
84 }
85}
86
87#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
93pub enum ParticleSpecies {
94 Code {
96 id: i64,
98 namespace: Option<String>,
100 },
101 Label(String),
103}
104
105impl ParticleSpecies {
106 pub fn code(id: i64) -> Self {
108 Self::Code {
109 id,
110 namespace: None,
111 }
112 }
113
114 pub fn with_namespace(namespace: impl Into<String>, id: i64) -> Self {
116 Self::Code {
117 id,
118 namespace: Some(namespace.into()),
119 }
120 }
121
122 pub fn label(label: impl Into<String>) -> Self {
124 Self::Label(label.into())
125 }
126}
127
128fn basis(z: Vec3) -> (Vec3, Vec3, Vec3) {
129 let z = z.unit();
130 let ref_axis = if z.z.abs() < 0.9 {
131 Vec3::z()
132 } else {
133 Vec3::y()
134 };
135 let x = ref_axis.cross(&z).unit();
136 let y = z.cross(&x);
137 (x, y, z)
138}
139
140#[derive(Clone, Debug)]
142pub struct InitialGenerator {
143 mass: f64,
144 energy_distribution: SimpleDistribution,
145}
146
147impl InitialGenerator {
148 pub fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
150 debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
151 debug_assert!(energy > 0.0, "Energy must be positive!\nEnergy: {}", energy);
152 Self {
153 mass,
154 energy_distribution: SimpleDistribution::Fixed(energy),
155 }
156 }
157
158 pub fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
160 debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
161 debug_assert!(
162 min_energy > 0.0,
163 "Minimum energy must be positive!\nMinimum Energy: {}",
164 min_energy
165 );
166 debug_assert!(
167 max_energy > min_energy,
168 "Maximum energy must be greater than minimum energy!"
169 );
170 Self {
171 mass,
172 energy_distribution: SimpleDistribution::Uniform {
173 min: min_energy,
174 max: max_energy,
175 },
176 }
177 }
178
179 pub fn beam_with_energy_histogram(mass: f64, energy: Histogram) -> LadduResult<Self> {
181 debug_assert!(
182 mass >= 0.0,
183 "Mass must be positive and greater than zero!\nMass: {}",
184 mass
185 );
186 let sampler = HistogramSampler::new(energy)?;
187 debug_assert!(
188 sampler.hist.bin_edges()[0] >= mass,
189 "Mass cannot be greater than the minimum allowed energy!\nMass: {}\nMinimum Energy: {}",
190 mass,
191 sampler.hist.bin_edges()[0]
192 );
193 Ok(Self {
194 mass,
195 energy_distribution: SimpleDistribution::Histogram(sampler),
196 })
197 }
198
199 pub fn target(mass: f64) -> Self {
201 Self {
202 mass,
203 energy_distribution: SimpleDistribution::Fixed(mass),
204 }
205 }
206}
207
208#[derive(Clone, Debug)]
210pub struct CompositeGenerator {
211 mass_distribution: SimpleDistribution,
212}
213
214impl CompositeGenerator {
215 pub fn new(min_mass: f64, max_mass: f64) -> Self {
217 Self {
218 mass_distribution: SimpleDistribution::Uniform {
219 min: min_mass,
220 max: max_mass,
221 },
222 }
223 }
224
225 fn sample_mass(&self, rng: &mut Rng) -> f64 {
226 self.mass_distribution.sample(rng)
227 }
228}
229
230#[derive(Clone, Debug)]
232pub struct StableGenerator {
233 mass_distribution: SimpleDistribution,
234}
235
236impl StableGenerator {
237 pub fn new(mass: f64) -> Self {
239 debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
240 Self {
241 mass_distribution: SimpleDistribution::Fixed(mass),
242 }
243 }
244
245 fn sample_mass(&self, rng: &mut Rng) -> f64 {
246 self.mass_distribution.sample(rng)
247 }
248}
249
250#[derive(Clone, Debug, PartialEq)]
252pub enum Reconstruction {
253 Stored,
255 Fixed(Vec4),
257 Missing,
259 Composite,
261}
262
263#[derive(Clone, Debug)]
265pub enum GeneratedParticle {
266 Initial {
268 id: String,
269 generator: InitialGenerator,
270 reconstruction: Reconstruction,
271 species: Option<ParticleSpecies>,
272 },
273 Stable {
275 id: String,
276 generator: StableGenerator,
277 reconstruction: Reconstruction,
278 species: Option<ParticleSpecies>,
279 },
280 Composite {
282 id: String,
283 generator: CompositeGenerator,
284 daughters: (Box<GeneratedParticle>, Box<GeneratedParticle>),
285 reconstruction: Reconstruction,
286 species: Option<ParticleSpecies>,
287 },
288}
289
290impl GeneratedParticle {
291 pub fn initial(
293 id: impl Into<String>,
294 generator: InitialGenerator,
295 reconstruction: Reconstruction,
296 ) -> Self {
297 Self::Initial {
298 id: id.into(),
299 generator,
300 reconstruction,
301 species: None,
302 }
303 }
304
305 pub fn stable(
307 id: impl Into<String>,
308 generator: StableGenerator,
309 reconstruction: Reconstruction,
310 ) -> Self {
311 Self::Stable {
312 id: id.into(),
313 generator,
314 reconstruction,
315 species: None,
316 }
317 }
318
319 pub fn composite(
321 id: impl Into<String>,
322 generator: CompositeGenerator,
323 daughters: (&GeneratedParticle, &GeneratedParticle),
324 reconstruction: Reconstruction,
325 ) -> Self {
326 Self::Composite {
327 id: id.into(),
328 generator,
329 daughters: (Box::new(daughters.0.clone()), Box::new(daughters.1.clone())),
330 reconstruction,
331 species: None,
332 }
333 }
334
335 pub fn with_species(mut self, species: ParticleSpecies) -> Self {
337 match &mut self {
338 Self::Initial {
339 species: particle_species,
340 ..
341 }
342 | Self::Stable {
343 species: particle_species,
344 ..
345 }
346 | Self::Composite {
347 species: particle_species,
348 ..
349 } => *particle_species = Some(species),
350 }
351 self
352 }
353
354 pub fn id(&self) -> &str {
356 match self {
357 Self::Initial { id, .. } | Self::Stable { id, .. } | Self::Composite { id, .. } => id,
358 }
359 }
360
361 pub fn species(&self) -> Option<&ParticleSpecies> {
363 match self {
364 Self::Initial { species, .. }
365 | Self::Stable { species, .. }
366 | Self::Composite { species, .. } => species.as_ref(),
367 }
368 }
369
370 pub fn reconstruction(&self) -> &Reconstruction {
372 match self {
373 Self::Initial { reconstruction, .. }
374 | Self::Stable { reconstruction, .. }
375 | Self::Composite { reconstruction, .. } => reconstruction,
376 }
377 }
378
379 fn p4_labels(&self) -> Vec<String> {
380 let mut labels = vec![self.id().to_string()];
381 if let Self::Composite { daughters, .. } = self {
382 labels.append(&mut daughters.0.p4_labels());
383 labels.append(&mut daughters.1.p4_labels());
384 }
385 labels
386 }
387
388 fn append_decay_layout(
389 &self,
390 parent_id: Option<usize>,
391 produced_vertex_id: Option<usize>,
392 storage: &GeneratedStorage,
393 particles: &mut Vec<GeneratedParticleLayout>,
394 vertices: &mut Vec<GeneratedVertexLayout>,
395 ) -> usize {
396 let product_id = particles.len();
397 particles.push(GeneratedParticleLayout {
398 id: self.id().to_string(),
399 product_id,
400 parent_id,
401 species: self.species().cloned(),
402 p4_label: storage.stores(self.id()).then(|| self.id().to_string()),
403 produced_vertex_id,
404 decay_vertex_id: None,
405 });
406 if let Self::Composite { daughters, .. } = self {
407 let vertex_id = vertices.len();
408 particles[product_id].decay_vertex_id = Some(vertex_id);
409 vertices.push(GeneratedVertexLayout {
410 vertex_id,
411 kind: GeneratedVertexKind::Decay,
412 incoming_product_ids: vec![product_id],
413 outgoing_product_ids: Vec::new(),
414 });
415 let daughter_1_id = daughters.0.append_decay_layout(
416 Some(product_id),
417 Some(vertex_id),
418 storage,
419 particles,
420 vertices,
421 );
422 let daughter_2_id = daughters.1.append_decay_layout(
423 Some(product_id),
424 Some(vertex_id),
425 storage,
426 particles,
427 vertices,
428 );
429 vertices[vertex_id].outgoing_product_ids = vec![daughter_1_id, daughter_2_id];
430 }
431 product_id
432 }
433
434 fn sample_mass(&self, rng: &mut Rng) -> f64 {
435 match self {
436 Self::Initial { generator, .. } => generator.mass,
437 Self::Stable { generator, .. } => generator.sample_mass(rng),
438 Self::Composite { generator, .. } => generator.sample_mass(rng),
439 }
440 }
441
442 fn generated_particle(&self) -> LadduResult<Particle> {
443 match self.reconstruction() {
444 Reconstruction::Stored => Ok(Particle::stored(self.id())),
445 Reconstruction::Fixed(p4) => Ok(Particle::fixed(self.id(), *p4)),
446 Reconstruction::Missing => Ok(Particle::missing(self.id())),
447 Reconstruction::Composite => {
448 let Self::Composite { daughters, .. } = self else {
449 return Err(LadduError::Custom(format!(
450 "particle '{}' cannot use composite reconstruction without daughters",
451 self.id()
452 )));
453 };
454 let daughter_1 = daughters.0.generated_particle()?;
455 let daughter_2 = daughters.1.generated_particle()?;
456 Particle::composite(self.id(), (&daughter_1, &daughter_2))
457 }
458 }
459 }
460
461 fn validate_reconstruction(&self) -> LadduResult<()> {
462 match (self, self.reconstruction()) {
463 (Self::Composite { daughters, .. }, Reconstruction::Composite) => {
464 daughters.0.validate_reconstruction()?;
465 daughters.1.validate_reconstruction()?;
466 Ok(())
467 }
468 (Self::Composite { .. }, _) => Ok(()),
469 (_, Reconstruction::Composite) => Err(LadduError::Custom(format!(
470 "particle '{}' cannot use composite reconstruction without daughters",
471 self.id()
472 ))),
473 _ => Ok(()),
474 }
475 }
476
477 fn collect_ids<'a>(&'a self, seen: &mut HashSet<&'a str>) -> LadduResult<()> {
478 if !seen.insert(self.id()) {
479 return Err(LadduError::Custom(format!(
480 "duplicate generated particle identifier '{}'",
481 self.id()
482 )));
483 }
484 if let Self::Composite { daughters, .. } = self {
485 daughters.0.collect_ids(seen)?;
486 daughters.1.collect_ids(seen)?;
487 }
488 Ok(())
489 }
490
491 fn generate_decay(
492 &self,
493 rng: &mut Rng,
494 p4_cm: Vec4,
495 cm_to_lab_boost: &Vec3,
496 p4_storage: &mut HashMap<String, Vec<Vec4>>,
497 ) {
498 let p4_lab = p4_cm.boost(cm_to_lab_boost);
499 if let Some(storage) = p4_storage.get_mut(self.id()) {
500 storage.push(p4_lab);
501 }
502
503 let Self::Composite { daughters, .. } = self else {
504 return;
505 };
506 let d1 = &daughters.0;
507 let d2 = &daughters.1;
508 let parent_mass = p4_cm.m();
509 let m1 = d1.sample_mass(rng);
510 let m2 = d2.sample_mass(rng);
511 let q = q_m(parent_mass, m1, m2, Sheet::Physical).re;
512 let parent_msq = parent_mass * parent_mass;
513 let msq1 = m1 * m1;
514 let msq2 = m2 * m2;
515 let e1 = (parent_msq + msq1 - msq2) / (2.0 * parent_mass);
516 let e2 = (parent_msq + msq2 - msq1) / (2.0 * parent_mass);
517
518 let cos_theta = rng.uniform(-1.0, 1.0);
519 let sin_theta = (1.0 - cos_theta * cos_theta).sqrt();
520 let phi = rng.uniform(0.0, 2.0 * PI);
521 let (sin_phi, cos_phi) = phi.sin_cos();
522
523 let dir = Vec3::new(sin_theta * cos_phi, sin_theta * sin_phi, cos_theta);
524 let p1_p4_rest = (dir * q).with_energy(e1);
525 let p2_p4_rest = (-dir * q).with_energy(e2);
526 let parent_to_cm_boost = p4_cm.beta();
527 let p1_p4_cm = p1_p4_rest.boost(&parent_to_cm_boost);
528 let p2_p4_cm = p2_p4_rest.boost(&parent_to_cm_boost);
529 d1.generate_decay(rng, p1_p4_cm, cm_to_lab_boost, p4_storage);
530 d2.generate_decay(rng, p2_p4_cm, cm_to_lab_boost, p4_storage);
531 }
532}
533
534#[derive(Clone, Debug)]
536pub struct GeneratedTwoToTwoReaction {
537 p1: GeneratedParticle,
538 p2: GeneratedParticle,
539 p3: GeneratedParticle,
540 p4: GeneratedParticle,
541 tdist: MandelstamTDistribution,
542 p1_p3_lab_dir: Vec3,
543 p2_p3_lab_dir: Vec3,
544}
545
546impl GeneratedTwoToTwoReaction {
547 pub fn new(
549 p1: GeneratedParticle,
550 p2: GeneratedParticle,
551 p3: GeneratedParticle,
552 p4: GeneratedParticle,
553 tdist: MandelstamTDistribution,
554 ) -> LadduResult<Self> {
555 validate_initial_role(&p1, "p1")?;
556 validate_initial_role(&p2, "p2")?;
557 validate_final_role(&p3, "p3")?;
558 validate_final_role(&p4, "p4")?;
559 let reaction = Self {
560 p1,
561 p2,
562 p3,
563 p4,
564 tdist,
565 p1_p3_lab_dir: Vec3::z(),
566 p2_p3_lab_dir: -Vec3::z(),
567 };
568 reaction.validate()?;
569 Ok(reaction)
570 }
571
572 fn validate(&self) -> LadduResult<()> {
573 let mut seen = HashSet::new();
574 for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
575 particle.collect_ids(&mut seen)?;
576 particle.validate_reconstruction()?;
577 }
578 Ok(())
579 }
580
581 fn p4_labels(&self) -> Vec<String> {
582 let mut labels = Vec::new();
583 for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
584 labels.append(&mut particle.p4_labels());
585 }
586 labels
587 }
588
589 fn layout_components(
590 &self,
591 storage: &GeneratedStorage,
592 ) -> (Vec<GeneratedParticleLayout>, Vec<GeneratedVertexLayout>) {
593 let mut particles = Vec::new();
594 let mut vertices = vec![GeneratedVertexLayout {
595 vertex_id: 0,
596 kind: GeneratedVertexKind::Production,
597 incoming_product_ids: Vec::new(),
598 outgoing_product_ids: Vec::new(),
599 }];
600 let p1_id = self
601 .p1
602 .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
603 let p2_id = self
604 .p2
605 .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
606 let p3_id =
607 self.p3
608 .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
609 let p4_id =
610 self.p4
611 .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
612 vertices[0].incoming_product_ids = vec![p1_id, p2_id];
613 vertices[0].outgoing_product_ids = vec![p3_id, p4_id];
614 (particles, vertices)
615 }
616
617 fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
618 self.particle_layouts_with_storage(&GeneratedStorage::All)
619 }
620
621 fn particle_layouts_with_storage(
622 &self,
623 storage: &GeneratedStorage,
624 ) -> Vec<GeneratedParticleLayout> {
625 self.layout_components(storage).0
626 }
627
628 fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
629 self.layout_components(&GeneratedStorage::All).1
630 }
631
632 fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
633 Reaction::two_to_two(
634 &self.p1.generated_particle()?,
635 &self.p2.generated_particle()?,
636 &self.p3.generated_particle()?,
637 &self.p4.generated_particle()?,
638 )
639 }
640
641 fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
642 let GeneratedParticle::Initial {
643 id: p1_id,
644 generator: p1_generator,
645 ..
646 } = &self.p1
647 else {
648 unreachable!("validated generated two-to-two p1 role")
649 };
650 let GeneratedParticle::Initial {
651 id: p2_id,
652 generator: p2_generator,
653 ..
654 } = &self.p2
655 else {
656 unreachable!("validated generated two-to-two p2 role")
657 };
658
659 let p1_e = p1_generator.energy_distribution.sample(rng);
660 let p1_m = p1_generator.mass;
661 let p1_msq = p1_m * p1_m;
662 let p1_p4_lab = rng.p4(p1_m, p1_e, self.p1_p3_lab_dir);
663 if let Some(storage) = p4_storage.get_mut(p1_id) {
664 storage.push(p1_p4_lab)
665 }
666
667 let p2_e = p2_generator.energy_distribution.sample(rng);
668 let p2_m = p2_generator.mass;
669 let p2_msq = p2_m * p2_m;
670 let p2_p4_lab = rng.p4(p2_m, p2_e, self.p2_p3_lab_dir);
671 if let Some(storage) = p4_storage.get_mut(p2_id) {
672 storage.push(p2_p4_lab)
673 }
674
675 let cm = p1_p4_lab + p2_p4_lab;
676 let cm_boost = -cm.beta();
677 let s = cm.mag2();
678 let sqrt_s = s.sqrt();
679 let t = self.tdist.sample(rng);
680
681 let p1_p4_cm = p1_p4_lab.boost(&cm_boost);
682 let p3_m = self.p3.sample_mass(rng);
683 let p3_msq = p3_m * p3_m;
684 let p4_m = self.p4.sample_mass(rng);
685 let p4_msq = p4_m * p4_m;
686 let p_in_mag = q_m(sqrt_s, p1_m, p2_m, Sheet::Physical).re;
687 let p_out_mag = q_m(sqrt_s, p3_m, p4_m, Sheet::Physical).re;
688 let p1_e_cm = (s + p1_msq - p2_msq) / (2.0 * sqrt_s);
689 let p3_e_cm = (s + p3_msq - p4_msq) / (2.0 * sqrt_s);
690 let p4_e_cm = (s + p4_msq - p3_msq) / (2.0 * sqrt_s);
691 let costheta =
692 (t - p1_msq - p3_msq + 2.0 * p1_e_cm * p3_e_cm) / (2.0 * p_in_mag * p_out_mag);
693 let costheta = costheta.clamp(-1.0, 1.0);
694 let sintheta = (1.0 - costheta * costheta).sqrt();
695 let phi = rng.uniform(0.0, 2.0 * PI);
696 let (sin_phi, cos_phi) = phi.sin_cos();
697 let (x, y, z) = basis(p1_p4_cm.vec3());
698 let p3_dir_cm = x * (sintheta * cos_phi) + y * (sintheta * sin_phi) + z * costheta;
699
700 let p3_p4_cm = (p3_dir_cm * p_out_mag).with_energy(p3_e_cm);
701 self.p3
702 .generate_decay(rng, p3_p4_cm, &-cm_boost, p4_storage);
703 let p4_p4_cm = (-p3_dir_cm * p_out_mag).with_energy(p4_e_cm);
704 self.p4
705 .generate_decay(rng, p4_p4_cm, &-cm_boost, p4_storage);
706 }
707}
708
709fn validate_initial_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
710 if matches!(particle, GeneratedParticle::Initial { .. }) {
711 Ok(())
712 } else {
713 Err(LadduError::Custom(format!(
714 "generated two-to-two role '{role}' requires an initial particle"
715 )))
716 }
717}
718
719fn validate_final_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
720 if matches!(
721 particle,
722 GeneratedParticle::Stable { .. } | GeneratedParticle::Composite { .. }
723 ) {
724 Ok(())
725 } else {
726 Err(LadduError::Custom(format!(
727 "generated two-to-two role '{role}' requires an outgoing particle"
728 )))
729 }
730}
731
732#[derive(Clone, Debug)]
734pub enum GeneratedReactionTopology {
735 TwoToTwo(GeneratedTwoToTwoReaction),
737}
738
739impl GeneratedReactionTopology {
740 fn p4_labels(&self) -> Vec<String> {
741 match self {
742 Self::TwoToTwo(reaction) => reaction.p4_labels(),
743 }
744 }
745
746 fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
747 match self {
748 Self::TwoToTwo(reaction) => reaction.particle_layouts(),
749 }
750 }
751
752 fn particle_layouts_with_storage(
753 &self,
754 storage: &GeneratedStorage,
755 ) -> Vec<GeneratedParticleLayout> {
756 match self {
757 Self::TwoToTwo(reaction) => reaction.particle_layouts_with_storage(storage),
758 }
759 }
760
761 fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
762 match self {
763 Self::TwoToTwo(reaction) => reaction.vertex_layouts(),
764 }
765 }
766
767 fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
768 match self {
769 Self::TwoToTwo(reaction) => reaction.reconstructed_reaction(),
770 }
771 }
772
773 fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
774 match self {
775 Self::TwoToTwo(reaction) => reaction.generate_event(rng, p4_storage),
776 }
777 }
778}
779
780#[derive(Clone, Debug)]
782pub struct GeneratedReaction {
783 topology: GeneratedReactionTopology,
784}
785
786impl GeneratedReaction {
787 pub fn two_to_two(
789 p1: GeneratedParticle,
790 p2: GeneratedParticle,
791 p3: GeneratedParticle,
792 p4: GeneratedParticle,
793 tdist: MandelstamTDistribution,
794 ) -> LadduResult<Self> {
795 Ok(Self {
796 topology: GeneratedReactionTopology::TwoToTwo(GeneratedTwoToTwoReaction::new(
797 p1, p2, p3, p4, tdist,
798 )?),
799 })
800 }
801
802 pub fn p4_labels(&self) -> Vec<String> {
804 self.topology.p4_labels()
805 }
806
807 pub fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
809 self.topology.particle_layouts()
810 }
811
812 pub fn particle_layouts_with_storage(
814 &self,
815 storage: &GeneratedStorage,
816 ) -> Vec<GeneratedParticleLayout> {
817 self.topology.particle_layouts_with_storage(storage)
818 }
819
820 pub fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
822 self.topology.vertex_layouts()
823 }
824
825 pub fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
827 self.topology.reconstructed_reaction()
828 }
829
830 fn generate(
831 &self,
832 rng: &mut Rng,
833 p4_storage: &mut HashMap<String, Vec<Vec4>>,
834 n_events: usize,
835 ) {
836 for _ in 0..n_events {
837 self.topology.generate_event(rng, p4_storage);
838 }
839 }
840}
841
842#[derive(Clone, Debug, PartialEq, Eq)]
844pub struct GeneratedParticleLayout {
845 id: String,
846 product_id: usize,
847 parent_id: Option<usize>,
848 species: Option<ParticleSpecies>,
849 p4_label: Option<String>,
850 produced_vertex_id: Option<usize>,
851 decay_vertex_id: Option<usize>,
852}
853
854impl GeneratedParticleLayout {
855 pub fn id(&self) -> &str {
857 &self.id
858 }
859
860 pub fn product_id(&self) -> usize {
862 self.product_id
863 }
864
865 pub fn parent_id(&self) -> Option<usize> {
867 self.parent_id
868 }
869
870 pub fn species(&self) -> Option<&ParticleSpecies> {
872 self.species.as_ref()
873 }
874
875 pub fn p4_label(&self) -> Option<&str> {
877 self.p4_label.as_deref()
878 }
879
880 pub fn produced_vertex_id(&self) -> Option<usize> {
882 self.produced_vertex_id
883 }
884
885 pub fn decay_vertex_id(&self) -> Option<usize> {
887 self.decay_vertex_id
888 }
889}
890
891#[derive(Clone, Copy, Debug, PartialEq, Eq)]
893pub enum GeneratedVertexKind {
894 Production,
896 Decay,
898}
899
900#[derive(Clone, Debug, PartialEq, Eq)]
902pub struct GeneratedVertexLayout {
903 vertex_id: usize,
904 kind: GeneratedVertexKind,
905 incoming_product_ids: Vec<usize>,
906 outgoing_product_ids: Vec<usize>,
907}
908
909impl GeneratedVertexLayout {
910 pub fn vertex_id(&self) -> usize {
912 self.vertex_id
913 }
914
915 pub fn kind(&self) -> GeneratedVertexKind {
917 self.kind
918 }
919
920 pub fn incoming_product_ids(&self) -> &[usize] {
922 &self.incoming_product_ids
923 }
924
925 pub fn outgoing_product_ids(&self) -> &[usize] {
927 &self.outgoing_product_ids
928 }
929}
930
931#[derive(Clone, Debug, PartialEq, Eq)]
933pub struct GeneratedEventLayout {
934 p4_labels: Vec<String>,
935 aux_labels: Vec<String>,
936 particles: Vec<GeneratedParticleLayout>,
937 vertices: Vec<GeneratedVertexLayout>,
938}
939
940impl GeneratedEventLayout {
941 pub fn new(
943 p4_labels: Vec<String>,
944 aux_labels: Vec<String>,
945 particles: Vec<GeneratedParticleLayout>,
946 vertices: Vec<GeneratedVertexLayout>,
947 ) -> Self {
948 Self {
949 p4_labels,
950 aux_labels,
951 particles,
952 vertices,
953 }
954 }
955
956 pub fn p4_labels(&self) -> &[String] {
958 &self.p4_labels
959 }
960
961 pub fn aux_labels(&self) -> &[String] {
963 &self.aux_labels
964 }
965
966 pub fn particles(&self) -> &[GeneratedParticleLayout] {
968 &self.particles
969 }
970
971 pub fn particle(&self, id: &str) -> Option<&GeneratedParticleLayout> {
973 self.particles.iter().find(|particle| particle.id() == id)
974 }
975
976 pub fn product(&self, product_id: usize) -> Option<&GeneratedParticleLayout> {
978 self.particles
979 .iter()
980 .find(|particle| particle.product_id() == product_id)
981 }
982
983 pub fn vertices(&self) -> &[GeneratedVertexLayout] {
985 &self.vertices
986 }
987
988 pub fn vertex(&self, vertex_id: usize) -> Option<&GeneratedVertexLayout> {
990 self.vertices
991 .iter()
992 .find(|vertex| vertex.vertex_id() == vertex_id)
993 }
994
995 pub fn production_vertex(&self) -> Option<&GeneratedVertexLayout> {
997 self.vertices
998 .iter()
999 .find(|vertex| vertex.kind() == GeneratedVertexKind::Production)
1000 }
1001
1002 pub fn decay_products(&self, parent_product_id: usize) -> Vec<&GeneratedParticleLayout> {
1004 self.particles
1005 .iter()
1006 .filter(|particle| particle.parent_id() == Some(parent_product_id))
1007 .collect()
1008 }
1009
1010 pub fn production_incoming(&self) -> Vec<&GeneratedParticleLayout> {
1012 self.production_vertex_products(GeneratedVertexLayout::incoming_product_ids)
1013 }
1014
1015 pub fn production_outgoing(&self) -> Vec<&GeneratedParticleLayout> {
1017 self.production_vertex_products(GeneratedVertexLayout::outgoing_product_ids)
1018 }
1019
1020 fn production_vertex_products(
1021 &self,
1022 ids: impl FnOnce(&GeneratedVertexLayout) -> &[usize],
1023 ) -> Vec<&GeneratedParticleLayout> {
1024 self.production_vertex()
1025 .map(|vertex| {
1026 ids(vertex)
1027 .iter()
1028 .filter_map(|product_id| self.product(*product_id))
1029 .collect()
1030 })
1031 .unwrap_or_default()
1032 }
1033}
1034
1035#[derive(Clone, Debug)]
1037pub struct GeneratedBatch {
1038 dataset: Dataset,
1039 reaction: GeneratedReaction,
1040 layout: GeneratedEventLayout,
1041}
1042
1043impl GeneratedBatch {
1044 pub fn new(
1046 dataset: Dataset,
1047 reaction: GeneratedReaction,
1048 layout: GeneratedEventLayout,
1049 ) -> Self {
1050 Self {
1051 dataset,
1052 reaction,
1053 layout,
1054 }
1055 }
1056
1057 pub fn dataset(&self) -> &Dataset {
1059 &self.dataset
1060 }
1061
1062 pub fn into_dataset(self) -> Dataset {
1064 self.dataset
1065 }
1066
1067 pub fn reaction(&self) -> &GeneratedReaction {
1069 &self.reaction
1070 }
1071
1072 pub fn layout(&self) -> &GeneratedEventLayout {
1074 &self.layout
1075 }
1076}
1077
1078#[derive(Clone, Debug)]
1080pub struct EventGenerator {
1081 reaction: GeneratedReaction,
1082 aux_generators: HashMap<String, Distribution>,
1083 storage: GeneratedStorage,
1084 seed: u64,
1085}
1086
1087#[derive(Clone, Debug)]
1089pub struct GeneratedBatchIter {
1090 generator: EventGenerator,
1091 remaining_events: usize,
1092 batch_size: usize,
1093 rng: Rng,
1094}
1095
1096impl Iterator for GeneratedBatchIter {
1097 type Item = LadduResult<GeneratedBatch>;
1098
1099 fn next(&mut self) -> Option<Self::Item> {
1100 if self.remaining_events == 0 {
1101 return None;
1102 }
1103 let n_events = self.batch_size.min(self.remaining_events);
1104 self.remaining_events -= n_events;
1105 Some(
1106 self.generator
1107 .generate_batch_with_rng(n_events, &mut self.rng),
1108 )
1109 }
1110}
1111
1112pub trait BatchIntensity {
1114 fn evaluate(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>>;
1116
1117 fn evaluate_checked(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1119 let intensities = self.evaluate(batch)?;
1120 if intensities.len() != batch.dataset().n_events() {
1121 return Err(LadduError::Custom(format!(
1122 "intensity length mismatch: expected {}, got {}",
1123 batch.dataset().n_events(),
1124 intensities.len()
1125 )));
1126 }
1127 for (index, weight) in intensities.iter().enumerate() {
1128 if !weight.is_finite() || *weight < 0.0 {
1129 return Err(LadduError::Custom(format!(
1130 "intensity at event {index} must be finite and nonnegative, got {weight}"
1131 )));
1132 }
1133 }
1134 Ok(intensities)
1135 }
1136}
1137
1138impl<F> BatchIntensity for F
1139where
1140 F: Fn(&GeneratedBatch) -> LadduResult<Vec<f64>>,
1141{
1142 fn evaluate(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1143 self(batch)
1144 }
1145}
1146
1147#[derive(Clone, Debug)]
1149pub struct ExpressionIntensity {
1150 expression: Expression,
1151 parameters: Vec<f64>,
1152}
1153
1154impl ExpressionIntensity {
1155 pub fn new(expression: Expression, parameters: Vec<f64>) -> Self {
1157 Self {
1158 expression,
1159 parameters,
1160 }
1161 }
1162}
1163
1164impl BatchIntensity for ExpressionIntensity {
1165 fn evaluate(&self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1166 let dataset = Arc::new(batch.dataset().clone());
1167 let evaluator = self.expression.load(&dataset)?;
1168 evaluator
1169 .evaluate(&self.parameters)?
1170 .into_iter()
1171 .enumerate()
1172 .map(|(index, value)| {
1173 if !value.im.is_finite() || value.im.abs() > f64::EPSILON {
1174 return Err(LadduError::Custom(format!(
1175 "expression intensity at event {index} must be real-valued, got {value}"
1176 )));
1177 }
1178 Ok(value.re)
1179 })
1180 .collect()
1181 }
1182}
1183
1184#[derive(Clone, Debug)]
1186pub enum RejectionEnvelope {
1187 Fixed {
1189 max_weight: f64,
1191 },
1192 Pilot {
1194 pilot_events: usize,
1196 batch_size: Option<usize>,
1198 safety_factor: f64,
1200 },
1201}
1202
1203#[derive(Clone, Debug)]
1205pub struct RejectionSamplingOptions {
1206 pub target_accepted: usize,
1208 pub generation_batch_size: usize,
1210 pub output_batch_size: usize,
1212 pub envelope: RejectionEnvelope,
1214 pub seed: u64,
1216}
1217
1218#[derive(Clone, Debug, Default)]
1220pub struct RejectionSamplingDiagnostics {
1221 pub generated_events: usize,
1223 pub accepted_events: usize,
1225 pub rejected_events: usize,
1227 pub max_observed_weight: f64,
1229 pub envelope_max_weight: f64,
1231 pub envelope_violations: usize,
1233}
1234
1235impl RejectionSamplingDiagnostics {
1236 pub fn acceptance_efficiency(&self) -> f64 {
1238 if self.generated_events == 0 {
1239 0.0
1240 } else {
1241 self.accepted_events as f64 / self.generated_events as f64
1242 }
1243 }
1244}
1245
1246#[derive(Clone, Debug)]
1248pub struct RejectionSampler<I> {
1249 generator: EventGenerator,
1250 intensity: I,
1251 max_weight: f64,
1252 options: RejectionSamplingOptions,
1253}
1254
1255impl<I> RejectionSampler<I>
1256where
1257 I: BatchIntensity,
1258{
1259 pub fn new(
1261 generator: EventGenerator,
1262 intensity: I,
1263 options: RejectionSamplingOptions,
1264 ) -> LadduResult<Self> {
1265 if options.generation_batch_size == 0 {
1266 return Err(LadduError::Custom(
1267 "generation_batch_size must be greater than zero".to_string(),
1268 ));
1269 }
1270 if options.output_batch_size == 0 {
1271 return Err(LadduError::Custom(
1272 "output_batch_size must be greater than zero".to_string(),
1273 ));
1274 }
1275 let max_weight = match options.envelope {
1276 RejectionEnvelope::Fixed { max_weight } => max_weight,
1277 RejectionEnvelope::Pilot {
1278 pilot_events,
1279 batch_size,
1280 safety_factor,
1281 } => estimate_rejection_envelope(
1282 &generator,
1283 &intensity,
1284 pilot_events,
1285 batch_size.unwrap_or(options.generation_batch_size),
1286 safety_factor,
1287 )?,
1288 };
1289 if !max_weight.is_finite() || max_weight <= 0.0 {
1290 return Err(LadduError::Custom(
1291 "rejection envelope max_weight must be finite and positive".to_string(),
1292 ));
1293 }
1294 Ok(Self {
1295 generator,
1296 intensity,
1297 max_weight,
1298 options,
1299 })
1300 }
1301
1302 pub fn accepted_batches(self) -> RejectionSampleIter<I> {
1304 RejectionSampleIter {
1305 generation_rng: Rng::with_seed(self.generator.seed),
1306 rejection_rng: Rng::with_seed(self.options.seed),
1307 diagnostics: RejectionSamplingDiagnostics {
1308 envelope_max_weight: self.max_weight,
1309 ..Default::default()
1310 },
1311 sampler: self,
1312 current_batch: None,
1313 current_intensities: Vec::new(),
1314 current_index: 0,
1315 }
1316 }
1317}
1318
1319fn estimate_rejection_envelope<I>(
1320 generator: &EventGenerator,
1321 intensity: &I,
1322 pilot_events: usize,
1323 batch_size: usize,
1324 safety_factor: f64,
1325) -> LadduResult<f64>
1326where
1327 I: BatchIntensity,
1328{
1329 if pilot_events == 0 {
1330 return Err(LadduError::Custom(
1331 "pilot_events must be greater than zero".to_string(),
1332 ));
1333 }
1334 if batch_size == 0 {
1335 return Err(LadduError::Custom(
1336 "pilot batch_size must be greater than zero".to_string(),
1337 ));
1338 }
1339 if !safety_factor.is_finite() || safety_factor <= 0.0 {
1340 return Err(LadduError::Custom(
1341 "pilot safety_factor must be finite and positive".to_string(),
1342 ));
1343 }
1344
1345 let mut max_observed = 0.0_f64;
1346 let mut rng = Rng::with_seed(generator.seed);
1347 let mut remaining = pilot_events;
1348 while remaining > 0 {
1349 let n_events = remaining.min(batch_size);
1350 let batch = generator.generate_batch_with_rng(n_events, &mut rng)?;
1351 let weights = intensity.evaluate_checked(&batch)?;
1352 for weight in weights {
1353 max_observed = max_observed.max(weight);
1354 }
1355 remaining -= n_events;
1356 }
1357
1358 let max_weight = max_observed * safety_factor;
1359 if !max_weight.is_finite() || max_weight <= 0.0 {
1360 return Err(LadduError::Custom(format!(
1361 "pilot envelope produced invalid max_weight {max_weight}; observed maximum was {max_observed}"
1362 )));
1363 }
1364 Ok(max_weight)
1365}
1366
1367#[derive(Clone, Debug)]
1369pub struct RejectionSampleIter<I> {
1370 sampler: RejectionSampler<I>,
1371 generation_rng: Rng,
1372 rejection_rng: Rng,
1373 diagnostics: RejectionSamplingDiagnostics,
1374 current_batch: Option<GeneratedBatch>,
1375 current_intensities: Vec<f64>,
1376 current_index: usize,
1377}
1378
1379impl<I> RejectionSampleIter<I> {
1380 pub fn diagnostics(&self) -> &RejectionSamplingDiagnostics {
1382 &self.diagnostics
1383 }
1384}
1385
1386impl<I> RejectionSampleIter<I>
1387where
1388 I: BatchIntensity,
1389{
1390 fn load_next_source_batch(&mut self) -> LadduResult<()> {
1391 let batch = self.sampler.generator.generate_batch_with_rng(
1392 self.sampler.options.generation_batch_size,
1393 &mut self.generation_rng,
1394 )?;
1395 let intensities = self.sampler.intensity.evaluate_checked(&batch)?;
1396 self.diagnostics.generated_events += batch.dataset().n_events();
1397 self.current_batch = Some(batch);
1398 self.current_intensities = intensities;
1399 self.current_index = 0;
1400 Ok(())
1401 }
1402
1403 fn empty_output_batch(source: &GeneratedBatch) -> GeneratedBatch {
1404 GeneratedBatch::new(
1405 Dataset::empty_local(source.dataset().metadata().clone()),
1406 source.reaction().clone(),
1407 source.layout().clone(),
1408 )
1409 }
1410}
1411
1412impl<I> Iterator for RejectionSampleIter<I>
1413where
1414 I: BatchIntensity,
1415{
1416 type Item = LadduResult<GeneratedBatch>;
1417
1418 fn next(&mut self) -> Option<Self::Item> {
1419 if self.diagnostics.accepted_events >= self.sampler.options.target_accepted {
1420 return None;
1421 }
1422
1423 let mut output: Option<GeneratedBatch> = None;
1424 while self.diagnostics.accepted_events < self.sampler.options.target_accepted {
1425 let needs_batch = self
1426 .current_batch
1427 .as_ref()
1428 .map(|batch| self.current_index >= batch.dataset().n_events())
1429 .unwrap_or(true);
1430 if needs_batch {
1431 if let Err(err) = self.load_next_source_batch() {
1432 return Some(Err(err));
1433 }
1434 }
1435
1436 let source = self
1437 .current_batch
1438 .as_ref()
1439 .expect("source batch should be loaded");
1440 if output.is_none() {
1441 output = Some(Self::empty_output_batch(source));
1442 }
1443
1444 let weight = self.current_intensities[self.current_index];
1445 self.diagnostics.max_observed_weight = self.diagnostics.max_observed_weight.max(weight);
1446 let envelope_max = self.sampler.max_weight;
1447 if weight > envelope_max {
1448 self.diagnostics.envelope_violations += 1;
1449 return Some(Err(LadduError::Custom(format!(
1450 "rejection envelope violation: observed weight {weight} exceeds max_weight {envelope_max}"
1451 ))));
1452 }
1453
1454 let accepted = self.rejection_rng.f64() * envelope_max < weight;
1455 if accepted {
1456 let event = match source.dataset().event_global(self.current_index) {
1457 Ok(event) => event,
1458 Err(err) => return Some(Err(err)),
1459 };
1460 if let Err(err) = output.as_mut().unwrap().dataset.push_event_local(
1461 event.p4s.clone(),
1462 event.aux.clone(),
1463 event.weight,
1464 ) {
1465 return Some(Err(err));
1466 }
1467 self.diagnostics.accepted_events += 1;
1468 } else {
1469 self.diagnostics.rejected_events += 1;
1470 }
1471 self.current_index += 1;
1472
1473 if output.as_ref().unwrap().dataset().n_events()
1474 >= self.sampler.options.output_batch_size
1475 || self.diagnostics.accepted_events >= self.sampler.options.target_accepted
1476 {
1477 break;
1478 }
1479 }
1480
1481 output
1482 .filter(|batch| batch.dataset().n_events() > 0)
1483 .map(Ok)
1484 }
1485}
1486
1487impl EventGenerator {
1488 pub fn new(
1490 reaction: GeneratedReaction,
1491 aux_generators: HashMap<String, Distribution>,
1492 seed: Option<u64>,
1493 ) -> Self {
1494 Self {
1495 reaction,
1496 aux_generators,
1497 storage: GeneratedStorage::All,
1498 seed: seed.unwrap_or_else(|| fastrand::u64(..)),
1499 }
1500 }
1501
1502 pub fn storage(&self) -> &GeneratedStorage {
1504 &self.storage
1505 }
1506
1507 pub fn with_storage(mut self, storage: GeneratedStorage) -> LadduResult<Self> {
1509 storage.validate(&self.reaction.p4_labels())?;
1510 self.storage = storage;
1511 Ok(self)
1512 }
1513
1514 fn aux_entries(&self) -> Vec<(&String, &Distribution)> {
1515 let mut aux_entries = self.aux_generators.iter().collect::<Vec<_>>();
1516 aux_entries.sort_by_key(|(label, _)| *label);
1517 aux_entries
1518 }
1519
1520 fn generate_batch_with_rng(
1521 &self,
1522 n_events: usize,
1523 rng: &mut Rng,
1524 ) -> LadduResult<GeneratedBatch> {
1525 let all_p4_labels = self.reaction.p4_labels();
1526 self.storage.validate(&all_p4_labels)?;
1527 let p4_labels = self.storage.stored_labels(&all_p4_labels);
1528 let aux_entries = self.aux_entries();
1529 let aux_labels = aux_entries
1530 .iter()
1531 .map(|(label, _)| (*label).clone())
1532 .collect::<Vec<_>>();
1533 let mut p4_data: HashMap<String, Vec<Vec4>> = p4_labels
1534 .iter()
1535 .map(|label| (label.clone(), Vec::with_capacity(n_events)))
1536 .collect();
1537 let metadata = DatasetMetadata::new(p4_labels.clone(), aux_labels.clone())?;
1538 let mut aux: Vec<Vec<f64>> = aux_entries
1539 .iter()
1540 .map(|_| Vec::with_capacity(n_events))
1541 .collect();
1542 let weights = vec![1.0; n_events];
1543 for _ in 0..n_events {
1544 for ((_, distribution), column) in aux_entries.iter().zip(aux.iter_mut()) {
1545 column.push(distribution.sample(rng));
1546 }
1547 self.reaction.generate(rng, &mut p4_data, 1);
1548 }
1549 let p4 = p4_labels
1550 .iter()
1551 .filter_map(|label| p4_data.remove(label))
1552 .collect();
1553 let dataset = Dataset::from_columns_local(metadata, p4, aux, weights)?;
1554 Ok(GeneratedBatch::new(
1555 dataset,
1556 self.reaction.clone(),
1557 GeneratedEventLayout::new(
1558 p4_labels,
1559 aux_labels,
1560 self.reaction.particle_layouts_with_storage(&self.storage),
1561 self.reaction.vertex_layouts(),
1562 ),
1563 ))
1564 }
1565
1566 pub fn generate_batch(&self, n_events: usize) -> LadduResult<GeneratedBatch> {
1568 let mut rng = Rng::with_seed(self.seed);
1569 self.generate_batch_with_rng(n_events, &mut rng)
1570 }
1571
1572 pub fn generate_batches(
1577 &self,
1578 total_events: usize,
1579 batch_size: usize,
1580 ) -> LadduResult<GeneratedBatchIter> {
1581 if batch_size == 0 {
1582 return Err(LadduError::Custom(
1583 "batch_size must be greater than zero".to_string(),
1584 ));
1585 }
1586 Ok(GeneratedBatchIter {
1587 generator: self.clone(),
1588 remaining_events: total_events,
1589 batch_size,
1590 rng: Rng::with_seed(self.seed),
1591 })
1592 }
1593
1594 pub fn generate_dataset(&self, n_events: usize) -> LadduResult<Dataset> {
1596 Ok(self.generate_batch(n_events)?.into_dataset())
1597 }
1598
1599 pub fn rejection_sampler_with_expression(
1601 &self,
1602 expression: Expression,
1603 parameters: Vec<f64>,
1604 options: RejectionSamplingOptions,
1605 ) -> LadduResult<RejectionSampler<ExpressionIntensity>> {
1606 RejectionSampler::new(
1607 self.clone(),
1608 ExpressionIntensity::new(expression, parameters),
1609 options,
1610 )
1611 }
1612}
1613
1614#[cfg(test)]
1615mod tests {
1616 use approx::assert_relative_eq;
1617 use laddu_core::{traits::Variable, Channel, Expression, Frame};
1618
1619 use super::*;
1620
1621 fn demo_reaction() -> GeneratedReaction {
1622 let beam = GeneratedParticle::initial(
1623 "beam",
1624 InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1625 Reconstruction::Stored,
1626 );
1627 let target = GeneratedParticle::initial(
1628 "target",
1629 InitialGenerator::target(0.938272),
1630 Reconstruction::Missing,
1631 );
1632 let ks1 = GeneratedParticle::stable(
1633 "kshort1",
1634 StableGenerator::new(0.497611),
1635 Reconstruction::Stored,
1636 );
1637 let ks2 = GeneratedParticle::stable(
1638 "kshort2",
1639 StableGenerator::new(0.497611),
1640 Reconstruction::Stored,
1641 );
1642 let kk = GeneratedParticle::composite(
1643 "kk",
1644 CompositeGenerator::new(1.1, 1.6),
1645 (&ks1, &ks2),
1646 Reconstruction::Composite,
1647 );
1648 let recoil = GeneratedParticle::stable(
1649 "recoil",
1650 StableGenerator::new(0.938272),
1651 Reconstruction::Stored,
1652 );
1653 let tdist = MandelstamTDistribution::Exponential { slope: 0.1 };
1654 GeneratedReaction::two_to_two(beam, target, kk, recoil, tdist).unwrap()
1655 }
1656
1657 #[test]
1658 fn test_generation() {
1659 let reaction = demo_reaction();
1660 let generator = EventGenerator::new(reaction, HashMap::new(), Some(12345));
1661 let n_events = 1_000;
1662 let dataset = generator.generate_dataset(n_events).unwrap();
1663 assert_eq!(dataset.n_events(), n_events);
1664 let metadata = dataset.metadata();
1665 assert!(metadata.p4_index("beam").is_some());
1666 assert!(metadata.p4_index("target").is_some());
1667 assert!(metadata.p4_index("kk").is_some());
1668 assert!(metadata.p4_index("kshort1").is_some());
1669 assert!(metadata.p4_index("kshort2").is_some());
1670 assert!(metadata.p4_index("recoil").is_some());
1671
1672 for event in dataset.events_global() {
1673 let beam_p4 = event.p4("beam").unwrap();
1674 let target_p4 = event.p4("target").unwrap();
1675 let kk_p4 = event.p4("kk").unwrap();
1676 let kshort1_p4 = event.p4("kshort1").unwrap();
1677 let kshort2_p4 = event.p4("kshort2").unwrap();
1678 let recoil_p4 = event.p4("recoil").unwrap();
1679
1680 assert!(beam_p4.e().is_finite());
1681 assert!(target_p4.e().is_finite());
1682 assert!(kk_p4.e().is_finite());
1683 assert!(kshort1_p4.e().is_finite());
1684 assert!(kshort2_p4.e().is_finite());
1685 assert!(recoil_p4.e().is_finite());
1686
1687 assert_relative_eq!(kk_p4, kshort1_p4 + kshort2_p4, epsilon = 1e-10);
1688 assert_relative_eq!(beam_p4 + target_p4, kk_p4 + recoil_p4, epsilon = 1e-10);
1689 assert_relative_eq!(kshort1_p4.m(), 0.497611, epsilon = 1e-10);
1690 assert_relative_eq!(kshort2_p4.m(), 0.497611, epsilon = 1e-10);
1691 assert_relative_eq!(recoil_p4.m(), 0.938272, epsilon = 1e-10);
1692 }
1693 }
1694
1695 #[test]
1696 fn test_reconstructed_reaction() {
1697 let generated = demo_reaction();
1698 let reaction = generated.reconstructed_reaction().unwrap();
1699 let dataset = EventGenerator::new(generated, HashMap::new(), Some(12345))
1700 .generate_dataset(4)
1701 .unwrap();
1702 let mass = reaction.mass("kk").value_on(&dataset).unwrap();
1703 let angles = reaction
1704 .decay("kk")
1705 .unwrap()
1706 .angles("kshort1", Frame::Helicity)
1707 .unwrap();
1708 let mandelstam = reaction
1709 .mandelstam(Channel::S)
1710 .unwrap()
1711 .value_on(&dataset)
1712 .unwrap();
1713
1714 assert_eq!(mass.len(), 4);
1715 assert_eq!(
1716 angles.costheta.to_string(),
1717 "CosTheta(parent=kk, daughter=kshort1, frame=Helicity)"
1718 );
1719 assert_eq!(mandelstam.len(), 4);
1720 }
1721
1722 #[test]
1723 fn test_generated_batch_metadata() {
1724 let generated = demo_reaction();
1725 let generator = EventGenerator::new(
1726 generated,
1727 HashMap::from([("pol_angle".to_string(), Distribution::Fixed(0.25))]),
1728 Some(12345),
1729 );
1730 let batch = generator.generate_batch(4).unwrap();
1731
1732 assert_eq!(batch.dataset().n_events(), 4);
1733 assert_eq!(
1734 batch.layout().p4_labels(),
1735 &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1736 );
1737 assert_eq!(batch.layout().aux_labels(), &["pol_angle"]);
1738 assert_eq!(
1739 batch.reaction().p4_labels(),
1740 vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1741 );
1742 assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1743 assert_eq!(batch.dataset().aux_names(), batch.layout().aux_labels());
1744 let particles = batch.layout().particles();
1745 assert_eq!(particles.len(), 6);
1746 assert_eq!(particles[0].id(), "beam");
1747 assert_eq!(particles[0].product_id(), 0);
1748 assert_eq!(particles[0].parent_id(), None);
1749 assert_eq!(particles[0].produced_vertex_id(), None);
1750 assert_eq!(particles[0].decay_vertex_id(), None);
1751 assert_eq!(particles[1].id(), "target");
1752 assert_eq!(particles[1].parent_id(), None);
1753 assert_eq!(particles[1].produced_vertex_id(), None);
1754 assert_eq!(particles[1].decay_vertex_id(), None);
1755 assert_eq!(particles[2].id(), "kk");
1756 assert_eq!(particles[2].product_id(), 2);
1757 assert_eq!(particles[2].parent_id(), None);
1758 assert_eq!(particles[2].produced_vertex_id(), Some(0));
1759 assert_eq!(particles[2].decay_vertex_id(), Some(1));
1760 assert_eq!(particles[3].id(), "kshort1");
1761 assert_eq!(particles[3].parent_id(), Some(2));
1762 assert_eq!(particles[3].produced_vertex_id(), Some(1));
1763 assert_eq!(particles[3].decay_vertex_id(), None);
1764 assert_eq!(particles[4].id(), "kshort2");
1765 assert_eq!(particles[4].parent_id(), Some(2));
1766 assert_eq!(particles[4].produced_vertex_id(), Some(1));
1767 assert_eq!(particles[4].decay_vertex_id(), None);
1768 assert_eq!(particles[5].id(), "recoil");
1769 assert_eq!(particles[5].parent_id(), None);
1770 assert_eq!(particles[5].produced_vertex_id(), Some(0));
1771 assert_eq!(particles[5].decay_vertex_id(), None);
1772 for particle in particles {
1773 assert_eq!(particle.p4_label(), Some(particle.id()));
1774 }
1775 let vertices = batch.layout().vertices();
1776 assert_eq!(vertices.len(), 2);
1777 assert_eq!(vertices[0].vertex_id(), 0);
1778 assert_eq!(vertices[0].kind(), GeneratedVertexKind::Production);
1779 assert_eq!(vertices[0].incoming_product_ids(), &[0, 1]);
1780 assert_eq!(vertices[0].outgoing_product_ids(), &[2, 5]);
1781 assert_eq!(vertices[1].vertex_id(), 1);
1782 assert_eq!(vertices[1].kind(), GeneratedVertexKind::Decay);
1783 assert_eq!(vertices[1].incoming_product_ids(), &[2]);
1784 assert_eq!(vertices[1].outgoing_product_ids(), &[3, 4]);
1785
1786 assert_eq!(batch.layout().particle("kk"), Some(&particles[2]));
1787 assert_eq!(batch.layout().particle("missing_id"), None);
1788 assert_eq!(batch.layout().product(5), Some(&particles[5]));
1789 assert_eq!(batch.layout().product(6), None);
1790 assert_eq!(batch.layout().vertex(1), Some(&vertices[1]));
1791 assert_eq!(batch.layout().vertex(2), None);
1792 assert_eq!(batch.layout().production_vertex(), Some(&vertices[0]));
1793 assert_eq!(
1794 batch
1795 .layout()
1796 .production_incoming()
1797 .iter()
1798 .map(|particle| particle.id())
1799 .collect::<Vec<_>>(),
1800 vec!["beam", "target"]
1801 );
1802 assert_eq!(
1803 batch
1804 .layout()
1805 .production_outgoing()
1806 .iter()
1807 .map(|particle| particle.id())
1808 .collect::<Vec<_>>(),
1809 vec!["kk", "recoil"]
1810 );
1811 assert_eq!(
1812 batch
1813 .layout()
1814 .decay_products(2)
1815 .iter()
1816 .map(|particle| particle.id())
1817 .collect::<Vec<_>>(),
1818 vec!["kshort1", "kshort2"]
1819 );
1820 assert!(batch.layout().decay_products(5).is_empty());
1821 }
1822
1823 #[test]
1824 fn generated_storage_only_projects_dataset_columns() {
1825 let generated = demo_reaction();
1826 let generator = EventGenerator::new(generated, HashMap::new(), Some(12345))
1827 .with_storage(GeneratedStorage::only([
1828 "beam", "target", "kshort1", "kshort2", "recoil",
1829 ]))
1830 .unwrap();
1831 let batch = generator.generate_batch(4).unwrap();
1832
1833 assert_eq!(
1834 batch.reaction().p4_labels(),
1835 vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1836 );
1837 assert_eq!(
1838 batch.layout().p4_labels(),
1839 &["beam", "target", "kshort1", "kshort2", "recoil"]
1840 );
1841 assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1842 assert!(batch.dataset().metadata().p4_index("kk").is_none());
1843
1844 let particles = batch.layout().particles();
1845 assert_eq!(particles.len(), 6);
1846 assert_eq!(particles[2].id(), "kk");
1847 assert_eq!(particles[2].p4_label(), None);
1848 assert_eq!(particles[3].p4_label(), Some("kshort1"));
1849 assert_eq!(particles[4].p4_label(), Some("kshort2"));
1850
1851 for index in 0..batch.dataset().n_events() {
1852 let event = batch.dataset().event_global(index).unwrap();
1853 assert_relative_eq!(
1854 event.p4("beam").unwrap() + event.p4("target").unwrap(),
1855 event.p4("kshort1").unwrap()
1856 + event.p4("kshort2").unwrap()
1857 + event.p4("recoil").unwrap(),
1858 epsilon = 1e-10
1859 );
1860 }
1861 }
1862
1863 #[test]
1864 fn generated_storage_rejects_unknown_and_duplicate_ids() {
1865 assert!(
1866 EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1867 .with_storage(GeneratedStorage::only(["beam", "does_not_exist"]))
1868 .is_err()
1869 );
1870 assert!(
1871 EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1872 .with_storage(GeneratedStorage::only(["beam", "beam"]))
1873 .is_err()
1874 );
1875 }
1876
1877 #[test]
1878 fn generated_species_metadata_propagates_to_layout() {
1879 let beam = GeneratedParticle::initial(
1880 "beam",
1881 InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1882 Reconstruction::Stored,
1883 )
1884 .with_species(ParticleSpecies::code(22));
1885 let target = GeneratedParticle::initial(
1886 "target",
1887 InitialGenerator::target(0.938272),
1888 Reconstruction::Missing,
1889 )
1890 .with_species(ParticleSpecies::with_namespace("pdg", 2212));
1891 let kshort1 = GeneratedParticle::stable(
1892 "kshort1",
1893 StableGenerator::new(0.497611),
1894 Reconstruction::Stored,
1895 )
1896 .with_species(ParticleSpecies::label("KShort"));
1897 let kshort2 = GeneratedParticle::stable(
1898 "kshort2",
1899 StableGenerator::new(0.497611),
1900 Reconstruction::Stored,
1901 )
1902 .with_species(ParticleSpecies::label("KShort"));
1903 let kk = GeneratedParticle::composite(
1904 "kk",
1905 CompositeGenerator::new(1.1, 1.6),
1906 (&kshort1, &kshort2),
1907 Reconstruction::Composite,
1908 )
1909 .with_species(ParticleSpecies::label("KK"));
1910 let recoil = GeneratedParticle::stable(
1911 "recoil",
1912 StableGenerator::new(0.938272),
1913 Reconstruction::Stored,
1914 )
1915 .with_species(ParticleSpecies::code(2212));
1916 let reaction = GeneratedReaction::two_to_two(
1917 beam,
1918 target,
1919 kk,
1920 recoil,
1921 MandelstamTDistribution::Exponential { slope: 0.1 },
1922 )
1923 .unwrap();
1924 let particles = reaction.particle_layouts();
1925
1926 assert_eq!(particles[0].species(), Some(&ParticleSpecies::code(22)));
1927 assert_eq!(
1928 particles[1].species(),
1929 Some(&ParticleSpecies::with_namespace("pdg", 2212))
1930 );
1931 assert_eq!(particles[2].species(), Some(&ParticleSpecies::label("KK")));
1932 assert_eq!(
1933 particles[3].species(),
1934 Some(&ParticleSpecies::label("KShort"))
1935 );
1936 assert_eq!(
1937 particles[4].species(),
1938 Some(&ParticleSpecies::label("KShort"))
1939 );
1940 assert_eq!(particles[5].species(), Some(&ParticleSpecies::code(2212)));
1941 }
1942
1943 #[test]
1944 fn generated_batches_match_one_shot_generation() {
1945 let generated = demo_reaction();
1946 let generator = EventGenerator::new(
1947 generated,
1948 HashMap::from([(
1949 "pol_angle".to_string(),
1950 Distribution::Uniform { min: 0.0, max: 1.0 },
1951 )]),
1952 Some(12345),
1953 );
1954 let one_shot = generator.generate_dataset(7).unwrap();
1955 let batches = generator
1956 .generate_batches(7, 3)
1957 .unwrap()
1958 .collect::<LadduResult<Vec<_>>>()
1959 .unwrap();
1960 let batch_sizes = batches
1961 .iter()
1962 .map(|batch| batch.dataset().n_events())
1963 .collect::<Vec<_>>();
1964 assert_eq!(batch_sizes, vec![3, 3, 1]);
1965
1966 let mut offset = 0;
1967 for batch in batches {
1968 for local_index in 0..batch.dataset().n_events() {
1969 let expected = one_shot.event_global(offset + local_index).unwrap();
1970 let actual = batch.dataset().event_global(local_index).unwrap();
1971 for name in one_shot.p4_names() {
1972 assert_relative_eq!(
1973 actual.p4(name).unwrap(),
1974 expected.p4(name).unwrap(),
1975 epsilon = 1e-10
1976 );
1977 }
1978 for aux_index in 0..one_shot.aux_names().len() {
1979 assert_relative_eq!(actual.aux[aux_index], expected.aux[aux_index]);
1980 }
1981 assert_relative_eq!(actual.weight(), expected.weight());
1982 }
1983 offset += batch.dataset().n_events();
1984 }
1985 assert_eq!(offset, one_shot.n_events());
1986 assert!(generator.generate_batches(1, 0).is_err());
1987 }
1988
1989 #[test]
1990 fn fixed_envelope_rejection_sampler_streams_accepted_batches() {
1991 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
1992 let sampler = RejectionSampler::new(
1993 generator,
1994 |batch: &GeneratedBatch| Ok(vec![1.0; batch.dataset().n_events()]),
1995 RejectionSamplingOptions {
1996 target_accepted: 5,
1997 generation_batch_size: 4,
1998 output_batch_size: 2,
1999 envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2000 seed: 67890,
2001 },
2002 )
2003 .unwrap();
2004
2005 let mut iter = sampler.accepted_batches();
2006 let mut accepted_batches = Vec::new();
2007 for batch in iter.by_ref() {
2008 accepted_batches.push(batch.unwrap());
2009 }
2010 assert_eq!(
2011 accepted_batches
2012 .iter()
2013 .map(|batch| batch.dataset().n_events())
2014 .collect::<Vec<_>>(),
2015 vec![2, 2, 1]
2016 );
2017 assert_eq!(iter.diagnostics().generated_events, 8);
2018 assert_eq!(iter.diagnostics().accepted_events, 5);
2019 assert_eq!(iter.diagnostics().rejected_events, 0);
2020 assert_relative_eq!(iter.diagnostics().acceptance_efficiency(), 5.0 / 8.0);
2021 for batch in accepted_batches {
2022 assert_eq!(
2023 batch.layout().p4_labels(),
2024 &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
2025 );
2026 }
2027 }
2028
2029 #[test]
2030 fn fixed_envelope_rejection_sampler_rejects_violations() {
2031 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2032 let sampler = RejectionSampler::new(
2033 generator,
2034 |batch: &GeneratedBatch| Ok(vec![2.0; batch.dataset().n_events()]),
2035 RejectionSamplingOptions {
2036 target_accepted: 1,
2037 generation_batch_size: 1,
2038 output_batch_size: 1,
2039 envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2040 seed: 67890,
2041 },
2042 )
2043 .unwrap();
2044
2045 let mut iter = sampler.accepted_batches();
2046 let err = iter.next().expect("sampler should produce an error");
2047 assert!(err.is_err());
2048 assert_eq!(iter.diagnostics().envelope_violations, 1);
2049 assert_relative_eq!(iter.diagnostics().max_observed_weight, 2.0);
2050 }
2051
2052 #[test]
2053 fn rejection_sampler_validates_custom_batch_intensities() {
2054 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2055 let sampler = RejectionSampler::new(
2056 generator,
2057 |batch: &GeneratedBatch| Ok(vec![f64::NAN; batch.dataset().n_events()]),
2058 RejectionSamplingOptions {
2059 target_accepted: 1,
2060 generation_batch_size: 1,
2061 output_batch_size: 1,
2062 envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2063 seed: 67890,
2064 },
2065 )
2066 .unwrap();
2067
2068 let err = sampler
2069 .accepted_batches()
2070 .next()
2071 .expect("sampler should produce an error");
2072 assert!(err.is_err());
2073 }
2074
2075 #[test]
2076 fn expression_rejection_sampler_streams_exact_target_count() {
2077 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2078 let sampler = generator
2079 .rejection_sampler_with_expression(
2080 Expression::from(1.0),
2081 vec![],
2082 RejectionSamplingOptions {
2083 target_accepted: 5,
2084 generation_batch_size: 4,
2085 output_batch_size: 2,
2086 envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2087 seed: 67890,
2088 },
2089 )
2090 .unwrap();
2091
2092 let batches = sampler
2093 .accepted_batches()
2094 .collect::<LadduResult<Vec<_>>>()
2095 .unwrap();
2096 assert_eq!(
2097 batches
2098 .iter()
2099 .map(|batch| batch.dataset().n_events())
2100 .collect::<Vec<_>>(),
2101 vec![2, 2, 1]
2102 );
2103 assert_eq!(
2104 batches
2105 .iter()
2106 .map(|batch| batch.dataset().n_events())
2107 .sum::<usize>(),
2108 5
2109 );
2110 }
2111
2112 #[test]
2113 fn expression_rejection_sampler_supports_pilot_envelope() {
2114 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2115 let sampler = generator
2116 .rejection_sampler_with_expression(
2117 Expression::from(1.0),
2118 vec![],
2119 RejectionSamplingOptions {
2120 target_accepted: 3,
2121 generation_batch_size: 2,
2122 output_batch_size: 2,
2123 envelope: RejectionEnvelope::Pilot {
2124 pilot_events: 4,
2125 batch_size: Some(2),
2126 safety_factor: 1.25,
2127 },
2128 seed: 67890,
2129 },
2130 )
2131 .unwrap();
2132 let mut iter = sampler.accepted_batches();
2133 let batches = iter.by_ref().collect::<LadduResult<Vec<_>>>().unwrap();
2134 assert_eq!(
2135 batches
2136 .iter()
2137 .map(|batch| batch.dataset().n_events())
2138 .sum::<usize>(),
2139 3
2140 );
2141 assert_relative_eq!(iter.diagnostics().envelope_max_weight, 1.25);
2142 }
2143
2144 #[test]
2145 fn expression_rejection_sampler_rejects_invalid_intensities() {
2146 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
2147 let err = generator
2148 .rejection_sampler_with_expression(
2149 Expression::from(-1.0),
2150 vec![],
2151 RejectionSamplingOptions {
2152 target_accepted: 1,
2153 generation_batch_size: 1,
2154 output_batch_size: 1,
2155 envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
2156 seed: 67890,
2157 },
2158 )
2159 .unwrap()
2160 .accepted_batches()
2161 .next()
2162 .expect("sampler should emit an error");
2163 assert!(err.is_err());
2164 }
2165
2166 #[test]
2167 fn duplicate_generated_particle_ids_are_rejected() {
2168 let beam = GeneratedParticle::initial(
2169 "beam",
2170 InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
2171 Reconstruction::Stored,
2172 );
2173 let target = GeneratedParticle::initial(
2174 "target",
2175 InitialGenerator::target(0.938272),
2176 Reconstruction::Missing,
2177 );
2178 let duplicate = GeneratedParticle::stable(
2179 "beam",
2180 StableGenerator::new(0.497611),
2181 Reconstruction::Stored,
2182 );
2183 let recoil = GeneratedParticle::stable(
2184 "recoil",
2185 StableGenerator::new(0.938272),
2186 Reconstruction::Stored,
2187 );
2188
2189 assert!(GeneratedReaction::two_to_two(
2190 beam,
2191 target,
2192 duplicate,
2193 recoil,
2194 MandelstamTDistribution::Exponential { slope: 0.1 },
2195 )
2196 .is_err());
2197 }
2198}