use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::bulkhead::Bulkhead;
use crate::circuit_breaker::BreakerPolicy;
use crate::policy::Policy;
use crate::rate_limit::RateLimiter;
use crate::retry_policy::RetryPolicy;
use crate::timeout::TimeoutPolicy;
#[derive(Clone)]
pub struct Pipeline {
retry_policy: Option<RetryPolicy>,
circuit_breaker: Option<BreakerPolicy>,
timeout: Option<TimeoutPolicy>,
rate_limiter: Option<RateLimiter>,
bulkhead: Option<Bulkhead>,
}
impl Pipeline {
pub fn new() -> Self {
Pipeline {
retry_policy: None,
circuit_breaker: None,
timeout: None,
rate_limiter: None,
bulkhead: None,
}
}
}
impl Default for Pipeline {
fn default() -> Self {
Self::new()
}
}
impl Pipeline {
pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = Some(policy);
self
}
pub fn with_circuit_breaker(mut self, policy: BreakerPolicy) -> Self {
self.circuit_breaker = Some(policy);
self
}
pub fn with_timeout(mut self, policy: TimeoutPolicy) -> Self {
self.timeout = Some(policy);
self
}
pub fn with_rate_limiter(mut self, policy: RateLimiter) -> Self {
self.rate_limiter = Some(policy);
self
}
pub fn with_bulkhead(mut self, policy: Bulkhead) -> Self {
self.bulkhead = Some(policy);
self
}
pub fn or_else<F, Fut, T, E>(self, fallback: F) -> FallbackPipeline<F, Fut, T, E> {
FallbackPipeline {
primary: self,
fallback,
_marker: PhantomData,
}
}
pub async fn run<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
where
F: FnMut() -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Send + 'static + From<crate::timeout::TimeoutError>,
{
if let Some(ref cb) = self.circuit_breaker {
if !cb.should_allow_request() {
return f().await;
}
}
let _bulkhead_permit = if let Some(ref bulkhead) = self.bulkhead {
bulkhead.try_acquire()
} else {
None
};
let result = match (self.retry_policy.as_ref(), self.timeout.as_ref()) {
(Some(retry), Some(timeout)) => {
let timeout_flag = Arc::new(AtomicBool::new(false));
let mut retry = retry.clone();
retry.timeout_occurred = Some(timeout_flag.clone());
let duration = timeout.duration;
let on_success = timeout.on_success.clone();
let on_failure = timeout.on_failure.clone();
let on_timeout = timeout.on_timeout.clone();
let rate_limiter = self.rate_limiter.clone();
let mut timed = || {
if let Some(ref rl) = rate_limiter {
if !rl.try_consume(1) {
return Box::pin(f())
as Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
}
}
timeout_flag.store(false, Ordering::Relaxed);
let fut = f();
let flag = timeout_flag.clone();
let os = on_success.clone();
let of = on_failure.clone();
let ot = on_timeout.clone();
Box::pin(async move {
match tokio::time::timeout(duration, fut).await {
Ok(Ok(val)) => {
if let Some(cb) = os {
cb().await;
}
Ok(val)
}
Ok(Err(e)) => {
if let Some(cb) = of {
cb().await;
}
Err(e)
}
Err(_elapsed) => {
if let Some(cb) = ot {
cb().await;
}
flag.store(true, Ordering::Relaxed);
unreachable!("timeout future is dropped - retry should handle this")
}
}
}) as Pin<Box<dyn Future<Output = Result<T, E>> + Send>>
};
retry.call(&mut timed).await
}
(Some(retry), None) => {
let mut g = || {
if let Some(ref rl) = self.rate_limiter {
if !rl.try_consume(1) {
return Box::pin(f())
as Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
}
}
Box::pin(f()) as Pin<Box<dyn Future<Output = Result<T, E>> + Send>>
};
retry.call(&mut g).await
}
(None, Some(timeout)) => {
let duration = timeout.duration;
let on_success = timeout.on_success.clone();
let on_failure = timeout.on_failure.clone();
let on_timeout = timeout.on_timeout.clone();
let name = timeout.name.clone();
let rate_limiter = self.rate_limiter.clone();
let g = move || {
if let Some(ref rl) = rate_limiter {
if !rl.try_consume(1) {
return Box::pin(f())
as Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
}
}
let fut = f();
let os = on_success.clone();
let of = on_failure.clone();
let ot = on_timeout.clone();
Box::pin(async move {
match tokio::time::timeout(duration, fut).await {
Ok(Ok(val)) => {
if let Some(cb) = os {
cb().await;
}
Ok(val)
}
Ok(Err(e)) => {
if let Some(cb) = of {
cb().await;
}
Err(e)
}
Err(_elapsed) => {
if let Some(cb) = ot {
cb().await;
}
Err(crate::timeout::TimeoutError::Elapsed { duration, name }.into())
}
}
}) as Pin<Box<dyn Future<Output = Result<T, E>> + Send>>
};
g().await
}
(None, None) => {
if let Some(ref rl) = self.rate_limiter {
let _ = rl.try_consume(1);
}
f().await
}
};
if let Some(ref cb) = self.circuit_breaker {
match &result {
Ok(_) => cb.record_success(),
Err(_) => cb.record_failure(),
}
}
result
}
}
pub struct FallbackPipeline<F, Fut, T, E> {
primary: Pipeline,
fallback: F,
_marker: PhantomData<(Fut, T, E)>,
}
impl<F, Fut, T, E> FallbackPipeline<F, Fut, T, E>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Send + 'static,
{
pub async fn run<G, GFut>(&self, op: &mut G) -> Result<T, E>
where
G: FnMut() -> GFut + Send,
GFut: Future<Output = Result<T, E>> + Send,
E: Send + 'static + From<crate::timeout::TimeoutError>,
{
match self.primary.run(op).await {
Ok(val) => Ok(val),
Err(_) => (self.fallback)().await,
}
}
}