mll-axum-utils 0.2.1

一个 Axum 的工具库
use self::blacklist::Blacklist;
use axum::{body::Body, http::Request, response::Response};
use futures_util::future::BoxFuture;
use std::collections::HashSet;
use std::net::IpAddr;
use std::{
    sync::Arc,
    task::{Context, Poll},
};
use tower::{Layer, Service};

mod blacklist;

pub trait Intercept {
    fn before(&self, req: Request<Body>) -> Result<Request<Body>, Response> {
        Ok(req)
    }
    fn after(&self, res: Response) -> Response {
        res
    }
}

/// 拦截器
#[derive(Clone)]
pub struct Interceptor<T> {
    pub intercept: Arc<T>,
}

impl<T> Interceptor<T>
where
    T: Clone,
{
    pub fn new(intercept: T) -> Self {
        Self {
            intercept: Arc::new(intercept),
        }
    }
}

impl<S, T> Layer<S> for Interceptor<T> {
    type Service = InterceptorService<S, T>;

    fn layer(&self, inner: S) -> Self::Service {
        InterceptorService {
            inner,
            intercept: self.intercept.clone(),
        }
    }
}

/// 拦截器服务
#[derive(Clone)]
pub struct InterceptorService<S, T> {
    pub inner: S,
    pub intercept: Arc<T>,
}

impl<S, T> Service<Request<Body>> for InterceptorService<S, T>
where
    T: Intercept + Sync + Send + 'static,
    S: Service<Request<Body>, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let intercept = self.intercept.clone();
        let future = self.intercept.before(req).map(|m| self.inner.call(m));
        Box::pin(async move {
            match future {
                Ok(res) => Ok(intercept.after(res.await?)),
                Err(res) => Ok(res),
            }
        })
    }
}

pub fn blacklist(ips: HashSet<IpAddr>) -> Interceptor<Blacklist> {
    Interceptor::new(Blacklist { ips })
}

pub fn blacklist_vec(ips: Vec<&str>) -> Interceptor<Blacklist> {
    let ips = ips.into_iter().map(|m| m.parse().unwrap()).collect();
    Interceptor::new(Blacklist { ips })
}