actix_web_lab/
redirect_to_https.rs

1use std::{
2    future::{Ready, ready},
3    rc::Rc,
4};
5
6use actix_web::{
7    HttpResponse, Responder as _,
8    body::EitherBody,
9    dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
10    http::header::TryIntoHeaderPair,
11    web::Redirect,
12};
13use futures_core::future::LocalBoxFuture;
14
15use crate::header::StrictTransportSecurity;
16
17/// Middleware to redirect traffic to HTTPS if connection is insecure.
18///
19/// # HSTS
20///
21/// [HTTP Strict Transport Security (HSTS)] is configurable. Care should be taken when setting up
22/// HSTS for your site; misconfiguration can potentially leave parts of your site in an unusable
23/// state. By default it is disabled.
24///
25/// See [`StrictTransportSecurity`] docs for more info.
26///
27/// # Examples
28///
29/// ```
30/// # use std::time::Duration;
31/// # use actix_web::App;
32/// use actix_web_lab::{header::StrictTransportSecurity, middleware::RedirectHttps};
33///
34/// let mw = RedirectHttps::default();
35/// let mw = RedirectHttps::default().to_port(8443);
36/// let mw = RedirectHttps::with_hsts(StrictTransportSecurity::default());
37/// let mw = RedirectHttps::with_hsts(StrictTransportSecurity::new(Duration::from_secs(60 * 60)));
38/// let mw = RedirectHttps::with_hsts(StrictTransportSecurity::recommended());
39///
40/// App::new().wrap(mw)
41/// # ;
42/// ```
43///
44/// [HTTP Strict Transport Security (HSTS)]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security
45#[derive(Debug, Clone, Default)]
46pub struct RedirectHttps {
47    hsts: Option<StrictTransportSecurity>,
48    port: Option<u16>,
49}
50
51impl RedirectHttps {
52    /// Construct new HTTP redirect middleware with strict transport security configuration.
53    pub fn with_hsts(hsts: StrictTransportSecurity) -> Self {
54        Self {
55            hsts: Some(hsts),
56            ..Self::default()
57        }
58    }
59
60    /// Sets custom secure redirect port.
61    ///
62    /// By default, no port is set explicitly so the standard HTTPS port (443) is used.
63    pub fn to_port(mut self, port: u16) -> Self {
64        self.port = Some(port);
65        self
66    }
67}
68
69impl<S, B> Transform<S, ServiceRequest> for RedirectHttps
70where
71    S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
72{
73    type Response = ServiceResponse<EitherBody<B, ()>>;
74    type Error = S::Error;
75    type Transform = RedirectHttpsMiddleware<S>;
76    type InitError = ();
77    type Future = Ready<Result<Self::Transform, Self::InitError>>;
78
79    fn new_transform(&self, service: S) -> Self::Future {
80        ready(Ok(RedirectHttpsMiddleware {
81            service: Rc::new(service),
82            hsts: self.hsts,
83            port: self.port,
84        }))
85    }
86}
87
88/// Middleware service implementation for [`RedirectHttps`].
89#[doc(hidden)]
90#[allow(missing_debug_implementations)]
91pub struct RedirectHttpsMiddleware<S> {
92    service: Rc<S>,
93    hsts: Option<StrictTransportSecurity>,
94    port: Option<u16>,
95}
96
97impl<S, B> Service<ServiceRequest> for RedirectHttpsMiddleware<S>
98where
99    S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
100{
101    type Response = ServiceResponse<EitherBody<B, ()>>;
102    type Error = S::Error;
103    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
104
105    forward_ready!(service);
106
107    fn call(&self, req: ServiceRequest) -> Self::Future {
108        #![allow(clippy::await_holding_refcell_ref)] // RefCell is dropped before await
109
110        let service = Rc::clone(&self.service);
111        let hsts = self.hsts;
112        let port = self.port;
113
114        Box::pin(async move {
115            let (req, pl) = req.into_parts();
116            let conn_info = req.connection_info();
117
118            if conn_info.scheme() != "https" {
119                let host = conn_info.host();
120
121                // construct equivalent https path
122                let parsed_url = url::Url::parse(&format!("http://{host}"));
123                let hostname = match &parsed_url {
124                    Ok(url) => url.host_str().unwrap_or(""),
125                    Err(_) => host.split_once(':').map_or("", |(host, _port)| host),
126                };
127
128                let path = req.uri().path();
129                let uri = match port {
130                    Some(port) => format!("https://{hostname}:{port}{path}"),
131                    None => format!("https://{hostname}{path}"),
132                };
133
134                // all connection info is acquired
135                drop(conn_info);
136
137                // create redirection response
138                let redirect = Redirect::to(uri);
139
140                let mut res = redirect.respond_to(&req).map_into_right_body();
141                apply_hsts(&mut res, hsts);
142
143                return Ok(ServiceResponse::new(req, res));
144            }
145
146            drop(conn_info);
147
148            let req = ServiceRequest::from_parts(req, pl);
149
150            // TODO: apply HSTS header to error case
151
152            service.call(req).await.map(|mut res| {
153                apply_hsts(res.response_mut(), hsts);
154                res.map_into_left_body()
155            })
156        })
157    }
158}
159
160/// Apply HSTS config to an `HttpResponse`.
161fn apply_hsts<B>(res: &mut HttpResponse<B>, hsts: Option<StrictTransportSecurity>) {
162    if let Some(hsts) = hsts {
163        let (name, val) = hsts.try_into_pair().unwrap();
164        res.headers_mut().insert(name, val);
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use actix_web::{
171        App, Error, HttpResponse,
172        body::MessageBody,
173        dev::ServiceFactory,
174        http::{
175            StatusCode,
176            header::{self, Header as _},
177        },
178        test, web,
179    };
180
181    use super::*;
182    use crate::{assert_response_matches, test_request};
183
184    fn test_app() -> App<
185        impl ServiceFactory<
186            ServiceRequest,
187            Response = ServiceResponse<impl MessageBody>,
188            Config = (),
189            InitError = (),
190            Error = Error,
191        >,
192    > {
193        App::new().wrap(RedirectHttps::default()).route(
194            "/",
195            web::get().to(|| async { HttpResponse::Ok().body("content") }),
196        )
197    }
198
199    #[actix_web::test]
200    async fn redirect_non_https() {
201        let app = test::init_service(test_app()).await;
202
203        let req = test::TestRequest::default().to_request();
204        let res = test::call_service(&app, req).await;
205
206        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
207        let loc = res.headers().get(header::LOCATION);
208        assert!(loc.unwrap().as_bytes().starts_with(b"https://"));
209
210        let body = test::read_body(res).await;
211        assert!(body.is_empty());
212    }
213
214    #[actix_web::test]
215    async fn do_not_redirect_already_https() {
216        let app = test::init_service(test_app()).await;
217
218        let req = test::TestRequest::default()
219            .uri("https://localhost:443/")
220            .to_request();
221
222        let res = test::call_service(&app, req).await;
223        assert_eq!(res.status(), StatusCode::OK);
224        assert!(res.headers().get(header::LOCATION).is_none());
225
226        let body = test::read_body(res).await;
227        assert_eq!(body, "content");
228    }
229
230    #[actix_web::test]
231    async fn with_hsts() {
232        // no HSTS
233        let app = RedirectHttps::default()
234            .new_transform(test::ok_service())
235            .await
236            .unwrap();
237
238        let req = test_request!(GET "http://localhost/").to_srv_request();
239        let res = test::call_service(&app, req).await;
240        assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
241
242        let req = test_request!(GET "https://localhost:443/").to_srv_request();
243        let res = test::call_service(&app, req).await;
244        assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
245
246        // with HSTS
247        let app = RedirectHttps::with_hsts(StrictTransportSecurity::recommended())
248            .new_transform(test::ok_service())
249            .await
250            .unwrap();
251
252        let req = test_request!(GET "http://localhost/").to_srv_request();
253        let res = test::call_service(&app, req).await;
254        assert!(res.headers().contains_key(StrictTransportSecurity::name()));
255
256        let req = test_request!(GET "https://localhost:443/").to_srv_request();
257        let res = test::call_service(&app, req).await;
258        assert!(res.headers().contains_key(StrictTransportSecurity::name()));
259    }
260
261    #[actix_web::test]
262    async fn to_custom_port() {
263        let app = RedirectHttps::default()
264            .to_port(8443)
265            .new_transform(test::ok_service())
266            .await
267            .unwrap();
268
269        let req = test_request!(GET "http://localhost/").to_srv_request();
270        let res = test::call_service(&app, req).await;
271        assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
272    }
273
274    #[actix_web::test]
275    async fn to_ipv6() {
276        let app = RedirectHttps::default()
277            .new_transform(test::ok_service())
278            .await
279            .unwrap();
280
281        let req = test_request!(GET "http://[fe80::1234:1234:1234:1234]/").to_srv_request();
282        let res = test::call_service(&app, req).await;
283        assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://[fe80::1234:1234:1234:1234]/");
284    }
285
286    #[actix_web::test]
287    async fn to_custom_port_when_port_in_host() {
288        let app = RedirectHttps::default()
289            .to_port(8443)
290            .new_transform(test::ok_service())
291            .await
292            .unwrap();
293
294        let req = test_request!(GET "http://localhost:8080/").to_srv_request();
295        let res = test::call_service(&app, req).await;
296        assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
297    }
298}