1#![cfg(feature = "axum")]
2
3use std::sync::Arc;
4
5use axum::{
6 body::Body,
7 extract::Path,
8 http::{HeaderMap, Request, Response, StatusCode},
9 middleware::{self, Next},
10 response::IntoResponse,
11 routing, Extension, Json, Router,
12};
13use base64::{prelude::BASE64_STANDARD, Engine};
14
15use crate::{ApiDoc, BasicAuth, Config, SwaggerUi, Url};
16
17impl<S> From<SwaggerUi> for Router<S>
18where
19 S: Clone + Send + Sync + 'static,
20{
21 fn from(swagger_ui: SwaggerUi) -> Self {
22 let urls_capacity = swagger_ui.urls.len();
23 let external_urls_capacity = swagger_ui.external_urls.len();
24
25 let (router, urls) = swagger_ui.urls.into_iter().fold(
26 (
27 Router::<S>::new(),
28 Vec::<Url>::with_capacity(urls_capacity + external_urls_capacity),
29 ),
30 |router_and_urls, (url, openapi)| {
31 add_api_doc_to_urls(router_and_urls, (url, Arc::new(ApiDoc::Utoipa(openapi))))
32 },
33 );
34 let (router, urls) = swagger_ui.external_urls.into_iter().fold(
35 (router, urls),
36 |router_and_urls, (url, openapi)| {
37 add_api_doc_to_urls(router_and_urls, (url, Arc::new(ApiDoc::Value(openapi))))
38 },
39 );
40
41 let config = if let Some(config) = swagger_ui.config {
42 if config.url.is_some() || !config.urls.is_empty() {
43 config
44 } else {
45 config.configure_defaults(urls)
46 }
47 } else {
48 Config::new(urls)
49 };
50
51 let handler = routing::get(serve_swagger_ui).layer(Extension(Arc::new(config.clone())));
52 let path: &str = swagger_ui.path.as_ref();
53
54 let mut router = if path == "/" {
55 router
56 .route(path, handler.clone())
57 .route(&format!("{}{{*rest}}", path), handler)
58 } else {
59 let path = if path.ends_with('/') {
60 &path[..path.len() - 1]
61 } else {
62 path
63 };
64 debug_assert!(!path.is_empty());
65
66 let slash_path = format!("{}/", path);
67 router
68 .route(
69 path,
70 routing::get(|| async move { axum::response::Redirect::to(&slash_path) }),
71 )
72 .route(&format!("{}/", path), handler.clone())
73 .route(&format!("{}/{{*rest}}", path), handler)
74 };
75
76 if let Some(BasicAuth { username, password }) = config.basic_auth {
77 let username = Arc::new(username);
78 let password = Arc::new(password);
79 let basic_auth_middleware =
80 move |headers: HeaderMap, req: Request<Body>, next: Next| {
81 let username = username.clone();
82 let password = password.clone();
83 async move {
84 if let Some(header) = headers.get("Authorization") {
85 if let Ok(header_str) = header.to_str() {
86 let base64_encoded_credentials =
87 BASE64_STANDARD.encode(format!("{}:{}", &username, &password));
88 if header_str == format!("Basic {}", base64_encoded_credentials) {
89 return Ok::<Response<Body>, StatusCode>(next.run(req).await);
90 }
91 }
92 }
93 Ok::<Response<Body>, StatusCode>(
94 (
95 StatusCode::UNAUTHORIZED,
96 [("WWW-Authenticate", "Basic realm=\":\"")],
97 )
98 .into_response(),
99 )
100 }
101 };
102 router = router.layer(middleware::from_fn(basic_auth_middleware));
103 }
104
105 router
106 }
107}
108
109fn add_api_doc_to_urls<S>(
110 router_and_urls: (Router<S>, Vec<Url<'static>>),
111 url: (Url<'static>, Arc<ApiDoc>),
112) -> (Router<S>, Vec<Url<'static>>)
113where
114 S: Clone + Send + Sync + 'static,
115{
116 let (router, mut urls) = router_and_urls;
117 let (url, openapi) = url;
118 (
119 router.route(
120 url.url.as_ref(),
121 routing::get(move || async { Json(openapi) }),
122 ),
123 {
124 urls.push(url);
125 urls
126 },
127 )
128}
129
130async fn serve_swagger_ui(
131 path: Option<Path<String>>,
132 Extension(state): Extension<Arc<Config<'static>>>,
133) -> impl IntoResponse {
134 let tail = match path.as_ref() {
135 Some(tail) => tail,
136 None => "",
137 };
138
139 match super::serve(tail, state) {
140 Ok(file) => file
141 .map(|file| {
142 (
143 StatusCode::OK,
144 [("Content-Type", file.content_type)],
145 file.bytes,
146 )
147 .into_response()
148 })
149 .unwrap_or_else(|| StatusCode::NOT_FOUND.into_response()),
150 Err(error) => (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response(),
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use http::header::AUTHORIZATION;
158 use http::HeaderValue;
159 use tower::util::ServiceExt;
160
161 #[tokio::test]
162 async fn mount_onto_root() {
163 let app = Router::<()>::from(SwaggerUi::new("/"));
164 let response = app.clone().oneshot(get("/")).await.unwrap();
165 assert_eq!(response.status(), StatusCode::OK);
166 let response = app.clone().oneshot(get("/swagger-ui.css")).await.unwrap();
167 assert_eq!(response.status(), StatusCode::OK);
168 }
169
170 #[tokio::test]
171 async fn mount_onto_path_ends_with_slash() {
172 let app = Router::<()>::from(SwaggerUi::new("/swagger-ui/"));
173 let response = app.clone().oneshot(get("/swagger-ui")).await.unwrap();
174 assert_eq!(response.status(), StatusCode::SEE_OTHER);
175 let response = app.clone().oneshot(get("/swagger-ui/")).await.unwrap();
176 assert_eq!(response.status(), StatusCode::OK);
177 let request = get("/swagger-ui/swagger-ui.css");
178 let response = app.clone().oneshot(request).await.unwrap();
179 assert_eq!(response.status(), StatusCode::OK);
180 }
181
182 #[tokio::test]
183 async fn mount_onto_path_not_end_with_slash() {
184 let app = Router::<()>::from(SwaggerUi::new("/swagger-ui"));
185 let response = app.clone().oneshot(get("/swagger-ui")).await.unwrap();
186 assert_eq!(response.status(), StatusCode::SEE_OTHER);
187 let response = app.clone().oneshot(get("/swagger-ui/")).await.unwrap();
188 assert_eq!(response.status(), StatusCode::OK);
189 let request = get("/swagger-ui/swagger-ui.css");
190 let response = app.clone().oneshot(request).await.unwrap();
191 assert_eq!(response.status(), StatusCode::OK);
192 }
193
194 #[tokio::test]
195 async fn basic_auth() {
196 let swagger_ui =
197 SwaggerUi::new("/swagger-ui").config(Config::default().basic_auth(BasicAuth {
198 username: "admin".to_string(),
199 password: "password".to_string(),
200 }));
201 let app = Router::<()>::from(swagger_ui);
202 let response = app.clone().oneshot(get("/swagger-ui")).await.unwrap();
203 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
204 let encoded_credentials = BASE64_STANDARD.encode("admin:password");
205 let authorization = format!("Basic {}", encoded_credentials);
206 let request = authorized_get("/swagger-ui", &authorization);
207 let response = app.clone().oneshot(request).await.unwrap();
208 assert_eq!(response.status(), StatusCode::SEE_OTHER);
209 let request = authorized_get("/swagger-ui/", &authorization);
210 let response = app.clone().oneshot(request).await.unwrap();
211 assert_eq!(response.status(), StatusCode::OK);
212 let request = authorized_get("/swagger-ui/swagger-ui.css", &authorization);
213 let response = app.clone().oneshot(request).await.unwrap();
214 assert_eq!(response.status(), StatusCode::OK);
215 }
216
217 fn get(url: &str) -> Request<Body> {
218 Request::builder().uri(url).body(Body::empty()).unwrap()
219 }
220
221 fn authorized_get(url: &str, authorization: &str) -> Request<Body> {
222 Request::builder()
223 .uri(url)
224 .header(AUTHORIZATION, HeaderValue::from_str(authorization).unwrap())
225 .body(Body::empty())
226 .unwrap()
227 }
228}