use std::error::Error;
use std::fmt::Debug;
use std::fmt::Write;
use std::num::ParseIntError;
use std::sync::Arc;
use std::time::{Duration, SystemTime, SystemTimeError};
use std::{env, io, iter};
use anyhow::anyhow;
use http::{
HeaderMap, HeaderName, HeaderValue, Method, StatusCode,
header::{
AUTHORIZATION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, LOCATION,
PROXY_AUTHORIZATION, REFERER, TRANSFER_ENCODING, WWW_AUTHENTICATE,
},
};
use itertools::Itertools;
use reqwest::{
Certificate, Client, ClientBuilder, IntoUrl, NoProxy, Proxy, Request, Response, multipart,
};
use reqwest_middleware::{ClientWithMiddleware, Middleware};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::{
Jitter, RetryPolicy, RetryTransientMiddleware, Retryable, RetryableStrategy,
default_on_request_error, default_on_request_success,
};
use thiserror::Error;
use tracing::{debug, trace, warn};
use url::ParseError;
use url::Url;
use uv_auth::{AuthMiddleware, Credentials, CredentialsCache, Indexes, PyxTokenStore};
use uv_configuration::ProxyUrlKind;
use uv_configuration::{KeyringProviderType, ProxyUrl, TrustedHost};
use uv_pep508::MarkerEnvironment;
use uv_platform_tags::Platform;
use uv_preview::Preview;
use uv_redacted::DisplaySafeUrl;
use uv_redacted::DisplaySafeUrlError;
use uv_static::EnvVars;
use uv_version::version;
use uv_warnings::warn_user_once;
use crate::linehaul::LineHaul;
use crate::middleware::OfflineMiddleware;
use crate::tls::{Certificates, read_identity};
use crate::{Connectivity, WrappedReqwestError};
pub const DEFAULT_RETRIES: u32 = 3;
pub const DEFAULT_MAX_REDIRECTS: u32 = 10;
pub const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
pub const DEFAULT_READ_TIMEOUT_UPLOAD: Duration = Duration::from_mins(15);
#[derive(Debug, Error)]
#[error("failed to build HTTP client")]
pub struct ClientBuildError(#[source] reqwest::Error);
impl From<reqwest::Error> for ClientBuildError {
fn from(error: reqwest::Error) -> Self {
Self(error)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum AuthIntegration {
#[default]
Default,
OnlyAuthenticated,
NoAuthMiddleware,
}
#[derive(Debug, Clone)]
pub struct BaseClientBuilder<'a> {
keyring: KeyringProviderType,
preview: Preview,
allow_insecure_host: Vec<TrustedHost>,
system_certs: bool,
retries: u32,
pub connectivity: Connectivity,
markers: Option<&'a MarkerEnvironment>,
platform: Option<&'a Platform>,
auth_integration: AuthIntegration,
credentials_cache: Arc<CredentialsCache>,
indexes: Indexes,
read_timeout: Duration,
connect_timeout: Duration,
extra_middleware: Option<ExtraMiddleware>,
proxies: Vec<Proxy>,
http_proxy: Option<ProxyUrl>,
https_proxy: Option<ProxyUrl>,
no_proxy: Option<Vec<String>>,
redirect_policy: RedirectPolicy,
cross_origin_credential_policy: CrossOriginCredentialsPolicy,
custom_client: Option<Client>,
subcommand: Option<Vec<String>>,
client_name: Option<&'static str>,
no_retry_delay: bool,
}
#[derive(Debug, Default, Clone, Copy)]
pub enum RedirectPolicy {
#[default]
BypassMiddleware,
RetriggerMiddleware,
NoRedirect,
}
impl RedirectPolicy {
pub fn reqwest_policy(self) -> reqwest::redirect::Policy {
match self {
Self::BypassMiddleware => reqwest::redirect::Policy::default(),
Self::RetriggerMiddleware => reqwest::redirect::Policy::none(),
Self::NoRedirect => reqwest::redirect::Policy::none(),
}
}
}
#[derive(Clone)]
pub struct ExtraMiddleware(pub Vec<Arc<dyn Middleware>>);
impl Debug for ExtraMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtraMiddleware")
.field("0", &format!("{} middlewares", self.0.len()))
.finish()
}
}
impl Default for BaseClientBuilder<'_> {
fn default() -> Self {
Self {
keyring: KeyringProviderType::default(),
preview: Preview::default(),
allow_insecure_host: vec![],
system_certs: false,
connectivity: Connectivity::Online,
retries: DEFAULT_RETRIES,
markers: None,
platform: None,
auth_integration: AuthIntegration::default(),
credentials_cache: Arc::new(CredentialsCache::default()),
indexes: Indexes::new(),
read_timeout: DEFAULT_READ_TIMEOUT,
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
extra_middleware: None,
proxies: vec![],
http_proxy: None,
https_proxy: None,
no_proxy: None,
redirect_policy: RedirectPolicy::default(),
cross_origin_credential_policy: CrossOriginCredentialsPolicy::Secure,
custom_client: None,
subcommand: None,
client_name: None,
no_retry_delay: env::var_os(EnvVars::UV_TEST_NO_HTTP_RETRY_DELAY).is_some(),
}
}
}
impl<'a> BaseClientBuilder<'a> {
pub fn new(
connectivity: Connectivity,
system_certs: bool,
allow_insecure_host: Vec<TrustedHost>,
preview: Preview,
read_timeout: Duration,
connect_timeout: Duration,
retries: u32,
) -> Self {
Self {
preview,
allow_insecure_host,
system_certs,
retries,
connectivity,
read_timeout,
connect_timeout,
..Self::default()
}
}
#[must_use]
pub fn custom_client(mut self, client: Client) -> Self {
self.custom_client = Some(client);
self
}
#[must_use]
pub fn keyring(mut self, keyring_type: KeyringProviderType) -> Self {
self.keyring = keyring_type;
self
}
#[must_use]
pub fn allow_insecure_host(mut self, allow_insecure_host: Vec<TrustedHost>) -> Self {
self.allow_insecure_host = allow_insecure_host;
self
}
#[must_use]
pub fn connectivity(mut self, connectivity: Connectivity) -> Self {
self.connectivity = connectivity;
self
}
#[must_use]
pub fn retries(mut self, retries: u32) -> Self {
self.retries = retries;
self
}
#[must_use]
pub fn no_retry_delay(mut self, no_retry_delay: bool) -> Self {
self.no_retry_delay = no_retry_delay;
self
}
#[must_use]
pub fn system_certs(&self) -> bool {
self.system_certs
}
#[must_use]
pub fn with_system_certs(mut self, system_certs: bool) -> Self {
self.system_certs = system_certs;
self
}
#[must_use]
pub fn markers(mut self, markers: &'a MarkerEnvironment) -> Self {
self.markers = Some(markers);
self
}
#[must_use]
pub fn platform(mut self, platform: &'a Platform) -> Self {
self.platform = Some(platform);
self
}
#[must_use]
pub fn auth_integration(mut self, auth_integration: AuthIntegration) -> Self {
self.auth_integration = auth_integration;
self
}
#[must_use]
pub fn indexes(mut self, indexes: Indexes) -> Self {
self.indexes = indexes;
self
}
#[must_use]
pub fn read_timeout(mut self, read_timeout: Duration) -> Self {
self.read_timeout = read_timeout;
self
}
#[must_use]
pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
self.connect_timeout = connect_timeout;
self
}
#[must_use]
pub fn extra_middleware(mut self, middleware: ExtraMiddleware) -> Self {
self.extra_middleware = Some(middleware);
self
}
#[must_use]
pub fn proxy(mut self, proxy: Proxy) -> Self {
self.proxies.push(proxy);
self
}
#[must_use]
pub fn http_proxy(mut self, http_proxy: Option<ProxyUrl>) -> Self {
self.http_proxy = http_proxy;
self
}
#[must_use]
pub fn https_proxy(mut self, https_proxy: Option<ProxyUrl>) -> Self {
self.https_proxy = https_proxy;
self
}
#[must_use]
pub fn no_proxy(mut self, no_proxy: Option<Vec<String>>) -> Self {
self.no_proxy = no_proxy;
self
}
#[must_use]
pub fn redirect(mut self, policy: RedirectPolicy) -> Self {
self.redirect_policy = policy;
self
}
#[cfg(test)]
#[must_use]
pub fn allow_cross_origin_credentials(mut self) -> Self {
self.cross_origin_credential_policy = CrossOriginCredentialsPolicy::Insecure;
self
}
#[must_use]
pub fn subcommand(mut self, subcommand: Vec<String>) -> Self {
self.subcommand = Some(subcommand);
self
}
#[must_use]
pub fn client_name(mut self, name: &'static str) -> Self {
self.client_name = Some(name);
self
}
pub fn credentials_cache(&self) -> &CredentialsCache {
&self.credentials_cache
}
pub fn store_credentials_from_url(&self, url: &DisplaySafeUrl) -> bool {
self.credentials_cache.store_credentials_from_url(url)
}
pub fn store_credentials(&self, url: &DisplaySafeUrl, credentials: Credentials) {
self.credentials_cache.store_credentials(url, credentials);
}
pub fn is_offline(&self) -> bool {
matches!(self.connectivity, Connectivity::Offline)
}
pub fn retry_policy(&self) -> ExponentialBackoff {
retry_policy(self.retries, self.no_retry_delay)
}
pub fn build(&self) -> Result<BaseClient, ClientBuildError> {
if let Some(name) = self.client_name {
debug!(
"Using request connect timeout of {}s and read timeout of {}s for {} client",
self.connect_timeout.as_secs(),
self.read_timeout.as_secs(),
name
);
} else {
debug!(
"Using request connect timeout of {}s and read timeout of {}s",
self.connect_timeout.as_secs(),
self.read_timeout.as_secs()
);
}
let (raw_client, raw_dangerous_client) = match &self.custom_client {
Some(client) => (client.clone(), client.clone()),
None => {
self.create_secure_and_insecure_clients(self.read_timeout, self.connect_timeout)?
}
};
let client = RedirectClientWithMiddleware {
client: self.apply_middleware(raw_client.clone()),
redirect_policy: self.redirect_policy,
cross_origin_credentials_policy: self.cross_origin_credential_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
cross_origin_credentials_policy: self.cross_origin_credential_policy,
};
Ok(BaseClient {
connectivity: self.connectivity,
allow_insecure_host: self.allow_insecure_host.clone(),
retries: self.retries,
no_retry_delay: self.no_retry_delay,
client,
raw_client,
dangerous_client,
raw_dangerous_client,
read_timeout: self.read_timeout,
connect_timeout: self.connect_timeout,
credentials_cache: self.credentials_cache.clone(),
})
}
pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient {
let client = RedirectClientWithMiddleware {
client: self.apply_middleware(existing.raw_client.clone()),
redirect_policy: self.redirect_policy,
cross_origin_credentials_policy: self.cross_origin_credential_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(existing.raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
cross_origin_credentials_policy: self.cross_origin_credential_policy,
};
BaseClient {
connectivity: self.connectivity,
allow_insecure_host: self.allow_insecure_host.clone(),
retries: self.retries,
no_retry_delay: self.no_retry_delay,
client,
dangerous_client,
raw_client: existing.raw_client.clone(),
raw_dangerous_client: existing.raw_dangerous_client.clone(),
read_timeout: existing.read_timeout,
connect_timeout: existing.connect_timeout,
credentials_cache: existing.credentials_cache.clone(),
}
}
fn create_secure_and_insecure_clients(
&self,
read_timeout: Duration,
connect_timeout: Duration,
) -> Result<(Client, Client), ClientBuildError> {
let mut user_agent_string = format!("uv/{}", version());
let linehaul = LineHaul::new(self.markers, self.platform, self.subcommand.clone());
if let Ok(output) = serde_json::to_string(&linehaul) {
let _ = write!(user_agent_string, " {output}");
}
let custom_certs = Certificates::from_env().map(|certs| certs.to_reqwest_certs());
let raw_client = self.create_client(
&user_agent_string,
read_timeout,
connect_timeout,
custom_certs.clone(),
Security::Secure,
self.redirect_policy,
)?;
let raw_dangerous_client = self.create_client(
&user_agent_string,
read_timeout,
connect_timeout,
custom_certs,
Security::Insecure,
self.redirect_policy,
)?;
Ok((raw_client, raw_dangerous_client))
}
fn create_client(
&self,
user_agent: &str,
read_timeout: Duration,
connect_timeout: Duration,
custom_certs: Option<Vec<Certificate>>,
security: Security,
redirect_policy: RedirectPolicy,
) -> Result<Client, ClientBuildError> {
let client_builder = ClientBuilder::new()
.http1_title_case_headers()
.user_agent(user_agent)
.pool_max_idle_per_host(20)
.read_timeout(read_timeout)
.connect_timeout(connect_timeout)
.redirect(redirect_policy.reqwest_policy());
let client_builder = match security {
Security::Secure => client_builder,
Security::Insecure => client_builder.danger_accept_invalid_certs(true),
};
let client_builder = client_builder.tls_backend_rustls();
let client_builder = if let Some(custom_certs) = custom_certs {
client_builder.tls_certs_only(custom_certs)
} else if self.system_certs {
client_builder
} else {
client_builder.tls_certs_only(Certificates::webpki_roots().to_reqwest_certs())
};
let client_builder = if let Some(ssl_client_cert) = env::var_os(EnvVars::SSL_CLIENT_CERT) {
match read_identity(&ssl_client_cert) {
Ok(identity) => client_builder.identity(identity),
Err(err) => {
warn_user_once!("Ignoring invalid `SSL_CLIENT_CERT`: {err}");
client_builder
}
}
} else {
client_builder
};
let mut client_builder = client_builder;
for p in &self.proxies {
client_builder = client_builder.proxy(p.clone());
}
let no_proxy = self
.no_proxy
.as_ref()
.and_then(|no_proxy| NoProxy::from_string(&no_proxy.join(",")));
if let Some(http_proxy) = &self.http_proxy {
let proxy = http_proxy
.as_proxy(ProxyUrlKind::Http)
.no_proxy(no_proxy.clone());
client_builder = client_builder.proxy(proxy);
}
if let Some(https_proxy) = &self.https_proxy {
let proxy = https_proxy.as_proxy(ProxyUrlKind::Https).no_proxy(no_proxy);
client_builder = client_builder.proxy(proxy);
}
client_builder.build().map_err(Into::into)
}
fn apply_middleware(&self, client: Client) -> ClientWithMiddleware {
match self.connectivity {
Connectivity::Online => {
let base_client = {
let mut client = reqwest_middleware::ClientBuilder::new(client.clone());
if self.retries > 0 {
let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy(
self.retry_policy(),
UvRetryableStrategy,
);
client = client.with(retry_strategy);
}
if let Some(extra_middleware) = &self.extra_middleware {
for middleware in &extra_middleware.0 {
client = client.with_arc(middleware.clone());
}
}
client.build()
};
let mut client = reqwest_middleware::ClientBuilder::new(client);
if self.retries > 0 {
let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy(
self.retry_policy(),
UvRetryableStrategy,
);
client = client.with(retry_strategy);
}
if let Some(extra_middleware) = &self.extra_middleware {
for middleware in &extra_middleware.0 {
client = client.with_arc(middleware.clone());
}
}
match self.auth_integration {
AuthIntegration::Default => {
let mut auth_middleware = AuthMiddleware::new()
.with_cache_arc(self.credentials_cache.clone())
.with_base_client(base_client)
.with_indexes(self.indexes.clone())
.with_keyring(self.keyring.to_provider())
.with_preview(self.preview);
if let Ok(token_store) = PyxTokenStore::from_settings() {
auth_middleware = auth_middleware.with_pyx_token_store(token_store);
}
client = client.with(auth_middleware);
}
AuthIntegration::OnlyAuthenticated => {
let mut auth_middleware = AuthMiddleware::new()
.with_cache_arc(self.credentials_cache.clone())
.with_base_client(base_client)
.with_indexes(self.indexes.clone())
.with_keyring(self.keyring.to_provider())
.with_preview(self.preview)
.with_only_authenticated(true);
if let Ok(token_store) = PyxTokenStore::from_settings() {
auth_middleware = auth_middleware.with_pyx_token_store(token_store);
}
client = client.with(auth_middleware);
}
AuthIntegration::NoAuthMiddleware => {
}
}
client.build()
}
Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client)
.with(OfflineMiddleware)
.build(),
}
}
}
#[derive(Debug, Clone)]
pub struct BaseClient {
client: RedirectClientWithMiddleware,
dangerous_client: RedirectClientWithMiddleware,
raw_client: Client,
raw_dangerous_client: Client,
connectivity: Connectivity,
read_timeout: Duration,
connect_timeout: Duration,
allow_insecure_host: Vec<TrustedHost>,
retries: u32,
no_retry_delay: bool,
credentials_cache: Arc<CredentialsCache>,
}
#[derive(Debug, Clone, Copy)]
enum Security {
Secure,
Insecure,
}
impl BaseClient {
pub fn for_host(&self, url: &DisplaySafeUrl) -> &RedirectClientWithMiddleware {
if self.disable_ssl(url) {
&self.dangerous_client
} else {
&self.client
}
}
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
let client = self.for_host(&DisplaySafeUrl::from_url(req.url().clone()));
client.execute(req).await
}
pub fn disable_ssl(&self, url: &DisplaySafeUrl) -> bool {
self.allow_insecure_host
.iter()
.any(|allow_insecure_host| allow_insecure_host.matches(url))
}
pub fn read_timeout(&self) -> Duration {
self.read_timeout
}
pub fn connect_timeout(&self) -> Duration {
self.connect_timeout
}
pub fn connectivity(&self) -> Connectivity {
self.connectivity
}
pub fn retry_policy(&self) -> ExponentialBackoff {
retry_policy(self.retries, self.no_retry_delay)
}
pub fn credentials_cache(&self) -> &CredentialsCache {
&self.credentials_cache
}
pub fn raw_client(&self) -> &Client {
&self.raw_client
}
}
#[derive(Debug, Clone)]
pub struct RedirectClientWithMiddleware {
client: ClientWithMiddleware,
redirect_policy: RedirectPolicy,
cross_origin_credentials_policy: CrossOriginCredentialsPolicy,
}
impl RedirectClientWithMiddleware {
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<'_> {
RequestBuilder::new(self.client.get(url), self)
}
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder<'_> {
RequestBuilder::new(self.client.post(url), self)
}
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder<'_> {
RequestBuilder::new(self.client.head(url), self)
}
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
match self.redirect_policy {
RedirectPolicy::BypassMiddleware => self.client.execute(req).await,
RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await,
RedirectPolicy::NoRedirect => self.client.execute(req).await,
}
}
async fn execute_with_redirect_handling(
&self,
req: Request,
) -> reqwest_middleware::Result<Response> {
let mut request = req;
let mut redirects = 0;
let max_redirects = DEFAULT_MAX_REDIRECTS;
loop {
let result = self
.client
.execute(request.try_clone().expect("HTTP request must be cloneable"))
.await;
let Ok(response) = result else {
return result;
};
if redirects >= max_redirects {
return Ok(response);
}
let Some(redirect_request) =
request_into_redirect(request, &response, self.cross_origin_credentials_policy)?
else {
return Ok(response);
};
redirects += 1;
request = redirect_request;
}
}
pub fn raw_client(&self) -> &ClientWithMiddleware {
&self.client
}
}
impl From<RedirectClientWithMiddleware> for ClientWithMiddleware {
fn from(item: RedirectClientWithMiddleware) -> Self {
item.client
}
}
fn request_into_redirect(
mut req: Request,
res: &Response,
cross_origin_credentials_policy: CrossOriginCredentialsPolicy,
) -> reqwest_middleware::Result<Option<Request>> {
let original_req_url = DisplaySafeUrl::from_url(req.url().clone());
let status = res.status();
let should_redirect = match status {
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT => true,
StatusCode::SEE_OTHER => {
*req.body_mut() = None;
for header in &[
TRANSFER_ENCODING,
CONTENT_ENCODING,
CONTENT_TYPE,
CONTENT_LENGTH,
] {
req.headers_mut().remove(header);
}
match *req.method() {
Method::GET | Method::HEAD => {}
_ => {
*req.method_mut() = Method::GET;
}
}
true
}
_ => false,
};
if !should_redirect {
return Ok(None);
}
let location = res
.headers()
.get(LOCATION)
.ok_or(reqwest_middleware::Error::Middleware(anyhow!(
"Server returned redirect (HTTP {status}) without destination URL. This may indicate a server configuration issue"
)))?
.to_str()
.map_err(|_| {
reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value: must only contain visible ascii characters"
))
})?;
let mut redirect_url = match DisplaySafeUrl::parse(location) {
Ok(url) => url,
Err(DisplaySafeUrlError::Url(ParseError::RelativeUrlWithoutBase)) => original_req_url.join(location).map_err(|err| {
reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}` relative to `{original_req_url}`: {err}"
))
})?,
Err(err) => {
return Err(reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}`: {err}"
)));
}
};
if let Some(fragment) = original_req_url.fragment() {
redirect_url.set_fragment(Some(fragment));
}
if let Err(err) = redirect_url.as_str().parse::<http::Uri>() {
return Err(reqwest_middleware::Error::Middleware(anyhow!(
"HTTP {status} 'Location' value `{redirect_url}` is not a valid HTTP URI: {err}"
)));
}
if redirect_url.scheme() != "http" && redirect_url.scheme() != "https" {
return Err(reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{redirect_url}`: scheme needs to be https or http"
)));
}
let mut headers = HeaderMap::new();
std::mem::swap(req.headers_mut(), &mut headers);
let cross_host = redirect_url.host_str() != original_req_url.host_str()
|| redirect_url.port_or_known_default() != original_req_url.port_or_known_default();
if cross_host {
if cross_origin_credentials_policy == CrossOriginCredentialsPolicy::Secure {
debug!("Received a cross-origin redirect. Removing sensitive headers.");
headers.remove(AUTHORIZATION);
headers.remove(COOKIE);
headers.remove(PROXY_AUTHORIZATION);
headers.remove(WWW_AUTHENTICATE);
}
} else if headers.contains_key(REFERER) {
if let Some(referer) = make_referer(&redirect_url, &original_req_url) {
headers.insert(REFERER, referer);
}
}
if !redirect_url.username().is_empty() {
if let Some(credentials) = Credentials::from_url(&redirect_url) {
let _ = redirect_url.set_username("");
let _ = redirect_url.set_password(None);
headers.insert(AUTHORIZATION, credentials.to_header_value());
}
}
std::mem::swap(req.headers_mut(), &mut headers);
*req.url_mut() = Url::from(redirect_url);
debug!(
"Received HTTP {status}. Redirecting to {}",
DisplaySafeUrl::ref_cast(req.url())
);
Ok(Some(req))
}
fn make_referer(
redirect_url: &DisplaySafeUrl,
original_url: &DisplaySafeUrl,
) -> Option<HeaderValue> {
if redirect_url.scheme() == "http" && original_url.scheme() == "https" {
return None;
}
let mut referer = original_url.clone();
referer.remove_credentials();
referer.set_fragment(None);
referer.as_str().parse().ok()
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub(crate) enum CrossOriginCredentialsPolicy {
#[default]
Secure,
#[cfg(test)]
Insecure,
}
#[derive(Debug)]
#[must_use]
pub struct RequestBuilder<'a> {
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
}
impl<'a> RequestBuilder<'a> {
pub fn new(
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
) -> Self {
Self { builder, client }
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.builder = self.builder.headers(headers);
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn version(mut self, version: reqwest::Version) -> Self {
self.builder = self.builder.version(version);
self
}
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
pub fn multipart(mut self, multipart: multipart::Form) -> Self {
self.builder = self.builder.multipart(multipart);
self
}
pub fn build(self) -> reqwest::Result<Request> {
self.builder.build()
}
pub async fn send(self) -> reqwest_middleware::Result<Response> {
self.client.execute(self.build()?).await
}
pub fn raw_builder(&self) -> &reqwest_middleware::RequestBuilder {
&self.builder
}
}
fn retry_policy(retries: u32, no_retry_delay: bool) -> ExponentialBackoff {
let mut builder = ExponentialBackoff::builder();
if no_retry_delay {
builder = builder.retry_bounds(Duration::from_millis(0), Duration::from_millis(0));
} else {
builder = builder
.jitter(Jitter::Bounded)
.retry_bounds(Duration::from_secs(2), Duration::from_secs(30));
}
builder.build_with_max_retries(retries)
}
pub struct UvRetryableStrategy;
impl RetryableStrategy for UvRetryableStrategy {
fn handle(&self, res: &Result<Response, reqwest_middleware::Error>) -> Option<Retryable> {
let retryable = match res {
Ok(success) => default_on_request_success(success),
Err(err) => retryable_on_request_failure(err),
};
if retryable == Some(Retryable::Transient) {
match res {
Ok(response) => {
debug!("Transient request failure for: {}", response.url());
}
Err(err) => {
let context = iter::successors(err.source(), |&err| err.source())
.map(|err| format!(" Caused by: {err}"))
.join("\n");
debug!(
"Transient request failure for {}, retrying: {err}\n{context}",
err.url().map(Url::as_str).unwrap_or("unknown URL")
);
}
}
}
retryable
}
}
pub fn retryable_on_request_failure(err: &(dyn Error + 'static)) -> Option<Retryable> {
if let Some((Some(status), Some(url))) = find_source::<WrappedReqwestError>(&err)
.map(|request_err| (request_err.status(), request_err.url()))
{
trace!(
"Considering retry of response HTTP {status} for {url}",
url = DisplaySafeUrl::from_url(url.clone())
);
} else {
trace!("Considering retry of error: {err:?}");
}
let mut has_known_error = false;
let mut current_source = Some(err);
while let Some(source) = current_source {
let reqwest_err = if let Some(reqwest_err) = source.downcast_ref::<reqwest::Error>() {
Some(reqwest_err)
} else if let Some(reqwest_err) = source
.downcast_ref::<WrappedReqwestError>()
.and_then(|err| err.inner())
{
Some(reqwest_err)
} else if let Some(reqwest_middleware::Error::Reqwest(reqwest_err)) =
source.downcast_ref::<reqwest_middleware::Error>()
{
Some(reqwest_err)
} else {
None
};
if let Some(reqwest_err) = reqwest_err {
has_known_error = true;
if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
trace!("Transient nested reqwest error");
return Some(Retryable::Transient);
}
if is_retryable_status_error(reqwest_err) {
trace!("Transient nested reqwest status code error");
return Some(Retryable::Transient);
}
trace!("Fatal nested reqwest error");
} else if source.downcast_ref::<h2::Error>().is_some() {
trace!("Transient nested h2 error");
return Some(Retryable::Transient);
} else if let Some(io_err) = source.downcast_ref::<io::Error>() {
has_known_error = true;
let retryable_io_err_kinds = [
io::ErrorKind::BrokenPipe,
io::ErrorKind::ConnectionAborted,
io::ErrorKind::ConnectionReset,
io::ErrorKind::InvalidData,
io::ErrorKind::TimedOut,
io::ErrorKind::UnexpectedEof,
];
if retryable_io_err_kinds.contains(&io_err.kind()) {
trace!("Transient IO error: `{}`", io_err.kind());
return Some(Retryable::Transient);
}
trace!(
"Fatal IO error `{}`, not a transient IO error kind",
io_err.kind()
);
}
current_source = source.source();
}
if !has_known_error {
trace!("Cannot retry error: neither an IO error nor a reqwest error");
}
None
}
pub struct RetryState {
retry_policy: ExponentialBackoff,
start_time: SystemTime,
total_retries: u32,
url: DisplaySafeUrl,
}
impl RetryState {
pub fn start(retry_policy: ExponentialBackoff, url: impl Into<DisplaySafeUrl>) -> Self {
Self {
retry_policy,
start_time: SystemTime::now(),
total_retries: 0,
url: url.into(),
}
}
pub fn total_retries(&self) -> u32 {
self.total_retries
}
pub fn duration(&self) -> Result<Duration, SystemTimeError> {
self.start_time.elapsed()
}
#[must_use]
pub fn should_retry(
&mut self,
err: &(dyn Error + 'static),
error_retries: u32,
) -> Option<Duration> {
self.total_retries += error_retries;
match retryable_on_request_failure(err) {
Some(Retryable::Transient) => {
let now = SystemTime::now();
let retry_decision = self
.retry_policy
.should_retry(self.start_time, self.total_retries);
if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision {
let duration = execute_after
.duration_since(now)
.unwrap_or_else(|_| Duration::default());
self.total_retries += 1;
return Some(duration);
}
None
}
Some(Retryable::Fatal) | None => None,
}
}
pub async fn sleep_backoff(&self, duration: Duration) {
debug!(
"Transient failure while handling response from {}; retrying after {:.1}s...",
self.url,
duration.as_secs_f32(),
);
tokio::time::sleep(duration).await;
}
}
pub trait RetriableError: std::error::Error + Sized + 'static {
fn should_try_next_url(&self) -> bool;
fn retries(&self) -> u32;
#[must_use]
fn into_retried(self, retries: u32, duration: Duration) -> Self;
}
pub async fn fetch_with_url_fallback<T, E, F>(
urls: &[DisplaySafeUrl],
retry_policy: ExponentialBackoff,
subject: &str,
mut attempt: F,
) -> Result<T, E>
where
F: AsyncFnMut(DisplaySafeUrl) -> Result<T, E>,
E: RetriableError + From<SystemTimeError>,
{
let mut retry_state = RetryState::start(
retry_policy,
urls.last().expect("urls must not be empty").clone(),
);
'retry: loop {
for (i, url) in urls.iter().enumerate() {
let is_last = i == urls.len() - 1;
match attempt(url.clone()).await {
Ok(result) => return Ok(result),
Err(err) => {
if !is_last && err.should_try_next_url() {
warn!(
"Failed to fetch {subject} from {url} ({err}); falling back to {}",
urls[i + 1]
);
continue;
}
if let Some(backoff) = retry_state.should_retry(&err, err.retries()) {
retry_state.sleep_backoff(backoff).await;
continue 'retry;
}
return if retry_state.total_retries() > 0 {
let retries = retry_state.total_retries();
Err(err.into_retried(retries, retry_state.duration()?))
} else {
Err(err)
};
}
}
}
unreachable!("urls must not be empty");
}
}
fn is_retryable_status_error(reqwest_err: &reqwest::Error) -> bool {
let Some(status) = reqwest_err.status() else {
return false;
};
status.is_server_error()
|| status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
}
fn find_source<E: Error + 'static>(orig: &dyn Error) -> Option<&E> {
let mut cause = orig.source();
while let Some(err) = cause {
if let Some(typed) = err.downcast_ref() {
return Some(typed);
}
cause = err.source();
}
None
}
#[derive(Debug, Error)]
pub enum RetryParsingError {
#[error("Failed to parse `UV_HTTP_RETRIES`")]
ParseInt(#[from] ParseIntError),
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use insta::assert_debug_snapshot;
use reqwest::{Client, Method};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use crate::base_client::request_into_redirect;
#[tokio::test]
async fn test_redirect_preserves_authorization_header_on_same_origin() -> Result<()> {
for status in &[301, 302, 303, 307, 308] {
let server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(
ResponseTemplate::new(*status)
.insert_header("location", format!("{}/redirect", server.uri())),
)
.mount(&server)
.await;
let request = Client::new()
.get(server.uri())
.basic_auth("username", Some("password"))
.build()
.unwrap();
assert!(request.headers().contains_key(AUTHORIZATION));
let response = Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap()
.execute(request.try_clone().unwrap())
.await
.unwrap();
let redirect_request =
request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)?
.unwrap();
assert!(redirect_request.headers().contains_key(AUTHORIZATION));
}
Ok(())
}
#[tokio::test]
async fn test_redirect_preserves_fragment() -> Result<()> {
for status in &[301, 302, 303, 307, 308] {
let server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(
ResponseTemplate::new(*status)
.insert_header("location", format!("{}/redirect", server.uri())),
)
.mount(&server)
.await;
let request = Client::new()
.get(format!("{}#fragment", server.uri()))
.build()
.unwrap();
let response = Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap()
.execute(request.try_clone().unwrap())
.await
.unwrap();
let redirect_request =
request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)?
.unwrap();
assert!(
redirect_request
.url()
.fragment()
.is_some_and(|fragment| fragment == "fragment")
);
}
Ok(())
}
#[tokio::test]
async fn test_redirect_removes_authorization_header_on_cross_origin() -> Result<()> {
for status in &[301, 302, 303, 307, 308] {
let server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(
ResponseTemplate::new(*status)
.insert_header("location", "https://cross-origin.com/simple"),
)
.mount(&server)
.await;
let request = Client::new()
.get(server.uri())
.basic_auth("username", Some("password"))
.build()
.unwrap();
assert!(request.headers().contains_key(AUTHORIZATION));
let response = Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap()
.execute(request.try_clone().unwrap())
.await
.unwrap();
let redirect_request =
request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)?
.unwrap();
assert!(!redirect_request.headers().contains_key(AUTHORIZATION));
}
Ok(())
}
#[tokio::test]
async fn test_redirect_303_changes_post_to_get() -> Result<()> {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(303)
.insert_header("location", format!("{}/redirect", server.uri())),
)
.mount(&server)
.await;
let request = Client::new()
.post(server.uri())
.basic_auth("username", Some("password"))
.build()
.unwrap();
assert_eq!(request.method(), Method::POST);
let response = Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap()
.execute(request.try_clone().unwrap())
.await
.unwrap();
let redirect_request =
request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)?
.unwrap();
assert_eq!(redirect_request.method(), Method::GET);
Ok(())
}
#[tokio::test]
async fn test_redirect_no_referer_if_disabled() -> Result<()> {
for status in &[301, 302, 303, 307, 308] {
let server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(
ResponseTemplate::new(*status)
.insert_header("location", format!("{}/redirect", server.uri())),
)
.mount(&server)
.await;
let request = Client::builder()
.referer(false)
.build()
.unwrap()
.get(server.uri())
.basic_auth("username", Some("password"))
.build()
.unwrap();
assert!(!request.headers().contains_key(REFERER));
let response = Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap()
.execute(request.try_clone().unwrap())
.await
.unwrap();
let redirect_request =
request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)?
.unwrap();
assert!(!redirect_request.headers().contains_key(REFERER));
}
Ok(())
}
#[tokio::test]
async fn retried_status_codes() -> Result<()> {
let server = MockServer::start().await;
let client = Client::default();
let middleware_client = ClientWithMiddleware::default();
let mut retried = Vec::new();
for status in 100..599 {
if StatusCode::from_u16(status)?.canonical_reason().is_none() && status != 420 {
continue;
}
Mock::given(path(format!("/{status}")))
.respond_with(ResponseTemplate::new(status))
.mount(&server)
.await;
let response = middleware_client
.get(format!("{}/{}", server.uri(), status))
.send()
.await;
let middleware_retry =
UvRetryableStrategy.handle(&response) == Some(Retryable::Transient);
let response = client
.get(format!("{}/{}", server.uri(), status))
.send()
.await?;
let uv_retry = match response.error_for_status() {
Ok(_) => false,
Err(err) => retryable_on_request_failure(&err) == Some(Retryable::Transient),
};
assert_eq!(middleware_retry, uv_retry);
if uv_retry {
retried.push(status);
}
}
assert_debug_snapshot!(retried, @"
[
100,
102,
103,
408,
429,
500,
501,
502,
503,
504,
505,
506,
507,
508,
510,
511,
]
");
Ok(())
}
}