use std::time::{Duration, Instant};
use tokio::{
select,
sync::{Mutex, Semaphore, SemaphorePermit},
};
use crate::Service;
#[derive(Debug)]
pub struct RateLimit<S> {
inner: S,
semaphore: Semaphore,
last_update: Mutex<Instant>,
interval: Duration,
permits: usize,
}
impl<S> RateLimit<S> {
pub(crate) fn new(inner: S, interval: Duration, permits: usize) -> Self {
Self {
inner,
semaphore: Semaphore::new(permits),
last_update: Mutex::new(Instant::now()),
interval,
permits,
}
}
}
#[derive(Debug)]
pub struct RateLimitPermit<'a, S, Request>
where
S: Service<Request> + 'a,
{
inner: S::Permit<'a>,
_permit: SemaphorePermit<'a>,
}
impl<Request, S> Service<Request> for RateLimit<S>
where
S: Service<Request>,
{
type Response = S::Response;
type Permit<'a> = RateLimitPermit<'a, S, Request>
where
Self: 'a;
async fn acquire(&self) -> Self::Permit<'_> {
let fut = async move {
let mut guard = self.last_update.lock().await;
loop {
let now = Instant::now();
let end = *guard + self.interval;
tokio::time::sleep_until(end.into()).await;
self.semaphore.forget_permits(usize::MAX);
self.semaphore.add_permits(self.permits);
*guard = now;
}
};
let acquire = self.semaphore.acquire();
let permit = select! { permit = acquire => { permit }, never = fut => { never } };
RateLimitPermit {
_permit: permit.unwrap(),
inner: self.inner.acquire().await,
}
}
async fn call<'a>(permit: Self::Permit<'a>, request: Request) -> Self::Response
where
Self: 'a,
{
let RateLimitPermit { inner, _permit } = permit;
_permit.forget();
S::call(inner, request).await
}
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use crate::{service_fn, ServiceExt};
#[tokio::test]
async fn limit() {
let svc = service_fn(|x: u32| async move { x.to_string() })
.rate_limit(Duration::from_millis(100), 2);
let now = Instant::now();
for _ in 0..7 {
svc.oneshot(1).await;
}
let elapsed = Instant::now()
.checked_duration_since(now)
.expect("time travel isnt possible");
println!("{elapsed:?}");
assert!(elapsed > Duration::from_millis(200));
}
}