utoipa_swagger_ui/
rocket.rs1#![cfg(feature = "rocket")]
2
3use std::{borrow::Cow, io::Cursor, sync::Arc};
4
5use base64::{prelude::BASE64_STANDARD, Engine};
6use rocket::{
7 http::{Header, Status},
8 request::{self, FromRequest},
9 response::{status::NotFound, Responder as RocketResponder},
10 route::{Handler, Outcome},
11 serde::json::Json,
12 Data as RocketData, Request, Response, Route,
13};
14
15use crate::{ApiDoc, BasicAuth, Config, SwaggerFile, SwaggerUi};
16
17impl From<SwaggerUi> for Vec<Route> {
18 fn from(swagger_ui: SwaggerUi) -> Self {
19 let mut routes =
20 Vec::<Route>::with_capacity(swagger_ui.urls.len() + 1 + swagger_ui.external_urls.len());
21 let mut api_docs =
22 Vec::<Route>::with_capacity(swagger_ui.urls.len() + swagger_ui.external_urls.len());
23
24 let urls = swagger_ui
25 .urls
26 .into_iter()
27 .map(|(url, openapi)| (url, ApiDoc::Utoipa(openapi)))
28 .chain(
29 swagger_ui
30 .external_urls
31 .into_iter()
32 .map(|(url, api_doc)| (url, ApiDoc::Value(api_doc))),
33 )
34 .map(|(url, openapi)| {
35 api_docs.push(Route::new(
36 rocket::http::Method::Get,
37 &url.url,
38 ServeApiDoc(openapi),
39 ));
40 url
41 });
42
43 routes.push(Route::new(
44 rocket::http::Method::Get,
45 swagger_ui.path.as_ref(),
46 ServeSwagger(
47 swagger_ui.path.clone(),
48 Arc::new(if let Some(config) = swagger_ui.config {
49 if config.url.is_some() || !config.urls.is_empty() {
50 config
51 } else {
52 config.configure_defaults(urls)
53 }
54 } else {
55 Config::new(urls)
56 }),
57 ),
58 ));
59 routes.extend(api_docs);
60
61 routes
62 }
63}
64
65#[derive(Clone)]
66struct ServeApiDoc(ApiDoc);
67
68#[rocket::async_trait]
69impl Handler for ServeApiDoc {
70 async fn handle<'r>(&self, request: &'r Request<'_>, _: RocketData<'r>) -> Outcome<'r> {
71 Outcome::from(request, Json(self.0.clone()))
72 }
73}
74
75#[derive(Clone)]
76struct ServeSwagger(Cow<'static, str>, Arc<Config<'static>>);
77
78#[rocket::async_trait]
79impl Handler for ServeSwagger {
80 async fn handle<'r>(&self, request: &'r Request<'_>, _: RocketData<'r>) -> Outcome<'r> {
81 if let Some(basic_auth) = &self.1.clone().basic_auth {
82 let request_guard = request.guard::<BasicAuth>().await;
83 match request_guard {
84 request::Outcome::Success(BasicAuth { username, password })
85 if username == basic_auth.username && password == basic_auth.password =>
86 {
87 ()
88 }
89 _ => return Outcome::from(request, BasicAuthErrorResponse),
90 }
91 }
92
93 let mut base_path = self.0.as_ref();
94 if let Some(index) = self.0.find('<') {
95 base_path = &base_path[..index];
96 }
97
98 let request_path = request.uri().path().as_str();
99 let request_path = match request_path.strip_prefix(base_path) {
100 Some(stripped) => stripped,
101 None => return Outcome::from(request, RedirectResponder(base_path.into())),
102 };
103 match super::serve(request_path, self.1.clone()) {
104 Ok(swagger_file) => swagger_file
105 .map(|file| Outcome::from(request, file))
106 .unwrap_or_else(|| Outcome::from(request, NotFound("Swagger UI file not found"))),
107 Err(error) => Outcome::from(
108 request,
109 rocket::response::status::Custom(Status::InternalServerError, error.to_string()),
110 ),
111 }
112 }
113}
114
115pub struct BasicAuthErrorResponse;
116
117impl<'r, 'o: 'r> RocketResponder<'r, 'o> for BasicAuthErrorResponse {
118 fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'o> {
119 Response::build()
120 .status(Status::Unauthorized)
121 .header(Header::new("WWW-Authenticate", "Basic realm=\":\""))
122 .ok()
123 }
124}
125
126impl<'r, 'o: 'r> RocketResponder<'r, 'o> for SwaggerFile<'o> {
127 fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'o> {
128 Ok(Response::build()
129 .header(Header::new("Content-Type", self.content_type))
130 .sized_body(self.bytes.len(), Cursor::new(self.bytes.to_vec()))
131 .status(Status::Ok)
132 .finalize())
133 }
134}
135
136struct RedirectResponder(String);
137impl<'r, 'a: 'r> RocketResponder<'r, 'a> for RedirectResponder {
138 fn respond_to(self, _request: &'r Request<'_>) -> rocket::response::Result<'a> {
139 Response::build()
140 .status(Status::Found)
141 .raw_header("Location", self.0)
142 .ok()
143 }
144}
145
146#[rocket::async_trait]
147impl<'r> FromRequest<'r> for BasicAuth {
148 type Error = ();
149
150 async fn from_request(req: &'r Request<'_>) -> request::Outcome<BasicAuth, ()> {
151 match req.headers().get_one("Authorization") {
152 None => request::Outcome::Error((Status::BadRequest, ())),
153 Some(credentials) => {
154 if let Some(basic_auth) = credentials
155 .strip_prefix("Basic ")
156 .and_then(|s| BASE64_STANDARD.decode(s).ok())
157 .and_then(|b| String::from_utf8(b).ok())
158 .and_then(|s| {
159 if let Some((username, password)) = s.split_once(':') {
160 Some(BasicAuth {
161 username: username.to_string(),
162 password: password.to_string(),
163 })
164 } else {
165 None
166 }
167 })
168 {
169 request::Outcome::Success(basic_auth)
170 } else {
171 request::Outcome::Error((Status::BadRequest, ()))
172 }
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use rocket::local::blocking::Client;
181
182 use crate::BasicAuth;
183
184 use super::*;
185
186 #[test]
187 fn mount_onto_path_not_end_with_slash() {
188 let routes: Vec<Route> = SwaggerUi::new("/swagger-ui").into();
189 let rocket = rocket::build().mount("/", routes);
190 let client = Client::tracked(rocket).unwrap();
191 let response = client.get("/swagger-ui").dispatch();
192 assert_eq!(response.status(), Status::Ok);
193 }
194
195 #[test]
196 fn basic_auth() {
197 let swagger_ui =
198 SwaggerUi::new("/swagger-ui").config(Config::default().basic_auth(BasicAuth {
199 username: "admin".to_string(),
200 password: "password".to_string(),
201 }));
202 let routes: Vec<Route> = swagger_ui.into();
203 let rocket = rocket::build().mount("/", routes);
204 let client = Client::tracked(rocket).unwrap();
205 let response = client.get("/swagger-ui").dispatch();
206 assert_eq!(response.status(), Status::Unauthorized);
207 let encoded_credentials = BASE64_STANDARD.encode("admin:password");
208 let response = client
209 .get("/swagger-ui")
210 .header(Header::new(
211 "Authorization",
212 format!("Basic {}", encoded_credentials),
213 ))
214 .dispatch();
215 assert_eq!(response.status(), Status::Ok);
216 }
217}