use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use crate::browser_emulation::BrowserProfile;
use crate::cookie::{CookieJar, CookieMiddleware};
use crate::decode::CompressionMode;
use crate::dns::{DnsCache, DnsConfig};
use crate::error::Result;
use crate::header::HeaderMap;
use crate::middleware::Middleware;
use crate::pool::PoolConfig;
use crate::progress::ProgressConfig;
use crate::protocol::grpc::GrpcRequestBuilder;
use crate::protocol::http1::Http1Transport;
use crate::protocol::websocket::WebSocketBuilder;
use crate::proxy::Proxy;
use crate::request::{
H2KeepAliveConfig, Method, ProgressCallback, ProtocolPolicy, Request, RequestBuilder,
TimeoutConfig,
};
use crate::response::Response;
use crate::retry::RetryPolicy;
use crate::tls::{RootStore, TlsBackend, TlsConfig};
use crate::url::Url;
pub(crate) type ResponseFuture = Pin<Box<dyn Future<Output = Result<Response>> + Send + 'static>>;
pub trait Transport: Send + Sync {
fn execute(&self, request: Request) -> ResponseFuture;
fn execute_with_redirect(&self, request: Request, _policy: RedirectPolicy) -> ResponseFuture {
self.execute(request)
}
fn close(&self) -> Result<()> {
Ok(())
}
}
#[derive(Clone)]
pub struct Client {
transport: Arc<dyn Transport>,
base_url: Option<Url>,
headers: HeaderMap,
cookies: Vec<(String, String)>,
redirect_policy: RedirectPolicy,
timeout_config: TimeoutConfig,
retry_policy: RetryPolicy,
prior_knowledge_h2c: bool,
middlewares: Vec<Arc<dyn Middleware>>,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
h2_keepalive_config: H2KeepAliveConfig,
cookie_jar: Option<CookieJar>,
tls_config: TlsConfig,
proxy: Option<Proxy>,
local_addr: Option<SocketAddr>,
compression_mode: CompressionMode,
protocol_policy: ProtocolPolicy,
browser_profile: Option<BrowserProfile>,
}
#[derive(Default)]
pub struct ClientBuilder {
base_url: Option<Url>,
headers: HeaderMap,
cookies: Vec<(String, String)>,
redirect_policy: RedirectPolicy,
timeout_config: TimeoutConfig,
retry_policy: RetryPolicy,
prior_knowledge_h2c: bool,
middlewares: Vec<Arc<dyn Middleware>>,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
h2_keepalive_config: H2KeepAliveConfig,
cookie_jar: Option<CookieJar>,
tls_config: TlsConfig,
proxy: Option<Proxy>,
dns_config: DnsConfig,
dns_prefetch_hosts: Vec<String>,
local_addr: Option<SocketAddr>,
compression_mode: CompressionMode,
protocol_policy: ProtocolPolicy,
pool_config: PoolConfig,
browser_profile: Option<BrowserProfile>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum RedirectPolicy {
None,
Limit(usize),
}
impl Default for RedirectPolicy {
fn default() -> Self {
Self::Limit(10)
}
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder::default()
}
pub fn get(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Get, url)
}
pub fn post(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Post, url)
}
pub fn put(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Put, url)
}
pub fn patch(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Patch, url)
}
pub fn delete(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Delete, url)
}
pub fn head(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Head, url)
}
pub fn options(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Options, url)
}
pub fn trace(&self, url: impl AsRef<str>) -> RequestBuilder {
self.request(Method::Trace, url)
}
pub fn ws(&self, url: impl AsRef<str>) -> WebSocketBuilder {
WebSocketBuilder::new(url.as_ref()).with_client_defaults(
self.base_url.clone(),
self.headers.clone(),
self.cookies.clone(),
self.timeout_config,
self.tls_config.clone(),
self.cookie_jar.clone(),
self.local_addr,
)
}
pub fn grpc(&self, url: impl AsRef<str>) -> GrpcRequestBuilder {
GrpcRequestBuilder::from_request_builder(self.post(url.as_ref()).http2_only())
}
pub fn download(&self, url: impl AsRef<str>) -> crate::download::DownloadBuilder<'_> {
crate::download::DownloadBuilder::new(self, url.as_ref())
}
pub fn request(&self, method: Method, url: impl AsRef<str>) -> RequestBuilder {
RequestBuilder::new(Arc::clone(&self.transport), method, url.as_ref()).with_client_defaults(
self.base_url.clone(),
self.headers.clone(),
self.cookies.clone(),
self.redirect_policy,
self.timeout_config,
self.protocol_policy,
self.retry_policy,
self.prior_knowledge_h2c,
self.middlewares.clone(),
self.progress_callback.clone(),
self.progress_config,
self.h2_keepalive_config,
self.tls_config.clone(),
self.proxy.clone(),
self.compression_mode,
self.browser_profile.clone(),
)
}
pub async fn close(&self) -> Result<()> {
self.transport.close()
}
pub fn cookie_jar(&self) -> Option<CookieJar> {
self.cookie_jar.clone()
}
}
impl ClientBuilder {
pub fn base_url(mut self, url: impl AsRef<str>) -> Result<Self> {
self.base_url = Some(Url::parse(url)?);
Ok(self)
}
pub fn header(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
self.headers.insert(name, value)?;
Ok(self)
}
pub fn headers<I, K, V>(mut self, headers: I) -> Result<Self>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
for (name, value) in headers {
self.headers.append(name, value)?;
}
Ok(self)
}
pub fn cookie(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Self {
self.cookies
.push((name.as_ref().to_owned(), value.as_ref().to_owned()));
self
}
pub fn cookies<I, K, V>(mut self, cookies: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
self.cookies.extend(
cookies
.into_iter()
.map(|(name, value)| (name.as_ref().to_owned(), value.as_ref().to_owned())),
);
self
}
pub fn user_agent(self, value: impl AsRef<str>) -> Result<Self> {
self.header("user-agent", value)
}
pub fn bearer_auth(self, token: impl AsRef<str>) -> Result<Self> {
self.header("authorization", format!("Bearer {}", token.as_ref()))
}
pub fn basic_auth(self, username: impl AsRef<str>, password: impl AsRef<str>) -> Result<Self> {
let raw = format!("{}:{}", username.as_ref(), password.as_ref());
let encoded = encode_base64(raw.as_bytes());
self.header("authorization", format!("Basic {encoded}"))
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout_config.total = Some(duration);
self
}
pub fn connect_timeout(mut self, duration: Duration) -> Self {
self.timeout_config.connect = Some(duration);
self
}
pub fn read_timeout(mut self, duration: Duration) -> Self {
self.timeout_config.read = Some(duration);
self
}
pub fn write_timeout(mut self, duration: Duration) -> Self {
self.timeout_config.write = Some(duration);
self
}
pub fn redirect(mut self, policy: RedirectPolicy) -> Self {
self.redirect_policy = policy;
self
}
pub fn retry(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = policy;
self
}
pub fn middleware<M>(mut self, middleware: M) -> Self
where
M: Middleware + 'static,
{
self.middlewares.push(Arc::new(middleware));
self
}
pub fn on_progress<F>(mut self, callback: F, config: ProgressConfig) -> Self
where
F: Fn(crate::Progress) + Send + Sync + 'static,
{
self.progress_callback = Some(Arc::new(callback));
self.progress_config = config;
self
}
pub fn h2_keepalive(mut self, idle_timeout: Duration, ack_timeout: Duration) -> Self {
self.h2_keepalive_config = H2KeepAliveConfig {
idle_timeout: Some(idle_timeout),
ack_timeout,
};
self
}
pub fn disable_h2_keepalive(mut self) -> Self {
self.h2_keepalive_config.idle_timeout = None;
self
}
pub fn cookie_jar(mut self, jar: CookieJar) -> Self {
self.cookie_jar = Some(jar);
self
}
pub fn cookie_store(mut self) -> Self {
self.cookie_jar = Some(CookieJar::new());
self
}
pub fn tls_config(mut self, tls_config: TlsConfig) -> Self {
self.tls_config = tls_config;
self
}
pub fn root_store(mut self, root_store: RootStore) -> Self {
self.tls_config = self.tls_config.clone().root_store(root_store);
self
}
pub fn dns_cache_ttl(mut self, ttl: Duration) -> Self {
self.dns_config.ttl = ttl;
self
}
pub fn dns_prefetch<I, S>(mut self, hosts: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.dns_prefetch_hosts = hosts
.into_iter()
.map(|host| host.as_ref().trim().to_owned())
.filter(|host| !host.is_empty())
.collect();
self
}
pub fn local_addr(mut self, addr: SocketAddr) -> Self {
self.local_addr = Some(addr);
self
}
pub fn dual_stack(mut self, enabled: bool) -> Self {
self.dns_config.dual_stack = enabled;
self
}
pub fn pin_certificate(
mut self,
domain: impl AsRef<str>,
fingerprint: impl AsRef<str>,
) -> Result<Self> {
self.tls_config = self
.tls_config
.clone()
.pin_certificate(domain, fingerprint)?;
Ok(self)
}
pub fn proxy(mut self, proxy: impl Into<Proxy>) -> Self {
self.proxy = Some(proxy.into());
self
}
pub fn danger_accept_invalid_certs(mut self, enabled: bool) -> Self {
self.tls_config = self.tls_config.clone().danger_accept_invalid_certs(enabled);
self
}
pub fn tls_backend(mut self, backend: TlsBackend) -> Self {
self.tls_config = self.tls_config.clone().backend(backend);
self
}
#[cfg(feature = "emulation")]
pub fn browser_profile(self, profile: BrowserProfile) -> Self {
self.emulation(profile)
}
#[cfg(feature = "emulation")]
pub fn emulation_profile(self, profile: BrowserProfile) -> Self {
self.emulation(profile)
}
#[cfg(feature = "emulation")]
pub fn emulation<T>(mut self, emulation: T) -> Self
where
T: Into<BrowserProfile>,
{
self.browser_profile = Some(emulation.into());
#[cfg(feature = "h2")]
if self.protocol_policy == ProtocolPolicy::Auto {
self.protocol_policy = ProtocolPolicy::PreferHttp2;
}
self
}
pub fn alpn_protocols<I, S>(mut self, protocols: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.tls_config = self.tls_config.clone().alpn_protocols(protocols);
self
}
pub fn disable_alpn(mut self) -> Self {
self.tls_config = self.tls_config.clone().disable_alpn();
self
}
pub fn protocol_policy(mut self, policy: ProtocolPolicy) -> Self {
self.protocol_policy = policy;
self
}
pub fn prefer_http3(mut self) -> Self {
self.protocol_policy = ProtocolPolicy::PreferHttp3;
self
}
pub fn prefer_http2(mut self) -> Self {
self.protocol_policy = ProtocolPolicy::PreferHttp2;
self
}
pub fn http1_only(mut self) -> Self {
self.protocol_policy = ProtocolPolicy::Http1Only;
self
}
pub fn http2_only(mut self) -> Self {
self.protocol_policy = ProtocolPolicy::Http2Only;
self
}
pub fn http3_only(mut self) -> Self {
self.protocol_policy = ProtocolPolicy::Http3Only;
self
}
pub fn prior_knowledge_h2c(mut self, enabled: bool) -> Self {
self.prior_knowledge_h2c = enabled;
self
}
pub fn compression_mode(mut self, compression_mode: CompressionMode) -> Self {
self.compression_mode = compression_mode;
self
}
pub fn build(self) -> Result<Client> {
let mut this = self;
this.prepare_browser_profile()?;
let mut middlewares = this.middlewares;
if let Some(jar) = &this.cookie_jar {
middlewares.push(Arc::new(CookieMiddleware::new(jar.clone())));
}
let dns_cache = Arc::new(DnsCache::new());
if !this.dns_prefetch_hosts.is_empty() {
dns_cache.prefetch(&this.dns_prefetch_hosts, this.dns_config)?;
}
Ok(Client {
transport: Arc::new(Http1Transport::new(
this.pool_config,
dns_cache,
this.dns_config,
this.local_addr,
)),
base_url: this.base_url,
headers: this.headers,
cookies: this.cookies,
redirect_policy: this.redirect_policy,
timeout_config: this.timeout_config,
retry_policy: this.retry_policy,
prior_knowledge_h2c: this.prior_knowledge_h2c,
middlewares,
progress_callback: this.progress_callback,
progress_config: this.progress_config,
h2_keepalive_config: this.h2_keepalive_config,
cookie_jar: this.cookie_jar,
tls_config: this.tls_config,
proxy: this.proxy,
local_addr: this.local_addr,
compression_mode: this.compression_mode,
protocol_policy: this.protocol_policy,
browser_profile: this.browser_profile,
})
}
#[cfg(feature = "emulation")]
fn prepare_browser_profile(&mut self) -> Result<()> {
if self.browser_profile.is_some() {
self.tls_config = self.tls_config.clone().ensure_emulation_backend()?;
}
Ok(())
}
#[cfg(not(feature = "emulation"))]
fn prepare_browser_profile(&mut self) -> Result<()> {
Ok(())
}
pub fn pool_config(mut self, pool_config: PoolConfig) -> Self {
self.pool_config = pool_config;
self
}
pub fn max_idle_per_host(mut self, max_idle_per_host: usize) -> Self {
self.pool_config.max_idle_per_host = max_idle_per_host;
self
}
pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
self.pool_config.idle_timeout = Some(idle_timeout);
self
}
}
#[cfg(test)]
pub(crate) fn shared_http1_transport() -> Arc<dyn Transport> {
Arc::new(Http1Transport::default()) as Arc<dyn Transport>
}
pub(crate) fn default_client() -> &'static Client {
static DEFAULT_CLIENT: OnceLock<Client> = OnceLock::new();
DEFAULT_CLIENT.get_or_init(|| Client::builder().build().expect("default client to build"))
}
fn encode_base64(bytes: &[u8]) -> String {
crate::util::encode_base64(bytes)
}
#[cfg(test)]
mod tests {
use crate::ProgressConfig;
#[cfg(feature = "emulation")]
use crate::{BrowserProfile, Emulation};
use futures_lite::future::block_on;
use std::time::Duration;
use super::Client;
use crate::ProtocolPolicy;
use crate::RetryPolicy;
fn run<T>(value: T) -> T::Output
where
T: std::future::IntoFuture,
{
block_on(async move { value.await })
}
use crate::middleware::{Middleware, Next};
use crate::protocol::http1::{spawn_assert_server, spawn_test_server};
use crate::{Request, Response};
struct AddHeaderMiddleware;
#[async_trait::async_trait]
impl Middleware for AddHeaderMiddleware {
async fn handle(&self, mut req: Request, next: Next<'_>) -> crate::Result<Response> {
req.headers_mut().insert("x-from-middleware", "1")?;
next.run(req).await
}
}
struct TagResponseMiddleware;
#[async_trait::async_trait]
impl Middleware for TagResponseMiddleware {
async fn handle(&self, req: Request, next: Next<'_>) -> crate::Result<Response> {
let mut response = next.run(req).await?;
response.headers_mut().insert("x-processed", "yes")?;
Ok(response)
}
}
#[test]
fn client_applies_default_base_url() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
))
.unwrap();
let client = Client::builder().base_url(&base).unwrap().build().unwrap();
let response = run(client.get("/users/1")).unwrap();
assert_eq!(response.url().as_str(), format!("{base}/users/1"));
}
#[cfg(feature = "emulation")]
#[test]
fn emulation_sets_default_headers() {
let client = Client::builder()
.emulation(Emulation::Safari18_4)
.build()
.unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert_eq!(
request.headers().get("user-agent"),
Some(
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/18.4 Safari/605.1.15"
)
);
#[cfg(feature = "h2")]
assert_eq!(client.protocol_policy, ProtocolPolicy::PreferHttp2);
}
#[cfg(feature = "emulation")]
#[test]
fn emulation_accepts_custom_profile() {
let profile = BrowserProfile::builder()
.default_header("user-agent", "CustomAgent/1.0")
.unwrap()
.build();
let client = Client::builder().emulation(profile).build().unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert_eq!(request.headers().get("user-agent"), Some("CustomAgent/1.0"));
#[cfg(feature = "h2")]
assert_eq!(client.protocol_policy, ProtocolPolicy::PreferHttp2);
}
#[cfg(feature = "emulation")]
#[test]
fn emulation_does_not_override_user_headers() {
let client = Client::builder()
.header("user-agent", "MyAgent/1.0")
.unwrap()
.emulation(Emulation::Safari18_4)
.build()
.unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert_eq!(request.headers().get("user-agent"), Some("MyAgent/1.0"));
}
#[test]
fn request_can_override_client_base_url() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
))
.unwrap();
let override_base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
))
.unwrap();
let client = Client::builder().base_url(&base).unwrap().build().unwrap();
let response = run(client.get("/users/1").base_url(&override_base).unwrap()).unwrap();
assert_eq!(response.url().as_str(), format!("{override_base}/users/1"));
}
#[test]
fn client_builder_http3_only_sets_protocol_policy() {
let client = Client::builder().http3_only().build().unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert_eq!(request.protocol_policy(), ProtocolPolicy::Http3Only);
}
#[test]
fn client_builder_http2_only_sets_protocol_policy() {
let client = Client::builder().http2_only().build().unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert_eq!(request.protocol_policy(), ProtocolPolicy::Http2Only);
}
#[test]
fn client_builder_can_enable_prior_knowledge_h2c() {
let client = Client::builder().prior_knowledge_h2c(true).build().unwrap();
let request = client
.get("http://example.com")
.build_request()
.expect("request to build");
assert!(request.prior_knowledge_h2c());
}
#[test]
fn client_builder_can_disable_alpn() {
let client = Client::builder().disable_alpn().build().unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert!(
request
.tls_config()
.effective_alpn_protocols(ProtocolPolicy::Auto)
.is_empty()
);
}
#[test]
fn client_builder_can_set_progress_callback() {
let client = Client::builder()
.on_progress(|_| {}, ProgressConfig::default())
.build()
.unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert!(request.progress_callback().is_some());
}
#[test]
fn client_builder_can_set_retry_policy() {
let client = Client::builder()
.retry(RetryPolicy::Limit(2))
.build()
.unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert_eq!(request.retry_policy(), RetryPolicy::Limit(2));
}
#[test]
fn client_builder_can_prefetch_dns_hosts() {
let client = Client::builder()
.dns_cache_ttl(Duration::from_secs(5))
.dual_stack(true)
.dns_prefetch(["localhost"])
.build()
.unwrap();
let request = client
.get("http://localhost")
.build_request()
.expect("request to build");
assert_eq!(request.url().host(), "localhost");
}
#[test]
fn client_builder_can_set_compression_mode() {
let client = Client::builder()
.compression_mode(crate::CompressionMode::Manual)
.build()
.unwrap();
let request = client
.get("https://example.com")
.build_request()
.expect("request to build");
assert_eq!(request.compression_mode(), crate::CompressionMode::Manual);
}
#[test]
fn middleware_can_mutate_request_before_send() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.contains("\r\nx-from-middleware: 1\r\n"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
))
.unwrap();
let client = Client::builder()
.base_url(&base)
.unwrap()
.middleware(AddHeaderMiddleware)
.build()
.unwrap();
let _ = run(client.get("/middleware")).unwrap();
}
#[test]
fn middleware_can_mutate_response_after_send() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
))
.unwrap();
let client = Client::builder()
.base_url(&base)
.unwrap()
.middleware(TagResponseMiddleware)
.build()
.unwrap();
let response = run(client.get("/middleware")).unwrap();
assert_eq!(response.headers().get("x-processed"), Some("yes"));
}
}