1#![allow(
11 clippy::module_name_repetitions,
12 clippy::struct_excessive_bools,
13 clippy::default_trait_access,
14 clippy::used_underscore_binding
15)]
16use std::{collections::HashSet, sync::Arc, time::Duration};
17
18use http::{
19 StatusCode,
20 header::{HeaderMap, HeaderValue},
21};
22use log::debug;
23use octocrab::Octocrab;
24use regex::RegexSet;
25use reqwest::{header, redirect, tls};
26use reqwest_cookie_store::CookieStoreMutex;
27use secrecy::{ExposeSecret, SecretString};
28use typed_builder::TypedBuilder;
29
30use crate::{
31 Base, BasicAuthCredentials, ErrorKind, Request, Response, Result, Status, Uri,
32 chain::RequestChain,
33 checker::{file::FileChecker, mail::MailChecker, website::WebsiteChecker},
34 filter::Filter,
35 ratelimit::{ClientMap, HostConfigs, HostKey, HostPool, RateLimitConfig},
36 remap::Remaps,
37 types::{DEFAULT_ACCEPTED_STATUS_CODES, redirect_history::RedirectHistory},
38};
39
40pub const DEFAULT_MAX_REDIRECTS: usize = 5;
42pub const DEFAULT_MAX_RETRIES: u64 = 3;
44pub const DEFAULT_RETRY_WAIT_TIME_SECS: usize = 1;
46pub const DEFAULT_TIMEOUT_SECS: usize = 20;
48pub const DEFAULT_USER_AGENT: &str = concat!("lychee/", env!("CARGO_PKG_VERSION"));
50
51const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
54const TCP_KEEPALIVE: Duration = Duration::from_secs(60);
59
60#[derive(TypedBuilder, Debug, Clone)]
64#[builder(field_defaults(default, setter(into)))]
65pub struct ClientBuilder {
66 github_token: Option<SecretString>,
75
76 remaps: Option<Remaps>,
91
92 fallback_extensions: Vec<String>,
96
97 #[builder(default = None)]
111 index_files: Option<Vec<String>>,
112
113 includes: Option<RegexSet>,
119
120 excludes: Option<RegexSet>,
123
124 exclude_all_private: bool,
131
132 exclude_private_ips: bool,
159
160 exclude_link_local_ips: bool,
180
181 exclude_loopback_ips: bool,
197
198 include_mail: bool,
200
201 #[builder(default = DEFAULT_MAX_REDIRECTS)]
205 max_redirects: usize,
206
207 #[builder(default = DEFAULT_MAX_RETRIES)]
211 max_retries: u64,
212
213 min_tls_version: Option<tls::Version>,
215
216 #[builder(default_code = "String::from(DEFAULT_USER_AGENT)")]
226 user_agent: String,
227
228 allow_insecure: bool,
238
239 schemes: HashSet<String>,
244
245 custom_headers: HeaderMap,
253
254 #[builder(default = reqwest::Method::GET)]
256 method: reqwest::Method,
257
258 #[builder(default = DEFAULT_ACCEPTED_STATUS_CODES.clone())]
262 accepted: HashSet<StatusCode>,
263
264 timeout: Option<Duration>,
266
267 base: Option<Base>,
272
273 #[builder(default_code = "Duration::from_secs(DEFAULT_RETRY_WAIT_TIME_SECS as u64)")]
286 retry_wait_time: Duration,
287
288 require_https: bool,
294
295 cookie_jar: Option<Arc<CookieStoreMutex>>,
299
300 include_fragments: bool,
302
303 include_wikilinks: bool,
306
307 plugin_request_chain: RequestChain,
312
313 rate_limit_config: RateLimitConfig,
315
316 hosts: HostConfigs,
318}
319
320impl Default for ClientBuilder {
321 #[inline]
322 fn default() -> Self {
323 Self::builder().build()
324 }
325}
326
327impl ClientBuilder {
328 pub fn client(self) -> Result<Client> {
343 let redirect_history = RedirectHistory::new();
344 let reqwest_client = self
345 .build_client(&redirect_history)?
346 .build()
347 .map_err(ErrorKind::BuildRequestClient)?;
348
349 let client_map = self.build_host_clients(&redirect_history)?;
350
351 let host_pool = HostPool::new(
352 self.rate_limit_config,
353 self.hosts,
354 reqwest_client,
355 client_map,
356 );
357
358 let github_client = match self.github_token.as_ref().map(ExposeSecret::expose_secret) {
359 Some(token) if !token.is_empty() => Some(
360 Octocrab::builder()
361 .personal_token(token.to_string())
362 .build()
363 .map_err(|e: octocrab::Error| ErrorKind::BuildGithubClient(Box::new(e)))?,
366 ),
367 _ => None,
368 };
369
370 let filter = Filter {
371 includes: self.includes.map(Into::into),
372 excludes: self.excludes.map(Into::into),
373 schemes: self.schemes,
374 exclude_private_ips: self.exclude_all_private || self.exclude_private_ips,
377 exclude_link_local_ips: self.exclude_all_private || self.exclude_link_local_ips,
378 exclude_loopback_ips: self.exclude_all_private || self.exclude_loopback_ips,
379 include_mail: self.include_mail,
380 };
381
382 let website_checker = WebsiteChecker::new(
383 self.method,
384 self.retry_wait_time,
385 redirect_history.clone(),
386 self.max_retries,
387 self.accepted,
388 github_client,
389 self.require_https,
390 self.plugin_request_chain,
391 self.include_fragments,
392 Arc::new(host_pool),
393 );
394
395 Ok(Client {
396 remaps: self.remaps,
397 filter,
398 email_checker: MailChecker::new(self.timeout),
399 website_checker,
400 file_checker: FileChecker::new(
401 self.base,
402 self.fallback_extensions,
403 self.index_files,
404 self.include_fragments,
405 self.include_wikilinks,
406 )?,
407 })
408 }
409
410 fn build_host_clients(&self, redirect_history: &RedirectHistory) -> Result<ClientMap> {
412 self.hosts
413 .iter()
414 .map(|(host, config)| {
415 let mut headers = self.default_headers()?;
416 headers.extend(config.headers.clone());
417 let client = self
418 .build_client(redirect_history)?
419 .default_headers(headers)
420 .build()
421 .map_err(ErrorKind::BuildRequestClient)?;
422 Ok((HostKey::from(host.as_str()), client))
423 })
424 .collect()
425 }
426
427 fn build_client(&self, redirect_history: &RedirectHistory) -> Result<reqwest::ClientBuilder> {
429 let mut builder = reqwest::ClientBuilder::new()
430 .gzip(true)
431 .default_headers(self.default_headers()?)
432 .danger_accept_invalid_certs(self.allow_insecure)
433 .connect_timeout(CONNECT_TIMEOUT)
434 .tcp_keepalive(TCP_KEEPALIVE)
435 .redirect(redirect_policy(
436 redirect_history.clone(),
437 self.max_redirects,
438 ));
439
440 if let Some(cookie_jar) = self.cookie_jar.clone() {
441 builder = builder.cookie_provider(cookie_jar);
442 }
443
444 if let Some(min_tls) = self.min_tls_version {
445 builder = builder.min_tls_version(min_tls);
446 }
447
448 if let Some(timeout) = self.timeout {
449 builder = builder.timeout(timeout);
450 }
451
452 Ok(builder)
453 }
454
455 fn default_headers(&self) -> Result<HeaderMap> {
456 let user_agent = self.user_agent.clone();
457 let mut headers = self.custom_headers.clone();
458
459 if let Some(prev_user_agent) =
460 headers.insert(header::USER_AGENT, HeaderValue::try_from(&user_agent)?)
461 {
462 debug!(
463 "Found user-agent in headers: {}. Overriding it with {user_agent}.",
464 prev_user_agent.to_str().unwrap_or("�"),
465 );
466 }
467
468 headers.insert(
469 header::TRANSFER_ENCODING,
470 HeaderValue::from_static("chunked"),
471 );
472
473 Ok(headers)
474 }
475}
476
477fn redirect_policy(redirect_history: RedirectHistory, max_redirects: usize) -> redirect::Policy {
480 redirect::Policy::custom(move |attempt| {
481 if attempt.previous().len() > max_redirects {
482 attempt.stop()
483 } else {
484 redirect_history.record_redirects(&attempt);
485 debug!("Following redirect to {}", attempt.url());
486 attempt.follow()
487 }
488 })
489}
490
491#[derive(Debug, Clone)]
496pub struct Client {
497 remaps: Option<Remaps>,
499
500 filter: Filter,
502
503 website_checker: WebsiteChecker,
505
506 file_checker: FileChecker,
508
509 email_checker: MailChecker,
511}
512
513impl Client {
514 #[must_use]
516 pub fn host_pool(&self) -> Arc<HostPool> {
517 self.website_checker.host_pool()
518 }
519
520 #[allow(clippy::missing_panics_doc)]
532 pub async fn check<T, E>(&self, request: T) -> Result<Response>
533 where
534 Request: TryFrom<T, Error = E>,
535 ErrorKind: From<E>,
536 {
537 let Request {
538 ref mut uri,
539 credentials,
540 source,
541 ..
542 } = request.try_into()?;
543
544 self.remap(uri)?;
545
546 if self.is_excluded(uri) {
547 return Ok(Response::new(uri.clone(), Status::Excluded, source.into()));
548 }
549
550 let status = match uri.scheme() {
551 _ if uri.is_tel() => Status::Excluded, _ if uri.is_file() => self.check_file(uri).await,
553 _ if uri.is_mail() => self.check_mail(uri).await,
554 _ => self.check_website(uri, credentials).await?,
555 };
556
557 Ok(Response::new(uri.clone(), status, source.into()))
558 }
559
560 pub async fn check_file(&self, uri: &Uri) -> Status {
562 self.file_checker.check(uri).await
563 }
564
565 pub fn remap(&self, uri: &mut Uri) -> Result<()> {
571 if let Some(ref remaps) = self.remaps {
572 uri.url = remaps.remap(&uri.url)?;
573 }
574 Ok(())
575 }
576
577 #[must_use]
579 pub fn is_excluded(&self, uri: &Uri) -> bool {
580 self.filter.is_excluded(uri)
581 }
582
583 pub async fn check_website(
593 &self,
594 uri: &Uri,
595 credentials: Option<BasicAuthCredentials>,
596 ) -> Result<Status> {
597 self.website_checker.check_website(uri, credentials).await
598 }
599
600 pub async fn check_mail(&self, uri: &Uri) -> Status {
602 self.email_checker.check_mail(uri).await
603 }
604}
605
606pub async fn check<T, E>(request: T) -> Result<Response>
619where
620 Request: TryFrom<T, Error = E>,
621 ErrorKind: From<E>,
622{
623 let client = ClientBuilder::builder().build().client()?;
624 client.check(request).await
625}
626
627#[cfg(test)]
628mod tests {
629 use std::{
630 fs::File,
631 time::{Duration, Instant},
632 };
633
634 use async_trait::async_trait;
635 use http::{StatusCode, header::HeaderMap};
636 use reqwest::header;
637 use tempfile::tempdir;
638 use test_utils::get_mock_client_response;
639 use test_utils::mock_server;
640 use test_utils::redirecting_mock_server;
641 use wiremock::{
642 Mock,
643 matchers::{method, path},
644 };
645
646 use super::ClientBuilder;
647 use crate::{
648 ErrorKind, Redirect, Redirects, Request, Status, Uri,
649 chain::{ChainResult, Handler, RequestChain},
650 };
651
652 #[tokio::test]
653 async fn test_nonexistent() {
654 let mock_server = mock_server!(StatusCode::NOT_FOUND);
655 let res = get_mock_client_response!(mock_server.uri()).await;
656
657 assert!(res.status().is_error());
658 }
659
660 #[tokio::test]
661 async fn test_nonexistent_with_path() {
662 let res = get_mock_client_response!("http://127.0.0.1/invalid").await;
663 assert!(res.status().is_error());
664 }
665
666 #[tokio::test]
667 async fn test_github() {
668 let res = get_mock_client_response!("https://github.com/lycheeverse/lychee").await;
669 assert!(res.status().is_success());
670 }
671
672 #[tokio::test]
673 async fn test_github_nonexistent_repo() {
674 let res = get_mock_client_response!("https://github.com/lycheeverse/not-lychee").await;
675 assert!(res.status().is_error());
676 }
677
678 #[tokio::test]
679 async fn test_github_nonexistent_file() {
680 let res = get_mock_client_response!(
681 "https://github.com/lycheeverse/lychee/blob/master/NON_EXISTENT_FILE.md",
682 )
683 .await;
684 assert!(res.status().is_error());
685 }
686
687 #[tokio::test]
688 async fn test_youtube() {
689 let res = get_mock_client_response!("https://www.youtube.com/watch?v=NlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
691 assert!(res.status().is_success());
692
693 let res = get_mock_client_response!("https://www.youtube.com/watch?v=invalidNlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
694 assert!(res.status().is_error());
695 }
696
697 #[tokio::test]
698 async fn test_basic_auth() {
699 let mut r: Request = "https://authenticationtest.com/HTTPAuth/"
700 .try_into()
701 .unwrap();
702
703 let res = get_mock_client_response!(r.clone()).await;
704 assert_eq!(res.status().code(), Some(401.try_into().unwrap()));
705
706 r.credentials = Some(crate::BasicAuthCredentials {
707 username: "user".into(),
708 password: "pass".into(),
709 });
710
711 let res = get_mock_client_response!(r).await;
712 assert!(matches!(
713 res.status(),
714 Status::Redirected(StatusCode::OK, _)
715 ));
716 }
717
718 #[tokio::test]
719 async fn test_non_github() {
720 let mock_server = mock_server!(StatusCode::OK);
721 let res = get_mock_client_response!(mock_server.uri()).await;
722
723 assert!(res.status().is_success());
724 }
725
726 #[tokio::test]
727 async fn test_invalid_ssl() {
728 let res = get_mock_client_response!("https://expired.badssl.com/").await;
729
730 assert!(res.status().is_error());
731
732 let res = ClientBuilder::builder()
734 .allow_insecure(true)
735 .build()
736 .client()
737 .unwrap()
738 .check("https://expired.badssl.com/")
739 .await
740 .unwrap();
741 assert!(res.status().is_success());
742 }
743
744 #[tokio::test]
745 async fn test_file() {
746 let dir = tempdir().unwrap();
747 let file = dir.path().join("temp");
748 File::create(file).unwrap();
749 let uri = format!("file://{}", dir.path().join("temp").to_str().unwrap());
750
751 let res = get_mock_client_response!(uri).await;
752 assert!(res.status().is_success());
753 }
754
755 #[tokio::test]
756 async fn test_custom_headers() {
757 let mut custom = HeaderMap::new();
759 custom.insert(header::ACCEPT, "text/html".parse().unwrap());
760 let res = ClientBuilder::builder()
761 .custom_headers(custom)
762 .build()
763 .client()
764 .unwrap()
765 .check("https://crates.io/crates/lychee")
766 .await
767 .unwrap();
768 assert!(res.status().is_success());
769 }
770
771 #[tokio::test]
772 async fn test_exclude_mail_by_default() {
773 let client = ClientBuilder::builder()
774 .exclude_all_private(true)
775 .build()
776 .client()
777 .unwrap();
778 assert!(client.is_excluded(&Uri {
779 url: "mailto://mail@example.com".try_into().unwrap()
780 }));
781 }
782
783 #[tokio::test]
784 async fn test_include_mail() {
785 let client = ClientBuilder::builder()
786 .include_mail(false)
787 .exclude_all_private(true)
788 .build()
789 .client()
790 .unwrap();
791 assert!(client.is_excluded(&Uri {
792 url: "mailto://mail@example.com".try_into().unwrap()
793 }));
794
795 let client = ClientBuilder::builder()
796 .include_mail(true)
797 .exclude_all_private(true)
798 .build()
799 .client()
800 .unwrap();
801 assert!(!client.is_excluded(&Uri {
802 url: "mailto://mail@example.com".try_into().unwrap()
803 }));
804 }
805
806 #[tokio::test]
807 async fn test_include_tel() {
808 let client = ClientBuilder::builder().build().client().unwrap();
809 assert!(client.is_excluded(&Uri {
810 url: "tel:1234567890".try_into().unwrap()
811 }));
812 }
813
814 #[tokio::test]
815 async fn test_require_https() {
816 let client = ClientBuilder::builder().build().client().unwrap();
817 let res = client.check("http://example.com").await.unwrap();
818 assert!(res.status().is_success());
819
820 let client = ClientBuilder::builder()
822 .require_https(true)
823 .build()
824 .client()
825 .unwrap();
826 let res = client.check("http://example.com").await.unwrap();
827 assert!(res.status().is_error());
828 }
829
830 #[tokio::test]
831 async fn test_timeout() {
832 let mock_delay = Duration::from_millis(20);
836 let checker_timeout = Duration::from_millis(10);
837 assert!(mock_delay > checker_timeout);
838
839 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
840
841 let client = ClientBuilder::builder()
842 .timeout(checker_timeout)
843 .max_retries(0u64)
844 .build()
845 .client()
846 .unwrap();
847
848 let res = client.check(mock_server.uri()).await.unwrap();
849 assert!(res.status().is_timeout());
850 }
851
852 #[tokio::test]
853 async fn test_exponential_backoff() {
854 let mock_delay = Duration::from_millis(20);
855 let checker_timeout = Duration::from_millis(10);
856 assert!(mock_delay > checker_timeout);
857
858 let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
859
860 let warm_up_client = ClientBuilder::builder()
865 .max_retries(0_u64)
866 .build()
867 .client()
868 .unwrap();
869 let _res = warm_up_client.check(mock_server.uri()).await.unwrap();
870
871 let client = ClientBuilder::builder()
872 .timeout(checker_timeout)
873 .max_retries(3_u64)
874 .retry_wait_time(Duration::from_millis(50))
875 .build()
876 .client()
877 .unwrap();
878
879 let start = Instant::now();
889 let res = client.check(mock_server.uri()).await.unwrap();
890 let end = start.elapsed();
891
892 assert!(res.status().is_error());
893
894 assert!((350..=550).contains(&end.as_millis()));
897 }
898
899 #[tokio::test]
900 async fn test_avoid_reqwest_panic() {
901 let client = ClientBuilder::builder().build().client().unwrap();
902 let res = client.check("http://\"").await.unwrap();
904
905 assert!(matches!(
906 res.status(),
907 Status::Unsupported(ErrorKind::BuildRequestClient(_))
908 ));
909 assert!(res.status().is_unsupported());
910 }
911
912 #[tokio::test]
913 async fn test_max_redirects() {
914 let mock_server = wiremock::MockServer::start().await;
915
916 let redirect_uri = format!("{}/redirect", &mock_server.uri());
917 let redirect = wiremock::ResponseTemplate::new(StatusCode::PERMANENT_REDIRECT)
918 .insert_header("Location", redirect_uri.as_str());
919
920 let redirect_count = 15usize;
921 let initial_invocation = 1;
922
923 Mock::given(method("GET"))
925 .and(path("/redirect"))
926 .respond_with(move |_: &_| redirect.clone())
927 .expect(initial_invocation + redirect_count as u64)
928 .mount(&mock_server)
929 .await;
930
931 let res = ClientBuilder::builder()
932 .max_redirects(redirect_count)
933 .build()
934 .client()
935 .unwrap()
936 .check(redirect_uri.clone())
937 .await
938 .unwrap();
939
940 assert_eq!(
941 res.status(),
942 &Status::Error(ErrorKind::RejectedStatusCode(
943 StatusCode::PERMANENT_REDIRECT
944 ))
945 );
946 }
947
948 #[tokio::test]
949 async fn test_redirects() {
950 redirecting_mock_server!(async |redirect_url: Url, ok_url| {
951 let res = ClientBuilder::builder()
952 .max_redirects(1_usize)
953 .build()
954 .client()
955 .unwrap()
956 .check(Uri::from((redirect_url).clone()))
957 .await
958 .unwrap();
959
960 let mut redirects = Redirects::new(redirect_url);
961 redirects.push(Redirect {
962 url: ok_url,
963 code: StatusCode::PERMANENT_REDIRECT,
964 });
965 assert_eq!(res.status(), &Status::Redirected(StatusCode::OK, redirects));
966 })
967 .await;
968 }
969
970 #[tokio::test]
971 async fn test_unsupported_scheme() {
972 let examples = vec![
973 "ftp://example.com",
974 "gopher://example.com",
975 "slack://example.com",
976 ];
977
978 for example in examples {
979 let client = ClientBuilder::builder().build().client().unwrap();
980 let res = client.check(example).await.unwrap();
981 assert!(res.status().is_unsupported());
982 }
983 }
984
985 #[tokio::test]
986 async fn test_chain() {
987 use reqwest::Request;
988
989 #[derive(Debug)]
990 struct ExampleHandler();
991
992 #[async_trait]
993 impl Handler<Request, Status> for ExampleHandler {
994 async fn handle(&mut self, _: Request) -> ChainResult<Request, Status> {
995 ChainResult::Done(Status::Excluded)
996 }
997 }
998
999 let chain = RequestChain::new(vec![Box::new(ExampleHandler {})]);
1000
1001 let client = ClientBuilder::builder()
1002 .plugin_request_chain(chain)
1003 .build()
1004 .client()
1005 .unwrap();
1006
1007 let result = client.check("http://example.com");
1008 let res = result.await.unwrap();
1009 assert_eq!(res.status(), &Status::Excluded);
1010 }
1011}