use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::response::IntoResponse;
use futures::future::BoxFuture;
use hyper::Request;
use tower::{Layer, Service};
use super::extractors::extract_client_ip;
use super::limiter::RateLimitManager;
use crate::app::ServerMode;
#[derive(Clone)]
pub struct RateLimitLayer {
manager: Arc<RateLimitManager>,
category: &'static str,
mode: ServerMode,
}
impl RateLimitLayer {
pub fn new(manager: Arc<RateLimitManager>, category: &'static str, mode: ServerMode) -> Self {
Self { manager, category, mode }
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
manager: self.manager.clone(),
category: self.category,
mode: self.mode,
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
manager: Arc<RateLimitManager>,
category: &'static str,
mode: ServerMode,
}
impl<S> Service<Request<Body>> for RateLimitService<S>
where
S: Service<Request<Body>, Response = axum::response::Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let manager = self.manager.clone();
let category = self.category;
let mode = self.mode;
let mut inner = self.inner.clone();
Box::pin(async move {
let client_ip = extract_client_ip(&req, &mode);
if let Some(ip) = client_ip {
if let Err(error) = manager.check(&ip, category) {
return Ok(error.into_response());
}
}
inner.call(req).await
})
}
}