use std::{future::Future, pin::Pin};
use tower::Service;
use crate::{
middleware::{error::RateLimitError, KeyExtractor},
storage::RateLimitStore,
};
#[derive(Clone)]
pub struct RateLimitService<S, Store, Extractor> {
pub(crate) inner: S,
pub(crate) store: Store,
pub(crate) extractor: Extractor,
}
impl<S, Store, Extractor, Req> Service<Req> for RateLimitService<S, Store, Extractor>
where
S: Service<Req> + Send,
S::Future: Send + 'static,
Store: RateLimitStore,
Extractor: KeyExtractor<Req>,
{
type Response = S::Response;
type Error = RateLimitError<S::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(RateLimitError::Inner)
}
fn call(&mut self, req: Req) -> Self::Future {
if let Some(key) = self.extractor.extract(&req) {
if !self.store.allow(&key) {
return Box::pin(async {
Err(RateLimitError::Rejected(crate::errors::RateLimitExceeded))
});
}
}
let fut = self.inner.call(req);
Box::pin(async move { fut.await.map_err(RateLimitError::Inner) })
}
}