use crate::cx::Cx;
use crate::http::h1::client::{ClientStreamingResponse, Http1Client};
use crate::http::h1::types::{Method, MultipartForm, Request, Response, Version};
use crate::http::pool::{Pool, PoolConfig, PoolKey};
use crate::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use crate::net::tcp::stream::TcpStream;
#[cfg(feature = "tls")]
use crate::tls::{TlsConnectorBuilder, TlsStream};
use crate::types::Time;
use base64::Engine;
use memchr::memmem;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fmt::Write;
use std::future::poll_fn;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
const CONNECT_MAX_HEADERS_SIZE: usize = 64 * 1024;
const MAX_COOKIES_PER_HOST: usize = 64;
const MAX_COOKIE_HOSTS: usize = 256;
const SOCKS5_VERSION: u8 = 0x05;
const SOCKS5_AUTH_NONE: u8 = 0x00;
const SOCKS5_AUTH_USER_PASS: u8 = 0x02;
const SOCKS5_AUTH_NO_ACCEPTABLE: u8 = 0xFF;
#[derive(Debug)]
pub enum ClientError {
InvalidUrl(String),
DnsError(io::Error),
ConnectError(io::Error),
TlsError(String),
HttpError(crate::http::h1::codec::HttpError),
TooManyRedirects {
count: u32,
max: u32,
},
Io(io::Error),
ConnectTunnelRefused {
status: u16,
reason: String,
},
InvalidConnectInput(String),
ProxyError(String),
PoolExhausted {
host: String,
port: u16,
},
Cancelled,
}
impl std::fmt::Display for ClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidUrl(url) => write!(f, "invalid URL: {url}"),
Self::DnsError(e) => write!(f, "DNS resolution failed: {e}"),
Self::ConnectError(e) => write!(f, "connection failed: {e}"),
Self::TlsError(e) => write!(f, "TLS error: {e}"),
Self::HttpError(e) => write!(f, "HTTP error: {e}"),
Self::TooManyRedirects { count, max } => {
write!(f, "too many redirects ({count} of max {max})")
}
Self::Io(e) => write!(f, "I/O error: {e}"),
Self::ConnectTunnelRefused { status, reason } => {
write!(
f,
"HTTP CONNECT tunnel rejected with status {status} ({reason})"
)
}
Self::InvalidConnectInput(msg) => write!(f, "invalid CONNECT input: {msg}"),
Self::ProxyError(msg) => write!(f, "proxy error: {msg}"),
Self::PoolExhausted { host, port } => {
write!(f, "connection pool exhausted for {host}:{port}")
}
Self::Cancelled => write!(f, "operation cancelled"),
}
}
}
impl std::error::Error for ClientError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::DnsError(e) | Self::ConnectError(e) | Self::Io(e) => Some(e),
Self::HttpError(e) => Some(e),
Self::ConnectTunnelRefused { .. }
| Self::InvalidConnectInput(_)
| Self::ProxyError(_)
| Self::PoolExhausted { .. }
| Self::TlsError(_)
| Self::InvalidUrl(_)
| Self::TooManyRedirects { .. }
| Self::Cancelled => None,
}
}
}
impl ClientError {
#[must_use]
pub fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled)
}
}
fn check_cx(cx: &Cx) -> Result<(), ClientError> {
if cx.checkpoint().is_err() {
Err(ClientError::Cancelled)
} else {
Ok(())
}
}
fn wall_clock_now() -> Time {
crate::time::wall_now()
}
impl From<crate::http::h1::codec::HttpError> for ClientError {
fn from(e: crate::http::h1::codec::HttpError) -> Self {
Self::HttpError(e)
}
}
impl From<io::Error> for ClientError {
fn from(e: io::Error) -> Self {
Self::Io(e)
}
}
#[derive(Debug, Clone)]
pub struct ParsedUrl {
pub scheme: Scheme,
pub host: String,
pub port: u16,
pub path: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Scheme {
Http,
Https,
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum ClientIo {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(TlsStream<TcpStream>),
#[cfg(feature = "tls")]
TlsTunnel(Box<TlsStream<HttpConnectTunnel<Self>>>),
}
impl AsyncRead for ClientIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
Self::Plain(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
Self::Tls(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
Self::TlsTunnel(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
}
}
}
impl AsyncWrite for ClientIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
data: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
Self::Plain(s) => Pin::new(s).poll_write(cx, data),
#[cfg(feature = "tls")]
Self::Tls(s) => Pin::new(s).poll_write(cx, data),
#[cfg(feature = "tls")]
Self::TlsTunnel(s) => Pin::new(s.as_mut()).poll_write(cx, data),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
Self::Plain(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "tls")]
Self::Tls(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "tls")]
Self::TlsTunnel(s) => Pin::new(s.as_mut()).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
Self::Plain(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "tls")]
Self::Tls(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "tls")]
Self::TlsTunnel(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
}
}
}
#[derive(Debug)]
pub struct HttpConnectTunnel<T> {
io: T,
prefetched: Vec<u8>,
prefetched_pos: usize,
}
impl<T> HttpConnectTunnel<T> {
fn new(io: T, prefetched: Vec<u8>) -> Self {
Self {
io,
prefetched,
prefetched_pos: 0,
}
}
#[must_use]
pub fn prefetched_len(&self) -> usize {
self.prefetched.len().saturating_sub(self.prefetched_pos)
}
#[must_use]
pub fn into_inner(self) -> T {
self.io
}
}
impl<T> AsyncRead for HttpConnectTunnel<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.prefetched_pos < self.prefetched.len() && buf.remaining() > 0 {
let remaining_prefetched = self.prefetched.len() - self.prefetched_pos;
let to_copy = remaining_prefetched.min(buf.remaining());
buf.put_slice(&self.prefetched[self.prefetched_pos..self.prefetched_pos + to_copy]);
self.prefetched_pos += to_copy;
return Poll::Ready(Ok(()));
}
Pin::new(&mut self.io).poll_read(cx, buf)
}
}
impl<T> AsyncWrite for HttpConnectTunnel<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
data: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, data)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_shutdown(cx)
}
}
impl ParsedUrl {
pub fn parse(url: &str) -> Result<Self, ClientError> {
let (scheme, rest) = if let Some(rest) = url.strip_prefix("https://") {
(Scheme::Https, rest)
} else if let Some(rest) = url.strip_prefix("http://") {
(Scheme::Http, rest)
} else {
return Err(ClientError::InvalidUrl(format!(
"unsupported scheme in: {url}"
)));
};
let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let authority = &rest[..authority_end];
let path_and_rest = &rest[authority_end..];
let path = if path_and_rest.is_empty() {
"/"
} else {
path_and_rest
};
if authority.contains('@') && !authority.starts_with('[') {
return Err(ClientError::InvalidUrl(
"URL must not contain userinfo (user@host)".into(),
));
}
if contains_ctl_or_whitespace(authority) {
return Err(ClientError::InvalidUrl(
"URL authority cannot contain control characters or whitespace".into(),
));
}
let (host, port) = if authority.starts_with('[') {
let bracket_end = authority.find(']').ok_or_else(|| {
ClientError::InvalidUrl("unclosed bracket in IPv6 address".into())
})?;
let host_str = &authority[..=bracket_end];
let rest = &authority[bracket_end + 1..];
if let Some(port_str) = rest.strip_prefix(':') {
let port: u16 = port_str
.parse()
.map_err(|_| ClientError::InvalidUrl(format!("invalid port: {port_str}")))?;
(host_str.to_owned(), port)
} else if rest.is_empty() {
let default_port = match scheme {
Scheme::Http => 80,
Scheme::Https => 443,
};
(host_str.to_owned(), default_port)
} else {
return Err(ClientError::InvalidUrl(format!(
"unexpected characters after IPv6 address: {rest}"
)));
}
} else if authority.matches(':').count() > 1 {
let default_port = match scheme {
Scheme::Http => 80,
Scheme::Https => 443,
};
(authority.to_owned(), default_port)
} else if let Some(i) = authority.rfind(':') {
let port_str = &authority[i + 1..];
let port: u16 = port_str
.parse()
.map_err(|_| ClientError::InvalidUrl(format!("invalid port: {port_str}")))?;
(authority[..i].to_owned(), port)
} else {
let default_port = match scheme {
Scheme::Http => 80,
Scheme::Https => 443,
};
(authority.to_owned(), default_port)
};
if host.is_empty() {
return Err(ClientError::InvalidUrl("empty host".into()));
}
Ok(Self {
scheme,
host,
port,
path: path.to_owned(),
})
}
#[must_use]
pub fn pool_key(&self) -> PoolKey {
PoolKey::new(&self.host, self.port, self.scheme == Scheme::Https)
}
#[must_use]
pub fn authority(&self) -> String {
let default_port = match self.scheme {
Scheme::Http => 80,
Scheme::Https => 443,
};
if self.port == default_port {
self.host.clone()
} else {
format!("{}:{}", self.host, self.port)
}
}
}
#[derive(Debug, Clone)]
pub enum RedirectPolicy {
None,
Limited(u32),
}
impl Default for RedirectPolicy {
fn default() -> Self {
Self::Limited(10)
}
}
#[derive(Debug, Clone, Default)]
pub struct HttpClientBuilder {
config: HttpClientConfig,
}
impl HttpClientBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn pool_config(mut self, pool_config: PoolConfig) -> Self {
self.config.pool_config = pool_config;
self
}
#[must_use]
pub fn max_connections_per_host(mut self, max: usize) -> Self {
self.config.pool_config.max_connections_per_host = max;
self
}
#[must_use]
pub fn max_total_connections(mut self, max: usize) -> Self {
self.config.pool_config.max_total_connections = max;
self
}
#[must_use]
pub fn idle_timeout(mut self, timeout: std::time::Duration) -> Self {
self.config.pool_config.idle_timeout = timeout;
self
}
#[must_use]
pub fn cleanup_interval(mut self, interval: std::time::Duration) -> Self {
self.config.pool_config.cleanup_interval = interval;
self
}
#[must_use]
pub fn redirect_policy(mut self, policy: RedirectPolicy) -> Self {
self.config.redirect_policy = policy;
self
}
#[must_use]
pub fn max_redirects(mut self, max: u32) -> Self {
self.config.redirect_policy = RedirectPolicy::Limited(max);
self
}
#[must_use]
pub fn no_redirects(mut self) -> Self {
self.config.redirect_policy = RedirectPolicy::None;
self
}
#[must_use]
pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.config.user_agent = Some(user_agent.into());
self
}
#[must_use]
pub fn no_user_agent(mut self) -> Self {
self.config.user_agent = None;
self
}
#[must_use]
pub fn cookie_store(mut self, enabled: bool) -> Self {
self.config.cookie_store = enabled;
self
}
#[must_use]
pub fn no_cookie_store(mut self) -> Self {
self.config.cookie_store = false;
self
}
#[must_use]
pub fn max_body_size(mut self, size: usize) -> Self {
self.config.max_body_size = Some(size);
self
}
#[must_use]
pub fn proxy(mut self, proxy_url: impl Into<String>) -> Self {
self.config.proxy_url = Some(proxy_url.into());
self
}
#[must_use]
pub fn no_proxy(mut self) -> Self {
self.config.proxy_url = None;
self
}
#[must_use]
pub fn with_time_getter(mut self, time_getter: fn() -> Time) -> Self {
self.config.time_getter = time_getter;
self
}
#[must_use]
pub fn build(self) -> HttpClient {
HttpClient::with_config(self.config)
}
}
#[derive(Debug, Clone)]
pub struct HttpClientConfig {
pub pool_config: PoolConfig,
pub redirect_policy: RedirectPolicy,
pub user_agent: Option<String>,
pub cookie_store: bool,
pub proxy_url: Option<String>,
pub max_body_size: Option<usize>,
time_getter: fn() -> Time,
}
impl Default for HttpClientConfig {
fn default() -> Self {
Self {
pool_config: PoolConfig::default(),
redirect_policy: RedirectPolicy::default(),
user_agent: Some("asupersync/0.1".into()),
cookie_store: false,
proxy_url: None,
max_body_size: None,
time_getter: wall_clock_now,
}
}
}
impl HttpClientConfig {
#[must_use]
pub const fn with_time_getter(mut self, time_getter: fn() -> Time) -> Self {
self.time_getter = time_getter;
self
}
#[must_use]
pub const fn time_getter(&self) -> fn() -> Time {
self.time_getter
}
}
pub struct HttpClient {
config: HttpClientConfig,
pool: Mutex<Pool>,
idle_connections: Mutex<HashMap<PoolKey, Vec<(u64, ClientIo)>>>,
cookies: Mutex<HashMap<String, Vec<StoredCookie>>>,
}
impl HttpClient {
#[must_use]
pub fn builder() -> HttpClientBuilder {
HttpClientBuilder::new()
}
#[must_use]
pub fn new() -> Self {
Self::with_config(HttpClientConfig::default())
}
#[must_use]
pub fn with_config(config: HttpClientConfig) -> Self {
let pool = Pool::with_config(config.pool_config.clone());
Self {
config,
pool: Mutex::new(pool),
idle_connections: Mutex::new(HashMap::new()),
cookies: Mutex::new(HashMap::new()),
}
}
fn pool_now(&self) -> Time {
(self.config.time_getter)()
}
pub async fn get(&self, cx: &Cx, url: &str) -> Result<Response, ClientError> {
self.request(cx, Method::Get, url, Vec::new(), Vec::new())
.await
}
pub async fn post(&self, cx: &Cx, url: &str, body: Vec<u8>) -> Result<Response, ClientError> {
self.request(cx, Method::Post, url, Vec::new(), body).await
}
pub async fn post_multipart(
&self,
cx: &Cx,
url: &str,
form: &MultipartForm,
) -> Result<Response, ClientError> {
self.request_multipart(cx, Method::Post, url, Vec::new(), form)
.await
}
pub async fn post_streaming(
&self,
cx: &Cx,
url: &str,
body: Vec<u8>,
) -> Result<ClientStreamingResponse<ClientIo>, ClientError> {
self.request_streaming(cx, Method::Post, url, Vec::new(), body)
.await
}
pub async fn post_multipart_streaming(
&self,
cx: &Cx,
url: &str,
form: &MultipartForm,
) -> Result<ClientStreamingResponse<ClientIo>, ClientError> {
self.request_multipart_streaming(cx, Method::Post, url, Vec::new(), form)
.await
}
pub async fn put(&self, cx: &Cx, url: &str, body: Vec<u8>) -> Result<Response, ClientError> {
self.request(cx, Method::Put, url, Vec::new(), body).await
}
pub async fn delete(&self, cx: &Cx, url: &str) -> Result<Response, ClientError> {
self.request(cx, Method::Delete, url, Vec::new(), Vec::new())
.await
}
pub async fn request(
&self,
cx: &Cx,
method: Method,
url: &str,
extra_headers: Vec<(String, String)>,
body: Vec<u8>,
) -> Result<Response, ClientError> {
check_cx(cx)?;
let parsed = ParsedUrl::parse(url)?;
self.execute_with_redirects(cx, method, parsed, extra_headers, body, 0)
.await
}
pub async fn request_multipart(
&self,
cx: &Cx,
method: Method,
url: &str,
mut extra_headers: Vec<(String, String)>,
form: &MultipartForm,
) -> Result<Response, ClientError> {
ensure_multipart_content_type(&mut extra_headers, form);
self.request(cx, method, url, extra_headers, form.to_body())
.await
}
pub async fn request_streaming(
&self,
cx: &Cx,
method: Method,
url: &str,
extra_headers: Vec<(String, String)>,
body: Vec<u8>,
) -> Result<ClientStreamingResponse<ClientIo>, ClientError> {
check_cx(cx)?;
let parsed = ParsedUrl::parse(url)?;
self.execute_with_redirects_streaming(cx, method, parsed, extra_headers, body, 0)
.await
}
pub async fn request_multipart_streaming(
&self,
cx: &Cx,
method: Method,
url: &str,
mut extra_headers: Vec<(String, String)>,
form: &MultipartForm,
) -> Result<ClientStreamingResponse<ClientIo>, ClientError> {
ensure_multipart_content_type(&mut extra_headers, form);
self.request_streaming(cx, method, url, extra_headers, form.to_body())
.await
}
pub async fn connect_tunnel(
&self,
cx: &Cx,
proxy_url: &str,
target_authority: &str,
extra_headers: Vec<(String, String)>,
) -> Result<HttpConnectTunnel<ClientIo>, ClientError> {
check_cx(cx)?;
let proxy = ParsedUrl::parse(proxy_url)?;
let io = self.connect_io(cx, &proxy).await?;
establish_http_connect_tunnel(
io,
target_authority,
self.config.user_agent.as_deref(),
&extra_headers,
)
.await
}
fn execute_with_redirects<'a>(
&'a self,
cx: &'a Cx,
method: Method,
parsed: ParsedUrl,
extra_headers: Vec<(String, String)>,
body: Vec<u8>,
redirect_count: u32,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Response, ClientError>> + Send + 'a>,
> {
Box::pin(async move {
check_cx(cx)?;
let resp =
Box::pin(self.execute_single(cx, &method, &parsed, &extra_headers, &body)).await?;
if is_redirect(resp.status) {
match &self.config.redirect_policy {
RedirectPolicy::None => return Ok(resp),
RedirectPolicy::Limited(max) => {
if redirect_count >= *max {
return Err(ClientError::TooManyRedirects {
count: redirect_count + 1,
max: *max,
});
}
if let Some(location) = get_header(&resp.headers, "Location") {
let next_url = resolve_redirect(&parsed, &location);
let next_parsed = ParsedUrl::parse(&next_url)?;
let next_method = redirect_method(resp.status, &method);
let next_body = if next_method == Method::Get {
Vec::new()
} else {
body
};
let next_headers = strip_sensitive_headers_on_redirect(
&parsed,
&next_parsed,
extra_headers,
);
return self
.execute_with_redirects(
cx,
next_method,
next_parsed,
next_headers,
next_body,
redirect_count + 1,
)
.await;
}
}
}
}
Ok(resp)
})
}
fn execute_with_redirects_streaming<'a>(
&'a self,
cx: &'a Cx,
method: Method,
parsed: ParsedUrl,
extra_headers: Vec<(String, String)>,
body: Vec<u8>,
redirect_count: u32,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<ClientStreamingResponse<ClientIo>, ClientError>>
+ Send
+ 'a,
>,
> {
Box::pin(async move {
check_cx(cx)?;
let resp = self
.execute_single_streaming(cx, &method, &parsed, &extra_headers, &body)
.await?;
if is_redirect(resp.head.status) {
match &self.config.redirect_policy {
RedirectPolicy::None => return Ok(resp),
RedirectPolicy::Limited(max) => {
if redirect_count >= *max {
return Err(ClientError::TooManyRedirects {
count: redirect_count + 1,
max: *max,
});
}
if let Some(location) = get_header(&resp.head.headers, "Location") {
let status = resp.head.status;
drop(resp);
let next_url = resolve_redirect(&parsed, &location);
let next_parsed = ParsedUrl::parse(&next_url)?;
let next_method = redirect_method(status, &method);
let next_body = if next_method == Method::Get {
Vec::new()
} else {
body
};
let next_headers = strip_sensitive_headers_on_redirect(
&parsed,
&next_parsed,
extra_headers,
);
return self
.execute_with_redirects_streaming(
cx,
next_method,
next_parsed,
next_headers,
next_body,
redirect_count + 1,
)
.await;
}
}
}
}
Ok(resp)
})
}
async fn execute_single(
&self,
cx: &Cx,
method: &Method,
parsed: &ParsedUrl,
extra_headers: &[(String, String)],
body: &[u8],
) -> Result<Response, ClientError> {
check_cx(cx)?;
if let Some(proxy_url) = self.config.proxy_url.as_deref() {
return self
.execute_single_with_proxy(cx, method, parsed, extra_headers, body, proxy_url)
.await;
}
let req = self.build_request(method, parsed, extra_headers, body, None, None);
let request_forbids_reuse = request_forbids_connection_reuse(&req.headers);
let key = parsed.pool_key();
let acquired = self.acquire_connection(cx, parsed).await?;
let mut guard = ConnectionGuard::new(self, key.clone(), acquired.pool_id);
let reused_connection = !acquired.fresh;
check_cx(cx)?;
let result = if let Some(max_body_size) = self.config.max_body_size {
Http1Client::request_with_io_and_max_body_size(acquired.io, req, max_body_size).await
} else {
Http1Client::request_with_io(acquired.io, req).await
};
match result {
Ok((response, io)) => {
check_cx(cx)?;
guard.defused = true;
self.store_response_cookies(&parsed.host, &response.headers);
if !request_forbids_reuse && connection_can_be_reused(&response, method) {
self.release_connection(&key, acquired.pool_id, acquired.fresh, io);
} else {
self.drop_connection(&key, acquired.pool_id);
}
Ok(response)
}
Err(err) => {
let err = ClientError::from(err);
if reused_connection && should_retry_reused_connection_failure(method, &err) {
drop(guard);
return self
.retry_single_request_on_fresh_connection(
cx,
method,
parsed,
extra_headers,
body,
)
.await;
}
Err(err)
}
}
}
async fn retry_single_request_on_fresh_connection(
&self,
cx: &Cx,
method: &Method,
parsed: &ParsedUrl,
extra_headers: &[(String, String)],
body: &[u8],
) -> Result<Response, ClientError> {
let req = self.build_request(method, parsed, extra_headers, body, None, None);
let request_forbids_reuse = request_forbids_connection_reuse(&req.headers);
let key = parsed.pool_key();
let acquired = self.acquire_connection(cx, parsed).await?;
let mut guard = ConnectionGuard::new(self, key.clone(), acquired.pool_id);
check_cx(cx)?;
let result = if let Some(max_body_size) = self.config.max_body_size {
Http1Client::request_with_io_and_max_body_size(acquired.io, req, max_body_size).await
} else {
Http1Client::request_with_io(acquired.io, req).await
};
match result {
Ok((response, io)) => {
check_cx(cx)?;
guard.defused = true;
self.store_response_cookies(&parsed.host, &response.headers);
if !request_forbids_reuse && connection_can_be_reused(&response, method) {
self.release_connection(&key, acquired.pool_id, acquired.fresh, io);
} else {
self.drop_connection(&key, acquired.pool_id);
}
Ok(response)
}
Err(err) => Err(ClientError::from(err)),
}
}
async fn execute_single_streaming(
&self,
cx: &Cx,
method: &Method,
parsed: &ParsedUrl,
extra_headers: &[(String, String)],
body: &[u8],
) -> Result<ClientStreamingResponse<ClientIo>, ClientError> {
check_cx(cx)?;
if let Some(proxy_url) = self.config.proxy_url.as_deref() {
return self
.execute_single_streaming_with_proxy(
cx,
method,
parsed,
extra_headers,
body,
proxy_url,
)
.await;
}
let req = self.build_request(method, parsed, extra_headers, body, None, None);
let stream = self.connect_io(cx, parsed).await?;
check_cx(cx)?;
let resp = if let Some(max_body_size) = self.config.max_body_size {
Http1Client::request_streaming_with_max_body_size(stream, req, max_body_size).await?
} else {
Http1Client::request_streaming(stream, req).await?
};
check_cx(cx)?;
self.store_response_cookies(&parsed.host, &resp.head.headers);
Ok(resp)
}
async fn execute_single_with_proxy(
&self,
cx: &Cx,
method: &Method,
parsed: &ParsedUrl,
extra_headers: &[(String, String)],
body: &[u8],
proxy_url: &str,
) -> Result<Response, ClientError> {
check_cx(cx)?;
let proxy = parse_proxy_endpoint(proxy_url)?;
let proxy_conn = self.connect_via_proxy(cx, parsed, &proxy).await?;
check_cx(cx)?;
let request_target = if proxy_conn.use_absolute_form {
Some(absolute_request_target(parsed))
} else {
None
};
let req = self.build_request(
method,
parsed,
extra_headers,
body,
request_target,
proxy_conn.proxy_authorization.as_deref(),
);
let (response, _io) = if let Some(max_body_size) = self.config.max_body_size {
Http1Client::request_with_io_and_max_body_size(proxy_conn.io, req, max_body_size)
.await?
} else {
Http1Client::request_with_io(proxy_conn.io, req).await?
};
check_cx(cx)?;
self.store_response_cookies(&parsed.host, &response.headers);
Ok(response)
}
async fn execute_single_streaming_with_proxy(
&self,
cx: &Cx,
method: &Method,
parsed: &ParsedUrl,
extra_headers: &[(String, String)],
body: &[u8],
proxy_url: &str,
) -> Result<ClientStreamingResponse<ClientIo>, ClientError> {
check_cx(cx)?;
let proxy = parse_proxy_endpoint(proxy_url)?;
let proxy_conn = self.connect_via_proxy(cx, parsed, &proxy).await?;
check_cx(cx)?;
let request_target = if proxy_conn.use_absolute_form {
Some(absolute_request_target(parsed))
} else {
None
};
let req = self.build_request(
method,
parsed,
extra_headers,
body,
request_target,
proxy_conn.proxy_authorization.as_deref(),
);
let resp = if let Some(max_body_size) = self.config.max_body_size {
Http1Client::request_streaming_with_max_body_size(proxy_conn.io, req, max_body_size)
.await?
} else {
Http1Client::request_streaming(proxy_conn.io, req).await?
};
check_cx(cx)?;
self.store_response_cookies(&parsed.host, &resp.head.headers);
Ok(resp)
}
async fn connect_via_proxy(
&self,
cx: &Cx,
parsed: &ParsedUrl,
proxy: &ProxyEndpoint,
) -> Result<ProxyConnection, ClientError> {
match proxy.scheme {
ProxyScheme::Http | ProxyScheme::Https => {
let proxy_parsed = ParsedUrl {
scheme: match proxy.scheme {
ProxyScheme::Http => Scheme::Http,
ProxyScheme::Https => Scheme::Https,
ProxyScheme::Socks5 => unreachable!(),
},
host: proxy.host.clone(),
port: proxy.port,
path: "/".to_owned(),
};
let proxy_io = self.connect_io(cx, &proxy_parsed).await?;
if parsed.scheme == Scheme::Http {
return Ok(ProxyConnection {
io: proxy_io,
use_absolute_form: true,
proxy_authorization: proxy
.http_proxy_authorization()
.map(std::borrow::ToOwned::to_owned),
});
}
let mut connect_headers = Vec::new();
if let Some(auth) = proxy.http_proxy_authorization() {
connect_headers.push(("Proxy-Authorization".to_owned(), auth.to_owned()));
}
let tunnel = establish_http_connect_tunnel(
proxy_io,
&parsed.authority(),
self.config.user_agent.as_deref(),
&connect_headers,
)
.await?;
#[cfg(feature = "tls")]
{
let domain = parsed.host.trim_start_matches('[').trim_end_matches(']');
let tls = self.tls_connect_stream(domain, tunnel).await?;
Ok(ProxyConnection {
io: ClientIo::TlsTunnel(Box::new(tls)),
use_absolute_form: false,
proxy_authorization: None,
})
}
#[cfg(not(feature = "tls"))]
{
let _ = tunnel;
Err(ClientError::TlsError(
"TLS support is disabled (enable asupersync feature \"tls\")".into(),
))
}
}
ProxyScheme::Socks5 => {
let tcp = connect_via_socks5(proxy, parsed, cx).await?;
if parsed.scheme == Scheme::Http {
return Ok(ProxyConnection {
io: ClientIo::Plain(tcp),
use_absolute_form: false,
proxy_authorization: None,
});
}
#[cfg(feature = "tls")]
{
let domain = parsed.host.trim_start_matches('[').trim_end_matches(']');
let tls = self.tls_connect_stream(domain, tcp).await?;
Ok(ProxyConnection {
io: ClientIo::Tls(tls),
use_absolute_form: false,
proxy_authorization: None,
})
}
#[cfg(not(feature = "tls"))]
{
let _ = tcp;
Err(ClientError::TlsError(
"TLS support is disabled (enable asupersync feature \"tls\")".into(),
))
}
}
}
}
fn build_request(
&self,
method: &Method,
parsed: &ParsedUrl,
extra_headers: &[(String, String)],
body: &[u8],
request_target: Option<String>,
proxy_authorization: Option<&str>,
) -> Request {
let has_cookie_header = extra_headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("cookie"));
let has_proxy_authorization = extra_headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("proxy-authorization"));
let request_target = request_target.unwrap_or_else(|| parsed.path.clone());
let mut builder =
Request::builder(method.clone(), request_target).header("Host", parsed.authority());
if let Some(user_agent) = self.config.user_agent.as_deref() {
builder = builder.header("User-Agent", user_agent);
}
if self.config.cookie_store
&& !has_cookie_header
&& let Some(cookie_header) = self.cookie_header_for_host(&parsed.host)
{
builder = builder.header("Cookie", cookie_header);
}
if !has_proxy_authorization && let Some(value) = proxy_authorization {
builder = builder.header("Proxy-Authorization", value);
}
builder
.headers(
extra_headers
.iter()
.filter(|(name, _)| !name.eq_ignore_ascii_case("host"))
.cloned(),
)
.body(body.to_vec())
.build()
}
fn store_response_cookies(&self, host: &str, headers: &[(String, String)]) {
if !self.config.cookie_store {
return;
}
let host = canonical_cookie_host(host);
let mut cookies = self.cookies.lock();
if !cookies.contains_key(&host) && cookies.len() >= MAX_COOKIE_HOSTS {
return;
}
let mut touched = false;
{
let entry = cookies.entry(host.clone()).or_default();
for (_, value) in headers
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case("set-cookie"))
{
if let Some((name, value)) = parse_set_cookie_pair(value) {
touched = true;
if value.is_empty() {
entry.retain(|cookie| !cookie.name.eq_ignore_ascii_case(&name));
continue;
}
if let Some(existing) = entry
.iter_mut()
.find(|cookie| cookie.name.eq_ignore_ascii_case(&name))
{
existing.value = value;
} else if entry.len() < MAX_COOKIES_PER_HOST {
entry.push(StoredCookie { name, value });
}
}
}
}
if touched && cookies.get(&host).is_some_and(Vec::is_empty) {
cookies.remove(&host);
}
}
fn cookie_header_for_host(&self, host: &str) -> Option<String> {
let host = canonical_cookie_host(host);
let host_cookies = {
let cookies = self.cookies.lock();
cookies.get(&host)?.clone()
};
if host_cookies.is_empty() {
return None;
}
Some(
host_cookies
.into_iter()
.map(|cookie| format!("{}={}", cookie.name, cookie.value))
.collect::<Vec<_>>()
.join("; "),
)
}
#[cfg(feature = "tls")]
async fn tls_connect_stream<T>(
&self,
domain: &str,
stream: T,
) -> Result<TlsStream<T>, ClientError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let builder = TlsConnectorBuilder::new().alpn_protocols(vec![b"http/1.1".to_vec()]);
#[cfg(feature = "tls-native-roots")]
let builder = builder
.with_native_roots()
.map_err(|e| ClientError::TlsError(e.to_string()))?;
#[cfg(all(not(feature = "tls-native-roots"), feature = "tls-webpki-roots"))]
let builder = builder.with_webpki_roots();
let connector = builder
.build()
.map_err(|e| ClientError::TlsError(e.to_string()))?;
connector
.connect(domain, stream)
.await
.map_err(|e| ClientError::TlsError(e.to_string()))
}
async fn connect_io(&self, cx: &Cx, parsed: &ParsedUrl) -> Result<ClientIo, ClientError> {
check_cx(cx)?;
let stream = if let Some(socket_addr) = parsed_numeric_socket_addr(parsed) {
TcpStream::connect_socket_addr(socket_addr)
.await
.map_err(ClientError::ConnectError)?
} else {
let addr = format!("{}:{}", parsed.host, parsed.port);
TcpStream::connect(addr)
.await
.map_err(ClientError::ConnectError)?
};
check_cx(cx)?;
match parsed.scheme {
Scheme::Http => Ok(ClientIo::Plain(stream)),
Scheme::Https => {
#[cfg(feature = "tls")]
{
let domain = parsed.host.trim_start_matches('[').trim_end_matches(']');
let tls = self.tls_connect_stream(domain, stream).await?;
check_cx(cx)?;
Ok(ClientIo::Tls(tls))
}
#[cfg(not(feature = "tls"))]
{
let _ = stream;
Err(ClientError::TlsError(
"TLS support is disabled (enable asupersync feature \"tls\")".into(),
))
}
}
}
}
async fn acquire_connection(
&self,
cx: &Cx,
parsed: &ParsedUrl,
) -> Result<AcquiredConnection, ClientError> {
struct ConnectGuard<'a> {
client: &'a HttpClient,
key: PoolKey,
id: Option<u64>,
}
impl Drop for ConnectGuard<'_> {
fn drop(&mut self) {
if let Some(id) = self.id {
self.client.pool.lock().remove(&self.key, id);
}
}
}
let key = parsed.pool_key();
let now = self.pool_now();
self.cleanup_expired_idle_connections(now);
{
let mut pool = self.pool.lock();
let mut idle = self.idle_connections.lock();
match pool.try_acquire(&key, now) {
Some(pool_id) => {
if let Some(io) = Self::take_idle_connection_locked(&mut idle, &key, pool_id) {
return Ok(AcquiredConnection {
pool_id: Some(pool_id),
io,
fresh: false,
});
}
pool.remove(&key, pool_id);
}
None => {}
}
}
let fresh_id = {
let mut pool = self.pool.lock();
if pool.can_create_connection(&key, now) {
Some(pool.register_connecting(key.clone(), now, 1))
} else {
None
}
};
let mut guard = ConnectGuard {
client: self,
key: key.clone(),
id: fresh_id,
};
if fresh_id.is_none() {
return Err(ClientError::PoolExhausted {
host: parsed.host.clone(),
port: parsed.port,
});
}
let io = self.connect_io(cx, parsed).await?;
guard.id = None;
Ok(AcquiredConnection {
pool_id: fresh_id,
io,
fresh: true,
})
}
fn release_connection(&self, key: &PoolKey, pool_id: Option<u64>, fresh: bool, io: ClientIo) {
if let Some(id) = pool_id {
let now = self.pool_now();
let mut pool = self.pool.lock();
let returned_to_pool = if fresh {
pool.mark_connected(key, id, now)
} else {
pool.release(key, id, now)
};
if returned_to_pool {
let mut idle = self.idle_connections.lock();
Self::store_idle_connection_locked(&mut idle, key.clone(), id, io);
} else {
pool.remove(key, id);
let mut idle = self.idle_connections.lock();
Self::remove_idle_connection_locked(&mut idle, key, id);
}
}
}
fn drop_connection(&self, key: &PoolKey, pool_id: Option<u64>) {
if let Some(id) = pool_id {
self.pool.lock().remove(key, id);
self.remove_idle_connection(key, id);
}
}
fn remove_idle_connection(&self, key: &PoolKey, id: u64) {
let mut idle = self.idle_connections.lock();
Self::remove_idle_connection_locked(&mut idle, key, id);
}
fn cleanup_expired_idle_connections(&self, now: Time) {
let expired = self.pool.lock().cleanup_expired_entries(now);
if expired.is_empty() {
return;
}
let mut idle = self.idle_connections.lock();
for (key, id) in expired {
if let Some(entries) = idle.get_mut(&key) {
if let Some(position) = entries.iter().position(|(entry_id, _)| *entry_id == id) {
entries.swap_remove(position);
}
if entries.is_empty() {
idle.remove(&key);
}
}
}
}
pub fn pool_stats(&self) -> crate::http::pool::PoolStats {
self.pool.lock().stats()
}
}
impl HttpClient {
fn take_idle_connection_locked(
idle: &mut HashMap<PoolKey, Vec<(u64, ClientIo)>>,
key: &PoolKey,
id: u64,
) -> Option<ClientIo> {
let (io, remove_key) = {
let entries = idle.get_mut(key)?;
let position = entries.iter().position(|(entry_id, _)| *entry_id == id)?;
let (_, io) = entries.swap_remove(position);
(io, entries.is_empty())
};
if remove_key {
idle.remove(key);
}
Some(io)
}
fn store_idle_connection_locked(
idle: &mut HashMap<PoolKey, Vec<(u64, ClientIo)>>,
key: PoolKey,
id: u64,
io: ClientIo,
) {
idle.entry(key).or_default().push((id, io));
}
fn remove_idle_connection_locked(
idle: &mut HashMap<PoolKey, Vec<(u64, ClientIo)>>,
key: &PoolKey,
id: u64,
) {
if let Some(entries) = idle.get_mut(key) {
if let Some(position) = entries.iter().position(|(entry_id, _)| *entry_id == id) {
entries.swap_remove(position);
}
if entries.is_empty() {
idle.remove(key);
}
}
}
}
impl Default for HttpClient {
fn default() -> Self {
Self::new()
}
}
struct AcquiredConnection {
pool_id: Option<u64>,
io: ClientIo,
fresh: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ProxyScheme {
Http,
Https,
Socks5,
}
#[derive(Debug, Clone)]
enum ProxyCredentials {
HttpBasic(String),
Socks5 { username: String, password: String },
}
#[derive(Debug, Clone)]
struct ProxyEndpoint {
scheme: ProxyScheme,
host: String,
port: u16,
credentials: Option<ProxyCredentials>,
}
impl ProxyEndpoint {
fn http_proxy_authorization(&self) -> Option<&str> {
match &self.credentials {
Some(ProxyCredentials::HttpBasic(value)) => Some(value.as_str()),
_ => None,
}
}
fn socks5_credentials(&self) -> Option<(&str, &str)> {
match &self.credentials {
Some(ProxyCredentials::Socks5 { username, password }) => {
Some((username.as_str(), password.as_str()))
}
_ => None,
}
}
}
struct ProxyConnection {
io: ClientIo,
use_absolute_form: bool,
proxy_authorization: Option<String>,
}
#[derive(Debug, Clone)]
struct StoredCookie {
name: String,
value: String,
}
struct ConnectionGuard<'a> {
client: &'a HttpClient,
key: PoolKey,
pool_id: Option<u64>,
defused: bool,
}
impl<'a> ConnectionGuard<'a> {
fn new(client: &'a HttpClient, key: PoolKey, pool_id: Option<u64>) -> Self {
Self {
client,
key,
pool_id,
defused: false,
}
}
}
impl Drop for ConnectionGuard<'_> {
fn drop(&mut self) {
if !self.defused {
self.client.drop_connection(&self.key, self.pool_id);
}
}
}
fn is_redirect(status: u16) -> bool {
matches!(status, 301 | 302 | 303 | 307 | 308)
}
fn get_header(headers: &[(String, String)], name: &str) -> Option<String> {
headers
.iter()
.find(|(n, _)| n.eq_ignore_ascii_case(name))
.map(|(_, v)| v.clone())
}
fn ensure_multipart_content_type(headers: &mut Vec<(String, String)>, form: &MultipartForm) {
if headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("content-type"))
{
return;
}
headers.push(("Content-Type".to_owned(), form.content_type_header()));
}
fn redirect_method(status: u16, original: &Method) -> Method {
match status {
303 => Method::Get,
301 | 302 if *original == Method::Post => Method::Get,
_ => original.clone(),
}
}
fn canonical_cookie_host(host: &str) -> String {
host.trim_start_matches('[')
.trim_end_matches(']')
.to_ascii_lowercase()
}
fn parsed_numeric_socket_addr(parsed: &ParsedUrl) -> Option<SocketAddr> {
let host = parsed.host.trim_start_matches('[').trim_end_matches(']');
host.parse::<IpAddr>()
.ok()
.map(|ip| SocketAddr::new(ip, parsed.port))
}
fn parse_set_cookie_pair(raw: &str) -> Option<(String, String)> {
let pair = raw.split(';').next()?.trim();
let (name, value) = pair.split_once('=')?;
let name = name.trim();
if name.is_empty() {
return None;
}
let value = value.trim().trim_matches('"');
Some((name.to_owned(), value.to_owned()))
}
fn parse_proxy_endpoint(proxy_url: &str) -> Result<ProxyEndpoint, ClientError> {
let (scheme, rest) = if let Some(rest) = proxy_url.strip_prefix("http://") {
(ProxyScheme::Http, rest)
} else if let Some(rest) = proxy_url.strip_prefix("https://") {
(ProxyScheme::Https, rest)
} else if let Some(rest) = proxy_url.strip_prefix("socks5://") {
(ProxyScheme::Socks5, rest)
} else {
return Err(ClientError::InvalidUrl(format!(
"unsupported proxy scheme in: {proxy_url}"
)));
};
let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let authority = rest[..authority_end].trim();
if authority.is_empty() {
return Err(ClientError::InvalidUrl(format!(
"proxy URL missing authority: {proxy_url}"
)));
}
let (userinfo, host_port) = authority
.rsplit_once('@')
.map_or((None, authority), |(userinfo, host_port)| {
(Some(userinfo), host_port)
});
let default_port = match scheme {
ProxyScheme::Http => 80,
ProxyScheme::Https => 443,
ProxyScheme::Socks5 => 1080,
};
let (host, port) = parse_host_port(host_port, default_port)?;
let credentials = match userinfo {
None => None,
Some(userinfo) => {
if userinfo.is_empty() {
return Err(ClientError::InvalidUrl(format!(
"proxy URL has empty credentials: {proxy_url}"
)));
}
match scheme {
ProxyScheme::Http | ProxyScheme::Https => {
let encoded = base64::engine::general_purpose::STANDARD.encode(userinfo);
Some(ProxyCredentials::HttpBasic(format!("Basic {encoded}")))
}
ProxyScheme::Socks5 => {
let (username, password) = userinfo
.split_once(':')
.map_or((userinfo, ""), |(username, password)| (username, password));
if username.is_empty() {
return Err(ClientError::InvalidUrl(
"SOCKS5 username cannot be empty".into(),
));
}
if username.len() > 255 || password.len() > 255 {
return Err(ClientError::InvalidUrl(
"SOCKS5 credentials must be <=255 bytes each".into(),
));
}
Some(ProxyCredentials::Socks5 {
username: username.to_owned(),
password: password.to_owned(),
})
}
}
}
};
Ok(ProxyEndpoint {
scheme,
host,
port,
credentials,
})
}
fn parse_host_port(authority: &str, default_port: u16) -> Result<(String, u16), ClientError> {
if authority.is_empty() {
return Err(ClientError::InvalidUrl("empty authority".into()));
}
if authority.starts_with('[') {
let bracket_end = authority
.find(']')
.ok_or_else(|| ClientError::InvalidUrl("unclosed bracket in IPv6 address".into()))?;
let host = authority[..=bracket_end].to_owned();
let rest = &authority[bracket_end + 1..];
if rest.is_empty() {
return Ok((host, default_port));
}
if let Some(port_str) = rest.strip_prefix(':') {
let port = port_str
.parse::<u16>()
.map_err(|_| ClientError::InvalidUrl(format!("invalid port: {port_str}")))?;
return Ok((host, port));
}
return Err(ClientError::InvalidUrl(format!(
"unexpected characters after IPv6 host: {rest}"
)));
}
if authority.matches(':').count() > 1 {
return Ok((authority.to_owned(), default_port));
}
if let Some((host, port_str)) = authority.rsplit_once(':') {
let port = port_str
.parse::<u16>()
.map_err(|_| ClientError::InvalidUrl(format!("invalid port: {port_str}")))?;
if host.is_empty() {
return Err(ClientError::InvalidUrl("empty host".into()));
}
return Ok((host.to_owned(), port));
}
Ok((authority.to_owned(), default_port))
}
fn absolute_request_target(parsed: &ParsedUrl) -> String {
let scheme = match parsed.scheme {
Scheme::Http => "http",
Scheme::Https => "https",
};
format!("{scheme}://{}{}", parsed.authority(), parsed.path)
}
async fn connect_via_socks5(
proxy: &ProxyEndpoint,
target: &ParsedUrl,
cx: &Cx,
) -> Result<TcpStream, ClientError> {
check_cx(cx)?;
let addr = format!("{}:{}", proxy.host, proxy.port);
let mut stream = TcpStream::connect(addr)
.await
.map_err(ClientError::ConnectError)?;
check_cx(cx)?;
socks5_negotiate_auth(&mut stream, proxy.socks5_credentials()).await?;
check_cx(cx)?;
let connect_req = build_socks5_connect_request(target)?;
stream.write_all(&connect_req).await?;
stream.flush().await?;
check_cx(cx)?;
read_socks5_connect_reply(&mut stream).await?;
check_cx(cx)?;
Ok(stream)
}
async fn socks5_negotiate_auth(
stream: &mut TcpStream,
socks_creds: Option<(&str, &str)>,
) -> Result<(), ClientError> {
let mut methods = vec![SOCKS5_AUTH_NONE];
if socks_creds.is_some() {
methods.push(SOCKS5_AUTH_USER_PASS);
}
let mut greeting = Vec::with_capacity(2 + methods.len());
greeting.push(SOCKS5_VERSION);
greeting.push(
u8::try_from(methods.len()).map_err(|_| {
ClientError::ProxyError("too many SOCKS5 auth methods configured".into())
})?,
);
greeting.extend_from_slice(&methods);
stream.write_all(&greeting).await?;
stream.flush().await?;
let mut method_reply = [0u8; 2];
stream.read_exact(&mut method_reply).await?;
if method_reply[0] != SOCKS5_VERSION {
return Err(ClientError::ProxyError(format!(
"unexpected SOCKS5 version {}",
method_reply[0]
)));
}
match method_reply[1] {
SOCKS5_AUTH_NONE => Ok(()),
SOCKS5_AUTH_USER_PASS => socks5_authenticate_user_pass(stream, socks_creds).await,
SOCKS5_AUTH_NO_ACCEPTABLE => Err(ClientError::ProxyError(
"SOCKS5 proxy rejected all authentication methods".into(),
)),
method => Err(ClientError::ProxyError(format!(
"SOCKS5 proxy selected unsupported auth method: {method:#x}"
))),
}
}
async fn socks5_authenticate_user_pass(
stream: &mut TcpStream,
socks_creds: Option<(&str, &str)>,
) -> Result<(), ClientError> {
let Some((username, password)) = socks_creds else {
return Err(ClientError::ProxyError(
"SOCKS5 proxy requested username/password auth but credentials were not set".into(),
));
};
let user_len = u8::try_from(username.len())
.map_err(|_| ClientError::ProxyError("SOCKS5 username exceeds 255 bytes".into()))?;
let pass_len = u8::try_from(password.len())
.map_err(|_| ClientError::ProxyError("SOCKS5 password exceeds 255 bytes".into()))?;
let mut auth = Vec::with_capacity(3 + username.len() + password.len());
auth.push(0x01);
auth.push(user_len);
auth.extend_from_slice(username.as_bytes());
auth.push(pass_len);
auth.extend_from_slice(password.as_bytes());
stream.write_all(&auth).await?;
stream.flush().await?;
let mut auth_reply = [0u8; 2];
stream.read_exact(&mut auth_reply).await?;
if auth_reply[0] != 0x01 || auth_reply[1] != 0x00 {
return Err(ClientError::ProxyError(
"SOCKS5 username/password authentication failed".into(),
));
}
Ok(())
}
fn build_socks5_connect_request(target: &ParsedUrl) -> Result<Vec<u8>, ClientError> {
let mut connect_req = Vec::with_capacity(300);
connect_req.extend_from_slice(&[SOCKS5_VERSION, 0x01, 0x00]); let host = target.host.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = host.parse::<IpAddr>() {
match ip {
IpAddr::V4(addr) => {
connect_req.push(0x01);
connect_req.extend_from_slice(&addr.octets());
}
IpAddr::V6(addr) => {
connect_req.push(0x04);
connect_req.extend_from_slice(&addr.octets());
}
}
} else {
let host_bytes = host.as_bytes();
let host_len = u8::try_from(host_bytes.len())
.map_err(|_| ClientError::ProxyError("SOCKS5 domain name exceeds 255 bytes".into()))?;
connect_req.push(0x03);
connect_req.push(host_len);
connect_req.extend_from_slice(host_bytes);
}
connect_req.extend_from_slice(&target.port.to_be_bytes());
Ok(connect_req)
}
async fn read_socks5_connect_reply(stream: &mut TcpStream) -> Result<(), ClientError> {
let mut reply_head = [0u8; 4];
stream.read_exact(&mut reply_head).await?;
if reply_head[0] != SOCKS5_VERSION {
return Err(ClientError::ProxyError(format!(
"unexpected SOCKS5 connect reply version {}",
reply_head[0]
)));
}
if reply_head[1] != 0x00 {
return Err(ClientError::ProxyError(format!(
"SOCKS5 CONNECT failed: {}",
socks5_reply_message(reply_head[1])
)));
}
match reply_head[3] {
0x01 => {
let mut addr = [0u8; 4];
stream.read_exact(&mut addr).await?;
}
0x04 => {
let mut addr = [0u8; 16];
stream.read_exact(&mut addr).await?;
}
0x03 => {
let mut len = [0u8; 1];
stream.read_exact(&mut len).await?;
let mut addr = vec![0u8; usize::from(len[0])];
stream.read_exact(&mut addr).await?;
}
atyp => {
return Err(ClientError::ProxyError(format!(
"SOCKS5 CONNECT reply has unknown ATYP {atyp:#x}"
)));
}
}
let mut port = [0u8; 2];
stream.read_exact(&mut port).await?;
Ok(())
}
fn socks5_reply_message(code: u8) -> &'static str {
match code {
0x01 => "general SOCKS server failure",
0x02 => "connection not allowed by ruleset",
0x03 => "network unreachable",
0x04 => "host unreachable",
0x05 => "connection refused by destination host",
0x06 => "TTL expired",
0x07 => "command not supported",
0x08 => "address type not supported",
_ => "unknown SOCKS5 error",
}
}
fn connection_can_be_reused(response: &Response, req_method: &Method) -> bool {
if response.status == 101
|| response
.headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("upgrade"))
|| header_has_token(&response.headers, "connection", "upgrade")
{
return false;
}
let has_content_length = response
.headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("content-length"));
let has_transfer_encoding = response
.headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("transfer-encoding"));
let is_bodyless =
matches!(response.status, 100..=199 | 204 | 304) || matches!(req_method, Method::Head);
if !has_content_length && !has_transfer_encoding && !is_bodyless {
return false;
}
match response.version {
Version::Http11 => !header_has_token(&response.headers, "connection", "close"),
Version::Http10 => header_has_token(&response.headers, "connection", "keep-alive"),
}
}
fn request_forbids_connection_reuse(headers: &[(String, String)]) -> bool {
header_has_token(headers, "connection", "close")
|| header_has_token(headers, "connection", "upgrade")
|| headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("upgrade"))
}
fn should_retry_reused_connection_failure(method: &Method, err: &ClientError) -> bool {
method_is_safe_to_retry_after_stale_reuse(method) && client_error_looks_like_stale_reuse(err)
}
fn method_is_safe_to_retry_after_stale_reuse(method: &Method) -> bool {
matches!(
method,
Method::Get | Method::Head | Method::Options | Method::Trace
)
}
fn client_error_looks_like_stale_reuse(err: &ClientError) -> bool {
match err {
ClientError::Io(io_err) => io_error_looks_like_stale_reuse(io_err),
ClientError::HttpError(crate::http::h1::codec::HttpError::Io(io_err)) => {
io_error_looks_like_stale_reuse(io_err)
}
_ => false,
}
}
fn io_error_looks_like_stale_reuse(err: &io::Error) -> bool {
matches!(
err.kind(),
io::ErrorKind::BrokenPipe
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
| io::ErrorKind::NotConnected
| io::ErrorKind::UnexpectedEof
)
}
fn header_has_token(headers: &[(String, String)], name: &str, token: &str) -> bool {
headers.iter().any(|(header_name, header_value)| {
header_name.eq_ignore_ascii_case(name)
&& header_value
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case(token))
})
}
fn resolve_redirect(current: &ParsedUrl, location: &str) -> String {
if location.starts_with("http://") || location.starts_with("https://") {
return location.to_owned();
}
if let Some(rest) = location.strip_prefix("//") {
return match current.scheme {
Scheme::Http => format!("http://{rest}"),
Scheme::Https => format!("https://{rest}"),
};
}
if location.starts_with('/') {
let scheme = match current.scheme {
Scheme::Http => "http",
Scheme::Https => "https",
};
return format!("{scheme}://{}:{}{location}", current.host, current.port);
}
let path_only = current
.path
.split_once(&['?', '#'][..])
.map_or(current.path.as_str(), |(p, _)| p);
let base_path = path_only.rfind('/').map_or("/", |i| &path_only[..=i]);
let scheme = match current.scheme {
Scheme::Http => "http",
Scheme::Https => "https",
};
format!(
"{scheme}://{}:{}{base_path}{location}",
current.host, current.port
)
}
fn same_origin(a: &ParsedUrl, b: &ParsedUrl) -> bool {
a.scheme == b.scheme && a.port == b.port && a.host.eq_ignore_ascii_case(&b.host)
}
fn strip_sensitive_headers_on_redirect(
from: &ParsedUrl,
to: &ParsedUrl,
headers: Vec<(String, String)>,
) -> Vec<(String, String)> {
if same_origin(from, to) {
return headers;
}
headers
.into_iter()
.filter(|(name, _)| {
let lower = name.to_ascii_lowercase();
lower != "authorization" && lower != "cookie" && lower != "proxy-authorization"
})
.collect()
}
fn find_headers_end(buf: &[u8]) -> Option<usize> {
memmem::find(buf, b"\r\n\r\n").map(|idx| idx + 4)
}
fn contains_ctl_line_break(s: &str) -> bool {
s.chars().any(|c| matches!(c, '\r' | '\n'))
}
fn contains_ctl_or_whitespace(s: &str) -> bool {
s.chars().any(|c| c.is_ascii_control() || c.is_whitespace())
}
fn validate_connect_inputs(
target_authority: &str,
extra_headers: &[(String, String)],
user_agent: Option<&str>,
) -> Result<(), ClientError> {
if target_authority.trim().is_empty() {
return Err(ClientError::InvalidConnectInput(
"target authority cannot be empty".into(),
));
}
if target_authority.chars().any(char::is_whitespace)
|| contains_ctl_line_break(target_authority)
{
return Err(ClientError::InvalidConnectInput(
"target authority must be RFC authority-form without whitespace".into(),
));
}
if let Some(ua) = user_agent
&& contains_ctl_line_break(ua)
{
return Err(ClientError::InvalidConnectInput(
"User-Agent header contains invalid control characters".into(),
));
}
for (name, value) in extra_headers {
if name.trim().is_empty() {
return Err(ClientError::InvalidConnectInput(
"header name cannot be empty".into(),
));
}
if contains_ctl_line_break(name) || contains_ctl_line_break(value) {
return Err(ClientError::InvalidConnectInput(
"header name/value cannot contain CR or LF".into(),
));
}
}
Ok(())
}
fn parse_connect_status_line(line: &str) -> Result<(u16, String), ClientError> {
let mut parts = line.splitn(3, ' ');
let version = parts.next().ok_or(ClientError::HttpError(
crate::http::h1::codec::HttpError::BadRequestLine,
))?;
let status = parts.next().ok_or(ClientError::HttpError(
crate::http::h1::codec::HttpError::BadRequestLine,
))?;
let reason = parts.next().unwrap_or("").to_owned();
if Version::from_bytes(version.as_bytes()).is_none() {
return Err(ClientError::HttpError(
crate::http::h1::codec::HttpError::UnsupportedVersion,
));
}
let code = status
.parse::<u16>()
.map_err(|_| ClientError::HttpError(crate::http::h1::codec::HttpError::BadRequestLine))?;
Ok((code, reason))
}
async fn establish_http_connect_tunnel<T>(
mut io: T,
target_authority: &str,
user_agent: Option<&str>,
extra_headers: &[(String, String)],
) -> Result<HttpConnectTunnel<T>, ClientError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
validate_connect_inputs(target_authority, extra_headers, user_agent)?;
let mut request = String::with_capacity(256);
write!(&mut request, "CONNECT {target_authority} HTTP/1.1\r\n")
.expect("in-memory string write cannot fail");
write!(&mut request, "Host: {target_authority}\r\n")
.expect("in-memory string write cannot fail");
if let Some(ua) = user_agent {
write!(&mut request, "User-Agent: {ua}\r\n").expect("in-memory string write cannot fail");
}
for (name, value) in extra_headers {
write!(&mut request, "{name}: {value}\r\n").expect("in-memory string write cannot fail");
}
request.push_str("\r\n");
io.write_all(request.as_bytes()).await?;
io.flush().await?;
let mut read_buf = Vec::with_capacity(8192);
let mut scratch = [0u8; 8192];
loop {
if let Some(end) = find_headers_end(&read_buf) {
if end > CONNECT_MAX_HEADERS_SIZE {
return Err(ClientError::HttpError(
crate::http::h1::codec::HttpError::HeadersTooLarge,
));
}
let head = std::str::from_utf8(&read_buf[..end]).map_err(|_| {
ClientError::HttpError(crate::http::h1::codec::HttpError::BadRequestLine)
})?;
let mut lines = head.split("\r\n");
let status_line = lines.next().ok_or(ClientError::HttpError(
crate::http::h1::codec::HttpError::BadRequestLine,
))?;
let (status, reason) = parse_connect_status_line(status_line)?;
if (100..=199).contains(&status) {
read_buf.drain(..end);
continue;
}
if !(200..=299).contains(&status) {
return Err(ClientError::ConnectTunnelRefused { status, reason });
}
let prefetched = read_buf[end..].to_vec();
return Ok(HttpConnectTunnel::new(io, prefetched));
}
if read_buf.len() > CONNECT_MAX_HEADERS_SIZE {
return Err(ClientError::HttpError(
crate::http::h1::codec::HttpError::HeadersTooLarge,
));
}
let n = poll_fn(|cx| {
let mut rb = ReadBuf::new(&mut scratch);
match Pin::new(&mut io).poll_read(cx, &mut rb) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(rb.filled().len())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}
})
.await?;
if n == 0 {
return Err(ClientError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"proxy closed before CONNECT response headers were complete",
)));
}
read_buf.extend_from_slice(&scratch[..n]);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::AsyncWriteExt;
use futures_lite::future::block_on;
use std::cell::Cell;
use std::future::poll_fn;
use std::net::TcpListener;
thread_local! {
static HTTP_CLIENT_TEST_TIME_NANOS: Cell<u64> = const { Cell::new(0) };
}
fn set_http_client_test_time(nanos: u64) {
HTTP_CLIENT_TEST_TIME_NANOS.with(|t| t.set(nanos));
}
fn http_client_test_time() -> Time {
Time::from_nanos(HTTP_CLIENT_TEST_TIME_NANOS.with(std::cell::Cell::get))
}
fn loopback_client_io() -> ClientIo {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
let addr = listener.local_addr().expect("listener address");
let client = std::net::TcpStream::connect(addr).expect("connect client");
let (_server, _) = listener.accept().expect("accept client");
let stream = TcpStream::from_std(client).expect("wrap stream");
ClientIo::Plain(stream)
}
#[test]
fn parse_http_url() {
let url = ParsedUrl::parse("http://example.com/path").unwrap();
assert_eq!(url.scheme, Scheme::Http);
assert_eq!(url.host, "example.com");
assert_eq!(url.port, 80);
assert_eq!(url.path, "/path");
}
#[test]
fn parse_https_url() {
let url = ParsedUrl::parse("https://example.com/api/v1").unwrap();
assert_eq!(url.scheme, Scheme::Https);
assert_eq!(url.host, "example.com");
assert_eq!(url.port, 443);
assert_eq!(url.path, "/api/v1");
}
#[test]
fn parse_url_with_port() {
let url = ParsedUrl::parse("http://localhost:8080/test").unwrap();
assert_eq!(url.host, "localhost");
assert_eq!(url.port, 8080);
assert_eq!(url.path, "/test");
}
#[test]
fn parse_url_no_path() {
let url = ParsedUrl::parse("http://example.com").unwrap();
assert_eq!(url.path, "/");
}
#[test]
fn parse_url_with_query() {
let url = ParsedUrl::parse("http://example.com/search?q=test&page=1").unwrap();
assert_eq!(url.path, "/search?q=test&page=1");
}
#[test]
fn parsed_numeric_socket_addr_handles_ipv4_host() {
let url = ParsedUrl::parse("http://127.0.0.1:8765/healthz").expect("parse url");
let socket_addr = parsed_numeric_socket_addr(&url).expect("numeric addr");
assert_eq!(socket_addr, "127.0.0.1:8765".parse().unwrap());
}
#[test]
fn parsed_numeric_socket_addr_handles_ipv6_host() {
let url = ParsedUrl::parse("http://[::1]:8765/healthz").expect("parse url");
let socket_addr = parsed_numeric_socket_addr(&url).expect("numeric addr");
assert_eq!(socket_addr, "[::1]:8765".parse().unwrap());
}
#[test]
fn parse_url_handles_unbracketed_ipv6_host() {
let url = ParsedUrl::parse("http://::1/healthz").expect("parse url");
assert_eq!(url.host, "::1");
assert_eq!(url.port, 80);
}
#[test]
fn parse_url_invalid_scheme() {
let result = ParsedUrl::parse("ftp://example.com");
assert!(result.is_err());
}
#[test]
fn parse_url_empty_host() {
let result = ParsedUrl::parse("http:///path");
assert!(result.is_err());
}
#[test]
fn parse_url_invalid_port() {
let result = ParsedUrl::parse("http://example.com:abc/path");
assert!(result.is_err());
}
#[test]
fn parse_url_rejects_control_or_whitespace_in_authority() {
let result = ParsedUrl::parse("http://example.com\r\nx-injected: y/path");
assert!(matches!(result, Err(ClientError::InvalidUrl(_))));
let result = ParsedUrl::parse("http://example.com bad/path");
assert!(matches!(result, Err(ClientError::InvalidUrl(_))));
}
#[test]
fn parse_http_proxy_endpoint_with_basic_auth() {
let proxy = parse_proxy_endpoint("http://alice:secret@proxy.local:8080")
.expect("proxy should parse");
assert_eq!(proxy.scheme, ProxyScheme::Http);
assert_eq!(proxy.host, "proxy.local");
assert_eq!(proxy.port, 8080);
assert_eq!(
proxy.http_proxy_authorization(),
Some("Basic YWxpY2U6c2VjcmV0")
);
}
#[test]
fn parse_socks5_proxy_endpoint_with_credentials() {
let proxy =
parse_proxy_endpoint("socks5://agent:pw@127.0.0.1").expect("proxy should parse");
assert_eq!(proxy.scheme, ProxyScheme::Socks5);
assert_eq!(proxy.host, "127.0.0.1");
assert_eq!(proxy.port, 1080);
assert_eq!(proxy.socks5_credentials(), Some(("agent", "pw")));
}
#[test]
fn absolute_request_target_uses_full_uri() {
let parsed = ParsedUrl::parse("http://example.com:8080/path?q=1").unwrap();
assert_eq!(
absolute_request_target(&parsed),
"http://example.com:8080/path?q=1"
);
}
#[test]
fn pool_key_from_http_url() {
let url = ParsedUrl::parse("http://example.com/path").unwrap();
let key = url.pool_key();
assert_eq!(key.host, "example.com");
assert_eq!(key.port, 80);
assert!(!key.is_https);
}
#[test]
fn pool_key_from_https_url() {
let url = ParsedUrl::parse("https://example.com/path").unwrap();
let key = url.pool_key();
assert_eq!(key.host, "example.com");
assert_eq!(key.port, 443);
assert!(key.is_https);
}
#[test]
fn authority_default_port_omitted() {
let url = ParsedUrl::parse("http://example.com/path").unwrap();
assert_eq!(url.authority(), "example.com");
let url = ParsedUrl::parse("https://example.com/path").unwrap();
assert_eq!(url.authority(), "example.com");
}
#[test]
fn authority_custom_port_included() {
let url = ParsedUrl::parse("http://example.com:8080/path").unwrap();
assert_eq!(url.authority(), "example.com:8080");
}
#[test]
fn is_redirect_detects_all_codes() {
assert!(is_redirect(301));
assert!(is_redirect(302));
assert!(is_redirect(303));
assert!(is_redirect(307));
assert!(is_redirect(308));
assert!(!is_redirect(200));
assert!(!is_redirect(404));
assert!(!is_redirect(500));
assert!(!is_redirect(304)); }
#[test]
fn redirect_method_303_always_get() {
assert_eq!(redirect_method(303, &Method::Post), Method::Get);
assert_eq!(redirect_method(303, &Method::Put), Method::Get);
assert_eq!(redirect_method(303, &Method::Get), Method::Get);
}
#[test]
fn redirect_method_307_preserves() {
assert_eq!(redirect_method(307, &Method::Post), Method::Post);
assert_eq!(redirect_method(307, &Method::Get), Method::Get);
assert_eq!(redirect_method(307, &Method::Put), Method::Put);
}
#[test]
fn redirect_method_308_preserves() {
assert_eq!(redirect_method(308, &Method::Post), Method::Post);
assert_eq!(redirect_method(308, &Method::Delete), Method::Delete);
}
#[test]
fn redirect_method_301_post_becomes_get() {
assert_eq!(redirect_method(301, &Method::Post), Method::Get);
assert_eq!(redirect_method(301, &Method::Get), Method::Get);
}
#[test]
fn redirect_method_302_post_becomes_get() {
assert_eq!(redirect_method(302, &Method::Post), Method::Get);
assert_eq!(redirect_method(302, &Method::Get), Method::Get);
}
#[test]
fn resolve_absolute_redirect() {
let current = ParsedUrl::parse("http://example.com/old").unwrap();
let result = resolve_redirect(¤t, "https://other.com/new");
assert_eq!(result, "https://other.com/new");
}
#[test]
fn resolve_protocol_relative_redirect() {
let current = ParsedUrl::parse("https://example.com/old").unwrap();
let result = resolve_redirect(¤t, "//cdn.example.com/asset");
assert_eq!(result, "https://cdn.example.com/asset");
}
#[test]
fn resolve_absolute_path_redirect() {
let current = ParsedUrl::parse("http://example.com:8080/old/page").unwrap();
let result = resolve_redirect(¤t, "/new/page");
assert_eq!(result, "http://example.com:8080/new/page");
}
#[test]
fn resolve_relative_path_redirect() {
let current = ParsedUrl::parse("http://example.com/dir/old").unwrap();
let result = resolve_redirect(¤t, "new");
assert_eq!(result, "http://example.com:80/dir/new");
}
#[test]
fn resolve_relative_path_redirect_ignores_query_slashes() {
let current = ParsedUrl::parse("http://example.com/dir/old?return=/home").unwrap();
let result = resolve_redirect(¤t, "new");
assert_eq!(result, "http://example.com:80/dir/new");
}
#[test]
fn get_header_case_insensitive() {
let headers = vec![
("Content-Type".into(), "text/html".into()),
("location".into(), "/new".into()),
];
assert_eq!(get_header(&headers, "Location"), Some("/new".into()));
assert_eq!(get_header(&headers, "LOCATION"), Some("/new".into()));
assert_eq!(
get_header(&headers, "content-type"),
Some("text/html".into())
);
assert_eq!(get_header(&headers, "X-Missing"), None);
}
#[test]
fn client_error_display() {
let err = ClientError::InvalidUrl("bad".into());
assert!(format!("{err}").contains("bad"));
let err = ClientError::TooManyRedirects { count: 5, max: 10 };
let msg = format!("{err}");
assert!(msg.contains('5'));
assert!(msg.contains("10"));
let err = ClientError::Cancelled;
assert!(format!("{err}").contains("cancelled"));
let err = ClientError::PoolExhausted {
host: "example.com".into(),
port: 80,
};
let msg = format!("{err}");
assert!(msg.contains("example.com"));
assert!(msg.contains("80"));
}
#[test]
fn client_error_source() {
use std::error::Error;
let err = ClientError::InvalidUrl("x".into());
assert!(err.source().is_none());
let io_err = io::Error::other("test");
let err = ClientError::Io(io_err);
assert!(err.source().is_some());
let err = ClientError::Cancelled;
assert!(err.source().is_none());
}
#[test]
fn client_error_is_cancelled() {
assert!(ClientError::Cancelled.is_cancelled());
assert!(!ClientError::InvalidUrl("x".into()).is_cancelled());
}
#[test]
fn default_config() {
let config = HttpClientConfig::default();
assert!(matches!(
config.redirect_policy,
RedirectPolicy::Limited(10)
));
assert_eq!(config.user_agent, Some("asupersync/0.1".into()));
assert!(!config.cookie_store);
assert!(config.proxy_url.is_none());
}
#[test]
fn config_with_time_getter_exposes_custom_clock() {
set_http_client_test_time(77);
let config = HttpClientConfig::default().with_time_getter(http_client_test_time);
assert_eq!((config.time_getter())().as_nanos(), 77);
}
#[test]
fn builder_default_matches_client_defaults() {
let client = HttpClient::builder().build();
assert_eq!(client.config.pool_config.max_connections_per_host, 6);
assert_eq!(client.config.pool_config.max_total_connections, 100);
assert_eq!(
client.config.pool_config.idle_timeout,
std::time::Duration::from_secs(90)
);
assert_eq!(
client.config.pool_config.cleanup_interval,
std::time::Duration::from_secs(30)
);
assert!(matches!(
client.config.redirect_policy,
RedirectPolicy::Limited(10)
));
assert_eq!(client.config.user_agent.as_deref(), Some("asupersync/0.1"));
assert!(!client.config.cookie_store);
assert!(client.config.proxy_url.is_none());
}
#[test]
fn builder_overrides_pool_and_redirect_and_user_agent() {
let client = HttpClient::builder()
.max_connections_per_host(12)
.max_total_connections(240)
.idle_timeout(std::time::Duration::from_secs(15))
.cleanup_interval(std::time::Duration::from_secs(5))
.no_redirects()
.no_user_agent()
.cookie_store(true)
.no_cookie_store()
.proxy("http://proxy.internal:3128")
.no_proxy()
.build();
assert_eq!(client.config.pool_config.max_connections_per_host, 12);
assert_eq!(client.config.pool_config.max_total_connections, 240);
assert_eq!(
client.config.pool_config.idle_timeout,
std::time::Duration::from_secs(15)
);
assert_eq!(
client.config.pool_config.cleanup_interval,
std::time::Duration::from_secs(5)
);
assert!(matches!(
client.config.redirect_policy,
RedirectPolicy::None
));
assert!(client.config.user_agent.is_none());
assert!(!client.config.cookie_store);
assert!(client.config.proxy_url.is_none());
}
#[test]
fn builder_pool_config_and_max_redirects() {
let pool_config = PoolConfig::builder()
.max_connections_per_host(3)
.max_total_connections(32)
.idle_timeout(std::time::Duration::from_secs(7))
.cleanup_interval(std::time::Duration::from_secs(3))
.build();
let client = HttpClient::builder()
.pool_config(pool_config)
.max_redirects(2)
.user_agent("asupersync-test/2.0")
.cookie_store(true)
.proxy("socks5://proxy.internal:1080")
.build();
assert_eq!(client.config.pool_config.max_connections_per_host, 3);
assert_eq!(client.config.pool_config.max_total_connections, 32);
assert_eq!(
client.config.pool_config.idle_timeout,
std::time::Duration::from_secs(7)
);
assert_eq!(
client.config.pool_config.cleanup_interval,
std::time::Duration::from_secs(3)
);
assert!(matches!(
client.config.redirect_policy,
RedirectPolicy::Limited(2)
));
assert_eq!(
client.config.user_agent.as_deref(),
Some("asupersync-test/2.0")
);
assert!(client.config.cookie_store);
assert_eq!(
client.config.proxy_url.as_deref(),
Some("socks5://proxy.internal:1080")
);
}
#[test]
fn builder_with_time_getter_overrides_pool_clock() {
set_http_client_test_time(777);
let client = HttpClient::builder()
.with_time_getter(http_client_test_time)
.build();
assert_eq!(client.pool_now().as_nanos(), 777);
assert_eq!((client.config.time_getter())().as_nanos(), 777);
}
#[test]
fn client_default_creates_pool() {
let client = HttpClient::new();
let stats = client.pool_stats();
assert_eq!(stats.total_connections, 0);
}
#[test]
fn release_connection_marks_fresh_connection_with_custom_time_getter() {
set_http_client_test_time(123);
let client = HttpClient::builder()
.with_time_getter(http_client_test_time)
.build();
let key = PoolKey::http("example.com", None);
let id = client
.pool
.lock()
.register_connecting(key.clone(), Time::ZERO, 1);
client.release_connection(&key, Some(id), true, loopback_client_io());
let (created_at, last_used, state, requests_served) = {
let pool = client.pool.lock();
let meta = pool
.get_connection_meta(&key, id)
.expect("connection metadata");
let values = (
meta.created_at,
meta.last_used,
meta.state,
meta.requests_served,
);
drop(pool);
values
};
assert_eq!(created_at, Time::ZERO);
assert_eq!(last_used, Time::from_nanos(123));
assert_eq!(state, crate::http::pool::PooledConnectionState::Idle);
assert_eq!(requests_served, 0);
}
#[test]
fn release_connection_marks_reused_connection_with_custom_time_getter() {
let client = HttpClient::builder()
.with_time_getter(http_client_test_time)
.build();
let key = PoolKey::http("example.com", None);
let id = {
let mut pool = client.pool.lock();
let id = pool.register_connecting(key.clone(), Time::from_nanos(10), 1);
assert!(pool.mark_connected(&key, id, Time::from_nanos(20)));
let acquired = pool
.try_acquire(&key, Time::from_nanos(30))
.expect("acquire pooled connection");
assert_eq!(acquired, id);
drop(pool);
id
};
set_http_client_test_time(456);
client.release_connection(&key, Some(id), false, loopback_client_io());
let (created_at, last_used, state, requests_served) = {
let pool = client.pool.lock();
let meta = pool
.get_connection_meta(&key, id)
.expect("connection metadata");
let values = (
meta.created_at,
meta.last_used,
meta.state,
meta.requests_served,
);
drop(pool);
values
};
assert_eq!(created_at, Time::from_nanos(10));
assert_eq!(last_used, Time::from_nanos(456));
assert_eq!(state, crate::http::pool::PooledConnectionState::Idle);
assert_eq!(requests_served, 1);
}
#[test]
fn release_connection_drops_fresh_io_when_mark_connected_is_rejected() {
let client = HttpClient::builder()
.with_time_getter(http_client_test_time)
.build();
let key = PoolKey::http("example.com", None);
let id = {
let mut pool = client.pool.lock();
let id = pool.register_connecting(key.clone(), Time::from_nanos(10), 1);
assert!(pool.mark_connected(&key, id, Time::from_nanos(20)));
drop(pool);
id
};
set_http_client_test_time(30);
client.release_connection(&key, Some(id), true, loopback_client_io());
assert!(
client.pool.lock().get_connection_meta(&key, id).is_none(),
"invalid fresh release must retire pool metadata instead of storing idle IO"
);
assert!(
client.idle_connections.lock().get(&key).is_none(),
"invalid fresh release must not leave idle IO behind"
);
}
#[test]
fn release_connection_drops_reused_io_when_release_is_rejected() {
let client = HttpClient::builder()
.with_time_getter(http_client_test_time)
.build();
let key = PoolKey::http("example.com", None);
let id = client
.pool
.lock()
.register_connecting(key.clone(), Time::from_nanos(10), 1);
set_http_client_test_time(30);
client.release_connection(&key, Some(id), false, loopback_client_io());
assert!(
client.pool.lock().get_connection_meta(&key, id).is_none(),
"invalid reused release must retire pool metadata instead of storing idle IO"
);
assert!(
client.idle_connections.lock().get(&key).is_none(),
"invalid reused release must not leave idle IO behind"
);
}
#[test]
fn cleanup_expired_idle_connections_drops_stale_client_io() {
set_http_client_test_time(0);
let client = HttpClient::builder()
.with_time_getter(http_client_test_time)
.idle_timeout(std::time::Duration::from_nanos(100))
.build();
let key = PoolKey::http("example.com", None);
let id = client
.pool
.lock()
.register_connecting(key.clone(), Time::ZERO, 1);
client.release_connection(&key, Some(id), true, loopback_client_io());
assert_eq!(
client.idle_connections.lock().get(&key).map_or(0, Vec::len),
1,
"freshly released client IO should be tracked as idle"
);
set_http_client_test_time(150);
client.cleanup_expired_idle_connections(http_client_test_time());
assert!(
client.pool.lock().get_connection_meta(&key, id).is_none(),
"expired pool metadata must be pruned"
);
assert!(
client.idle_connections.lock().get(&key).is_none(),
"expired idle IO must be dropped alongside metadata"
);
}
#[test]
fn release_connection_holds_pool_lock_until_idle_io_is_recorded() {
use std::sync::Arc;
use std::sync::mpsc;
use std::time::{Duration, Instant};
let client = Arc::new(
HttpClient::builder()
.with_time_getter(http_client_test_time)
.build(),
);
let key = PoolKey::http("example.com", None);
let id = client
.pool
.lock()
.register_connecting(key.clone(), Time::ZERO, 1);
let idle_guard = client.idle_connections.lock();
let (tx, rx) = mpsc::channel();
let release_client = Arc::clone(&client);
let release_key = key.clone();
let release_thread = std::thread::spawn(move || {
tx.send(()).expect("signal release start");
release_client.release_connection(&release_key, Some(id), true, loopback_client_io());
});
rx.recv().expect("release thread should start");
let deadline = Instant::now() + Duration::from_secs(1);
loop {
if client.pool.try_lock().is_none() {
break;
}
assert!(
Instant::now() <= deadline,
"release_connection should hold the pool lock until idle IO is inserted"
);
std::thread::yield_now();
}
drop(idle_guard);
release_thread
.join()
.expect("release thread should complete");
assert_eq!(
client.idle_connections.lock().get(&key).map_or(0, Vec::len),
1,
"release must leave idle IO recorded once the lock handoff completes"
);
let state = client
.pool
.lock()
.get_connection_meta(&key, id)
.expect("connection metadata should remain")
.state;
assert_eq!(state, crate::http::pool::PooledConnectionState::Idle);
}
#[test]
fn parse_set_cookie_pair_extracts_first_pair() {
let parsed = parse_set_cookie_pair("session=abc123; Path=/; HttpOnly");
assert_eq!(parsed, Some(("session".to_string(), "abc123".to_string())));
assert_eq!(parse_set_cookie_pair(""), None);
assert_eq!(parse_set_cookie_pair("invalid"), None);
assert_eq!(parse_set_cookie_pair(" =value"), None);
}
#[test]
fn cookie_store_replays_quoted_set_cookie_value_without_quotes() {
fn rfc6265_reference_pair(raw: &str) -> Option<(String, String)> {
let pair = raw.split(';').next()?.trim();
let (name, value) = pair.split_once('=')?;
let name = name.trim();
if name.is_empty() {
return None;
}
let value = value.trim();
let value = value
.strip_prefix('"')
.and_then(|inner| inner.strip_suffix('"'))
.unwrap_or(value);
Some((name.to_string(), value.to_string()))
}
let raw = r#"session="abc123=="; Path=/; HttpOnly"#;
assert_eq!(parse_set_cookie_pair(raw), rfc6265_reference_pair(raw));
let client = HttpClient::builder().cookie_store(true).build();
client.store_response_cookies(
"example.com",
&[("Set-Cookie".to_string(), raw.to_string())],
);
let cookie_header = client
.cookie_header_for_host("example.com")
.expect("cookie header");
assert_eq!(cookie_header, "session=abc123==");
}
#[test]
fn cookie_store_attaches_cookie_header_when_enabled() {
let client = HttpClient::builder().cookie_store(true).build();
client.store_response_cookies(
"Example.COM",
&[(
"Set-Cookie".to_string(),
"session=abc123; Path=/".to_string(),
)],
);
let parsed = ParsedUrl::parse("http://example.com/data").expect("valid URL");
let req = client.build_request(&Method::Get, &parsed, &[], &[], None, None);
assert_eq!(
get_header(&req.headers, "cookie"),
Some("session=abc123".to_string())
);
}
#[test]
fn cookie_store_respects_explicit_cookie_header() {
let client = HttpClient::builder().cookie_store(true).build();
client.store_response_cookies(
"example.com",
&[("Set-Cookie".to_string(), "session=abc123".to_string())],
);
let parsed = ParsedUrl::parse("http://example.com/path").expect("valid URL");
let req = client.build_request(
&Method::Get,
&parsed,
&[("Cookie".to_string(), "manual=1".to_string())],
&[],
None,
None,
);
assert_eq!(
get_header(&req.headers, "cookie"),
Some("manual=1".to_string())
);
}
#[test]
fn build_request_adds_proxy_authorization_when_not_explicit() {
let client = HttpClient::builder().build();
let parsed = ParsedUrl::parse("http://example.com/path").expect("valid URL");
let req = client.build_request(
&Method::Get,
&parsed,
&[("Accept".to_string(), "application/json".to_string())],
&[],
Some(absolute_request_target(&parsed)),
Some("Basic Zm9vOmJhcg=="),
);
assert_eq!(
get_header(&req.headers, "proxy-authorization"),
Some("Basic Zm9vOmJhcg==".to_string())
);
assert_eq!(
req.uri, "http://example.com/path",
"forward proxy request must use absolute-form URI"
);
}
#[test]
fn build_request_preserves_explicit_proxy_authorization() {
let client = HttpClient::builder().build();
let parsed = ParsedUrl::parse("http://example.com/path").expect("valid URL");
let req = client.build_request(
&Method::Get,
&parsed,
&[(
"Proxy-Authorization".to_string(),
"Basic ZXhwbGljaXQ=".to_string(),
)],
&[],
Some(absolute_request_target(&parsed)),
Some("Basic aWdub3JlZA=="),
);
assert_eq!(
get_header(&req.headers, "proxy-authorization"),
Some("Basic ZXhwbGljaXQ=".to_string())
);
}
#[test]
fn build_request_ignores_explicit_host_header() {
let client = HttpClient::builder().build();
let parsed = ParsedUrl::parse("http://example.com/path").expect("valid URL");
let req = client.build_request(
&Method::Get,
&parsed,
&[("Host".to_string(), "attacker.test".to_string())],
&[],
None,
None,
);
let host_headers: Vec<_> = req
.headers
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case("host"))
.collect();
assert_eq!(host_headers.len(), 1);
assert_eq!(host_headers[0].1, "example.com");
}
#[test]
fn ensure_multipart_content_type_adds_header_when_missing() {
let form = MultipartForm::with_boundary("upload-boundary")
.unwrap()
.text("user", "alice");
let mut headers = vec![("Accept".to_string(), "application/json".to_string())];
ensure_multipart_content_type(&mut headers, &form);
assert_eq!(
get_header(&headers, "content-type"),
Some("multipart/form-data; boundary=upload-boundary".to_string())
);
}
#[test]
fn ensure_multipart_content_type_respects_existing_header() {
let form = MultipartForm::with_boundary("upload-boundary")
.unwrap()
.text("user", "alice");
let mut headers = vec![(
"Content-Type".to_string(),
"multipart/form-data; boundary=manual".to_string(),
)];
ensure_multipart_content_type(&mut headers, &form);
assert_eq!(
get_header(&headers, "content-type"),
Some("multipart/form-data; boundary=manual".to_string())
);
assert_eq!(headers.len(), 1);
}
#[test]
fn cookie_store_updates_and_removes_existing_cookie() {
let client = HttpClient::builder().cookie_store(true).build();
client.store_response_cookies(
"example.com",
&[("Set-Cookie".to_string(), "session=abc123".to_string())],
);
client.store_response_cookies(
"example.com",
&[("Set-Cookie".to_string(), "theme=dark".to_string())],
);
client.store_response_cookies(
"example.com",
&[("Set-Cookie".to_string(), "session=updated".to_string())],
);
let cookie_header = client
.cookie_header_for_host("example.com")
.expect("cookie header");
assert!(cookie_header.contains("session=updated"));
assert!(cookie_header.contains("theme=dark"));
client.store_response_cookies(
"example.com",
&[("Set-Cookie".to_string(), "session=".to_string())],
);
let cookie_header = client
.cookie_header_for_host("example.com")
.expect("cookie header");
assert!(!cookie_header.contains("session="));
assert!(cookie_header.contains("theme=dark"));
}
#[derive(Debug)]
struct ConnectTestIo {
read_data: Vec<u8>,
read_pos: usize,
written: Vec<u8>,
}
impl ConnectTestIo {
fn new(read_data: impl AsRef<[u8]>) -> Self {
Self {
read_data: read_data.as_ref().to_vec(),
read_pos: 0,
written: Vec::new(),
}
}
}
impl AsyncRead for ConnectTestIo {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.read_pos >= self.read_data.len() {
return Poll::Ready(Ok(()));
}
let remaining = self.read_data.len() - self.read_pos;
let to_copy = remaining.min(buf.remaining());
buf.put_slice(&self.read_data[self.read_pos..self.read_pos + to_copy]);
self.read_pos += to_copy;
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for ConnectTestIo {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
data: &[u8],
) -> Poll<io::Result<usize>> {
self.written.extend_from_slice(data);
Poll::Ready(Ok(data.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn connect_tunnel_writes_expected_request() {
let io = ConnectTestIo::new("HTTP/1.1 200 Connection Established\r\n\r\n");
let tunnel = block_on(establish_http_connect_tunnel(
io,
"example.com:443",
Some("asupersync-test/1.0"),
&[("Proxy-Authorization".into(), "Basic abc".into())],
))
.expect("tunnel should establish");
let io = tunnel.into_inner();
let written = String::from_utf8(io.written).expect("request should be utf8");
assert!(written.starts_with("CONNECT example.com:443 HTTP/1.1\r\n"));
assert!(written.contains("\r\nHost: example.com:443\r\n"));
assert!(written.contains("\r\nUser-Agent: asupersync-test/1.0\r\n"));
assert!(written.contains("\r\nProxy-Authorization: Basic abc\r\n"));
assert!(written.ends_with("\r\n\r\n"));
}
#[test]
fn connect_tunnel_preserves_prefetched_bytes_and_supports_write() {
let io = ConnectTestIo::new("HTTP/1.1 200 OK\r\n\r\nHELLO");
let mut tunnel = block_on(establish_http_connect_tunnel(
io,
"example.com:443",
None,
&[],
))
.expect("tunnel should establish");
assert_eq!(tunnel.prefetched_len(), 5);
let mut first = [0u8; 3];
block_on(async {
poll_fn(|cx| -> std::task::Poll<Result<(), std::io::Error>> {
let mut rb = ReadBuf::new(&mut first);
Pin::new(&mut tunnel).poll_read(cx, &mut rb)
})
.await
.expect("read prefetched bytes");
});
assert_eq!(&first, b"HEL");
assert_eq!(tunnel.prefetched_len(), 2);
block_on(async {
tunnel.write_all(b"PING").await.expect("write to tunnel");
tunnel.flush().await.expect("flush to tunnel");
});
let io = tunnel.into_inner();
let written = String::from_utf8(io.written).expect("request should be utf8");
assert!(written.ends_with("\r\n\r\nPING"));
}
#[test]
fn connect_tunnel_rejects_non_success_status() {
let io = ConnectTestIo::new("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n");
let err = block_on(establish_http_connect_tunnel(
io,
"example.com:443",
None,
&[],
))
.expect_err("non-2xx should fail");
match err {
ClientError::ConnectTunnelRefused { status, reason } => {
assert_eq!(status, 407);
assert!(reason.contains("Proxy Authentication Required"));
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn connect_tunnel_rejects_header_injection() {
let io = ConnectTestIo::new("HTTP/1.1 200 OK\r\n\r\n");
let err = block_on(establish_http_connect_tunnel(
io,
"example.com:443",
None,
&[("X-Test".into(), "ok\r\nbad".into())],
))
.expect_err("CRLF in header value must be rejected");
match err {
ClientError::InvalidConnectInput(msg) => {
assert!(msg.contains("header name/value"));
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn redirect_policy_default_is_limited() {
let policy = RedirectPolicy::default();
assert!(matches!(policy, RedirectPolicy::Limited(10)));
}
#[test]
fn scheme_debug_clone_copy_eq() {
let a = Scheme::Http;
let b = a; let c = a;
assert_eq!(a, b);
assert_eq!(a, c);
assert_ne!(a, Scheme::Https);
let dbg = format!("{a:?}");
assert!(dbg.contains("Http"));
}
#[test]
fn redirect_policy_debug_clone() {
let a = RedirectPolicy::Limited(5);
let b = a.clone();
let dbg = format!("{a:?}");
assert!(dbg.contains("Limited"));
assert!(dbg.contains('5'));
let dbg2 = format!("{b:?}");
assert_eq!(dbg, dbg2);
}
#[test]
fn parsed_url_debug_clone() {
let url = ParsedUrl {
scheme: Scheme::Https,
host: "example.com".to_string(),
port: 443,
path: "/api/v1".to_string(),
};
let cloned = url.clone();
assert_eq!(cloned.host, "example.com");
assert_eq!(cloned.port, 443);
let dbg = format!("{url:?}");
assert!(dbg.contains("ParsedUrl"));
assert!(dbg.contains("example.com"));
}
#[test]
fn header_has_token_matches_case_insensitive_csv_values() {
let headers = vec![
("Connection".to_string(), "keep-alive, Upgrade".to_string()),
("X-Test".to_string(), "value".to_string()),
];
assert!(header_has_token(&headers, "connection", "keep-alive"));
assert!(header_has_token(&headers, "connection", "upgrade"));
assert!(!header_has_token(&headers, "connection", "close"));
}
#[test]
fn connection_can_be_reused_for_http11_without_close() {
let response = Response {
version: Version::Http11,
status: 200,
reason: "OK".into(),
headers: vec![
("Content-Length".into(), "0".into()),
("Connection".into(), "keep-alive".into()),
],
body: Vec::new(),
trailers: Vec::new(),
};
assert!(connection_can_be_reused(&response, &Method::Get));
let close_response = Response {
headers: vec![
("Content-Length".into(), "0".into()),
("Connection".into(), "close".into()),
],
..response
};
assert!(!connection_can_be_reused(&close_response, &Method::Get));
}
#[test]
fn connection_can_be_reused_for_http10_only_with_keep_alive() {
let response = Response {
version: Version::Http10,
status: 200,
reason: "OK".into(),
headers: vec![
("Content-Length".into(), "0".into()),
("Connection".into(), "keep-alive".into()),
],
body: Vec::new(),
trailers: Vec::new(),
};
assert!(connection_can_be_reused(&response, &Method::Get));
let no_header = Response {
headers: Vec::new(),
..response
};
assert!(!connection_can_be_reused(&no_header, &Method::Get));
}
#[test]
fn connection_not_reused_for_eof_delimited_body() {
let response = Response {
version: Version::Http11,
status: 200,
reason: "OK".into(),
headers: vec![],
body: Vec::new(),
trailers: Vec::new(),
};
assert!(!connection_can_be_reused(&response, &Method::Get));
let no_content = Response {
status: 204,
reason: "No Content".into(),
..response.clone()
};
assert!(connection_can_be_reused(&no_content, &Method::Get));
let chunked = Response {
headers: vec![("Transfer-Encoding".into(), "chunked".into())],
..response
};
assert!(connection_can_be_reused(&chunked, &Method::Get));
}
#[test]
fn connection_not_reused_for_protocol_upgrade() {
let response = Response {
version: Version::Http11,
status: 101,
reason: "Switching Protocols".into(),
headers: vec![
("Connection".into(), "Upgrade".into()),
("Upgrade".into(), "websocket".into()),
],
body: Vec::new(),
trailers: Vec::new(),
};
assert!(!connection_can_be_reused(&response, &Method::Get));
}
#[test]
fn request_connection_close_forbids_reuse() {
let headers = vec![("Connection".to_string(), "keep-alive, close".to_string())];
assert!(request_forbids_connection_reuse(&headers));
let upgrade_headers = vec![("Upgrade".to_string(), "websocket".to_string())];
assert!(request_forbids_connection_reuse(&upgrade_headers));
}
#[test]
fn check_cx_returns_cancelled_when_cancelled() {
let cx = Cx::for_testing();
assert!(check_cx(&cx).is_ok());
cx.set_cancel_requested(true);
let err = check_cx(&cx).unwrap_err();
assert!(err.is_cancelled());
}
#[test]
fn request_returns_cancelled_when_cx_already_cancelled() {
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let client = HttpClient::new();
let result = block_on(client.get(&cx, "http://example.com/test"));
assert!(result.is_err());
assert!(result.unwrap_err().is_cancelled());
}
#[test]
fn post_returns_cancelled_when_cx_already_cancelled() {
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let client = HttpClient::new();
let result = block_on(client.post(&cx, "http://example.com/submit", b"data".to_vec()));
assert!(result.unwrap_err().is_cancelled());
}
#[test]
fn put_returns_cancelled_when_cx_already_cancelled() {
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let client = HttpClient::new();
let result = block_on(client.put(&cx, "http://example.com/item", b"data".to_vec()));
assert!(result.unwrap_err().is_cancelled());
}
#[test]
fn delete_returns_cancelled_when_cx_already_cancelled() {
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let client = HttpClient::new();
let result = block_on(client.delete(&cx, "http://example.com/item"));
assert!(result.unwrap_err().is_cancelled());
}
#[test]
fn request_streaming_returns_cancelled_when_cx_already_cancelled() {
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let client = HttpClient::new();
let result = block_on(client.request_streaming(
&cx,
Method::Get,
"http://example.com/stream",
Vec::new(),
Vec::new(),
));
assert!(result.unwrap_err().is_cancelled());
}
#[test]
fn connect_tunnel_returns_cancelled_when_cx_already_cancelled() {
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let client = HttpClient::new();
let result = block_on(client.connect_tunnel(
&cx,
"http://proxy.local:3128",
"example.com:443",
Vec::new(),
));
assert!(result.unwrap_err().is_cancelled());
}
#[test]
fn request_succeeds_with_non_cancelled_cx() {
let cx = Cx::for_testing();
let client = HttpClient::new();
let result = block_on(client.get(&cx, "http://127.0.0.1:1/nonexistent"));
assert!(result.is_err());
assert!(!result.unwrap_err().is_cancelled());
}
#[test]
fn acquire_connection_respects_pool_limits_before_opening_socket() {
use std::time::{Duration, Instant};
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
listener
.set_nonblocking(true)
.expect("set nonblocking listener");
let addr = listener.local_addr().expect("listener address");
let url = format!("http://{addr}/pooled");
let parsed = ParsedUrl::parse(&url).expect("parse pooled url");
let key = parsed.pool_key();
let client = HttpClient::builder()
.max_connections_per_host(1)
.max_total_connections(1)
.build();
client.pool.lock().register_connecting(key, Time::ZERO, 1);
let cx = Cx::for_testing();
let err = match block_on(client.acquire_connection(&cx, &parsed)) {
Ok(_) => panic!("pool exhaustion must reject before dialing"),
Err(err) => err,
};
assert!(matches!(
err,
ClientError::PoolExhausted { ref host, port }
if host == "127.0.0.1" && port == addr.port()
));
let deadline = Instant::now() + Duration::from_millis(100);
loop {
match listener.accept() {
Ok(_) => panic!("pool-exhausted acquisition must not open a socket"),
Err(io_err) if io_err.kind() == io::ErrorKind::WouldBlock => {
if Instant::now() >= deadline {
break;
}
std::thread::sleep(Duration::from_millis(5));
}
Err(io_err) => panic!("accept failed: {io_err}"),
}
}
}
#[test]
fn request_returns_cancelled_when_cx_is_cancelled_after_exchange() {
use std::io::{Read, Write};
use std::sync::{Arc, mpsc};
fn read_request_head(stream: &mut std::net::TcpStream) {
let mut buf = [0_u8; 1024];
let mut request = Vec::new();
loop {
let n = stream.read(&mut buf).expect("read request");
assert!(n > 0, "request must arrive before peer closes");
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
}
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
let addr = listener.local_addr().expect("listener address");
let (request_seen_tx, request_seen_rx) = mpsc::channel();
let (send_response_tx, send_response_rx) = mpsc::channel();
let server = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept request");
read_request_head(&mut stream);
request_seen_tx.send(()).expect("notify request arrival");
send_response_rx
.recv()
.expect("wait for cancellation before responding");
stream
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok",
)
.expect("write response");
stream.flush().expect("flush response");
});
let client = Arc::new(
HttpClient::builder()
.max_connections_per_host(1)
.max_total_connections(1)
.build(),
);
let cx = Cx::for_testing();
let request_cx = cx.clone();
let request_client = Arc::clone(&client);
let url = format!("http://{addr}/late-cancel");
let request_url = url.clone();
let request =
std::thread::spawn(move || block_on(request_client.get(&request_cx, &request_url)));
request_seen_rx
.recv()
.expect("server should observe the request before cancellation");
cx.set_cancel_requested(true);
send_response_tx
.send(())
.expect("allow server to send response");
let result = request.join().expect("request thread should join");
assert!(matches!(result, Err(ClientError::Cancelled)));
server.join().expect("server thread should join");
let stats = client.pool_stats();
assert_eq!(
stats.idle_connections, 0,
"late-cancelled responses must not be returned to the idle pool"
);
}
#[test]
fn safe_method_retries_once_after_stale_pooled_connection_close() {
use std::io::{Read, Write};
use std::time::Duration;
fn read_request_head(stream: &mut std::net::TcpStream) {
let mut buf = [0_u8; 1024];
let mut request = Vec::new();
loop {
let n = stream.read(&mut buf).expect("read request");
assert!(n > 0, "request must arrive before peer closes");
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
}
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
let addr = listener.local_addr().expect("listener address");
let server = std::thread::spawn(move || {
let (mut first, _) = listener.accept().expect("accept first connection");
read_request_head(&mut first);
first
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok",
)
.expect("write first response");
first.flush().expect("flush first response");
std::thread::sleep(Duration::from_millis(50));
drop(first);
let (mut second, _) = listener.accept().expect("accept retry connection");
read_request_head(&mut second);
second
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: close\r\n\r\nfresh",
)
.expect("write retry response");
second.flush().expect("flush retry response");
});
let client = HttpClient::builder()
.max_connections_per_host(1)
.max_total_connections(1)
.build();
let cx = Cx::for_testing();
let url = format!("http://{addr}/healthz");
let first = block_on(client.get(&cx, &url)).expect("initial request should succeed");
assert_eq!(first.status, 200);
assert_eq!(first.body, b"ok");
std::thread::sleep(Duration::from_millis(100));
let second = block_on(client.get(&cx, &url)).expect("retry request should succeed");
assert_eq!(second.status, 200);
assert_eq!(second.body, b"fresh");
server.join().expect("server thread should join");
let stats = client.pool_stats();
assert!(
stats.connections_created >= 2,
"client should establish a fresh connection after stale pooled reuse fails"
);
}
#[test]
fn request_connection_close_header_prevents_pool_reuse() {
use std::io::{Read, Write};
use std::time::{Duration, Instant};
fn read_request_head(stream: &mut std::net::TcpStream) -> Vec<u8> {
let mut buf = [0_u8; 1024];
let mut request = Vec::new();
loop {
let n = stream.read(&mut buf).expect("read request");
assert!(n > 0, "request must arrive before peer closes");
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
request
}
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
listener
.set_nonblocking(true)
.expect("set nonblocking listener");
let addr = listener.local_addr().expect("listener address");
let server = std::thread::spawn(move || {
let deadline = Instant::now() + Duration::from_secs(5);
let accept = |listener: &TcpListener| -> std::net::TcpStream {
loop {
match listener.accept() {
Ok((stream, _)) => return stream,
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
assert!(
Instant::now() <= deadline,
"timed out waiting for a fresh connection"
);
std::thread::sleep(Duration::from_millis(10));
}
Err(err) => panic!("accept failed: {err}"),
}
}
};
let mut first = accept(&listener);
let first_req = read_request_head(&mut first);
first
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok",
)
.expect("write first response");
first.flush().expect("flush first response");
let mut second = accept(&listener);
let second_req = read_request_head(&mut second);
second
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok")
.expect("write second response");
second.flush().expect("flush second response");
(first_req, second_req)
});
let client = HttpClient::builder()
.max_connections_per_host(1)
.max_total_connections(1)
.build();
let cx = Cx::for_testing();
let url = format!("http://{addr}/close-after-response");
let headers = vec![("Connection".to_string(), "close".to_string())];
let first = block_on(client.request(&cx, Method::Get, &url, headers.clone(), Vec::new()))
.expect("first request should succeed");
assert_eq!(first.status, 200);
let second = block_on(client.request(&cx, Method::Get, &url, headers, Vec::new()))
.expect("second request should succeed");
assert_eq!(second.status, 200);
let (first_req, second_req) = server.join().expect("server thread should join");
let first_text = String::from_utf8(first_req).expect("first request should be utf8");
let second_text = String::from_utf8(second_req).expect("second request should be utf8");
assert!(first_text.contains("Connection: close\r\n"));
assert!(second_text.contains("Connection: close\r\n"));
let stats = client.pool_stats();
assert!(
stats.connections_created >= 2,
"Connection: close requests must not be satisfied from a reused idle connection"
);
assert_eq!(
stats.idle_connections, 0,
"Connection: close requests must not leave pooled idle connections behind"
);
}
#[test]
fn proxy_non_streaming_path_respects_max_body_size() {
use crate::http::h1::codec::HttpError;
use std::io::{Read, Write};
fn read_request_head(stream: &mut std::net::TcpStream) -> Vec<u8> {
let mut buf = [0_u8; 1024];
let mut request = Vec::new();
loop {
let n = stream.read(&mut buf).expect("read request");
assert!(n > 0, "request must arrive before peer closes");
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
request
}
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
let addr = listener.local_addr().expect("listener address");
let server = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept proxy request");
let request = read_request_head(&mut stream);
stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok")
.expect("write proxy response");
stream.flush().expect("flush proxy response");
request
});
let client = HttpClient::builder()
.proxy(format!("http://{addr}"))
.max_body_size(1)
.build();
let cx = Cx::for_testing();
let err = block_on(client.get(&cx, "http://example.com/oversized"))
.expect_err("response body should exceed configured limit");
match err {
ClientError::HttpError(HttpError::BodyTooLargeDetailed { actual, limit }) => {
assert_eq!(actual, 2);
assert_eq!(limit, 1);
}
other => panic!("unexpected error: {other:?}"),
}
let request = server.join().expect("proxy server thread should join");
let request_text = String::from_utf8(request).expect("proxy request should be utf8");
assert!(
request_text.starts_with("GET http://example.com/oversized HTTP/1.1\r\n"),
"forward-proxy request must use absolute-form and hit the non-streaming proxy path"
);
}
}