spring_web/
middleware.rs

1use crate::config::CorsMiddleware;
2use crate::config::{
3    EnableMiddleware, LimitPayloadMiddleware, Middlewares, StaticAssetsMiddleware,
4    TimeoutRequestMiddleware, TraceLoggerMiddleware,
5};
6use anyhow::Context;
7use axum::Router;
8use spring::error::Result;
9use std::path::PathBuf;
10use std::str::FromStr;
11use std::time::Duration;
12use tower_http::trace::DefaultMakeSpan;
13use tower_http::trace::DefaultOnRequest;
14use tower_http::trace::DefaultOnResponse;
15use tower_http::{
16    catch_panic::CatchPanicLayer,
17    compression::CompressionLayer,
18    cors::CorsLayer,
19    limit::RequestBodyLimitLayer,
20    services::{ServeDir, ServeFile},
21    timeout::TimeoutLayer,
22    trace::TraceLayer,
23};
24use trace::DefaultOnEos;
25
26pub use tower_http::*;
27
28pub(crate) fn apply_middleware(mut router: Router, middleware: Middlewares) -> Router {
29    if Some(EnableMiddleware { enable: true }) == middleware.catch_panic {
30        router = router.layer(CatchPanicLayer::new());
31    }
32    if Some(EnableMiddleware { enable: true }) == middleware.compression {
33        router = router.layer(CompressionLayer::new());
34    }
35    if let Some(TraceLoggerMiddleware { enable, level }) = middleware.logger {
36        if enable {
37            let level = level.into();
38            router = router.layer(
39                TraceLayer::new_for_http()
40                    .make_span_with(DefaultMakeSpan::default().level(level))
41                    .on_request(DefaultOnRequest::default().level(level))
42                    .on_response(DefaultOnResponse::default().level(level))
43                    .on_eos(DefaultOnEos::default().level(level)),
44            );
45        }
46    }
47    if let Some(TimeoutRequestMiddleware { enable, timeout }) = middleware.timeout_request {
48        if enable {
49            router = router.layer(TimeoutLayer::new(Duration::from_millis(timeout)));
50        }
51    }
52    if let Some(LimitPayloadMiddleware { enable, body_limit }) = middleware.limit_payload {
53        if enable {
54            let limit = byte_unit::Byte::from_str(&body_limit)
55                .unwrap_or_else(|_| panic!("parse limit payload str failed: {}", &body_limit));
56
57            let limit_payload = RequestBodyLimitLayer::new(limit.as_u64() as usize);
58            router = router.layer(limit_payload);
59        }
60    }
61    if let Some(cors) = middleware.cors {
62        if cors.enable {
63            let cors = build_cors_middleware(&cors).expect("cors middleware build failed");
64            router = router.layer(cors);
65        }
66    }
67    if let Some(static_assets) = middleware.static_assets {
68        if static_assets.enable {
69            router = apply_static_dir(router, static_assets);
70        }
71    }
72    router
73}
74
75fn apply_static_dir(router: Router, static_assets: StaticAssetsMiddleware) -> Router {
76    if static_assets.must_exist
77        && (!PathBuf::from(&static_assets.path).exists()
78            || !PathBuf::from(&static_assets.fallback).exists())
79    {
80        panic!(
81            "one of the static path are not found, Folder `{}` fallback: `{}`",
82            static_assets.path, static_assets.fallback
83        );
84    }
85
86    let serve_dir =
87        ServeDir::new(static_assets.path).not_found_service(ServeFile::new(static_assets.fallback));
88
89    let service = if static_assets.precompressed {
90        tracing::info!("[Middleware] Enable precompressed static assets");
91        serve_dir.precompressed_gzip()
92    } else {
93        serve_dir
94    };
95
96    router.nest_service(&static_assets.uri, service)
97}
98
99fn build_cors_middleware(cors: &CorsMiddleware) -> Result<CorsLayer> {
100    let mut layer = CorsLayer::new();
101
102    if let Some(allow_origins) = &cors.allow_origins {
103        if allow_origins.iter().any(|item| item == "*") {
104            layer = layer.allow_origin(cors::Any);
105        } else {
106            let mut origins = Vec::with_capacity(allow_origins.len());
107            for origin in allow_origins {
108                let origin = origin
109                    .parse()
110                    .with_context(|| format!("cors origin parse failed:{}", origin))?;
111                origins.push(origin);
112            }
113            layer = layer.allow_origin(origins);
114        }
115    }
116
117    if let Some(allow_headers) = &cors.allow_headers {
118        if allow_headers.iter().any(|item| item == "*") {
119            layer = layer.allow_headers(cors::Any);
120        } else {
121            let mut headers = Vec::with_capacity(allow_headers.len());
122            for header in allow_headers {
123                let header = header
124                    .parse()
125                    .with_context(|| format!("http header parse failed:{}", header))?;
126                headers.push(header);
127            }
128            layer = layer.allow_headers(headers);
129        }
130    }
131
132    if let Some(allow_methods) = &cors.allow_methods {
133        if allow_methods.iter().any(|item| item == "*") {
134            layer = layer.allow_methods(cors::Any);
135        } else {
136            let mut methods = Vec::with_capacity(allow_methods.len());
137            for method in allow_methods {
138                let method = method
139                    .parse()
140                    .with_context(|| format!("http method parse failed:{}", method))?;
141                methods.push(method);
142            }
143            layer = layer.allow_methods(methods);
144        }
145    }
146
147    if let Some(max_age) = cors.max_age {
148        layer = layer.max_age(Duration::from_secs(max_age));
149    }
150
151    Ok(layer)
152}