1#[cfg(feature = "serde1")]
2use serde::{Deserialize, Serialize};
3
4use crate::traits::{
5 Cdf, HasDensity, Mean, Parameterized, Sampleable, Variance,
6};
7use rand::Rng;
8
9#[derive(Debug, Clone, PartialEq, PartialOrd)]
34#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
35#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
36pub struct Empirical {
37 xs: Vec<f64>,
38 range: (f64, f64),
39}
40
41#[derive(Clone, Copy, Debug)]
42enum Pos {
43 First,
44 Last,
45 Present(usize),
46 Absent(usize),
47}
48
49pub struct EmpiricalParameters {
50 pub xs: Vec<f64>,
51}
52
53impl Parameterized for Empirical {
54 type Parameters = EmpiricalParameters;
55
56 fn emit_params(&self) -> Self::Parameters {
57 Self::Parameters {
58 xs: self.xs.clone(),
59 }
60 }
61
62 fn from_params(params: Self::Parameters) -> Self {
63 Self::new(params.xs)
64 }
65}
66
67impl Empirical {
68 #[must_use]
70 pub fn new(mut xs: Vec<f64>) -> Self {
71 xs.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
72 let min = xs[0];
73 let max = xs[xs.len() - 1];
74 Empirical {
75 xs,
76 range: (min, max),
77 }
78 }
79
80 fn pos(&self, x: f64) -> Pos {
81 if x < self.range.0 {
82 Pos::First
83 } else if x >= self.range.1 {
84 Pos::Last
85 } else {
86 self.xs
87 .binary_search_by(|&probe| probe.partial_cmp(&x).unwrap())
88 .map_or_else(Pos::Absent, Pos::Present)
89 }
90 }
91
92 fn empcdf(&self, pos: Pos) -> f64 {
94 match pos {
95 Pos::First => 0.0,
96 Pos::Last => 1.0,
97 Pos::Present(ix) => ix as f64 / self.xs.len() as f64,
98 Pos::Absent(ix) => ix as f64 / self.xs.len() as f64,
99 }
100 }
101
102 #[must_use]
104 pub fn empcdfs(&self, values: &[f64]) -> Vec<f64> {
105 values
106 .iter()
107 .map(|&value| {
108 let pos = self.pos(value);
109 self.empcdf(pos)
110 })
111 .collect()
112 }
113
114 #[must_use]
116 pub fn pp(&self, other: &Self) -> (Vec<f64>, Vec<f64>) {
117 let mut xys = self.xs.clone();
118 xys.append(&mut other.xs.clone());
119 xys.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
120 (self.empcdfs(&xys), other.empcdfs(&xys))
121 }
122
123 #[must_use]
125 pub fn err(&self, other: &Self) -> f64 {
126 let (fxs, fys) = self.pp(other);
127 let diff: Vec<f64> = fxs
128 .iter()
129 .zip(fys.iter())
130 .map(|(fx, fy)| (fx - fy).abs())
131 .collect();
132
133 let mut q = 0.0;
134 for i in 1..fxs.len() {
135 let step = fxs[i] - fxs[i - 1];
136 let trap = diff[i] + diff[i - 1];
137 q += step * trap;
138 }
139 q / 2.0
140 }
141
142 #[must_use]
144 pub fn range(&self) -> &(f64, f64) {
145 &self.range
146 }
147}
148
149impl HasDensity<f64> for Empirical {
150 fn f(&self, x: &f64) -> f64 {
151 eprintln!(
152 "WARNING: empirical.f is unstable. You probably don't want to use it."
153 );
154 match self.pos(*x) {
155 Pos::First => 0.0,
156 Pos::Last => 0.0,
157 Pos::Present(0) => 0.0,
158 Pos::Present(ix) => {
159 let cdf_x = self.empcdf(Pos::Present(ix));
160 let cdf_y = self.empcdf(Pos::Present(ix - 1));
161 let y = self.xs[ix - 1];
162 let h = x - y;
163 (cdf_x - cdf_y) / h
164 }
165 Pos::Absent(ix) => {
166 let cdf_x = self.empcdf(Pos::Absent(ix));
167 let cdf_y = self.empcdf(Pos::Present(ix - 1));
168 let y = self.xs[ix - 1];
169 let h = x - y;
170 (cdf_x - cdf_y) / h
171 }
172 }
173 }
174
175 fn ln_f(&self, x: &f64) -> f64 {
176 self.f(x).ln()
177 }
178}
179
180impl Sampleable<f64> for Empirical {
181 fn draw<R: Rng>(&self, rng: &mut R) -> f64 {
182 let n = self.xs.len();
183 let ix: usize = rng.random_range(0..n);
184 self.xs[ix]
185 }
186}
187
188impl Cdf<f64> for Empirical {
189 fn cdf(&self, x: &f64) -> f64 {
190 let pos = self.pos(*x);
191 self.empcdf(pos)
192 }
193}
194
195impl Mean<f64> for Empirical {
196 fn mean(&self) -> Option<f64> {
197 let n = self.xs.len() as f64;
198 Some(self.xs.iter().sum::<f64>() / n)
199 }
200}
201
202impl Variance<f64> for Empirical {
203 fn variance(&self) -> Option<f64> {
204 let n = self.xs.len() as f64;
205 self.mean().map(|m| {
206 self.xs.iter().map(|&x| (x - m) * (x - m)).sum::<f64>() / n
207 })
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::{dist::Gaussian, misc::linspace};
215 use rand::SeedableRng;
216 use rand_xoshiro::Xoshiro256Plus;
217
218 #[test]
219 #[ignore = "This failure is expected, ln_f should not be used."]
220 fn gaussian_sample() {
221 let mut rng = Xoshiro256Plus::seed_from_u64(0xABCD);
222 let dist = Gaussian::standard();
223 let sample: Vec<f64> = dist.sample(10000, &mut rng);
224 let emp_dist = Empirical::new(sample);
225
226 let (f_errs, cdf_errs): (Vec<f64>, Vec<f64>) =
227 linspace(emp_dist.range().0, emp_dist.range().1, 1000)
228 .into_iter()
229 .map(|x| {
230 let ft = dist.f(&x);
231 let fe = emp_dist.f(&x);
232 let cdf_t = dist.cdf(&x);
233 let cdf_e = emp_dist.cdf(&x);
234 (fe - ft, cdf_e - cdf_t)
235 })
236 .unzip();
237
238 let max_f_err = f_errs
239 .iter()
240 .map(|x| x.abs())
241 .max_by(|a, b| a.partial_cmp(b).unwrap())
242 .unwrap();
243
244 let max_cdf_err = cdf_errs
245 .iter()
246 .map(|x| x.abs())
247 .max_by(|a, b| a.partial_cmp(b).unwrap())
248 .unwrap();
249
250 assert!(max_cdf_err < 1E-5);
251 assert!(max_f_err < 1E-5);
252 }
253
254 #[test]
255 fn draw_smoke() {
256 let mut rng = rand::rng();
257 let xs = vec![0.0, 1.0, 2.0];
260 let emp_dist = Empirical::new(xs);
261
262 for _ in 0..1_000 {
263 let _x: f64 = emp_dist.draw(&mut rng);
264 }
265 }
266
267 #[test]
268 fn emit_and_from_params_are_identity() {
269 let dist_a = Empirical::new(vec![0.0, 0.2, 0.3]);
270 let dist_b = Empirical::from_params(dist_a.emit_params());
271 assert_eq!(dist_a, dist_b);
272 }
273}