1use std::fmt;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::{error::Error as StdError, sync::Arc};
10
11use crate::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, REFERER, WWW_AUTHENTICATE};
12use http::{HeaderMap, HeaderValue};
13use hyper::StatusCode;
14
15use crate::{async_impl, Url};
16use tower_http::follow_redirect::policy::{
17 Action as TowerAction, Attempt as TowerAttempt, Policy as TowerPolicy,
18};
19
20pub struct Policy {
30 pub(crate) inner: PolicyKind,
31}
32
33#[derive(Debug)]
36pub struct Attempt<'a> {
37 status: StatusCode,
38 next: &'a Url,
39 previous: &'a [Url],
40}
41
42#[derive(Debug)]
44pub struct Action {
45 inner: ActionKind,
46}
47
48impl Policy {
49 pub fn limited(max: usize) -> Self {
53 Self {
54 inner: PolicyKind::Limit(max),
55 }
56 }
57
58 pub fn none() -> Self {
60 Self {
61 inner: PolicyKind::None,
62 }
63 }
64
65 pub fn custom<T>(policy: T) -> Self
104 where
105 T: Fn(Attempt) -> Action + Send + Sync + 'static,
106 {
107 Self {
108 inner: PolicyKind::Custom(Box::new(policy)),
109 }
110 }
111
112 pub fn redirect(&self, attempt: Attempt) -> Action {
133 match self.inner {
134 PolicyKind::Custom(ref custom) => custom(attempt),
135 PolicyKind::Limit(max) => {
136 if attempt.previous.len() > max {
138 attempt.error(TooManyRedirects)
139 } else {
140 attempt.follow()
141 }
142 }
143 PolicyKind::None => attempt.stop(),
144 }
145 }
146
147 pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
148 self.redirect(Attempt {
149 status,
150 next,
151 previous,
152 })
153 .inner
154 }
155
156 pub(crate) fn is_default(&self) -> bool {
157 matches!(self.inner, PolicyKind::Limit(10))
158 }
159}
160
161impl Default for Policy {
162 fn default() -> Policy {
163 Policy::limited(10)
165 }
166}
167
168impl Attempt<'_> {
169 pub fn status(&self) -> StatusCode {
171 self.status
172 }
173
174 pub fn url(&self) -> &Url {
176 self.next
177 }
178
179 pub fn previous(&self) -> &[Url] {
181 self.previous
182 }
183 pub fn follow(self) -> Action {
185 Action {
186 inner: ActionKind::Follow,
187 }
188 }
189
190 pub fn stop(self) -> Action {
194 Action {
195 inner: ActionKind::Stop,
196 }
197 }
198
199 pub fn error<E: Into<Box<dyn StdError + Send + Sync>>>(self, error: E) -> Action {
203 Action {
204 inner: ActionKind::Error(error.into()),
205 }
206 }
207}
208
209pub(crate) enum PolicyKind {
210 Custom(Box<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
211 Limit(usize),
212 None,
213}
214
215impl fmt::Debug for Policy {
216 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
217 f.debug_tuple("Policy").field(&self.inner).finish()
218 }
219}
220
221impl fmt::Debug for PolicyKind {
222 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
223 match *self {
224 PolicyKind::Custom(..) => f.pad("Custom"),
225 PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
226 PolicyKind::None => f.pad("None"),
227 }
228 }
229}
230
231#[derive(Debug)]
234pub(crate) enum ActionKind {
235 Follow,
236 Stop,
237 Error(Box<dyn StdError + Send + Sync>),
238}
239
240pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
241 if let Some(previous) = previous.last() {
242 let cross_host = next.host_str() != previous.host_str()
243 || next.port_or_known_default() != previous.port_or_known_default();
244 if cross_host {
245 headers.remove(AUTHORIZATION);
246 headers.remove(COOKIE);
247 headers.remove("cookie2");
248 headers.remove(PROXY_AUTHORIZATION);
249 headers.remove(WWW_AUTHENTICATE);
250 }
251 }
252}
253
254#[derive(Debug)]
255struct TooManyRedirects;
256
257impl fmt::Display for TooManyRedirects {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 f.write_str("too many redirects")
260 }
261}
262
263impl StdError for TooManyRedirects {}
264
265#[derive(Clone)]
266pub(crate) struct TowerRedirectPolicy {
267 policy: Arc<Policy>,
268 referer: bool,
269 urls: Vec<Url>,
270 https_only: bool,
271 redirect_enabled: Arc<AtomicBool>,
272}
273
274impl TowerRedirectPolicy {
275 pub(crate) fn new(policy: Policy) -> Self {
276 let enabled = !matches!(policy.inner, PolicyKind::None);
277 Self {
278 policy: Arc::new(policy),
279 referer: false,
280 urls: Vec::new(),
281 https_only: false,
282 redirect_enabled: Arc::new(AtomicBool::new(enabled)),
283 }
284 }
285
286 pub(crate) fn with_referer(&mut self, referer: bool) -> &mut Self {
287 self.referer = referer;
288 self
289 }
290
291 pub(crate) fn with_https_only(&mut self, https_only: bool) -> &mut Self {
292 self.https_only = https_only;
293 self
294 }
295
296 pub(crate) fn redirect_enabled_ref(&self) -> Arc<AtomicBool> {
297 self.redirect_enabled.clone()
298 }
299}
300
301fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
302 if next.scheme() == "http" && previous.scheme() == "https" {
303 return None;
304 }
305
306 let mut referer = previous.clone();
307 let _ = referer.set_username("");
308 let _ = referer.set_password(None);
309 referer.set_fragment(None);
310 referer.as_str().parse().ok()
311}
312
313impl TowerPolicy<async_impl::body::Body, crate::Error> for TowerRedirectPolicy {
314 fn redirect(&mut self, attempt: &TowerAttempt<'_>) -> Result<TowerAction, crate::Error> {
315 if !self.redirect_enabled.load(Ordering::Relaxed) {
317 return Ok(TowerAction::Stop);
318 }
319
320 let previous_url =
321 Url::parse(&attempt.previous().to_string()).expect("Previous URL must be valid");
322
323 let next_url = match Url::parse(&attempt.location().to_string()) {
324 Ok(url) => url,
325 Err(e) => return Err(crate::error::builder(e)),
326 };
327
328 self.urls.push(previous_url.clone());
329
330 match self.policy.check(attempt.status(), &next_url, &self.urls) {
331 ActionKind::Follow => {
332 if next_url.scheme() != "http" && next_url.scheme() != "https" {
333 return Err(crate::error::url_bad_scheme(next_url));
334 }
335
336 if self.https_only && next_url.scheme() != "https" {
337 return Err(crate::error::redirect(
338 crate::error::url_bad_scheme(next_url.clone()),
339 next_url,
340 ));
341 }
342 Ok(TowerAction::Follow)
343 }
344 ActionKind::Stop => Ok(TowerAction::Stop),
345 ActionKind::Error(e) => Err(crate::error::redirect(e, previous_url)),
346 }
347 }
348
349 fn on_request(&mut self, req: &mut http::Request<async_impl::body::Body>) {
350 if let Ok(next_url) = Url::parse(&req.uri().to_string()) {
351 remove_sensitive_headers(req.headers_mut(), &next_url, &self.urls);
352 if self.referer {
353 if let Some(previous_url) = self.urls.last() {
354 if let Some(v) = make_referer(&next_url, previous_url) {
355 req.headers_mut().insert(REFERER, v);
356 }
357 }
358 }
359 };
360 }
361
362 fn clone_body(&self, body: &async_impl::body::Body) -> Option<async_impl::body::Body> {
364 body.try_clone()
365 }
366}
367
368#[test]
369fn test_redirect_policy_limit() {
370 let policy = Policy::default();
371 let next = Url::parse("http://x.y/z").unwrap();
372 let mut previous = (0..=9)
373 .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
374 .collect::<Vec<_>>();
375
376 match policy.check(StatusCode::FOUND, &next, &previous) {
377 ActionKind::Follow => (),
378 other => panic!("unexpected {other:?}"),
379 }
380
381 previous.push(Url::parse("http://a.b.d/e/33").unwrap());
382
383 match policy.check(StatusCode::FOUND, &next, &previous) {
384 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
385 other => panic!("unexpected {other:?}"),
386 }
387}
388
389#[test]
390fn test_redirect_policy_limit_to_0() {
391 let policy = Policy::limited(0);
392 let next = Url::parse("http://x.y/z").unwrap();
393 let previous = vec![Url::parse("http://a.b/c").unwrap()];
394
395 match policy.check(StatusCode::FOUND, &next, &previous) {
396 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
397 other => panic!("unexpected {other:?}"),
398 }
399}
400
401#[test]
402fn test_redirect_policy_custom() {
403 let policy = Policy::custom(|attempt| {
404 if attempt.url().host_str() == Some("foo") {
405 attempt.stop()
406 } else {
407 attempt.follow()
408 }
409 });
410
411 let next = Url::parse("http://bar/baz").unwrap();
412 match policy.check(StatusCode::FOUND, &next, &[]) {
413 ActionKind::Follow => (),
414 other => panic!("unexpected {other:?}"),
415 }
416
417 let next = Url::parse("http://foo/baz").unwrap();
418 match policy.check(StatusCode::FOUND, &next, &[]) {
419 ActionKind::Stop => (),
420 other => panic!("unexpected {other:?}"),
421 }
422}
423
424#[test]
425fn test_remove_sensitive_headers() {
426 use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
427
428 let mut headers = HeaderMap::new();
429 headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
430 headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
431 headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
432
433 let next = Url::parse("http://initial-domain.com/path").unwrap();
434 let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
435 let mut filtered_headers = headers.clone();
436
437 remove_sensitive_headers(&mut headers, &next, &prev);
438 assert_eq!(headers, filtered_headers);
439
440 prev.push(Url::parse("http://new-domain.com/path").unwrap());
441 filtered_headers.remove(AUTHORIZATION);
442 filtered_headers.remove(COOKIE);
443
444 remove_sensitive_headers(&mut headers, &next, &prev);
445 assert_eq!(headers, filtered_headers);
446}