radiate_utils/stats/
quantile.rs1use crate::Float;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::cmp::Ordering;
5
6#[derive(Clone, PartialEq)]
20#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21pub struct Quantile<F: Float = f32> {
22 q: F,
23 heights: [F; 5],
24 positions: [F; 5],
25 desired: [F; 5],
26 increments: [F; 5],
27 count: u32,
28}
29
30impl<F: Float> Quantile<F> {
31 pub fn new(q: F) -> Self {
32 assert!(
33 q > F::ZERO && q < F::ONE,
34 "Quantile q must be in the open interval (0, 1)"
35 );
36 Self {
37 q,
38 heights: [F::ZERO; 5],
39 positions: [F::ZERO; 5],
40 desired: [F::ZERO; 5],
41 increments: Self::compute_increments(q),
42 count: 0,
43 }
44 }
45
46 pub fn q(&self) -> F {
47 self.q
48 }
49
50 pub fn count(&self) -> u32 {
51 self.count
52 }
53
54 pub fn value(&self) -> Option<F> {
58 match self.count {
59 0 => None,
60 n if n < 5 => Some(self.interp_partial(n as usize)),
61 _ => Some(self.heights[2]),
62 }
63 }
64
65 pub fn add(&mut self, x: F) {
66 if self.count < 5 {
67 self.heights[self.count as usize] = x;
68 self.count += 1;
69 if self.count == 5 {
70 self.heights
71 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
72 for i in 0..5 {
73 self.positions[i] = F::from(i + 1).unwrap_or(F::ZERO);
74 }
75 self.desired[0] = F::ONE;
76 self.desired[1] = F::ONE + F::TWO * self.q;
77 self.desired[2] = F::ONE + F::FOUR * self.q;
78 self.desired[3] = F::THREE + F::TWO * self.q;
79 self.desired[4] = F::FIVE;
80 }
81 return;
82 }
83
84 let k = self.find_cell(x);
85
86 for i in (k + 1)..5 {
87 self.positions[i] = self.positions[i] + F::ONE;
88 }
89
90 for i in 0..5 {
91 self.desired[i] = self.desired[i] + self.increments[i];
92 }
93
94 for i in 1..4 {
95 let d = self.desired[i] - self.positions[i];
96 let up = self.positions[i + 1] - self.positions[i];
97 let down = self.positions[i - 1] - self.positions[i];
98
99 let sign = if d >= F::ONE && up > F::ONE {
100 F::ONE
101 } else if d <= -F::ONE && down < -F::ONE {
102 -F::ONE
103 } else {
104 continue;
105 };
106
107 let qs = self.parabolic(i, sign);
108 let new_h = if self.heights[i - 1] < qs && qs < self.heights[i + 1] {
109 qs
110 } else {
111 self.linear(i, sign)
112 };
113 self.heights[i] = new_h;
114 self.positions[i] = self.positions[i] + sign;
115 }
116
117 self.count += 1;
118 }
119
120 pub fn clear(&mut self) {
121 self.heights = [F::ZERO; 5];
122 self.positions = [F::ZERO; 5];
123 self.desired = [F::ZERO; 5];
124 self.increments = Self::compute_increments(self.q);
125 self.count = 0;
126 }
127
128 fn compute_increments(q: F) -> [F; 5] {
129 let half = F::from(0.5).unwrap_or(F::ZERO);
130 [F::ZERO, q * half, q, (F::ONE + q) * half, F::ONE]
131 }
132
133 fn find_cell(&mut self, x: F) -> usize {
134 if x < self.heights[0] {
135 self.heights[0] = x;
136 return 0;
137 }
138 if x >= self.heights[4] {
139 self.heights[4] = x;
140 return 3;
141 }
142 let mut k = 0;
143 while k < 3 && x >= self.heights[k + 1] {
144 k += 1;
145 }
146 k
147 }
148
149 fn parabolic(&self, i: usize, d: F) -> F {
150 let n_prev = self.positions[i - 1];
151 let n = self.positions[i];
152 let n_next = self.positions[i + 1];
153 let h_prev = self.heights[i - 1];
154 let h = self.heights[i];
155 let h_next = self.heights[i + 1];
156
157 h + d / (n_next - n_prev)
158 * ((n - n_prev + d) * (h_next - h) / (n_next - n)
159 + (n_next - n - d) * (h - h_prev) / (n - n_prev))
160 }
161
162 fn linear(&self, i: usize, d: F) -> F {
163 let neighbor = if d > F::ZERO { i + 1 } else { i - 1 };
164 self.heights[i]
165 + d * (self.heights[neighbor] - self.heights[i])
166 / (self.positions[neighbor] - self.positions[i])
167 }
168
169 fn interp_partial(&self, n: usize) -> F {
170 let mut buf = [F::ZERO; 5];
171 buf[..n].copy_from_slice(&self.heights[..n]);
172 buf[..n].sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
173
174 if n == 1 {
175 return buf[0];
176 }
177 let q_f = self.q.to_f64().unwrap_or(0.0);
178 let pos = q_f * (n - 1) as f64;
179 let lo = pos.floor() as usize;
180 let hi = pos.ceil() as usize;
181 let frac = F::from(pos - lo as f64).unwrap_or(F::ZERO);
182 buf[lo] + frac * (buf[hi] - buf[lo])
183 }
184}
185
186impl<F: Float> Extend<F> for Quantile<F> {
187 fn extend<T: IntoIterator<Item = F>>(&mut self, iter: T) {
188 for v in iter {
189 self.add(v);
190 }
191 }
192}
193
194impl<F: Float> std::fmt::Debug for Quantile<F> {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct("Quantile")
197 .field("q", &self.q)
198 .field("count", &self.count)
199 .field("value", &self.value())
200 .finish()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn median_of_uniform_sequence() {
210 let mut q = Quantile::<f32>::new(0.5);
211 for i in 1..=100 {
212 q.add(i as f32);
213 }
214 let v = q.value().unwrap();
215 assert!((v - 50.5).abs() < 2.0, "got {v}");
217 }
218
219 #[test]
220 fn p95_of_uniform_sequence() {
221 let mut q = Quantile::<f32>::new(0.95);
222 for i in 1..=1000 {
223 q.add(i as f32);
224 }
225 let v = q.value().unwrap();
226 assert!((v - 950.0).abs() < 15.0, "got {v}");
227 }
228
229 #[test]
230 fn constant_input_returns_constant() {
231 let mut q = Quantile::<f32>::new(0.5);
232 for _ in 0..50 {
233 q.add(7.0);
234 }
235 assert_eq!(q.value().unwrap(), 7.0);
236 }
237
238 #[test]
239 fn single_sample_returns_that_sample() {
240 let mut q = Quantile::<f32>::new(0.5);
241 q.add(3.14);
242 assert_eq!(q.value().unwrap(), 3.14);
243 }
244
245 #[test]
246 fn empty_returns_none() {
247 let q = Quantile::<f32>::new(0.5);
248 assert!(q.value().is_none());
249 }
250
251 #[test]
252 fn fewer_than_five_uses_exact_interp() {
253 let mut q = Quantile::<f32>::new(0.5);
254 q.add(1.0);
255 q.add(3.0);
256 q.add(5.0);
257 assert_eq!(q.value().unwrap(), 3.0);
259 }
260
261 #[test]
262 fn clear_resets_state() {
263 let mut q = Quantile::<f32>::new(0.5);
264 for i in 1..=20 {
265 q.add(i as f32);
266 }
267 q.clear();
268 assert_eq!(q.count(), 0);
269 assert!(q.value().is_none());
270 }
271
272 #[test]
273 fn extend_from_iter() {
274 let mut q = Quantile::<f32>::new(0.5);
275 q.extend((1..=100).map(|i| i as f32));
276 assert_eq!(q.count(), 100);
277 let v = q.value().unwrap();
278 assert!((v - 50.5).abs() < 2.0);
279 }
280
281 #[test]
282 #[should_panic]
283 fn rejects_q_zero() {
284 let _ = Quantile::<f32>::new(0.0);
285 }
286
287 #[test]
288 #[should_panic]
289 fn rejects_q_one() {
290 let _ = Quantile::<f32>::new(1.0);
291 }
292}