#![warn(missing_docs)]
#![warn(rustdoc::broken_intra_doc_links)]
use histogram::AtomicHistogram;
use std::{
sync::atomic::{AtomicU64, Ordering::*},
time::Duration,
};
#[derive(Debug, thiserror::Error)]
#[error("hedged error: {source}")]
pub struct Error {
#[source]
source: histogram::Error,
}
impl From<histogram::Error> for Error {
fn from(value: histogram::Error) -> Self {
Self { source: value }
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub struct Hedge {
histogram: AtomicHistogram,
current_usec: AtomicU64,
observation_count: AtomicU64,
period: u64,
percentile: f64,
}
impl Default for Hedge {
fn default() -> Hedge {
Self {
histogram: AtomicHistogram::new(7, 64).expect("histogram"),
current_usec: AtomicU64::new(
Duration::from_secs(30)
.as_micros()
.try_into()
.expect("valid timeout"),
),
observation_count: AtomicU64::new(0),
period: 10,
percentile: 0.95,
}
}
}
impl Hedge {
pub fn new(
p: u8,
n: u8,
initial_timeout: Duration,
period: u64,
percentile: f64,
) -> Result<Self> {
if percentile <= 0.0 || percentile > 1.0 {
panic!("percentile should in (0.0, 1.0], was {percentile}");
}
if period == 0 {
panic!("period must be greater that 0");
}
Ok(Self {
histogram: AtomicHistogram::new(p, n)?,
current_usec: AtomicU64::new(
initial_timeout
.as_micros()
.try_into()
.map_err(|_| histogram::Error::Overflow)?,
),
observation_count: AtomicU64::new(0),
period,
percentile: percentile * 100.0,
})
}
pub fn with_initial_timeout(self, timeout: Duration) -> Result<Self> {
Ok(Self {
current_usec: AtomicU64::new(
timeout
.as_micros()
.try_into()
.map_err(|_| histogram::Error::Overflow)?,
),
..self
})
}
pub fn with_period(self, period: u64) -> Self {
if period == 0 {
panic!("period must be greater that 0");
}
Self { period, ..self }
}
#[cfg(feature = "tokio")]
pub async fn send<F, Fut, R>(&self, mut f: F) -> (R, Option<std::pin::Pin<Box<Fut>>>)
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = R>,
{
use tokio::time::{timeout, Instant};
let mut first_request = Box::pin((f)());
let first_start = Instant::now();
if let Ok(res) = timeout(self.value(), first_request.as_mut()).await {
let _ = self.observe(first_start.elapsed());
return (res, None);
};
let mut second_request = Box::pin((f)());
let second_start = Instant::now();
let (is_first, res) = tokio::select! {
res = first_request.as_mut() => {
let _ = self.observe(first_start.elapsed());
(true, res)
}
res = second_request.as_mut() => {
let _ = self.observe(second_start.elapsed());
(false, res)
}
};
let rem = if is_first {
second_request
} else {
first_request
};
(
res,
Some(rem),
)
}
pub fn value(&self) -> Duration {
let current = self.current_usec.load(Relaxed);
Duration::from_micros(current)
}
pub fn observe(&self, duration: Duration) -> Result<()> {
self.histogram.increment(
duration
.as_micros()
.try_into()
.map_err(|_| histogram::Error::Overflow)?,
)?;
let observation_count = self.observation_count.fetch_add(1, SeqCst) + 1;
if observation_count % self.period == 0 {
self.rollout()?;
}
Ok(())
}
#[inline(always)]
fn rollout(&self) -> Result<()> {
let snap = self.histogram.snapshot();
let bucket = snap.percentile(self.percentile)?;
self.current_usec.store(bucket.end(), Relaxed);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rollout_at_period() {
let initial = Duration::from_secs(30);
let inner = Hedge::new(7, 64, initial, 10, 0.9).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(1)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(2)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(3)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(3)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(1)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(2)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(3)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(3)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(3)).unwrap();
assert_eq!(initial, inner.value());
inner.observe(Duration::from_secs(10)).unwrap();
assert_eq!(3.0, inner.value().as_secs_f64().round());
}
}