Skip to main content

laddu_generation/
distributions.rs

1use fastrand::Rng;
2use fastrand_contrib::RngExt;
3use laddu_core::{math::Histogram, LadduResult, Vec3, Vec4};
4
5/// Sampler for drawing values from a weighted histogram.
6#[derive(Clone, Debug)]
7pub struct HistogramSampler {
8    pub(crate) hist: Histogram,
9    cdf: Vec<f64>,
10    total: f64,
11}
12
13impl HistogramSampler {
14    /// Construct a histogram sampler.
15    pub fn new(hist: Histogram) -> LadduResult<Self> {
16        hist.validate()?;
17        hist.validate_positive_counts()?;
18        let mut cdf = Vec::with_capacity(hist.counts().len());
19        let mut total = 0.0;
20
21        for &count in hist.counts() {
22            total += count;
23            cdf.push(total);
24        }
25        Ok(Self { hist, cdf, total })
26    }
27
28    /// Sample a value uniformly within a histogram bin selected by bin weight.
29    pub fn sample(&self, rng: &mut Rng) -> f64 {
30        let r = rng.f64() * self.total;
31        let bin = self.cdf.partition_point(|&x| x <= r);
32        let lo = self.hist.bin_edges()[bin];
33        let hi = self.hist.bin_edges()[bin + 1];
34        lo + rng.f64() * (hi - lo)
35    }
36}
37
38#[derive(Clone, Debug)]
39pub enum SimpleDistribution {
40    Fixed(f64),
41    Uniform { min: f64, max: f64 },
42    Histogram(HistogramSampler),
43}
44impl SimpleDistribution {
45    pub fn sample(&self, rng: &mut Rng) -> f64 {
46        match self {
47            Self::Fixed(v) => *v,
48            Self::Uniform { min, max } => rng.uniform(*min, *max),
49            Self::Histogram(sampler) => sampler.sample(rng),
50        }
51    }
52}
53
54#[derive(Clone, Debug)]
55pub enum MandelstamTDistribution {
56    Exponential { slope: f64 },
57    Histogram(HistogramSampler),
58}
59impl MandelstamTDistribution {
60    pub fn sample(&self, rng: &mut Rng) -> f64 {
61        match self {
62            Self::Exponential { slope } => -rng.exponential(*slope),
63            Self::Histogram(sampler) => sampler.sample(rng),
64        }
65    }
66}
67
68#[derive(Clone, Debug)]
69pub enum Distribution {
70    Fixed(f64),
71    Uniform { min: f64, max: f64 },
72    Normal { mu: f64, sigma: f64 },
73    Exponential { slope: f64 },
74    Histogram(HistogramSampler),
75}
76impl Distribution {
77    pub fn sample(&self, rng: &mut Rng) -> f64 {
78        match self {
79            Self::Fixed(v) => *v,
80            Self::Uniform { min, max } => rng.uniform(*min, *max),
81            Self::Normal { mu, sigma } => rng.normal(*mu, *sigma),
82            Self::Exponential { slope } => rng.exponential(*slope),
83            Self::Histogram(hist) => hist.sample(rng),
84        }
85    }
86}
87
88pub trait LadduGenRngExt {
89    fn uniform(&mut self, min: f64, max: f64) -> f64;
90    fn normal(&mut self, mu: f64, sigma: f64) -> f64;
91    fn exponential(&mut self, slope: f64) -> f64;
92    fn p4(&mut self, mass: f64, energy: f64, direction: Vec3) -> Vec4;
93}
94
95impl LadduGenRngExt for Rng {
96    fn uniform(&mut self, min: f64, max: f64) -> f64 {
97        self.f64_range(min..=max)
98    }
99
100    fn normal(&mut self, mu: f64, sigma: f64) -> f64 {
101        self.f64_normal_approx(mu, sigma)
102    }
103
104    fn exponential(&mut self, slope: f64) -> f64 {
105        -(-self.f64()).ln_1p() / slope
106    }
107    fn p4(&mut self, mass: f64, energy: f64, direction: Vec3) -> Vec4 {
108        debug_assert!(
109            energy >= mass,
110            "Mass cannot be greater than energy!\nEnergy: {}\nMass: {}",
111            energy,
112            mass
113        );
114        let momentum = ((energy - mass) * (energy + mass)).max(0.0).sqrt();
115        (momentum * direction).with_mass(mass)
116    }
117}