Skip to main content

rv/dist/
empirical.rs

1#[cfg(feature = "serde1")]
2use serde::{Deserialize, Serialize};
3
4use crate::traits::{
5    Cdf, HasDensity, Mean, Parameterized, Sampleable, Variance,
6};
7use rand::Rng;
8
9/// An empirical distribution derived from samples.
10///
11/// __WARNING__: The `ln_f` and `f` methods are poor approximations.
12/// They both are likely be have unbound errors.
13///
14/// ```rust
15/// use rv::dist::{Gaussian, Empirical};
16/// use rv::prelude::*;
17/// use rv::misc::linspace;
18/// use rand_xoshiro::Xoshiro256Plus;
19/// use rand::SeedableRng;
20///
21/// let mut rng = Xoshiro256Plus::seed_from_u64(0xABCD);
22/// let dist = Gaussian::standard();
23///
24/// let sample: Vec<f64> = dist.sample(1000, &mut rng);
25/// let emp_dist = Empirical::new(sample);
26///
27/// let ln_f_err: Vec<f64> = linspace(emp_dist.range().0, emp_dist.range().1, 1000)
28///     .iter()
29///     .map(|x| {
30///         dist.ln_f(x) - emp_dist.ln_f(x)
31///     }).collect();
32/// ```
33#[derive(Debug, Clone, PartialEq, PartialOrd)]
34#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
35#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
36pub struct Empirical {
37    xs: Vec<f64>,
38    range: (f64, f64),
39}
40
41#[derive(Clone, Copy, Debug)]
42enum Pos {
43    First,
44    Last,
45    Present(usize),
46    Absent(usize),
47}
48
49pub struct EmpiricalParameters {
50    pub xs: Vec<f64>,
51}
52
53impl Parameterized for Empirical {
54    type Parameters = EmpiricalParameters;
55
56    fn emit_params(&self) -> Self::Parameters {
57        Self::Parameters {
58            xs: self.xs.clone(),
59        }
60    }
61
62    fn from_params(params: Self::Parameters) -> Self {
63        Self::new(params.xs)
64    }
65}
66
67impl Empirical {
68    /// Create a new Empirical distribution with the given observed values
69    #[must_use]
70    pub fn new(mut xs: Vec<f64>) -> Self {
71        xs.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
72        let min = xs[0];
73        let max = xs[xs.len() - 1];
74        Empirical {
75            xs,
76            range: (min, max),
77        }
78    }
79
80    fn pos(&self, x: f64) -> Pos {
81        if x < self.range.0 {
82            Pos::First
83        } else if x >= self.range.1 {
84            Pos::Last
85        } else {
86            self.xs
87                .binary_search_by(|&probe| probe.partial_cmp(&x).unwrap())
88                .map_or_else(Pos::Absent, Pos::Present)
89        }
90    }
91
92    /// Return the CDF at X
93    fn empcdf(&self, pos: Pos) -> f64 {
94        match pos {
95            Pos::First => 0.0,
96            Pos::Last => 1.0,
97            Pos::Present(ix) => ix as f64 / self.xs.len() as f64,
98            Pos::Absent(ix) => ix as f64 / self.xs.len() as f64,
99        }
100    }
101
102    /// Compute the CDF of a number of values
103    #[must_use]
104    pub fn empcdfs(&self, values: &[f64]) -> Vec<f64> {
105        values
106            .iter()
107            .map(|&value| {
108                let pos = self.pos(value);
109                self.empcdf(pos)
110            })
111            .collect()
112    }
113
114    /// A utility for computing a P-P plot.
115    #[must_use]
116    pub fn pp(&self, other: &Self) -> (Vec<f64>, Vec<f64>) {
117        let mut xys = self.xs.clone();
118        xys.append(&mut other.xs.clone());
119        xys.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
120        (self.empcdfs(&xys), other.empcdfs(&xys))
121    }
122
123    /// Area between CDF-CDF (1-1) line
124    #[must_use]
125    pub fn err(&self, other: &Self) -> f64 {
126        let (fxs, fys) = self.pp(other);
127        let diff: Vec<f64> = fxs
128            .iter()
129            .zip(fys.iter())
130            .map(|(fx, fy)| (fx - fy).abs())
131            .collect();
132
133        let mut q = 0.0;
134        for i in 1..fxs.len() {
135            let step = fxs[i] - fxs[i - 1];
136            let trap = diff[i] + diff[i - 1];
137            q += step * trap;
138        }
139        q / 2.0
140    }
141
142    /// Return the range of non-zero support for this distribution.
143    #[must_use]
144    pub fn range(&self) -> &(f64, f64) {
145        &self.range
146    }
147}
148
149impl HasDensity<f64> for Empirical {
150    fn f(&self, x: &f64) -> f64 {
151        eprintln!(
152            "WARNING: empirical.f is unstable. You probably don't want to use it."
153        );
154        match self.pos(*x) {
155            Pos::First => 0.0,
156            Pos::Last => 0.0,
157            Pos::Present(0) => 0.0,
158            Pos::Present(ix) => {
159                let cdf_x = self.empcdf(Pos::Present(ix));
160                let cdf_y = self.empcdf(Pos::Present(ix - 1));
161                let y = self.xs[ix - 1];
162                let h = x - y;
163                (cdf_x - cdf_y) / h
164            }
165            Pos::Absent(ix) => {
166                let cdf_x = self.empcdf(Pos::Absent(ix));
167                let cdf_y = self.empcdf(Pos::Present(ix - 1));
168                let y = self.xs[ix - 1];
169                let h = x - y;
170                (cdf_x - cdf_y) / h
171            }
172        }
173    }
174
175    fn ln_f(&self, x: &f64) -> f64 {
176        self.f(x).ln()
177    }
178}
179
180impl Sampleable<f64> for Empirical {
181    fn draw<R: Rng>(&self, rng: &mut R) -> f64 {
182        let n = self.xs.len();
183        let ix: usize = rng.random_range(0..n);
184        self.xs[ix]
185    }
186}
187
188impl Cdf<f64> for Empirical {
189    fn cdf(&self, x: &f64) -> f64 {
190        let pos = self.pos(*x);
191        self.empcdf(pos)
192    }
193}
194
195impl Mean<f64> for Empirical {
196    fn mean(&self) -> Option<f64> {
197        let n = self.xs.len() as f64;
198        Some(self.xs.iter().sum::<f64>() / n)
199    }
200}
201
202impl Variance<f64> for Empirical {
203    fn variance(&self) -> Option<f64> {
204        let n = self.xs.len() as f64;
205        self.mean().map(|m| {
206            self.xs.iter().map(|&x| (x - m) * (x - m)).sum::<f64>() / n
207        })
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::{dist::Gaussian, misc::linspace};
215    use rand::SeedableRng;
216    use rand_xoshiro::Xoshiro256Plus;
217
218    #[test]
219    #[ignore = "This failure is expected, ln_f should not be used."]
220    fn gaussian_sample() {
221        let mut rng = Xoshiro256Plus::seed_from_u64(0xABCD);
222        let dist = Gaussian::standard();
223        let sample: Vec<f64> = dist.sample(10000, &mut rng);
224        let emp_dist = Empirical::new(sample);
225
226        let (f_errs, cdf_errs): (Vec<f64>, Vec<f64>) =
227            linspace(emp_dist.range().0, emp_dist.range().1, 1000)
228                .into_iter()
229                .map(|x| {
230                    let ft = dist.f(&x);
231                    let fe = emp_dist.f(&x);
232                    let cdf_t = dist.cdf(&x);
233                    let cdf_e = emp_dist.cdf(&x);
234                    (fe - ft, cdf_e - cdf_t)
235                })
236                .unzip();
237
238        let max_f_err = f_errs
239            .iter()
240            .map(|x| x.abs())
241            .max_by(|a, b| a.partial_cmp(b).unwrap())
242            .unwrap();
243
244        let max_cdf_err = cdf_errs
245            .iter()
246            .map(|x| x.abs())
247            .max_by(|a, b| a.partial_cmp(b).unwrap())
248            .unwrap();
249
250        assert!(max_cdf_err < 1E-5);
251        assert!(max_f_err < 1E-5);
252    }
253
254    #[test]
255    fn draw_smoke() {
256        let mut rng = rand::rng();
257        // create a distribution with only a few bins so that draw hits all the
258        // bins.
259        let xs = vec![0.0, 1.0, 2.0];
260        let emp_dist = Empirical::new(xs);
261
262        for _ in 0..1_000 {
263            let _x: f64 = emp_dist.draw(&mut rng);
264        }
265    }
266
267    #[test]
268    fn emit_and_from_params_are_identity() {
269        let dist_a = Empirical::new(vec![0.0, 0.2, 0.3]);
270        let dist_b = Empirical::from_params(dist_a.emit_params());
271        assert_eq!(dist_a, dist_b);
272    }
273}