spring_web/
middleware.rs

1use crate::config::CorsMiddleware;
2use crate::config::{
3    EnableMiddleware, LimitPayloadMiddleware, Middlewares, StaticAssetsMiddleware,
4    TimeoutRequestMiddleware, TraceLoggerMiddleware,
5};
6use crate::Router;
7use anyhow::Context;
8use axum::http::StatusCode;
9use spring::error::Result;
10use std::path::PathBuf;
11use std::str::FromStr;
12use std::time::Duration;
13use tower_http::trace::DefaultMakeSpan;
14use tower_http::trace::DefaultOnRequest;
15use tower_http::trace::DefaultOnResponse;
16use tower_http::{
17    catch_panic::CatchPanicLayer,
18    compression::CompressionLayer,
19    cors::CorsLayer,
20    limit::RequestBodyLimitLayer,
21    services::{ServeDir, ServeFile},
22    timeout::TimeoutLayer,
23    trace::TraceLayer,
24};
25use trace::DefaultOnEos;
26
27pub use tower_http::*;
28
29pub(crate) fn apply_middleware(mut router: Router, middleware: Middlewares) -> Router {
30    // Always apply URI capture middleware first (for Problem Details)
31    router = router.layer(axum::middleware::from_fn(crate::problem_details::capture_request_uri_middleware));
32    
33    if Some(EnableMiddleware { enable: true }) == middleware.catch_panic {
34        router = router.layer(CatchPanicLayer::new());
35    }
36    if Some(EnableMiddleware { enable: true }) == middleware.compression {
37        router = router.layer(CompressionLayer::new());
38    }
39    if let Some(TraceLoggerMiddleware { enable, level }) = middleware.logger {
40        if enable {
41            let level = level.into();
42            router = router.layer(
43                TraceLayer::new_for_http()
44                    .make_span_with(DefaultMakeSpan::default().level(level))
45                    .on_request(DefaultOnRequest::default().level(level))
46                    .on_response(DefaultOnResponse::default().level(level))
47                    .on_eos(DefaultOnEos::default().level(level)),
48            );
49        }
50    }
51    if let Some(TimeoutRequestMiddleware { enable, timeout }) = middleware.timeout_request {
52        if enable {
53            router = router.layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_millis(timeout)));
54        }
55    }
56    if let Some(LimitPayloadMiddleware { enable, body_limit }) = middleware.limit_payload {
57        if enable {
58            let limit = byte_unit::Byte::from_str(&body_limit)
59                .unwrap_or_else(|_| panic!("parse limit payload str failed: {}", &body_limit));
60
61            let limit_payload = RequestBodyLimitLayer::new(limit.as_u64() as usize);
62            router = router.layer(limit_payload);
63        }
64    }
65    if let Some(cors) = middleware.cors {
66        if cors.enable {
67            let cors = build_cors_middleware(&cors).expect("cors middleware build failed");
68            router = router.layer(cors);
69        }
70    }
71    if let Some(static_assets) = middleware.static_assets {
72        if static_assets.enable {
73            router = apply_static_dir(router, static_assets);
74        }
75    }
76    router
77}
78
79fn apply_static_dir(router: Router, static_assets: StaticAssetsMiddleware) -> Router {
80    if static_assets.must_exist
81        && (!PathBuf::from(&static_assets.path).exists()
82            || !PathBuf::from(&static_assets.fallback).exists())
83    {
84        panic!(
85            "one of the static path are not found, Folder `{}` fallback: `{}`",
86            static_assets.path, static_assets.fallback
87        );
88    }
89
90    let fallback = ServeFile::new(format!("{}/{}", static_assets.path, static_assets.fallback));
91    let serve_dir = ServeDir::new(static_assets.path).not_found_service(fallback);
92
93    let service = if static_assets.precompressed {
94        tracing::info!("[Middleware] Enable precompressed static assets");
95        serve_dir.precompressed_gzip()
96    } else {
97        serve_dir
98    };
99
100    if static_assets.uri == "/" {
101        router.fallback_service(service)
102    } else {
103        router.nest_service(&static_assets.uri, service)
104    }
105}
106
107fn build_cors_middleware(cors: &CorsMiddleware) -> Result<CorsLayer> {
108    let mut layer = CorsLayer::new();
109
110    if let Some(allow_origins) = &cors.allow_origins {
111        if allow_origins.iter().any(|item| item == "*") {
112            layer = layer.allow_origin(cors::Any);
113        } else {
114            let mut origins = Vec::with_capacity(allow_origins.len());
115            for origin in allow_origins {
116                let origin = origin
117                    .parse()
118                    .with_context(|| format!("cors origin parse failed:{origin}"))?;
119                origins.push(origin);
120            }
121            layer = layer.allow_origin(origins);
122        }
123    }
124
125    if let Some(allow_headers) = &cors.allow_headers {
126        if allow_headers.iter().any(|item| item == "*") {
127            layer = layer.allow_headers(cors::Any);
128        } else {
129            let mut headers = Vec::with_capacity(allow_headers.len());
130            for header in allow_headers {
131                let header = header
132                    .parse()
133                    .with_context(|| format!("http header parse failed:{header}"))?;
134                headers.push(header);
135            }
136            layer = layer.allow_headers(headers);
137        }
138    }
139
140    if let Some(allow_methods) = &cors.allow_methods {
141        if allow_methods.iter().any(|item| item == "*") {
142            layer = layer.allow_methods(cors::Any);
143        } else {
144            let mut methods = Vec::with_capacity(allow_methods.len());
145            for method in allow_methods {
146                let method = method
147                    .parse()
148                    .with_context(|| format!("http method parse failed:{method}"))?;
149                methods.push(method);
150            }
151            layer = layer.allow_methods(methods);
152        }
153    }
154
155    if let Some(max_age) = cors.max_age {
156        layer = layer.max_age(Duration::from_secs(max_age));
157    }
158
159    Ok(layer)
160}