use std::future::{ready, Future, Ready};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform};
use actix_web::{
body::EitherBody,
dev::{ServiceRequest, ServiceResponse},
http::StatusCode,
Error, HttpResponse,
};
use crate::algorithm::Algorithm;
use crate::decision::Decision;
use crate::quota::Quota;
use crate::storage::Storage;
pub struct RateLimiter<S, A> {
storage: Arc<S>,
algorithm: A,
quota: Quota,
}
impl<S, A> RateLimiter<S, A>
where
S: Storage,
A: Algorithm + Clone,
{
pub fn new(storage: S, algorithm: A, quota: Quota) -> Self {
Self {
storage: Arc::new(storage),
algorithm,
quota,
}
}
}
impl<S, A> Clone for RateLimiter<S, A>
where
A: Clone,
{
fn clone(&self) -> Self {
Self {
storage: self.storage.clone(),
algorithm: self.algorithm.clone(),
quota: self.quota.clone(),
}
}
}
impl<S, A, Svc, B> Transform<Svc, ServiceRequest> for RateLimiter<S, A>
where
S: Storage + Send + Sync + 'static,
A: Algorithm + Clone + Send + Sync + 'static,
Svc: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
Svc::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Transform = RateLimiterMiddleware<S, A, Svc>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: Svc) -> Self::Future {
ready(Ok(RateLimiterMiddleware {
service,
storage: self.storage.clone(),
algorithm: self.algorithm.clone(),
quota: self.quota.clone(),
}))
}
}
pub struct RateLimiterMiddleware<S, A, Svc> {
service: Svc,
storage: Arc<S>,
algorithm: A,
quota: Quota,
}
impl<S, A, Svc, B> Service<ServiceRequest> for RateLimiterMiddleware<S, A, Svc>
where
S: Storage + Send + Sync + 'static,
A: Algorithm + Clone + Send + Sync + 'static,
Svc: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
Svc::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let storage = self.storage.clone();
let algorithm = self.algorithm.clone();
let quota = self.quota.clone();
let key = extract_key(&req);
let fut = self.service.call(req);
Box::pin(async move {
let decision = algorithm
.check_and_record(&*storage, &key, "a)
.await
.unwrap_or_else(|_| {
Decision::allowed(crate::decision::RateLimitInfo::new(
quota.max_requests(),
quota.max_requests(),
std::time::Instant::now() + quota.window(),
std::time::Instant::now(),
))
});
if decision.is_denied() {
let info = decision.info();
let retry_after = info
.retry_after
.map(|d| d.as_secs().to_string())
.unwrap_or_else(|| "60".to_string());
let body = format!(
r#"{{"error":"Too Many Requests","retry_after":{},"remaining":{},"limit":{}}}"#,
retry_after, info.remaining, info.limit
);
let response = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
.insert_header(("Content-Type", "application/json"))
.insert_header(("X-RateLimit-Limit", info.limit.to_string()))
.insert_header(("X-RateLimit-Remaining", info.remaining.to_string()))
.insert_header(("X-RateLimit-Reset", info.reset_seconds().to_string()))
.insert_header(("Retry-After", retry_after))
.body(body);
return Err(actix_web::error::InternalError::from_response(
"Rate limited",
response,
)
.into());
}
let res = fut.await?;
Ok(res.map_into_left_body())
})
}
}
#[allow(clippy::collapsible_if)]
fn extract_key(req: &ServiceRequest) -> String {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(value) = forwarded.to_str() {
if let Some(ip) = value.split(',').next() {
return format!("ip:{}", ip.trim());
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(value) = real_ip.to_str() {
return format!("ip:{}", value);
}
}
if let Some(peer) = req.connection_info().peer_addr() {
return format!("ip:{}", peer);
}
format!("path:{}", req.path())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_creation() {
use crate::algorithm::GCRA;
use crate::storage::MemoryStorage;
let storage = MemoryStorage::new();
let limiter = RateLimiter::new(storage, GCRA::new(), Quota::per_second(10));
assert_eq!(limiter.quota.max_requests(), 10);
}
}