utoipa_swagger_ui/
rocket.rs

1#![cfg(feature = "rocket")]
2
3use std::{borrow::Cow, io::Cursor, sync::Arc};
4
5use base64::{prelude::BASE64_STANDARD, Engine};
6use rocket::{
7    http::{Header, Status},
8    request::{self, FromRequest},
9    response::{status::NotFound, Responder as RocketResponder},
10    route::{Handler, Outcome},
11    serde::json::Json,
12    Data as RocketData, Request, Response, Route,
13};
14
15use crate::{ApiDoc, BasicAuth, Config, SwaggerFile, SwaggerUi};
16
17impl From<SwaggerUi> for Vec<Route> {
18    fn from(swagger_ui: SwaggerUi) -> Self {
19        let mut routes =
20            Vec::<Route>::with_capacity(swagger_ui.urls.len() + 1 + swagger_ui.external_urls.len());
21        let mut api_docs =
22            Vec::<Route>::with_capacity(swagger_ui.urls.len() + swagger_ui.external_urls.len());
23
24        let urls = swagger_ui
25            .urls
26            .into_iter()
27            .map(|(url, openapi)| (url, ApiDoc::Utoipa(openapi)))
28            .chain(
29                swagger_ui
30                    .external_urls
31                    .into_iter()
32                    .map(|(url, api_doc)| (url, ApiDoc::Value(api_doc))),
33            )
34            .map(|(url, openapi)| {
35                api_docs.push(Route::new(
36                    rocket::http::Method::Get,
37                    &url.url,
38                    ServeApiDoc(openapi),
39                ));
40                url
41            });
42
43        routes.push(Route::new(
44            rocket::http::Method::Get,
45            swagger_ui.path.as_ref(),
46            ServeSwagger(
47                swagger_ui.path.clone(),
48                Arc::new(if let Some(config) = swagger_ui.config {
49                    if config.url.is_some() || !config.urls.is_empty() {
50                        config
51                    } else {
52                        config.configure_defaults(urls)
53                    }
54                } else {
55                    Config::new(urls)
56                }),
57            ),
58        ));
59        routes.extend(api_docs);
60
61        routes
62    }
63}
64
65#[derive(Clone)]
66struct ServeApiDoc(ApiDoc);
67
68#[rocket::async_trait]
69impl Handler for ServeApiDoc {
70    async fn handle<'r>(&self, request: &'r Request<'_>, _: RocketData<'r>) -> Outcome<'r> {
71        Outcome::from(request, Json(self.0.clone()))
72    }
73}
74
75#[derive(Clone)]
76struct ServeSwagger(Cow<'static, str>, Arc<Config<'static>>);
77
78#[rocket::async_trait]
79impl Handler for ServeSwagger {
80    async fn handle<'r>(&self, request: &'r Request<'_>, _: RocketData<'r>) -> Outcome<'r> {
81        if let Some(basic_auth) = &self.1.clone().basic_auth {
82            let request_guard = request.guard::<BasicAuth>().await;
83            match request_guard {
84                request::Outcome::Success(BasicAuth { username, password })
85                    if username == basic_auth.username && password == basic_auth.password =>
86                {
87                    ()
88                }
89                _ => return Outcome::from(request, BasicAuthErrorResponse),
90            }
91        }
92
93        let mut base_path = self.0.as_ref();
94        if let Some(index) = self.0.find('<') {
95            base_path = &base_path[..index];
96        }
97
98        let request_path = request.uri().path().as_str();
99        let request_path = match request_path.strip_prefix(base_path) {
100            Some(stripped) => stripped,
101            None => return Outcome::from(request, RedirectResponder(base_path.into())),
102        };
103        match super::serve(request_path, self.1.clone()) {
104            Ok(swagger_file) => swagger_file
105                .map(|file| Outcome::from(request, file))
106                .unwrap_or_else(|| Outcome::from(request, NotFound("Swagger UI file not found"))),
107            Err(error) => Outcome::from(
108                request,
109                rocket::response::status::Custom(Status::InternalServerError, error.to_string()),
110            ),
111        }
112    }
113}
114
115pub struct BasicAuthErrorResponse;
116
117impl<'r, 'o: 'r> RocketResponder<'r, 'o> for BasicAuthErrorResponse {
118    fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'o> {
119        Response::build()
120            .status(Status::Unauthorized)
121            .header(Header::new("WWW-Authenticate", "Basic realm=\":\""))
122            .ok()
123    }
124}
125
126impl<'r, 'o: 'r> RocketResponder<'r, 'o> for SwaggerFile<'o> {
127    fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'o> {
128        Ok(Response::build()
129            .header(Header::new("Content-Type", self.content_type))
130            .sized_body(self.bytes.len(), Cursor::new(self.bytes.to_vec()))
131            .status(Status::Ok)
132            .finalize())
133    }
134}
135
136struct RedirectResponder(String);
137impl<'r, 'a: 'r> RocketResponder<'r, 'a> for RedirectResponder {
138    fn respond_to(self, _request: &'r Request<'_>) -> rocket::response::Result<'a> {
139        Response::build()
140            .status(Status::Found)
141            .raw_header("Location", self.0)
142            .ok()
143    }
144}
145
146#[rocket::async_trait]
147impl<'r> FromRequest<'r> for BasicAuth {
148    type Error = ();
149
150    async fn from_request(req: &'r Request<'_>) -> request::Outcome<BasicAuth, ()> {
151        match req.headers().get_one("Authorization") {
152            None => request::Outcome::Error((Status::BadRequest, ())),
153            Some(credentials) => {
154                if let Some(basic_auth) = credentials
155                    .strip_prefix("Basic ")
156                    .and_then(|s| BASE64_STANDARD.decode(s).ok())
157                    .and_then(|b| String::from_utf8(b).ok())
158                    .and_then(|s| {
159                        if let Some((username, password)) = s.split_once(':') {
160                            Some(BasicAuth {
161                                username: username.to_string(),
162                                password: password.to_string(),
163                            })
164                        } else {
165                            None
166                        }
167                    })
168                {
169                    request::Outcome::Success(basic_auth)
170                } else {
171                    request::Outcome::Error((Status::BadRequest, ()))
172                }
173            }
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use rocket::local::blocking::Client;
181
182    use crate::BasicAuth;
183
184    use super::*;
185
186    #[test]
187    fn mount_onto_path_not_end_with_slash() {
188        let routes: Vec<Route> = SwaggerUi::new("/swagger-ui").into();
189        let rocket = rocket::build().mount("/", routes);
190        let client = Client::tracked(rocket).unwrap();
191        let response = client.get("/swagger-ui").dispatch();
192        assert_eq!(response.status(), Status::Ok);
193    }
194
195    #[test]
196    fn basic_auth() {
197        let swagger_ui =
198            SwaggerUi::new("/swagger-ui").config(Config::default().basic_auth(BasicAuth {
199                username: "admin".to_string(),
200                password: "password".to_string(),
201            }));
202        let routes: Vec<Route> = swagger_ui.into();
203        let rocket = rocket::build().mount("/", routes);
204        let client = Client::tracked(rocket).unwrap();
205        let response = client.get("/swagger-ui").dispatch();
206        assert_eq!(response.status(), Status::Unauthorized);
207        let encoded_credentials = BASE64_STANDARD.encode("admin:password");
208        let response = client
209            .get("/swagger-ui")
210            .header(Header::new(
211                "Authorization",
212                format!("Basic {}", encoded_credentials),
213            ))
214            .dispatch();
215        assert_eq!(response.status(), Status::Ok);
216    }
217}