mod config;
mod error;
mod events;
mod layer;
pub use config::{HedgeConfig, HedgeConfigBuilder, HedgeDelay};
pub use error::HedgeError;
pub use events::HedgeEvent;
pub use layer::HedgeLayer;
use futures::future::BoxFuture;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tower::Service;
pub struct Hedge<S> {
inner: S,
config: Arc<HedgeConfig>,
}
impl<S> Hedge<S> {
pub fn new(inner: S, config: HedgeConfig) -> Self {
Self {
inner,
config: Arc::new(config),
}
}
}
impl<S: Clone> Clone for Hedge<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
config: Arc::clone(&self.config),
}
}
}
impl<S, Req> Service<Req> for Hedge<S>
where
S: Service<Req> + Clone + Send + 'static,
S::Response: Send + Sync + 'static,
S::Error: Clone + Send + Sync + 'static,
S::Future: Send,
Req: Clone + Send + Sync + 'static,
{
type Response = S::Response;
type Error = HedgeError<S::Error>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(HedgeError::Inner)
}
fn call(&mut self, req: Req) -> Self::Future {
let config = Arc::clone(&self.config);
let inner = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move { execute_with_hedging(inner, req, config).await })
}
}
async fn execute_with_hedging<S, Req>(
service: S,
req: Req,
config: Arc<HedgeConfig>,
) -> Result<S::Response, HedgeError<S::Error>>
where
S: Service<Req> + Clone + Send + 'static,
S::Response: Send + 'static,
S::Error: Clone + Send + 'static,
S::Future: Send,
Req: Clone + Send + 'static,
{
use tokio::sync::mpsc;
let max_attempts = config.max_hedged_attempts;
let start = Instant::now();
config.listeners.emit(&HedgeEvent::PrimaryStarted {
name: config.name.clone(),
timestamp: Instant::now(),
});
let (tx, mut rx) = mpsc::channel::<(usize, Result<S::Response, S::Error>)>(max_attempts);
let mut service_clone = service.clone();
let req_clone = req.clone();
let tx_clone = tx.clone();
tokio::spawn(async move {
let result = service_clone.call(req_clone).await;
let _ = tx_clone.send((0, result)).await;
});
let mut hedges_spawned: usize = 0;
let mut primary_error: Option<S::Error> = None;
let first_delay = config.delay.get_delay(1);
if max_attempts > 1 {
match first_delay {
Some(delay) if delay > Duration::ZERO => {
let mut delay_fut = std::pin::pin!(tokio::time::sleep(delay));
loop {
tokio::select! {
biased;
Some((attempt, result)) = rx.recv() => {
match &result {
Ok(_) => {
let duration = start.elapsed();
if attempt == 0 {
config.listeners.emit(&HedgeEvent::PrimarySucceeded {
name: config.name.clone(),
duration,
hedges_cancelled: hedges_spawned,
timestamp: Instant::now(),
});
} else {
config.listeners.emit(&HedgeEvent::HedgeSucceeded {
name: config.name.clone(),
attempt,
duration,
primary_cancelled: true,
timestamp: Instant::now(),
});
}
return result.map_err(HedgeError::Inner);
}
Err(e) => {
if attempt == 0 {
primary_error = Some(e.clone());
}
if hedges_spawned + 1 >= max_attempts {
config.listeners.emit(&HedgeEvent::AllFailed {
name: config.name.clone(),
attempts: hedges_spawned + 1,
timestamp: Instant::now(),
});
return Err(HedgeError::AllAttemptsFailed(
primary_error.unwrap_or_else(|| e.clone())
));
}
}
}
}
_ = &mut delay_fut, if hedges_spawned + 1 < max_attempts => {
hedges_spawned += 1;
let attempt_num = hedges_spawned;
config.listeners.emit(&HedgeEvent::HedgeStarted {
name: config.name.clone(),
attempt: attempt_num,
delay,
timestamp: Instant::now(),
});
let mut svc = service.clone();
let r = req.clone();
let tx_c = tx.clone();
tokio::spawn(async move {
let result = svc.call(r).await;
let _ = tx_c.send((attempt_num, result)).await;
});
if hedges_spawned + 1 < max_attempts {
if let Some(next_delay) = config.delay.get_delay(hedges_spawned + 1) {
delay_fut.set(tokio::time::sleep(next_delay));
}
}
}
else => {
if let Some((attempt, result)) = rx.recv().await {
match &result {
Ok(_) => {
let duration = start.elapsed();
if attempt == 0 {
config.listeners.emit(&HedgeEvent::PrimarySucceeded {
name: config.name.clone(),
duration,
hedges_cancelled: hedges_spawned,
timestamp: Instant::now(),
});
} else {
config.listeners.emit(&HedgeEvent::HedgeSucceeded {
name: config.name.clone(),
attempt,
duration,
primary_cancelled: attempt != 0,
timestamp: Instant::now(),
});
}
return result.map_err(HedgeError::Inner);
}
Err(e) => {
if attempt == 0 && primary_error.is_none() {
primary_error = Some(e.clone());
}
}
}
} else {
break;
}
}
}
}
}
_ => {
for i in 1..max_attempts {
hedges_spawned += 1;
config.listeners.emit(&HedgeEvent::HedgeStarted {
name: config.name.clone(),
attempt: i,
delay: Duration::ZERO,
timestamp: Instant::now(),
});
let mut svc = service.clone();
let r = req.clone();
let tx_c = tx.clone();
tokio::spawn(async move {
let result = svc.call(r).await;
let _ = tx_c.send((i, result)).await;
});
}
}
}
}
drop(tx);
let mut attempts_received: usize = 0;
let total_attempts = hedges_spawned + 1;
while let Some((attempt, result)) = rx.recv().await {
attempts_received += 1;
match result {
Ok(res) => {
let duration = start.elapsed();
if attempt == 0 {
config.listeners.emit(&HedgeEvent::PrimarySucceeded {
name: config.name.clone(),
duration,
hedges_cancelled: hedges_spawned.saturating_sub(attempts_received - 1),
timestamp: Instant::now(),
});
} else {
config.listeners.emit(&HedgeEvent::HedgeSucceeded {
name: config.name.clone(),
attempt,
duration,
primary_cancelled: true,
timestamp: Instant::now(),
});
}
return Ok(res);
}
Err(e) => {
if primary_error.is_none() {
primary_error = Some(e);
}
}
}
}
config.listeners.emit(&HedgeEvent::AllFailed {
name: config.name.clone(),
attempts: total_attempts,
timestamp: Instant::now(),
});
Err(HedgeError::AllAttemptsFailed(
primary_error.expect("at least one error should exist"),
))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tower::{Layer, ServiceExt};
#[derive(Clone, Debug)]
struct TestError;
#[tokio::test]
async fn test_primary_succeeds_no_hedge() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let service = tower::service_fn(move |_req: String| {
let cc = Arc::clone(&cc);
async move {
cc.fetch_add(1, Ordering::SeqCst);
Ok::<_, TestError>("success".to_string())
}
});
let layer = HedgeLayer::builder()
.delay(Duration::from_millis(100))
.max_hedged_attempts(2)
.build();
let mut service = layer.layer(service);
let result = service
.ready()
.await
.unwrap()
.call("test".to_string())
.await;
assert!(result.is_ok());
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_parallel_mode_all_called() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let service = tower::service_fn(move |_req: String| {
let cc = Arc::clone(&cc);
async move {
cc.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
Ok::<_, TestError>("success".to_string())
}
});
let layer = HedgeLayer::builder()
.no_delay()
.max_hedged_attempts(3)
.build();
let mut service = layer.layer(service);
let result = service
.ready()
.await
.unwrap()
.call("test".to_string())
.await;
assert!(result.is_ok());
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_hedge_fires_after_delay() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let service = tower::service_fn(move |_req: String| {
let cc = Arc::clone(&cc);
async move {
let count = cc.fetch_add(1, Ordering::SeqCst);
if count == 0 {
tokio::time::sleep(Duration::from_millis(200)).await;
}
Ok::<_, TestError>("success".to_string())
}
});
let layer = HedgeLayer::builder()
.delay(Duration::from_millis(50))
.max_hedged_attempts(2)
.build();
let mut service = layer.layer(service);
let start = Instant::now();
let result = service
.ready()
.await
.unwrap()
.call("test".to_string())
.await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert!(elapsed < Duration::from_millis(150));
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_all_fail_returns_error() {
let service = tower::service_fn(|_req: String| async move { Err::<String, _>(TestError) });
let layer = HedgeLayer::builder()
.no_delay()
.max_hedged_attempts(2)
.build();
let mut service = layer.layer(service);
let result = service
.ready()
.await
.unwrap()
.call("test".to_string())
.await;
assert!(matches!(result, Err(HedgeError::AllAttemptsFailed(_))));
}
#[test]
fn test_preset_conservative() {
let _layer = HedgeLayer::conservative();
}
#[test]
fn test_preset_standard() {
let _layer = HedgeLayer::standard();
}
#[test]
fn test_preset_aggressive() {
let _layer = HedgeLayer::aggressive();
}
}