use descartes_core::{scheduler, SimTime};
use http::Request;
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tower::{Layer, Service};
use crate::tower::{ServiceError, SimBody};
#[derive(Clone)]
pub struct DesRateLimitLayer {
rate: f64, burst: usize, }
impl DesRateLimitLayer {
pub fn new(rate: f64, burst: usize) -> Self {
Self { rate, burst }
}
}
impl<S> Layer<S> for DesRateLimitLayer {
type Service = DesRateLimit<S>;
fn layer(&self, inner: S) -> Self::Service {
DesRateLimit::new(inner, self.rate, self.burst)
}
}
#[derive(Clone)]
pub struct DesRateLimit<S> {
inner: S,
rate: f64, burst: usize, tokens: Arc<Mutex<f64>>,
last_refill: Arc<Mutex<SimTime>>,
}
impl<S> DesRateLimit<S> {
pub fn new(inner: S, rate: f64, burst: usize) -> Self {
Self {
inner,
rate,
burst,
tokens: Arc::new(Mutex::new(burst as f64)),
last_refill: Arc::new(Mutex::new(SimTime::zero())),
}
}
fn try_acquire_token(&self) -> bool {
let current_time = scheduler::current_time().unwrap_or(SimTime::zero());
let mut tokens = self.tokens.lock().unwrap();
let mut last_refill = self.last_refill.lock().unwrap();
let elapsed = current_time.duration_since(*last_refill);
let tokens_to_add = elapsed.as_secs_f64() * self.rate;
*tokens = (*tokens + tokens_to_add).min(self.burst as f64);
*last_refill = current_time;
if *tokens >= 1.0 {
*tokens -= 1.0;
true
} else {
false
}
}
}
#[pin_project]
pub struct DesRateLimitFuture<F> {
#[pin]
inner: Option<F>,
immediate_error: Option<ServiceError>,
}
impl<S, ReqBody> Service<Request<ReqBody>> for DesRateLimit<S>
where
S: Service<Request<ReqBody>, Response = http::Response<SimBody>, Error = ServiceError>,
{
type Response = S::Response;
type Error = ServiceError;
type Future = DesRateLimitFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
if !self.try_acquire_token() {
return DesRateLimitFuture {
inner: None,
immediate_error: Some(ServiceError::Overloaded),
};
}
let inner_future = self.inner.call(req);
DesRateLimitFuture {
inner: Some(inner_future),
immediate_error: None,
}
}
}
impl<F> Future for DesRateLimitFuture<F>
where
F: Future<Output = Result<http::Response<SimBody>, ServiceError>>,
{
type Output = Result<http::Response<SimBody>, ServiceError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Some(error) = this.immediate_error.take() {
return Poll::Ready(Err(error));
}
if let Some(inner) = this.inner.as_mut().as_pin_mut() {
inner.poll(cx)
} else {
Poll::Ready(Err(ServiceError::RateLimiterInvalidState))
}
}
}