#![allow(
clippy::module_name_repetitions,
clippy::struct_excessive_bools,
clippy::default_trait_access,
clippy::used_underscore_binding
)]
use std::{collections::HashSet, sync::Arc, time::Duration};
use http::{
StatusCode,
header::{HeaderMap, HeaderValue},
};
use log::debug;
use octocrab::Octocrab;
use regex::RegexSet;
use reqwest::{header, redirect, tls};
use reqwest_cookie_store::CookieStoreMutex;
use secrecy::{ExposeSecret, SecretString};
use typed_builder::TypedBuilder;
use crate::{
Base, BasicAuthCredentials, ErrorKind, Request, Response, Result, Status, Uri,
chain::RequestChain,
checker::{file::FileChecker, mail::MailChecker, website::WebsiteChecker},
filter::Filter,
ratelimit::{ClientMap, HostConfigs, HostKey, HostPool, RateLimitConfig},
remap::Remaps,
types::{DEFAULT_ACCEPTED_STATUS_CODES, redirect_history::RedirectHistory},
};
pub const DEFAULT_MAX_REDIRECTS: usize = 5;
pub const DEFAULT_MAX_RETRIES: u64 = 3;
pub const DEFAULT_RETRY_WAIT_TIME_SECS: usize = 1;
pub const DEFAULT_TIMEOUT_SECS: usize = 20;
pub const DEFAULT_USER_AGENT: &str = concat!("lychee/", env!("CARGO_PKG_VERSION"));
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const TCP_KEEPALIVE: Duration = Duration::from_secs(60);
#[derive(TypedBuilder, Debug, Clone)]
#[builder(field_defaults(default, setter(into)))]
pub struct ClientBuilder {
github_token: Option<SecretString>,
remaps: Option<Remaps>,
fallback_extensions: Vec<String>,
#[builder(default = None)]
index_files: Option<Vec<String>>,
includes: Option<RegexSet>,
excludes: Option<RegexSet>,
exclude_all_private: bool,
exclude_private_ips: bool,
exclude_link_local_ips: bool,
exclude_loopback_ips: bool,
include_mail: bool,
#[builder(default = DEFAULT_MAX_REDIRECTS)]
max_redirects: usize,
#[builder(default = DEFAULT_MAX_RETRIES)]
max_retries: u64,
min_tls_version: Option<tls::Version>,
#[builder(default_code = "String::from(DEFAULT_USER_AGENT)")]
user_agent: String,
allow_insecure: bool,
schemes: HashSet<String>,
custom_headers: HeaderMap,
#[builder(default = reqwest::Method::GET)]
method: reqwest::Method,
#[builder(default = DEFAULT_ACCEPTED_STATUS_CODES.clone())]
accepted: HashSet<StatusCode>,
timeout: Option<Duration>,
base: Option<Base>,
#[builder(default_code = "Duration::from_secs(DEFAULT_RETRY_WAIT_TIME_SECS as u64)")]
retry_wait_time: Duration,
require_https: bool,
cookie_jar: Option<Arc<CookieStoreMutex>>,
include_fragments: bool,
include_wikilinks: bool,
plugin_request_chain: RequestChain,
rate_limit_config: RateLimitConfig,
hosts: HostConfigs,
}
impl Default for ClientBuilder {
#[inline]
fn default() -> Self {
Self::builder().build()
}
}
impl ClientBuilder {
pub fn client(self) -> Result<Client> {
let redirect_history = RedirectHistory::new();
let reqwest_client = self
.build_client(&redirect_history)?
.build()
.map_err(ErrorKind::BuildRequestClient)?;
let client_map = self.build_host_clients(&redirect_history)?;
let host_pool = HostPool::new(
self.rate_limit_config,
self.hosts,
reqwest_client,
client_map,
);
let github_client = match self.github_token.as_ref().map(ExposeSecret::expose_secret) {
Some(token) if !token.is_empty() => Some(
Octocrab::builder()
.personal_token(token.to_string())
.build()
.map_err(|e: octocrab::Error| ErrorKind::BuildGithubClient(Box::new(e)))?,
),
_ => None,
};
let filter = Filter {
includes: self.includes.map(Into::into),
excludes: self.excludes.map(Into::into),
schemes: self.schemes,
exclude_private_ips: self.exclude_all_private || self.exclude_private_ips,
exclude_link_local_ips: self.exclude_all_private || self.exclude_link_local_ips,
exclude_loopback_ips: self.exclude_all_private || self.exclude_loopback_ips,
include_mail: self.include_mail,
};
let website_checker = WebsiteChecker::new(
self.method,
self.retry_wait_time,
redirect_history.clone(),
self.max_retries,
self.accepted,
github_client,
self.require_https,
self.plugin_request_chain,
self.include_fragments,
Arc::new(host_pool),
);
Ok(Client {
remaps: self.remaps,
filter,
email_checker: MailChecker::new(self.timeout),
website_checker,
file_checker: FileChecker::new(
self.base,
self.fallback_extensions,
self.index_files,
self.include_fragments,
self.include_wikilinks,
)?,
})
}
fn build_host_clients(&self, redirect_history: &RedirectHistory) -> Result<ClientMap> {
self.hosts
.iter()
.map(|(host, config)| {
let mut headers = self.default_headers()?;
headers.extend(config.headers.clone());
let client = self
.build_client(redirect_history)?
.default_headers(headers)
.build()
.map_err(ErrorKind::BuildRequestClient)?;
Ok((HostKey::from(host.as_str()), client))
})
.collect()
}
fn build_client(&self, redirect_history: &RedirectHistory) -> Result<reqwest::ClientBuilder> {
let mut builder = reqwest::ClientBuilder::new()
.gzip(true)
.default_headers(self.default_headers()?)
.danger_accept_invalid_certs(self.allow_insecure)
.connect_timeout(CONNECT_TIMEOUT)
.tcp_keepalive(TCP_KEEPALIVE)
.redirect(redirect_policy(
redirect_history.clone(),
self.max_redirects,
));
if let Some(cookie_jar) = self.cookie_jar.clone() {
builder = builder.cookie_provider(cookie_jar);
}
if let Some(min_tls) = self.min_tls_version {
builder = builder.min_tls_version(min_tls);
}
if let Some(timeout) = self.timeout {
builder = builder.timeout(timeout);
}
Ok(builder)
}
fn default_headers(&self) -> Result<HeaderMap> {
let user_agent = self.user_agent.clone();
let mut headers = self.custom_headers.clone();
if let Some(prev_user_agent) =
headers.insert(header::USER_AGENT, HeaderValue::try_from(&user_agent)?)
{
debug!(
"Found user-agent in headers: {}. Overriding it with {user_agent}.",
prev_user_agent.to_str().unwrap_or("�"),
);
}
headers.insert(
header::TRANSFER_ENCODING,
HeaderValue::from_static("chunked"),
);
Ok(headers)
}
}
fn redirect_policy(redirect_history: RedirectHistory, max_redirects: usize) -> redirect::Policy {
redirect::Policy::custom(move |attempt| {
if attempt.previous().len() > max_redirects {
attempt.stop()
} else {
redirect_history.record_redirects(&attempt);
debug!("Following redirect to {}", attempt.url());
attempt.follow()
}
})
}
#[derive(Debug, Clone)]
pub struct Client {
remaps: Option<Remaps>,
filter: Filter,
website_checker: WebsiteChecker,
file_checker: FileChecker,
email_checker: MailChecker,
}
impl Client {
#[must_use]
pub fn host_pool(&self) -> Arc<HostPool> {
self.website_checker.host_pool()
}
#[allow(clippy::missing_panics_doc)]
pub async fn check<T, E>(&self, request: T) -> Result<Response>
where
Request: TryFrom<T, Error = E>,
ErrorKind: From<E>,
{
let Request {
ref mut uri,
credentials,
source,
..
} = request.try_into()?;
self.remap(uri)?;
if self.is_excluded(uri) {
return Ok(Response::new(uri.clone(), Status::Excluded, source.into()));
}
let status = match uri.scheme() {
_ if uri.is_tel() => Status::Excluded, _ if uri.is_file() => self.check_file(uri).await,
_ if uri.is_mail() => self.check_mail(uri).await,
_ => self.check_website(uri, credentials).await?,
};
Ok(Response::new(uri.clone(), status, source.into()))
}
pub async fn check_file(&self, uri: &Uri) -> Status {
self.file_checker.check(uri).await
}
pub fn remap(&self, uri: &mut Uri) -> Result<()> {
if let Some(ref remaps) = self.remaps {
uri.url = remaps.remap(&uri.url)?;
}
Ok(())
}
#[must_use]
pub fn is_excluded(&self, uri: &Uri) -> bool {
self.filter.is_excluded(uri)
}
pub async fn check_website(
&self,
uri: &Uri,
credentials: Option<BasicAuthCredentials>,
) -> Result<Status> {
self.website_checker.check_website(uri, credentials).await
}
pub async fn check_mail(&self, uri: &Uri) -> Status {
self.email_checker.check_mail(uri).await
}
}
pub async fn check<T, E>(request: T) -> Result<Response>
where
Request: TryFrom<T, Error = E>,
ErrorKind: From<E>,
{
let client = ClientBuilder::builder().build().client()?;
client.check(request).await
}
#[cfg(test)]
mod tests {
use std::{
fs::File,
time::{Duration, Instant},
};
use async_trait::async_trait;
use http::{StatusCode, header::HeaderMap};
use reqwest::header;
use tempfile::tempdir;
use test_utils::get_mock_client_response;
use test_utils::mock_server;
use test_utils::redirecting_mock_server;
use wiremock::{
Mock,
matchers::{method, path},
};
use super::ClientBuilder;
use crate::{
ErrorKind, Redirect, Redirects, Request, Status, Uri,
chain::{ChainResult, Handler, RequestChain},
};
#[tokio::test]
async fn test_nonexistent() {
let mock_server = mock_server!(StatusCode::NOT_FOUND);
let res = get_mock_client_response!(mock_server.uri()).await;
assert!(res.status().is_error());
}
#[tokio::test]
async fn test_nonexistent_with_path() {
let res = get_mock_client_response!("http://127.0.0.1/invalid").await;
assert!(res.status().is_error());
}
#[tokio::test]
async fn test_github() {
let res = get_mock_client_response!("https://github.com/lycheeverse/lychee").await;
assert!(res.status().is_success());
}
#[tokio::test]
async fn test_github_nonexistent_repo() {
let res = get_mock_client_response!("https://github.com/lycheeverse/not-lychee").await;
assert!(res.status().is_error());
}
#[tokio::test]
async fn test_github_nonexistent_file() {
let res = get_mock_client_response!(
"https://github.com/lycheeverse/lychee/blob/master/NON_EXISTENT_FILE.md",
)
.await;
assert!(res.status().is_error());
}
#[tokio::test]
async fn test_youtube() {
let res = get_mock_client_response!("https://www.youtube.com/watch?v=NlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
assert!(res.status().is_success());
let res = get_mock_client_response!("https://www.youtube.com/watch?v=invalidNlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").await;
assert!(res.status().is_error());
}
#[tokio::test]
async fn test_basic_auth() {
let mut r: Request = "https://authenticationtest.com/HTTPAuth/"
.try_into()
.unwrap();
let res = get_mock_client_response!(r.clone()).await;
assert_eq!(res.status().code(), Some(401.try_into().unwrap()));
r.credentials = Some(crate::BasicAuthCredentials {
username: "user".into(),
password: "pass".into(),
});
let res = get_mock_client_response!(r).await;
assert!(matches!(
res.status(),
Status::Redirected(StatusCode::OK, _)
));
}
#[tokio::test]
async fn test_non_github() {
let mock_server = mock_server!(StatusCode::OK);
let res = get_mock_client_response!(mock_server.uri()).await;
assert!(res.status().is_success());
}
#[tokio::test]
async fn test_invalid_ssl() {
let res = get_mock_client_response!("https://expired.badssl.com/").await;
assert!(res.status().is_error());
let res = ClientBuilder::builder()
.allow_insecure(true)
.build()
.client()
.unwrap()
.check("https://expired.badssl.com/")
.await
.unwrap();
assert!(res.status().is_success());
}
#[tokio::test]
async fn test_file() {
let dir = tempdir().unwrap();
let file = dir.path().join("temp");
File::create(file).unwrap();
let uri = format!("file://{}", dir.path().join("temp").to_str().unwrap());
let res = get_mock_client_response!(uri).await;
assert!(res.status().is_success());
}
#[tokio::test]
async fn test_custom_headers() {
let mut custom = HeaderMap::new();
custom.insert(header::ACCEPT, "text/html".parse().unwrap());
let res = ClientBuilder::builder()
.custom_headers(custom)
.build()
.client()
.unwrap()
.check("https://crates.io/crates/lychee")
.await
.unwrap();
assert!(res.status().is_success());
}
#[tokio::test]
async fn test_exclude_mail_by_default() {
let client = ClientBuilder::builder()
.exclude_all_private(true)
.build()
.client()
.unwrap();
assert!(client.is_excluded(&Uri {
url: "mailto://mail@example.com".try_into().unwrap()
}));
}
#[tokio::test]
async fn test_include_mail() {
let client = ClientBuilder::builder()
.include_mail(false)
.exclude_all_private(true)
.build()
.client()
.unwrap();
assert!(client.is_excluded(&Uri {
url: "mailto://mail@example.com".try_into().unwrap()
}));
let client = ClientBuilder::builder()
.include_mail(true)
.exclude_all_private(true)
.build()
.client()
.unwrap();
assert!(!client.is_excluded(&Uri {
url: "mailto://mail@example.com".try_into().unwrap()
}));
}
#[tokio::test]
async fn test_include_tel() {
let client = ClientBuilder::builder().build().client().unwrap();
assert!(client.is_excluded(&Uri {
url: "tel:1234567890".try_into().unwrap()
}));
}
#[tokio::test]
async fn test_require_https() {
let client = ClientBuilder::builder().build().client().unwrap();
let res = client.check("http://example.com").await.unwrap();
assert!(res.status().is_success());
let client = ClientBuilder::builder()
.require_https(true)
.build()
.client()
.unwrap();
let res = client.check("http://example.com").await.unwrap();
assert!(res.status().is_error());
}
#[tokio::test]
async fn test_timeout() {
let mock_delay = Duration::from_millis(20);
let checker_timeout = Duration::from_millis(10);
assert!(mock_delay > checker_timeout);
let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
let client = ClientBuilder::builder()
.timeout(checker_timeout)
.max_retries(0u64)
.build()
.client()
.unwrap();
let res = client.check(mock_server.uri()).await.unwrap();
assert!(res.status().is_timeout());
}
#[tokio::test]
async fn test_exponential_backoff() {
let mock_delay = Duration::from_millis(20);
let checker_timeout = Duration::from_millis(10);
assert!(mock_delay > checker_timeout);
let mock_server = mock_server!(StatusCode::OK, set_delay(mock_delay));
let warm_up_client = ClientBuilder::builder()
.max_retries(0_u64)
.build()
.client()
.unwrap();
let _res = warm_up_client.check(mock_server.uri()).await.unwrap();
let client = ClientBuilder::builder()
.timeout(checker_timeout)
.max_retries(3_u64)
.retry_wait_time(Duration::from_millis(50))
.build()
.client()
.unwrap();
let start = Instant::now();
let res = client.check(mock_server.uri()).await.unwrap();
let end = start.elapsed();
assert!(res.status().is_error());
assert!((350..=550).contains(&end.as_millis()));
}
#[tokio::test]
async fn test_avoid_reqwest_panic() {
let client = ClientBuilder::builder().build().client().unwrap();
let res = client.check("http://\"").await.unwrap();
assert!(matches!(
res.status(),
Status::Unsupported(ErrorKind::BuildRequestClient(_))
));
assert!(res.status().is_unsupported());
}
#[tokio::test]
async fn test_max_redirects() {
let mock_server = wiremock::MockServer::start().await;
let redirect_uri = format!("{}/redirect", &mock_server.uri());
let redirect = wiremock::ResponseTemplate::new(StatusCode::PERMANENT_REDIRECT)
.insert_header("Location", redirect_uri.as_str());
let redirect_count = 15usize;
let initial_invocation = 1;
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(move |_: &_| redirect.clone())
.expect(initial_invocation + redirect_count as u64)
.mount(&mock_server)
.await;
let res = ClientBuilder::builder()
.max_redirects(redirect_count)
.build()
.client()
.unwrap()
.check(redirect_uri.clone())
.await
.unwrap();
assert_eq!(
res.status(),
&Status::Error(ErrorKind::RejectedStatusCode(
StatusCode::PERMANENT_REDIRECT
))
);
}
#[tokio::test]
async fn test_redirects() {
redirecting_mock_server!(async |redirect_url: Url, ok_url| {
let res = ClientBuilder::builder()
.max_redirects(1_usize)
.build()
.client()
.unwrap()
.check(Uri::from((redirect_url).clone()))
.await
.unwrap();
let mut redirects = Redirects::new(redirect_url);
redirects.push(Redirect {
url: ok_url,
code: StatusCode::PERMANENT_REDIRECT,
});
assert_eq!(res.status(), &Status::Redirected(StatusCode::OK, redirects));
})
.await;
}
#[tokio::test]
async fn test_unsupported_scheme() {
let examples = vec![
"ftp://example.com",
"gopher://example.com",
"slack://example.com",
];
for example in examples {
let client = ClientBuilder::builder().build().client().unwrap();
let res = client.check(example).await.unwrap();
assert!(res.status().is_unsupported());
}
}
#[tokio::test]
async fn test_chain() {
use reqwest::Request;
#[derive(Debug)]
struct ExampleHandler();
#[async_trait]
impl Handler<Request, Status> for ExampleHandler {
async fn handle(&mut self, _: Request) -> ChainResult<Request, Status> {
ChainResult::Done(Status::Excluded)
}
}
let chain = RequestChain::new(vec![Box::new(ExampleHandler {})]);
let client = ClientBuilder::builder()
.plugin_request_chain(chain)
.build()
.client()
.unwrap();
let result = client.check("http://example.com");
let res = result.await.unwrap();
assert_eq!(res.status(), &Status::Excluded);
}
}