awc/middleware/
redirect.rs

1use std::{
2    convert::TryFrom,
3    future::Future,
4    net::SocketAddr,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8};
9
10use actix_http::{header, Method, RequestHead, RequestHeadType, StatusCode, Uri};
11use actix_service::Service;
12use bytes::Bytes;
13use futures_core::ready;
14
15use super::Transform;
16use crate::{
17    any_body::AnyBody,
18    client::{InvalidUrl, SendRequestError},
19    connect::{ConnectRequest, ConnectResponse},
20    ClientResponse,
21};
22
23pub struct Redirect {
24    max_redirect_times: u8,
25}
26
27impl Default for Redirect {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl Redirect {
34    pub fn new() -> Self {
35        Self {
36            max_redirect_times: 10,
37        }
38    }
39
40    pub fn max_redirect_times(mut self, times: u8) -> Self {
41        self.max_redirect_times = times;
42        self
43    }
44}
45
46impl<S> Transform<S, ConnectRequest> for Redirect
47where
48    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
49{
50    type Transform = RedirectService<S>;
51
52    fn new_transform(self, service: S) -> Self::Transform {
53        RedirectService {
54            max_redirect_times: self.max_redirect_times,
55            connector: Rc::new(service),
56        }
57    }
58}
59
60pub struct RedirectService<S> {
61    max_redirect_times: u8,
62    connector: Rc<S>,
63}
64
65impl<S> Service<ConnectRequest> for RedirectService<S>
66where
67    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
68{
69    type Response = S::Response;
70    type Error = S::Error;
71    type Future = RedirectServiceFuture<S>;
72
73    actix_service::forward_ready!(connector);
74
75    fn call(&self, req: ConnectRequest) -> Self::Future {
76        match req {
77            ConnectRequest::Tunnel(head, addr) => {
78                let fut = self.connector.call(ConnectRequest::Tunnel(head, addr));
79                RedirectServiceFuture::Tunnel { fut }
80            }
81            ConnectRequest::Client(head, body, addr) => {
82                let connector = self.connector.clone();
83                let max_redirect_times = self.max_redirect_times;
84
85                // backup the uri and method for reuse schema and authority.
86                let (uri, method, headers) = match head {
87                    RequestHeadType::Owned(ref head) => {
88                        (head.uri.clone(), head.method.clone(), head.headers.clone())
89                    }
90                    RequestHeadType::Rc(ref head, ..) => {
91                        (head.uri.clone(), head.method.clone(), head.headers.clone())
92                    }
93                };
94
95                let body_opt = match body {
96                    AnyBody::Bytes { ref body } => Some(body.clone()),
97                    _ => None,
98                };
99
100                let fut = connector.call(ConnectRequest::Client(head, body, addr));
101
102                RedirectServiceFuture::Client {
103                    fut,
104                    max_redirect_times,
105                    uri: Some(uri),
106                    method: Some(method),
107                    headers: Some(headers),
108                    body: body_opt,
109                    addr,
110                    connector: Some(connector),
111                }
112            }
113        }
114    }
115}
116
117pin_project_lite::pin_project! {
118    #[project = RedirectServiceProj]
119    pub enum RedirectServiceFuture<S>
120    where
121        S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
122        S: 'static
123    {
124        Tunnel { #[pin] fut: S::Future },
125        Client {
126            #[pin]
127            fut: S::Future,
128            max_redirect_times: u8,
129            uri: Option<Uri>,
130            method: Option<Method>,
131            headers: Option<header::HeaderMap>,
132            body: Option<Bytes>,
133            addr: Option<SocketAddr>,
134            connector: Option<Rc<S>>,
135        }
136    }
137}
138
139impl<S> Future for RedirectServiceFuture<S>
140where
141    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
142{
143    type Output = Result<ConnectResponse, SendRequestError>;
144
145    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
146        match self.as_mut().project() {
147            RedirectServiceProj::Tunnel { fut } => fut.poll(cx),
148            RedirectServiceProj::Client {
149                fut,
150                max_redirect_times,
151                uri,
152                method,
153                headers,
154                body,
155                addr,
156                connector,
157            } => match ready!(fut.poll(cx))? {
158                ConnectResponse::Client(res) => match res.head().status {
159                    StatusCode::MOVED_PERMANENTLY
160                    | StatusCode::FOUND
161                    | StatusCode::SEE_OTHER
162                    | StatusCode::TEMPORARY_REDIRECT
163                    | StatusCode::PERMANENT_REDIRECT
164                        if *max_redirect_times > 0
165                            && res.headers().contains_key(header::LOCATION) =>
166                    {
167                        let reuse_body = res.head().status == StatusCode::TEMPORARY_REDIRECT
168                            || res.head().status == StatusCode::PERMANENT_REDIRECT;
169
170                        let prev_uri = uri.take().unwrap();
171
172                        // rebuild uri from the location header value.
173                        let next_uri = build_next_uri(&res, &prev_uri)?;
174
175                        // take ownership of states that could be reused
176                        let addr = addr.take();
177                        let connector = connector.take();
178
179                        // reset method
180                        let method = if reuse_body {
181                            method.take().unwrap()
182                        } else {
183                            let method = method.take().unwrap();
184                            match method {
185                                Method::GET | Method::HEAD => method,
186                                _ => Method::GET,
187                            }
188                        };
189
190                        let mut body = body.take();
191                        let body_new = if reuse_body {
192                            // try to reuse saved body
193                            match body {
194                                Some(ref bytes) => AnyBody::Bytes {
195                                    body: bytes.clone(),
196                                },
197
198                                // body was a non-reusable type so send an empty body instead
199                                _ => AnyBody::empty(),
200                            }
201                        } else {
202                            body = None;
203                            // remove body since we're downgrading to a GET
204                            AnyBody::None
205                        };
206
207                        let mut headers = headers.take().unwrap();
208
209                        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
210
211                        // use a new request head.
212                        let mut head = RequestHead::default();
213                        head.uri = next_uri.clone();
214                        head.method = method.clone();
215                        head.headers = headers.clone();
216
217                        let head = RequestHeadType::Owned(head);
218
219                        let mut max_redirect_times = *max_redirect_times;
220                        max_redirect_times -= 1;
221
222                        let fut = connector
223                            .as_ref()
224                            .unwrap()
225                            .call(ConnectRequest::Client(head, body_new, addr));
226
227                        self.set(RedirectServiceFuture::Client {
228                            fut,
229                            max_redirect_times,
230                            uri: Some(next_uri),
231                            method: Some(method),
232                            headers: Some(headers),
233                            body,
234                            addr,
235                            connector,
236                        });
237
238                        self.poll(cx)
239                    }
240                    _ => Poll::Ready(Ok(ConnectResponse::Client(res))),
241                },
242                _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
243            },
244        }
245    }
246}
247
248fn build_next_uri(res: &ClientResponse, prev_uri: &Uri) -> Result<Uri, SendRequestError> {
249    // responses without this header are not processed
250    let location = res.headers().get(header::LOCATION).unwrap();
251
252    // try to parse the location and resolve to a full URI but fall back to default if it fails
253    let uri = Uri::try_from(location.as_bytes()).unwrap_or_else(|_| Uri::default());
254
255    let uri = if uri.scheme().is_none() || uri.authority().is_none() {
256        let builder = Uri::builder()
257            .scheme(prev_uri.scheme().cloned().unwrap())
258            .authority(prev_uri.authority().cloned().unwrap());
259
260        // when scheme or authority is missing treat the location value as path and query
261        // recover error where location does not have leading slash
262        let path = if location.as_bytes().starts_with(b"/") {
263            location.as_bytes().to_owned()
264        } else {
265            [b"/", location.as_bytes()].concat()
266        };
267
268        builder
269            .path_and_query(path)
270            .build()
271            .map_err(|err| SendRequestError::Url(InvalidUrl::HttpError(err)))?
272    } else {
273        uri
274    };
275
276    Ok(uri)
277}
278
279fn remove_sensitive_headers(headers: &mut header::HeaderMap, prev_uri: &Uri, next_uri: &Uri) {
280    if next_uri.host() != prev_uri.host()
281        || next_uri.port() != prev_uri.port()
282        || next_uri.scheme() != prev_uri.scheme()
283    {
284        headers.remove(header::COOKIE);
285        headers.remove(header::AUTHORIZATION);
286        headers.remove(header::PROXY_AUTHORIZATION);
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use std::str::FromStr;
293
294    use actix_web::{web, App, Error, HttpRequest, HttpResponse};
295
296    use super::*;
297    use crate::{
298        http::{header::HeaderValue, StatusCode},
299        ClientBuilder,
300    };
301
302    #[actix_rt::test]
303    async fn basic_redirect() {
304        let client = ClientBuilder::new()
305            .disable_redirects()
306            .wrap(Redirect::new().max_redirect_times(10))
307            .finish();
308
309        let srv = actix_test::start(|| {
310            App::new()
311                .service(web::resource("/test").route(web::to(|| async {
312                    Ok::<_, Error>(HttpResponse::BadRequest())
313                })))
314                .service(web::resource("/").route(web::to(|| async {
315                    Ok::<_, Error>(
316                        HttpResponse::Found()
317                            .append_header(("location", "/test"))
318                            .finish(),
319                    )
320                })))
321        });
322
323        let res = client.get(srv.url("/")).send().await.unwrap();
324
325        assert_eq!(res.status().as_u16(), 400);
326    }
327
328    #[actix_rt::test]
329    async fn redirect_relative_without_leading_slash() {
330        let client = ClientBuilder::new().finish();
331
332        let srv = actix_test::start(|| {
333            App::new()
334                .service(web::resource("/").route(web::to(|| async {
335                    HttpResponse::Found()
336                        .insert_header(("location", "abc/"))
337                        .finish()
338                })))
339                .service(
340                    web::resource("/abc/")
341                        .route(web::to(|| async { HttpResponse::Accepted().finish() })),
342                )
343        });
344
345        let res = client.get(srv.url("/")).send().await.unwrap();
346        assert_eq!(res.status(), StatusCode::ACCEPTED);
347    }
348
349    #[actix_rt::test]
350    async fn redirect_without_location() {
351        let client = ClientBuilder::new()
352            .disable_redirects()
353            .wrap(Redirect::new().max_redirect_times(10))
354            .finish();
355
356        let srv = actix_test::start(|| {
357            App::new().service(web::resource("/").route(web::to(|| async {
358                Ok::<_, Error>(HttpResponse::Found().finish())
359            })))
360        });
361
362        let res = client.get(srv.url("/")).send().await.unwrap();
363        assert_eq!(res.status(), StatusCode::FOUND);
364    }
365
366    #[actix_rt::test]
367    async fn test_redirect_limit() {
368        let client = ClientBuilder::new()
369            .disable_redirects()
370            .wrap(Redirect::new().max_redirect_times(1))
371            .connector(crate::Connector::new())
372            .finish();
373
374        let srv = actix_test::start(|| {
375            App::new()
376                .service(web::resource("/").route(web::to(|| async {
377                    Ok::<_, Error>(
378                        HttpResponse::Found()
379                            .insert_header(("location", "/test"))
380                            .finish(),
381                    )
382                })))
383                .service(web::resource("/test").route(web::to(|| async {
384                    Ok::<_, Error>(
385                        HttpResponse::Found()
386                            .insert_header(("location", "/test2"))
387                            .finish(),
388                    )
389                })))
390                .service(web::resource("/test2").route(web::to(|| async {
391                    Ok::<_, Error>(HttpResponse::BadRequest())
392                })))
393        });
394
395        let res = client.get(srv.url("/")).send().await.unwrap();
396        assert_eq!(res.status(), StatusCode::FOUND);
397        assert_eq!(
398            res.headers()
399                .get(header::LOCATION)
400                .unwrap()
401                .to_str()
402                .unwrap(),
403            "/test2"
404        );
405    }
406
407    #[actix_rt::test]
408    async fn test_redirect_status_kind_307_308() {
409        let srv = actix_test::start(|| {
410            async fn root() -> HttpResponse {
411                HttpResponse::TemporaryRedirect()
412                    .append_header(("location", "/test"))
413                    .finish()
414            }
415
416            async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
417                if req.method() == Method::POST && !body.is_empty() {
418                    HttpResponse::Ok().finish()
419                } else {
420                    HttpResponse::InternalServerError().finish()
421                }
422            }
423
424            App::new()
425                .service(web::resource("/").route(web::to(root)))
426                .service(web::resource("/test").route(web::to(test)))
427        });
428
429        let res = srv.post("/").send_body("Hello").await.unwrap();
430        assert_eq!(res.status().as_u16(), 200);
431    }
432
433    #[actix_rt::test]
434    async fn test_redirect_status_kind_301_302_303() {
435        let srv = actix_test::start(|| {
436            async fn root() -> HttpResponse {
437                HttpResponse::Found()
438                    .append_header(("location", "/test"))
439                    .finish()
440            }
441
442            async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
443                if (req.method() == Method::GET || req.method() == Method::HEAD)
444                    && body.is_empty()
445                {
446                    HttpResponse::Ok().finish()
447                } else {
448                    HttpResponse::InternalServerError().finish()
449                }
450            }
451
452            App::new()
453                .service(web::resource("/").route(web::to(root)))
454                .service(web::resource("/test").route(web::to(test)))
455        });
456
457        let res = srv.post("/").send_body("Hello").await.unwrap();
458        assert_eq!(res.status().as_u16(), 200);
459
460        let res = srv.post("/").send().await.unwrap();
461        assert_eq!(res.status().as_u16(), 200);
462    }
463
464    #[actix_rt::test]
465    async fn test_redirect_headers() {
466        let srv = actix_test::start(|| {
467            async fn root(req: HttpRequest) -> HttpResponse {
468                if req
469                    .headers()
470                    .get("custom")
471                    .unwrap_or(&HeaderValue::from_str("").unwrap())
472                    == "value"
473                {
474                    HttpResponse::Found()
475                        .append_header(("location", "/test"))
476                        .finish()
477                } else {
478                    HttpResponse::InternalServerError().finish()
479                }
480            }
481
482            async fn test(req: HttpRequest) -> HttpResponse {
483                if req
484                    .headers()
485                    .get("custom")
486                    .unwrap_or(&HeaderValue::from_str("").unwrap())
487                    == "value"
488                {
489                    HttpResponse::Ok().finish()
490                } else {
491                    HttpResponse::InternalServerError().finish()
492                }
493            }
494
495            App::new()
496                .service(web::resource("/").route(web::to(root)))
497                .service(web::resource("/test").route(web::to(test)))
498        });
499
500        let client = ClientBuilder::new()
501            .add_default_header(("custom", "value"))
502            .disable_redirects()
503            .finish();
504        let res = client.get(srv.url("/")).send().await.unwrap();
505        assert_eq!(res.status().as_u16(), 302);
506
507        let client = ClientBuilder::new()
508            .add_default_header(("custom", "value"))
509            .finish();
510        let res = client.get(srv.url("/")).send().await.unwrap();
511        assert_eq!(res.status().as_u16(), 200);
512
513        let client = ClientBuilder::new().finish();
514        let res = client
515            .get(srv.url("/"))
516            .insert_header(("custom", "value"))
517            .send()
518            .await
519            .unwrap();
520        assert_eq!(res.status().as_u16(), 200);
521    }
522
523    #[actix_rt::test]
524    async fn test_redirect_cross_origin_headers() {
525        // defining two services to have two different origins
526        let srv2 = actix_test::start(|| {
527            async fn root(req: HttpRequest) -> HttpResponse {
528                if req.headers().get(header::AUTHORIZATION).is_none() {
529                    HttpResponse::Ok().finish()
530                } else {
531                    HttpResponse::InternalServerError().finish()
532                }
533            }
534
535            App::new().service(web::resource("/").route(web::to(root)))
536        });
537        let srv2_port: u16 = srv2.addr().port();
538
539        let srv1 = actix_test::start(move || {
540            async fn root(req: HttpRequest) -> HttpResponse {
541                let port = *req.app_data::<u16>().unwrap();
542                if req.headers().get(header::AUTHORIZATION).is_some() {
543                    HttpResponse::Found()
544                        .append_header((
545                            "location",
546                            format!("http://localhost:{}/", port).as_str(),
547                        ))
548                        .finish()
549                } else {
550                    HttpResponse::InternalServerError().finish()
551                }
552            }
553
554            async fn test1(req: HttpRequest) -> HttpResponse {
555                if req.headers().get(header::AUTHORIZATION).is_some() {
556                    HttpResponse::Found()
557                        .append_header(("location", "/test2"))
558                        .finish()
559                } else {
560                    HttpResponse::InternalServerError().finish()
561                }
562            }
563
564            async fn test2(req: HttpRequest) -> HttpResponse {
565                if req.headers().get(header::AUTHORIZATION).is_some() {
566                    HttpResponse::Ok().finish()
567                } else {
568                    HttpResponse::InternalServerError().finish()
569                }
570            }
571
572            App::new()
573                .app_data(srv2_port)
574                .service(web::resource("/").route(web::to(root)))
575                .service(web::resource("/test1").route(web::to(test1)))
576                .service(web::resource("/test2").route(web::to(test2)))
577        });
578
579        // send a request to different origins, http://srv1/ then http://srv2/. So it should remove the header
580        let client = ClientBuilder::new()
581            .add_default_header((header::AUTHORIZATION, "auth_key_value"))
582            .finish();
583        let res = client.get(srv1.url("/")).send().await.unwrap();
584        assert_eq!(res.status().as_u16(), 200);
585
586        // send a request to same origin, http://srv1/test1 then http://srv1/test2. So it should NOT remove any header
587        let res = client.get(srv1.url("/test1")).send().await.unwrap();
588        assert_eq!(res.status().as_u16(), 200);
589    }
590
591    #[actix_rt::test]
592    async fn test_remove_sensitive_headers() {
593        fn gen_headers() -> header::HeaderMap {
594            let mut headers = header::HeaderMap::new();
595            headers.insert(header::USER_AGENT, HeaderValue::from_str("value").unwrap());
596            headers.insert(
597                header::AUTHORIZATION,
598                HeaderValue::from_str("value").unwrap(),
599            );
600            headers.insert(
601                header::PROXY_AUTHORIZATION,
602                HeaderValue::from_str("value").unwrap(),
603            );
604            headers.insert(header::COOKIE, HeaderValue::from_str("value").unwrap());
605            headers
606        }
607
608        // Same origin
609        let prev_uri = Uri::from_str("https://host/path1").unwrap();
610        let next_uri = Uri::from_str("https://host/path2").unwrap();
611        let mut headers = gen_headers();
612        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
613        assert_eq!(headers.len(), 4);
614
615        // different schema
616        let prev_uri = Uri::from_str("http://host/").unwrap();
617        let next_uri = Uri::from_str("https://host/").unwrap();
618        let mut headers = gen_headers();
619        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
620        assert_eq!(headers.len(), 1);
621
622        // different host
623        let prev_uri = Uri::from_str("https://host1/").unwrap();
624        let next_uri = Uri::from_str("https://host2/").unwrap();
625        let mut headers = gen_headers();
626        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
627        assert_eq!(headers.len(), 1);
628
629        // different port
630        let prev_uri = Uri::from_str("https://host:12/").unwrap();
631        let next_uri = Uri::from_str("https://host:23/").unwrap();
632        let mut headers = gen_headers();
633        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
634        assert_eq!(headers.len(), 1);
635
636        // different everything!
637        let prev_uri = Uri::from_str("http://host1:12/path1").unwrap();
638        let next_uri = Uri::from_str("https://host2:23/path2").unwrap();
639        let mut headers = gen_headers();
640        remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
641        assert_eq!(headers.len(), 1);
642    }
643}