mll-axum-utils 0.2.1

一个 Axum 的工具库
use std::{
    fs::File,
    io::Write,
    net::SocketAddr,
    sync::mpsc::{channel, Sender},
    task::{Context, Poll},
};

use axum::{
    body::Body,
    extract::ConnectInfo,
    http::{header::LOCATION, Request},
    response::Response,
};
use chrono::{DateTime, Local};
use color_string::{cs, fonts, Colored, Font::*};
use futures_util::future::BoxFuture;
use percent_encoding::percent_decode;
use tower::{Layer, Service};

use crate::utils::create_log_file;

/// # Examples
/// ```no_run
/// use axum::Router;
/// use mll_axum_utils::middleware::logger::Logger;
/// use std::net::SocketAddr;
///
/// #[tokio::main]
/// async fn main() {
/// let addr = "127.0.0.1:3000";
///     let app = Router::new().layer(Logger::default());
///
///     axum::Server::bind(&addr.parse().unwrap())
///         .serve(app.into_make_service_with_connect_info::<SocketAddr>())
///         .await
///         .unwrap();
/// }
/// ```
#[derive(Clone)]
pub struct Logger {
    sender: Sender<LogMsg>,
}

impl Logger {
    /// # Examples
    /// ```no_run
    /// use mll_axum_utils::middleware::logger::Logger;
    /// Logger::new("logs/%Y-%m-%d.log", true, true);
    /// ```
    pub fn new(format: &str, stdout: bool, file_out: bool) -> Self {
        let mut time = Local::now();

        let mut file = file_out.then(|| {
            let path = time.format(format).to_string();
            create_log_file(path)
        });

        let (sender, rx) = channel::<LogMsg>();
        // 单独线程 同步写入日志
        let format = format.to_string();
        tokio::spawn(async move {
            for msg in rx {
                if stdout {
                    msg.stdout()
                }

                if let Some(file) = file.as_mut() {
                    // 切换日志文件
                    if time.date_naive() != msg.begin.date_naive() {
                        time = msg.begin;
                        *file = create_log_file(time.format(&format).to_string())
                    }
                    msg.file_out(file)
                }
            }
        });

        Self { sender }
    }
}

impl Default for Logger {
    fn default() -> Self {
        Self::new("logs/%Y-%m-%d.log", true, false)
    }
}

impl<S> Layer<S> for Logger {
    type Service = LoggerService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        LoggerService {
            inner,
            sender: self.sender.clone(),
        }
    }
}

#[derive(Clone)]
pub struct LoggerService<S> {
    inner: S,
    sender: Sender<LogMsg>,
}

impl<S> Service<Request<Body>> for LoggerService<S>
where
    S: Service<Request<Body>, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        // 开始时间
        let begin = Local::now();
        // 请求方式
        let method = req.method().to_string();
        // 连接 ip
        let ip = match req.extensions().get::<ConnectInfo<SocketAddr>>() {
            Some(v) => v.0.ip().to_string(),
            None => panic!("Axum service 未配置 ConnectInfo<SocketAddr>"),
        };
        // 请求路径 解码为 utf-8
        let mut path = percent_decode(req.uri().path().as_bytes())
            .decode_utf8_lossy()
            .to_string();

        let sender = self.sender.clone();
        let future = self.inner.call(req);

        Box::pin(async move {
            let response: Self::Response = future.await?;
            // 状态码
            let status = response.status().as_u16();
            // 是否重定向
            if let Some(p) = response.headers().get(LOCATION) {
                path = format!(
                    "{path} -> {}",
                    percent_decode(p.as_bytes()).decode_utf8_lossy()
                )
            }

            let msg = LogMsg {
                logo: "[AXUM]".into(),
                begin,
                end: Local::now(),
                status,
                ip,
                method,
                path,
                other: "".into(),
            };

            if let Err(err) = sender.send(msg) {
                eprintln!("Send 日志时出现错误 -> {err}")
            }
            Ok(response)
        })
    }
}

struct LogMsg {
    logo: String,
    begin: DateTime<Local>,
    end: DateTime<Local>,
    status: u16,
    ip: String,
    method: String,
    path: String,
    other: String,
}

impl LogMsg {
    fn stdout(&self) {
        let status = match self.status / 100 {
            2 => cs!(BgGreen; " {} ", self.status),
            3 => cs!(BgBlue; " {} ", self.status),
            4 | 5 => cs!(BgRed; " {} ", self.status),
            _ => cs!(BgYellow; " {} ", self.status),
        };

        let method = match self.method.as_str() {
            "GET" | "POST" => cs!(BgBlue; " {:<6} ", self.method),
            "DELETE" => cs!(BgRed; " {:<6} ", self.method),
            _ => cs!(BgYellow; " {:<6} ", self.method),
        };

        println!(
            "[{}] {} |{}| {:>6} | {} |{} {} {}",
            self.begin.format("%Y-%m-%d %H:%M:%S").color(127, 132, 142),
            self.logo.fonts(fonts!(Bold, Purple)),
            status,
            format!("{}ms", (self.end - self.begin).num_milliseconds()),
            cs!(Yellow; "{:<15}", self.ip),
            method,
            self.path,
            self.other
        );
    }

    fn file_out(&self, file: &mut File) {
        let msg = format!(
            "[{}] {} | {} | {:>6} | {:>15} | {:<6} {} {}\n",
            self.begin.format("%Y-%m-%d %H:%M:%S"),
            self.logo,
            self.status,
            format!("{}ms", (self.end - self.begin).num_milliseconds()),
            self.ip,
            self.method,
            self.path,
            self.other
        );
        if let Err(err) = file.write_all(msg.as_bytes()) {
            println!("日志写入文件时出错 -> {err}")
        }
    }
}