1use std::{
2 future::{Ready, ready},
3 rc::Rc,
4};
5
6use actix_web::{
7 HttpResponse, Responder as _,
8 body::EitherBody,
9 dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
10 http::header::TryIntoHeaderPair,
11 web::Redirect,
12};
13use futures_core::future::LocalBoxFuture;
14
15use crate::header::StrictTransportSecurity;
16
17#[derive(Debug, Clone, Default)]
46pub struct RedirectHttps {
47 hsts: Option<StrictTransportSecurity>,
48 port: Option<u16>,
49}
50
51impl RedirectHttps {
52 pub fn with_hsts(hsts: StrictTransportSecurity) -> Self {
54 Self {
55 hsts: Some(hsts),
56 ..Self::default()
57 }
58 }
59
60 pub fn to_port(mut self, port: u16) -> Self {
64 self.port = Some(port);
65 self
66 }
67}
68
69impl<S, B> Transform<S, ServiceRequest> for RedirectHttps
70where
71 S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
72{
73 type Response = ServiceResponse<EitherBody<B, ()>>;
74 type Error = S::Error;
75 type Transform = RedirectHttpsMiddleware<S>;
76 type InitError = ();
77 type Future = Ready<Result<Self::Transform, Self::InitError>>;
78
79 fn new_transform(&self, service: S) -> Self::Future {
80 ready(Ok(RedirectHttpsMiddleware {
81 service: Rc::new(service),
82 hsts: self.hsts,
83 port: self.port,
84 }))
85 }
86}
87
88#[doc(hidden)]
90#[allow(missing_debug_implementations)]
91pub struct RedirectHttpsMiddleware<S> {
92 service: Rc<S>,
93 hsts: Option<StrictTransportSecurity>,
94 port: Option<u16>,
95}
96
97impl<S, B> Service<ServiceRequest> for RedirectHttpsMiddleware<S>
98where
99 S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
100{
101 type Response = ServiceResponse<EitherBody<B, ()>>;
102 type Error = S::Error;
103 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
104
105 forward_ready!(service);
106
107 fn call(&self, req: ServiceRequest) -> Self::Future {
108 #![allow(clippy::await_holding_refcell_ref)] let service = Rc::clone(&self.service);
111 let hsts = self.hsts;
112 let port = self.port;
113
114 Box::pin(async move {
115 let (req, pl) = req.into_parts();
116 let conn_info = req.connection_info();
117
118 if conn_info.scheme() != "https" {
119 let host = conn_info.host();
120
121 let parsed_url = url::Url::parse(&format!("http://{host}"));
123 let hostname = match &parsed_url {
124 Ok(url) => url.host_str().unwrap_or(""),
125 Err(_) => host.split_once(':').map_or("", |(host, _port)| host),
126 };
127
128 let path = req.uri().path();
129 let uri = match port {
130 Some(port) => format!("https://{hostname}:{port}{path}"),
131 None => format!("https://{hostname}{path}"),
132 };
133
134 drop(conn_info);
136
137 let redirect = Redirect::to(uri);
139
140 let mut res = redirect.respond_to(&req).map_into_right_body();
141 apply_hsts(&mut res, hsts);
142
143 return Ok(ServiceResponse::new(req, res));
144 }
145
146 drop(conn_info);
147
148 let req = ServiceRequest::from_parts(req, pl);
149
150 service.call(req).await.map(|mut res| {
153 apply_hsts(res.response_mut(), hsts);
154 res.map_into_left_body()
155 })
156 })
157 }
158}
159
160fn apply_hsts<B>(res: &mut HttpResponse<B>, hsts: Option<StrictTransportSecurity>) {
162 if let Some(hsts) = hsts {
163 let (name, val) = hsts.try_into_pair().unwrap();
164 res.headers_mut().insert(name, val);
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use actix_web::{
171 App, Error, HttpResponse,
172 body::MessageBody,
173 dev::ServiceFactory,
174 http::{
175 StatusCode,
176 header::{self, Header as _},
177 },
178 test, web,
179 };
180
181 use super::*;
182 use crate::{assert_response_matches, test_request};
183
184 fn test_app() -> App<
185 impl ServiceFactory<
186 ServiceRequest,
187 Response = ServiceResponse<impl MessageBody>,
188 Config = (),
189 InitError = (),
190 Error = Error,
191 >,
192 > {
193 App::new().wrap(RedirectHttps::default()).route(
194 "/",
195 web::get().to(|| async { HttpResponse::Ok().body("content") }),
196 )
197 }
198
199 #[actix_web::test]
200 async fn redirect_non_https() {
201 let app = test::init_service(test_app()).await;
202
203 let req = test::TestRequest::default().to_request();
204 let res = test::call_service(&app, req).await;
205
206 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
207 let loc = res.headers().get(header::LOCATION);
208 assert!(loc.unwrap().as_bytes().starts_with(b"https://"));
209
210 let body = test::read_body(res).await;
211 assert!(body.is_empty());
212 }
213
214 #[actix_web::test]
215 async fn do_not_redirect_already_https() {
216 let app = test::init_service(test_app()).await;
217
218 let req = test::TestRequest::default()
219 .uri("https://localhost:443/")
220 .to_request();
221
222 let res = test::call_service(&app, req).await;
223 assert_eq!(res.status(), StatusCode::OK);
224 assert!(res.headers().get(header::LOCATION).is_none());
225
226 let body = test::read_body(res).await;
227 assert_eq!(body, "content");
228 }
229
230 #[actix_web::test]
231 async fn with_hsts() {
232 let app = RedirectHttps::default()
234 .new_transform(test::ok_service())
235 .await
236 .unwrap();
237
238 let req = test_request!(GET "http://localhost/").to_srv_request();
239 let res = test::call_service(&app, req).await;
240 assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
241
242 let req = test_request!(GET "https://localhost:443/").to_srv_request();
243 let res = test::call_service(&app, req).await;
244 assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
245
246 let app = RedirectHttps::with_hsts(StrictTransportSecurity::recommended())
248 .new_transform(test::ok_service())
249 .await
250 .unwrap();
251
252 let req = test_request!(GET "http://localhost/").to_srv_request();
253 let res = test::call_service(&app, req).await;
254 assert!(res.headers().contains_key(StrictTransportSecurity::name()));
255
256 let req = test_request!(GET "https://localhost:443/").to_srv_request();
257 let res = test::call_service(&app, req).await;
258 assert!(res.headers().contains_key(StrictTransportSecurity::name()));
259 }
260
261 #[actix_web::test]
262 async fn to_custom_port() {
263 let app = RedirectHttps::default()
264 .to_port(8443)
265 .new_transform(test::ok_service())
266 .await
267 .unwrap();
268
269 let req = test_request!(GET "http://localhost/").to_srv_request();
270 let res = test::call_service(&app, req).await;
271 assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
272 }
273
274 #[actix_web::test]
275 async fn to_ipv6() {
276 let app = RedirectHttps::default()
277 .new_transform(test::ok_service())
278 .await
279 .unwrap();
280
281 let req = test_request!(GET "http://[fe80::1234:1234:1234:1234]/").to_srv_request();
282 let res = test::call_service(&app, req).await;
283 assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://[fe80::1234:1234:1234:1234]/");
284 }
285
286 #[actix_web::test]
287 async fn to_custom_port_when_port_in_host() {
288 let app = RedirectHttps::default()
289 .to_port(8443)
290 .new_transform(test::ok_service())
291 .await
292 .unwrap();
293
294 let req = test_request!(GET "http://localhost:8080/").to_srv_request();
295 let res = test::call_service(&app, req).await;
296 assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
297 }
298}