1#![allow(
11 clippy::module_name_repetitions,
12 clippy::struct_excessive_bools,
13 clippy::default_trait_access,
14 clippy::used_underscore_binding
15)]
16use std::{collections::HashSet, sync::Arc, time::Duration};
17
18use http::{
19 StatusCode,
20 header::{HeaderMap, HeaderValue},
21};
22use log::debug;
23use octocrab::Octocrab;
24use regex::RegexSet;
25use reqwest::{header, redirect, tls};
26use reqwest_cookie_store::CookieStoreMutex;
27use secrecy::{ExposeSecret, SecretString};
28use typed_builder::TypedBuilder;
29
30use crate::{
31 Base, BasicAuthCredentials, ErrorKind, Request, Response, Result, Status, Uri,
32 chain::RequestChain,
33 checker::{file::FileChecker, mail::MailChecker, website::WebsiteChecker},
34 filter::Filter,
35 remap::Remaps,
36 types::{DEFAULT_ACCEPTED_STATUS_CODES, redirect_history::RedirectHistory},
37};
38
39pub const DEFAULT_MAX_REDIRECTS: usize = 5;
41pub const DEFAULT_MAX_RETRIES: u64 = 3;
43pub const DEFAULT_RETRY_WAIT_TIME_SECS: usize = 1;
45pub const DEFAULT_TIMEOUT_SECS: usize = 20;
47pub const DEFAULT_USER_AGENT: &str = concat!("lychee/", env!("CARGO_PKG_VERSION"));
49
50const CONNECT_TIMEOUT: u64 = 10;
53const TCP_KEEPALIVE: u64 = 60;
58
59#[derive(TypedBuilder, Debug, Clone)]
63#[builder(field_defaults(default, setter(into)))]
64pub struct ClientBuilder {
65 github_token: Option<SecretString>,
74
75 remaps: Option<Remaps>,
90
91 fallback_extensions: Vec<String>,
95
96 #[builder(default = None)]
110 index_files: Option<Vec<String>>,
111
112 includes: Option<RegexSet>,
118
119 excludes: Option<RegexSet>,
122
123 exclude_all_private: bool,
130
131 exclude_private_ips: bool,
158
159 exclude_link_local_ips: bool,
179
180 exclude_loopback_ips: bool,
196
197 include_mail: bool,
199
200 #[builder(default = DEFAULT_MAX_REDIRECTS)]
204 max_redirects: usize,
205
206 #[builder(default = DEFAULT_MAX_RETRIES)]
210 max_retries: u64,
211
212 min_tls_version: Option<tls::Version>,
214
215 #[builder(default_code = "String::from(DEFAULT_USER_AGENT)")]
225 user_agent: String,
226
227 allow_insecure: bool,
237
238 schemes: HashSet<String>,
243
244 custom_headers: HeaderMap,
252
253 #[builder(default = reqwest::Method::GET)]
255 method: reqwest::Method,
256
257 #[builder(default = DEFAULT_ACCEPTED_STATUS_CODES.clone())]
261 accepted: HashSet<StatusCode>,
262
263 timeout: Option<Duration>,
265
266 base: Option<Base>,
271
272 #[builder(default_code = "Duration::from_secs(DEFAULT_RETRY_WAIT_TIME_SECS as u64)")]
285 retry_wait_time: Duration,
286
287 require_https: bool,
293
294 cookie_jar: Option<Arc<CookieStoreMutex>>,
298
299 include_fragments: bool,
301
302 plugin_request_chain: RequestChain,
307}
308
309impl Default for ClientBuilder {
310 #[inline]
311 fn default() -> Self {
312 Self::builder().build()
313 }
314}
315
316impl ClientBuilder {
317 pub fn client(self) -> Result<Client> {
332 let Self {
333 user_agent,
334 custom_headers: mut headers,
335 ..
336 } = self;
337
338 if let Some(prev_user_agent) =
339 headers.insert(header::USER_AGENT, HeaderValue::try_from(&user_agent)?)
340 {
341 debug!(
342 "Found user-agent in headers: {}. Overriding it with {user_agent}.",
343 prev_user_agent.to_str().unwrap_or("�"),
344 );
345 }
346
347 headers.insert(
348 header::TRANSFER_ENCODING,
349 HeaderValue::from_static("chunked"),
350 );
351
352 let redirect_history = RedirectHistory::new();
353
354 let mut builder = reqwest::ClientBuilder::new()
355 .gzip(true)
356 .default_headers(headers)
357 .danger_accept_invalid_certs(self.allow_insecure)
358 .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT))
359 .tcp_keepalive(Duration::from_secs(TCP_KEEPALIVE))
360 .redirect(redirect_policy(
361 redirect_history.clone(),
362 self.max_redirects,
363 ));
364
365 if let Some(cookie_jar) = self.cookie_jar {
366 builder = builder.cookie_provider(cookie_jar);
367 }
368
369 if let Some(min_tls) = self.min_tls_version {
370 builder = builder.min_tls_version(min_tls);
371 }
372
373 let reqwest_client = match self.timeout {
374 Some(t) => builder.timeout(t),
375 None => builder,
376 }
377 .build()
378 .map_err(ErrorKind::BuildRequestClient)?;
379
380 let github_client = match self.github_token.as_ref().map(ExposeSecret::expose_secret) {
381 Some(token) if !token.is_empty() => Some(
382 Octocrab::builder()
383 .personal_token(token.to_string())
384 .build()
385 .map_err(|e: octocrab::Error| ErrorKind::BuildGithubClient(Box::new(e)))?,
388 ),
389 _ => None,
390 };
391
392 let filter = Filter {
393 includes: self.includes.map(Into::into),
394 excludes: self.excludes.map(Into::into),
395 schemes: self.schemes,
396 exclude_private_ips: self.exclude_all_private || self.exclude_private_ips,
399 exclude_link_local_ips: self.exclude_all_private || self.exclude_link_local_ips,
400 exclude_loopback_ips: self.exclude_all_private || self.exclude_loopback_ips,
401 include_mail: self.include_mail,
402 };
403
404 let website_checker = WebsiteChecker::new(
405 self.method,
406 self.retry_wait_time,
407 redirect_history.clone(),
408 self.max_retries,
409 reqwest_client,
410 self.accepted,
411 github_client,
412 self.require_https,
413 self.plugin_request_chain,
414 self.include_fragments,
415 );
416
417 Ok(Client {
418 remaps: self.remaps,
419 filter,
420 email_checker: MailChecker::new(),
421 website_checker,
422 file_checker: FileChecker::new(
423 self.base,
424 self.fallback_extensions,
425 self.index_files,
426 self.include_fragments,
427 ),
428 })
429 }
430}
431
432fn redirect_policy(redirect_history: RedirectHistory, max_redirects: usize) -> redirect::Policy {
435 redirect::Policy::custom(move |attempt| {
436 if attempt.previous().len() > max_redirects {
437 attempt.stop()
438 } else {
439 let redirects = &[attempt.previous(), &[attempt.url().clone()]].concat();
440 redirect_history.record_redirects(redirects);
441 debug!("Following redirect to {}", attempt.url());
442 attempt.follow()
443 }
444 })
445}
446
447#[derive(Debug, Clone)]
452pub struct Client {
453 remaps: Option<Remaps>,
455
456 filter: Filter,
458
459 website_checker: WebsiteChecker,
461
462 file_checker: FileChecker,
464
465 email_checker: MailChecker,
467}
468
469impl Client {
470 #[allow(clippy::missing_panics_doc)]
482 pub async fn check<T, E>(&self, request: T) -> Result<Response>
483 where
484 Request: TryFrom<T, Error = E>,
485 ErrorKind: From<E>,
486 {
487 let Request {
488 ref mut uri,
489 credentials,
490 source,
491 ..
492 } = request.try_into()?;
493
494 self.remap(uri)?;
495
496 if self.is_excluded(uri) {
497 return Ok(Response::new(uri.clone(), Status::Excluded, source));
498 }
499
500 let status = match uri.scheme() {
501 _ if uri.is_tel() => Status::Excluded,
503 _ if uri.is_file() => self.check_file(uri).await,
504 _ if uri.is_mail() => self.check_mail(uri).await,
505 _ => self.check_website(uri, credentials).await?,
506 };
507
508 Ok(Response::new(uri.clone(), status, source))
509 }
510
511 pub async fn check_file(&self, uri: &Uri) -> Status {
513 self.file_checker.check(uri).await
514 }
515
516 pub fn remap(&self, uri: &mut Uri) -> Result<()> {
522 if let Some(ref remaps) = self.remaps {
523 uri.url = remaps.remap(&uri.url)?;
524 }
525 Ok(())
526 }
527
528 #[must_use]
530 pub fn is_excluded(&self, uri: &Uri) -> bool {
531 self.filter.is_excluded(uri)
532 }
533
534 pub async fn check_website(
544 &self,
545 uri: &Uri,
546 credentials: Option<BasicAuthCredentials>,
547 ) -> Result<Status> {
548 self.website_checker.check_website(uri, credentials).await
549 }
550
551 pub async fn check_mail(&self, uri: &Uri) -> Status {
553 self.email_checker.check_mail(uri).await
554 }
555}
556
557pub async fn check<T, E>(request: T) -> Result<Response>
570where
571 Request: TryFrom<T, Error = E>,
572 ErrorKind: From<E>,
573{
574 let client = ClientBuilder::builder().build().client()?;
575 client.check(request).await
576}
577
578#[cfg(test)]
579mod tests {
580 use std::{
581 fs::File,
582 time::{Duration, Instant},
583 };
584
585 use async_trait::async_trait;
586 use http::{StatusCode, header::HeaderMap};
587 use reqwest::header;
588 use tempfile::tempdir;
589 use test_utils::get_mock_client_response;
590 use test_utils::mock_server;
591 use test_utils::redirecting_mock_server;
592 use wiremock::{
593 Mock,
594 matchers::{method, path},
595 };
596
597 use super::ClientBuilder;
598 use crate::{
599 ErrorKind, Request, Status, Uri,
600 chain::{ChainResult, Handler, RequestChain},
601 };
602
603 #[tokio::test]
604 async fn test_nonexistent() {
605 let mock_server = mock_server!(StatusCode::NOT_FOUND);
606 let res = get_mock_client_response!(mock_server.uri()).await;
607
608 assert!(res.status().is_error());
609 }
610
611 #[tokio::test]
612 async fn test_nonexistent_with_path() {
613 let res = get_mock_client_response!("http://127.0.0.1/invalid").await;
614 assert!(res.status().is_error());
615 }
616
617 #[tokio::test]
618 async fn test_github() {
619 let res = get_mock_client_response!("https://github.com/lycheeverse/lychee").await;
620 assert!(res.status().is_success());
621 }
622
623 #[tokio::test]
624 async fn test_github_nonexistent_repo() {
625 let res = get_mock_client_response!("https://github.com/lycheeverse/not-lychee").await;
626 assert!(res.status().is_error());
627 }
628
629 #[tokio::test]
630 async fn test_github_nonexistent_file() {
631 let res = get_mock_client_response!(
632 "https://github.com/lycheeverse/lychee/blob/master/NON_EXISTENT_FILE.md",
633 )
634 .await;
635 assert!(res.status().is_error());
636 }
637
638 #[tokio::test]
639 async fn test_youtube() {
640 let res = get_mock_client_response!("https://www.youtube.com/watch?v=NlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
642 assert!(res.status().is_success());
643
644 let res = get_mock_client_response!("https://www.youtube.com/watch?v=invalidNlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
645 assert!(res.status().is_error());
646 }
647
648 #[tokio::test]
649 async fn test_basic_auth() {
650 let mut r: Request = "https://authenticationtest.com/HTTPAuth/"
651 .try_into()
652 .unwrap();
653
654 let res = get_mock_client_response!(r.clone()).await;
655 assert_eq!(res.status().code(), Some(401.try_into().unwrap()));
656
657 r.credentials = Some(crate::BasicAuthCredentials {
658 username: "user".into(),
659 password: "pass".into(),
660 });
661
662 let res = get_mock_client_response!(r).await;
663 assert!(matches!(
664 res.status(),
665 Status::Redirected(StatusCode::OK, _)
666 ));
667 }
668
669 #[tokio::test]
670 async fn test_non_github() {
671 let mock_server = mock_server!(StatusCode::OK);
672 let res = get_mock_client_response!(mock_server.uri()).await;
673
674 assert!(res.status().is_success());
675 }
676
677 #[tokio::test]
678 async fn test_invalid_ssl() {
679 let res = get_mock_client_response!("https://expired.badssl.com/").await;
680
681 assert!(res.status().is_error());
682
683 let res = ClientBuilder::builder()
685 .allow_insecure(true)
686 .build()
687 .client()
688 .unwrap()
689 .check("https://expired.badssl.com/")
690 .await
691 .unwrap();
692 assert!(res.status().is_success());
693 }
694
695 #[tokio::test]
696 async fn test_file() {
697 let dir = tempdir().unwrap();
698 let file = dir.path().join("temp");
699 File::create(file).unwrap();
700 let uri = format!("file://{}", dir.path().join("temp").to_str().unwrap());
701
702 let res = get_mock_client_response!(uri).await;
703 assert!(res.status().is_success());
704 }
705
706 #[tokio::test]
707 async fn test_custom_headers() {
708 let mut custom = HeaderMap::new();
710 custom.insert(header::ACCEPT, "text/html".parse().unwrap());
711 let res = ClientBuilder::builder()
712 .custom_headers(custom)
713 .build()
714 .client()
715 .unwrap()
716 .check("https://crates.io/crates/lychee")
717 .await
718 .unwrap();
719 assert!(res.status().is_success());
720 }
721
722 #[tokio::test]
723 async fn test_exclude_mail_by_default() {
724 let client = ClientBuilder::builder()
725 .exclude_all_private(true)
726 .build()
727 .client()
728 .unwrap();
729 assert!(client.is_excluded(&Uri {
730 url: "mailto://mail@example.com".try_into().unwrap()
731 }));
732 }
733
734 #[tokio::test]
735 async fn test_include_mail() {
736 let client = ClientBuilder::builder()
737 .include_mail(false)
738 .exclude_all_private(true)
739 .build()
740 .client()
741 .unwrap();
742 assert!(client.is_excluded(&Uri {
743 url: "mailto://mail@example.com".try_into().unwrap()
744 }));
745
746 let client = ClientBuilder::builder()
747 .include_mail(true)
748 .exclude_all_private(true)
749 .build()
750 .client()
751 .unwrap();
752 assert!(!client.is_excluded(&Uri {
753 url: "mailto://mail@example.com".try_into().unwrap()
754 }));
755 }
756
757 #[tokio::test]
758 async fn test_include_tel() {
759 let client = ClientBuilder::builder().build().client().unwrap();
760 assert!(client.is_excluded(&Uri {
761 url: "tel:1234567890".try_into().unwrap()
762 }));
763 }
764
765 #[tokio::test]
766 async fn test_require_https() {
767 let client = ClientBuilder::builder().build().client().unwrap();
768 let res = client.check("http://example.com").await.unwrap();
769 assert!(res.status().is_success());
770
771 let client = ClientBuilder::builder()
773 .require_https(true)
774 .build()
775 .client()
776 .unwrap();
777 let res = client.check("http://example.com").await.unwrap();
778 assert!(res.status().is_error());
779 }
780
781 #[tokio::test]
782 async fn test_timeout() {
783 let mock_delay = Duration::from_millis(20);
787 let checker_timeout = Duration::from_millis(10);
788 assert!(mock_delay > checker_timeout);
789
790 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
791
792 let client = ClientBuilder::builder()
793 .timeout(checker_timeout)
794 .build()
795 .client()
796 .unwrap();
797
798 let res = client.check(mock_server.uri()).await.unwrap();
799 assert!(res.status().is_timeout());
800 }
801
802 #[tokio::test]
803 async fn test_exponential_backoff() {
804 let mock_delay = Duration::from_millis(20);
805 let checker_timeout = Duration::from_millis(10);
806 assert!(mock_delay > checker_timeout);
807
808 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
809
810 let warm_up_client = ClientBuilder::builder()
815 .max_retries(0_u64)
816 .build()
817 .client()
818 .unwrap();
819 let _res = warm_up_client.check(mock_server.uri()).await.unwrap();
820
821 let client = ClientBuilder::builder()
822 .timeout(checker_timeout)
823 .max_retries(3_u64)
824 .retry_wait_time(Duration::from_millis(50))
825 .build()
826 .client()
827 .unwrap();
828
829 let start = Instant::now();
839 let res = client.check(mock_server.uri()).await.unwrap();
840 let end = start.elapsed();
841
842 assert!(res.status().is_error());
843
844 assert!((350..=550).contains(&end.as_millis()));
847 }
848
849 #[tokio::test]
850 async fn test_avoid_reqwest_panic() {
851 let client = ClientBuilder::builder().build().client().unwrap();
852 let res = client.check("http://\"").await.unwrap();
854
855 assert!(matches!(
856 res.status(),
857 Status::Unsupported(ErrorKind::BuildRequestClient(_))
858 ));
859 assert!(res.status().is_unsupported());
860 }
861
862 #[tokio::test]
863 async fn test_max_redirects() {
864 let mock_server = wiremock::MockServer::start().await;
865
866 let redirect_uri = format!("{}/redirect", &mock_server.uri());
867 let redirect = wiremock::ResponseTemplate::new(StatusCode::PERMANENT_REDIRECT)
868 .insert_header("Location", redirect_uri.as_str());
869
870 let redirect_count = 15usize;
871 let initial_invocation = 1;
872
873 Mock::given(method("GET"))
875 .and(path("/redirect"))
876 .respond_with(move |_: &_| redirect.clone())
877 .expect(initial_invocation + redirect_count as u64)
878 .mount(&mock_server)
879 .await;
880
881 let res = ClientBuilder::builder()
882 .max_redirects(redirect_count)
883 .build()
884 .client()
885 .unwrap()
886 .check(redirect_uri.clone())
887 .await
888 .unwrap();
889
890 assert_eq!(
891 res.status(),
892 &Status::Error(ErrorKind::RejectedStatusCode(
893 StatusCode::PERMANENT_REDIRECT
894 ))
895 );
896 }
897
898 #[tokio::test]
899 async fn test_redirects() {
900 redirecting_mock_server!(async |redirect_url: Url, ok_ur| {
901 let res = ClientBuilder::builder()
902 .max_redirects(1_usize)
903 .build()
904 .client()
905 .unwrap()
906 .check(Uri::from((redirect_url).clone()))
907 .await
908 .unwrap();
909
910 assert_eq!(
911 res.status(),
912 &Status::Redirected(StatusCode::OK, vec![redirect_url, ok_ur].into())
913 );
914 })
915 .await;
916 }
917
918 #[tokio::test]
919 async fn test_unsupported_scheme() {
920 let examples = vec![
921 "ftp://example.com",
922 "gopher://example.com",
923 "slack://example.com",
924 ];
925
926 for example in examples {
927 let client = ClientBuilder::builder().build().client().unwrap();
928 let res = client.check(example).await.unwrap();
929 assert!(res.status().is_unsupported());
930 }
931 }
932
933 #[tokio::test]
934 async fn test_chain() {
935 use reqwest::Request;
936
937 #[derive(Debug)]
938 struct ExampleHandler();
939
940 #[async_trait]
941 impl Handler<Request, Status> for ExampleHandler {
942 async fn handle(&mut self, _: Request) -> ChainResult<Request, Status> {
943 ChainResult::Done(Status::Excluded)
944 }
945 }
946
947 let chain = RequestChain::new(vec![Box::new(ExampleHandler {})]);
948
949 let client = ClientBuilder::builder()
950 .plugin_request_chain(chain)
951 .build()
952 .client()
953 .unwrap();
954
955 let result = client.check("http://example.com");
956 let res = result.await.unwrap();
957 assert_eq!(res.status(), &Status::Excluded);
958 }
959}