1use std::{
2 convert::TryFrom,
3 future::Future,
4 net::SocketAddr,
5 pin::Pin,
6 rc::Rc,
7 task::{Context, Poll},
8};
9
10use actix_http::{header, Method, RequestHead, RequestHeadType, StatusCode, Uri};
11use actix_service::Service;
12use bytes::Bytes;
13use futures_core::ready;
14
15use super::Transform;
16use crate::{
17 any_body::AnyBody,
18 client::{InvalidUrl, SendRequestError},
19 connect::{ConnectRequest, ConnectResponse},
20 ClientResponse,
21};
22
23pub struct Redirect {
24 max_redirect_times: u8,
25}
26
27impl Default for Redirect {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl Redirect {
34 pub fn new() -> Self {
35 Self {
36 max_redirect_times: 10,
37 }
38 }
39
40 pub fn max_redirect_times(mut self, times: u8) -> Self {
41 self.max_redirect_times = times;
42 self
43 }
44}
45
46impl<S> Transform<S, ConnectRequest> for Redirect
47where
48 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
49{
50 type Transform = RedirectService<S>;
51
52 fn new_transform(self, service: S) -> Self::Transform {
53 RedirectService {
54 max_redirect_times: self.max_redirect_times,
55 connector: Rc::new(service),
56 }
57 }
58}
59
60pub struct RedirectService<S> {
61 max_redirect_times: u8,
62 connector: Rc<S>,
63}
64
65impl<S> Service<ConnectRequest> for RedirectService<S>
66where
67 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
68{
69 type Response = S::Response;
70 type Error = S::Error;
71 type Future = RedirectServiceFuture<S>;
72
73 actix_service::forward_ready!(connector);
74
75 fn call(&self, req: ConnectRequest) -> Self::Future {
76 match req {
77 ConnectRequest::Tunnel(head, addr) => {
78 let fut = self.connector.call(ConnectRequest::Tunnel(head, addr));
79 RedirectServiceFuture::Tunnel { fut }
80 }
81 ConnectRequest::Client(head, body, addr) => {
82 let connector = self.connector.clone();
83 let max_redirect_times = self.max_redirect_times;
84
85 let (uri, method, headers) = match head {
87 RequestHeadType::Owned(ref head) => {
88 (head.uri.clone(), head.method.clone(), head.headers.clone())
89 }
90 RequestHeadType::Rc(ref head, ..) => {
91 (head.uri.clone(), head.method.clone(), head.headers.clone())
92 }
93 };
94
95 let body_opt = match body {
96 AnyBody::Bytes { ref body } => Some(body.clone()),
97 _ => None,
98 };
99
100 let fut = connector.call(ConnectRequest::Client(head, body, addr));
101
102 RedirectServiceFuture::Client {
103 fut,
104 max_redirect_times,
105 uri: Some(uri),
106 method: Some(method),
107 headers: Some(headers),
108 body: body_opt,
109 addr,
110 connector: Some(connector),
111 }
112 }
113 }
114 }
115}
116
117pin_project_lite::pin_project! {
118 #[project = RedirectServiceProj]
119 pub enum RedirectServiceFuture<S>
120 where
121 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
122 S: 'static
123 {
124 Tunnel { #[pin] fut: S::Future },
125 Client {
126 #[pin]
127 fut: S::Future,
128 max_redirect_times: u8,
129 uri: Option<Uri>,
130 method: Option<Method>,
131 headers: Option<header::HeaderMap>,
132 body: Option<Bytes>,
133 addr: Option<SocketAddr>,
134 connector: Option<Rc<S>>,
135 }
136 }
137}
138
139impl<S> Future for RedirectServiceFuture<S>
140where
141 S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
142{
143 type Output = Result<ConnectResponse, SendRequestError>;
144
145 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
146 match self.as_mut().project() {
147 RedirectServiceProj::Tunnel { fut } => fut.poll(cx),
148 RedirectServiceProj::Client {
149 fut,
150 max_redirect_times,
151 uri,
152 method,
153 headers,
154 body,
155 addr,
156 connector,
157 } => match ready!(fut.poll(cx))? {
158 ConnectResponse::Client(res) => match res.head().status {
159 StatusCode::MOVED_PERMANENTLY
160 | StatusCode::FOUND
161 | StatusCode::SEE_OTHER
162 | StatusCode::TEMPORARY_REDIRECT
163 | StatusCode::PERMANENT_REDIRECT
164 if *max_redirect_times > 0
165 && res.headers().contains_key(header::LOCATION) =>
166 {
167 let reuse_body = res.head().status == StatusCode::TEMPORARY_REDIRECT
168 || res.head().status == StatusCode::PERMANENT_REDIRECT;
169
170 let prev_uri = uri.take().unwrap();
171
172 let next_uri = build_next_uri(&res, &prev_uri)?;
174
175 let addr = addr.take();
177 let connector = connector.take();
178
179 let method = if reuse_body {
181 method.take().unwrap()
182 } else {
183 let method = method.take().unwrap();
184 match method {
185 Method::GET | Method::HEAD => method,
186 _ => Method::GET,
187 }
188 };
189
190 let mut body = body.take();
191 let body_new = if reuse_body {
192 match body {
194 Some(ref bytes) => AnyBody::Bytes {
195 body: bytes.clone(),
196 },
197
198 _ => AnyBody::empty(),
200 }
201 } else {
202 body = None;
203 AnyBody::None
205 };
206
207 let mut headers = headers.take().unwrap();
208
209 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
210
211 let mut head = RequestHead::default();
213 head.uri = next_uri.clone();
214 head.method = method.clone();
215 head.headers = headers.clone();
216
217 let head = RequestHeadType::Owned(head);
218
219 let mut max_redirect_times = *max_redirect_times;
220 max_redirect_times -= 1;
221
222 let fut = connector
223 .as_ref()
224 .unwrap()
225 .call(ConnectRequest::Client(head, body_new, addr));
226
227 self.set(RedirectServiceFuture::Client {
228 fut,
229 max_redirect_times,
230 uri: Some(next_uri),
231 method: Some(method),
232 headers: Some(headers),
233 body,
234 addr,
235 connector,
236 });
237
238 self.poll(cx)
239 }
240 _ => Poll::Ready(Ok(ConnectResponse::Client(res))),
241 },
242 _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
243 },
244 }
245 }
246}
247
248fn build_next_uri(res: &ClientResponse, prev_uri: &Uri) -> Result<Uri, SendRequestError> {
249 let location = res.headers().get(header::LOCATION).unwrap();
251
252 let uri = Uri::try_from(location.as_bytes()).unwrap_or_else(|_| Uri::default());
254
255 let uri = if uri.scheme().is_none() || uri.authority().is_none() {
256 let builder = Uri::builder()
257 .scheme(prev_uri.scheme().cloned().unwrap())
258 .authority(prev_uri.authority().cloned().unwrap());
259
260 let path = if location.as_bytes().starts_with(b"/") {
263 location.as_bytes().to_owned()
264 } else {
265 [b"/", location.as_bytes()].concat()
266 };
267
268 builder
269 .path_and_query(path)
270 .build()
271 .map_err(|err| SendRequestError::Url(InvalidUrl::HttpError(err)))?
272 } else {
273 uri
274 };
275
276 Ok(uri)
277}
278
279fn remove_sensitive_headers(headers: &mut header::HeaderMap, prev_uri: &Uri, next_uri: &Uri) {
280 if next_uri.host() != prev_uri.host()
281 || next_uri.port() != prev_uri.port()
282 || next_uri.scheme() != prev_uri.scheme()
283 {
284 headers.remove(header::COOKIE);
285 headers.remove(header::AUTHORIZATION);
286 headers.remove(header::PROXY_AUTHORIZATION);
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use std::str::FromStr;
293
294 use actix_web::{web, App, Error, HttpRequest, HttpResponse};
295
296 use super::*;
297 use crate::{
298 http::{header::HeaderValue, StatusCode},
299 ClientBuilder,
300 };
301
302 #[actix_rt::test]
303 async fn basic_redirect() {
304 let client = ClientBuilder::new()
305 .disable_redirects()
306 .wrap(Redirect::new().max_redirect_times(10))
307 .finish();
308
309 let srv = actix_test::start(|| {
310 App::new()
311 .service(web::resource("/test").route(web::to(|| async {
312 Ok::<_, Error>(HttpResponse::BadRequest())
313 })))
314 .service(web::resource("/").route(web::to(|| async {
315 Ok::<_, Error>(
316 HttpResponse::Found()
317 .append_header(("location", "/test"))
318 .finish(),
319 )
320 })))
321 });
322
323 let res = client.get(srv.url("/")).send().await.unwrap();
324
325 assert_eq!(res.status().as_u16(), 400);
326 }
327
328 #[actix_rt::test]
329 async fn redirect_relative_without_leading_slash() {
330 let client = ClientBuilder::new().finish();
331
332 let srv = actix_test::start(|| {
333 App::new()
334 .service(web::resource("/").route(web::to(|| async {
335 HttpResponse::Found()
336 .insert_header(("location", "abc/"))
337 .finish()
338 })))
339 .service(
340 web::resource("/abc/")
341 .route(web::to(|| async { HttpResponse::Accepted().finish() })),
342 )
343 });
344
345 let res = client.get(srv.url("/")).send().await.unwrap();
346 assert_eq!(res.status(), StatusCode::ACCEPTED);
347 }
348
349 #[actix_rt::test]
350 async fn redirect_without_location() {
351 let client = ClientBuilder::new()
352 .disable_redirects()
353 .wrap(Redirect::new().max_redirect_times(10))
354 .finish();
355
356 let srv = actix_test::start(|| {
357 App::new().service(web::resource("/").route(web::to(|| async {
358 Ok::<_, Error>(HttpResponse::Found().finish())
359 })))
360 });
361
362 let res = client.get(srv.url("/")).send().await.unwrap();
363 assert_eq!(res.status(), StatusCode::FOUND);
364 }
365
366 #[actix_rt::test]
367 async fn test_redirect_limit() {
368 let client = ClientBuilder::new()
369 .disable_redirects()
370 .wrap(Redirect::new().max_redirect_times(1))
371 .connector(crate::Connector::new())
372 .finish();
373
374 let srv = actix_test::start(|| {
375 App::new()
376 .service(web::resource("/").route(web::to(|| async {
377 Ok::<_, Error>(
378 HttpResponse::Found()
379 .insert_header(("location", "/test"))
380 .finish(),
381 )
382 })))
383 .service(web::resource("/test").route(web::to(|| async {
384 Ok::<_, Error>(
385 HttpResponse::Found()
386 .insert_header(("location", "/test2"))
387 .finish(),
388 )
389 })))
390 .service(web::resource("/test2").route(web::to(|| async {
391 Ok::<_, Error>(HttpResponse::BadRequest())
392 })))
393 });
394
395 let res = client.get(srv.url("/")).send().await.unwrap();
396 assert_eq!(res.status(), StatusCode::FOUND);
397 assert_eq!(
398 res.headers()
399 .get(header::LOCATION)
400 .unwrap()
401 .to_str()
402 .unwrap(),
403 "/test2"
404 );
405 }
406
407 #[actix_rt::test]
408 async fn test_redirect_status_kind_307_308() {
409 let srv = actix_test::start(|| {
410 async fn root() -> HttpResponse {
411 HttpResponse::TemporaryRedirect()
412 .append_header(("location", "/test"))
413 .finish()
414 }
415
416 async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
417 if req.method() == Method::POST && !body.is_empty() {
418 HttpResponse::Ok().finish()
419 } else {
420 HttpResponse::InternalServerError().finish()
421 }
422 }
423
424 App::new()
425 .service(web::resource("/").route(web::to(root)))
426 .service(web::resource("/test").route(web::to(test)))
427 });
428
429 let res = srv.post("/").send_body("Hello").await.unwrap();
430 assert_eq!(res.status().as_u16(), 200);
431 }
432
433 #[actix_rt::test]
434 async fn test_redirect_status_kind_301_302_303() {
435 let srv = actix_test::start(|| {
436 async fn root() -> HttpResponse {
437 HttpResponse::Found()
438 .append_header(("location", "/test"))
439 .finish()
440 }
441
442 async fn test(req: HttpRequest, body: Bytes) -> HttpResponse {
443 if (req.method() == Method::GET || req.method() == Method::HEAD)
444 && body.is_empty()
445 {
446 HttpResponse::Ok().finish()
447 } else {
448 HttpResponse::InternalServerError().finish()
449 }
450 }
451
452 App::new()
453 .service(web::resource("/").route(web::to(root)))
454 .service(web::resource("/test").route(web::to(test)))
455 });
456
457 let res = srv.post("/").send_body("Hello").await.unwrap();
458 assert_eq!(res.status().as_u16(), 200);
459
460 let res = srv.post("/").send().await.unwrap();
461 assert_eq!(res.status().as_u16(), 200);
462 }
463
464 #[actix_rt::test]
465 async fn test_redirect_headers() {
466 let srv = actix_test::start(|| {
467 async fn root(req: HttpRequest) -> HttpResponse {
468 if req
469 .headers()
470 .get("custom")
471 .unwrap_or(&HeaderValue::from_str("").unwrap())
472 == "value"
473 {
474 HttpResponse::Found()
475 .append_header(("location", "/test"))
476 .finish()
477 } else {
478 HttpResponse::InternalServerError().finish()
479 }
480 }
481
482 async fn test(req: HttpRequest) -> HttpResponse {
483 if req
484 .headers()
485 .get("custom")
486 .unwrap_or(&HeaderValue::from_str("").unwrap())
487 == "value"
488 {
489 HttpResponse::Ok().finish()
490 } else {
491 HttpResponse::InternalServerError().finish()
492 }
493 }
494
495 App::new()
496 .service(web::resource("/").route(web::to(root)))
497 .service(web::resource("/test").route(web::to(test)))
498 });
499
500 let client = ClientBuilder::new()
501 .add_default_header(("custom", "value"))
502 .disable_redirects()
503 .finish();
504 let res = client.get(srv.url("/")).send().await.unwrap();
505 assert_eq!(res.status().as_u16(), 302);
506
507 let client = ClientBuilder::new()
508 .add_default_header(("custom", "value"))
509 .finish();
510 let res = client.get(srv.url("/")).send().await.unwrap();
511 assert_eq!(res.status().as_u16(), 200);
512
513 let client = ClientBuilder::new().finish();
514 let res = client
515 .get(srv.url("/"))
516 .insert_header(("custom", "value"))
517 .send()
518 .await
519 .unwrap();
520 assert_eq!(res.status().as_u16(), 200);
521 }
522
523 #[actix_rt::test]
524 async fn test_redirect_cross_origin_headers() {
525 let srv2 = actix_test::start(|| {
527 async fn root(req: HttpRequest) -> HttpResponse {
528 if req.headers().get(header::AUTHORIZATION).is_none() {
529 HttpResponse::Ok().finish()
530 } else {
531 HttpResponse::InternalServerError().finish()
532 }
533 }
534
535 App::new().service(web::resource("/").route(web::to(root)))
536 });
537 let srv2_port: u16 = srv2.addr().port();
538
539 let srv1 = actix_test::start(move || {
540 async fn root(req: HttpRequest) -> HttpResponse {
541 let port = *req.app_data::<u16>().unwrap();
542 if req.headers().get(header::AUTHORIZATION).is_some() {
543 HttpResponse::Found()
544 .append_header((
545 "location",
546 format!("http://localhost:{}/", port).as_str(),
547 ))
548 .finish()
549 } else {
550 HttpResponse::InternalServerError().finish()
551 }
552 }
553
554 async fn test1(req: HttpRequest) -> HttpResponse {
555 if req.headers().get(header::AUTHORIZATION).is_some() {
556 HttpResponse::Found()
557 .append_header(("location", "/test2"))
558 .finish()
559 } else {
560 HttpResponse::InternalServerError().finish()
561 }
562 }
563
564 async fn test2(req: HttpRequest) -> HttpResponse {
565 if req.headers().get(header::AUTHORIZATION).is_some() {
566 HttpResponse::Ok().finish()
567 } else {
568 HttpResponse::InternalServerError().finish()
569 }
570 }
571
572 App::new()
573 .app_data(srv2_port)
574 .service(web::resource("/").route(web::to(root)))
575 .service(web::resource("/test1").route(web::to(test1)))
576 .service(web::resource("/test2").route(web::to(test2)))
577 });
578
579 let client = ClientBuilder::new()
581 .add_default_header((header::AUTHORIZATION, "auth_key_value"))
582 .finish();
583 let res = client.get(srv1.url("/")).send().await.unwrap();
584 assert_eq!(res.status().as_u16(), 200);
585
586 let res = client.get(srv1.url("/test1")).send().await.unwrap();
588 assert_eq!(res.status().as_u16(), 200);
589 }
590
591 #[actix_rt::test]
592 async fn test_remove_sensitive_headers() {
593 fn gen_headers() -> header::HeaderMap {
594 let mut headers = header::HeaderMap::new();
595 headers.insert(header::USER_AGENT, HeaderValue::from_str("value").unwrap());
596 headers.insert(
597 header::AUTHORIZATION,
598 HeaderValue::from_str("value").unwrap(),
599 );
600 headers.insert(
601 header::PROXY_AUTHORIZATION,
602 HeaderValue::from_str("value").unwrap(),
603 );
604 headers.insert(header::COOKIE, HeaderValue::from_str("value").unwrap());
605 headers
606 }
607
608 let prev_uri = Uri::from_str("https://host/path1").unwrap();
610 let next_uri = Uri::from_str("https://host/path2").unwrap();
611 let mut headers = gen_headers();
612 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
613 assert_eq!(headers.len(), 4);
614
615 let prev_uri = Uri::from_str("http://host/").unwrap();
617 let next_uri = Uri::from_str("https://host/").unwrap();
618 let mut headers = gen_headers();
619 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
620 assert_eq!(headers.len(), 1);
621
622 let prev_uri = Uri::from_str("https://host1/").unwrap();
624 let next_uri = Uri::from_str("https://host2/").unwrap();
625 let mut headers = gen_headers();
626 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
627 assert_eq!(headers.len(), 1);
628
629 let prev_uri = Uri::from_str("https://host:12/").unwrap();
631 let next_uri = Uri::from_str("https://host:23/").unwrap();
632 let mut headers = gen_headers();
633 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
634 assert_eq!(headers.len(), 1);
635
636 let prev_uri = Uri::from_str("http://host1:12/path1").unwrap();
638 let next_uri = Uri::from_str("https://host2:23/path2").unwrap();
639 let mut headers = gen_headers();
640 remove_sensitive_headers(&mut headers, &prev_uri, &next_uri);
641 assert_eq!(headers.len(), 1);
642 }
643}