use self::blacklist::Blacklist;
use axum::{body::Body, http::Request, response::Response};
use futures_util::future::BoxFuture;
use std::collections::HashSet;
use std::net::IpAddr;
use std::{
sync::Arc,
task::{Context, Poll},
};
use tower::{Layer, Service};
mod blacklist;
pub trait Intercept {
fn before(&self, req: Request<Body>) -> Result<Request<Body>, Response> {
Ok(req)
}
fn after(&self, res: Response) -> Response {
res
}
}
#[derive(Clone)]
pub struct Interceptor<T> {
pub intercept: Arc<T>,
}
impl<T> Interceptor<T>
where
T: Clone,
{
pub fn new(intercept: T) -> Self {
Self {
intercept: Arc::new(intercept),
}
}
}
impl<S, T> Layer<S> for Interceptor<T> {
type Service = InterceptorService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
InterceptorService {
inner,
intercept: self.intercept.clone(),
}
}
}
#[derive(Clone)]
pub struct InterceptorService<S, T> {
pub inner: S,
pub intercept: Arc<T>,
}
impl<S, T> Service<Request<Body>> for InterceptorService<S, T>
where
T: Intercept + Sync + Send + 'static,
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let intercept = self.intercept.clone();
let future = self.intercept.before(req).map(|m| self.inner.call(m));
Box::pin(async move {
match future {
Ok(res) => Ok(intercept.after(res.await?)),
Err(res) => Ok(res),
}
})
}
}
pub fn blacklist(ips: HashSet<IpAddr>) -> Interceptor<Blacklist> {
Interceptor::new(Blacklist { ips })
}
pub fn blacklist_vec(ips: Vec<&str>) -> Interceptor<Blacklist> {
let ips = ips.into_iter().map(|m| m.parse().unwrap()).collect();
Interceptor::new(Blacklist { ips })
}