1use indexmap::IndexMap;
2use std::str::FromStr;
3
4#[derive(Copy, Clone, Debug, Hash, PartialOrd, PartialEq, Ord, Eq)]
6pub struct Percentile(u32);
7
8pub struct Distribution<T = u64> {
9 percentiles: IndexMap<Percentile, u64>,
11 _marker: std::marker::PhantomData<T>,
12}
13
14#[derive(Debug)]
15pub enum InvalidDistribution {
16 Unordered,
17 InvalidValue,
18 InvalidPercentile,
19}
20
21#[derive(Debug)]
22pub struct InvalidPercentile(());
23
24impl<T> Clone for Distribution<T> {
27 fn clone(&self) -> Self {
28 Self {
29 percentiles: self.percentiles.clone(),
30 _marker: self._marker,
31 }
32 }
33}
34
35impl<T> std::fmt::Debug for Distribution<T> {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("Distribution")
38 .field("percentiles", &self.percentiles)
39 .finish()
40 }
41}
42
43impl<T: Default + Into<u64>> Default for Distribution<T> {
44 fn default() -> Self {
45 let mut percentiles = IndexMap::new();
46 let v = T::default().into();
47 percentiles.entry(Percentile::MIN).or_insert(v);
48 percentiles.entry(Percentile::MAX).or_insert(v);
49 Self {
50 percentiles,
51 _marker: std::marker::PhantomData,
52 }
53 }
54}
55
56impl<T: FromStr + Default + Into<u64>> FromStr for Distribution<T> {
57 type Err = InvalidDistribution;
58
59 fn from_str(s: &str) -> Result<Self, Self::Err> {
60 let pvs = s.split(',').collect::<Vec<_>>();
61 if pvs.len() == 1 {
62 let pv = match pvs[0].splitn(2, '=').collect::<Vec<_>>().as_slice() {
64 [v] => {
65 let v = v
66 .parse::<T>()
67 .map_err(|_| InvalidDistribution::InvalidValue)?;
68 (0f32, v)
69 }
70 [p, v] => {
71 let p = p
72 .parse::<f32>()
73 .map_err(|_| InvalidDistribution::InvalidPercentile)?;
74 let v = v
75 .parse::<T>()
76 .map_err(|_| InvalidDistribution::InvalidValue)?;
77 (p, v)
78 }
79 _ => return Err(InvalidDistribution::InvalidPercentile),
80 };
81 Self::build(Some(pv))
82 } else {
83 let mut pairs = Vec::new();
84 for pv in pvs {
85 let mut pv = pv.splitn(2, '=');
86 match (pv.next(), pv.next()) {
87 (Some(p), Some(v)) => {
88 let p = p
89 .parse::<f32>()
90 .map_err(|_| InvalidDistribution::InvalidPercentile)?;
91 let v = v
92 .parse::<T>()
93 .map_err(|_| InvalidDistribution::InvalidValue)?;
94 pairs.push((p, v));
95 }
96 _ => return Err(InvalidDistribution::InvalidPercentile),
97 }
98 }
99 Self::build(pairs)
100 }
101 }
102}
103
104impl<T: Default + Into<u64>> Distribution<T> {
105 pub fn build<P>(pairs: impl IntoIterator<Item = (P, T)>) -> Result<Self, InvalidDistribution>
106 where
107 P: std::convert::TryInto<Percentile>,
108 {
109 let mut percentiles = IndexMap::new();
110 for (p, v) in pairs.into_iter() {
111 let p = p
112 .try_into()
113 .map_err(|_| InvalidDistribution::InvalidPercentile)?;
114 percentiles.insert(p, v.into());
115 }
116
117 percentiles
119 .entry(Percentile::MIN)
120 .or_insert(T::default().into());
121 percentiles.sort_keys();
122
123 let mut base_v = 0u64;
125 for v in percentiles.values() {
126 if *v < base_v {
127 return Err(InvalidDistribution::Unordered);
128 }
129 base_v = *v;
130 }
131
132 let max_v = base_v;
134 percentiles.entry(Percentile::MAX).or_insert(max_v);
135
136 Ok(Self {
137 percentiles,
138 _marker: std::marker::PhantomData,
139 })
140 }
141}
142
143impl<T: From<u64>> Distribution<T> {
144 #[cfg(test)]
145 pub fn min(&self) -> T {
146 let v = self.percentiles.get(&Percentile::MIN).unwrap();
147 (*v).into()
148 }
149
150 #[cfg(test)]
151 pub fn max(&self) -> T {
152 let v = self.percentiles.get(&Percentile::MAX).unwrap();
153 (*v).into()
154 }
155
156 #[cfg(test)]
157 pub fn try_get<P>(&self, p: P) -> Result<T, P::Error>
158 where
159 P: std::convert::TryInto<Percentile>,
160 {
161 let p = p.try_into()?;
162 Ok(self.get(p))
163 }
164
165 pub fn get(&self, Percentile(percentile): Percentile) -> T {
166 let mut lower_p = 0u32;
167 let mut lower_v = 0u64;
168 for (Percentile(p), v) in self.percentiles.iter() {
169 if *p == percentile {
170 return (*v).into();
171 }
172
173 if *p > percentile {
174 let p_delta = *p as u64 - lower_p as u64;
175 let added = if p_delta > 0 {
176 let v_delta = *v - lower_v;
177 let unit = v_delta as f64 / p_delta as f64;
178 let a = unit * ((percentile - lower_p) as u64) as f64;
179 a as u64
180 } else {
181 0
182 };
183 return (lower_v + added).into();
184 }
185
186 lower_p = *p;
187 lower_v = *v;
188 }
189
190 unreachable!("percentile must exist in distribution");
191 }
192}
193
194impl<T: From<u64>> rand::distributions::Distribution<T> for Distribution<T> {
195 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> T {
196 self.get(rng.gen())
197 }
198}
199
200impl Percentile {
203 pub const MIN: Self = Self(0);
204 pub const MAX: Self = Self(100_0000);
205 const FACTOR: u32 = 10000;
206}
207
208impl std::convert::TryFrom<f32> for Percentile {
209 type Error = InvalidPercentile;
210
211 fn try_from(v: f32) -> Result<Self, Self::Error> {
212 if !(0.0..=100.0).contains(&v) {
213 return Err(InvalidPercentile(()));
214 }
215 let adjusted = v * (Self::FACTOR as f32);
216
217 Ok(Percentile(adjusted as u32))
218 }
219}
220
221impl std::convert::TryFrom<u32> for Percentile {
222 type Error = InvalidPercentile;
223
224 fn try_from(v: u32) -> Result<Self, Self::Error> {
225 if v > 100 {
226 return Err(InvalidPercentile(()));
227 }
228 Ok(Percentile(v * Self::FACTOR))
229 }
230}
231
232impl rand::distributions::Distribution<Percentile> for rand::distributions::Standard {
233 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Percentile {
234 Percentile(rng.gen_range(0..=100_0000))
235 }
236}
237
238impl std::fmt::Display for InvalidDistribution {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 match self {
243 Self::Unordered => write!(f, "Undordered distribution"),
244 Self::InvalidPercentile => write!(f, "Invalid percentile"),
245 Self::InvalidValue => write!(f, "Invalid value"),
246 }
247 }
248}
249
250impl std::error::Error for InvalidDistribution {}
251
252impl std::fmt::Display for InvalidPercentile {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 write!(f, "Invalid percentile")
255 }
256}
257
258impl std::error::Error for InvalidPercentile {}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use std::convert::TryFrom;
264
265 #[test]
266 fn convert_percentiles() {
267 assert_eq!(Percentile::try_from(0).unwrap(), Percentile::MIN);
268 assert_eq!(Percentile::try_from(0.0).unwrap(), Percentile::MIN);
269 assert_eq!(Percentile::try_from(50).unwrap(), Percentile(50_0000));
270 assert_eq!(Percentile::try_from(50.0).unwrap(), Percentile(50_0000));
271 assert_eq!(Percentile::try_from(75).unwrap(), Percentile(75_0000));
272 assert_eq!(Percentile::try_from(75.0).unwrap(), Percentile(75_0000));
273 assert_eq!(Percentile::try_from(99).unwrap(), Percentile(99_0000));
274 assert_eq!(Percentile::try_from(99.0).unwrap(), Percentile(99_0000));
275 assert_eq!(Percentile::try_from(99.99).unwrap(), Percentile(99_9900));
276 assert_eq!(Percentile::try_from(99.99999).unwrap(), Percentile(99_9999));
277 assert_eq!(Percentile::try_from(100).unwrap(), Percentile::MAX);
278 assert_eq!(Percentile::try_from(100.0).unwrap(), Percentile::MAX);
279
280 assert!(Percentile::try_from(-1.0).is_err());
281 assert!(Percentile::try_from(101.0).is_err());
282 }
283
284 #[test]
285 fn distributions() {
286 let d = Distribution::<u64>::default();
287 assert_eq!(d.min(), 0);
288 assert_eq!(d.try_get(50).unwrap(), 0);
289 assert_eq!(d.max(), 0);
290
291 let d = Distribution::build(vec![(0, 1000u64), (100, 2000)]).unwrap();
292 assert_eq!(d.min(), 1000);
293 assert_eq!(d.try_get(50).unwrap(), 1500);
294 assert_eq!(d.max(), 2000);
295 }
296
297 #[test]
298 fn parse() {
299 let d = "123".parse::<Distribution<u64>>().unwrap();
300 assert_eq!(d.min(), 123);
301 assert_eq!(d.try_get(50).unwrap(), 123);
302 assert_eq!(d.max(), 123);
303
304 let d = "50=123".parse::<Distribution<u64>>().unwrap();
305 assert_eq!(d.min(), 0);
306 assert_eq!(d.try_get(50).unwrap(), 123);
307 assert_eq!(d.max(), 123);
308
309 let d = "0=1,50=123,100=234".parse::<Distribution<u64>>().unwrap();
310 assert_eq!(d.min(), 1);
311 assert_eq!(d.try_get(50).unwrap(), 123);
312 assert_eq!(d.max(), 234);
313
314 #[derive(Debug, Default, PartialEq, Eq)]
315 struct Dingus(u64);
316 impl From<u64> for Dingus {
317 fn from(n: u64) -> Self {
318 Self(n)
319 }
320 }
321 impl Into<u64> for Dingus {
322 fn into(self) -> u64 {
323 self.0
324 }
325 }
326 impl std::str::FromStr for Dingus {
327 type Err = ();
328 fn from_str(s: &str) -> Result<Self, ()> {
329 match s {
330 "A" => Ok(Self(10)),
331 "B" => Ok(Self(20)),
332 "C" => Ok(Self(30)),
333 "D" => Ok(Self(40)),
334 _ => Err(()),
335 }
336 }
337 }
338
339 let d = "0=A,50=B,90=C,100=D"
340 .parse::<Distribution<Dingus>>()
341 .unwrap();
342 assert_eq!(d.min(), Dingus(10));
343 assert_eq!(d.try_get(50).unwrap(), Dingus(20));
344 assert_eq!(d.try_get(90).unwrap(), Dingus(30));
345 assert_eq!(d.try_get(95).unwrap(), Dingus(35));
346 assert_eq!(d.max(), Dingus(40));
347 }
348}