use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc, task::ready, time::Duration};
use http_body_util::BodyExt;
use hyper::{Request, Response, StatusCode};
use pin_project_lite::pin_project;
use rand::Rng;
use serde::{Deserialize, Serialize};
use tokio::time::Sleep;
use tower_layer::Layer;
use spacegate_kernel::{SgBody, SgBoxLayer, SgResponseExt};
use crate::def_plugin;
#[derive(Debug, Clone)]
pub struct RetryLayer<P> {
policy_default: P,
}
impl<P> RetryLayer<P>
where
P: Policy,
{
pub fn new(policy_default: P) -> Self {
Self { policy_default }
}
}
impl<S, P> Layer<S> for RetryLayer<P>
where
P: Clone,
{
type Service = Retry<P, S>;
fn layer(&self, service: S) -> Self::Service {
Retry {
policy: self.policy_default.clone(),
service,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub enum BackOff {
Fixed,
#[default]
Exponential,
Random,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(default)]
pub struct SgPluginRetryConfig {
pub retries: u16,
pub retirable_methods: Vec<String>,
pub backoff: BackOff,
pub base_interval: u64,
pub max_interval: u64,
}
impl Default for SgPluginRetryConfig {
fn default() -> Self {
Self {
retries: 3,
retirable_methods: vec!["*".to_string()],
backoff: BackOff::default(),
base_interval: 100,
max_interval: 10000,
}
}
}
#[derive(Clone)]
pub struct RetryPolicy {
times: usize,
config: Arc<SgPluginRetryConfig>,
}
pin_project! {
pub struct Delay<T> {
value: Option<T>,
#[pin]
sleep: Sleep,
}
}
impl<T> Delay<T> {
pub fn new(value: T, duration: Duration) -> Self {
Self {
value: Some(value),
sleep: tokio::time::sleep(duration),
}
}
}
impl<T> Future for Delay<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
let this = self.project();
ready!(this.sleep.poll(cx));
std::task::Poll::Ready(this.value.take().expect("poll after ready"))
}
}
pub trait Policy: Sized {
type Future: Future<Output = Self>;
fn retry(&self, req: &Request<SgBody>, response: &Response<SgBody>) -> Option<Self::Future>;
}
impl Policy for RetryPolicy {
type Future = Delay<Self>;
fn retry(&self, _req: &Request<SgBody>, response: &Response<SgBody>) -> Option<Self::Future> {
if self.times < self.config.retries.into() && response.status() == StatusCode::INTERNAL_SERVER_ERROR {
let delay = match self.config.backoff {
BackOff::Fixed => self.config.base_interval,
BackOff::Exponential => self.config.base_interval * 2u64.pow(self.times as u32),
BackOff::Random => {
let mut rng = rand::rng();
rng.gen_range(self.config.base_interval..self.config.max_interval)
}
};
Some(Delay::new(
RetryPolicy {
times: self.times + 1,
config: self.config.clone(),
},
Duration::from_millis(delay),
))
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct Retry<P, S> {
policy: P,
service: S,
}
pin_project_lite::pin_project! {
pub struct RetryFuture<P, S>
where
P: Policy,
S: hyper::service::Service<Request<SgBody>, Response = Response<SgBody>, Error = Infallible>,
{
policy: P,
service: S,
#[pin]
state: RetryState<P::Future, S::Future>,
request: Option<Request<SgBody>>
}
}
impl<P, S> RetryFuture<P, S>
where
P: Policy,
S: hyper::service::Service<Request<SgBody>, Response = Response<SgBody>, Error = Infallible>,
{
pub fn new(policy: P, service: S, req: Request<SgBody>) -> Self {
let (parts, body) = req.into_parts();
let body = body.collect();
Self {
policy,
service,
state: RetryState::Collecting { body, parts },
request: None,
}
}
}
pin_project_lite::pin_project! {
#[project = RetryStateProj]
pub enum RetryState<PF, SF> {
Collecting {
#[pin]
body: http_body_util::combinators::Collect<SgBody>,
parts: hyper::http::request::Parts,
},
Requesting {
#[pin]
future: SF,
},
Retrying {
#[pin]
future: PF,
},
}
}
impl<P, S> Future for RetryFuture<P, S>
where
P: Policy,
S: hyper::service::Service<Request<SgBody>, Response = Response<SgBody>, Error = Infallible>,
{
type Output = Result<Response<SgBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
let mut this = self.project();
loop {
match this.state.as_mut().project() {
RetryStateProj::Collecting { body, parts: part } => {
let body = ready!(body.poll(cx));
match body {
Ok(body) => {
let req = Request::from_parts(part.clone(), SgBody::full(body.to_bytes()));
{
let req = req.clone();
*this.request = Some(req);
}
let fut = this.service.call(req);
this.state.set(RetryState::Requesting { future: fut });
}
Err(e) => {
return std::task::Poll::Ready(Ok(Response::with_code_message(StatusCode::BAD_REQUEST, e.to_string())));
}
}
}
RetryStateProj::Requesting { future } => {
let resp = ready!(future.poll(cx));
match resp {
Ok(resp) => {
if let Some(fut) = this.policy.retry(this.request.as_ref().expect("status conflict"), &resp) {
this.state.set(RetryState::Retrying { future: fut });
} else {
return std::task::Poll::Ready(Ok(resp));
}
}
Err(_e) => {
unreachable!()
}
}
}
RetryStateProj::Retrying { future } => {
let next_p = ready!(future.poll(cx));
*this.policy = next_p;
let fut = this.service.call(this.request.as_ref().expect("status conflict").clone());
this.state.set(RetryState::Requesting { future: fut });
}
}
}
}
}
impl<P, S> hyper::service::Service<Request<SgBody>> for Retry<P, S>
where
P: Policy + Clone,
S: hyper::service::Service<Request<SgBody>, Response = Response<SgBody>, Error = Infallible> + Clone,
{
type Response = Response<SgBody>;
type Error = Infallible;
type Future = RetryFuture<P, S>;
fn call(&self, req: Request<SgBody>) -> Self::Future {
RetryFuture::new(self.policy.clone(), self.service.clone(), req)
}
}
def_plugin!("retry", RetryPlugin, SgPluginRetryConfig; #[cfg(feature = "schema")] schema;);
#[cfg(feature = "schema")]
crate::schema! {
RetryPlugin,
SgPluginRetryConfig
}