1#![allow(
11 clippy::module_name_repetitions,
12 clippy::struct_excessive_bools,
13 clippy::default_trait_access,
14 clippy::used_underscore_binding
15)]
16use crate::remap::Remap;
17use std::{collections::HashSet, sync::Arc, time::Duration};
18
19use http::{
20 StatusCode,
21 header::{HeaderMap, HeaderValue},
22};
23use log::debug;
24use octocrab::Octocrab;
25use regex::RegexSet;
26use reqwest::{header, redirect, tls};
27use reqwest_cookie_store::CookieStoreMutex;
28use secrecy::{ExposeSecret, SecretString};
29use typed_builder::TypedBuilder;
30
31use crate::{
32 BaseInfo, BasicAuthCredentials, ErrorKind, Request, Response, Result, Status, Uri,
33 chain::RequestChain,
34 checker::{file::FileChecker, mail::MailChecker, website::WebsiteChecker},
35 filter::Filter,
36 ratelimit::{ClientMap, HostConfigs, HostKey, HostPool, RateLimitConfig},
37 remap::Remaps,
38 types::{DEFAULT_ACCEPTED_STATUS_CODES, Redirects, redirect_history::RedirectHistory},
39};
40
41pub const DEFAULT_MAX_REDIRECTS: usize = 10;
43pub const DEFAULT_MAX_RETRIES: u64 = 3;
45pub const DEFAULT_RETRY_WAIT_TIME_SECS: u64 = 1;
47pub const DEFAULT_TIMEOUT_SECS: u64 = 20;
49pub const DEFAULT_USER_AGENT: &str = concat!("lychee/", env!("CARGO_PKG_VERSION"));
51
52const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
55const TCP_KEEPALIVE: Duration = Duration::from_secs(60);
60
61#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
63pub struct FragmentCheckerOptions {
64 pub check_anchor_fragments: bool,
66 pub check_text_fragments: bool,
68}
69
70impl FragmentCheckerOptions {
71 #[must_use]
73 pub const fn any_enabled(self) -> bool {
74 self.check_anchor_fragments || self.check_text_fragments
75 }
76}
77
78#[derive(TypedBuilder, Debug, Clone)]
82#[builder(field_defaults(default, setter(into)))]
83pub struct ClientBuilder {
84 github_token: Option<SecretString>,
93
94 remaps: Option<Remaps>,
109
110 fallback_extensions: Vec<String>,
114
115 #[builder(default = None)]
129 index_files: Option<Vec<String>>,
130
131 includes: Option<RegexSet>,
137
138 excludes: Option<RegexSet>,
141
142 exclude_all_private: bool,
149
150 exclude_private_ips: bool,
177
178 exclude_link_local_ips: bool,
198
199 exclude_loopback_ips: bool,
215
216 include_mail: bool,
218
219 #[builder(default = DEFAULT_MAX_REDIRECTS)]
223 max_redirects: usize,
224
225 #[builder(default = DEFAULT_MAX_RETRIES)]
229 max_retries: u64,
230
231 min_tls_version: Option<tls::Version>,
233
234 #[builder(default_code = "String::from(DEFAULT_USER_AGENT)")]
244 user_agent: String,
245
246 allow_insecure: bool,
256
257 schemes: HashSet<String>,
262
263 custom_headers: HeaderMap,
271
272 #[builder(default = reqwest::Method::GET)]
274 method: reqwest::Method,
275
276 #[builder(default = DEFAULT_ACCEPTED_STATUS_CODES.clone())]
280 accepted: HashSet<StatusCode>,
281
282 timeout: Option<Duration>,
284
285 base: BaseInfo,
290
291 #[builder(default_code = "Duration::from_secs(DEFAULT_RETRY_WAIT_TIME_SECS as u64)")]
304 retry_wait_time: Duration,
305
306 require_https: bool,
312
313 cookie_jar: Option<Arc<CookieStoreMutex>>,
317
318 fragment_checker_options: FragmentCheckerOptions,
320
321 include_wikilinks: bool,
324
325 plugin_request_chain: RequestChain,
330
331 rate_limit_config: RateLimitConfig,
333
334 hosts: HostConfigs,
336}
337
338impl Default for ClientBuilder {
339 #[inline]
340 fn default() -> Self {
341 Self::builder().build()
342 }
343}
344
345impl ClientBuilder {
346 pub fn client(self) -> Result<Client> {
361 let redirect_history = RedirectHistory::new();
362 let reqwest_client = self
363 .build_client(redirect_history.clone())?
364 .build()
365 .map_err(ErrorKind::BuildRequestClient)?;
366
367 let client_map = self.build_host_clients(&redirect_history)?;
368
369 let host_pool = HostPool::new(
370 self.rate_limit_config,
371 self.hosts,
372 reqwest_client,
373 client_map,
374 );
375
376 let github_client = match self.github_token.as_ref().map(ExposeSecret::expose_secret) {
377 Some(token) if !token.is_empty() => Some(
378 Octocrab::builder()
379 .personal_token(token.to_string())
380 .build()
381 .map_err(|e: octocrab::Error| ErrorKind::BuildGithubClient(Box::new(e)))?,
384 ),
385 _ => None,
386 };
387
388 let filter = Filter {
389 includes: self.includes.map(Into::into),
390 excludes: self.excludes.map(Into::into),
391 schemes: self.schemes,
392 exclude_private_ips: self.exclude_all_private || self.exclude_private_ips,
395 exclude_link_local_ips: self.exclude_all_private || self.exclude_link_local_ips,
396 exclude_loopback_ips: self.exclude_all_private || self.exclude_loopback_ips,
397 include_mail: self.include_mail,
398 };
399
400 let website_checker = WebsiteChecker::new(
401 self.method,
402 self.retry_wait_time,
403 redirect_history.clone(),
404 self.max_retries,
405 self.accepted,
406 github_client,
407 self.require_https,
408 self.plugin_request_chain,
409 self.fragment_checker_options,
410 Arc::new(host_pool),
411 );
412
413 Ok(Client {
414 remaps: self.remaps,
415 filter,
416 email_checker: MailChecker::new(self.timeout),
417 website_checker,
418 file_checker: FileChecker::new(
419 &self.base,
420 self.fallback_extensions,
421 self.index_files,
422 self.fragment_checker_options,
423 self.include_wikilinks,
424 )?,
425 })
426 }
427
428 fn build_host_clients(&self, redirect_history: &RedirectHistory) -> Result<ClientMap> {
430 self.hosts
431 .iter()
432 .map(|(host, config)| {
433 let mut headers = self.default_headers()?;
434 headers.extend(config.headers.clone());
435 let client = self
436 .build_client(redirect_history.clone())?
437 .default_headers(headers)
438 .build()
439 .map_err(ErrorKind::BuildRequestClient)?;
440 Ok((HostKey::from(host.as_str()), client))
441 })
442 .collect()
443 }
444
445 fn build_client(&self, redirect_history: RedirectHistory) -> Result<reqwest::ClientBuilder> {
447 let mut builder = reqwest::ClientBuilder::new()
448 .gzip(true)
449 .default_headers(self.default_headers()?)
450 .danger_accept_invalid_certs(self.allow_insecure)
451 .connect_timeout(CONNECT_TIMEOUT)
452 .tcp_keepalive(TCP_KEEPALIVE)
453 .redirect(redirect_policy(redirect_history, self.max_redirects));
454
455 if let Some(cookie_jar) = self.cookie_jar.clone() {
456 builder = builder.cookie_provider(cookie_jar);
457 }
458
459 if let Some(min_tls) = self.min_tls_version {
460 builder = builder.min_tls_version(min_tls);
461 }
462
463 if let Some(timeout) = self.timeout {
464 builder = builder.timeout(timeout);
465 }
466
467 Ok(builder)
468 }
469
470 fn default_headers(&self) -> Result<HeaderMap> {
471 let user_agent = self.user_agent.clone();
472 let mut headers = self.custom_headers.clone();
473
474 if let Some(prev_user_agent) =
475 headers.insert(header::USER_AGENT, HeaderValue::try_from(&user_agent)?)
476 {
477 debug!(
478 "Found user-agent in headers: {}. Overriding it with {user_agent}.",
479 prev_user_agent.to_str().unwrap_or("�"),
480 );
481 }
482
483 headers.insert(
484 header::TRANSFER_ENCODING,
485 HeaderValue::from_static("chunked"),
486 );
487
488 Ok(headers)
489 }
490}
491
492fn redirect_policy(redirect_history: RedirectHistory, max_redirects: usize) -> redirect::Policy {
495 redirect::Policy::custom(move |attempt| {
496 if attempt.previous().len() > max_redirects {
497 attempt.stop()
498 } else {
499 redirect_history.record_redirects(&attempt);
500 debug!("Following redirect to {}", attempt.url());
501 attempt.follow()
502 }
503 })
504}
505
506#[derive(Debug, Clone)]
511pub struct Client {
512 remaps: Option<Remaps>,
514
515 filter: Filter,
517
518 website_checker: WebsiteChecker,
520
521 file_checker: FileChecker,
523
524 email_checker: MailChecker,
526}
527
528impl Client {
529 #[must_use]
531 pub fn host_pool(&self) -> Arc<HostPool> {
532 self.website_checker.host_pool()
533 }
534
535 #[allow(clippy::missing_panics_doc)]
547 pub async fn check<T, E>(&self, request: T) -> Result<Response>
548 where
549 Request: TryFrom<T, Error = E>,
550 ErrorKind: From<E>,
551 {
552 let Request {
553 mut uri,
554 credentials,
555 source,
556 span,
557 ..
558 } = request.try_into()?;
559
560 let start = std::time::Instant::now(); let remap = self.remap(&mut uri)?.inspect(|r| debug!("Remapping {r}"));
562
563 let (status, redirects) = match uri.scheme() {
564 _ if self.is_excluded(&uri) => (Status::Excluded, None),
565 _ if uri.is_tel() => (Status::Excluded, None), _ if uri.is_file() => (self.check_file(&uri).await, None),
567 _ if uri.is_mail() => (self.check_mail(&uri).await, None),
568 _ => self.check_website(&uri, credentials).await,
569 };
570
571 Ok(Response::new(
572 uri,
573 status,
574 redirects,
575 remap,
576 source.into(),
577 span,
578 Some(start.elapsed()),
579 ))
580 }
581
582 pub async fn check_file(&self, uri: &Uri) -> Status {
584 self.file_checker.check(uri).await
585 }
586
587 pub fn remap(&self, uri: &mut Uri) -> Result<Option<Remap>> {
594 match self.remaps {
595 Some(ref remaps) => {
596 let remapped = remaps.remap(uri)?;
597 if let Some(remapped) = &remapped {
598 *uri = remapped.new.clone();
599 }
600
601 Ok(remapped)
602 }
603 None => Ok(None),
604 }
605 }
606
607 #[must_use]
609 pub fn is_excluded(&self, uri: &Uri) -> bool {
610 self.filter.is_excluded(uri)
611 }
612
613 pub async fn check_website(
623 &self,
624 uri: &Uri,
625 credentials: Option<BasicAuthCredentials>,
626 ) -> (Status, Option<Redirects>) {
627 self.website_checker.check_website(uri, credentials).await
628 }
629
630 pub async fn check_mail(&self, uri: &Uri) -> Status {
632 self.email_checker.check_mail(uri).await
633 }
634}
635
636pub async fn check<T, E>(request: T) -> Result<Response>
649where
650 Request: TryFrom<T, Error = E>,
651 ErrorKind: From<E>,
652{
653 let client = ClientBuilder::builder().build().client()?;
654 client.check(request).await
655}
656
657#[cfg(test)]
658mod tests {
659 use std::{
660 fs::File,
661 time::{Duration, Instant},
662 };
663
664 use async_trait::async_trait;
665 use http::{StatusCode, header::HeaderMap};
666 use reqwest::header;
667 use tempfile::tempdir;
668 use test_utils::get_mock_client_response;
669 use test_utils::mock_server;
670 use test_utils::redirecting_mock_server;
671 use wiremock::{
672 Mock,
673 matchers::{method, path},
674 };
675
676 use super::ClientBuilder;
677 use crate::{
678 ErrorKind, Redirect, Redirects, Request, Status, Uri,
679 chain::{ChainResult, Handler, RequestChain},
680 remap::{Remap, Remaps},
681 };
682
683 #[tokio::test]
684 async fn test_nonexistent() {
685 let mock_server = mock_server!(StatusCode::NOT_FOUND);
686 let res = get_mock_client_response!(mock_server.uri()).await;
687
688 assert!(res.status().is_error());
689 }
690
691 #[tokio::test]
692 async fn test_nonexistent_with_path() {
693 let res = get_mock_client_response!("http://127.0.0.1/invalid").await;
694 assert!(res.status().is_error());
695 }
696
697 #[tokio::test]
698 async fn test_github() {
699 let res = get_mock_client_response!("https://github.com/lycheeverse/lychee").await;
700 assert!(res.status().is_success());
701 }
702
703 #[tokio::test]
704 async fn test_github_nonexistent_repo() {
705 let res = get_mock_client_response!("https://github.com/lycheeverse/not-lychee").await;
706 assert!(res.status().is_error());
707 }
708
709 #[tokio::test]
710 async fn test_github_nonexistent_file() {
711 let res = get_mock_client_response!(
712 "https://github.com/lycheeverse/lychee/blob/master/NON_EXISTENT_FILE.md",
713 )
714 .await;
715 assert!(res.status().is_error());
716 }
717
718 #[tokio::test]
719 async fn test_youtube() {
720 let res = get_mock_client_response!("https://www.youtube.com/watch?v=NlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
722 assert!(res.status().is_success());
723
724 let res = get_mock_client_response!("https://www.youtube.com/watch?v=invalidNlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
725 assert!(res.status().is_error());
726 }
727
728 #[tokio::test]
729 async fn test_basic_auth() {
730 let mut r: Request = "https://authenticationtest.com/HTTPAuth/"
731 .try_into()
732 .unwrap();
733
734 let res = get_mock_client_response!(r.clone()).await;
735 assert_eq!(res.status().code(), Some(401.try_into().unwrap()));
736
737 r.credentials = Some(crate::BasicAuthCredentials {
738 username: "user".into(),
739 password: "pass".into(),
740 });
741
742 let res = get_mock_client_response!(r).await;
743 assert!(res.status().is_success());
744 }
745
746 #[tokio::test]
747 async fn test_non_github() {
748 let mock_server = mock_server!(StatusCode::OK);
749 let res = get_mock_client_response!(mock_server.uri()).await;
750
751 assert!(res.status().is_success());
752 }
753
754 #[tokio::test]
755 async fn test_invalid_ssl() {
756 let res = get_mock_client_response!("https://expired.badssl.com/").await;
757
758 assert!(res.status().is_error());
759
760 let res = ClientBuilder::builder()
762 .allow_insecure(true)
763 .build()
764 .client()
765 .unwrap()
766 .check("https://expired.badssl.com/")
767 .await
768 .unwrap();
769 assert!(res.status().is_success());
770 }
771
772 #[tokio::test]
773 async fn test_file() {
774 let dir = tempdir().unwrap();
775 let file = dir.path().join("temp");
776 File::create(file).unwrap();
777 let uri = format!("file://{}", dir.path().join("temp").to_str().unwrap());
778
779 let res = get_mock_client_response!(uri).await;
780 assert!(res.status().is_success());
781 }
782
783 #[tokio::test]
784 async fn test_custom_headers() {
785 let mut custom = HeaderMap::new();
787 custom.insert(header::ACCEPT, "text/html".parse().unwrap());
788 let res = ClientBuilder::builder()
789 .custom_headers(custom)
790 .build()
791 .client()
792 .unwrap()
793 .check("https://crates.io/crates/lychee")
794 .await
795 .unwrap();
796 assert!(res.status().is_success());
797 }
798
799 #[tokio::test]
800 async fn test_exclude_mail_by_default() {
801 let client = ClientBuilder::builder()
802 .exclude_all_private(true)
803 .build()
804 .client()
805 .unwrap();
806 assert!(client.is_excluded(&Uri {
807 url: "mailto://mail@example.com".try_into().unwrap()
808 }));
809 }
810
811 #[tokio::test]
812 async fn test_include_mail() {
813 let client = ClientBuilder::builder()
814 .include_mail(false)
815 .exclude_all_private(true)
816 .build()
817 .client()
818 .unwrap();
819 assert!(client.is_excluded(&Uri {
820 url: "mailto://mail@example.com".try_into().unwrap()
821 }));
822
823 let client = ClientBuilder::builder()
824 .include_mail(true)
825 .exclude_all_private(true)
826 .build()
827 .client()
828 .unwrap();
829 assert!(!client.is_excluded(&Uri {
830 url: "mailto://mail@example.com".try_into().unwrap()
831 }));
832 }
833
834 #[tokio::test]
835 async fn test_include_tel() {
836 let client = ClientBuilder::builder().build().client().unwrap();
837 assert!(client.is_excluded(&Uri {
838 url: "tel:1234567890".try_into().unwrap()
839 }));
840 }
841
842 #[tokio::test]
843 async fn test_require_https() {
844 let client = ClientBuilder::builder().build().client().unwrap();
845 let res = client.check("http://rust-lang.org/").await.unwrap();
846 assert!(res.status().is_success());
847
848 let client = ClientBuilder::builder()
850 .require_https(true)
851 .build()
852 .client()
853 .unwrap();
854 let res = client.check("http://rust-lang.org/").await.unwrap();
855 assert!(res.status().is_error());
856 }
857
858 #[tokio::test]
859 async fn test_timeout() {
860 let mock_delay = Duration::from_millis(20);
864 let checker_timeout = Duration::from_millis(10);
865 assert!(mock_delay > checker_timeout);
866
867 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
868
869 let client = ClientBuilder::builder()
870 .timeout(checker_timeout)
871 .max_retries(0u64)
872 .build()
873 .client()
874 .unwrap();
875
876 let res = client.check(mock_server.uri()).await.unwrap();
877 assert!(res.status().is_timeout());
878 }
879
880 #[tokio::test]
881 async fn test_exponential_backoff() {
882 let mock_delay = Duration::from_millis(20);
883 let checker_timeout = Duration::from_millis(10);
884 assert!(mock_delay > checker_timeout);
885
886 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
887
888 let warm_up_client = ClientBuilder::builder()
893 .max_retries(0_u64)
894 .build()
895 .client()
896 .unwrap();
897 let _res = warm_up_client.check(mock_server.uri()).await.unwrap();
898
899 let client = ClientBuilder::builder()
900 .timeout(checker_timeout)
901 .max_retries(3_u64)
902 .retry_wait_time(Duration::from_millis(50))
903 .build()
904 .client()
905 .unwrap();
906
907 let start = Instant::now();
917 let res = client.check(mock_server.uri()).await.unwrap();
918 let end = start.elapsed();
919
920 assert!(res.status().is_error());
921
922 assert!((350..=550).contains(&end.as_millis()));
925 }
926
927 #[tokio::test]
928 async fn test_avoid_reqwest_panic() {
929 let client = ClientBuilder::builder().build().client().unwrap();
930 let res = client.check("http://\"").await.unwrap();
932
933 assert!(matches!(
934 res.status(),
935 Status::Unsupported(ErrorKind::BuildRequestClient(_))
936 ));
937 assert!(res.status().is_unsupported());
938 }
939
940 #[tokio::test]
941 async fn test_max_redirects() {
942 let mock_server = wiremock::MockServer::start().await;
943
944 let redirect_uri = format!("{}/redirect", &mock_server.uri());
945 let redirect = wiremock::ResponseTemplate::new(StatusCode::PERMANENT_REDIRECT)
946 .insert_header("Location", redirect_uri.as_str());
947
948 let redirect_count = 15usize;
949 let initial_invocation = 1;
950
951 Mock::given(method("GET"))
953 .and(path("/redirect"))
954 .respond_with(move |_: &_| redirect.clone())
955 .expect(initial_invocation + redirect_count as u64)
956 .mount(&mock_server)
957 .await;
958
959 let res = ClientBuilder::builder()
960 .max_redirects(redirect_count)
961 .build()
962 .client()
963 .unwrap()
964 .check(redirect_uri.clone())
965 .await
966 .unwrap();
967
968 assert!(matches!(
969 res.status(),
970 Status::Error(ErrorKind::RejectedStatusCode(
971 StatusCode::PERMANENT_REDIRECT
972 ))
973 ));
974 assert!(matches!(
975 res.redirects(),
976 Some(redirects) if redirects.count() == redirect_count,
977 ));
978 }
979
980 #[tokio::test]
981 async fn test_redirects() {
982 redirecting_mock_server!(async |redirect_url: Url, ok_url| {
983 let res = ClientBuilder::builder()
984 .max_redirects(1_usize)
985 .build()
986 .client()
987 .unwrap()
988 .check(Uri::from((redirect_url).clone()))
989 .await
990 .unwrap();
991
992 let mut redirects = Redirects::new(redirect_url);
993 redirects.push(Redirect {
994 url: ok_url,
995 code: StatusCode::PERMANENT_REDIRECT,
996 });
997
998 assert_eq!(res.status(), &Status::Ok(StatusCode::OK));
999 assert_eq!(res.redirects(), Some(&redirects));
1000 })
1001 .await;
1002 }
1003
1004 #[tokio::test]
1005 async fn test_remaps() {
1006 let mapped = String::from("file:///nope");
1007 let client = ClientBuilder::builder()
1008 .remaps(Remaps::new(vec![(
1009 regex::Regex::new("http://example.org").unwrap(),
1010 mapped.clone(),
1011 )]))
1012 .build()
1013 .client()
1014 .unwrap();
1015
1016 let input = Uri::try_from("http://example.org").unwrap();
1017 let res = client.check(input.clone()).await.unwrap();
1018
1019 assert_eq!(
1020 res.status(),
1021 &Status::Error(ErrorKind::InvalidFilePath(
1022 format!("{mapped}/").try_into().unwrap(),
1023 ))
1024 );
1025 assert_eq!(
1026 res.remap(),
1027 Some(&Remap {
1028 original: input,
1029 new: format!("{mapped}/").try_into().unwrap(),
1030 })
1031 );
1032 }
1033
1034 #[tokio::test]
1035 async fn test_unsupported_scheme() {
1036 let examples = vec![
1037 "ftp://example.com",
1038 "gopher://example.com",
1039 "slack://example.com",
1040 ];
1041
1042 for example in examples {
1043 let client = ClientBuilder::builder().build().client().unwrap();
1044 let res = client.check(example).await.unwrap();
1045 assert!(res.status().is_unsupported());
1046 }
1047 }
1048
1049 #[tokio::test]
1050 async fn test_chain() {
1051 use reqwest::Request;
1052
1053 #[derive(Debug)]
1054 struct ExampleHandler();
1055
1056 #[async_trait]
1057 impl Handler<Request, Status> for ExampleHandler {
1058 async fn handle(&mut self, _: Request) -> ChainResult<Request, Status> {
1059 ChainResult::Done(Status::Excluded)
1060 }
1061 }
1062
1063 let chain = RequestChain::new(vec![Box::new(ExampleHandler {})]);
1064
1065 let client = ClientBuilder::builder()
1066 .plugin_request_chain(chain)
1067 .build()
1068 .client()
1069 .unwrap();
1070
1071 let result = client.check("http://example.com");
1072 let res = result.await.unwrap();
1073 assert_eq!(res.status(), &Status::Excluded);
1074 }
1075}