1use std::fmt;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::{error::Error as StdError, sync::Arc};
10
11use crate::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, REFERER, WWW_AUTHENTICATE};
12use http::{HeaderMap, HeaderValue};
13use hyper::StatusCode;
14
15use crate::{async_impl, Url};
16use tower_http::follow_redirect::policy::{
17 Action as TowerAction, Attempt as TowerAttempt, Policy as TowerPolicy,
18};
19
20pub struct Policy {
30 pub(crate) inner: PolicyKind,
31}
32
33#[derive(Debug)]
36pub struct Attempt<'a> {
37 status: StatusCode,
38 next: &'a Url,
39 previous: &'a [Url],
40}
41
42#[derive(Debug)]
44pub struct Action {
45 inner: ActionKind,
46}
47
48impl Policy {
49 pub fn limited(max: usize) -> Self {
53 Self {
54 inner: PolicyKind::Limit(max),
55 }
56 }
57
58 pub fn none() -> Self {
60 Self {
61 inner: PolicyKind::None,
62 }
63 }
64
65 pub fn custom<T>(policy: T) -> Self
104 where
105 T: Fn(Attempt) -> Action + Send + Sync + 'static,
106 {
107 Self {
108 inner: PolicyKind::Custom(Box::new(policy)),
109 }
110 }
111
112 pub fn redirect(&self, attempt: Attempt) -> Action {
133 match self.inner {
134 PolicyKind::Custom(ref custom) => custom(attempt),
135 PolicyKind::Limit(max) => {
136 if attempt.previous.len() > max {
138 attempt.error(TooManyRedirects)
139 } else {
140 attempt.follow()
141 }
142 }
143 PolicyKind::None => attempt.stop(),
144 }
145 }
146
147 pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
148 self.redirect(Attempt {
149 status,
150 next,
151 previous,
152 })
153 .inner
154 }
155
156 pub(crate) fn is_default(&self) -> bool {
157 matches!(self.inner, PolicyKind::Limit(10))
158 }
159}
160
161impl Default for Policy {
162 fn default() -> Policy {
163 Policy::limited(10)
165 }
166}
167
168impl Attempt<'_> {
169 pub fn status(&self) -> StatusCode {
171 self.status
172 }
173
174 pub fn url(&self) -> &Url {
176 self.next
177 }
178
179 pub fn previous(&self) -> &[Url] {
181 self.previous
182 }
183 pub fn follow(self) -> Action {
185 Action {
186 inner: ActionKind::Follow,
187 }
188 }
189
190 pub fn stop(self) -> Action {
194 Action {
195 inner: ActionKind::Stop,
196 }
197 }
198
199 pub fn error<E: Into<Box<dyn StdError + Send + Sync>>>(self, error: E) -> Action {
203 Action {
204 inner: ActionKind::Error(error.into()),
205 }
206 }
207}
208
209pub(crate) enum PolicyKind {
210 Custom(Box<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
211 Limit(usize),
212 None,
213}
214
215impl fmt::Debug for Policy {
216 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
217 f.debug_tuple("Policy").field(&self.inner).finish()
218 }
219}
220
221impl fmt::Debug for PolicyKind {
222 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
223 match *self {
224 PolicyKind::Custom(..) => f.pad("Custom"),
225 PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
226 PolicyKind::None => f.pad("None"),
227 }
228 }
229}
230
231#[derive(Debug)]
234pub(crate) enum ActionKind {
235 Follow,
236 Stop,
237 Error(Box<dyn StdError + Send + Sync>),
238}
239
240pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
241 if let Some(previous) = previous.last() {
242 let cross_host = next.host_str() != previous.host_str()
243 || next.port_or_known_default() != previous.port_or_known_default()
244 || next.scheme() != previous.scheme();
245 if cross_host {
246 headers.remove(AUTHORIZATION);
247 headers.remove(COOKIE);
248 headers.remove("cookie2");
249 headers.remove(PROXY_AUTHORIZATION);
250 headers.remove(WWW_AUTHENTICATE);
251 }
252 }
253}
254
255#[derive(Debug)]
256struct TooManyRedirects;
257
258impl fmt::Display for TooManyRedirects {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 f.write_str("too many redirects")
261 }
262}
263
264impl StdError for TooManyRedirects {}
265
266#[derive(Clone)]
267pub(crate) struct TowerRedirectPolicy {
268 policy: Arc<Policy>,
269 referer: bool,
270 urls: Vec<Url>,
271 https_only: bool,
272 redirect_enabled: Arc<AtomicBool>,
273}
274
275impl TowerRedirectPolicy {
276 pub(crate) fn new(policy: Policy) -> Self {
277 let enabled = !matches!(policy.inner, PolicyKind::None);
278 Self {
279 policy: Arc::new(policy),
280 referer: false,
281 urls: Vec::new(),
282 https_only: false,
283 redirect_enabled: Arc::new(AtomicBool::new(enabled)),
284 }
285 }
286
287 pub(crate) fn with_referer(&mut self, referer: bool) -> &mut Self {
288 self.referer = referer;
289 self
290 }
291
292 pub(crate) fn with_https_only(&mut self, https_only: bool) -> &mut Self {
293 self.https_only = https_only;
294 self
295 }
296
297 pub(crate) fn redirect_enabled_ref(&self) -> Arc<AtomicBool> {
298 self.redirect_enabled.clone()
299 }
300}
301
302fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
303 if next.scheme() == "http" && previous.scheme() == "https" {
304 return None;
305 }
306
307 let mut referer = previous.clone();
308 let _ = referer.set_username("");
309 let _ = referer.set_password(None);
310 referer.set_fragment(None);
311 referer.as_str().parse().ok()
312}
313
314impl TowerPolicy<async_impl::body::Body, crate::Error> for TowerRedirectPolicy {
315 fn redirect(&mut self, attempt: &TowerAttempt<'_>) -> Result<TowerAction, crate::Error> {
316 if !self.redirect_enabled.load(Ordering::Relaxed) {
318 return Ok(TowerAction::Stop);
319 }
320
321 let previous_url =
322 Url::parse(&attempt.previous().to_string()).expect("Previous URL must be valid");
323
324 let next_url = match Url::parse(&attempt.location().to_string()) {
325 Ok(url) => url,
326 Err(e) => return Err(crate::error::builder(e)),
327 };
328
329 self.urls.push(previous_url.clone());
330
331 match self.policy.check(attempt.status(), &next_url, &self.urls) {
332 ActionKind::Follow => {
333 if next_url.scheme() != "http" && next_url.scheme() != "https" {
334 return Err(crate::error::url_bad_scheme(next_url));
335 }
336
337 if self.https_only && next_url.scheme() != "https" {
338 return Err(crate::error::redirect(
339 crate::error::url_bad_scheme(next_url.clone()),
340 next_url,
341 ));
342 }
343 Ok(TowerAction::Follow)
344 }
345 ActionKind::Stop => Ok(TowerAction::Stop),
346 ActionKind::Error(e) => Err(crate::error::redirect(e, previous_url)),
347 }
348 }
349
350 fn on_request(&mut self, req: &mut http::Request<async_impl::body::Body>) {
351 if let Ok(next_url) = Url::parse(&req.uri().to_string()) {
352 remove_sensitive_headers(req.headers_mut(), &next_url, &self.urls);
353 if self.referer {
354 if let Some(previous_url) = self.urls.last() {
355 if let Some(v) = make_referer(&next_url, previous_url) {
356 req.headers_mut().insert(REFERER, v);
357 }
358 }
359 }
360 };
361 }
362
363 fn clone_body(&self, body: &async_impl::body::Body) -> Option<async_impl::body::Body> {
365 body.try_clone()
366 }
367}
368
369#[test]
370fn test_redirect_policy_limit() {
371 let policy = Policy::default();
372 let next = Url::parse("http://x.y/z").unwrap();
373 let mut previous = (0..=9)
374 .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
375 .collect::<Vec<_>>();
376
377 match policy.check(StatusCode::FOUND, &next, &previous) {
378 ActionKind::Follow => (),
379 other => panic!("unexpected {other:?}"),
380 }
381
382 previous.push(Url::parse("http://a.b.d/e/33").unwrap());
383
384 match policy.check(StatusCode::FOUND, &next, &previous) {
385 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
386 other => panic!("unexpected {other:?}"),
387 }
388}
389
390#[test]
391fn test_redirect_policy_limit_to_0() {
392 let policy = Policy::limited(0);
393 let next = Url::parse("http://x.y/z").unwrap();
394 let previous = vec![Url::parse("http://a.b/c").unwrap()];
395
396 match policy.check(StatusCode::FOUND, &next, &previous) {
397 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
398 other => panic!("unexpected {other:?}"),
399 }
400}
401
402#[test]
403fn test_redirect_policy_custom() {
404 let policy = Policy::custom(|attempt| {
405 if attempt.url().host_str() == Some("foo") {
406 attempt.stop()
407 } else {
408 attempt.follow()
409 }
410 });
411
412 let next = Url::parse("http://bar/baz").unwrap();
413 match policy.check(StatusCode::FOUND, &next, &[]) {
414 ActionKind::Follow => (),
415 other => panic!("unexpected {other:?}"),
416 }
417
418 let next = Url::parse("http://foo/baz").unwrap();
419 match policy.check(StatusCode::FOUND, &next, &[]) {
420 ActionKind::Stop => (),
421 other => panic!("unexpected {other:?}"),
422 }
423}
424
425#[test]
426fn test_remove_sensitive_headers() {
427 use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
428
429 let mut headers = HeaderMap::new();
430 headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
431 headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
432 headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
433
434 let next = Url::parse("http://initial-domain.com/path").unwrap();
435 let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
436 let mut filtered_headers = headers.clone();
437
438 remove_sensitive_headers(&mut headers, &next, &prev);
439 assert_eq!(headers, filtered_headers);
440
441 prev.push(Url::parse("http://new-domain.com/path").unwrap());
442 filtered_headers.remove(AUTHORIZATION);
443 filtered_headers.remove(COOKIE);
444
445 remove_sensitive_headers(&mut headers, &next, &prev);
446 assert_eq!(headers, filtered_headers);
447}