use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use tower::Layer;
#[derive(Clone, Debug)]
pub struct RateLimitLayer {
state: Arc<Mutex<RateLimitState>>,
}
impl RateLimitLayer {
pub fn new(requests: u32, period: Duration) -> Self {
Self {
state: Arc::new(Mutex::new(RateLimitState::new(requests, period))),
}
}
pub fn per_second(requests: u32) -> Self {
Self::new(requests, Duration::from_secs(1))
}
pub fn with_min_delay(delay: Duration) -> Self {
Self::new(1, delay)
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, service: S) -> Self::Service {
RateLimitService {
service,
state: self.state.clone(),
}
}
}
#[derive(Debug)]
struct RateLimitState {
capacity: u32,
tokens: f64,
refill_rate: f64,
last_refill: Instant,
}
impl RateLimitState {
fn new(requests: u32, period: Duration) -> Self {
let refill_rate = requests as f64 / period.as_nanos() as f64;
Self {
capacity: requests,
tokens: requests as f64,
refill_rate,
last_refill: Instant::now(),
}
}
fn try_acquire(&mut self) -> Option<Duration> {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
None
} else {
let needed = 1.0 - self.tokens;
let wait_nanos = needed / self.refill_rate;
Some(Duration::from_nanos(wait_nanos as u64))
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let new_tokens = elapsed.as_nanos() as f64 * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity as f64);
self.last_refill = now;
}
}
#[derive(Clone, Debug)]
pub struct RateLimitService<S> {
service: S,
state: Arc<Mutex<RateLimitState>>,
}
impl<S, Request> tower::Service<Request> for RateLimitService<S>
where
S: tower::Service<Request> + Clone + Send + 'static,
S::Future: Send,
Request: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let state = self.state.clone();
let mut service = self.service.clone();
Box::pin(async move {
loop {
let wait_time = {
let mut state = state.lock().await;
state.try_acquire()
};
match wait_time {
None => break,
Some(duration) => {
tokio::time::sleep(duration).await;
}
}
}
service.call(request).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[tokio::test]
async fn test_rate_limit_state_immediate_acquire() {
let mut state = RateLimitState::new(10, Duration::from_secs(1));
assert!(state.try_acquire().is_none());
assert!(state.try_acquire().is_none());
}
#[tokio::test]
async fn test_rate_limit_state_exhaustion() {
let mut state = RateLimitState::new(2, Duration::from_secs(1));
assert!(state.try_acquire().is_none());
assert!(state.try_acquire().is_none());
let wait = state.try_acquire();
assert!(wait.is_some());
}
#[tokio::test]
async fn test_rate_limit_state_refill() {
let mut state = RateLimitState::new(10, Duration::from_secs(1));
for _ in 0..10 {
state.try_acquire();
}
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(state.try_acquire().is_none());
}
#[tokio::test]
async fn test_rate_limit_layer_construction() {
let layer = RateLimitLayer::new(10, Duration::from_secs(1));
assert!(layer.state.lock().await.capacity == 10);
}
#[tokio::test]
async fn test_rate_limit_per_second() {
let layer = RateLimitLayer::per_second(25);
assert!(layer.state.lock().await.capacity == 25);
}
#[tokio::test]
async fn test_rate_limit_enforces_rate() {
#[derive(Clone)]
struct InstantService;
impl tower::Service<()> for InstantService {
type Response = ();
type Error = std::convert::Infallible;
type Future = std::future::Ready<Result<(), std::convert::Infallible>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ()) -> Self::Future {
std::future::ready(Ok(()))
}
}
let layer = RateLimitLayer::per_second(5);
let mut service = layer.layer(InstantService);
let start = Instant::now();
for _ in 0..6 {
tower::Service::call(&mut service, ()).await.unwrap();
}
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(180));
}
}