Skip to main content

reqwest/
redirect.rs

1//! Redirect Handling
2//!
3//! By default, a `Client` will automatically handle HTTP redirects, having a
4//! maximum redirect chain of 10 hops. To customize this behavior, a
5//! `redirect::Policy` can be used with a `ClientBuilder`.
6
7use std::fmt;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::{error::Error as StdError, sync::Arc};
10
11use crate::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, REFERER, WWW_AUTHENTICATE};
12use http::{HeaderMap, HeaderValue};
13use hyper::StatusCode;
14
15use crate::{async_impl, Url};
16use tower_http::follow_redirect::policy::{
17    Action as TowerAction, Attempt as TowerAttempt, Policy as TowerPolicy,
18};
19
20/// A type that controls the policy on how to handle the following of redirects.
21///
22/// The default value will catch redirect loops, and has a maximum of 10
23/// redirects it will follow in a chain before returning an error.
24///
25/// - `limited` can be used have the same as the default behavior, but adjust
26///   the allowed maximum redirect hops in a chain.
27/// - `none` can be used to disable all redirect behavior.
28/// - `custom` can be used to create a customized policy.
29pub struct Policy {
30    pub(crate) inner: PolicyKind,
31}
32
33/// A type that holds information on the next request and previous requests
34/// in redirect chain.
35#[derive(Debug)]
36pub struct Attempt<'a> {
37    status: StatusCode,
38    next: &'a Url,
39    previous: &'a [Url],
40}
41
42/// An action to perform when a redirect status code is found.
43#[derive(Debug)]
44pub struct Action {
45    inner: ActionKind,
46}
47
48impl Policy {
49    /// Create a `Policy` with a maximum number of redirects.
50    ///
51    /// An `Error` will be returned if the max is reached.
52    pub fn limited(max: usize) -> Self {
53        Self {
54            inner: PolicyKind::Limit(max),
55        }
56    }
57
58    /// Create a `Policy` that does not follow any redirect.
59    pub fn none() -> Self {
60        Self {
61            inner: PolicyKind::None,
62        }
63    }
64
65    /// Create a custom `Policy` using the passed function.
66    ///
67    /// # Note
68    ///
69    /// The default `Policy` handles a maximum loop
70    /// chain, but the custom variant does not do that for you automatically.
71    /// The custom policy should have some way of handling those.
72    ///
73    /// Information on the next request and previous requests can be found
74    /// on the [`Attempt`] argument passed to the closure.
75    ///
76    /// Actions can be conveniently created from methods on the
77    /// [`Attempt`].
78    ///
79    /// # Example
80    ///
81    /// ```rust
82    /// # use reqwest::{Error, redirect};
83    /// #
84    /// # fn run() -> Result<(), Error> {
85    /// let custom = redirect::Policy::custom(|attempt| {
86    ///     if attempt.previous().len() > 5 {
87    ///         attempt.error("too many redirects")
88    ///     } else if attempt.url().host_str() == Some("example.domain") {
89    ///         // prevent redirects to 'example.domain'
90    ///         attempt.stop()
91    ///     } else {
92    ///         attempt.follow()
93    ///     }
94    /// });
95    /// let client = reqwest::Client::builder()
96    ///     .redirect(custom)
97    ///     .build()?;
98    /// # Ok(())
99    /// # }
100    /// ```
101    ///
102    /// [`Attempt`]: struct.Attempt.html
103    pub fn custom<T>(policy: T) -> Self
104    where
105        T: Fn(Attempt) -> Action + Send + Sync + 'static,
106    {
107        Self {
108            inner: PolicyKind::Custom(Box::new(policy)),
109        }
110    }
111
112    /// Apply this policy to a given [`Attempt`] to produce a [`Action`].
113    ///
114    /// # Note
115    ///
116    /// This method can be used together with `Policy::custom()`
117    /// to construct one `Policy` that wraps another.
118    ///
119    /// # Example
120    ///
121    /// ```rust
122    /// # use reqwest::{Error, redirect};
123    /// #
124    /// # fn run() -> Result<(), Error> {
125    /// let custom = redirect::Policy::custom(|attempt| {
126    ///     eprintln!("{}, Location: {:?}", attempt.status(), attempt.url());
127    ///     redirect::Policy::default().redirect(attempt)
128    /// });
129    /// # Ok(())
130    /// # }
131    /// ```
132    pub fn redirect(&self, attempt: Attempt) -> Action {
133        match self.inner {
134            PolicyKind::Custom(ref custom) => custom(attempt),
135            PolicyKind::Limit(max) => {
136                // The first URL in the previous is the initial URL and not a redirection. It needs to be excluded.
137                if attempt.previous.len() > max {
138                    attempt.error(TooManyRedirects)
139                } else {
140                    attempt.follow()
141                }
142            }
143            PolicyKind::None => attempt.stop(),
144        }
145    }
146
147    pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
148        self.redirect(Attempt {
149            status,
150            next,
151            previous,
152        })
153        .inner
154    }
155
156    pub(crate) fn is_default(&self) -> bool {
157        matches!(self.inner, PolicyKind::Limit(10))
158    }
159}
160
161impl Default for Policy {
162    fn default() -> Policy {
163        // Keep `is_default` in sync
164        Policy::limited(10)
165    }
166}
167
168impl Attempt<'_> {
169    /// Get the type of redirect.
170    pub fn status(&self) -> StatusCode {
171        self.status
172    }
173
174    /// Get the next URL to redirect to.
175    pub fn url(&self) -> &Url {
176        self.next
177    }
178
179    /// Get the list of previous URLs that have already been requested in this chain.
180    pub fn previous(&self) -> &[Url] {
181        self.previous
182    }
183    /// Returns an action meaning reqwest should follow the next URL.
184    pub fn follow(self) -> Action {
185        Action {
186            inner: ActionKind::Follow,
187        }
188    }
189
190    /// Returns an action meaning reqwest should not follow the next URL.
191    ///
192    /// The 30x response will be returned as the `Ok` result.
193    pub fn stop(self) -> Action {
194        Action {
195            inner: ActionKind::Stop,
196        }
197    }
198
199    /// Returns an action failing the redirect with an error.
200    ///
201    /// The `Error` will be returned for the result of the sent request.
202    pub fn error<E: Into<Box<dyn StdError + Send + Sync>>>(self, error: E) -> Action {
203        Action {
204            inner: ActionKind::Error(error.into()),
205        }
206    }
207}
208
209pub(crate) enum PolicyKind {
210    Custom(Box<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
211    Limit(usize),
212    None,
213}
214
215impl fmt::Debug for Policy {
216    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
217        f.debug_tuple("Policy").field(&self.inner).finish()
218    }
219}
220
221impl fmt::Debug for PolicyKind {
222    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
223        match *self {
224            PolicyKind::Custom(..) => f.pad("Custom"),
225            PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
226            PolicyKind::None => f.pad("None"),
227        }
228    }
229}
230
231// pub(crate)
232
233#[derive(Debug)]
234pub(crate) enum ActionKind {
235    Follow,
236    Stop,
237    Error(Box<dyn StdError + Send + Sync>),
238}
239
240pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
241    if let Some(previous) = previous.last() {
242        let cross_host = next.host_str() != previous.host_str()
243            || next.port_or_known_default() != previous.port_or_known_default()
244            || next.scheme() != previous.scheme();
245        if cross_host {
246            headers.remove(AUTHORIZATION);
247            headers.remove(COOKIE);
248            headers.remove("cookie2");
249            headers.remove(PROXY_AUTHORIZATION);
250            headers.remove(WWW_AUTHENTICATE);
251        }
252    }
253}
254
255#[derive(Debug)]
256struct TooManyRedirects;
257
258impl fmt::Display for TooManyRedirects {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        f.write_str("too many redirects")
261    }
262}
263
264impl StdError for TooManyRedirects {}
265
266#[derive(Clone)]
267pub(crate) struct TowerRedirectPolicy {
268    policy: Arc<Policy>,
269    referer: bool,
270    urls: Vec<Url>,
271    https_only: bool,
272    redirect_enabled: Arc<AtomicBool>,
273}
274
275impl TowerRedirectPolicy {
276    pub(crate) fn new(policy: Policy) -> Self {
277        let enabled = !matches!(policy.inner, PolicyKind::None);
278        Self {
279            policy: Arc::new(policy),
280            referer: false,
281            urls: Vec::new(),
282            https_only: false,
283            redirect_enabled: Arc::new(AtomicBool::new(enabled)),
284        }
285    }
286
287    pub(crate) fn with_referer(&mut self, referer: bool) -> &mut Self {
288        self.referer = referer;
289        self
290    }
291
292    pub(crate) fn with_https_only(&mut self, https_only: bool) -> &mut Self {
293        self.https_only = https_only;
294        self
295    }
296
297    pub(crate) fn redirect_enabled_ref(&self) -> Arc<AtomicBool> {
298        self.redirect_enabled.clone()
299    }
300}
301
302fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
303    if next.scheme() == "http" && previous.scheme() == "https" {
304        return None;
305    }
306
307    let mut referer = previous.clone();
308    let _ = referer.set_username("");
309    let _ = referer.set_password(None);
310    referer.set_fragment(None);
311    referer.as_str().parse().ok()
312}
313
314impl TowerPolicy<async_impl::body::Body, crate::Error> for TowerRedirectPolicy {
315    fn redirect(&mut self, attempt: &TowerAttempt<'_>) -> Result<TowerAction, crate::Error> {
316        // Check if redirects are enabled
317        if !self.redirect_enabled.load(Ordering::Relaxed) {
318            return Ok(TowerAction::Stop);
319        }
320
321        let previous_url =
322            Url::parse(&attempt.previous().to_string()).expect("Previous URL must be valid");
323
324        let next_url = match Url::parse(&attempt.location().to_string()) {
325            Ok(url) => url,
326            Err(e) => return Err(crate::error::builder(e)),
327        };
328
329        self.urls.push(previous_url.clone());
330
331        match self.policy.check(attempt.status(), &next_url, &self.urls) {
332            ActionKind::Follow => {
333                if next_url.scheme() != "http" && next_url.scheme() != "https" {
334                    return Err(crate::error::url_bad_scheme(next_url));
335                }
336
337                if self.https_only && next_url.scheme() != "https" {
338                    return Err(crate::error::redirect(
339                        crate::error::url_bad_scheme(next_url.clone()),
340                        next_url,
341                    ));
342                }
343                Ok(TowerAction::Follow)
344            }
345            ActionKind::Stop => Ok(TowerAction::Stop),
346            ActionKind::Error(e) => Err(crate::error::redirect(e, previous_url)),
347        }
348    }
349
350    fn on_request(&mut self, req: &mut http::Request<async_impl::body::Body>) {
351        if let Ok(next_url) = Url::parse(&req.uri().to_string()) {
352            remove_sensitive_headers(req.headers_mut(), &next_url, &self.urls);
353            if self.referer {
354                if let Some(previous_url) = self.urls.last() {
355                    if let Some(v) = make_referer(&next_url, previous_url) {
356                        req.headers_mut().insert(REFERER, v);
357                    }
358                }
359            }
360        };
361    }
362
363    // This must be implemented to make 307 and 308 redirects work
364    fn clone_body(&self, body: &async_impl::body::Body) -> Option<async_impl::body::Body> {
365        body.try_clone()
366    }
367}
368
369#[test]
370fn test_redirect_policy_limit() {
371    let policy = Policy::default();
372    let next = Url::parse("http://x.y/z").unwrap();
373    let mut previous = (0..=9)
374        .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
375        .collect::<Vec<_>>();
376
377    match policy.check(StatusCode::FOUND, &next, &previous) {
378        ActionKind::Follow => (),
379        other => panic!("unexpected {other:?}"),
380    }
381
382    previous.push(Url::parse("http://a.b.d/e/33").unwrap());
383
384    match policy.check(StatusCode::FOUND, &next, &previous) {
385        ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
386        other => panic!("unexpected {other:?}"),
387    }
388}
389
390#[test]
391fn test_redirect_policy_limit_to_0() {
392    let policy = Policy::limited(0);
393    let next = Url::parse("http://x.y/z").unwrap();
394    let previous = vec![Url::parse("http://a.b/c").unwrap()];
395
396    match policy.check(StatusCode::FOUND, &next, &previous) {
397        ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
398        other => panic!("unexpected {other:?}"),
399    }
400}
401
402#[test]
403fn test_redirect_policy_custom() {
404    let policy = Policy::custom(|attempt| {
405        if attempt.url().host_str() == Some("foo") {
406            attempt.stop()
407        } else {
408            attempt.follow()
409        }
410    });
411
412    let next = Url::parse("http://bar/baz").unwrap();
413    match policy.check(StatusCode::FOUND, &next, &[]) {
414        ActionKind::Follow => (),
415        other => panic!("unexpected {other:?}"),
416    }
417
418    let next = Url::parse("http://foo/baz").unwrap();
419    match policy.check(StatusCode::FOUND, &next, &[]) {
420        ActionKind::Stop => (),
421        other => panic!("unexpected {other:?}"),
422    }
423}
424
425#[test]
426fn test_remove_sensitive_headers() {
427    use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
428
429    let mut headers = HeaderMap::new();
430    headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
431    headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
432    headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
433
434    let next = Url::parse("http://initial-domain.com/path").unwrap();
435    let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
436    let mut filtered_headers = headers.clone();
437
438    remove_sensitive_headers(&mut headers, &next, &prev);
439    assert_eq!(headers, filtered_headers);
440
441    prev.push(Url::parse("http://new-domain.com/path").unwrap());
442    filtered_headers.remove(AUTHORIZATION);
443    filtered_headers.remove(COOKIE);
444
445    remove_sensitive_headers(&mut headers, &next, &prev);
446    assert_eq!(headers, filtered_headers);
447}