1use std::collections::{HashMap, HashSet};
2
3use fastrand::Rng;
4use laddu_core::{
5 math::{q_m, Histogram, Sheet},
6 Dataset, DatasetMetadata, LadduError, LadduResult, Particle, Reaction, Vec3, Vec4, PI,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::distributions::{
11 Distribution, HistogramSampler, LadduGenRngExt, MandelstamTDistribution, SimpleDistribution,
12};
13
14#[derive(Clone, Debug, PartialEq, Eq)]
20pub enum GeneratedStorage {
21 All,
23 Only(Vec<String>),
25}
26
27impl GeneratedStorage {
28 pub fn all() -> Self {
30 Self::All
31 }
32
33 pub fn only<I, S>(ids: I) -> Self
35 where
36 I: IntoIterator<Item = S>,
37 S: Into<String>,
38 {
39 Self::Only(ids.into_iter().map(Into::into).collect())
40 }
41
42 pub fn stores(&self, id: &str) -> bool {
44 match self {
45 Self::All => true,
46 Self::Only(ids) => ids.iter().any(|stored_id| stored_id == id),
47 }
48 }
49
50 fn validate(&self, available_ids: &[String]) -> LadduResult<()> {
51 let available = available_ids
52 .iter()
53 .map(String::as_str)
54 .collect::<HashSet<_>>();
55 let Self::Only(ids) = self else {
56 return Ok(());
57 };
58 let mut seen = HashSet::new();
59 for id in ids {
60 if !seen.insert(id.as_str()) {
61 return Err(LadduError::Custom(format!(
62 "generated storage contains duplicate particle ID '{id}'"
63 )));
64 }
65 if !available.contains(id.as_str()) {
66 return Err(LadduError::Custom(format!(
67 "generated storage references unknown particle ID '{id}'"
68 )));
69 }
70 }
71 Ok(())
72 }
73
74 fn stored_labels(&self, all_labels: &[String]) -> Vec<String> {
75 all_labels
76 .iter()
77 .filter(|label| self.stores(label))
78 .cloned()
79 .collect()
80 }
81}
82
83#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
89pub enum ParticleSpecies {
90 Code {
92 id: i64,
94 namespace: Option<String>,
96 },
97 Label(String),
99}
100
101impl ParticleSpecies {
102 pub fn code(id: i64) -> Self {
104 Self::Code {
105 id,
106 namespace: None,
107 }
108 }
109
110 pub fn with_namespace(namespace: impl Into<String>, id: i64) -> Self {
112 Self::Code {
113 id,
114 namespace: Some(namespace.into()),
115 }
116 }
117
118 pub fn label(label: impl Into<String>) -> Self {
120 Self::Label(label.into())
121 }
122}
123
124fn basis(z: Vec3) -> (Vec3, Vec3, Vec3) {
125 let z = z.unit();
126 let ref_axis = if z.z.abs() < 0.9 {
127 Vec3::z()
128 } else {
129 Vec3::y()
130 };
131 let x = ref_axis.cross(&z).unit();
132 let y = z.cross(&x);
133 (x, y, z)
134}
135
136#[derive(Clone, Debug)]
138pub struct InitialGenerator {
139 mass: f64,
140 energy_distribution: SimpleDistribution,
141}
142
143impl InitialGenerator {
144 pub fn beam_with_fixed_energy(mass: f64, energy: f64) -> Self {
146 debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
147 debug_assert!(energy > 0.0, "Energy must be positive!\nEnergy: {}", energy);
148 Self {
149 mass,
150 energy_distribution: SimpleDistribution::Fixed(energy),
151 }
152 }
153
154 pub fn beam(mass: f64, min_energy: f64, max_energy: f64) -> Self {
156 debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
157 debug_assert!(
158 min_energy > 0.0,
159 "Minimum energy must be positive!\nMinimum Energy: {}",
160 min_energy
161 );
162 debug_assert!(
163 max_energy > min_energy,
164 "Maximum energy must be greater than minimum energy!"
165 );
166 Self {
167 mass,
168 energy_distribution: SimpleDistribution::Uniform {
169 min: min_energy,
170 max: max_energy,
171 },
172 }
173 }
174
175 pub fn beam_with_energy_histogram(mass: f64, energy: Histogram) -> LadduResult<Self> {
177 debug_assert!(
178 mass >= 0.0,
179 "Mass must be positive and greater than zero!\nMass: {}",
180 mass
181 );
182 let sampler = HistogramSampler::new(energy)?;
183 debug_assert!(
184 sampler.hist.bin_edges()[0] >= mass,
185 "Mass cannot be greater than the minimum allowed energy!\nMass: {}\nMinimum Energy: {}",
186 mass,
187 sampler.hist.bin_edges()[0]
188 );
189 Ok(Self {
190 mass,
191 energy_distribution: SimpleDistribution::Histogram(sampler),
192 })
193 }
194
195 pub fn target(mass: f64) -> Self {
197 Self {
198 mass,
199 energy_distribution: SimpleDistribution::Fixed(mass),
200 }
201 }
202}
203
204#[derive(Clone, Debug)]
206pub struct CompositeGenerator {
207 mass_distribution: SimpleDistribution,
208}
209
210impl CompositeGenerator {
211 pub fn new(min_mass: f64, max_mass: f64) -> Self {
213 Self {
214 mass_distribution: SimpleDistribution::Uniform {
215 min: min_mass,
216 max: max_mass,
217 },
218 }
219 }
220
221 fn sample_mass(&self, rng: &mut Rng) -> f64 {
222 self.mass_distribution.sample(rng)
223 }
224}
225
226#[derive(Clone, Debug)]
228pub struct StableGenerator {
229 mass_distribution: SimpleDistribution,
230}
231
232impl StableGenerator {
233 pub fn new(mass: f64) -> Self {
235 debug_assert!(mass >= 0.0, "Mass cannot be negative!\nMass: {}", mass);
236 Self {
237 mass_distribution: SimpleDistribution::Fixed(mass),
238 }
239 }
240
241 fn sample_mass(&self, rng: &mut Rng) -> f64 {
242 self.mass_distribution.sample(rng)
243 }
244}
245
246#[derive(Clone, Debug, PartialEq)]
248pub enum Reconstruction {
249 Stored,
251 Fixed(Vec4),
253 Missing,
255 Composite,
257}
258
259#[derive(Clone, Debug)]
261pub enum GeneratedParticle {
262 Initial {
264 id: String,
265 generator: InitialGenerator,
266 reconstruction: Reconstruction,
267 species: Option<ParticleSpecies>,
268 },
269 Stable {
271 id: String,
272 generator: StableGenerator,
273 reconstruction: Reconstruction,
274 species: Option<ParticleSpecies>,
275 },
276 Composite {
278 id: String,
279 generator: CompositeGenerator,
280 daughters: (Box<GeneratedParticle>, Box<GeneratedParticle>),
281 reconstruction: Reconstruction,
282 species: Option<ParticleSpecies>,
283 },
284}
285
286impl GeneratedParticle {
287 pub fn initial(
289 id: impl Into<String>,
290 generator: InitialGenerator,
291 reconstruction: Reconstruction,
292 ) -> Self {
293 Self::Initial {
294 id: id.into(),
295 generator,
296 reconstruction,
297 species: None,
298 }
299 }
300
301 pub fn stable(
303 id: impl Into<String>,
304 generator: StableGenerator,
305 reconstruction: Reconstruction,
306 ) -> Self {
307 Self::Stable {
308 id: id.into(),
309 generator,
310 reconstruction,
311 species: None,
312 }
313 }
314
315 pub fn composite(
317 id: impl Into<String>,
318 generator: CompositeGenerator,
319 daughters: (&GeneratedParticle, &GeneratedParticle),
320 reconstruction: Reconstruction,
321 ) -> Self {
322 Self::Composite {
323 id: id.into(),
324 generator,
325 daughters: (Box::new(daughters.0.clone()), Box::new(daughters.1.clone())),
326 reconstruction,
327 species: None,
328 }
329 }
330
331 pub fn with_species(mut self, species: ParticleSpecies) -> Self {
333 match &mut self {
334 Self::Initial {
335 species: particle_species,
336 ..
337 }
338 | Self::Stable {
339 species: particle_species,
340 ..
341 }
342 | Self::Composite {
343 species: particle_species,
344 ..
345 } => *particle_species = Some(species),
346 }
347 self
348 }
349
350 pub fn id(&self) -> &str {
352 match self {
353 Self::Initial { id, .. } | Self::Stable { id, .. } | Self::Composite { id, .. } => id,
354 }
355 }
356
357 pub fn species(&self) -> Option<&ParticleSpecies> {
359 match self {
360 Self::Initial { species, .. }
361 | Self::Stable { species, .. }
362 | Self::Composite { species, .. } => species.as_ref(),
363 }
364 }
365
366 pub fn reconstruction(&self) -> &Reconstruction {
368 match self {
369 Self::Initial { reconstruction, .. }
370 | Self::Stable { reconstruction, .. }
371 | Self::Composite { reconstruction, .. } => reconstruction,
372 }
373 }
374
375 fn p4_labels(&self) -> Vec<String> {
376 let mut labels = vec![self.id().to_string()];
377 if let Self::Composite { daughters, .. } = self {
378 labels.append(&mut daughters.0.p4_labels());
379 labels.append(&mut daughters.1.p4_labels());
380 }
381 labels
382 }
383
384 fn append_decay_layout(
385 &self,
386 parent_id: Option<usize>,
387 produced_vertex_id: Option<usize>,
388 storage: &GeneratedStorage,
389 particles: &mut Vec<GeneratedParticleLayout>,
390 vertices: &mut Vec<GeneratedVertexLayout>,
391 ) -> usize {
392 let product_id = particles.len();
393 particles.push(GeneratedParticleLayout {
394 id: self.id().to_string(),
395 product_id,
396 parent_id,
397 species: self.species().cloned(),
398 p4_label: storage.stores(self.id()).then(|| self.id().to_string()),
399 produced_vertex_id,
400 decay_vertex_id: None,
401 });
402 if let Self::Composite { daughters, .. } = self {
403 let vertex_id = vertices.len();
404 particles[product_id].decay_vertex_id = Some(vertex_id);
405 vertices.push(GeneratedVertexLayout {
406 vertex_id,
407 kind: GeneratedVertexKind::Decay,
408 incoming_product_ids: vec![product_id],
409 outgoing_product_ids: Vec::new(),
410 });
411 let daughter_1_id = daughters.0.append_decay_layout(
412 Some(product_id),
413 Some(vertex_id),
414 storage,
415 particles,
416 vertices,
417 );
418 let daughter_2_id = daughters.1.append_decay_layout(
419 Some(product_id),
420 Some(vertex_id),
421 storage,
422 particles,
423 vertices,
424 );
425 vertices[vertex_id].outgoing_product_ids = vec![daughter_1_id, daughter_2_id];
426 }
427 product_id
428 }
429
430 fn sample_mass(&self, rng: &mut Rng) -> f64 {
431 match self {
432 Self::Initial { generator, .. } => generator.mass,
433 Self::Stable { generator, .. } => generator.sample_mass(rng),
434 Self::Composite { generator, .. } => generator.sample_mass(rng),
435 }
436 }
437
438 fn generated_particle(&self) -> LadduResult<Particle> {
439 match self.reconstruction() {
440 Reconstruction::Stored => Ok(Particle::stored(self.id())),
441 Reconstruction::Fixed(p4) => Ok(Particle::fixed(self.id(), *p4)),
442 Reconstruction::Missing => Ok(Particle::missing(self.id())),
443 Reconstruction::Composite => {
444 let Self::Composite { daughters, .. } = self else {
445 return Err(LadduError::Custom(format!(
446 "particle '{}' cannot use composite reconstruction without daughters",
447 self.id()
448 )));
449 };
450 let daughter_1 = daughters.0.generated_particle()?;
451 let daughter_2 = daughters.1.generated_particle()?;
452 Particle::composite(self.id(), (&daughter_1, &daughter_2))
453 }
454 }
455 }
456
457 fn validate_reconstruction(&self) -> LadduResult<()> {
458 match (self, self.reconstruction()) {
459 (Self::Composite { daughters, .. }, Reconstruction::Composite) => {
460 daughters.0.validate_reconstruction()?;
461 daughters.1.validate_reconstruction()?;
462 Ok(())
463 }
464 (Self::Composite { .. }, _) => Ok(()),
465 (_, Reconstruction::Composite) => Err(LadduError::Custom(format!(
466 "particle '{}' cannot use composite reconstruction without daughters",
467 self.id()
468 ))),
469 _ => Ok(()),
470 }
471 }
472
473 fn collect_ids<'a>(&'a self, seen: &mut HashSet<&'a str>) -> LadduResult<()> {
474 if !seen.insert(self.id()) {
475 return Err(LadduError::Custom(format!(
476 "duplicate generated particle identifier '{}'",
477 self.id()
478 )));
479 }
480 if let Self::Composite { daughters, .. } = self {
481 daughters.0.collect_ids(seen)?;
482 daughters.1.collect_ids(seen)?;
483 }
484 Ok(())
485 }
486
487 fn generate_decay(
488 &self,
489 rng: &mut Rng,
490 p4_cm: Vec4,
491 cm_to_lab_boost: &Vec3,
492 p4_storage: &mut HashMap<String, Vec<Vec4>>,
493 ) {
494 let p4_lab = p4_cm.boost(cm_to_lab_boost);
495 if let Some(storage) = p4_storage.get_mut(self.id()) {
496 storage.push(p4_lab);
497 }
498
499 let Self::Composite { daughters, .. } = self else {
500 return;
501 };
502 let d1 = &daughters.0;
503 let d2 = &daughters.1;
504 let parent_mass = p4_cm.m();
505 let m1 = d1.sample_mass(rng);
506 let m2 = d2.sample_mass(rng);
507 let q = q_m(parent_mass, m1, m2, Sheet::Physical).re;
508 let parent_msq = parent_mass * parent_mass;
509 let msq1 = m1 * m1;
510 let msq2 = m2 * m2;
511 let e1 = (parent_msq + msq1 - msq2) / (2.0 * parent_mass);
512 let e2 = (parent_msq + msq2 - msq1) / (2.0 * parent_mass);
513
514 let cos_theta = rng.uniform(-1.0, 1.0);
515 let sin_theta = (1.0 - cos_theta * cos_theta).sqrt();
516 let phi = rng.uniform(0.0, 2.0 * PI);
517 let (sin_phi, cos_phi) = phi.sin_cos();
518
519 let dir = Vec3::new(sin_theta * cos_phi, sin_theta * sin_phi, cos_theta);
520 let p1_p4_rest = (dir * q).with_energy(e1);
521 let p2_p4_rest = (-dir * q).with_energy(e2);
522 let parent_to_cm_boost = p4_cm.beta();
523 let p1_p4_cm = p1_p4_rest.boost(&parent_to_cm_boost);
524 let p2_p4_cm = p2_p4_rest.boost(&parent_to_cm_boost);
525 d1.generate_decay(rng, p1_p4_cm, cm_to_lab_boost, p4_storage);
526 d2.generate_decay(rng, p2_p4_cm, cm_to_lab_boost, p4_storage);
527 }
528}
529
530#[derive(Clone, Debug)]
532pub struct GeneratedTwoToTwoReaction {
533 p1: GeneratedParticle,
534 p2: GeneratedParticle,
535 p3: GeneratedParticle,
536 p4: GeneratedParticle,
537 tdist: MandelstamTDistribution,
538 p1_p3_lab_dir: Vec3,
539 p2_p3_lab_dir: Vec3,
540}
541
542impl GeneratedTwoToTwoReaction {
543 pub fn new(
545 p1: GeneratedParticle,
546 p2: GeneratedParticle,
547 p3: GeneratedParticle,
548 p4: GeneratedParticle,
549 tdist: MandelstamTDistribution,
550 ) -> LadduResult<Self> {
551 validate_initial_role(&p1, "p1")?;
552 validate_initial_role(&p2, "p2")?;
553 validate_final_role(&p3, "p3")?;
554 validate_final_role(&p4, "p4")?;
555 let reaction = Self {
556 p1,
557 p2,
558 p3,
559 p4,
560 tdist,
561 p1_p3_lab_dir: Vec3::z(),
562 p2_p3_lab_dir: -Vec3::z(),
563 };
564 reaction.validate()?;
565 Ok(reaction)
566 }
567
568 fn validate(&self) -> LadduResult<()> {
569 let mut seen = HashSet::new();
570 for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
571 particle.collect_ids(&mut seen)?;
572 particle.validate_reconstruction()?;
573 }
574 Ok(())
575 }
576
577 fn p4_labels(&self) -> Vec<String> {
578 let mut labels = Vec::new();
579 for particle in [&self.p1, &self.p2, &self.p3, &self.p4] {
580 labels.append(&mut particle.p4_labels());
581 }
582 labels
583 }
584
585 fn layout_components(
586 &self,
587 storage: &GeneratedStorage,
588 ) -> (Vec<GeneratedParticleLayout>, Vec<GeneratedVertexLayout>) {
589 let mut particles = Vec::new();
590 let mut vertices = vec![GeneratedVertexLayout {
591 vertex_id: 0,
592 kind: GeneratedVertexKind::Production,
593 incoming_product_ids: Vec::new(),
594 outgoing_product_ids: Vec::new(),
595 }];
596 let p1_id = self
597 .p1
598 .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
599 let p2_id = self
600 .p2
601 .append_decay_layout(None, None, storage, &mut particles, &mut vertices);
602 let p3_id =
603 self.p3
604 .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
605 let p4_id =
606 self.p4
607 .append_decay_layout(None, Some(0), storage, &mut particles, &mut vertices);
608 vertices[0].incoming_product_ids = vec![p1_id, p2_id];
609 vertices[0].outgoing_product_ids = vec![p3_id, p4_id];
610 (particles, vertices)
611 }
612
613 fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
614 self.particle_layouts_with_storage(&GeneratedStorage::All)
615 }
616
617 fn particle_layouts_with_storage(
618 &self,
619 storage: &GeneratedStorage,
620 ) -> Vec<GeneratedParticleLayout> {
621 self.layout_components(storage).0
622 }
623
624 fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
625 self.layout_components(&GeneratedStorage::All).1
626 }
627
628 fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
629 Reaction::two_to_two(
630 &self.p1.generated_particle()?,
631 &self.p2.generated_particle()?,
632 &self.p3.generated_particle()?,
633 &self.p4.generated_particle()?,
634 )
635 }
636
637 fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
638 let GeneratedParticle::Initial {
639 id: p1_id,
640 generator: p1_generator,
641 ..
642 } = &self.p1
643 else {
644 unreachable!("validated generated two-to-two p1 role")
645 };
646 let GeneratedParticle::Initial {
647 id: p2_id,
648 generator: p2_generator,
649 ..
650 } = &self.p2
651 else {
652 unreachable!("validated generated two-to-two p2 role")
653 };
654
655 let p1_e = p1_generator.energy_distribution.sample(rng);
656 let p1_m = p1_generator.mass;
657 let p1_msq = p1_m * p1_m;
658 let p1_p4_lab = rng.p4(p1_m, p1_e, self.p1_p3_lab_dir);
659 if let Some(storage) = p4_storage.get_mut(p1_id) {
660 storage.push(p1_p4_lab)
661 }
662
663 let p2_e = p2_generator.energy_distribution.sample(rng);
664 let p2_m = p2_generator.mass;
665 let p2_msq = p2_m * p2_m;
666 let p2_p4_lab = rng.p4(p2_m, p2_e, self.p2_p3_lab_dir);
667 if let Some(storage) = p4_storage.get_mut(p2_id) {
668 storage.push(p2_p4_lab)
669 }
670
671 let cm = p1_p4_lab + p2_p4_lab;
672 let cm_boost = -cm.beta();
673 let s = cm.mag2();
674 let sqrt_s = s.sqrt();
675 let t = self.tdist.sample(rng);
676
677 let p1_p4_cm = p1_p4_lab.boost(&cm_boost);
678 let p3_m = self.p3.sample_mass(rng);
679 let p3_msq = p3_m * p3_m;
680 let p4_m = self.p4.sample_mass(rng);
681 let p4_msq = p4_m * p4_m;
682 let p_in_mag = q_m(sqrt_s, p1_m, p2_m, Sheet::Physical).re;
683 let p_out_mag = q_m(sqrt_s, p3_m, p4_m, Sheet::Physical).re;
684 let p1_e_cm = (s + p1_msq - p2_msq) / (2.0 * sqrt_s);
685 let p3_e_cm = (s + p3_msq - p4_msq) / (2.0 * sqrt_s);
686 let p4_e_cm = (s + p4_msq - p3_msq) / (2.0 * sqrt_s);
687 let costheta =
688 (t - p1_msq - p3_msq + 2.0 * p1_e_cm * p3_e_cm) / (2.0 * p_in_mag * p_out_mag);
689 let costheta = costheta.clamp(-1.0, 1.0);
690 let sintheta = (1.0 - costheta * costheta).sqrt();
691 let phi = rng.uniform(0.0, 2.0 * PI);
692 let (sin_phi, cos_phi) = phi.sin_cos();
693 let (x, y, z) = basis(p1_p4_cm.vec3());
694 let p3_dir_cm = x * (sintheta * cos_phi) + y * (sintheta * sin_phi) + z * costheta;
695
696 let p3_p4_cm = (p3_dir_cm * p_out_mag).with_energy(p3_e_cm);
697 self.p3
698 .generate_decay(rng, p3_p4_cm, &-cm_boost, p4_storage);
699 let p4_p4_cm = (-p3_dir_cm * p_out_mag).with_energy(p4_e_cm);
700 self.p4
701 .generate_decay(rng, p4_p4_cm, &-cm_boost, p4_storage);
702 }
703}
704
705fn validate_initial_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
706 if matches!(particle, GeneratedParticle::Initial { .. }) {
707 Ok(())
708 } else {
709 Err(LadduError::Custom(format!(
710 "generated two-to-two role '{role}' requires an initial particle"
711 )))
712 }
713}
714
715fn validate_final_role(particle: &GeneratedParticle, role: &str) -> LadduResult<()> {
716 if matches!(
717 particle,
718 GeneratedParticle::Stable { .. } | GeneratedParticle::Composite { .. }
719 ) {
720 Ok(())
721 } else {
722 Err(LadduError::Custom(format!(
723 "generated two-to-two role '{role}' requires an outgoing particle"
724 )))
725 }
726}
727
728#[derive(Clone, Debug)]
730pub enum GeneratedReactionTopology {
731 TwoToTwo(GeneratedTwoToTwoReaction),
733}
734
735impl GeneratedReactionTopology {
736 fn p4_labels(&self) -> Vec<String> {
737 match self {
738 Self::TwoToTwo(reaction) => reaction.p4_labels(),
739 }
740 }
741
742 fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
743 match self {
744 Self::TwoToTwo(reaction) => reaction.particle_layouts(),
745 }
746 }
747
748 fn particle_layouts_with_storage(
749 &self,
750 storage: &GeneratedStorage,
751 ) -> Vec<GeneratedParticleLayout> {
752 match self {
753 Self::TwoToTwo(reaction) => reaction.particle_layouts_with_storage(storage),
754 }
755 }
756
757 fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
758 match self {
759 Self::TwoToTwo(reaction) => reaction.vertex_layouts(),
760 }
761 }
762
763 fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
764 match self {
765 Self::TwoToTwo(reaction) => reaction.reconstructed_reaction(),
766 }
767 }
768
769 fn generate_event(&self, rng: &mut Rng, p4_storage: &mut HashMap<String, Vec<Vec4>>) {
770 match self {
771 Self::TwoToTwo(reaction) => reaction.generate_event(rng, p4_storage),
772 }
773 }
774}
775
776#[derive(Clone, Debug)]
778pub struct GeneratedReaction {
779 topology: GeneratedReactionTopology,
780}
781
782impl GeneratedReaction {
783 pub fn two_to_two(
785 p1: GeneratedParticle,
786 p2: GeneratedParticle,
787 p3: GeneratedParticle,
788 p4: GeneratedParticle,
789 tdist: MandelstamTDistribution,
790 ) -> LadduResult<Self> {
791 Ok(Self {
792 topology: GeneratedReactionTopology::TwoToTwo(GeneratedTwoToTwoReaction::new(
793 p1, p2, p3, p4, tdist,
794 )?),
795 })
796 }
797
798 pub fn p4_labels(&self) -> Vec<String> {
800 self.topology.p4_labels()
801 }
802
803 pub fn particle_layouts(&self) -> Vec<GeneratedParticleLayout> {
805 self.topology.particle_layouts()
806 }
807
808 pub fn particle_layouts_with_storage(
810 &self,
811 storage: &GeneratedStorage,
812 ) -> Vec<GeneratedParticleLayout> {
813 self.topology.particle_layouts_with_storage(storage)
814 }
815
816 pub fn vertex_layouts(&self) -> Vec<GeneratedVertexLayout> {
818 self.topology.vertex_layouts()
819 }
820
821 pub fn reconstructed_reaction(&self) -> LadduResult<Reaction> {
823 self.topology.reconstructed_reaction()
824 }
825
826 fn generate(
827 &self,
828 rng: &mut Rng,
829 p4_storage: &mut HashMap<String, Vec<Vec4>>,
830 n_events: usize,
831 ) {
832 for _ in 0..n_events {
833 self.topology.generate_event(rng, p4_storage);
834 }
835 }
836}
837
838#[derive(Clone, Debug, PartialEq, Eq)]
840pub struct GeneratedParticleLayout {
841 id: String,
842 product_id: usize,
843 parent_id: Option<usize>,
844 species: Option<ParticleSpecies>,
845 p4_label: Option<String>,
846 produced_vertex_id: Option<usize>,
847 decay_vertex_id: Option<usize>,
848}
849
850impl GeneratedParticleLayout {
851 pub fn id(&self) -> &str {
853 &self.id
854 }
855
856 pub fn product_id(&self) -> usize {
858 self.product_id
859 }
860
861 pub fn parent_id(&self) -> Option<usize> {
863 self.parent_id
864 }
865
866 pub fn species(&self) -> Option<&ParticleSpecies> {
868 self.species.as_ref()
869 }
870
871 pub fn p4_label(&self) -> Option<&str> {
873 self.p4_label.as_deref()
874 }
875
876 pub fn produced_vertex_id(&self) -> Option<usize> {
878 self.produced_vertex_id
879 }
880
881 pub fn decay_vertex_id(&self) -> Option<usize> {
883 self.decay_vertex_id
884 }
885}
886
887#[derive(Clone, Copy, Debug, PartialEq, Eq)]
889pub enum GeneratedVertexKind {
890 Production,
892 Decay,
894}
895
896#[derive(Clone, Debug, PartialEq, Eq)]
898pub struct GeneratedVertexLayout {
899 vertex_id: usize,
900 kind: GeneratedVertexKind,
901 incoming_product_ids: Vec<usize>,
902 outgoing_product_ids: Vec<usize>,
903}
904
905impl GeneratedVertexLayout {
906 pub fn vertex_id(&self) -> usize {
908 self.vertex_id
909 }
910
911 pub fn kind(&self) -> GeneratedVertexKind {
913 self.kind
914 }
915
916 pub fn incoming_product_ids(&self) -> &[usize] {
918 &self.incoming_product_ids
919 }
920
921 pub fn outgoing_product_ids(&self) -> &[usize] {
923 &self.outgoing_product_ids
924 }
925}
926
927#[derive(Clone, Debug, PartialEq, Eq)]
929pub struct GeneratedEventLayout {
930 p4_labels: Vec<String>,
931 aux_labels: Vec<String>,
932 particles: Vec<GeneratedParticleLayout>,
933 vertices: Vec<GeneratedVertexLayout>,
934}
935
936impl GeneratedEventLayout {
937 pub fn new(
939 p4_labels: Vec<String>,
940 aux_labels: Vec<String>,
941 particles: Vec<GeneratedParticleLayout>,
942 vertices: Vec<GeneratedVertexLayout>,
943 ) -> Self {
944 Self {
945 p4_labels,
946 aux_labels,
947 particles,
948 vertices,
949 }
950 }
951
952 pub fn p4_labels(&self) -> &[String] {
954 &self.p4_labels
955 }
956
957 pub fn aux_labels(&self) -> &[String] {
959 &self.aux_labels
960 }
961
962 pub fn particles(&self) -> &[GeneratedParticleLayout] {
964 &self.particles
965 }
966
967 pub fn particle(&self, id: &str) -> Option<&GeneratedParticleLayout> {
969 self.particles.iter().find(|particle| particle.id() == id)
970 }
971
972 pub fn product(&self, product_id: usize) -> Option<&GeneratedParticleLayout> {
974 self.particles
975 .iter()
976 .find(|particle| particle.product_id() == product_id)
977 }
978
979 pub fn vertices(&self) -> &[GeneratedVertexLayout] {
981 &self.vertices
982 }
983
984 pub fn vertex(&self, vertex_id: usize) -> Option<&GeneratedVertexLayout> {
986 self.vertices
987 .iter()
988 .find(|vertex| vertex.vertex_id() == vertex_id)
989 }
990
991 pub fn production_vertex(&self) -> Option<&GeneratedVertexLayout> {
993 self.vertices
994 .iter()
995 .find(|vertex| vertex.kind() == GeneratedVertexKind::Production)
996 }
997
998 pub fn decay_products(&self, parent_product_id: usize) -> Vec<&GeneratedParticleLayout> {
1000 self.particles
1001 .iter()
1002 .filter(|particle| particle.parent_id() == Some(parent_product_id))
1003 .collect()
1004 }
1005
1006 pub fn production_incoming(&self) -> Vec<&GeneratedParticleLayout> {
1008 self.production_vertex_products(GeneratedVertexLayout::incoming_product_ids)
1009 }
1010
1011 pub fn production_outgoing(&self) -> Vec<&GeneratedParticleLayout> {
1013 self.production_vertex_products(GeneratedVertexLayout::outgoing_product_ids)
1014 }
1015
1016 fn production_vertex_products(
1017 &self,
1018 ids: impl FnOnce(&GeneratedVertexLayout) -> &[usize],
1019 ) -> Vec<&GeneratedParticleLayout> {
1020 self.production_vertex()
1021 .map(|vertex| {
1022 ids(vertex)
1023 .iter()
1024 .filter_map(|product_id| self.product(*product_id))
1025 .collect()
1026 })
1027 .unwrap_or_default()
1028 }
1029}
1030
1031#[derive(Clone, Debug)]
1033pub struct GeneratedBatch {
1034 dataset: Dataset,
1035 reaction: GeneratedReaction,
1036 layout: GeneratedEventLayout,
1037}
1038
1039impl GeneratedBatch {
1040 pub fn new(
1042 dataset: Dataset,
1043 reaction: GeneratedReaction,
1044 layout: GeneratedEventLayout,
1045 ) -> Self {
1046 Self {
1047 dataset,
1048 reaction,
1049 layout,
1050 }
1051 }
1052
1053 pub fn dataset(&self) -> &Dataset {
1055 &self.dataset
1056 }
1057
1058 pub fn into_dataset(self) -> Dataset {
1060 self.dataset
1061 }
1062
1063 pub fn reaction(&self) -> &GeneratedReaction {
1065 &self.reaction
1066 }
1067
1068 pub fn layout(&self) -> &GeneratedEventLayout {
1070 &self.layout
1071 }
1072}
1073
1074#[derive(Clone, Debug)]
1076pub struct EventGenerator {
1077 reaction: GeneratedReaction,
1078 aux_generators: HashMap<String, Distribution>,
1079 storage: GeneratedStorage,
1080 seed: u64,
1081}
1082
1083#[derive(Clone, Debug)]
1085pub struct GeneratedBatchIter {
1086 generator: EventGenerator,
1087 remaining_events: usize,
1088 batch_size: usize,
1089 rng: Rng,
1090}
1091
1092impl Iterator for GeneratedBatchIter {
1093 type Item = LadduResult<GeneratedBatch>;
1094
1095 fn next(&mut self) -> Option<Self::Item> {
1096 if self.remaining_events == 0 {
1097 return None;
1098 }
1099 let n_events = self.batch_size.min(self.remaining_events);
1100 self.remaining_events -= n_events;
1101 Some(
1102 self.generator
1103 .generate_batch_with_rng(n_events, &mut self.rng),
1104 )
1105 }
1106}
1107
1108pub trait BatchIntensity {
1110 fn evaluate(&mut self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>>;
1112}
1113
1114impl<F> BatchIntensity for F
1115where
1116 F: FnMut(&GeneratedBatch) -> LadduResult<Vec<f64>>,
1117{
1118 fn evaluate(&mut self, batch: &GeneratedBatch) -> LadduResult<Vec<f64>> {
1119 self(batch)
1120 }
1121}
1122
1123#[derive(Clone, Debug)]
1125pub enum RejectionEnvelope {
1126 Fixed {
1128 max_weight: f64,
1130 },
1131}
1132
1133impl RejectionEnvelope {
1134 fn max_weight(&self) -> f64 {
1135 match self {
1136 Self::Fixed { max_weight } => *max_weight,
1137 }
1138 }
1139}
1140
1141#[derive(Clone, Debug)]
1143pub struct RejectionSamplingOptions {
1144 pub target_accepted: usize,
1146 pub generation_batch_size: usize,
1148 pub output_batch_size: usize,
1150 pub envelope: RejectionEnvelope,
1152 pub seed: u64,
1154}
1155
1156#[derive(Clone, Debug, Default)]
1158pub struct RejectionSamplingDiagnostics {
1159 pub generated_events: usize,
1161 pub accepted_events: usize,
1163 pub rejected_events: usize,
1165 pub max_observed_weight: f64,
1167 pub envelope_max_weight: f64,
1169 pub envelope_violations: usize,
1171}
1172
1173impl RejectionSamplingDiagnostics {
1174 pub fn acceptance_efficiency(&self) -> f64 {
1176 if self.generated_events == 0 {
1177 0.0
1178 } else {
1179 self.accepted_events as f64 / self.generated_events as f64
1180 }
1181 }
1182}
1183
1184#[derive(Clone, Debug)]
1186pub struct RejectionSampler<I> {
1187 generator: EventGenerator,
1188 intensity: I,
1189 options: RejectionSamplingOptions,
1190}
1191
1192impl<I> RejectionSampler<I>
1193where
1194 I: BatchIntensity,
1195{
1196 pub fn new(
1198 generator: EventGenerator,
1199 intensity: I,
1200 options: RejectionSamplingOptions,
1201 ) -> LadduResult<Self> {
1202 if options.generation_batch_size == 0 {
1203 return Err(LadduError::Custom(
1204 "generation_batch_size must be greater than zero".to_string(),
1205 ));
1206 }
1207 if options.output_batch_size == 0 {
1208 return Err(LadduError::Custom(
1209 "output_batch_size must be greater than zero".to_string(),
1210 ));
1211 }
1212 let max_weight = options.envelope.max_weight();
1213 if !max_weight.is_finite() || max_weight <= 0.0 {
1214 return Err(LadduError::Custom(
1215 "rejection envelope max_weight must be finite and positive".to_string(),
1216 ));
1217 }
1218 Ok(Self {
1219 generator,
1220 intensity,
1221 options,
1222 })
1223 }
1224
1225 pub fn accepted_batches(self) -> RejectionSampleIter<I> {
1227 let envelope_max_weight = self.options.envelope.max_weight();
1228 RejectionSampleIter {
1229 generation_rng: Rng::with_seed(self.generator.seed),
1230 rejection_rng: Rng::with_seed(self.options.seed),
1231 diagnostics: RejectionSamplingDiagnostics {
1232 envelope_max_weight,
1233 ..Default::default()
1234 },
1235 sampler: self,
1236 current_batch: None,
1237 current_intensities: Vec::new(),
1238 current_index: 0,
1239 }
1240 }
1241}
1242
1243#[derive(Clone, Debug)]
1245pub struct RejectionSampleIter<I> {
1246 sampler: RejectionSampler<I>,
1247 generation_rng: Rng,
1248 rejection_rng: Rng,
1249 diagnostics: RejectionSamplingDiagnostics,
1250 current_batch: Option<GeneratedBatch>,
1251 current_intensities: Vec<f64>,
1252 current_index: usize,
1253}
1254
1255impl<I> RejectionSampleIter<I> {
1256 pub fn diagnostics(&self) -> &RejectionSamplingDiagnostics {
1258 &self.diagnostics
1259 }
1260}
1261
1262impl<I> RejectionSampleIter<I>
1263where
1264 I: BatchIntensity,
1265{
1266 fn load_next_source_batch(&mut self) -> LadduResult<()> {
1267 let batch = self.sampler.generator.generate_batch_with_rng(
1268 self.sampler.options.generation_batch_size,
1269 &mut self.generation_rng,
1270 )?;
1271 let intensities = self.sampler.intensity.evaluate(&batch)?;
1272 if intensities.len() != batch.dataset().n_events() {
1273 return Err(LadduError::Custom(format!(
1274 "intensity length mismatch: expected {}, got {}",
1275 batch.dataset().n_events(),
1276 intensities.len()
1277 )));
1278 }
1279 self.diagnostics.generated_events += batch.dataset().n_events();
1280 self.current_batch = Some(batch);
1281 self.current_intensities = intensities;
1282 self.current_index = 0;
1283 Ok(())
1284 }
1285
1286 fn empty_output_batch(source: &GeneratedBatch) -> GeneratedBatch {
1287 GeneratedBatch::new(
1288 Dataset::empty_local(source.dataset().metadata().clone()),
1289 source.reaction().clone(),
1290 source.layout().clone(),
1291 )
1292 }
1293}
1294
1295impl<I> Iterator for RejectionSampleIter<I>
1296where
1297 I: BatchIntensity,
1298{
1299 type Item = LadduResult<GeneratedBatch>;
1300
1301 fn next(&mut self) -> Option<Self::Item> {
1302 if self.diagnostics.accepted_events >= self.sampler.options.target_accepted {
1303 return None;
1304 }
1305
1306 let mut output: Option<GeneratedBatch> = None;
1307 while self.diagnostics.accepted_events < self.sampler.options.target_accepted {
1308 let needs_batch = self
1309 .current_batch
1310 .as_ref()
1311 .map(|batch| self.current_index >= batch.dataset().n_events())
1312 .unwrap_or(true);
1313 if needs_batch {
1314 if let Err(err) = self.load_next_source_batch() {
1315 return Some(Err(err));
1316 }
1317 }
1318
1319 let source = self
1320 .current_batch
1321 .as_ref()
1322 .expect("source batch should be loaded");
1323 if output.is_none() {
1324 output = Some(Self::empty_output_batch(source));
1325 }
1326
1327 let weight = self.current_intensities[self.current_index];
1328 if !weight.is_finite() || weight < 0.0 {
1329 return Some(Err(LadduError::Custom(format!(
1330 "intensity at event {} must be finite and nonnegative, got {weight}",
1331 self.current_index
1332 ))));
1333 }
1334 self.diagnostics.max_observed_weight = self.diagnostics.max_observed_weight.max(weight);
1335 let envelope_max = self.sampler.options.envelope.max_weight();
1336 if weight > envelope_max {
1337 self.diagnostics.envelope_violations += 1;
1338 return Some(Err(LadduError::Custom(format!(
1339 "rejection envelope violation: observed weight {weight} exceeds max_weight {envelope_max}"
1340 ))));
1341 }
1342
1343 let accepted = self.rejection_rng.f64() * envelope_max < weight;
1344 if accepted {
1345 let event = match source.dataset().event_global(self.current_index) {
1346 Ok(event) => event,
1347 Err(err) => return Some(Err(err)),
1348 };
1349 if let Err(err) = output.as_mut().unwrap().dataset.push_event_local(
1350 event.p4s.clone(),
1351 event.aux.clone(),
1352 event.weight,
1353 ) {
1354 return Some(Err(err));
1355 }
1356 self.diagnostics.accepted_events += 1;
1357 } else {
1358 self.diagnostics.rejected_events += 1;
1359 }
1360 self.current_index += 1;
1361
1362 if output.as_ref().unwrap().dataset().n_events()
1363 >= self.sampler.options.output_batch_size
1364 || self.diagnostics.accepted_events >= self.sampler.options.target_accepted
1365 {
1366 break;
1367 }
1368 }
1369
1370 output
1371 .filter(|batch| batch.dataset().n_events() > 0)
1372 .map(Ok)
1373 }
1374}
1375
1376impl EventGenerator {
1377 pub fn new(
1379 reaction: GeneratedReaction,
1380 aux_generators: HashMap<String, Distribution>,
1381 seed: Option<u64>,
1382 ) -> Self {
1383 Self {
1384 reaction,
1385 aux_generators,
1386 storage: GeneratedStorage::All,
1387 seed: seed.unwrap_or_else(|| fastrand::u64(..)),
1388 }
1389 }
1390
1391 pub fn storage(&self) -> &GeneratedStorage {
1393 &self.storage
1394 }
1395
1396 pub fn with_storage(mut self, storage: GeneratedStorage) -> LadduResult<Self> {
1398 storage.validate(&self.reaction.p4_labels())?;
1399 self.storage = storage;
1400 Ok(self)
1401 }
1402
1403 fn aux_entries(&self) -> Vec<(&String, &Distribution)> {
1404 let mut aux_entries = self.aux_generators.iter().collect::<Vec<_>>();
1405 aux_entries.sort_by_key(|(label, _)| *label);
1406 aux_entries
1407 }
1408
1409 fn generate_batch_with_rng(
1410 &self,
1411 n_events: usize,
1412 rng: &mut Rng,
1413 ) -> LadduResult<GeneratedBatch> {
1414 let all_p4_labels = self.reaction.p4_labels();
1415 self.storage.validate(&all_p4_labels)?;
1416 let p4_labels = self.storage.stored_labels(&all_p4_labels);
1417 let aux_entries = self.aux_entries();
1418 let aux_labels = aux_entries
1419 .iter()
1420 .map(|(label, _)| (*label).clone())
1421 .collect::<Vec<_>>();
1422 let mut p4_data: HashMap<String, Vec<Vec4>> = p4_labels
1423 .iter()
1424 .map(|label| (label.clone(), Vec::with_capacity(n_events)))
1425 .collect();
1426 let metadata = DatasetMetadata::new(p4_labels.clone(), aux_labels.clone())?;
1427 let mut aux: Vec<Vec<f64>> = aux_entries
1428 .iter()
1429 .map(|_| Vec::with_capacity(n_events))
1430 .collect();
1431 let weights = vec![1.0; n_events];
1432 for _ in 0..n_events {
1433 for ((_, distribution), column) in aux_entries.iter().zip(aux.iter_mut()) {
1434 column.push(distribution.sample(rng));
1435 }
1436 self.reaction.generate(rng, &mut p4_data, 1);
1437 }
1438 let p4 = p4_labels
1439 .iter()
1440 .filter_map(|label| p4_data.remove(label))
1441 .collect();
1442 let dataset = Dataset::from_columns_local(metadata, p4, aux, weights)?;
1443 Ok(GeneratedBatch::new(
1444 dataset,
1445 self.reaction.clone(),
1446 GeneratedEventLayout::new(
1447 p4_labels,
1448 aux_labels,
1449 self.reaction.particle_layouts_with_storage(&self.storage),
1450 self.reaction.vertex_layouts(),
1451 ),
1452 ))
1453 }
1454
1455 pub fn generate_batch(&self, n_events: usize) -> LadduResult<GeneratedBatch> {
1457 let mut rng = Rng::with_seed(self.seed);
1458 self.generate_batch_with_rng(n_events, &mut rng)
1459 }
1460
1461 pub fn generate_batches(
1466 &self,
1467 total_events: usize,
1468 batch_size: usize,
1469 ) -> LadduResult<GeneratedBatchIter> {
1470 if batch_size == 0 {
1471 return Err(LadduError::Custom(
1472 "batch_size must be greater than zero".to_string(),
1473 ));
1474 }
1475 Ok(GeneratedBatchIter {
1476 generator: self.clone(),
1477 remaining_events: total_events,
1478 batch_size,
1479 rng: Rng::with_seed(self.seed),
1480 })
1481 }
1482
1483 pub fn generate_dataset(&self, n_events: usize) -> LadduResult<Dataset> {
1485 Ok(self.generate_batch(n_events)?.into_dataset())
1486 }
1487}
1488
1489#[cfg(test)]
1490mod tests {
1491 use approx::assert_relative_eq;
1492 use laddu_core::{traits::Variable, Channel, Frame};
1493
1494 use super::*;
1495
1496 fn demo_reaction() -> GeneratedReaction {
1497 let beam = GeneratedParticle::initial(
1498 "beam",
1499 InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1500 Reconstruction::Stored,
1501 );
1502 let target = GeneratedParticle::initial(
1503 "target",
1504 InitialGenerator::target(0.938272),
1505 Reconstruction::Missing,
1506 );
1507 let ks1 = GeneratedParticle::stable(
1508 "kshort1",
1509 StableGenerator::new(0.497611),
1510 Reconstruction::Stored,
1511 );
1512 let ks2 = GeneratedParticle::stable(
1513 "kshort2",
1514 StableGenerator::new(0.497611),
1515 Reconstruction::Stored,
1516 );
1517 let kk = GeneratedParticle::composite(
1518 "kk",
1519 CompositeGenerator::new(1.1, 1.6),
1520 (&ks1, &ks2),
1521 Reconstruction::Composite,
1522 );
1523 let recoil = GeneratedParticle::stable(
1524 "recoil",
1525 StableGenerator::new(0.938272),
1526 Reconstruction::Stored,
1527 );
1528 let tdist = MandelstamTDistribution::Exponential { slope: 0.1 };
1529 GeneratedReaction::two_to_two(beam, target, kk, recoil, tdist).unwrap()
1530 }
1531
1532 #[test]
1533 fn test_generation() {
1534 let reaction = demo_reaction();
1535 let generator = EventGenerator::new(reaction, HashMap::new(), Some(12345));
1536 let n_events = 1_000;
1537 let dataset = generator.generate_dataset(n_events).unwrap();
1538 assert_eq!(dataset.n_events(), n_events);
1539 let metadata = dataset.metadata();
1540 assert!(metadata.p4_index("beam").is_some());
1541 assert!(metadata.p4_index("target").is_some());
1542 assert!(metadata.p4_index("kk").is_some());
1543 assert!(metadata.p4_index("kshort1").is_some());
1544 assert!(metadata.p4_index("kshort2").is_some());
1545 assert!(metadata.p4_index("recoil").is_some());
1546
1547 for event in dataset.events_global() {
1548 let beam_p4 = event.p4("beam").unwrap();
1549 let target_p4 = event.p4("target").unwrap();
1550 let kk_p4 = event.p4("kk").unwrap();
1551 let kshort1_p4 = event.p4("kshort1").unwrap();
1552 let kshort2_p4 = event.p4("kshort2").unwrap();
1553 let recoil_p4 = event.p4("recoil").unwrap();
1554
1555 assert!(beam_p4.e().is_finite());
1556 assert!(target_p4.e().is_finite());
1557 assert!(kk_p4.e().is_finite());
1558 assert!(kshort1_p4.e().is_finite());
1559 assert!(kshort2_p4.e().is_finite());
1560 assert!(recoil_p4.e().is_finite());
1561
1562 assert_relative_eq!(kk_p4, kshort1_p4 + kshort2_p4, epsilon = 1e-10);
1563 assert_relative_eq!(beam_p4 + target_p4, kk_p4 + recoil_p4, epsilon = 1e-10);
1564 assert_relative_eq!(kshort1_p4.m(), 0.497611, epsilon = 1e-10);
1565 assert_relative_eq!(kshort2_p4.m(), 0.497611, epsilon = 1e-10);
1566 assert_relative_eq!(recoil_p4.m(), 0.938272, epsilon = 1e-10);
1567 }
1568 }
1569
1570 #[test]
1571 fn test_reconstructed_reaction() {
1572 let generated = demo_reaction();
1573 let reaction = generated.reconstructed_reaction().unwrap();
1574 let dataset = EventGenerator::new(generated, HashMap::new(), Some(12345))
1575 .generate_dataset(4)
1576 .unwrap();
1577 let mass = reaction.mass("kk").value_on(&dataset).unwrap();
1578 let angles = reaction
1579 .decay("kk")
1580 .unwrap()
1581 .angles("kshort1", Frame::Helicity)
1582 .unwrap();
1583 let mandelstam = reaction
1584 .mandelstam(Channel::S)
1585 .unwrap()
1586 .value_on(&dataset)
1587 .unwrap();
1588
1589 assert_eq!(mass.len(), 4);
1590 assert_eq!(
1591 angles.costheta.to_string(),
1592 "CosTheta(parent=kk, daughter=kshort1, frame=Helicity)"
1593 );
1594 assert_eq!(mandelstam.len(), 4);
1595 }
1596
1597 #[test]
1598 fn test_generated_batch_metadata() {
1599 let generated = demo_reaction();
1600 let generator = EventGenerator::new(
1601 generated,
1602 HashMap::from([("pol_angle".to_string(), Distribution::Fixed(0.25))]),
1603 Some(12345),
1604 );
1605 let batch = generator.generate_batch(4).unwrap();
1606
1607 assert_eq!(batch.dataset().n_events(), 4);
1608 assert_eq!(
1609 batch.layout().p4_labels(),
1610 &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1611 );
1612 assert_eq!(batch.layout().aux_labels(), &["pol_angle"]);
1613 assert_eq!(
1614 batch.reaction().p4_labels(),
1615 vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1616 );
1617 assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1618 assert_eq!(batch.dataset().aux_names(), batch.layout().aux_labels());
1619 let particles = batch.layout().particles();
1620 assert_eq!(particles.len(), 6);
1621 assert_eq!(particles[0].id(), "beam");
1622 assert_eq!(particles[0].product_id(), 0);
1623 assert_eq!(particles[0].parent_id(), None);
1624 assert_eq!(particles[0].produced_vertex_id(), None);
1625 assert_eq!(particles[0].decay_vertex_id(), None);
1626 assert_eq!(particles[1].id(), "target");
1627 assert_eq!(particles[1].parent_id(), None);
1628 assert_eq!(particles[1].produced_vertex_id(), None);
1629 assert_eq!(particles[1].decay_vertex_id(), None);
1630 assert_eq!(particles[2].id(), "kk");
1631 assert_eq!(particles[2].product_id(), 2);
1632 assert_eq!(particles[2].parent_id(), None);
1633 assert_eq!(particles[2].produced_vertex_id(), Some(0));
1634 assert_eq!(particles[2].decay_vertex_id(), Some(1));
1635 assert_eq!(particles[3].id(), "kshort1");
1636 assert_eq!(particles[3].parent_id(), Some(2));
1637 assert_eq!(particles[3].produced_vertex_id(), Some(1));
1638 assert_eq!(particles[3].decay_vertex_id(), None);
1639 assert_eq!(particles[4].id(), "kshort2");
1640 assert_eq!(particles[4].parent_id(), Some(2));
1641 assert_eq!(particles[4].produced_vertex_id(), Some(1));
1642 assert_eq!(particles[4].decay_vertex_id(), None);
1643 assert_eq!(particles[5].id(), "recoil");
1644 assert_eq!(particles[5].parent_id(), None);
1645 assert_eq!(particles[5].produced_vertex_id(), Some(0));
1646 assert_eq!(particles[5].decay_vertex_id(), None);
1647 for particle in particles {
1648 assert_eq!(particle.p4_label(), Some(particle.id()));
1649 }
1650 let vertices = batch.layout().vertices();
1651 assert_eq!(vertices.len(), 2);
1652 assert_eq!(vertices[0].vertex_id(), 0);
1653 assert_eq!(vertices[0].kind(), GeneratedVertexKind::Production);
1654 assert_eq!(vertices[0].incoming_product_ids(), &[0, 1]);
1655 assert_eq!(vertices[0].outgoing_product_ids(), &[2, 5]);
1656 assert_eq!(vertices[1].vertex_id(), 1);
1657 assert_eq!(vertices[1].kind(), GeneratedVertexKind::Decay);
1658 assert_eq!(vertices[1].incoming_product_ids(), &[2]);
1659 assert_eq!(vertices[1].outgoing_product_ids(), &[3, 4]);
1660
1661 assert_eq!(batch.layout().particle("kk"), Some(&particles[2]));
1662 assert_eq!(batch.layout().particle("missing_id"), None);
1663 assert_eq!(batch.layout().product(5), Some(&particles[5]));
1664 assert_eq!(batch.layout().product(6), None);
1665 assert_eq!(batch.layout().vertex(1), Some(&vertices[1]));
1666 assert_eq!(batch.layout().vertex(2), None);
1667 assert_eq!(batch.layout().production_vertex(), Some(&vertices[0]));
1668 assert_eq!(
1669 batch
1670 .layout()
1671 .production_incoming()
1672 .iter()
1673 .map(|particle| particle.id())
1674 .collect::<Vec<_>>(),
1675 vec!["beam", "target"]
1676 );
1677 assert_eq!(
1678 batch
1679 .layout()
1680 .production_outgoing()
1681 .iter()
1682 .map(|particle| particle.id())
1683 .collect::<Vec<_>>(),
1684 vec!["kk", "recoil"]
1685 );
1686 assert_eq!(
1687 batch
1688 .layout()
1689 .decay_products(2)
1690 .iter()
1691 .map(|particle| particle.id())
1692 .collect::<Vec<_>>(),
1693 vec!["kshort1", "kshort2"]
1694 );
1695 assert!(batch.layout().decay_products(5).is_empty());
1696 }
1697
1698 #[test]
1699 fn generated_storage_only_projects_dataset_columns() {
1700 let generated = demo_reaction();
1701 let generator = EventGenerator::new(generated, HashMap::new(), Some(12345))
1702 .with_storage(GeneratedStorage::only([
1703 "beam", "target", "kshort1", "kshort2", "recoil",
1704 ]))
1705 .unwrap();
1706 let batch = generator.generate_batch(4).unwrap();
1707
1708 assert_eq!(
1709 batch.reaction().p4_labels(),
1710 vec!["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1711 );
1712 assert_eq!(
1713 batch.layout().p4_labels(),
1714 &["beam", "target", "kshort1", "kshort2", "recoil"]
1715 );
1716 assert_eq!(batch.dataset().p4_names(), batch.layout().p4_labels());
1717 assert!(batch.dataset().metadata().p4_index("kk").is_none());
1718
1719 let particles = batch.layout().particles();
1720 assert_eq!(particles.len(), 6);
1721 assert_eq!(particles[2].id(), "kk");
1722 assert_eq!(particles[2].p4_label(), None);
1723 assert_eq!(particles[3].p4_label(), Some("kshort1"));
1724 assert_eq!(particles[4].p4_label(), Some("kshort2"));
1725
1726 for index in 0..batch.dataset().n_events() {
1727 let event = batch.dataset().event_global(index).unwrap();
1728 assert_relative_eq!(
1729 event.p4("beam").unwrap() + event.p4("target").unwrap(),
1730 event.p4("kshort1").unwrap()
1731 + event.p4("kshort2").unwrap()
1732 + event.p4("recoil").unwrap(),
1733 epsilon = 1e-10
1734 );
1735 }
1736 }
1737
1738 #[test]
1739 fn generated_storage_rejects_unknown_and_duplicate_ids() {
1740 assert!(
1741 EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1742 .with_storage(GeneratedStorage::only(["beam", "does_not_exist"]))
1743 .is_err()
1744 );
1745 assert!(
1746 EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345))
1747 .with_storage(GeneratedStorage::only(["beam", "beam"]))
1748 .is_err()
1749 );
1750 }
1751
1752 #[test]
1753 fn generated_species_metadata_propagates_to_layout() {
1754 let beam = GeneratedParticle::initial(
1755 "beam",
1756 InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1757 Reconstruction::Stored,
1758 )
1759 .with_species(ParticleSpecies::code(22));
1760 let target = GeneratedParticle::initial(
1761 "target",
1762 InitialGenerator::target(0.938272),
1763 Reconstruction::Missing,
1764 )
1765 .with_species(ParticleSpecies::with_namespace("pdg", 2212));
1766 let kshort1 = GeneratedParticle::stable(
1767 "kshort1",
1768 StableGenerator::new(0.497611),
1769 Reconstruction::Stored,
1770 )
1771 .with_species(ParticleSpecies::label("KShort"));
1772 let kshort2 = GeneratedParticle::stable(
1773 "kshort2",
1774 StableGenerator::new(0.497611),
1775 Reconstruction::Stored,
1776 )
1777 .with_species(ParticleSpecies::label("KShort"));
1778 let kk = GeneratedParticle::composite(
1779 "kk",
1780 CompositeGenerator::new(1.1, 1.6),
1781 (&kshort1, &kshort2),
1782 Reconstruction::Composite,
1783 )
1784 .with_species(ParticleSpecies::label("KK"));
1785 let recoil = GeneratedParticle::stable(
1786 "recoil",
1787 StableGenerator::new(0.938272),
1788 Reconstruction::Stored,
1789 )
1790 .with_species(ParticleSpecies::code(2212));
1791 let reaction = GeneratedReaction::two_to_two(
1792 beam,
1793 target,
1794 kk,
1795 recoil,
1796 MandelstamTDistribution::Exponential { slope: 0.1 },
1797 )
1798 .unwrap();
1799 let particles = reaction.particle_layouts();
1800
1801 assert_eq!(particles[0].species(), Some(&ParticleSpecies::code(22)));
1802 assert_eq!(
1803 particles[1].species(),
1804 Some(&ParticleSpecies::with_namespace("pdg", 2212))
1805 );
1806 assert_eq!(particles[2].species(), Some(&ParticleSpecies::label("KK")));
1807 assert_eq!(
1808 particles[3].species(),
1809 Some(&ParticleSpecies::label("KShort"))
1810 );
1811 assert_eq!(
1812 particles[4].species(),
1813 Some(&ParticleSpecies::label("KShort"))
1814 );
1815 assert_eq!(particles[5].species(), Some(&ParticleSpecies::code(2212)));
1816 }
1817
1818 #[test]
1819 fn generated_batches_match_one_shot_generation() {
1820 let generated = demo_reaction();
1821 let generator = EventGenerator::new(
1822 generated,
1823 HashMap::from([(
1824 "pol_angle".to_string(),
1825 Distribution::Uniform { min: 0.0, max: 1.0 },
1826 )]),
1827 Some(12345),
1828 );
1829 let one_shot = generator.generate_dataset(7).unwrap();
1830 let batches = generator
1831 .generate_batches(7, 3)
1832 .unwrap()
1833 .collect::<LadduResult<Vec<_>>>()
1834 .unwrap();
1835 let batch_sizes = batches
1836 .iter()
1837 .map(|batch| batch.dataset().n_events())
1838 .collect::<Vec<_>>();
1839 assert_eq!(batch_sizes, vec![3, 3, 1]);
1840
1841 let mut offset = 0;
1842 for batch in batches {
1843 for local_index in 0..batch.dataset().n_events() {
1844 let expected = one_shot.event_global(offset + local_index).unwrap();
1845 let actual = batch.dataset().event_global(local_index).unwrap();
1846 for name in one_shot.p4_names() {
1847 assert_relative_eq!(
1848 actual.p4(name).unwrap(),
1849 expected.p4(name).unwrap(),
1850 epsilon = 1e-10
1851 );
1852 }
1853 for aux_index in 0..one_shot.aux_names().len() {
1854 assert_relative_eq!(actual.aux[aux_index], expected.aux[aux_index]);
1855 }
1856 assert_relative_eq!(actual.weight(), expected.weight());
1857 }
1858 offset += batch.dataset().n_events();
1859 }
1860 assert_eq!(offset, one_shot.n_events());
1861 assert!(generator.generate_batches(1, 0).is_err());
1862 }
1863
1864 #[test]
1865 fn fixed_envelope_rejection_sampler_streams_accepted_batches() {
1866 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
1867 let sampler = RejectionSampler::new(
1868 generator,
1869 |batch: &GeneratedBatch| Ok(vec![1.0; batch.dataset().n_events()]),
1870 RejectionSamplingOptions {
1871 target_accepted: 5,
1872 generation_batch_size: 4,
1873 output_batch_size: 2,
1874 envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
1875 seed: 67890,
1876 },
1877 )
1878 .unwrap();
1879
1880 let mut iter = sampler.accepted_batches();
1881 let mut accepted_batches = Vec::new();
1882 for batch in iter.by_ref() {
1883 accepted_batches.push(batch.unwrap());
1884 }
1885 assert_eq!(
1886 accepted_batches
1887 .iter()
1888 .map(|batch| batch.dataset().n_events())
1889 .collect::<Vec<_>>(),
1890 vec![2, 2, 1]
1891 );
1892 assert_eq!(iter.diagnostics().generated_events, 8);
1893 assert_eq!(iter.diagnostics().accepted_events, 5);
1894 assert_eq!(iter.diagnostics().rejected_events, 0);
1895 assert_relative_eq!(iter.diagnostics().acceptance_efficiency(), 5.0 / 8.0);
1896 for batch in accepted_batches {
1897 assert_eq!(
1898 batch.layout().p4_labels(),
1899 &["beam", "target", "kk", "kshort1", "kshort2", "recoil"]
1900 );
1901 }
1902 }
1903
1904 #[test]
1905 fn fixed_envelope_rejection_sampler_rejects_violations() {
1906 let generator = EventGenerator::new(demo_reaction(), HashMap::new(), Some(12345));
1907 let sampler = RejectionSampler::new(
1908 generator,
1909 |batch: &GeneratedBatch| Ok(vec![2.0; batch.dataset().n_events()]),
1910 RejectionSamplingOptions {
1911 target_accepted: 1,
1912 generation_batch_size: 1,
1913 output_batch_size: 1,
1914 envelope: RejectionEnvelope::Fixed { max_weight: 1.0 },
1915 seed: 67890,
1916 },
1917 )
1918 .unwrap();
1919
1920 let mut iter = sampler.accepted_batches();
1921 let err = iter.next().expect("sampler should produce an error");
1922 assert!(err.is_err());
1923 assert_eq!(iter.diagnostics().envelope_violations, 1);
1924 assert_relative_eq!(iter.diagnostics().max_observed_weight, 2.0);
1925 }
1926
1927 #[test]
1928 fn duplicate_generated_particle_ids_are_rejected() {
1929 let beam = GeneratedParticle::initial(
1930 "beam",
1931 InitialGenerator::beam_with_fixed_energy(0.0, 8.0),
1932 Reconstruction::Stored,
1933 );
1934 let target = GeneratedParticle::initial(
1935 "target",
1936 InitialGenerator::target(0.938272),
1937 Reconstruction::Missing,
1938 );
1939 let duplicate = GeneratedParticle::stable(
1940 "beam",
1941 StableGenerator::new(0.497611),
1942 Reconstruction::Stored,
1943 );
1944 let recoil = GeneratedParticle::stable(
1945 "recoil",
1946 StableGenerator::new(0.938272),
1947 Reconstruction::Stored,
1948 );
1949
1950 assert!(GeneratedReaction::two_to_two(
1951 beam,
1952 target,
1953 duplicate,
1954 recoil,
1955 MandelstamTDistribution::Exponential { slope: 0.1 },
1956 )
1957 .is_err());
1958 }
1959}