ort_core/
distribution.rs

1use indexmap::IndexMap;
2use std::str::FromStr;
3
4/// Max: 1_000_000
5#[derive(Copy, Clone, Debug, Hash, PartialOrd, PartialEq, Ord, Eq)]
6pub struct Percentile(u32);
7
8pub struct Distribution<T = u64> {
9    // A sorted list of percentile-value pairs.
10    percentiles: IndexMap<Percentile, u64>,
11    _marker: std::marker::PhantomData<T>,
12}
13
14#[derive(Debug)]
15pub enum InvalidDistribution {
16    Unordered,
17    InvalidValue,
18    InvalidPercentile,
19}
20
21#[derive(Debug)]
22pub struct InvalidPercentile(());
23
24// === impl Distribution ===
25
26impl<T> Clone for Distribution<T> {
27    fn clone(&self) -> Self {
28        Self {
29            percentiles: self.percentiles.clone(),
30            _marker: self._marker,
31        }
32    }
33}
34
35impl<T> std::fmt::Debug for Distribution<T> {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("Distribution")
38            .field("percentiles", &self.percentiles)
39            .finish()
40    }
41}
42
43impl<T: Default + Into<u64>> Default for Distribution<T> {
44    fn default() -> Self {
45        let mut percentiles = IndexMap::new();
46        let v = T::default().into();
47        percentiles.entry(Percentile::MIN).or_insert(v);
48        percentiles.entry(Percentile::MAX).or_insert(v);
49        Self {
50            percentiles,
51            _marker: std::marker::PhantomData,
52        }
53    }
54}
55
56impl<T: FromStr + Default + Into<u64>> FromStr for Distribution<T> {
57    type Err = InvalidDistribution;
58
59    fn from_str(s: &str) -> Result<Self, Self::Err> {
60        let pvs = s.split(',').collect::<Vec<_>>();
61        if pvs.len() == 1 {
62            // If there is only a single item, it may not necessarily have a percentile.
63            let pv = match pvs[0].splitn(2, '=').collect::<Vec<_>>().as_slice() {
64                [v] => {
65                    let v = v
66                        .parse::<T>()
67                        .map_err(|_| InvalidDistribution::InvalidValue)?;
68                    (0f32, v)
69                }
70                [p, v] => {
71                    let p = p
72                        .parse::<f32>()
73                        .map_err(|_| InvalidDistribution::InvalidPercentile)?;
74                    let v = v
75                        .parse::<T>()
76                        .map_err(|_| InvalidDistribution::InvalidValue)?;
77                    (p, v)
78                }
79                _ => return Err(InvalidDistribution::InvalidPercentile),
80            };
81            Self::build(Some(pv))
82        } else {
83            let mut pairs = Vec::new();
84            for pv in pvs {
85                let mut pv = pv.splitn(2, '=');
86                match (pv.next(), pv.next()) {
87                    (Some(p), Some(v)) => {
88                        let p = p
89                            .parse::<f32>()
90                            .map_err(|_| InvalidDistribution::InvalidPercentile)?;
91                        let v = v
92                            .parse::<T>()
93                            .map_err(|_| InvalidDistribution::InvalidValue)?;
94                        pairs.push((p, v));
95                    }
96                    _ => return Err(InvalidDistribution::InvalidPercentile),
97                }
98            }
99            Self::build(pairs)
100        }
101    }
102}
103
104impl<T: Default + Into<u64>> Distribution<T> {
105    pub fn build<P>(pairs: impl IntoIterator<Item = (P, T)>) -> Result<Self, InvalidDistribution>
106    where
107        P: std::convert::TryInto<Percentile>,
108    {
109        let mut percentiles = IndexMap::new();
110        for (p, v) in pairs.into_iter() {
111            let p = p
112                .try_into()
113                .map_err(|_| InvalidDistribution::InvalidPercentile)?;
114            percentiles.insert(p, v.into());
115        }
116
117        // Ensure there is a minimum value in the distribution.
118        percentiles
119            .entry(Percentile::MIN)
120            .or_insert(T::default().into());
121        percentiles.sort_keys();
122
123        // Ensure all values are in ascending order.
124        let mut base_v = 0u64;
125        for v in percentiles.values() {
126            if *v < base_v {
127                return Err(InvalidDistribution::Unordered);
128            }
129            base_v = *v;
130        }
131
132        // Ensure there is a maximum value in the distribution.
133        let max_v = base_v;
134        percentiles.entry(Percentile::MAX).or_insert(max_v);
135
136        Ok(Self {
137            percentiles,
138            _marker: std::marker::PhantomData,
139        })
140    }
141}
142
143impl<T: From<u64>> Distribution<T> {
144    #[cfg(test)]
145    pub fn min(&self) -> T {
146        let v = self.percentiles.get(&Percentile::MIN).unwrap();
147        (*v).into()
148    }
149
150    #[cfg(test)]
151    pub fn max(&self) -> T {
152        let v = self.percentiles.get(&Percentile::MAX).unwrap();
153        (*v).into()
154    }
155
156    #[cfg(test)]
157    pub fn try_get<P>(&self, p: P) -> Result<T, P::Error>
158    where
159        P: std::convert::TryInto<Percentile>,
160    {
161        let p = p.try_into()?;
162        Ok(self.get(p))
163    }
164
165    pub fn get(&self, Percentile(percentile): Percentile) -> T {
166        let mut lower_p = 0u32;
167        let mut lower_v = 0u64;
168        for (Percentile(p), v) in self.percentiles.iter() {
169            if *p == percentile {
170                return (*v).into();
171            }
172
173            if *p > percentile {
174                let p_delta = *p as u64 - lower_p as u64;
175                let added = if p_delta > 0 {
176                    let v_delta = *v - lower_v;
177                    let unit = v_delta as f64 / p_delta as f64;
178                    let a = unit * ((percentile - lower_p) as u64) as f64;
179                    a as u64
180                } else {
181                    0
182                };
183                return (lower_v + added).into();
184            }
185
186            lower_p = *p;
187            lower_v = *v;
188        }
189
190        unreachable!("percentile must exist in distribution");
191    }
192}
193
194impl<T: From<u64>> rand::distributions::Distribution<T> for Distribution<T> {
195    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> T {
196        self.get(rng.gen())
197    }
198}
199
200// === impl Percentile ===
201
202impl Percentile {
203    pub const MIN: Self = Self(0);
204    pub const MAX: Self = Self(100_0000);
205    const FACTOR: u32 = 10000;
206}
207
208impl std::convert::TryFrom<f32> for Percentile {
209    type Error = InvalidPercentile;
210
211    fn try_from(v: f32) -> Result<Self, Self::Error> {
212        if !(0.0..=100.0).contains(&v) {
213            return Err(InvalidPercentile(()));
214        }
215        let adjusted = v * (Self::FACTOR as f32);
216
217        Ok(Percentile(adjusted as u32))
218    }
219}
220
221impl std::convert::TryFrom<u32> for Percentile {
222    type Error = InvalidPercentile;
223
224    fn try_from(v: u32) -> Result<Self, Self::Error> {
225        if v > 100 {
226            return Err(InvalidPercentile(()));
227        }
228        Ok(Percentile(v * Self::FACTOR))
229    }
230}
231
232impl rand::distributions::Distribution<Percentile> for rand::distributions::Standard {
233    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Percentile {
234        Percentile(rng.gen_range(0..=100_0000))
235    }
236}
237
238// === impl InvalidDistribution ===
239
240impl std::fmt::Display for InvalidDistribution {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        match self {
243            Self::Unordered => write!(f, "Undordered distribution"),
244            Self::InvalidPercentile => write!(f, "Invalid percentile"),
245            Self::InvalidValue => write!(f, "Invalid value"),
246        }
247    }
248}
249
250impl std::error::Error for InvalidDistribution {}
251
252impl std::fmt::Display for InvalidPercentile {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        write!(f, "Invalid percentile")
255    }
256}
257
258impl std::error::Error for InvalidPercentile {}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use std::convert::TryFrom;
264
265    #[test]
266    fn convert_percentiles() {
267        assert_eq!(Percentile::try_from(0).unwrap(), Percentile::MIN);
268        assert_eq!(Percentile::try_from(0.0).unwrap(), Percentile::MIN);
269        assert_eq!(Percentile::try_from(50).unwrap(), Percentile(50_0000));
270        assert_eq!(Percentile::try_from(50.0).unwrap(), Percentile(50_0000));
271        assert_eq!(Percentile::try_from(75).unwrap(), Percentile(75_0000));
272        assert_eq!(Percentile::try_from(75.0).unwrap(), Percentile(75_0000));
273        assert_eq!(Percentile::try_from(99).unwrap(), Percentile(99_0000));
274        assert_eq!(Percentile::try_from(99.0).unwrap(), Percentile(99_0000));
275        assert_eq!(Percentile::try_from(99.99).unwrap(), Percentile(99_9900));
276        assert_eq!(Percentile::try_from(99.99999).unwrap(), Percentile(99_9999));
277        assert_eq!(Percentile::try_from(100).unwrap(), Percentile::MAX);
278        assert_eq!(Percentile::try_from(100.0).unwrap(), Percentile::MAX);
279
280        assert!(Percentile::try_from(-1.0).is_err());
281        assert!(Percentile::try_from(101.0).is_err());
282    }
283
284    #[test]
285    fn distributions() {
286        let d = Distribution::<u64>::default();
287        assert_eq!(d.min(), 0);
288        assert_eq!(d.try_get(50).unwrap(), 0);
289        assert_eq!(d.max(), 0);
290
291        let d = Distribution::build(vec![(0, 1000u64), (100, 2000)]).unwrap();
292        assert_eq!(d.min(), 1000);
293        assert_eq!(d.try_get(50).unwrap(), 1500);
294        assert_eq!(d.max(), 2000);
295    }
296
297    #[test]
298    fn parse() {
299        let d = "123".parse::<Distribution<u64>>().unwrap();
300        assert_eq!(d.min(), 123);
301        assert_eq!(d.try_get(50).unwrap(), 123);
302        assert_eq!(d.max(), 123);
303
304        let d = "50=123".parse::<Distribution<u64>>().unwrap();
305        assert_eq!(d.min(), 0);
306        assert_eq!(d.try_get(50).unwrap(), 123);
307        assert_eq!(d.max(), 123);
308
309        let d = "0=1,50=123,100=234".parse::<Distribution<u64>>().unwrap();
310        assert_eq!(d.min(), 1);
311        assert_eq!(d.try_get(50).unwrap(), 123);
312        assert_eq!(d.max(), 234);
313
314        #[derive(Debug, Default, PartialEq, Eq)]
315        struct Dingus(u64);
316        impl From<u64> for Dingus {
317            fn from(n: u64) -> Self {
318                Self(n)
319            }
320        }
321        impl Into<u64> for Dingus {
322            fn into(self) -> u64 {
323                self.0
324            }
325        }
326        impl std::str::FromStr for Dingus {
327            type Err = ();
328            fn from_str(s: &str) -> Result<Self, ()> {
329                match s {
330                    "A" => Ok(Self(10)),
331                    "B" => Ok(Self(20)),
332                    "C" => Ok(Self(30)),
333                    "D" => Ok(Self(40)),
334                    _ => Err(()),
335                }
336            }
337        }
338
339        let d = "0=A,50=B,90=C,100=D"
340            .parse::<Distribution<Dingus>>()
341            .unwrap();
342        assert_eq!(d.min(), Dingus(10));
343        assert_eq!(d.try_get(50).unwrap(), Dingus(20));
344        assert_eq!(d.try_get(90).unwrap(), Dingus(30));
345        assert_eq!(d.try_get(95).unwrap(), Dingus(35));
346        assert_eq!(d.max(), Dingus(40));
347    }
348}