use std::sync::atomic::AtomicU64;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use futures::ready;
use tower::Service;
use super::future::ResponseFuture;
use super::Rate;
use crate::plugins::traffic_shaping::rate::error::RateLimited;
#[derive(Debug, Clone)]
pub(crate) struct RateLimit<T> {
pub(crate) inner: T,
pub(crate) rate: Rate,
pub(crate) window_start: Arc<AtomicU64>,
pub(crate) previous_nb_requests: Arc<AtomicUsize>,
pub(crate) current_nb_requests: Arc<AtomicUsize>,
}
impl<S, Request> Service<Request> for RateLimit<S>
where
S: Service<Request>,
S::Error: Into<tower::BoxError>,
{
type Response = S::Response;
type Error = tower::BoxError;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let time_unit = self.rate.per().as_millis() as u64;
let updated =
self.window_start
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |window_start| {
let duration_now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after EPOCH")
.as_millis() as u64;
if duration_now - window_start > self.rate.per().as_millis() as u64 {
Some(duration_now)
} else {
None
}
});
if let Ok(_updated_window_start) = updated {
self.previous_nb_requests.swap(
self.current_nb_requests.load(Ordering::SeqCst),
Ordering::SeqCst,
);
self.current_nb_requests.swap(1, Ordering::SeqCst);
}
let estimated_cap = (self.previous_nb_requests.load(Ordering::SeqCst)
* (time_unit
.checked_sub(self.window_start.load(Ordering::SeqCst))
.unwrap_or_default()
/ time_unit) as usize)
+ self.current_nb_requests.load(Ordering::SeqCst);
if estimated_cap as u64 > self.rate.num() {
tracing::trace!("rate limit exceeded; sleeping.");
return Poll::Ready(Err(RateLimited::new().into()));
}
self.current_nb_requests.fetch_add(1, Ordering::SeqCst);
Poll::Ready(ready!(self.inner.poll_ready(cx)).map_err(Into::into))
}
fn call(&mut self, request: Request) -> Self::Future {
ResponseFuture::new(self.inner.call(request))
}
}