Skip to main content

p_square/
lib.rs

1//! The P-Square (P2) algorithm for dynamic calculation of quantiles without
2//! storing observations.
3//!
4//! An implementation based on the algorithm described in this [paper].
5//!
6//! This algorithm calculates estimates for percentiles of observations sets
7//! dynamically, with a O(1) space complexity.
8//!
9//! # Examples
10//!
11//! ```
12//! # use p_square::P2;
13//! #
14//! let mut p2 = P2::new(0.3);
15//!
16//! for n in 1..=100 {
17//!     p2.feed(n as f64);
18//! }
19//!
20//! assert_eq!(p2.estimate(), 30.0);
21//! ```
22//!
23//! [paper]: https://www.cse.wustl.edu/~jain/papers/ftp/psqr.pdf
24
25const MARKERS_COUNT: usize = 5;
26
27/// Marker indices for `q`.
28mod marker_index {
29    /// Minimum of the observations so far.
30    pub(super) const MINIMUM: usize = 0;
31    /// Current estimate of the `p/2`-quantile.
32    pub(super) const LOWER_MEDIAN: usize = 1;
33    /// Current estimate of the `p`-quantile.
34    pub(super) const QUANTILE: usize = 2;
35    /// Current estimate of the `(1+p)/2`-quantile.
36    pub(super) const UPPER_MEDIAN: usize = 3;
37    /// Current estimate of the `(1+p)/2`-quantile.
38    pub(super) const MAXIMUM: usize = 4;
39}
40
41#[derive(Clone)]
42pub struct P2 {
43    quantile: f64,
44
45    // q
46    heights: [f64; MARKERS_COUNT],
47    // n
48    positions: [f64; MARKERS_COUNT],
49    // n'
50    desired_positions: [f64; MARKERS_COUNT],
51    // dn'
52    increments: [f64; MARKERS_COUNT],
53
54    observations_counter: usize,
55}
56
57impl P2 {
58    /// Construct a new P2 state for estimating the `quantile` of the
59    /// observations.
60    ///
61    /// See the [crate](crate) documentation for more.
62    ///
63    /// # Panics
64    ///
65    /// Will panic if `quantile` is not in the range `0.0..=1.0`.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use p_square::P2;
71    ///
72    /// // A P-Square state to estimate a median.
73    /// let mut p2 = P2::new(0.5);
74    /// ```
75    pub const fn new(quantile: f64) -> Self {
76        assert!(
77            0.0 <= quantile && quantile <= 1.0,
78            "quantile must be in the range 0.0..=1.0"
79        );
80
81        let heights = [0.0; MARKERS_COUNT];
82        let positions = [0.0, 1.0, 2.0, 3.0, 4.0];
83        let desired_positions = [
84            0.0,
85            2.0 * quantile,
86            4.0 * quantile,
87            2.0 + 2.0 * quantile,
88            4.0,
89        ];
90        let increments = [0.0, quantile / 2.0, quantile, (1.0 + quantile) / 2.0, 1.0];
91
92        Self {
93            quantile,
94            heights,
95            positions,
96            desired_positions,
97            increments,
98            observations_counter: 0,
99        }
100    }
101
102    /// Current estimate of the desired quantile.
103    pub fn estimate(&self) -> f64 {
104        // PERF: this might be a good use case for hint::unlikely (but it's not
105        // stable).
106        if self.observations_counter <= MARKERS_COUNT {
107            // XXX: is this the best way to handle this situation?
108
109            let len = self.observations_counter;
110            let Some(max_idx) = len.checked_sub(1) else {
111                return 0.0;
112            };
113
114            let index = (max_idx as f64 * self.quantile).round() as usize;
115            debug_assert!(index <= max_idx, "quantile <= 1");
116
117            let initialized_heights = &mut self.heights.clone()[..len];
118
119            let (_lesser, v, _greater) =
120                initialized_heights.select_nth_unstable_by(index, |a, b| a.total_cmp(b));
121
122            *v
123        } else {
124            self.q3()
125        }
126    }
127
128    /// Feed a new observation.
129    pub fn feed(&mut self, observation: f64) {
130        let j = self.observations_counter;
131        self.observations_counter = self.observations_counter.saturating_add(1);
132
133        if j < MARKERS_COUNT {
134            self.heights[j] = observation;
135
136            if j + 1 == MARKERS_COUNT {
137                self.heights.sort_unstable_by(|a, b| a.total_cmp(b));
138            }
139
140            return;
141        }
142
143        // B.1.
144        let k: usize = if observation < self.q1() {
145            self.heights[marker_index::MINIMUM] = observation;
146            0
147        } else if self.q1() <= observation && observation < self.q2() {
148            0
149        } else if self.q2() <= observation && observation < self.q3() {
150            1
151        } else if self.q3() <= observation && observation < self.q4() {
152            2
153        } else if self.q4() <= observation && observation <= self.q5() {
154            3
155        } else if self.q5() < observation {
156            self.heights[marker_index::MAXIMUM] = observation;
157            3
158        } else {
159            unreachable!();
160        };
161
162        // B.2.
163        for n in self.positions.iter_mut().skip(k + 1) {
164            *n += 1.0;
165        }
166
167        for (n, d) in self.desired_positions.iter_mut().zip(self.increments) {
168            *n += d;
169        }
170
171        // B.3.
172        for i in marker_index::LOWER_MEDIAN..=marker_index::UPPER_MEDIAN {
173            let d = self.np(i) - self.n(i);
174
175            if (d >= 1.0 && self.n(i + 1) - self.n(i) > 1.0)
176                || (d <= -1.0 && self.n(i - 1) - self.n(i) < -1.0)
177            {
178                let d_sign = d.signum();
179                let qp = self.parabolic(i, d_sign);
180
181                self.heights[i] = if self.q(i - 1) < qp && qp < self.q(i + 1) {
182                    qp
183                } else {
184                    self.linear(i, d_sign)
185                };
186
187                self.positions[i] += d_sign;
188            }
189        }
190    }
191
192    const fn parabolic(&self, i: usize, d: f64) -> f64 {
193        self.q(i)
194            + (d / (self.n(i + 1) - self.n(i - 1)))
195                * ((self.n(i) - self.n(i - 1) + d)
196                    * ((self.q(i + 1) - self.q(i)) / (self.n(i + 1) - self.n(i)))
197                    + (self.n(i + 1) - self.n(i) - d)
198                        * ((self.q(i) - self.q(i - 1)) / (self.n(i) - self.n(i - 1))))
199    }
200
201    const fn linear(&self, i: usize, d: f64) -> f64 {
202        let i_plus_d = (i as i64 + d as i64) as usize;
203
204        self.q(i) + d * ((self.q(i_plus_d) - self.q(i)) / (self.n(i_plus_d) - self.n(i)))
205    }
206
207    // Helper getters
208
209    #[inline]
210    const fn q(&self, i: usize) -> f64 {
211        self.heights[i]
212    }
213
214    #[inline]
215    const fn n(&self, i: usize) -> f64 {
216        self.positions[i]
217    }
218
219    #[inline]
220    const fn np(&self, i: usize) -> f64 {
221        self.desired_positions[i]
222    }
223
224    #[inline]
225    fn q1(&self) -> f64 {
226        self.heights[marker_index::MINIMUM]
227    }
228
229    #[inline]
230    fn q2(&self) -> f64 {
231        self.heights[marker_index::LOWER_MEDIAN]
232    }
233
234    #[inline]
235    fn q3(&self) -> f64 {
236        self.heights[marker_index::QUANTILE]
237    }
238
239    #[inline]
240    fn q4(&self) -> f64 {
241        self.heights[marker_index::UPPER_MEDIAN]
242    }
243
244    #[inline]
245    fn q5(&self) -> f64 {
246        self.heights[marker_index::MINIMUM]
247    }
248}
249
250pub fn from_iter<I>(quantile: f64, iter: I) -> f64
251where
252    I: Iterator<Item = f64>,
253{
254    let mut state = P2::new(quantile);
255
256    iter.for_each(|observation| state.feed(observation));
257
258    state.estimate()
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn before_initialization() {
267        let n = from_iter(0.5, (1..=1).map(|n| n as f64));
268        assert_eq!(n, 1.0);
269
270        let n = from_iter(0.7, (1..=2).map(|n| n as f64));
271        assert_eq!(n, 2.0);
272
273        let n = from_iter(0.5, (1..=3).map(|n| n as f64));
274        assert_eq!(n, 2.0);
275
276        let n = from_iter(0.25, (1..=4).map(|n| n as f64));
277        assert_eq!(n, 2.0);
278
279        let n = from_iter(0.6, (1..=5).map(|n| n as f64));
280        assert_eq!(n, 3.0);
281    }
282
283    #[test]
284    #[should_panic(expected = "quantile must be in the range 0.0..=1.0")]
285    fn invalid_quantile() {
286        P2::new(1.2);
287    }
288}