#[cfg(feature = "native-tls")]
use async_native_tls::TlsConnector as NativeTlsConnector;
#[cfg(test)]
use async_net::TcpListener;
use async_net::TcpStream;
#[cfg(all(test, feature = "rustls"))]
use async_tls::TlsAcceptor;
use bytes::Bytes;
use futures_lite::StreamExt;
use futures_lite::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use rcgen::generate_simple_self_signed;
#[cfg(all(test, feature = "rustls"))]
use rustls::ServerConfig;
use std::net::SocketAddr;
use std::sync::Mutex;
use std::sync::{Arc, OnceLock};
use std::time::Instant;
use crate::alt_svc::AltSvcCache;
use crate::body::{Body, BodyData};
use crate::browser_emulation::Http1Fingerprint;
use crate::client::{RedirectPolicy, ResponseFuture, Transport};
use crate::connection::tcp::{connect_happy_eyeballs, connect_tcp_addr};
#[cfg(feature = "native-tls")]
use crate::connection::tls::build_native_tls_connector_for_protocols;
#[cfg(feature = "rustls")]
use crate::connection::tls::connect_async_tls_with_config;
use crate::decode::{DEFAULT_ACCEPT_ENCODING, maybe_decode_response_body};
use crate::dns::{DnsCache, DnsConfig};
use crate::error::{Error, ErrorKind, Result};
use crate::header::HeaderMap;
use crate::metrics::Metrics;
use crate::pool::{IdlePool, PoolConfig, PoolKey};
use crate::progress::{ProgressPhase, ProgressReporter};
use crate::protocol::http2 as h2;
use crate::protocol::http3 as h3;
use crate::proxy::{Proxy, build_http_connect_request};
use crate::request::{Method, ProgressCallback, ProtocolPolicy, Request, TimeoutConfig};
use crate::response::{Response, StatusCode, TrailerState, Version};
use crate::retry::{backoff, retry_attempts, should_retry_request, should_retry_stale_connection};
#[cfg(any(feature = "rustls", feature = "native-tls", feature = "btls-backend"))]
use crate::tls::TlsBackend;
#[cfg(any(feature = "rustls", feature = "native-tls", feature = "btls-backend"))]
use crate::tls::TlsConfig;
#[cfg(feature = "native-tls")]
use crate::tls::verify_pinned_certificate;
use crate::url::Url;
use crate::util::response_body_allowed;
trait IoStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> IoStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
type BoxedStream = Box<dyn IoStream>;
type ConnectionPool = IdlePool<BoxedStream>;
#[derive(Clone, Copy, Default)]
struct ConnectTiming {
dns: Option<std::time::Duration>,
connect: Option<std::time::Duration>,
tls: Option<std::time::Duration>,
}
pub struct Http1Transport {
pool_config: PoolConfig,
pool: Arc<Mutex<ConnectionPool>>,
h2_pool: Arc<Mutex<h2::ConnectionPool>>,
h3_pool: Arc<Mutex<h3::ConnectionPool>>,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
alt_svc_cache: AltSvcCache,
}
impl Default for Http1Transport {
fn default() -> Self {
Self::new(
PoolConfig::default(),
Arc::new(DnsCache::default()),
DnsConfig::default(),
None,
)
}
}
impl Http1Transport {
pub fn new(
pool_config: PoolConfig,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
) -> Self {
Self {
pool_config,
pool: Arc::new(Mutex::new(ConnectionPool::default())),
h2_pool: Arc::new(Mutex::new(h2::ConnectionPool::default())),
h3_pool: Arc::new(Mutex::new(h3::ConnectionPool::default())),
dns_cache,
dns_config,
local_addr,
alt_svc_cache: AltSvcCache::default(),
}
}
}
impl Transport for Http1Transport {
fn execute(&self, request: Request) -> ResponseFuture {
let pool_config = self.pool_config;
let pool = Arc::clone(&self.pool);
let h2_pool = Arc::clone(&self.h2_pool);
let h3_pool = Arc::clone(&self.h3_pool);
let dns_cache = Arc::clone(&self.dns_cache);
let dns_config = self.dns_config;
let local_addr = self.local_addr;
let alt_svc_cache = self.alt_svc_cache.clone();
let method_str = request.method().as_str();
let url_str = request.url().to_string();
Box::pin(async move {
execute_http1(
request,
pool,
h2_pool,
h3_pool,
dns_cache,
dns_config,
local_addr,
pool_config,
alt_svc_cache,
)
.await
.map_err(|e| e.with_request_context(method_str, &url_str))
})
}
fn execute_with_redirect(&self, request: Request, policy: RedirectPolicy) -> ResponseFuture {
let pool_config = self.pool_config;
let pool = Arc::clone(&self.pool);
let h2_pool = Arc::clone(&self.h2_pool);
let h3_pool = Arc::clone(&self.h3_pool);
let dns_cache = Arc::clone(&self.dns_cache);
let dns_config = self.dns_config;
let local_addr = self.local_addr;
let alt_svc_cache = self.alt_svc_cache.clone();
let method_str = request.method().as_str();
let url_str = request.url().to_string();
Box::pin(async move {
execute_http1_with_redirect(
request,
policy,
pool,
h2_pool,
h3_pool,
dns_cache,
dns_config,
local_addr,
pool_config,
alt_svc_cache,
)
.await
.map_err(|e| e.with_request_context(&method_str, &url_str))
})
}
fn close(&self) -> Result<()> {
self.pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.clear();
self.h2_pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.clear();
self.h3_pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.clear();
self.alt_svc_cache.clear();
Ok(())
}
}
async fn execute_http1(
request: Request,
pool: Arc<Mutex<ConnectionPool>>,
h2_pool: Arc<Mutex<h2::ConnectionPool>>,
h3_pool: Arc<Mutex<h3::ConnectionPool>>,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
pool_config: PoolConfig,
alt_svc_cache: AltSvcCache,
) -> Result<Response> {
let protocol_policy = request.protocol_policy();
let protocol_policy = if protocol_policy == ProtocolPolicy::Auto
&& h3::can_attempt_h3(&request)
&& alt_svc_cache.has_h3(request.url().host(), request.url().effective_port())
{
ProtocolPolicy::PreferHttp3
} else {
protocol_policy
};
if protocol_policy == ProtocolPolicy::Http3Only {
return h3::execute(
request,
Arc::clone(&h3_pool),
dns_cache,
dns_config,
local_addr,
pool_config,
)
.await;
}
if protocol_policy == ProtocolPolicy::PreferHttp3 && h3::can_attempt_h3(&request) {
if let Ok(cloned) = h3::clone_request_for_h3(&request) {
match h3::execute(
cloned,
Arc::clone(&h3_pool),
Arc::clone(&dns_cache),
dns_config,
local_addr,
pool_config,
)
.await
{
Ok(response) => return Ok(response),
Err(err) if matches!(err.kind(), ErrorKind::Transport | ErrorKind::Timeout) => {}
Err(err) => return Err(err),
}
}
}
let protocol_policy = request.protocol_policy();
let should_try_h2 = match protocol_policy {
ProtocolPolicy::Http2Only | ProtocolPolicy::PreferHttp2 => true,
ProtocolPolicy::Auto | ProtocolPolicy::PreferHttp3 => {
request.url().scheme() == "http" && request.prior_knowledge_h2c()
}
ProtocolPolicy::Http1Only | ProtocolPolicy::Http3Only => false,
};
if should_try_h2 && h2::can_attempt_h2(&request) {
if protocol_policy == ProtocolPolicy::Http2Only {
return h2::execute(
request,
Arc::clone(&h2_pool),
dns_cache,
dns_config,
local_addr,
pool_config,
)
.await;
}
match h2::clone_request_for_h2(&request) {
Ok(cloned) => match h2::execute(
cloned,
Arc::clone(&h2_pool),
Arc::clone(&dns_cache),
dns_config,
local_addr,
pool_config,
)
.await
{
Ok(response) => return Ok(response),
Err(err) if matches!(err.kind(), ErrorKind::Transport | ErrorKind::Timeout) => {}
Err(err) => return Err(err),
},
Err(_) => {}
}
}
match request.protocol_policy() {
ProtocolPolicy::Http2Only => {
return Err(Error::new(
ErrorKind::Transport,
"http2-only requested, but http2 transport is unavailable",
));
}
ProtocolPolicy::Http3Only => {
return Err(Error::new(
ErrorKind::Transport,
"http3-only requested, but http3 transport is unavailable",
));
}
_ => {}
}
let (
method,
url,
headers,
cookies,
timeout_config,
protocol_policy,
retry_policy,
_prior_knowledge_h2c,
progress_callback,
progress_config,
_h2_keepalive_config,
tls_config,
proxy,
compression_mode,
body,
browser_profile,
) = request.into_parts();
let http1_fingerprint = browser_profile.as_ref().and_then(|p| p.http1_fingerprint());
let request_allows_reuse = !headers
.get_all("connection")
.iter()
.any(|value| value.eq_ignore_ascii_case("close"));
let body_data = body.into_data()?;
let use_http_proxy_absolute_form =
matches!(proxy, Some(Proxy::Http { .. })) && url.scheme() == "http";
let (request_head, request_body, write_mode) = if use_http_proxy_absolute_form {
encode_proxy_request(
method.as_str(),
&url,
&headers,
&cookies,
proxy.as_ref().expect("proxy exists"),
compression_mode,
body_data,
http1_fingerprint,
)?
} else {
encode_request(
method.as_str(),
&url,
&headers,
&cookies,
compression_mode,
body_data,
http1_fingerprint,
)?
};
match protocol_policy {
ProtocolPolicy::Http2Only | ProtocolPolicy::Http3Only => {
unreachable!("http2/http3-only requests are handled before HTTP/1 execution")
}
ProtocolPolicy::Auto
| ProtocolPolicy::Http1Only
| ProtocolPolicy::PreferHttp2
| ProtocolPolicy::PreferHttp3 => {}
}
let pool_key = PoolKey::for_http1(&url, &tls_config, protocol_policy, browser_profile.as_ref());
let result = match request_body {
BodyData::Bytes(bytes) => {
let attempts = retry_attempts(retry_policy);
match url.scheme() {
"http" => {
let mut remaining_attempts = attempts;
let mut attempt_index: usize = 0;
loop {
let (stream, reused, timing) = if let Some(stream) = pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.checkout(&pool_key, pool_config)
{
(stream, true, ConnectTiming::default())
} else {
let (stream, timing) = if let Some(proxy) = &proxy {
let s = match proxy {
Proxy::Http { .. } => Box::new(
with_timeout_io(
timeout_config.connect,
TcpStream::connect(proxy.addr()),
"proxy connect timed out",
)
.await?,
)
as BoxedStream,
Proxy::Socks5 { .. } => Box::new(
connect_via_proxy(
proxy,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?,
)
as BoxedStream,
};
(s, ConnectTiming::default())
} else {
let (stream, timing) = connect_tcp_target(
&dns_cache,
dns_config,
&url,
timeout_config,
local_addr,
)
.await?;
(Box::new(stream) as BoxedStream, timing)
};
(stream, false, timing)
};
let result = execute_http1_over_stream(
stream,
method,
&url,
timeout_config,
compression_mode,
progress_callback.clone(),
progress_config,
&request_head,
BodyData::Bytes(bytes.clone()),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key.clone(),
pool_config,
)
.await;
match result {
Ok(response) => break Ok(response),
Err(err) if reused && should_retry_stale_connection(method, &err) => {
let (stream, timing) = if let Some(proxy) = &proxy {
let s = match proxy {
Proxy::Http { .. } => Box::new(
with_timeout_io(
timeout_config.connect,
TcpStream::connect(proxy.addr()),
"proxy connect timed out",
)
.await?,
)
as BoxedStream,
Proxy::Socks5 { .. } => Box::new(
connect_via_proxy(
proxy,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?,
)
as BoxedStream,
};
(s, ConnectTiming::default())
} else {
let (stream, timing) = connect_tcp_target(
&dns_cache,
dns_config,
&url,
timeout_config,
local_addr,
)
.await?;
(Box::new(stream) as BoxedStream, timing)
};
let retried = execute_http1_over_stream(
stream,
method,
&url,
timeout_config,
compression_mode,
progress_callback.clone(),
progress_config,
&request_head,
BodyData::Bytes(bytes.clone()),
write_mode,
request_allows_reuse,
false,
timing,
Arc::clone(&pool),
pool_key.clone(),
pool_config,
)
.await;
match retried {
Ok(response) => break Ok(response),
Err(err)
if should_retry_request(
method,
&err,
remaining_attempts,
) =>
{
remaining_attempts -= 1;
backoff(attempt_index).await;
attempt_index += 1;
continue;
}
Err(err) => break Err(err),
}
}
Err(err) if should_retry_request(method, &err, remaining_attempts) => {
remaining_attempts -= 1;
backoff(attempt_index).await;
attempt_index += 1;
continue;
}
Err(err) => break Err(err),
}
}
}
"https" => {
let mut remaining_attempts = attempts;
let mut attempt_index: usize = 0;
loop {
let result = execute_https_request(
method,
&url,
&tls_config,
protocol_policy,
timeout_config,
progress_callback.clone(),
progress_config,
&request_head,
BodyData::Bytes(bytes.clone()),
write_mode,
compression_mode,
request_allows_reuse,
Arc::clone(&pool),
pool_key.clone(),
Arc::clone(&dns_cache),
dns_config,
local_addr,
pool_config,
proxy.as_ref(),
)
.await;
match result {
Ok(response) => break Ok(response),
Err(err) if should_retry_request(method, &err, remaining_attempts) => {
remaining_attempts -= 1;
backoff(attempt_index).await;
attempt_index += 1;
continue;
}
Err(err) => break Err(err),
}
}
}
_ => Err(Error::new(
ErrorKind::Transport,
format!("unsupported scheme for current transport: {}", url.scheme()),
)),
}
}
BodyData::Stream(body_stream) => match url.scheme() {
"http" => {
let (stream, reused, timing) = if let Some(stream) = pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.checkout(&pool_key, pool_config)
{
(stream, true, ConnectTiming::default())
} else {
let (stream, timing) = if let Some(proxy) = &proxy {
let s = match proxy {
Proxy::Http { .. } => Box::new(
with_timeout_io(
timeout_config.connect,
TcpStream::connect(proxy.addr()),
"proxy connect timed out",
)
.await?,
) as BoxedStream,
Proxy::Socks5 { .. } => Box::new(
connect_via_proxy(
proxy,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?,
) as BoxedStream,
};
(s, ConnectTiming::default())
} else {
let (stream, timing) = connect_tcp_target(
&dns_cache,
dns_config,
&url,
timeout_config,
local_addr,
)
.await?;
(Box::new(stream) as BoxedStream, timing)
};
(stream, false, timing)
};
execute_http1_over_stream(
stream,
method,
&url,
timeout_config,
compression_mode,
progress_callback,
progress_config,
&request_head,
BodyData::Stream(body_stream),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key,
pool_config,
)
.await
}
"https" => {
execute_https_request(
method,
&url,
&tls_config,
protocol_policy,
timeout_config,
progress_callback,
progress_config,
&request_head,
BodyData::Stream(body_stream),
write_mode,
compression_mode,
request_allows_reuse,
Arc::clone(&pool),
pool_key,
dns_cache,
dns_config,
local_addr,
pool_config,
proxy.as_ref(),
)
.await
}
_ => Err(Error::new(
ErrorKind::Transport,
format!("unsupported scheme for current transport: {}", url.scheme()),
)),
},
};
if let Ok(ref response) = result {
if let Some(alt_svc_value) = response.headers().get("alt-svc") {
alt_svc_cache.record(url.host(), url.effective_port(), alt_svc_value);
}
}
result
}
#[cfg(any(feature = "rustls", feature = "native-tls", feature = "btls-backend"))]
async fn execute_https_request(
method: Method,
url: &Url,
tls_config: &TlsConfig,
protocol_policy: ProtocolPolicy,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
request_head: &[u8],
request_body: BodyData,
write_mode: BodyWriteMode,
compression_mode: crate::CompressionMode,
request_allows_reuse: bool,
pool: Arc<Mutex<ConnectionPool>>,
pool_key: PoolKey,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
pool_config: PoolConfig,
proxy: Option<&Proxy>,
) -> Result<Response> {
fn default_backend() -> TlsBackend {
#[cfg(feature = "rustls")]
{
return TlsBackend::Rustls;
}
#[cfg(all(not(feature = "rustls"), feature = "native-tls"))]
{
return TlsBackend::Native;
}
#[cfg(all(
not(feature = "rustls"),
not(feature = "native-tls"),
feature = "btls-backend"
))]
{
return TlsBackend::Boring;
}
}
let backend = tls_config.backend.unwrap_or_else(default_backend);
match backend {
#[cfg(feature = "rustls")]
TlsBackend::Rustls => {
execute_https_request_async_tls(
url,
tls_config,
protocol_policy,
timeout_config,
progress_callback,
progress_config,
request_head,
request_body,
write_mode,
compression_mode,
method,
request_allows_reuse,
pool,
pool_key,
dns_cache,
dns_config,
local_addr,
pool_config,
proxy,
)
.await
}
#[cfg(feature = "native-tls")]
TlsBackend::Native => {
execute_https_request_native(
url,
tls_config,
protocol_policy,
timeout_config,
progress_callback,
progress_config,
request_head,
request_body,
write_mode,
compression_mode,
method,
request_allows_reuse,
pool,
pool_key,
dns_cache,
dns_config,
local_addr,
pool_config,
proxy,
)
.await
}
#[cfg(feature = "btls-backend")]
TlsBackend::Boring => {
execute_https_request_boring(
url,
tls_config,
protocol_policy,
timeout_config,
progress_callback,
progress_config,
request_head,
request_body,
write_mode,
compression_mode,
method,
request_allows_reuse,
pool,
pool_key,
dns_cache,
dns_config,
local_addr,
pool_config,
proxy,
)
.await
}
}
}
#[cfg(all(
not(feature = "rustls"),
not(feature = "native-tls"),
not(feature = "btls-backend")
))]
async fn execute_https_request(
_method: Method,
_url: &Url,
_tls_config: &crate::tls::TlsConfig,
_protocol_policy: ProtocolPolicy,
_timeout_config: TimeoutConfig,
_progress_callback: Option<ProgressCallback>,
_request_head: &[u8],
_request_body: BodyData,
_write_mode: BodyWriteMode,
_compression_mode: crate::CompressionMode,
_request_allows_reuse: bool,
_pool: Arc<Mutex<ConnectionPool>>,
_pool_key: PoolKey,
_dns_cache: Arc<DnsCache>,
_dns_config: DnsConfig,
_local_addr: Option<SocketAddr>,
_pool_config: PoolConfig,
_proxy: Option<&Proxy>,
) -> Result<Response> {
Err(Error::new(
ErrorKind::Transport,
"https support requires a tls feature",
))
}
#[cfg(feature = "rustls")]
async fn execute_https_request_async_tls(
url: &Url,
tls_config: &TlsConfig,
protocol_policy: ProtocolPolicy,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
request_head: &[u8],
request_body: BodyData,
write_mode: BodyWriteMode,
compression_mode: crate::CompressionMode,
method: Method,
request_allows_reuse: bool,
pool: Arc<Mutex<ConnectionPool>>,
pool_key: PoolKey,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
pool_config: PoolConfig,
proxy: Option<&Proxy>,
) -> Result<Response> {
let (stream, reused, timing) = if let Some(stream) = pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.checkout(&pool_key, pool_config)
{
(stream, true, ConnectTiming::default())
} else {
let (tcp_stream, mut timing) = if let Some(proxy) = proxy {
let s = match proxy {
Proxy::Http { addr, auth } => {
connect_via_http_proxy_tunnel(
addr,
auth,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
Proxy::Socks5 { .. } => {
connect_via_proxy(
proxy,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
};
(s, ConnectTiming::default())
} else {
let (stream, timing) =
connect_tcp_target(&dns_cache, dns_config, url, timeout_config, local_addr).await?;
(stream, timing)
};
let tls_started = Instant::now();
let tls_stream =
connect_tls(tcp_stream, url, tls_config, protocol_policy, timeout_config).await?;
timing.tls = Some(tls_started.elapsed());
(Box::new(tls_stream) as BoxedStream, false, timing)
};
match request_body {
BodyData::Bytes(bytes) => {
let result = execute_http1_over_stream(
stream,
method,
url,
timeout_config,
compression_mode,
progress_callback.clone(),
progress_config,
request_head,
BodyData::Bytes(bytes.clone()),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key.clone(),
pool_config,
)
.await;
match result {
Ok(response) => Ok(response),
Err(err) if reused && should_retry_stale_connection(method, &err) => {
let (tcp_stream, mut retry_timing) = if let Some(proxy) = proxy {
let s = match proxy {
Proxy::Http { addr, auth } => {
connect_via_http_proxy_tunnel(
addr,
auth,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
Proxy::Socks5 { .. } => {
connect_via_proxy(
proxy,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
};
(s, ConnectTiming::default())
} else {
let (stream, timing) = connect_tcp_target(
&dns_cache,
dns_config,
url,
timeout_config,
local_addr,
)
.await?;
(stream, timing)
};
let tls_started = Instant::now();
let stream =
connect_tls(tcp_stream, url, tls_config, protocol_policy, timeout_config)
.await?;
retry_timing.tls = Some(tls_started.elapsed());
execute_http1_over_stream(
Box::new(stream) as BoxedStream,
method,
url,
timeout_config,
compression_mode,
progress_callback,
progress_config,
request_head,
BodyData::Bytes(bytes),
write_mode,
request_allows_reuse,
false,
retry_timing,
Arc::clone(&pool),
pool_key,
pool_config,
)
.await
}
Err(err) => Err(err),
}
}
BodyData::Stream(stream_body) => {
execute_http1_over_stream(
stream,
method,
url,
timeout_config,
compression_mode,
progress_callback,
progress_config,
request_head,
BodyData::Stream(stream_body),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key,
pool_config,
)
.await
}
}
}
#[cfg(feature = "native-tls")]
async fn execute_https_request_native(
url: &Url,
tls_config: &TlsConfig,
protocol_policy: ProtocolPolicy,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
request_head: &[u8],
request_body: BodyData,
write_mode: BodyWriteMode,
compression_mode: crate::CompressionMode,
method: Method,
request_allows_reuse: bool,
pool: Arc<Mutex<ConnectionPool>>,
pool_key: PoolKey,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
pool_config: PoolConfig,
proxy: Option<&Proxy>,
) -> Result<Response> {
let (stream, reused, timing) = if let Some(stream) = pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.checkout(&pool_key, pool_config)
{
(stream, true, ConnectTiming::default())
} else {
let (tcp_stream, mut timing) = if let Some(proxy) = proxy {
let s = match proxy {
Proxy::Http { addr, auth } => {
connect_via_http_proxy_tunnel(
addr,
auth,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
Proxy::Socks5 { .. } => {
connect_via_proxy(
proxy,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
};
(s, ConnectTiming::default())
} else {
let (stream, timing) =
connect_tcp_target(&dns_cache, dns_config, url, timeout_config, local_addr).await?;
(stream, timing)
};
let connector = build_native_tls_connector(tls_config, protocol_policy)?;
let tls_started = Instant::now();
let stream = with_timeout(
timeout_config.connect,
connector.connect(url.host(), tcp_stream),
ErrorKind::Timeout,
"tls handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "failed tls handshake", err)
}
})?
.map_err(|err| Error::with_source(ErrorKind::Transport, "failed tls handshake", err))?;
timing.tls = Some(tls_started.elapsed());
if let Some(cert) = stream.peer_certificate().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect peer certificate",
err,
)
})? {
let cert_der = cert.to_der().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to serialize peer certificate",
err,
)
})?;
verify_pinned_certificate(&tls_config.pinned_certificates, url.host(), &cert_der)?;
}
(Box::new(stream) as BoxedStream, false, timing)
};
match request_body {
BodyData::Bytes(bytes) => {
let result = execute_http1_over_stream(
stream,
method,
url,
timeout_config,
compression_mode,
progress_callback.clone(),
progress_config,
request_head,
BodyData::Bytes(bytes.clone()),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key.clone(),
pool_config,
)
.await;
match result {
Ok(response) => Ok(response),
Err(err) if reused && should_retry_stale_connection(method, &err) => {
let (tcp_stream, mut retry_timing) =
connect_tcp_target(&dns_cache, dns_config, url, timeout_config, local_addr)
.await?;
let connector = build_native_tls_connector(tls_config, protocol_policy)?;
let tls_started = Instant::now();
let stream = with_timeout(
timeout_config.connect,
connector.connect(url.host(), tcp_stream),
ErrorKind::Timeout,
"tls handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "failed tls handshake", err)
}
})?
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed tls handshake", err)
})?;
retry_timing.tls = Some(tls_started.elapsed());
if let Some(cert) = stream.peer_certificate().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect peer certificate",
err,
)
})? {
let cert_der = cert.to_der().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to serialize peer certificate",
err,
)
})?;
verify_pinned_certificate(
&tls_config.pinned_certificates,
url.host(),
&cert_der,
)?;
}
execute_http1_over_stream(
Box::new(stream) as BoxedStream,
method,
url,
timeout_config,
compression_mode,
progress_callback,
progress_config,
request_head,
BodyData::Bytes(bytes),
write_mode,
request_allows_reuse,
false,
retry_timing,
Arc::clone(&pool),
pool_key,
pool_config,
)
.await
}
Err(err) => Err(err),
}
}
BodyData::Stream(stream_body) => {
execute_http1_over_stream(
stream,
method,
url,
timeout_config,
compression_mode,
progress_callback,
progress_config,
request_head,
BodyData::Stream(stream_body),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key,
pool_config,
)
.await
}
}
}
#[cfg(feature = "btls-backend")]
async fn connect_tls_boring(
stream: TcpStream,
url: &Url,
tls_config: &TlsConfig,
protocol_policy: ProtocolPolicy,
timeout_config: TimeoutConfig,
) -> Result<crate::connection::tls::BoringTlsStream> {
let protocols = tls_config
.validate_http1_alpn(protocol_policy)
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
let connector =
crate::connection::tls::build_boring_tls_connector_for_protocols(tls_config, &protocols)?;
crate::connection::tls::connect_boring_tls_with_connector(
stream,
url.host(),
connector,
tls_config,
timeout_config.connect,
)
.await
}
#[cfg(feature = "btls-backend")]
async fn execute_https_request_boring(
url: &Url,
tls_config: &TlsConfig,
protocol_policy: ProtocolPolicy,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
request_head: &[u8],
request_body: BodyData,
write_mode: BodyWriteMode,
compression_mode: crate::CompressionMode,
method: Method,
request_allows_reuse: bool,
pool: Arc<Mutex<ConnectionPool>>,
pool_key: PoolKey,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
pool_config: PoolConfig,
proxy: Option<&Proxy>,
) -> Result<Response> {
let (stream, reused, timing) = if let Some(stream) = pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.checkout(&pool_key, pool_config)
{
(stream, true, ConnectTiming::default())
} else {
let (tcp_stream, mut timing) = if let Some(proxy) = proxy {
let s = match proxy {
Proxy::Http { addr, auth } => {
connect_via_http_proxy_tunnel(
addr,
auth,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
Proxy::Socks5 { .. } => {
connect_via_proxy(
proxy,
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
};
(s, ConnectTiming::default())
} else {
let (stream, timing) =
connect_tcp_target(&dns_cache, dns_config, url, timeout_config, local_addr).await?;
(stream, timing)
};
let tls_started = Instant::now();
let tls_stream =
connect_tls_boring(tcp_stream, url, tls_config, protocol_policy, timeout_config)
.await?;
timing.tls = Some(tls_started.elapsed());
(Box::new(tls_stream) as BoxedStream, false, timing)
};
match request_body {
BodyData::Bytes(bytes) => {
let result = execute_http1_over_stream(
stream,
method,
url,
timeout_config,
compression_mode,
progress_callback.clone(),
progress_config,
request_head,
BodyData::Bytes(bytes.clone()),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key.clone(),
pool_config,
)
.await;
match result {
Ok(response) => Ok(response),
Err(err) if reused && should_retry_stale_connection(method, &err) => {
let (tcp_stream, mut retry_timing) =
connect_tcp_target(&dns_cache, dns_config, url, timeout_config, local_addr)
.await?;
let tls_started = Instant::now();
let stream = connect_tls_boring(
tcp_stream,
url,
tls_config,
protocol_policy,
timeout_config,
)
.await?;
retry_timing.tls = Some(tls_started.elapsed());
execute_http1_over_stream(
Box::new(stream) as BoxedStream,
method,
url,
timeout_config,
compression_mode,
progress_callback,
progress_config,
request_head,
BodyData::Bytes(bytes),
write_mode,
request_allows_reuse,
false,
retry_timing,
Arc::clone(&pool),
pool_key,
pool_config,
)
.await
}
Err(err) => Err(err),
}
}
BodyData::Stream(stream_body) => {
execute_http1_over_stream(
stream,
method,
url,
timeout_config,
compression_mode,
progress_callback,
progress_config,
request_head,
BodyData::Stream(stream_body),
write_mode,
request_allows_reuse,
reused,
timing,
Arc::clone(&pool),
pool_key,
pool_config,
)
.await
}
}
}
#[cfg(feature = "rustls")]
async fn connect_tls(
stream: TcpStream,
url: &Url,
tls_config: &TlsConfig,
protocol_policy: ProtocolPolicy,
timeout_config: TimeoutConfig,
) -> Result<async_tls::client::TlsStream<TcpStream>> {
let config = tls_config
.build_client_config(protocol_policy)
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
connect_async_tls_with_config(stream, url.host(), config, timeout_config.connect).await
}
#[cfg(feature = "native-tls")]
fn build_native_tls_connector(
tls_config: &TlsConfig,
protocol_policy: ProtocolPolicy,
) -> Result<NativeTlsConnector> {
let protocols = tls_config
.validate_http1_alpn(protocol_policy)
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
build_native_tls_connector_for_protocols(tls_config, &protocols)
}
async fn execute_http1_over_stream(
mut stream: BoxedStream,
method: Method,
url: &Url,
timeout_config: TimeoutConfig,
compression_mode: crate::CompressionMode,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
request_head: &[u8],
request_body: BodyData,
write_mode: BodyWriteMode,
request_allows_reuse: bool,
reused_connection: bool,
connect_timing: ConnectTiming,
pool: Arc<Mutex<ConnectionPool>>,
pool_key: PoolKey,
pool_config: PoolConfig,
) -> Result<Response> {
let write_started = Instant::now();
with_timeout_io(
timeout_config.write,
stream.write_all(request_head),
"write timed out",
)
.await
.map_err(|err| {
if reused_connection && err.kind() == &ErrorKind::Transport {
Error::new(
ErrorKind::StaleConnection,
"write on reused connection failed",
)
} else {
err
}
})?;
match request_body {
BodyData::Bytes(bytes) => {
let mut upload_progress = progress_callback.clone().map(|callback| {
ProgressReporter::new(
callback,
ProgressPhase::Upload,
Some(bytes.len()),
progress_config,
)
});
if !bytes.is_empty() {
match write_mode {
BodyWriteMode::Fixed => {
with_timeout_io(
timeout_config.write,
stream.write_all(&bytes),
"write timed out",
)
.await?;
}
BodyWriteMode::Chunked => {
let header = format!("{:X}\r\n", bytes.len());
with_timeout_io(
timeout_config.write,
stream.write_all(header.as_bytes()),
"write timed out",
)
.await?;
with_timeout_io(
timeout_config.write,
stream.write_all(&bytes),
"write timed out",
)
.await?;
with_timeout_io(
timeout_config.write,
stream.write_all(b"\r\n"),
"write timed out",
)
.await?;
with_timeout_io(
timeout_config.write,
stream.write_all(b"0\r\n\r\n"),
"write timed out",
)
.await?;
}
}
} else if write_mode == BodyWriteMode::Chunked {
with_timeout_io(
timeout_config.write,
stream.write_all(b"0\r\n\r\n"),
"write timed out",
)
.await?;
}
with_timeout_io(timeout_config.write, stream.flush(), "write timed out").await?;
if let Some(progress) = &mut upload_progress {
progress.record(bytes.len());
progress.finish();
}
}
BodyData::Stream(mut stream_body) => {
let mut upload_progress = progress_callback.clone().map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Upload, None, progress_config)
});
while let Some(chunk) = stream_body.next().await {
let chunk = chunk?;
if chunk.is_empty() {
continue;
}
match write_mode {
BodyWriteMode::Fixed => {
with_timeout_io(
timeout_config.write,
stream.write_all(&chunk),
"write timed out",
)
.await?;
}
BodyWriteMode::Chunked => {
let header = format!("{:X}\r\n", chunk.len());
with_timeout_io(
timeout_config.write,
stream.write_all(header.as_bytes()),
"write timed out",
)
.await?;
with_timeout_io(
timeout_config.write,
stream.write_all(&chunk),
"write timed out",
)
.await?;
with_timeout_io(
timeout_config.write,
stream.write_all(b"\r\n"),
"write timed out",
)
.await?;
}
}
if let Some(progress) = &mut upload_progress {
progress.record(chunk.len());
}
}
if write_mode == BodyWriteMode::Chunked {
with_timeout_io(
timeout_config.write,
stream.write_all(b"0\r\n\r\n"),
"write timed out",
)
.await?;
}
with_timeout_io(timeout_config.write, stream.flush(), "write timed out").await?;
if let Some(progress) = &mut upload_progress {
progress.finish();
}
}
}
let write_duration = write_started.elapsed();
let ttfb_started = Instant::now();
let (
status,
response_headers,
response_body,
trailers,
reusable,
reusable_stream,
ttfb,
download_duration,
) = read_response(
stream,
method,
compression_mode,
timeout_config,
progress_callback,
progress_config,
)
.await?;
let _ = ttfb_started; if request_allows_reuse && reusable {
if let Some(stream) = reusable_stream {
pool.lock().unwrap_or_else(|err| err.into_inner()).insert(
pool_key,
stream,
pool_config,
);
}
}
Ok(Response::new_with_trailer_state(
status,
Version::Http11,
url.clone(),
response_headers,
trailers,
response_body,
)
.with_metrics(
Metrics::default()
.with_protocol(Version::Http11)
.with_reused_connection(reused_connection)
.with_dns_duration(connect_timing.dns)
.with_connect_duration(connect_timing.connect)
.with_tls_duration(connect_timing.tls)
.with_request_write_duration(Some(write_duration))
.with_ttfb(Some(ttfb))
.with_download_duration(download_duration),
))
}
async fn execute_http1_with_redirect(
request: Request,
policy: RedirectPolicy,
pool: Arc<Mutex<ConnectionPool>>,
h2_pool: Arc<Mutex<h2::ConnectionPool>>,
h3_pool: Arc<Mutex<h3::ConnectionPool>>,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
pool_config: PoolConfig,
alt_svc_cache: AltSvcCache,
) -> Result<Response> {
if matches!(policy, RedirectPolicy::None) {
return execute_http1(
request,
pool,
h2_pool,
h3_pool,
dns_cache,
dns_config,
local_addr,
pool_config,
alt_svc_cache,
)
.await;
}
let limit = match policy {
RedirectPolicy::None => 0,
RedirectPolicy::Limit(limit) => limit,
};
let mut current_request = request;
for _ in 0..=limit {
let original_method = current_request.method();
let original_headers = current_request.headers().clone();
let original_cookies = current_request.cookies().to_vec();
let original_timeout_config = current_request.timeout_config();
let original_protocol_policy = current_request.protocol_policy();
let original_retry_policy = current_request.retry_policy();
let original_prior_knowledge_h2c = current_request.prior_knowledge_h2c();
let original_progress_callback = current_request.progress_callback().cloned();
let original_progress_config = current_request.progress_config();
let original_h2_keepalive_config = current_request.h2_keepalive_config();
let original_tls_config = current_request.tls_config().clone();
let original_proxy = current_request.proxy_cloned();
let original_compression_mode = current_request.compression_mode();
let original_browser_profile = current_request.emulation_profile().cloned();
let original_body = current_request.body().try_clone();
let response = execute_http1(
current_request,
Arc::clone(&pool),
Arc::clone(&h2_pool),
Arc::clone(&h3_pool),
Arc::clone(&dns_cache),
dns_config,
local_addr,
pool_config,
alt_svc_cache.clone(),
)
.await?;
let status = response.status();
if !status.is_redirect() {
return Ok(response);
}
let location = response.headers().get("location").ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"redirect response missing location header",
)
})?;
let next_url = resolve_redirect_url(response.url(), location)?;
let same_origin = same_origin(response.url(), &next_url);
let (next_method, preserve_body) = redirect_behavior(status, original_method);
let next_body = if preserve_body {
original_body.ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"redirect requires replayable request body",
)
})?
} else {
Body::empty()
};
let next_cookies = if same_origin {
original_cookies
} else {
Vec::new()
};
current_request = Request::new(
next_method,
next_url,
sanitize_redirect_headers(original_headers, same_origin),
next_cookies,
original_timeout_config,
original_protocol_policy,
original_retry_policy,
original_prior_knowledge_h2c,
original_progress_callback,
original_progress_config,
original_h2_keepalive_config,
original_tls_config,
original_proxy,
original_compression_mode,
next_body,
original_browser_profile.clone(),
);
}
Err(Error::new(ErrorKind::Transport, "redirect limit exceeded"))
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum BodyWriteMode {
Fixed,
Chunked,
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct EncodedHeaderLine {
lower_name: String,
value: String,
}
fn ordered_headers(
headers: &HeaderMap,
generated_headers: Vec<EncodedHeaderLine>,
order: Option<&[String]>,
) -> Vec<EncodedHeaderLine> {
let mut items: Vec<_> = headers
.iter()
.map(|(name, value)| EncodedHeaderLine {
lower_name: name.as_str().to_owned(),
value: value.as_str().to_owned(),
})
.collect();
items.extend(generated_headers);
if let Some(order) = order {
if !order.is_empty() {
items.sort_by_key(|header| {
order
.iter()
.position(|target| target.eq_ignore_ascii_case(header.lower_name.as_str()))
.unwrap_or(order.len())
});
}
}
items
}
fn write_serialized_header(
buffer: &mut Vec<u8>,
lower_name: &str,
value: &str,
original_header_case: Option<&[(String, String)]>,
) {
let rendered_name = original_header_case
.and_then(|pairs| {
pairs
.iter()
.find(|(candidate, _)| candidate.eq_ignore_ascii_case(lower_name))
.map(|(_, original)| original.as_str())
})
.unwrap_or(lower_name);
buffer.extend_from_slice(rendered_name.as_bytes());
buffer.extend_from_slice(b": ");
buffer.extend_from_slice(value.as_bytes());
buffer.extend_from_slice(b"\r\n");
}
fn build_generated_headers(
url: &Url,
cookies: &[(String, String)],
compression_mode: crate::CompressionMode,
body: BodyData,
headers: &HeaderMap,
mut extra_headers: Vec<EncodedHeaderLine>,
) -> Result<(Vec<EncodedHeaderLine>, BodyData, BodyWriteMode)> {
let mut generated_headers = Vec::new();
let mut has_host = false;
let mut has_content_length = false;
let mut has_accept_encoding = false;
let mut transfer_encoding = None;
for (name, value) in headers.iter() {
match name.as_str() {
"host" => has_host = true,
"content-length" => has_content_length = true,
"transfer-encoding" => transfer_encoding = Some(value.as_str().to_owned()),
"accept-encoding" => has_accept_encoding = true,
_ => {}
}
}
if !has_host {
generated_headers.push(EncodedHeaderLine {
lower_name: "host".to_owned(),
value: url.authority().to_owned(),
});
}
generated_headers.append(&mut extra_headers);
let write_mode = match (&body, transfer_encoding.as_deref()) {
(_, Some(value)) => {
if has_content_length {
return Err(Error::new(
ErrorKind::InvalidHeaderValue,
"content-length cannot be used with transfer-encoding",
));
}
if !transfer_encoding_is_chunked(value) {
return Err(Error::new(
ErrorKind::InvalidHeaderValue,
"transfer-encoding must include chunked",
));
}
BodyWriteMode::Chunked
}
(BodyData::Stream(_), None) => {
if has_content_length {
BodyWriteMode::Fixed
} else {
generated_headers.push(EncodedHeaderLine {
lower_name: "transfer-encoding".to_owned(),
value: "chunked".to_owned(),
});
BodyWriteMode::Chunked
}
}
(BodyData::Bytes(bytes), None) => {
if !has_content_length {
generated_headers.push(EncodedHeaderLine {
lower_name: "content-length".to_owned(),
value: bytes.len().to_string(),
});
}
BodyWriteMode::Fixed
}
};
if !has_accept_encoding && compression_mode.should_add_accept_encoding() {
generated_headers.push(EncodedHeaderLine {
lower_name: "accept-encoding".to_owned(),
value: DEFAULT_ACCEPT_ENCODING.to_owned(),
});
}
if !cookies.is_empty() {
let cookie_value = cookies
.iter()
.enumerate()
.map(|(index, (name, value))| {
if index == 0 {
format!("{name}={value}")
} else {
format!("; {name}={value}")
}
})
.collect::<String>();
generated_headers.push(EncodedHeaderLine {
lower_name: "cookie".to_owned(),
value: cookie_value,
});
}
Ok((generated_headers, body, write_mode))
}
fn encode_request(
method: &str,
url: &Url,
headers: &HeaderMap,
cookies: &[(String, String)],
compression_mode: crate::CompressionMode,
body: BodyData,
http1_fingerprint: Option<&Http1Fingerprint>,
) -> Result<(Vec<u8>, BodyData, BodyWriteMode)> {
let header_order = http1_fingerprint.map(|fingerprint| fingerprint.header_order.as_slice());
let original_header_case =
http1_fingerprint.map(|fingerprint| fingerprint.original_header_case.as_slice());
let (generated_headers, body, write_mode) =
build_generated_headers(url, cookies, compression_mode, body, headers, Vec::new())?;
let mut buffer = Vec::new();
buffer.extend_from_slice(method.as_bytes());
buffer.extend_from_slice(b" ");
buffer.extend_from_slice(url.path_and_query().as_bytes());
buffer.extend_from_slice(b" HTTP/1.1\r\n");
for header in ordered_headers(headers, generated_headers, header_order) {
write_serialized_header(
&mut buffer,
&header.lower_name,
&header.value,
original_header_case,
);
}
buffer.extend_from_slice(b"\r\n");
Ok((buffer, body, write_mode))
}
fn encode_proxy_request(
method: &str,
url: &Url,
headers: &HeaderMap,
cookies: &[(String, String)],
proxy: &Proxy,
compression_mode: crate::CompressionMode,
body: BodyData,
http1_fingerprint: Option<&Http1Fingerprint>,
) -> Result<(Vec<u8>, BodyData, BodyWriteMode)> {
let header_order = http1_fingerprint.map(|fingerprint| fingerprint.header_order.as_slice());
let original_header_case =
http1_fingerprint.map(|fingerprint| fingerprint.original_header_case.as_slice());
let extra_headers = match proxy {
Proxy::Http {
auth: Some(auth), ..
} => vec![EncodedHeaderLine {
lower_name: "proxy-authorization".to_owned(),
value: format!(
"Basic {}",
base64_encode(format!("{}:{}", auth.username, auth.password).as_bytes())
),
}],
_ => Vec::new(),
};
let (generated_headers, body, write_mode) =
build_generated_headers(url, cookies, compression_mode, body, headers, extra_headers)?;
let mut buffer = Vec::new();
buffer.extend_from_slice(method.as_bytes());
buffer.extend_from_slice(b" ");
buffer.extend_from_slice(url.as_str().as_bytes());
buffer.extend_from_slice(b" HTTP/1.1\r\n");
for header in ordered_headers(headers, generated_headers, header_order) {
write_serialized_header(
&mut buffer,
&header.lower_name,
&header.value,
original_header_case,
);
}
buffer.extend_from_slice(b"\r\n");
Ok((buffer, body, write_mode))
}
fn transfer_encoding_is_chunked(value: &str) -> bool {
value
.split(',')
.map(|part| part.trim())
.any(|part| part.eq_ignore_ascii_case("chunked"))
}
async fn read_length_body_bytes<S>(
stream: &mut S,
mut body: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
content_length: usize,
) -> Result<Vec<u8>>
where
S: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(
callback,
ProgressPhase::Download,
Some(content_length),
progress_config,
)
});
if let Some(progress) = &mut progress {
if !body.is_empty() {
progress.record(body.len());
}
}
while body.len() < content_length {
let mut chunk = vec![0; content_length - body.len()];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
if let Some(progress) = &mut progress {
progress.record(read);
}
}
if let Some(progress) = &mut progress {
progress.finish();
}
Ok(body)
}
async fn read_to_eof_bytes<S>(
stream: &mut S,
mut body: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
) -> Result<Vec<u8>>
where
S: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
if let Some(progress) = &mut progress {
if !body.is_empty() {
progress.record(body.len());
}
}
loop {
let mut chunk = [0_u8; 4096];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
if let Some(progress) = &mut progress {
progress.record(read);
}
}
if let Some(progress) = &mut progress {
progress.finish();
}
Ok(body)
}
fn read_to_eof_body_stream(
mut stream: BoxedStream,
initial: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
) -> Body {
let stream_body = async_stream::try_stream! {
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
if !initial.is_empty() {
if let Some(progress) = &mut progress {
progress.record(initial.len());
}
yield Bytes::from(initial);
}
loop {
let mut chunk = vec![0_u8; 8192];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
chunk.truncate(read);
if let Some(progress) = &mut progress {
progress.record(read);
}
yield Bytes::from(chunk);
}
if let Some(progress) = &mut progress {
progress.finish();
}
};
Body::from_stream(Box::pin(stream_body))
}
async fn read_response(
mut stream: BoxedStream,
method: Method,
compression_mode: crate::CompressionMode,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
) -> Result<(
StatusCode,
HeaderMap,
Body,
TrailerState,
bool,
Option<BoxedStream>,
std::time::Duration, // ttfb (time from start of read_response to first response byte)
Option<std::time::Duration>, // download_duration (None for streaming bodies)
)> {
let read_started = Instant::now();
let mut buffer = Vec::new();
let (status, mut headers, body_prefix) = loop {
let mut header_end = find_headers_end(&buffer);
while header_end.is_none() {
let mut chunk = [0_u8; 1024];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
buffer.extend_from_slice(&chunk[..read]);
header_end = find_headers_end(&buffer);
}
let header_end = header_end.ok_or_else(|| {
let kind = if buffer.is_empty() {
ErrorKind::StaleConnection
} else {
ErrorKind::Transport
};
Error::new(kind, "response headers are incomplete")
})?;
let (head, rest) = buffer.split_at(header_end);
let rest = rest[4..].to_vec();
let head_text = std::str::from_utf8(head).map_err(|err| {
Error::with_source(
ErrorKind::Decode,
"response headers are not valid utf-8",
err,
)
})?;
let mut lines = head_text.split("\r\n");
let status_line = lines
.next()
.ok_or_else(|| Error::new(ErrorKind::Transport, "missing status line"))?;
let status = parse_status_line(status_line)?;
let mut headers = HeaderMap::new();
for line in lines {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').ok_or_else(|| {
Error::new(ErrorKind::Transport, format!("invalid header line: {line}"))
})?;
headers.append(name.trim(), value.trim())?;
}
if (100..200).contains(&status.as_u16()) && status.as_u16() != 101 {
buffer = rest;
continue;
}
break (status, headers, rest);
};
let ttfb = read_started.elapsed();
let response_allows_reuse = !headers
.get_all("connection")
.iter()
.any(|value| value.eq_ignore_ascii_case("close"));
if !response_body_allowed(method, status) {
let reusable = response_allows_reuse && body_prefix.is_empty();
return Ok((
status,
headers,
Body::empty(),
TrailerState::Ready(None),
reusable,
if reusable { Some(stream) } else { None },
ttfb,
Some(std::time::Duration::ZERO),
));
}
let has_content_encoding = headers.get("content-encoding").is_some();
let download_started = Instant::now();
let (body, trailers, reusable_by_body, reusable_stream, download_duration) = if headers
.get_all("transfer-encoding")
.iter()
.any(|value| value.eq_ignore_ascii_case("chunked"))
{
if has_content_encoding {
let (body_bytes, trailers) = read_chunked_body_bytes(
&mut stream,
body_prefix.clone(),
timeout_config,
progress_callback,
progress_config,
)
.await?;
let body = maybe_decode_response_body(&mut headers, body_bytes, compression_mode)?;
(
body,
TrailerState::Ready(trailers),
true,
Some(stream),
Some(download_started.elapsed()),
)
} else {
let trailers = Arc::new(OnceLock::new());
let body = read_chunked_body_stream(
stream,
body_prefix.clone(),
timeout_config,
progress_callback,
progress_config,
Arc::clone(&trailers),
);
(body, TrailerState::Deferred(trailers), false, None, None)
}
} else if let Some(content_length) = headers.get("content-length") {
let content_length: usize = content_length
.parse()
.map_err(|_| Error::new(ErrorKind::Transport, "invalid content-length"))?;
let body_bytes = read_length_body_bytes(
&mut stream,
body_prefix.clone(),
timeout_config,
progress_callback,
progress_config,
content_length,
)
.await?;
let body = if has_content_encoding {
maybe_decode_response_body(&mut headers, body_bytes, compression_mode)?
} else {
Body::from(body_bytes)
};
(
body,
TrailerState::Ready(None),
true,
Some(stream),
Some(download_started.elapsed()),
)
} else {
if has_content_encoding {
let body_bytes = read_to_eof_bytes(
&mut stream,
body_prefix.clone(),
timeout_config,
progress_callback,
progress_config,
)
.await?;
let body = maybe_decode_response_body(&mut headers, body_bytes, compression_mode)?;
(
body,
TrailerState::Ready(None),
false,
None,
Some(download_started.elapsed()),
)
} else {
let body = read_to_eof_body_stream(
stream,
body_prefix.clone(),
timeout_config,
progress_callback,
progress_config,
);
(body, TrailerState::Ready(None), false, None, None)
}
};
Ok((
status,
headers,
body,
trailers,
response_allows_reuse && reusable_by_body,
reusable_stream,
ttfb,
download_duration,
))
}
async fn read_chunked_body_bytes<S>(
stream: &mut S,
mut buffer: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
) -> Result<(Vec<u8>, Option<HeaderMap>)>
where
S: AsyncRead + Unpin + ?Sized,
{
let mut decoded = Vec::new();
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
loop {
let size_line_end = loop {
if let Some(pos) = find_crlf(&buffer) {
break pos;
}
read_more(stream, &mut buffer, timeout_config).await?;
};
let size_line = String::from_utf8(buffer[..size_line_end].to_vec()).map_err(|err| {
Error::with_source(ErrorKind::Decode, "chunk size is not valid utf-8", err)
})?;
buffer.drain(..size_line_end + 2);
let size = usize::from_str_radix(size_line.trim(), 16).map_err(|_| {
Error::new(
ErrorKind::Transport,
format!("invalid chunk size: {size_line}"),
)
})?;
if size == 0 {
if let Some(progress) = &mut progress {
progress.finish();
}
if buffer.starts_with(b"\r\n") {
return Ok((decoded, None));
}
while find_headers_end(&buffer).is_none() {
read_more(stream, &mut buffer, timeout_config).await?;
}
let trailers = parse_trailers(&buffer)?;
return Ok((decoded, trailers));
}
while buffer.len() < size + 2 {
read_more(stream, &mut buffer, timeout_config).await?;
}
decoded.extend_from_slice(&buffer[..size]);
if let Some(progress) = &mut progress {
progress.record(size);
}
buffer.drain(..size + 2);
}
}
fn read_chunked_body_stream(
mut stream: BoxedStream,
mut buffer: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: crate::progress::ProgressConfig,
trailers: Arc<OnceLock<HeaderMap>>,
) -> Body {
let stream_body = async_stream::try_stream! {
let mut progress =
progress_callback.map(|callback| ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config));
loop {
let size_line_end = loop {
if let Some(pos) = find_crlf(&buffer) {
break pos;
}
read_more(&mut stream, &mut buffer, timeout_config).await?;
};
let size_line = String::from_utf8(buffer[..size_line_end].to_vec()).map_err(|err| {
Error::with_source(ErrorKind::Decode, "chunk size is not valid utf-8", err)
})?;
buffer.drain(..size_line_end + 2);
let size = usize::from_str_radix(size_line.trim(), 16).map_err(|_| {
Error::new(
ErrorKind::Transport,
format!("invalid chunk size: {size_line}"),
)
})?;
if size == 0 {
if buffer.starts_with(b"\r\n") {
buffer.drain(..2);
} else {
while find_headers_end(&buffer).is_none() {
read_more(&mut stream, &mut buffer, timeout_config).await?;
}
let parsed = parse_trailers(&buffer)?;
if let Some(parsed) = parsed {
let _ = trailers.set(parsed);
}
}
if let Some(progress) = &mut progress {
progress.finish();
}
break;
}
while buffer.len() < size + 2 {
read_more(&mut stream, &mut buffer, timeout_config).await?;
}
let chunk = buffer.drain(..size).collect::<Vec<_>>();
buffer.drain(..2);
if let Some(progress) = &mut progress {
progress.record(size);
}
yield Bytes::from(chunk);
}
};
Body::from_stream(Box::pin(stream_body))
}
fn find_headers_end(bytes: &[u8]) -> Option<usize> {
bytes.windows(4).position(|window| window == b"\r\n\r\n")
}
fn find_crlf(bytes: &[u8]) -> Option<usize> {
bytes.windows(2).position(|window| window == b"\r\n")
}
async fn read_more<S>(
stream: &mut S,
buffer: &mut Vec<u8>,
timeout_config: TimeoutConfig,
) -> Result<()>
where
S: AsyncRead + Unpin + ?Sized,
{
let mut chunk = [0_u8; 1024];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
return Err(Error::new(
ErrorKind::Transport,
"unexpected eof while reading chunked body",
));
}
buffer.extend_from_slice(&chunk[..read]);
Ok(())
}
async fn with_timeout<F, T>(
timeout: Option<std::time::Duration>,
future: F,
kind: ErrorKind,
message: &'static str,
) -> Result<T>
where
F: std::future::Future<Output = T>,
{
match timeout {
Some(duration) => {
futures_lite::future::or(async move { Ok(future.await) }, async move {
async_io::Timer::after(duration).await;
Err(Error::new(kind, message))
})
.await
}
None => Ok(future.await),
}
}
async fn with_timeout_io<F, T>(
timeout: Option<std::time::Duration>,
future: F,
timeout_message: &'static str,
) -> Result<T>
where
F: std::future::Future<Output = std::io::Result<T>>,
{
match with_timeout(timeout, future, ErrorKind::Timeout, timeout_message).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(err)) => Err(Error::with_source(
ErrorKind::Transport,
"io operation failed",
err,
)),
Err(err) => Err(err),
}
}
async fn connect_tcp_target(
dns_cache: &DnsCache,
dns_config: DnsConfig,
url: &Url,
timeout_config: TimeoutConfig,
local_addr: Option<SocketAddr>,
) -> Result<(TcpStream, ConnectTiming)> {
let dns_started = Instant::now();
let addrs = dns_cache.resolve_socket_addrs(url.host(), url.effective_port(), dns_config)?;
let dns_duration = dns_started.elapsed();
let (primary, fallback) = split_by_family(addrs);
let connect_started = Instant::now();
let stream =
connect_happy_eyeballs(primary, fallback, local_addr, timeout_config.connect).await?;
let connect_duration = connect_started.elapsed();
let timing = ConnectTiming {
dns: Some(dns_duration),
connect: Some(connect_duration),
tls: None,
};
Ok((stream, timing))
}
fn split_by_family(addrs: Vec<SocketAddr>) -> (Vec<SocketAddr>, Vec<SocketAddr>) {
let primary: Vec<SocketAddr> = addrs.iter().copied().filter(|a| a.is_ipv6()).collect();
let fallback: Vec<SocketAddr> = addrs.iter().copied().filter(|a| a.is_ipv4()).collect();
if primary.is_empty() {
(fallback, vec![])
} else {
(primary, fallback)
}
}
async fn connect_via_proxy(
proxy: &Proxy,
target_host: &str,
target_port: u16,
timeout_config: TimeoutConfig,
local_addr: Option<SocketAddr>,
) -> Result<TcpStream> {
match proxy {
Proxy::Http { addr, auth } => {
connect_via_http_proxy(
addr,
auth,
target_host,
target_port,
timeout_config,
local_addr,
)
.await
}
Proxy::Socks5 { addr, auth } => {
connect_via_socks5_proxy(
addr,
auth,
target_host,
target_port,
timeout_config,
local_addr,
)
.await
}
}
}
async fn connect_via_http_proxy_tunnel(
proxy_addr: &std::net::SocketAddr,
auth: &Option<crate::proxy::ProxyAuth>,
target_host: &str,
target_port: u16,
timeout_config: TimeoutConfig,
local_addr: Option<SocketAddr>,
) -> Result<TcpStream> {
connect_via_http_proxy(
proxy_addr,
auth,
target_host,
target_port,
timeout_config,
local_addr,
)
.await
}
async fn connect_via_http_proxy(
proxy_addr: &std::net::SocketAddr,
auth: &Option<crate::proxy::ProxyAuth>,
target_host: &str,
target_port: u16,
timeout_config: TimeoutConfig,
local_addr: Option<SocketAddr>,
) -> Result<TcpStream> {
let stream = connect_tcp_addr(*proxy_addr, local_addr, timeout_config.connect)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "failed to connect to HTTP proxy", err)
}
})?;
let mut stream = stream;
let connect_request = build_http_connect_request(target_host, target_port, auth.as_ref());
with_timeout_io(
timeout_config.write,
stream.write_all(connect_request.as_bytes()),
"proxy connect write timed out",
)
.await?;
let mut response = Vec::new();
let mut buf = [0u8; 1];
loop {
let n = with_timeout_io(
timeout_config.read,
stream.read(&mut buf),
"proxy connect read timed out",
)
.await?;
if n == 0 {
break;
}
response.push(buf[0]);
if response.ends_with(b"\r\n\r\n") {
break;
}
}
let response_str = String::from_utf8_lossy(&response);
let success = response_str.starts_with("HTTP/1.1 200")
|| response_str.starts_with("HTTP/1.0 200")
|| response_str.starts_with("HTTP/2 200")
|| response_str.starts_with("HTTP/2.0 200");
if !success {
return Err(Error::new(
ErrorKind::Transport,
format!("HTTP proxy connection failed: {}", response_str.trim()),
));
}
Ok(stream)
}
async fn connect_via_socks5_proxy(
proxy_addr: &std::net::SocketAddr,
auth: &Option<crate::proxy::ProxyAuth>,
target_host: &str,
target_port: u16,
timeout_config: TimeoutConfig,
local_addr: Option<SocketAddr>,
) -> Result<TcpStream> {
let stream = connect_tcp_addr(*proxy_addr, local_addr, timeout_config.connect)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(
ErrorKind::Transport,
"failed to connect to SOCKS5 proxy",
err,
)
}
})?;
let mut stream = stream;
let auth = auth.as_ref();
let has_auth = auth.is_some();
if has_auth {
with_timeout_io(
timeout_config.write,
stream.write_all(&[0x05, 0x02, 0x00, 0x02]),
"SOCKS5 handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "SOCKS5 auth handshake failed", err)
}
})?;
} else {
with_timeout_io(
timeout_config.write,
stream.write_all(&[0x05, 0x01, 0x00]),
"SOCKS5 handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "SOCKS5 handshake failed", err)
}
})?;
}
let mut reply = [0u8; 2];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut reply),
"SOCKS5 handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "SOCKS5 handshake reply failed", err)
}
})?;
if reply[0] != 0x05 {
return Err(Error::new(
ErrorKind::Transport,
"SOCKS5 protocol version mismatch",
));
}
let auth_method = reply[1];
if has_auth && auth_method != 0x02 {
return Err(Error::new(
ErrorKind::Transport,
format!("SOCKS5 proxy rejected auth method: {:02x}", auth_method),
));
} else if !has_auth && auth_method != 0x00 {
return Err(Error::new(
ErrorKind::Transport,
format!("SOCKS5 proxy rejected auth method: {:02x}", auth_method),
));
}
if has_auth {
let auth = auth.unwrap();
let username = auth.username.as_bytes();
let password = auth.password.as_bytes();
let mut auth_request = vec![0x01, username.len() as u8];
auth_request.extend_from_slice(username);
auth_request.push(password.len() as u8);
auth_request.extend_from_slice(password);
with_timeout_io(
timeout_config.write,
stream.write_all(&auth_request),
"SOCKS5 handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "SOCKS5 auth failed", err)
}
})?;
let mut auth_reply = [0u8; 2];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut auth_reply),
"SOCKS5 handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "SOCKS5 auth reply failed", err)
}
})?;
if auth_reply[1] != 0x00 {
return Err(Error::new(
ErrorKind::Transport,
"SOCKS5 authentication failed",
));
}
}
let mut connect_request = vec![0x05, 0x01, 0x00, 0x03];
let target_host_bytes = target_host.as_bytes();
connect_request.push(target_host_bytes.len() as u8);
connect_request.extend_from_slice(target_host_bytes);
connect_request.extend_from_slice(&target_port.to_be_bytes());
with_timeout_io(
timeout_config.write,
stream.write_all(&connect_request),
"SOCKS5 handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "SOCKS5 connect request failed", err)
}
})?;
let mut reply = [0u8; 10];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut reply),
"SOCKS5 handshake timed out",
)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "SOCKS5 connect reply failed", err)
}
})?;
if reply[0] != 0x05 || reply[1] != 0x00 {
return Err(Error::new(
ErrorKind::Transport,
format!(
"SOCKS5 connection failed: reply={:02x}{:02x}",
reply[0], reply[1]
),
));
}
Ok(stream)
}
fn base64_encode(input: &[u8]) -> String {
crate::util::encode_base64(input)
}
fn parse_trailers(buffer: &[u8]) -> Result<Option<HeaderMap>> {
if buffer == b"\r\n" {
return Ok(None);
}
let header_end = find_headers_end(buffer)
.ok_or_else(|| Error::new(ErrorKind::Transport, "incomplete trailer block"))?;
let head = &buffer[..header_end];
let text = std::str::from_utf8(head).map_err(|err| {
Error::with_source(ErrorKind::Decode, "trailers are not valid utf-8", err)
})?;
let mut headers = HeaderMap::new();
for line in text.split("\r\n") {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').ok_or_else(|| {
Error::new(
ErrorKind::Transport,
format!("invalid trailer line: {line}"),
)
})?;
headers.append(name.trim(), value.trim())?;
}
Ok(Some(headers))
}
fn parse_status_line(line: &str) -> Result<StatusCode> {
let mut parts = line.split_whitespace();
let version = parts
.next()
.ok_or_else(|| Error::new(ErrorKind::Transport, "invalid status line"))?;
if !version.starts_with("HTTP/1.") {
return Err(Error::new(
ErrorKind::Transport,
format!("unsupported http version: {version}"),
));
}
let code = parts
.next()
.ok_or_else(|| Error::new(ErrorKind::Transport, "missing status code"))?;
let code = code
.parse()
.map_err(|_| Error::new(ErrorKind::Transport, format!("invalid status code: {code}")))?;
Ok(StatusCode::new(code))
}
fn redirect_behavior(status: StatusCode, method: Method) -> (Method, bool) {
match status {
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => {
if method == Method::Head {
(Method::Head, false)
} else {
(Method::Get, false)
}
}
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => (method, true),
_ => (method, true),
}
}
fn resolve_redirect_url(base: &Url, location: &str) -> Result<Url> {
if location.contains("://") {
return Url::parse(location);
}
if location.starts_with('/') {
let origin = format!("{}://{}", base.scheme(), base.authority());
return Url::parse(format!("{origin}{location}"));
}
let mut base_prefix = format!("{}://{}", base.scheme(), base.authority());
let base_path = base.path_and_query();
let dir = base_path
.rsplit_once('/')
.map(|(prefix, _)| prefix)
.unwrap_or("");
if !dir.starts_with('/') {
base_prefix.push('/');
}
base_prefix.push_str(dir.trim_end_matches('/'));
if !location.starts_with('/') {
base_prefix.push('/');
}
base_prefix.push_str(location);
Url::parse(base_prefix)
}
fn sanitize_redirect_headers(headers: HeaderMap, same_origin: bool) -> HeaderMap {
let mut next_headers = HeaderMap::new();
for (name, value) in headers.iter() {
if !same_origin && matches!(name.as_str(), "authorization" | "cookie") {
continue;
}
let _ = next_headers.append(name.as_str(), value.as_str());
}
next_headers
}
fn same_origin(left: &Url, right: &Url) -> bool {
left.scheme() == right.scheme()
&& left.host() == right.host()
&& left.effective_port() == right.effective_port()
}
#[cfg(test)]
pub(crate) async fn spawn_test_server(response: &'static str) -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await.unwrap();
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok(format!("http://{}", addr))
}
#[cfg(test)]
pub(crate) async fn spawn_bytes_server(response: Vec<u8>) -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await.unwrap();
stream.write_all(&response).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok(format!("http://{}", addr))
}
#[cfg(test)]
pub(crate) async fn spawn_peer_addr_server() -> Result<(String, Arc<Mutex<Option<SocketAddr>>>)> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
let peer_addr = Arc::new(Mutex::new(None));
let captured = Arc::clone(&peer_addr);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, remote_addr) = listener.accept().await.unwrap();
*captured.lock().unwrap() = Some(remote_addr);
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await.unwrap();
stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")
.await
.unwrap();
stream.flush().await.unwrap();
});
});
Ok((format!("http://{}", addr), peer_addr))
}
#[cfg(test)]
pub(crate) async fn spawn_sequence_server(responses: Vec<&'static str>) -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
for response in responses {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await.unwrap();
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
}
});
});
Ok(format!("http://{}", addr))
}
#[cfg(test)]
pub(crate) async fn spawn_timed_sequence_server(
responses: Vec<(&'static str, std::time::Duration)>,
) -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
for (response, delay) in responses {
let (mut stream, _) = listener.accept().await.unwrap();
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await;
async_io::Timer::after(delay).await;
if stream.write_all(response.as_bytes()).await.is_err() {
return;
}
let _ = stream.flush().await;
});
});
}
});
});
Ok(format!("http://{}", addr))
}
#[cfg(test)]
async fn read_full_request(stream: &mut TcpStream) -> String {
let mut buffer = Vec::new();
let mut header_end = None;
let mut expected_body_len = None;
let mut chunked = false;
loop {
let mut scratch = [0_u8; 1024];
let read = stream.read(&mut scratch).await.unwrap();
if read == 0 {
break;
}
buffer.extend_from_slice(&scratch[..read]);
if header_end.is_none() {
header_end = find_headers_end(&buffer);
if let Some(end) = header_end {
let head = &buffer[..end];
if let Ok(head_text) = std::str::from_utf8(head) {
for line in head_text.split("\r\n") {
let Some((name, value)) = line.split_once(':') else {
continue;
};
let name = name.trim().to_ascii_lowercase();
let value = value.trim();
if name == "content-length" {
if let Ok(len) = value.parse::<usize>() {
expected_body_len = Some(len);
}
}
if name == "transfer-encoding"
&& value
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("chunked"))
{
chunked = true;
}
}
}
}
}
let Some(end) = header_end else {
continue;
};
let body_start = end + 4;
if buffer.len() < body_start {
continue;
}
if let Some(len) = expected_body_len {
if buffer.len() >= body_start + len {
break;
}
} else if chunked {
let body = &buffer[body_start..];
if body.windows(5).any(|window| window == b"0\r\n\r\n") {
break;
}
} else {
break;
}
}
String::from_utf8_lossy(&buffer).to_string()
}
#[cfg(test)]
async fn spawn_http_proxy_server(
assert_request: fn(&str),
response: &'static str,
) -> Result<Proxy> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind proxy server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect proxy server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut upstream, _) = listener.accept().await.unwrap();
let first_request = read_full_request(&mut upstream).await;
assert_request(&first_request);
if first_request.starts_with("CONNECT ") {
upstream
.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
.await
.unwrap();
upstream.flush().await.unwrap();
let tunneled_request = read_full_request(&mut upstream).await;
assert!(
tunneled_request.starts_with("GET ") || tunneled_request.starts_with("POST ")
);
}
upstream.write_all(response.as_bytes()).await.unwrap();
upstream.flush().await.unwrap();
});
});
Ok(Proxy::http(addr))
}
#[cfg(test)]
async fn spawn_socks5_proxy_server(
expect_auth: Option<(&'static str, &'static str)>,
response: &'static str,
) -> Result<Proxy> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind socks5 proxy server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect socks5 proxy server",
err,
)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut hello = [0_u8; 4];
stream.read_exact(&mut hello[..2]).await.unwrap();
assert_eq!(hello[0], 0x05);
let methods_len = hello[1] as usize;
let mut methods = vec![0_u8; methods_len];
stream.read_exact(&mut methods).await.unwrap();
match expect_auth {
Some((username, password)) => {
assert!(methods.contains(&0x02));
stream.write_all(&[0x05, 0x02]).await.unwrap();
let mut auth_header = [0_u8; 2];
stream.read_exact(&mut auth_header).await.unwrap();
assert_eq!(auth_header[0], 0x01);
let username_len = auth_header[1] as usize;
let mut username_buf = vec![0_u8; username_len];
stream.read_exact(&mut username_buf).await.unwrap();
let mut password_len = [0_u8; 1];
stream.read_exact(&mut password_len).await.unwrap();
let mut password_buf = vec![0_u8; password_len[0] as usize];
stream.read_exact(&mut password_buf).await.unwrap();
assert_eq!(String::from_utf8_lossy(&username_buf), username);
assert_eq!(String::from_utf8_lossy(&password_buf), password);
stream.write_all(&[0x01, 0x00]).await.unwrap();
}
None => {
assert!(methods.contains(&0x00));
stream.write_all(&[0x05, 0x00]).await.unwrap();
}
}
let mut request_header = [0_u8; 4];
stream.read_exact(&mut request_header).await.unwrap();
assert_eq!(request_header, [0x05, 0x01, 0x00, 0x03]);
let mut domain_len = [0_u8; 1];
stream.read_exact(&mut domain_len).await.unwrap();
let mut domain = vec![0_u8; domain_len[0] as usize];
stream.read_exact(&mut domain).await.unwrap();
let mut port = [0_u8; 2];
stream.read_exact(&mut port).await.unwrap();
assert!(!domain.is_empty());
assert_ne!(u16::from_be_bytes(port), 0);
let mut reply = vec![0x05, 0x00, 0x00, 0x01];
reply.extend_from_slice(&[127, 0, 0, 1]);
reply.extend_from_slice(&8080_u16.to_be_bytes());
stream.write_all(&reply).await.unwrap();
stream.flush().await.unwrap();
let tunneled_request = read_full_request(&mut stream).await;
assert!(tunneled_request.starts_with("GET ") || tunneled_request.starts_with("POST "));
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok(match expect_auth {
Some((username, password)) => Proxy::socks5_with_auth(addr, username, password),
None => Proxy::socks5(addr),
})
}
#[cfg(test)]
pub(crate) async fn spawn_assert_server(
assert_request: fn(&str),
response: &'static str,
) -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let request = read_full_request(&mut stream).await;
assert_request(&request);
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok(format!("http://{}", addr))
}
#[cfg(test)]
pub(crate) async fn spawn_keep_alive_server() -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
for _ in 0..2 {
let mut request = Vec::new();
loop {
let mut chunk = [0_u8; 1024];
let read = stream.read(&mut chunk).await.unwrap();
if read == 0 {
return;
}
request.extend_from_slice(&chunk[..read]);
if find_headers_end(&request).is_some() {
break;
}
}
let response =
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok";
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
}
});
});
Ok(format!("http://{}", addr))
}
#[cfg(test)]
pub(crate) async fn spawn_connection_count_server(
response: &'static str,
expected_connections: usize,
) -> Result<(String, Arc<std::sync::atomic::AtomicUsize>)> {
use std::sync::atomic::{AtomicUsize, Ordering};
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let connection_count_task = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
for _ in 0..expected_connections {
let (mut stream, _) = listener.accept().await.unwrap();
connection_count_task.fetch_add(1, Ordering::SeqCst);
let mut request = Vec::new();
loop {
let mut chunk = [0_u8; 1024];
let read = stream.read(&mut chunk).await.unwrap();
if read == 0 {
break;
}
request.extend_from_slice(&chunk[..read]);
if find_headers_end(&request).is_some() {
break;
}
}
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
}
});
});
Ok((format!("http://{}", addr), connection_count))
}
#[cfg(test)]
pub(crate) async fn spawn_delayed_server(
delay: std::time::Duration,
response: &'static str,
) -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await.unwrap();
std::thread::sleep(delay);
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok(format!("http://{}", addr))
}
#[cfg(all(test, feature = "rustls"))]
pub(crate) async fn spawn_tls_test_server(response: &'static str) -> Result<String> {
let (base, _) = spawn_tls_test_server_with_cert(response).await?;
Ok(base)
}
#[cfg(all(test, feature = "rustls"))]
pub(crate) async fn spawn_tls_test_server_with_cert(
response: &'static str,
) -> Result<(String, String)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_pem = cert.cert.pem();
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build tls server config",
err,
)
})?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind tls test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect tls test server",
err,
)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let mut scratch = vec![0_u8; 4096];
let _ = stream.read(&mut scratch).await.unwrap();
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok((format!("https://localhost:{}", addr.port()), cert_pem))
}
#[cfg(all(test, feature = "rustls", feature = "btls-backend"))]
pub(crate) async fn spawn_tls_test_server_with_cert_and_alpn(
response: &'static str,
server_alpn_protocols: Vec<Vec<u8>>,
expected_alpn: Option<Vec<u8>>,
) -> Result<(String, String)> {
use std::io::{Read, Write};
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_pem = cert.cert.pem();
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build tls server config",
err,
)
})?;
server_config.alpn_protocols = server_alpn_protocols;
let server_config = Arc::new(server_config);
let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind tls test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect tls test server",
err,
)
})?;
std::thread::spawn(move || {
let (tcp, _) = listener.accept().unwrap();
let conn = rustls::ServerConnection::new(server_config).unwrap();
let mut tls = rustls::StreamOwned::new(conn, tcp);
while tls.conn.is_handshaking() {
tls.conn.complete_io(&mut tls.sock).unwrap();
}
if let Some(expected) = expected_alpn {
let negotiated = tls.conn.alpn_protocol();
if negotiated != Some(expected.as_slice()) {
return;
}
}
let mut scratch = [0_u8; 4096];
let _ = tls.read(&mut scratch).unwrap();
tls.write_all(response.as_bytes()).unwrap();
tls.flush().unwrap();
});
Ok((format!("https://localhost:{}", addr.port()), cert_pem))
}
#[cfg(all(test, feature = "rustls"))]
pub(crate) async fn spawn_tls_keep_alive_server()
-> Result<(String, Arc<std::sync::atomic::AtomicUsize>)> {
spawn_tls_connection_count_server(1, 2).await
}
#[cfg(all(test, feature = "rustls"))]
pub(crate) async fn spawn_tls_connection_count_server(
expected_connections: usize,
requests_per_connection: usize,
) -> Result<(String, Arc<std::sync::atomic::AtomicUsize>)> {
use std::sync::atomic::{AtomicUsize, Ordering};
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build tls server config",
err,
)
})?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind tls test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect tls test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let connection_count_task = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
for _ in 0..expected_connections {
let (stream, _) = listener.accept().await.unwrap();
connection_count_task.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
for _ in 0..requests_per_connection {
let mut request = Vec::new();
loop {
let mut chunk = [0_u8; 1024];
let read = stream.read(&mut chunk).await.unwrap();
if read == 0 {
return;
}
request.extend_from_slice(&chunk[..read]);
if find_headers_end(&request).is_some() {
break;
}
}
let response =
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok";
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
))
}
#[cfg(test)]
mod tests {
#[cfg(feature = "emulation")]
use crate::{
Emulation, EmulationProfile, Http2Fingerprint, Http2PriorityPhase, Http2PrioritySpec,
};
use bytes::Bytes;
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::DeflateEncoder;
use flate2::write::GzEncoder;
use futures_lite::StreamExt;
use futures_lite::future::block_on;
use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
use futures_lite::stream;
#[cfg(all(feature = "rustls", feature = "h2"))]
use hpack::{Decoder as HpackDecoder, Encoder as HpackEncoder};
#[cfg(feature = "h3")]
use quiche::h3::NameValue;
use sha2::Digest;
#[cfg(all(feature = "rustls", feature = "h2"))]
use std::collections::HashMap;
use std::io::{Read, Write};
#[cfg(all(feature = "rustls", feature = "h2"))]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex};
use super::TcpListener;
#[cfg(all(feature = "rustls", feature = "h2"))]
use super::{ServerConfig, TlsAcceptor, generate_simple_self_signed};
use super::{
spawn_assert_server, spawn_bytes_server, spawn_connection_count_server,
spawn_delayed_server, spawn_http_proxy_server, spawn_keep_alive_server,
spawn_peer_addr_server, spawn_sequence_server, spawn_socks5_proxy_server,
spawn_test_server, spawn_timed_sequence_server,
};
#[cfg(feature = "rustls")]
use super::{
spawn_tls_connection_count_server, spawn_tls_keep_alive_server, spawn_tls_test_server,
spawn_tls_test_server_with_cert,
};
use crate::decode::DEFAULT_ACCEPT_ENCODING;
use crate::progress::ProgressConfig;
use crate::{Error, ErrorKind, Result, StatusCode, get};
fn run<T>(value: T) -> T::Output
where
T: std::future::IntoFuture,
{
block_on(async move { value.await })
}
fn spawn_chunked_assert_server(
assert_request: fn(&str),
response: &'static str,
) -> Result<String> {
let listener = block_on(TcpListener::bind(("127.0.0.1", 0))).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to inspect test server", err)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let request = super::read_full_request(&mut stream).await;
assert_request(&request);
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
});
});
Ok(format!("http://{}", addr))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_END_STREAM: u8 = 0x1;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_END_HEADERS: u8 = 0x4;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_ACK: u8 = 0x1;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_DATA: u8 = 0x0;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_HEADERS: u8 = 0x1;
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
const H2_TYPE_PRIORITY: u8 = 0x2;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_RST_STREAM: u8 = 0x3;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_SETTINGS: u8 = 0x4;
#[cfg(all(feature = "rustls", feature = "h2"))]
#[allow(dead_code)] const H2_TYPE_PUSH_PROMISE: u8 = 0x5;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_PING: u8 = 0x6;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_GOAWAY: u8 = 0x7;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_WINDOW_UPDATE: u8 = 0x8;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_TYPE_CONTINUATION: u8 = 0x9;
#[cfg(all(feature = "rustls", feature = "h2"))]
const H2_CLIENT_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
#[cfg(all(feature = "rustls", feature = "h2"))]
#[derive(Clone, Debug, Default)]
struct H2TestResponse {
status: u16,
headers: Vec<(String, String)>,
body_frames: Vec<Vec<u8>>,
trailers: Vec<(String, String)>,
send_goaway: bool,
}
#[cfg(all(feature = "rustls", feature = "h2"))]
impl H2TestResponse {
fn text(status: u16, body: impl AsRef<[u8]>) -> Self {
let body = body.as_ref().to_vec();
Self {
status,
headers: vec![("content-length".to_owned(), body.len().to_string())],
body_frames: vec![body],
trailers: Vec::new(),
send_goaway: false,
}
}
fn empty(status: u16) -> Self {
Self {
status,
headers: Vec::new(),
body_frames: Vec::new(),
trailers: Vec::new(),
send_goaway: false,
}
}
fn header(mut self, name: &str, value: impl Into<String>) -> Self {
self.headers.push((name.to_owned(), value.into()));
self
}
fn body_frame(mut self, chunk: impl AsRef<[u8]>) -> Self {
self.body_frames.push(chunk.as_ref().to_vec());
self
}
fn trailer(mut self, name: &str, value: impl Into<String>) -> Self {
self.trailers.push((name.to_owned(), value.into()));
self
}
fn goaway(mut self) -> Self {
self.send_goaway = true;
self
}
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[derive(Default)]
struct H2RequestState {
headers_complete: bool,
end_stream: bool,
header_block: Vec<u8>,
headers: Vec<(String, String)>,
body: Vec<u8>,
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[derive(Clone, Debug, Default)]
struct H2CapturedRequest {
headers: Vec<(String, String)>,
body: Vec<u8>,
}
#[cfg(all(feature = "rustls", feature = "h2"))]
impl H2CapturedRequest {
fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(existing, _)| existing.eq_ignore_ascii_case(name))
.map(|(_, value)| value.as_str())
}
fn body(&self) -> &[u8] {
&self.body
}
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn decode_h2_test_block(
decoder: &mut HpackDecoder<'static>,
block: &[u8],
) -> Result<Vec<(String, String)>> {
let decoded = decoder.decode(block).map_err(|err| {
Error::new(
ErrorKind::Transport,
format!("failed to decode test hpack block: {err:?}"),
)
})?;
let mut headers = Vec::with_capacity(decoded.len());
for (name, value) in decoded {
let name = String::from_utf8(name).map_err(|err| {
Error::with_source(ErrorKind::Transport, "invalid test hpack header name", err)
})?;
let value = String::from_utf8(value).map_err(|err| {
Error::with_source(ErrorKind::Transport, "invalid test hpack header value", err)
})?;
headers.push((name, value));
}
Ok(headers)
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn grpc_test_frame_with_flag(payload: impl AsRef<[u8]>, compressed: bool) -> Vec<u8> {
let payload = payload.as_ref();
let mut frame = Vec::with_capacity(5 + payload.len());
frame.push(u8::from(compressed));
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(payload);
frame
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn grpc_test_frame(payload: impl AsRef<[u8]>) -> Vec<u8> {
grpc_test_frame_with_flag(payload, false)
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn grpc_test_gzip_bytes(payload: impl AsRef<[u8]>) -> Vec<u8> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(payload.as_ref()).unwrap();
encoder.finish().unwrap()
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn grpc_test_gzip_frame(payload: impl AsRef<[u8]>) -> Vec<u8> {
grpc_test_frame_with_flag(grpc_test_gzip_bytes(payload), true)
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn grpc_test_json_frame<T: serde::Serialize>(value: &T) -> Vec<u8> {
grpc_test_frame(serde_json::to_vec(value).unwrap())
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn grpc_test_gzip_json_frame<T: serde::Serialize>(value: &T) -> Vec<u8> {
grpc_test_gzip_frame(serde_json::to_vec(value).unwrap())
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn decode_grpc_test_frames(body: &[u8], compression: Option<&str>) -> Result<Vec<Bytes>> {
let mut offset = 0usize;
let mut messages = Vec::new();
while offset < body.len() {
if body.len() - offset < 5 {
return Err(Error::new(
ErrorKind::Transport,
"incomplete grpc test frame header",
));
}
let compressed = body[offset];
let len = u32::from_be_bytes([
body[offset + 1],
body[offset + 2],
body[offset + 3],
body[offset + 4],
]) as usize;
offset += 5;
if body.len() - offset < len {
return Err(Error::new(
ErrorKind::Transport,
"grpc test frame length exceeds remaining body",
));
}
let payload = &body[offset..offset + len];
let payload = match compressed {
0 => Bytes::copy_from_slice(payload),
1 => {
if !matches!(compression, Some(value) if value.eq_ignore_ascii_case("gzip")) {
return Err(Error::new(
ErrorKind::Transport,
"grpc test frame is compressed but compression algorithm is missing",
));
}
let mut decoder = GzDecoder::new(payload);
let mut decoded = Vec::new();
decoder.read_to_end(&mut decoded).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to decode gzip grpc test frame",
err,
)
})?;
Bytes::from(decoded)
}
other => {
return Err(Error::new(
ErrorKind::Transport,
format!("invalid grpc test compression flag: {other}"),
));
}
};
messages.push(payload);
offset += len;
}
Ok(messages)
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn write_h2_test_frame<S>(
stream: &mut S,
frame_type: u8,
flags: u8,
stream_id: u32,
payload: &[u8],
) -> Result<()>
where
S: futures_lite::io::AsyncWrite + Unpin,
{
let len = payload.len();
if len > 0xFF_FF_FF {
return Err(Error::new(
ErrorKind::Transport,
"test http2 frame payload exceeds maximum size",
));
}
let mut header = [0_u8; 9];
header[0] = ((len >> 16) & 0xFF) as u8;
header[1] = ((len >> 8) & 0xFF) as u8;
header[2] = (len & 0xFF) as u8;
header[3] = frame_type;
header[4] = flags;
header[5..9].copy_from_slice(&(stream_id & 0x7FFF_FFFF).to_be_bytes());
stream.write_all(&header).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to write test http2 frame",
err,
)
})?;
if !payload.is_empty() {
stream.write_all(payload).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to write test http2 frame payload",
err,
)
})?;
}
stream.flush().await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to flush test http2 frame",
err,
)
})?;
Ok(())
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn encode_h2_test_settings(settings: &[(u16, u32)]) -> Vec<u8> {
let mut payload = Vec::with_capacity(settings.len() * 6);
for (identifier, value) in settings {
payload.extend_from_slice(&identifier.to_be_bytes());
payload.extend_from_slice(&value.to_be_bytes());
}
payload
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn read_h2_test_frame<S>(stream: &mut S) -> Result<(u8, u8, u32, Vec<u8>)>
where
S: futures_lite::io::AsyncRead + Unpin,
{
let mut header = [0_u8; 9];
stream.read_exact(&mut header).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to read test http2 frame", err)
})?;
let len = ((header[0] as usize) << 16) | ((header[1] as usize) << 8) | header[2] as usize;
let mut payload = vec![0_u8; len];
if len > 0 {
stream.read_exact(&mut payload).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to read test http2 frame payload",
err,
)
})?;
}
let stream_id =
u32::from_be_bytes([header[5], header[6], header[7], header[8]]) & 0x7FFF_FFFF;
Ok((header[3], header[4], stream_id, payload))
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
fn decode_h2_test_settings(payload: &[u8]) -> Vec<(u16, u32)> {
payload
.chunks_exact(6)
.map(|chunk| {
(
u16::from_be_bytes([chunk[0], chunk[1]]),
u32::from_be_bytes([chunk[2], chunk[3], chunk[4], chunk[5]]),
)
})
.collect()
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
fn decode_h2_test_window_update(payload: &[u8]) -> u32 {
u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]) & 0x7FFF_FFFF
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
fn decode_h2_test_priority(payload: &[u8]) -> (u32, u16, bool) {
let dependency = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
let exclusive = dependency & 0x8000_0000 != 0;
let stream_dependency = dependency & 0x7FFF_FFFF;
let weight = u16::from(payload[4]) + 1;
(stream_dependency, weight, exclusive)
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn encode_h2_test_headers(
encoder: &mut HpackEncoder<'static>,
status: u16,
headers: &[(String, String)],
) -> Vec<u8> {
let mut pairs = Vec::with_capacity(headers.len() + 1);
pairs.push((b":status".to_vec(), status.to_string().into_bytes()));
for (name, value) in headers {
pairs.push((name.as_bytes().to_vec(), value.as_bytes().to_vec()));
}
let refs = pairs
.iter()
.map(|(name, value)| (name.as_slice(), value.as_slice()))
.collect::<Vec<_>>();
encoder.encode(refs)
}
#[cfg(all(feature = "rustls", feature = "h2"))]
fn encode_h2_test_trailers(
encoder: &mut HpackEncoder<'static>,
trailers: &[(String, String)],
) -> Vec<u8> {
let refs = trailers
.iter()
.map(|(name, value)| (name.as_bytes(), value.as_bytes()))
.collect::<Vec<_>>();
encoder.encode(refs)
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn send_h2_test_response<S>(
stream: &mut S,
encoder: &mut HpackEncoder<'static>,
stream_id: u32,
response: H2TestResponse,
) -> Result<()>
where
S: futures_lite::io::AsyncRead + futures_lite::io::AsyncWrite + Unpin,
{
let end_headers_stream = response.body_frames.is_empty() && response.trailers.is_empty();
let header_block = encode_h2_test_headers(encoder, response.status, &response.headers);
write_h2_test_frame(
stream,
H2_TYPE_HEADERS,
H2_END_HEADERS | if end_headers_stream { H2_END_STREAM } else { 0 },
stream_id,
&header_block,
)
.await?;
for (index, chunk) in response.body_frames.iter().enumerate() {
let end_stream =
index + 1 == response.body_frames.len() && response.trailers.is_empty();
write_h2_test_frame(
stream,
H2_TYPE_DATA,
if end_stream { H2_END_STREAM } else { 0 },
stream_id,
chunk,
)
.await?;
}
if !response.trailers.is_empty() {
let trailer_block = encode_h2_test_trailers(encoder, &response.trailers);
write_h2_test_frame(
stream,
H2_TYPE_HEADERS,
H2_END_HEADERS | H2_END_STREAM,
stream_id,
&trailer_block,
)
.await?;
}
if response.send_goaway {
let mut payload = [0_u8; 8];
payload[..4].copy_from_slice(&(stream_id & 0x7FFF_FFFF).to_be_bytes());
write_h2_test_frame(stream, H2_TYPE_GOAWAY, 0, 0, &payload).await?;
}
Ok(())
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn serve_h2_test_connection<S>(
stream: &mut S,
responses: Vec<H2TestResponse>,
recorded_requests: Option<Arc<Mutex<Vec<H2CapturedRequest>>>>,
initial_settings_payload: &[u8],
) -> Result<()>
where
S: futures_lite::io::AsyncRead + futures_lite::io::AsyncWrite + Unpin,
{
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to read test http2 preface",
err,
)
})?;
if preface.as_slice() != H2_CLIENT_PREFACE {
return Err(Error::new(
ErrorKind::Transport,
format!(
"unexpected client http2 preface in test server: {:?}",
String::from_utf8_lossy(&preface)
),
));
}
write_h2_test_frame(stream, H2_TYPE_SETTINGS, 0, 0, initial_settings_payload).await?;
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let mut response_iter = responses.into_iter();
let mut served = 0usize;
let expected = response_iter.len();
while served < expected {
let (frame_type, flags, stream_id, payload) = read_h2_test_frame(stream).await?;
match frame_type {
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
write_h2_test_frame(stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[]).await?;
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(stream, H2_TYPE_PING, H2_ACK, 0, &payload).await?;
}
}
H2_TYPE_HEADERS => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers = decode_h2_test_block(&mut decoder, &state.header_block)?;
state.header_block.clear();
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
if let Some(recorded_requests) = recorded_requests.as_ref() {
recorded_requests.lock().unwrap().push(H2CapturedRequest {
headers: state.headers,
body: state.body,
});
}
let response = response_iter.next().ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"missing scripted http2 response for request",
)
})?;
send_h2_test_response(stream, &mut encoder, stream_id, response).await?;
served += 1;
}
}
H2_TYPE_CONTINUATION => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers = decode_h2_test_block(&mut decoder, &state.header_block)?;
state.header_block.clear();
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
if let Some(recorded_requests) = recorded_requests.as_ref() {
recorded_requests.lock().unwrap().push(H2CapturedRequest {
headers: state.headers,
body: state.body,
});
}
let response = response_iter.next().ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"missing scripted http2 response for continued request",
)
})?;
send_h2_test_response(stream, &mut encoder, stream_id, response).await?;
served += 1;
}
}
H2_TYPE_DATA => {
let state = request_states.entry(stream_id).or_default();
state.body.extend_from_slice(&payload);
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
if let Some(recorded_requests) = recorded_requests.as_ref() {
recorded_requests.lock().unwrap().push(H2CapturedRequest {
headers: state.headers,
body: state.body,
});
}
let response = response_iter.next().ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"missing scripted http2 response for data request",
)
})?;
send_h2_test_response(stream, &mut encoder, stream_id, response).await?;
served += 1;
}
}
H2_TYPE_WINDOW_UPDATE => {}
H2_TYPE_GOAWAY => break,
_ => {}
}
}
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
Ok(())
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_recording_server(
connection_scripts: Vec<Vec<H2TestResponse>>,
) -> Result<(String, Arc<AtomicUsize>, Arc<Mutex<Vec<H2CapturedRequest>>>)> {
spawn_h2_tls_recording_server_with_settings(connection_scripts, Vec::new()).await
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_recording_server_with_settings(
connection_scripts: Vec<Vec<H2TestResponse>>,
initial_settings_payload: Vec<u8>,
) -> Result<(String, Arc<AtomicUsize>, Arc<Mutex<Vec<H2CapturedRequest>>>)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
let task_initial_settings_payload = initial_settings_payload.clone();
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
for script in connection_scripts {
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
serve_h2_test_connection(
&mut stream,
script,
Some(Arc::clone(&task_recorded_requests)),
&task_initial_settings_payload,
)
.await
.unwrap();
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
recorded_requests,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_script_server(
connection_scripts: Vec<Vec<H2TestResponse>>,
) -> Result<(String, Arc<AtomicUsize>)> {
let (base, connection_count, _recorded_requests) =
spawn_h2_tls_recording_server(connection_scripts).await?;
Ok((base, connection_count))
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
async fn spawn_h2_tls_lifecycle_probe_server() -> Result<(
String,
Arc<Mutex<Option<(H2LifecycleCapture, H2CapturedRequest)>>>,
)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 lifecycle tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 lifecycle tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 lifecycle tls test server",
err,
)
})?;
let capture = Arc::new(Mutex::new(None));
let task_capture = Arc::clone(&capture);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let mut lifecycle = H2LifecycleCapture::default();
let (stream_id, request) = read_complete_h2_test_request_with_capture(
&mut stream,
&mut decoder,
&mut request_states,
&mut lifecycle,
)
.await
.unwrap();
*task_capture.lock().unwrap() = Some((lifecycle, request));
let mut encoder = HpackEncoder::new();
send_h2_test_response(
&mut stream,
&mut encoder,
stream_id,
H2TestResponse::text(200, "ok"),
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((format!("https://localhost:{}", addr.port()), capture))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2c_recording_server(
connection_scripts: Vec<Vec<H2TestResponse>>,
) -> Result<(String, Arc<AtomicUsize>, Arc<Mutex<Vec<H2CapturedRequest>>>)> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h2c test server", err)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2c test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
for script in connection_scripts {
let (mut stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
serve_h2_test_connection(
&mut stream,
script,
Some(Arc::clone(&task_recorded_requests)),
&[],
)
.await
.unwrap();
}
});
});
Ok((
format!("http://{}", addr),
connection_count,
recorded_requests,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2c_script_server(
connection_scripts: Vec<Vec<H2TestResponse>>,
) -> Result<(String, Arc<AtomicUsize>)> {
let (base, connection_count, _recorded_requests) =
spawn_h2c_recording_server(connection_scripts).await?;
Ok((base, connection_count))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2c_upgrade_server(response: H2TestResponse) -> Result<String> {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2c upgrade test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2c upgrade test server",
err,
)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let request = super::read_full_request(&mut stream).await;
assert!(request.contains("Upgrade: h2c") || request.contains("upgrade: h2c"));
assert!(request.contains("HTTP2-Settings:") || request.contains("http2-settings:"));
stream
.write_all(
b"HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: h2c\r\n\r\n",
)
.await
.unwrap();
stream.flush().await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
let (frame_type, flags, _stream_id, _payload) =
read_h2_test_frame(&mut stream).await.unwrap();
assert_eq!(frame_type, H2_TYPE_SETTINGS);
assert_eq!(flags & H2_ACK, 0);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[])
.await
.unwrap();
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
send_h2_test_response(&mut stream, &mut encoder, 1, response)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok(format!("http://{}", addr))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[allow(dead_code)] async fn spawn_h2_tls_push_refusal_server() -> Result<(String, Arc<Mutex<Vec<u32>>>)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
let reset_codes = Arc::new(Mutex::new(Vec::new()));
let task_reset_codes = Arc::clone(&reset_codes);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let (stream_id, _request) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let push_headers = vec![
(b":method".as_slice(), b"GET".as_slice()),
(b":scheme".as_slice(), b"https".as_slice()),
(b":authority".as_slice(), b"localhost".as_slice()),
(b":path".as_slice(), b"/pushed".as_slice()),
(b"x-shared".as_slice(), b"indexed".as_slice()),
];
let push_block = encoder.encode(push_headers);
let mut push_payload = Vec::with_capacity(4 + push_block.len());
push_payload.extend_from_slice(&(2_u32).to_be_bytes());
push_payload.extend_from_slice(&push_block);
write_h2_test_frame(
&mut stream,
H2_TYPE_PUSH_PROMISE,
H2_END_HEADERS,
stream_id,
&push_payload,
)
.await
.unwrap();
let response = H2TestResponse::text(200, "ok").header("x-shared", "indexed");
send_h2_test_response(&mut stream, &mut encoder, stream_id, response)
.await
.unwrap();
loop {
let (frame_type, _flags, frame_stream_id, payload) =
read_h2_test_frame(&mut stream).await.unwrap();
match frame_type {
H2_TYPE_RST_STREAM if frame_stream_id == 2 => {
let code = u32::from_be_bytes([
payload[0], payload[1], payload[2], payload[3],
]);
task_reset_codes.lock().unwrap().push(code);
break;
}
H2_TYPE_SETTINGS if payload.is_empty() => {}
H2_TYPE_WINDOW_UPDATE | H2_TYPE_PING => {}
H2_TYPE_GOAWAY => break,
_ => {}
}
}
});
});
Ok((format!("https://localhost:{}", addr.port()), reset_codes))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_delayed_grpc_stream_server() -> Result<String> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let (stream_id, _request) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
let headers = encode_h2_test_headers(
&mut encoder,
200,
&[(
"content-type".to_owned(),
"application/grpc+json".to_owned(),
)],
);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS,
stream_id,
&headers,
)
.await
.unwrap();
write_h2_test_frame(
&mut stream,
H2_TYPE_DATA,
0,
stream_id,
&grpc_test_json_frame(&serde_json::json!({ "seq": 1 })),
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(60)).await;
write_h2_test_frame(
&mut stream,
H2_TYPE_DATA,
0,
stream_id,
&grpc_test_json_frame(&serde_json::json!({ "seq": 2 })),
)
.await
.unwrap();
let trailers = encode_h2_test_trailers(
&mut encoder,
&[
("grpc-status".to_owned(), "0".to_owned()),
("x-trace-id".to_owned(), "delayed".to_owned()),
],
);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS | H2_END_STREAM,
stream_id,
&trailers,
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok(format!("https://localhost:{}", addr.port()))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_grpc_bidi_probe_server()
-> Result<(String, Arc<Mutex<Vec<H2CapturedRequest>>>)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let mut responded = false;
loop {
let (frame_type, flags, stream_id, payload) =
read_h2_test_frame(&mut stream).await.unwrap();
match frame_type {
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[])
.await
.unwrap();
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_PING, H2_ACK, 0, &payload)
.await
.unwrap();
}
}
H2_TYPE_HEADERS => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers =
decode_h2_test_block(&mut decoder, &state.header_block)
.unwrap();
state.header_block.clear();
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
}
H2_TYPE_CONTINUATION => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers =
decode_h2_test_block(&mut decoder, &state.header_block)
.unwrap();
state.header_block.clear();
}
}
H2_TYPE_DATA => {
let state = request_states.entry(stream_id).or_default();
state.body.extend_from_slice(&payload);
if !responded {
let headers = encode_h2_test_headers(
&mut encoder,
200,
&[
(
"content-type".to_owned(),
"application/grpc+json".to_owned(),
),
("x-stream-bin".to_owned(), "AQI=".to_owned()),
],
);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS,
stream_id,
&headers,
)
.await
.unwrap();
write_h2_test_frame(
&mut stream,
H2_TYPE_DATA,
0,
stream_id,
&grpc_test_json_frame(&serde_json::json!({ "seq": 1 })),
)
.await
.unwrap();
responded = true;
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).unwrap();
task_recorded_requests
.lock()
.unwrap()
.push(H2CapturedRequest {
headers: state.headers,
body: state.body,
});
write_h2_test_frame(
&mut stream,
H2_TYPE_DATA,
0,
stream_id,
&grpc_test_json_frame(&serde_json::json!({ "seq": 2 })),
)
.await
.unwrap();
let trailers = encode_h2_test_trailers(
&mut encoder,
&[
("grpc-status".to_owned(), "0".to_owned()),
("x-trace-bin".to_owned(), "AAE=".to_owned()),
],
);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS | H2_END_STREAM,
stream_id,
&trailers,
)
.await
.unwrap();
break;
}
}
H2_TYPE_WINDOW_UPDATE => {}
_ => {}
}
}
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((
format!("https://localhost:{}", addr.port()),
recorded_requests,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn read_complete_h2_test_request<S>(
stream: &mut S,
decoder: &mut HpackDecoder<'static>,
request_states: &mut HashMap<u32, H2RequestState>,
) -> Result<(u32, H2CapturedRequest)>
where
S: futures_lite::io::AsyncRead + futures_lite::io::AsyncWrite + Unpin,
{
loop {
let (frame_type, flags, stream_id, payload) = read_h2_test_frame(stream).await?;
match frame_type {
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
write_h2_test_frame(stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[]).await?;
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(stream, H2_TYPE_PING, H2_ACK, 0, &payload).await?;
}
}
H2_TYPE_HEADERS => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers = decode_h2_test_block(decoder, &state.header_block)?;
state.header_block.clear();
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
return Ok((
stream_id,
H2CapturedRequest {
headers: state.headers,
body: state.body,
},
));
}
}
H2_TYPE_CONTINUATION => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers = decode_h2_test_block(decoder, &state.header_block)?;
state.header_block.clear();
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
return Ok((
stream_id,
H2CapturedRequest {
headers: state.headers,
body: state.body,
},
));
}
}
H2_TYPE_DATA => {
let state = request_states.entry(stream_id).or_default();
state.body.extend_from_slice(&payload);
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
return Ok((
stream_id,
H2CapturedRequest {
headers: state.headers,
body: state.body,
},
));
}
}
H2_TYPE_WINDOW_UPDATE => {}
H2_TYPE_RST_STREAM => {}
H2_TYPE_GOAWAY => {
return Err(Error::new(
ErrorKind::Transport,
"test server received GOAWAY before request completed",
));
}
_ => {}
}
}
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
#[derive(Debug, Default)]
struct H2LifecycleCapture {
client_settings: Option<Vec<(u16, u32)>>,
connection_window_increment: Option<u32>,
priority_frames: Vec<(u32, u32, u16, bool)>,
frame_sequence: Vec<(u8, u8, u32)>,
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
async fn read_complete_h2_test_request_with_capture<S>(
stream: &mut S,
decoder: &mut HpackDecoder<'static>,
request_states: &mut HashMap<u32, H2RequestState>,
capture: &mut H2LifecycleCapture,
) -> Result<(u32, H2CapturedRequest)>
where
S: futures_lite::io::AsyncRead + futures_lite::io::AsyncWrite + Unpin,
{
loop {
let (frame_type, flags, stream_id, payload) = read_h2_test_frame(stream).await?;
capture.frame_sequence.push((frame_type, flags, stream_id));
match frame_type {
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
capture.client_settings = Some(decode_h2_test_settings(&payload));
write_h2_test_frame(stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[]).await?;
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(stream, H2_TYPE_PING, H2_ACK, 0, &payload).await?;
}
}
H2_TYPE_WINDOW_UPDATE if stream_id == 0 => {
capture.connection_window_increment =
Some(decode_h2_test_window_update(&payload));
}
H2_TYPE_PRIORITY => {
let (stream_dependency, weight, exclusive) = decode_h2_test_priority(&payload);
capture
.priority_frames
.push((stream_id, stream_dependency, weight, exclusive));
}
H2_TYPE_HEADERS => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers = decode_h2_test_block(decoder, &state.header_block)?;
state.header_block.clear();
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
drain_h2_lifecycle_follow_up_frames(stream, capture).await?;
return Ok((
stream_id,
H2CapturedRequest {
headers: state.headers,
body: state.body,
},
));
}
}
H2_TYPE_CONTINUATION => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers = decode_h2_test_block(decoder, &state.header_block)?;
state.header_block.clear();
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
return Ok((
stream_id,
H2CapturedRequest {
headers: state.headers,
body: state.body,
},
));
}
}
H2_TYPE_DATA => {
let state = request_states.entry(stream_id).or_default();
state.body.extend_from_slice(&payload);
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
let state = request_states.remove(&stream_id).expect("request state");
drain_h2_lifecycle_follow_up_frames(stream, capture).await?;
return Ok((
stream_id,
H2CapturedRequest {
headers: state.headers,
body: state.body,
},
));
}
}
H2_TYPE_WINDOW_UPDATE => {}
H2_TYPE_RST_STREAM => {}
H2_TYPE_GOAWAY => {
return Err(Error::new(
ErrorKind::Transport,
"test server received GOAWAY before request completed",
));
}
_ => {}
}
}
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
async fn drain_h2_lifecycle_follow_up_frames<S>(
stream: &mut S,
capture: &mut H2LifecycleCapture,
) -> Result<()>
where
S: futures_lite::io::AsyncRead + futures_lite::io::AsyncWrite + Unpin,
{
loop {
let read_future = read_h2_test_frame(stream);
let event = futures_lite::future::or(
async {
async_io::Timer::after(std::time::Duration::from_millis(10)).await;
None
},
async { Some(read_future.await) },
)
.await;
let Some(frame) = event else {
return Ok(());
};
let (frame_type, flags, stream_id, payload) = frame?;
capture.frame_sequence.push((frame_type, flags, stream_id));
match frame_type {
H2_TYPE_PRIORITY => {
let (stream_dependency, weight, exclusive) = decode_h2_test_priority(&payload);
capture
.priority_frames
.push((stream_id, stream_dependency, weight, exclusive));
}
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
capture.client_settings = Some(decode_h2_test_settings(&payload));
write_h2_test_frame(stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[]).await?;
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(stream, H2_TYPE_PING, H2_ACK, 0, &payload).await?;
}
}
H2_TYPE_WINDOW_UPDATE if stream_id == 0 => {
capture.connection_window_increment =
Some(decode_h2_test_window_update(&payload));
}
H2_TYPE_GOAWAY => {
return Err(Error::new(
ErrorKind::Transport,
"test server received GOAWAY before response was sent",
));
}
_ => {}
}
}
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[derive(Clone, Copy)]
enum H2ProtocolErrorScenario {
DataBeforeHeaders,
TrailersWithoutEndStream,
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_protocol_error_server(
scenario: H2ProtocolErrorScenario,
) -> Result<(String, Arc<AtomicUsize>)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
match scenario {
H2ProtocolErrorScenario::DataBeforeHeaders
| H2ProtocolErrorScenario::TrailersWithoutEndStream => {
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let (stream_id, _request) = read_complete_h2_test_request(
&mut stream,
&mut decoder,
&mut request_states,
)
.await
.unwrap();
match scenario {
H2ProtocolErrorScenario::DataBeforeHeaders => {
write_h2_test_frame(
&mut stream,
H2_TYPE_DATA,
H2_END_STREAM,
stream_id,
b"oops",
)
.await
.unwrap();
}
H2ProtocolErrorScenario::TrailersWithoutEndStream => {
let headers = encode_h2_test_headers(&mut encoder, 200, &[]);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS,
stream_id,
&headers,
)
.await
.unwrap();
let trailers = encode_h2_test_trailers(
&mut encoder,
&[("x-bad".to_owned(), "1".to_owned())],
);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS,
stream_id,
&trailers,
)
.await
.unwrap();
}
}
}
}
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_keepalive_ack_server() -> Result<(String, Arc<AtomicUsize>)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 keepalive test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 keepalive test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let (first_stream_id, _) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
send_h2_test_response(
&mut stream,
&mut encoder,
first_stream_id,
H2TestResponse::text(200, "one"),
)
.await
.unwrap();
let (second_stream_id, _) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
send_h2_test_response(
&mut stream,
&mut encoder,
second_stream_id,
H2TestResponse::text(200, "two"),
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_keepalive_timeout_server() -> Result<(String, Arc<AtomicUsize>)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 keepalive timeout server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 keepalive timeout server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let (first_stream_id, _) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
send_h2_test_response(
&mut stream,
&mut encoder,
first_stream_id,
H2TestResponse::text(200, "one"),
)
.await
.unwrap();
let mut request_states = HashMap::<u32, H2RequestState>::new();
loop {
let (frame_type, flags, stream_id, payload) =
read_h2_test_frame(&mut stream).await.unwrap();
match frame_type {
H2_TYPE_SETTINGS if flags & H2_ACK == 0 => {
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[])
.await
.unwrap();
}
H2_TYPE_PING if flags & H2_ACK == 0 => {
async_io::Timer::after(std::time::Duration::from_millis(120)).await;
break;
}
H2_TYPE_HEADERS => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
panic!(
"h2 keepalive timeout test unexpectedly reused the first connection"
);
}
}
H2_TYPE_CONTINUATION => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
}
if state.headers_complete && state.end_stream {
panic!(
"h2 keepalive timeout test unexpectedly reused the first connection"
);
}
}
H2_TYPE_DATA => {
let state = request_states.entry(stream_id).or_default();
state.body.extend_from_slice(&payload);
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
if state.headers_complete && state.end_stream {
panic!(
"h2 keepalive timeout test unexpectedly reused the first connection"
);
}
}
H2_TYPE_WINDOW_UPDATE => {}
H2_TYPE_GOAWAY => break,
_ => {}
}
}
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let (second_stream_id, _) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
send_h2_test_response(
&mut stream,
&mut encoder,
second_stream_id,
H2TestResponse::text(200, "two"),
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
))
}
#[cfg(feature = "h3")]
static NEXT_H3_TEST_ASSET_ID: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(1);
#[cfg(feature = "h3")]
fn next_h3_test_asset_id() -> u64 {
NEXT_H3_TEST_ASSET_ID.fetch_add(1, Ordering::Relaxed)
}
#[cfg(feature = "h3")]
fn next_h3_test_asset_paths() -> (std::path::PathBuf, std::path::PathBuf) {
let unique = next_h3_test_asset_id();
let pid = std::process::id();
let cert_path = std::env::temp_dir().join(format!("request-h3-{pid}-{unique}.crt"));
let key_path = std::env::temp_dir().join(format!("request-h3-{pid}-{unique}.key"));
(cert_path, key_path)
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_text_server(body: &'static str) -> Result<String> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to generate h3 test certificate",
err,
)
})?;
let (cert_path, key_path) = next_h3_test_asset_paths();
std::fs::write(&cert_path, cert.cert.pem()).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to write h3 test certificate",
err,
)
})?;
std::fs::write(&key_path, cert.signing_key.serialize_pem()).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to write h3 test key", err)
})?;
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to build h3 test config", err)
})?;
config
.load_cert_chain_from_pem_file(cert_path.to_str().unwrap())
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to load h3 test certificate",
err,
)
})?;
config
.load_priv_key_from_pem_file(key_path.to_str().unwrap())
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to load h3 test key", err)
})?;
config
.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to configure h3 test alpn",
err,
)
})?;
config.set_max_idle_timeout(30_000);
config.set_max_recv_udp_payload_size(1350);
config.set_max_send_udp_payload_size(1350);
config.set_initial_max_data(10_000_000);
config.set_initial_max_stream_data_bidi_local(1_000_000);
config.set_initial_max_stream_data_bidi_remote(1_000_000);
config.set_initial_max_stream_data_uni(1_000_000);
config.set_initial_max_streams_bidi(16);
config.set_initial_max_streams_uni(16);
config.set_disable_active_migration(true);
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let body = body.as_bytes().to_vec();
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut responded = false;
loop {
if let Some(conn) = conn.as_mut() {
loop {
match conn.send(&mut send_buffer) {
Ok((written, send_info)) => {
socket
.send_to(&send_buffer[..written], send_info.to)
.await
.unwrap();
}
Err(quiche::Error::Done) => break,
Err(err) => panic!("h3 test server send failed: {err:?}"),
}
}
}
if responded {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xAB_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((stream_id, quiche::h3::Event::Headers { .. })) => {
let _ =
conn.stream_shutdown(stream_id, quiche::Shutdown::Read, 0);
let content_length = body.len().to_string();
let headers = vec![
quiche::h3::Header::new(b":status", b"200"),
quiche::h3::Header::new(
b"content-length",
content_length.as_bytes(),
),
];
h3_conn
.send_response(conn, stream_id, &headers, false)
.unwrap();
let written =
h3_conn.send_body(conn, stream_id, &body, true).unwrap();
assert_eq!(written, body.len());
responded = true;
break;
}
Ok((_stream_id, quiche::h3::Event::Data)) => {}
Ok((_stream_id, quiche::h3::Event::Finished)) => {}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => panic!("h3 test server poll failed: {err:?}"),
}
}
}
}
});
});
Ok(format!("https://localhost:{}", addr.port()))
}
#[cfg(feature = "h3")]
#[derive(Clone, Debug, Default)]
struct H3CapturedRequest {
headers: Vec<(String, String)>,
body: Vec<u8>,
}
#[cfg(feature = "h3")]
impl H3CapturedRequest {
fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(existing, _)| existing.eq_ignore_ascii_case(name))
.map(|(_, value)| value.as_str())
}
fn body(&self) -> &[u8] {
&self.body
}
}
#[cfg(feature = "h3")]
#[derive(Clone, Debug, Default)]
struct H3TestResponse {
status: u16,
headers: Vec<(String, String)>,
trailers: Vec<(String, String)>,
body: Vec<u8>,
response_delay: std::time::Duration,
send_goaway: bool,
close_connection_before_response: bool,
close_connection_after_body: bool,
close_connection_after_response_delay: Option<std::time::Duration>,
}
#[cfg(feature = "h3")]
impl H3TestResponse {
fn text(status: u16, body: impl AsRef<[u8]>) -> Self {
let body = body.as_ref().to_vec();
Self {
status,
headers: vec![("content-length".to_owned(), body.len().to_string())],
trailers: Vec::new(),
body,
response_delay: std::time::Duration::default(),
send_goaway: false,
close_connection_before_response: false,
close_connection_after_body: false,
close_connection_after_response_delay: None,
}
}
fn body_bytes(mut self, body: impl AsRef<[u8]>) -> Self {
self.body = body.as_ref().to_vec();
self.headers
.retain(|(name, _)| !name.eq_ignore_ascii_case("content-length"));
self.headers
.push(("content-length".to_owned(), self.body.len().to_string()));
self
}
fn header(mut self, name: &str, value: impl Into<String>) -> Self {
self.headers.push((name.to_owned(), value.into()));
self
}
fn trailer(mut self, name: &str, value: impl Into<String>) -> Self {
self.trailers.push((name.to_owned(), value.into()));
self
}
fn delay(mut self, duration: std::time::Duration) -> Self {
self.response_delay = duration;
self
}
fn goaway(mut self) -> Self {
self.send_goaway = true;
self
}
fn close_connection_before_response(mut self) -> Self {
self.close_connection_before_response = true;
self
}
fn close_connection_after_body(mut self) -> Self {
self.close_connection_after_body = true;
self
}
fn close_connection_after_response_delay(mut self, duration: std::time::Duration) -> Self {
self.close_connection_after_response_delay = Some(duration);
self
}
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_recording_server(
response: H3TestResponse,
) -> Result<(String, Arc<Mutex<Vec<H3CapturedRequest>>>)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to generate h3 test certificate",
err,
)
})?;
let (cert_path, key_path) = next_h3_test_asset_paths();
std::fs::write(&cert_path, cert.cert.pem()).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to write h3 test certificate",
err,
)
})?;
std::fs::write(&key_path, cert.signing_key.serialize_pem()).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to write h3 test key", err)
})?;
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to build h3 test config", err)
})?;
config
.load_cert_chain_from_pem_file(cert_path.to_str().unwrap())
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to load h3 test certificate",
err,
)
})?;
config
.load_priv_key_from_pem_file(key_path.to_str().unwrap())
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to load h3 test key", err)
})?;
config
.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to configure h3 test alpn",
err,
)
})?;
config.set_max_idle_timeout(30_000);
config.set_max_recv_udp_payload_size(1350);
config.set_max_send_udp_payload_size(1350);
config.set_initial_max_data(10_000_000);
config.set_initial_max_stream_data_bidi_local(1_000_000);
config.set_initial_max_stream_data_bidi_remote(1_000_000);
config.set_initial_max_stream_data_uni(1_000_000);
config.set_initial_max_streams_bidi(16);
config.set_initial_max_streams_uni(16);
config.set_disable_active_migration(true);
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut responded = false;
loop {
if let Some(conn) = conn.as_mut() {
loop {
match conn.send(&mut send_buffer) {
Ok((written, send_info)) => {
socket
.send_to(&send_buffer[..written], send_info.to)
.await
.unwrap();
}
Err(quiche::Error::Done) => break,
Err(err) => panic!("h3 test server send failed: {err:?}"),
}
}
}
if responded {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xAB_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((stream_id, quiche::h3::Event::Headers { list, .. })) => {
let headers = list
.into_iter()
.map(|header| {
(
String::from_utf8(header.name().to_vec()).unwrap(),
String::from_utf8(header.value().to_vec()).unwrap(),
)
})
.collect::<Vec<_>>();
task_recorded_requests.lock().unwrap().push(
H3CapturedRequest {
headers,
body: Vec::new(),
},
);
let _ =
conn.stream_shutdown(stream_id, quiche::Shutdown::Read, 0);
let headers = std::iter::once((
":status".to_owned(),
response.status.to_string(),
))
.chain(response.headers.iter().cloned())
.map(|(name, value)| {
quiche::h3::Header::new(name.as_bytes(), value.as_bytes())
})
.collect::<Vec<_>>();
h3_conn
.send_response(conn, stream_id, &headers, false)
.unwrap();
let written = h3_conn
.send_body(conn, stream_id, &response.body, true)
.unwrap();
assert_eq!(written, response.body.len());
responded = true;
break;
}
Ok((_stream_id, quiche::h3::Event::Data)) => {}
Ok((_stream_id, quiche::h3::Event::Finished)) => {}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => panic!("h3 test server poll failed: {err:?}"),
}
}
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
recorded_requests,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_cancel_server() -> Result<(String, Arc<AtomicUsize>, Arc<Mutex<Vec<u32>>>)>
{
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
let reset_codes = Arc::new(Mutex::new(Vec::new()));
let task_reset_codes = Arc::clone(&reset_codes);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let (first_stream_id, _first_request) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
let first_headers = encode_h2_test_headers(
&mut encoder,
200,
&[("content-length".to_owned(), "11".to_owned())],
);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS,
first_stream_id,
&first_headers,
)
.await
.unwrap();
write_h2_test_frame(&mut stream, H2_TYPE_DATA, 0, first_stream_id, b"hello")
.await
.unwrap();
loop {
let (frame_type, flags, stream_id, payload) =
read_h2_test_frame(&mut stream).await.unwrap();
match frame_type {
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[])
.await
.unwrap();
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_PING, H2_ACK, 0, &payload)
.await
.unwrap();
}
}
H2_TYPE_WINDOW_UPDATE => {}
H2_TYPE_RST_STREAM if stream_id == first_stream_id => {
task_reset_codes.lock().unwrap().push(u32::from_be_bytes([
payload[0], payload[1], payload[2], payload[3],
]));
break;
}
_ => {}
}
}
let (second_stream_id, _second_request) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
send_h2_test_response(
&mut stream,
&mut encoder,
second_stream_id,
H2TestResponse::text(200, "ok"),
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
reset_codes,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_early_response_server() -> Result<(
String,
Arc<AtomicUsize>,
Arc<Mutex<Vec<u32>>>,
Arc<Mutex<Vec<H2CapturedRequest>>>,
)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
let reset_codes = Arc::new(Mutex::new(Vec::new()));
let task_reset_codes = Arc::clone(&reset_codes);
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let mut first_stream_id = None;
let mut responded = false;
loop {
let (frame_type, flags, stream_id, payload) =
read_h2_test_frame(&mut stream).await.unwrap();
match frame_type {
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[])
.await
.unwrap();
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_PING, H2_ACK, 0, &payload)
.await
.unwrap();
}
}
H2_TYPE_HEADERS => {
let state = request_states.entry(stream_id).or_default();
first_stream_id.get_or_insert(stream_id);
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers =
decode_h2_test_block(&mut decoder, &state.header_block)
.unwrap();
state.header_block.clear();
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
}
H2_TYPE_CONTINUATION => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers =
decode_h2_test_block(&mut decoder, &state.header_block)
.unwrap();
state.header_block.clear();
}
}
H2_TYPE_DATA => {
let state = request_states.entry(stream_id).or_default();
state.body.extend_from_slice(&payload);
if !responded {
let headers = encode_h2_test_headers(
&mut encoder,
200,
&[("content-length".to_owned(), "4".to_owned())],
);
write_h2_test_frame(
&mut stream,
H2_TYPE_HEADERS,
H2_END_HEADERS,
stream_id,
&headers,
)
.await
.unwrap();
write_h2_test_frame(
&mut stream,
H2_TYPE_DATA,
H2_END_STREAM,
stream_id,
b"done",
)
.await
.unwrap();
responded = true;
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
}
H2_TYPE_WINDOW_UPDATE => {}
H2_TYPE_RST_STREAM if Some(stream_id) == first_stream_id => {
task_reset_codes.lock().unwrap().push(u32::from_be_bytes([
payload[0], payload[1], payload[2], payload[3],
]));
let state = request_states.remove(&stream_id).unwrap_or_default();
task_recorded_requests
.lock()
.unwrap()
.push(H2CapturedRequest {
headers: state.headers,
body: state.body,
});
break;
}
_ => {}
}
}
let (second_stream_id, second_request) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
task_recorded_requests.lock().unwrap().push(second_request);
send_h2_test_response(
&mut stream,
&mut encoder,
second_stream_id,
H2TestResponse::text(200, "ok"),
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
reset_codes,
recorded_requests,
))
}
#[cfg(all(feature = "rustls", feature = "h2"))]
async fn spawn_h2_tls_upload_error_server() -> Result<(
String,
Arc<AtomicUsize>,
Arc<Mutex<Vec<u32>>>,
Arc<Mutex<Vec<H2CapturedRequest>>>,
)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to generate certificate", err)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to bind h2 tls test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h2 tls test server",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
let reset_codes = Arc::new(Mutex::new(Vec::new()));
let task_reset_codes = Arc::clone(&reset_codes);
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
task_connection_count.fetch_add(1, Ordering::SeqCst);
let mut stream = acceptor.accept(stream).await.unwrap();
let mut preface = [0_u8; 24];
stream.read_exact(&mut preface).await.unwrap();
assert_eq!(preface.as_slice(), H2_CLIENT_PREFACE);
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, 0, 0, &[])
.await
.unwrap();
let mut encoder = HpackEncoder::new();
let mut decoder = HpackDecoder::new();
let mut request_states = HashMap::<u32, H2RequestState>::new();
let mut first_stream_id = None;
loop {
let (frame_type, flags, stream_id, payload) =
read_h2_test_frame(&mut stream).await.unwrap();
match frame_type {
H2_TYPE_SETTINGS => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_SETTINGS, H2_ACK, 0, &[])
.await
.unwrap();
}
}
H2_TYPE_PING => {
if flags & H2_ACK == 0 {
write_h2_test_frame(&mut stream, H2_TYPE_PING, H2_ACK, 0, &payload)
.await
.unwrap();
}
}
H2_TYPE_HEADERS => {
let state = request_states.entry(stream_id).or_default();
first_stream_id.get_or_insert(stream_id);
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers =
decode_h2_test_block(&mut decoder, &state.header_block)
.unwrap();
state.header_block.clear();
}
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
}
H2_TYPE_CONTINUATION => {
let state = request_states.entry(stream_id).or_default();
state.header_block.extend_from_slice(&payload);
if flags & H2_END_HEADERS != 0 {
state.headers_complete = true;
state.headers =
decode_h2_test_block(&mut decoder, &state.header_block)
.unwrap();
state.header_block.clear();
}
}
H2_TYPE_DATA => {
let state = request_states.entry(stream_id).or_default();
state.body.extend_from_slice(&payload);
if flags & H2_END_STREAM != 0 {
state.end_stream = true;
}
}
H2_TYPE_WINDOW_UPDATE => {}
H2_TYPE_RST_STREAM if Some(stream_id) == first_stream_id => {
task_reset_codes.lock().unwrap().push(u32::from_be_bytes([
payload[0], payload[1], payload[2], payload[3],
]));
let state = request_states.remove(&stream_id).unwrap_or_default();
task_recorded_requests
.lock()
.unwrap()
.push(H2CapturedRequest {
headers: state.headers,
body: state.body,
});
break;
}
_ => {}
}
}
let (second_stream_id, second_request) =
read_complete_h2_test_request(&mut stream, &mut decoder, &mut request_states)
.await
.unwrap();
task_recorded_requests.lock().unwrap().push(second_request);
send_h2_test_response(
&mut stream,
&mut encoder,
second_stream_id,
H2TestResponse::text(200, "ok"),
)
.await
.unwrap();
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
reset_codes,
recorded_requests,
))
}
#[test]
fn performs_plain_http_request() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nX-Test: ok\r\n\r\nhello",
))
.unwrap();
let response = run(get(format!("{base}/users"))).unwrap();
let body = block_on(response.text()).unwrap();
assert_eq!(body, "hello");
}
#[test]
fn reports_download_progress_on_request() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let captured = Arc::clone(&events);
let response = run(get(format!("{base}/progress")).on_progress(
move |progress| {
captured.lock().unwrap().push(progress);
},
ProgressConfig::default(),
))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
let events = events.lock().unwrap();
assert!(events.iter().any(|event| {
event.phase() == crate::ProgressPhase::Download
&& event.transferred() == 5
&& event.total() == Some(5)
&& event.is_done()
}));
}
#[test]
fn reports_upload_progress_on_request() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.contains("\r\ncontent-length: 5\r\n"));
assert!(request.ends_with("hello"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let captured = Arc::clone(&events);
let response = run(get(format!("{base}/upload")).body("hello").on_progress(
move |progress| {
captured.lock().unwrap().push(progress);
},
ProgressConfig::default(),
))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let events = events.lock().unwrap();
assert!(events.iter().any(|event| {
event.phase() == crate::ProgressPhase::Upload
&& event.transferred() == 5
&& event.total() == Some(5)
&& event.is_done()
}));
}
#[test]
fn request_progress_callback_overrides_client_default() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client_events = Arc::new(Mutex::new(Vec::new()));
let request_events = Arc::new(Mutex::new(Vec::new()));
let captured_client = Arc::clone(&client_events);
let captured_request = Arc::clone(&request_events);
let client = crate::Client::builder()
.on_progress(
move |progress| {
captured_client.lock().unwrap().push(progress);
},
ProgressConfig::default(),
)
.build()
.unwrap();
let response = run(client.get(format!("{base}/override")).on_progress(
move |progress| {
captured_request.lock().unwrap().push(progress);
},
ProgressConfig::default(),
))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
assert!(client_events.lock().unwrap().is_empty());
assert!(!request_events.lock().unwrap().is_empty());
}
#[test]
fn reports_download_progress_without_content_length() {
let base = block_on(spawn_test_server("HTTP/1.1 200 OK\r\n\r\nhello")).unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let captured = Arc::clone(&events);
let response = run(get(format!("{base}/streaming")).on_progress(
move |progress| {
captured.lock().unwrap().push(progress);
},
ProgressConfig::default(),
))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
let events = events.lock().unwrap();
assert!(events.iter().any(|event| {
event.phase() == crate::ProgressPhase::Download
&& event.transferred() == 5
&& event.total().is_none()
&& event.is_done()
}));
}
#[test]
fn redirect_keeps_progress_callback_active() {
let target = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 7\r\n\r\ntarget!",
))
.unwrap();
let redirect_response =
format!("HTTP/1.1 302 Found\r\nLocation: {target}/done\r\nContent-Length: 0\r\n\r\n");
let redirect = block_on(spawn_test_server(Box::leak(
redirect_response.into_boxed_str(),
)))
.unwrap();
let events = Arc::new(Mutex::new(Vec::new()));
let captured = Arc::clone(&events);
let response = run(get(format!("{redirect}/start")).on_progress(
move |progress| {
captured.lock().unwrap().push(progress);
},
ProgressConfig::default(),
))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "target!");
let events = events.lock().unwrap();
assert!(events.iter().any(|event| {
event.phase() == crate::ProgressPhase::Download
&& event.transferred() == 7
&& event.total() == Some(7)
&& event.is_done()
}));
}
#[test]
fn retries_idempotent_request_after_timeout() {
let base = block_on(spawn_timed_sequence_server(vec![
(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
std::time::Duration::from_millis(50),
),
(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
std::time::Duration::from_millis(0),
),
]))
.unwrap();
let response = run(get(format!("{base}/retry"))
.read_timeout(std::time::Duration::from_millis(1))
.retry(crate::RetryPolicy::Limit(1)));
assert_eq!(block_on(response.unwrap().text()).unwrap(), "ok");
}
#[test]
fn does_not_retry_non_idempotent_post_by_default() {
let base = block_on(spawn_timed_sequence_server(vec![
(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
std::time::Duration::from_millis(50),
),
(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
std::time::Duration::from_millis(0),
),
]))
.unwrap();
let result = run(crate::post(format!("{base}/retry"))
.body("hello")
.read_timeout(std::time::Duration::from_millis(10))
.retry(crate::RetryPolicy::Limit(1)));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Timeout));
}
#[test]
fn request_retry_policy_overrides_client_default() {
let base = block_on(spawn_timed_sequence_server(vec![
(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
std::time::Duration::from_millis(50),
),
(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
std::time::Duration::from_millis(0),
),
]))
.unwrap();
let client = crate::Client::builder()
.retry(crate::RetryPolicy::None)
.build()
.unwrap();
let response = run(client
.get(format!("{base}/retry"))
.read_timeout(std::time::Duration::from_millis(10))
.retry(crate::RetryPolicy::Limit(1)))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn reuses_keep_alive_connection_for_same_client() {
let base = block_on(spawn_keep_alive_server()).unwrap();
let client = crate::Client::builder().build().unwrap();
let first = run(client.get(format!("{base}/one"))).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "ok");
let second = run(client.get(format!("{base}/two"))).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
}
#[test]
fn max_idle_per_host_zero_disables_connection_reuse() {
let (base, connection_count) = block_on(spawn_connection_count_server(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok",
2,
))
.unwrap();
let client = crate::Client::builder()
.max_idle_per_host(0)
.build()
.unwrap();
let first = run(client.get(format!("{base}/one"))).unwrap().text();
let second = run(client.get(format!("{base}/two"))).unwrap().text();
assert_eq!(block_on(first).unwrap(), "ok");
assert_eq!(block_on(second).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[test]
fn connection_close_request_header_disables_reuse() {
let (base, connection_count) = block_on(spawn_connection_count_server(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok",
2,
))
.unwrap();
let client = crate::Client::builder().build().unwrap();
let first = run(client
.get(format!("{base}/one"))
.header("connection", "close")
.unwrap())
.unwrap()
.text();
let second = run(client.get(format!("{base}/two"))).unwrap().text();
assert_eq!(block_on(first).unwrap(), "ok");
assert_eq!(block_on(second).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[test]
fn connection_close_response_header_disables_reuse() {
let (base, connection_count) = block_on(spawn_connection_count_server(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok",
2,
))
.unwrap();
let client = crate::Client::builder().build().unwrap();
let first = run(client.get(format!("{base}/one"))).unwrap().text();
let second = run(client.get(format!("{base}/two"))).unwrap().text();
assert_eq!(block_on(first).unwrap(), "ok");
assert_eq!(block_on(second).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[test]
fn idle_timeout_expiry_prevents_reuse() {
let (base, connection_count) = block_on(spawn_connection_count_server(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok",
2,
))
.unwrap();
let client = crate::Client::builder()
.idle_timeout(std::time::Duration::from_millis(1))
.build()
.unwrap();
let first = run(client.get(format!("{base}/one"))).unwrap().text();
assert_eq!(block_on(first).unwrap(), "ok");
std::thread::sleep(std::time::Duration::from_millis(10));
let second = run(client.get(format!("{base}/two"))).unwrap().text();
assert_eq!(block_on(second).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[test]
fn different_clients_do_not_share_connection_pool() {
let (base, connection_count) = block_on(spawn_connection_count_server(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nok",
2,
))
.unwrap();
let client_a = crate::Client::builder().build().unwrap();
let client_b = crate::Client::builder().build().unwrap();
let first = run(client_a.get(format!("{base}/one"))).unwrap().text();
let second = run(client_b.get(format!("{base}/two"))).unwrap().text();
assert_eq!(block_on(first).unwrap(), "ok");
assert_eq!(block_on(second).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
fn build_h3_test_server_config() -> Result<quiche::Config> {
build_h3_test_server_config_with_max_streams_bidi(16)
}
#[cfg(feature = "h3")]
fn build_h3_test_server_config_with_max_streams_bidi(
initial_max_streams_bidi: u64,
) -> Result<quiche::Config> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to generate h3 test certificate",
err,
)
})?;
let (cert_path, key_path) = next_h3_test_asset_paths();
std::fs::write(&cert_path, cert.cert.pem()).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to write h3 test certificate",
err,
)
})?;
std::fs::write(&key_path, cert.signing_key.serialize_pem()).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to write h3 test key", err)
})?;
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to build h3 test config", err)
})?;
config
.load_cert_chain_from_pem_file(cert_path.to_str().unwrap())
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to load h3 test certificate",
err,
)
})?;
config
.load_priv_key_from_pem_file(key_path.to_str().unwrap())
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to load h3 test key", err)
})?;
config
.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to configure h3 test alpn",
err,
)
})?;
config.set_max_idle_timeout(30_000);
config.set_max_recv_udp_payload_size(1350);
config.set_max_send_udp_payload_size(1350);
config.set_initial_max_data(10_000_000);
config.set_initial_max_stream_data_bidi_local(1_000_000);
config.set_initial_max_stream_data_bidi_remote(1_000_000);
config.set_initial_max_stream_data_uni(1_000_000);
config.set_initial_max_streams_bidi(initial_max_streams_bidi);
config.set_initial_max_streams_uni(16);
config.set_disable_active_migration(true);
Ok(config)
}
#[cfg(feature = "h3")]
async fn flush_h3_test_connection(
socket: &async_net::UdpSocket,
conn: &mut quiche::Connection,
send_buffer: &mut [u8],
) {
loop {
match conn.send(send_buffer) {
Ok((written, send_info)) => {
socket
.send_to(&send_buffer[..written], send_info.to)
.await
.unwrap();
}
Err(quiche::Error::Done) => break,
Err(err) => panic!("h3 test server send failed: {err:?}"),
}
}
}
#[cfg(feature = "h3")]
async fn send_h3_test_response(
socket: &async_net::UdpSocket,
conn: &mut quiche::Connection,
h3_conn: &mut quiche::h3::Connection,
stream_id: u64,
response: &H3TestResponse,
send_buffer: &mut [u8],
) {
if !response.response_delay.is_zero() {
async_io::Timer::after(response.response_delay).await;
}
if response.close_connection_before_response {
conn.close(false, 0, b"scripted close").unwrap();
flush_h3_test_connection(socket, conn, send_buffer).await;
return;
}
let headers = std::iter::once((":status".to_owned(), response.status.to_string()))
.chain(response.headers.iter().cloned())
.map(|(name, value)| quiche::h3::Header::new(name.as_bytes(), value.as_bytes()))
.collect::<Vec<_>>();
let fin_with_headers = response.body.is_empty() && response.trailers.is_empty();
h3_conn
.send_response(conn, stream_id, &headers, fin_with_headers)
.unwrap();
flush_h3_test_connection(socket, conn, send_buffer).await;
if !response.body.is_empty() {
let fin = response.trailers.is_empty();
let written = h3_conn
.send_body(conn, stream_id, &response.body, fin)
.unwrap();
assert_eq!(written, response.body.len());
flush_h3_test_connection(socket, conn, send_buffer).await;
if response.close_connection_after_body {
conn.close(false, 0, b"scripted close after body").unwrap();
flush_h3_test_connection(socket, conn, send_buffer).await;
return;
}
}
if !response.trailers.is_empty() {
let trailers = response
.trailers
.iter()
.map(|(name, value)| quiche::h3::Header::new(name.as_bytes(), value.as_bytes()))
.collect::<Vec<_>>();
h3_conn
.send_additional_headers(conn, stream_id, &trailers, true, true)
.unwrap();
flush_h3_test_connection(socket, conn, send_buffer).await;
}
if response.send_goaway {
h3_conn.send_goaway(conn, stream_id).unwrap();
flush_h3_test_connection(socket, conn, send_buffer).await;
}
if let Some(delay) = response.close_connection_after_response_delay {
async_io::Timer::after(delay).await;
conn.close(false, 0, b"scripted idle close").unwrap();
flush_h3_test_connection(socket, conn, send_buffer).await;
}
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_script_server(
responses: Vec<H3TestResponse>,
) -> Result<(String, Arc<std::sync::atomic::AtomicUsize>)> {
spawn_h3_quic_script_server_with_max_streams_bidi(responses, 16).await
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_script_server_with_max_streams_bidi(
responses: Vec<H3TestResponse>,
initial_max_streams_bidi: u64,
) -> Result<(String, Arc<std::sync::atomic::AtomicUsize>)> {
let mut config =
build_h3_test_server_config_with_max_streams_bidi(initial_max_streams_bidi)?;
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let connection_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
struct TestConnection {
conn: quiche::Connection,
h3_conn: Option<quiche::h3::Connection>,
}
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut pending = std::collections::VecDeque::from(responses);
let mut connections =
std::collections::HashMap::<std::net::SocketAddr, TestConnection>::new();
let mut next_conn_id = 1_u64;
let mut completed = false;
loop {
for state in connections.values_mut() {
flush_h3_test_connection(&socket, &mut state.conn, &mut send_buffer).await;
}
if completed {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
let state = connections.entry(from).or_insert_with(|| {
let mut scid_bytes = [0_u8; quiche::MAX_CONN_ID_LEN];
scid_bytes[..8].copy_from_slice(&next_conn_id.to_be_bytes());
next_conn_id += 1;
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
task_connection_count.fetch_add(1, Ordering::SeqCst);
TestConnection {
conn: quiche::accept(&scid, None, local_addr, from, &mut config)
.unwrap(),
h3_conn: None,
}
});
state
.conn
.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (state.conn.is_in_early_data() || state.conn.is_established())
&& state.h3_conn.is_none()
{
state.h3_conn = Some(
quiche::h3::Connection::with_transport(&mut state.conn, &h3_config)
.unwrap(),
);
}
if let Some(h3_conn) = state.h3_conn.as_mut() {
loop {
match h3_conn.poll(&mut state.conn) {
Ok((stream_id, quiche::h3::Event::Headers { .. })) => {
let response =
pending.pop_front().expect("h3 scripted response");
send_h3_test_response(
&socket,
&mut state.conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
if pending.is_empty() {
completed = true;
}
}
Ok((stream_id, quiche::h3::Event::Data)) => {
let mut discard = [0_u8; 4096];
loop {
match h3_conn.recv_body(
&mut state.conn,
stream_id,
&mut discard,
) {
Ok(0) | Err(quiche::h3::Error::Done) => break,
Ok(_) => {}
Err(err) => {
panic!("h3 script server recv body failed: {err:?}")
}
}
}
}
Ok((_stream_id, quiche::h3::Event::Finished)) => {}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => panic!("h3 script server poll failed: {err:?}"),
}
}
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
))
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_recording_body_server(
response: H3TestResponse,
) -> Result<(String, Arc<Mutex<Vec<H3CapturedRequest>>>)> {
let mut config = build_h3_test_server_config()?;
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut request_states = std::collections::HashMap::<u64, H3CapturedRequest>::new();
let mut responded = false;
loop {
if let Some(conn) = conn.as_mut() {
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
}
if responded {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xAB_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((
stream_id,
quiche::h3::Event::Headers { list, more_frames },
)) => {
let headers = list
.into_iter()
.map(|header| {
(
String::from_utf8(header.name().to_vec()).unwrap(),
String::from_utf8(header.value().to_vec()).unwrap(),
)
})
.collect::<Vec<_>>();
request_states.insert(
stream_id,
H3CapturedRequest {
headers,
body: Vec::new(),
},
);
if !more_frames {
let request = request_states.remove(&stream_id).unwrap();
task_recorded_requests.lock().unwrap().push(request);
send_h3_test_response(
&socket,
conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
responded = true;
break;
}
}
Ok((stream_id, quiche::h3::Event::Data)) => {
let state = request_states.get_mut(&stream_id).unwrap();
loop {
let mut chunk = [0_u8; 4096];
match h3_conn.recv_body(conn, stream_id, &mut chunk) {
Ok(0) | Err(quiche::h3::Error::Done) => break,
Ok(read) => {
state.body.extend_from_slice(&chunk[..read])
}
Err(err) => panic!(
"h3 recording server recv body failed: {err:?}"
),
}
}
}
Ok((stream_id, quiche::h3::Event::Finished)) => {
let request = request_states.remove(&stream_id).unwrap();
task_recorded_requests.lock().unwrap().push(request);
send_h3_test_response(
&socket,
conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
responded = true;
break;
}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => panic!("h3 recording server poll failed: {err:?}"),
}
}
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
recorded_requests,
))
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_grpc_bidi_probe_server()
-> Result<(String, Arc<Mutex<Vec<H3CapturedRequest>>>)> {
let mut config = build_h3_test_server_config()?;
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
#[derive(Default)]
struct RequestState {
headers: Vec<(String, String)>,
body: Vec<u8>,
responded_first: bool,
}
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut request_states = std::collections::HashMap::<u64, RequestState>::new();
loop {
if let Some(conn) = conn.as_mut() {
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xBC_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((stream_id, quiche::h3::Event::Headers { list, .. })) => {
let headers = list
.into_iter()
.map(|header| {
(
String::from_utf8(header.name().to_vec()).unwrap(),
String::from_utf8(header.value().to_vec()).unwrap(),
)
})
.collect::<Vec<_>>();
request_states.insert(
stream_id,
RequestState {
headers,
body: Vec::new(),
responded_first: false,
},
);
}
Ok((stream_id, quiche::h3::Event::Data)) => {
let mut chunk = [0_u8; 4096];
loop {
match h3_conn.recv_body(conn, stream_id, &mut chunk) {
Ok(0) | Err(quiche::h3::Error::Done) => break,
Ok(read) => {
let state =
request_states.get_mut(&stream_id).unwrap();
state.body.extend_from_slice(&chunk[..read]);
if !state.responded_first && !state.body.is_empty()
{
let headers = vec![
quiche::h3::Header::new(b":status", b"200"),
quiche::h3::Header::new(
b"content-type",
b"application/grpc+json",
),
quiche::h3::Header::new(
b"x-stream-bin",
b"AQI=",
),
];
h3_conn
.send_response(
conn, stream_id, &headers, false,
)
.unwrap();
flush_h3_test_connection(
&socket,
conn,
&mut send_buffer,
)
.await;
let written = h3_conn
.send_body(
conn,
stream_id,
&grpc_test_json_frame(
&serde_json::json!({ "seq": 1 }),
),
false,
)
.unwrap();
assert_eq!(
written,
grpc_test_json_frame(
&serde_json::json!({ "seq": 1 })
)
.len()
);
flush_h3_test_connection(
&socket,
conn,
&mut send_buffer,
)
.await;
state.responded_first = true;
}
}
Err(err) => {
panic!("h3 grpc bidi recv body failed: {err:?}")
}
}
}
}
Ok((stream_id, quiche::h3::Event::Finished)) => {
let state = request_states.remove(&stream_id).unwrap();
task_recorded_requests.lock().unwrap().push(
H3CapturedRequest {
headers: state.headers,
body: state.body,
},
);
let written = h3_conn
.send_body(
conn,
stream_id,
&grpc_test_json_frame(&serde_json::json!({ "seq": 2 })),
false,
)
.unwrap();
assert_eq!(
written,
grpc_test_json_frame(&serde_json::json!({ "seq": 2 }))
.len()
);
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
let trailers = vec![
quiche::h3::Header::new(b"grpc-status", b"0"),
quiche::h3::Header::new(b"x-trace-bin", b"AAE="),
];
h3_conn
.send_additional_headers(
conn, stream_id, &trailers, true, true,
)
.unwrap();
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
break;
}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => break,
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => panic!("h3 grpc bidi server poll failed: {err:?}"),
}
}
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
recorded_requests,
))
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_streaming_body_server() -> Result<String> {
let mut config = build_h3_test_server_config()?;
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut responded = false;
loop {
if let Some(conn) = conn.as_mut() {
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
}
if responded {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xAB_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((stream_id, quiche::h3::Event::Headers { more_frames, .. })) => {
if !more_frames {
let headers = vec![
quiche::h3::Header::new(b":status", b"200"),
quiche::h3::Header::new(b"content-length", b"11"),
];
h3_conn
.send_response(conn, stream_id, &headers, false)
.unwrap();
flush_h3_test_connection(&socket, conn, &mut send_buffer)
.await;
let written = h3_conn
.send_body(conn, stream_id, b"hello", false)
.unwrap();
assert_eq!(written, 5);
flush_h3_test_connection(&socket, conn, &mut send_buffer)
.await;
async_io::Timer::after(std::time::Duration::from_millis(
50,
))
.await;
let written = h3_conn
.send_body(conn, stream_id, b" world", false)
.unwrap();
assert_eq!(written, 6);
flush_h3_test_connection(&socket, conn, &mut send_buffer)
.await;
let trailers =
vec![quiche::h3::Header::new(b"x-finished", b"yes")];
h3_conn
.send_additional_headers(
conn, stream_id, &trailers, true, true,
)
.unwrap();
flush_h3_test_connection(&socket, conn, &mut send_buffer)
.await;
responded = true;
break;
}
}
Ok((stream_id, quiche::h3::Event::Data)) => {
let mut discard = [0_u8; 4096];
loop {
match h3_conn.recv_body(conn, stream_id, &mut discard) {
Ok(0) | Err(quiche::h3::Error::Done) => break,
Ok(_) => {}
Err(err) => panic!(
"h3 streaming server recv body failed: {err:?}"
),
}
}
}
Ok((_stream_id, quiche::h3::Event::Finished)) => {}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => panic!("h3 streaming server poll failed: {err:?}"),
}
}
}
}
});
});
Ok(format!("https://localhost:{}", addr.port()))
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_early_response_server() -> Result<(String, Arc<AtomicUsize>)> {
let mut config = build_h3_test_server_config()?;
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut responded_early = false;
let mut finished = false;
loop {
if let Some(conn) = conn.as_mut() {
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
}
if finished {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xCD_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
task_connection_count.fetch_add(1, Ordering::SeqCst);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((stream_id, quiche::h3::Event::Headers { more_frames, .. })) => {
if responded_early && !more_frames {
let response = H3TestResponse::text(200, "ok");
send_h3_test_response(
&socket,
conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
finished = true;
break;
}
}
Ok((stream_id, quiche::h3::Event::Data)) => {
let mut saw_body = false;
loop {
let mut chunk = [0_u8; 4096];
match h3_conn.recv_body(conn, stream_id, &mut chunk) {
Ok(0) | Err(quiche::h3::Error::Done) => break,
Ok(read) => saw_body |= read > 0,
Err(err) => panic!(
"h3 early-response server recv body failed: {err:?}"
),
}
}
if saw_body && !responded_early {
let response = H3TestResponse::text(200, "done");
send_h3_test_response(
&socket,
conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
responded_early = true;
}
}
Ok((_stream_id, quiche::h3::Event::Finished)) => {}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => {
panic!("h3 early-response server poll failed: {err:?}")
}
}
}
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
))
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_cancel_server() -> Result<(String, Arc<AtomicUsize>)> {
let mut config = build_h3_test_server_config()?;
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut first_response_sent = false;
let mut finished = false;
loop {
if let Some(conn) = conn.as_mut() {
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
}
if finished {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xCE_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
task_connection_count.fetch_add(1, Ordering::SeqCst);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((stream_id, quiche::h3::Event::Headers { more_frames, .. })) => {
if !first_response_sent {
let headers = vec![
quiche::h3::Header::new(b":status", b"200"),
quiche::h3::Header::new(b"content-length", b"11"),
];
h3_conn
.send_response(conn, stream_id, &headers, false)
.unwrap();
flush_h3_test_connection(&socket, conn, &mut send_buffer)
.await;
let written = h3_conn
.send_body(conn, stream_id, b"hello", false)
.unwrap();
assert_eq!(written, 5);
flush_h3_test_connection(&socket, conn, &mut send_buffer)
.await;
first_response_sent = true;
} else if !more_frames {
let response = H3TestResponse::text(200, "ok");
send_h3_test_response(
&socket,
conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
finished = true;
break;
}
}
Ok((_stream_id, quiche::h3::Event::Data)) => {}
Ok((_stream_id, quiche::h3::Event::Finished)) => {}
Ok((_stream_id, quiche::h3::Event::Reset(_))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => panic!("h3 cancel server poll failed: {err:?}"),
}
}
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
))
}
#[cfg(feature = "h3")]
async fn spawn_h3_quic_upload_error_server()
-> Result<(String, Arc<AtomicUsize>, Arc<Mutex<Vec<H3CapturedRequest>>>)> {
let mut config = build_h3_test_server_config()?;
let socket = async_net::UdpSocket::bind(("127.0.0.1", 0))
.await
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind h3 test socket", err)
})?;
let addr = socket.local_addr().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to inspect h3 test socket",
err,
)
})?;
let connection_count = Arc::new(AtomicUsize::new(0));
let task_connection_count = Arc::clone(&connection_count);
let recorded_requests = Arc::new(Mutex::new(Vec::new()));
let task_recorded_requests = Arc::clone(&recorded_requests);
std::thread::spawn(move || {
futures_lite::future::block_on(async move {
let mut recv_buffer = [0_u8; 65535];
let mut send_buffer = [0_u8; 1350];
let h3_config = quiche::h3::Config::new().unwrap();
let mut conn = None::<quiche::Connection>;
let mut h3_conn = None::<quiche::h3::Connection>;
let mut request_states = std::collections::HashMap::<u64, H3CapturedRequest>::new();
let mut first_stream_recorded = false;
let mut finished = false;
loop {
if let Some(conn) = conn.as_mut() {
flush_h3_test_connection(&socket, conn, &mut send_buffer).await;
}
if finished {
async_io::Timer::after(std::time::Duration::from_millis(50)).await;
break;
}
let (len, from) = socket.recv_from(&mut recv_buffer).await.unwrap();
let local_addr = socket.local_addr().unwrap();
if conn.is_none() {
let scid_bytes = [0xCF_u8; quiche::MAX_CONN_ID_LEN];
let scid = quiche::ConnectionId::from_ref(&scid_bytes);
task_connection_count.fetch_add(1, Ordering::SeqCst);
conn = Some(
quiche::accept(&scid, None, local_addr, from, &mut config).unwrap(),
);
}
let conn = conn.as_mut().unwrap();
conn.recv(
&mut recv_buffer[..len],
quiche::RecvInfo {
from,
to: local_addr,
},
)
.unwrap();
if (conn.is_in_early_data() || conn.is_established()) && h3_conn.is_none() {
h3_conn =
Some(quiche::h3::Connection::with_transport(conn, &h3_config).unwrap());
}
if let Some(h3_conn) = h3_conn.as_mut() {
loop {
match h3_conn.poll(conn) {
Ok((
stream_id,
quiche::h3::Event::Headers { list, more_frames },
)) => {
let headers = list
.into_iter()
.map(|header| {
(
String::from_utf8(header.name().to_vec()).unwrap(),
String::from_utf8(header.value().to_vec()).unwrap(),
)
})
.collect::<Vec<_>>();
if stream_id != 0 && !first_stream_recorded {
if let Some(request) = request_states.remove(&0) {
task_recorded_requests.lock().unwrap().push(request);
}
first_stream_recorded = true;
}
request_states.insert(
stream_id,
H3CapturedRequest {
headers,
body: Vec::new(),
},
);
if stream_id != 0 && !more_frames {
let request = request_states.remove(&stream_id).unwrap();
task_recorded_requests.lock().unwrap().push(request);
let response = H3TestResponse::text(200, "ok");
send_h3_test_response(
&socket,
conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
finished = true;
break;
}
}
Ok((stream_id, quiche::h3::Event::Data)) => {
let state = request_states.entry(stream_id).or_default();
loop {
let mut chunk = [0_u8; 4096];
match h3_conn.recv_body(conn, stream_id, &mut chunk) {
Ok(0) | Err(quiche::h3::Error::Done) => break,
Ok(read) => {
state.body.extend_from_slice(&chunk[..read]);
}
Err(err) => panic!(
"h3 upload-error server recv body failed: {err:?}"
),
}
}
}
Ok((stream_id, quiche::h3::Event::Finished)) => {
if stream_id == 0 {
if !first_stream_recorded {
if let Some(request) = request_states.remove(&stream_id)
{
task_recorded_requests
.lock()
.unwrap()
.push(request);
}
first_stream_recorded = true;
}
} else {
let request = request_states.remove(&stream_id).unwrap();
task_recorded_requests.lock().unwrap().push(request);
let response = H3TestResponse::text(200, "ok");
send_h3_test_response(
&socket,
conn,
h3_conn,
stream_id,
&response,
&mut send_buffer,
)
.await;
finished = true;
break;
}
}
Ok((_stream_id, quiche::h3::Event::Reset(_error_code))) => {}
Ok((_flow_id, quiche::h3::Event::PriorityUpdate)) => {}
Ok((_goaway_id, quiche::h3::Event::GoAway)) => {}
Err(quiche::h3::Error::Done) => break,
Err(err) => {
panic!("h3 upload-error server poll failed: {err:?}")
}
}
}
}
}
});
});
Ok((
format!("https://localhost:{}", addr.port()),
connection_count,
recorded_requests,
))
}
#[cfg(feature = "h3")]
#[test]
fn executes_http3_only_request_over_quic() {
let base = block_on(spawn_h3_quic_text_server("hello h3")).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.h2_keepalive(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(30),
)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/h3").http3_only()).unwrap();
assert_eq!(response.version(), crate::Version::Http3);
assert_eq!(block_on(response.text()).unwrap(), "hello h3");
}
#[cfg(feature = "h3")]
#[test]
fn prefer_http3_uses_http3_transport_when_available() {
let (base, _) = block_on(spawn_h3_quic_recording_server(H3TestResponse::text(
200,
"preferred",
)))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.prefer_http3()
.build()
.unwrap();
let response = run(client.get("/pref")).unwrap();
assert_eq!(response.version(), crate::Version::Http3);
assert_eq!(block_on(response.text()).unwrap(), "preferred");
}
#[cfg(feature = "h3")]
#[test]
fn decodes_gzip_http3_response_body() {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"compressed over h3").unwrap();
let compressed = encoder.finish().unwrap();
let response = H3TestResponse::text(200, "placeholder")
.body_bytes(compressed)
.header("content-encoding", "gzip");
let (base, _) = block_on(spawn_h3_quic_recording_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.h2_keepalive(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(30),
)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/gzip").http3_only()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "compressed over h3");
}
#[cfg(feature = "h3")]
#[test]
fn no_content_response_does_not_surface_http3_body_frames() {
let response = H3TestResponse::text(204, "oops");
let (base, _) = block_on(spawn_h3_quic_recording_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
match run(client.get("/no-content").http3_only()) {
Ok(response) => {
assert_eq!(response.version(), crate::Version::Http3);
assert_eq!(response.status().as_u16(), 204);
match block_on(response.text()) {
Ok(text) => assert_eq!(text, ""),
Err(err) => assert_eq!(err.kind(), &crate::ErrorKind::Transport),
}
}
Err(err) => assert_eq!(err.kind(), &crate::ErrorKind::Transport),
}
}
#[cfg(all(feature = "h3", feature = "zstd"))]
#[test]
fn decodes_zstd_http3_response_body() {
let compressed = zstd::stream::encode_all(&b"compressed over h3 zstd"[..], 0).unwrap();
let response = H3TestResponse::text(200, "placeholder")
.body_bytes(compressed)
.header("content-encoding", "zstd");
let (base, _) = block_on(spawn_h3_quic_recording_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/zstd").http3_only()).unwrap();
assert_eq!(
block_on(response.text()).unwrap(),
"compressed over h3 zstd"
);
}
#[cfg(feature = "h3")]
#[test]
fn http3_sends_default_accept_encoding_header() {
let (base, recorded_requests) = block_on(spawn_h3_quic_recording_server(
H3TestResponse::text(200, "ok"),
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/headers").http3_only()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let recorded = recorded_requests.lock().unwrap();
assert_eq!(recorded.len(), 1);
assert_eq!(
recorded[0].header("accept-encoding"),
Some(DEFAULT_ACCEPT_ENCODING)
);
}
#[cfg(feature = "h3")]
#[test]
fn reuses_http3_connection_for_multiple_requests() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "first"),
H3TestResponse::text(200, "second"),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http3_only()).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "first");
let second = run(client.get("/two").http3_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "second");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(feature = "h3")]
#[test]
fn retries_http3_request_after_goaway_on_reused_connection() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "first").goaway(),
H3TestResponse::text(200, "second"),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http3_only()).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "first");
let second = run(client.get("/two").http3_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "second");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
#[test]
fn retries_http3_request_after_peer_closes_reused_connection_before_response() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "first"),
H3TestResponse::text(200, "").close_connection_before_response(),
H3TestResponse::text(200, "retried"),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http3_only()).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "first");
let second = run(client.get("/two").http3_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "retried");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
#[test]
fn opens_new_http3_connection_after_idle_peer_close() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "first")
.close_connection_after_response_delay(std::time::Duration::from_millis(40)),
H3TestResponse::text(200, "second"),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http3_only()).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "first");
std::thread::sleep(std::time::Duration::from_millis(100));
let second = run(client.get("/two").http3_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "second");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
#[test]
fn does_not_retry_http3_request_after_response_headers_have_started() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let mut truncated = grpc_test_json_frame(&HelloReply {
message: "partial".to_owned(),
});
truncated.pop();
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "warm"),
H3TestResponse::text(200, "")
.header("content-type", "application/grpc+json")
.body_bytes(truncated)
.close_connection_after_body(),
H3TestResponse::text(200, "fresh"),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let warm = run(client.get("/warm").http3_only()).unwrap();
assert_eq!(block_on(warm.text()).unwrap(), "warm");
let request = HelloRequest {
name: "Ada".to_owned(),
};
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.http3_only()
.message(&request)?
.send_streaming()
.await
})
.unwrap();
let err = block_on(response.next_message::<HelloReply>()).unwrap_err();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
let fresh = run(client.get("/fresh").http3_only()).unwrap();
assert_eq!(block_on(fresh.text()).unwrap(), "fresh");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
#[test]
fn evicts_idle_http3_goaway_connection_before_reusing_pool_slot() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "first").goaway(),
H3TestResponse::text(200, "second"),
H3TestResponse::text(200, "third"),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.max_idle_per_host(1)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http3_only()).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "first");
let second = run(client.get("/two").http3_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "second");
let third = run(client.get("/three").http3_only()).unwrap();
assert_eq!(block_on(third.text()).unwrap(), "third");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
#[test]
fn evicts_idle_http3_peer_closed_connection_before_reusing_pool_slot() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "first").close_connection_after_body(),
H3TestResponse::text(200, "second"),
H3TestResponse::text(200, "third"),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.max_idle_per_host(1)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http3_only()).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "first");
let second = run(client.get("/two").http3_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "second");
let third = run(client.get("/three").http3_only()).unwrap();
assert_eq!(block_on(third.text()).unwrap(), "third");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
#[test]
fn multiplexes_concurrent_http3_requests_on_single_connection() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server(vec![
H3TestResponse::text(200, "warm"),
H3TestResponse::text(200, "ok").delay(std::time::Duration::from_millis(80)),
H3TestResponse::text(200, "ok").delay(std::time::Duration::from_millis(80)),
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let warm = run(client.get("/warm").http3_only()).unwrap();
assert_eq!(block_on(warm.text()).unwrap(), "warm");
let first_client = client.clone();
let second_client = client.clone();
let first = std::thread::spawn(move || {
let response = run(first_client.get("/one").http3_only()).unwrap();
block_on(response.text()).unwrap()
});
let second = std::thread::spawn(move || {
let response = run(second_client.get("/two").http3_only()).unwrap();
block_on(response.text()).unwrap()
});
assert_eq!(first.join().unwrap(), "ok");
assert_eq!(second.join().unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(feature = "h3")]
#[test]
fn opens_second_http3_connection_when_peer_allows_only_one_bidi_stream() {
let (base, connection_count) = block_on(spawn_h3_quic_script_server_with_max_streams_bidi(
vec![
H3TestResponse::text(200, "warm"),
H3TestResponse::text(200, "first").delay(std::time::Duration::from_millis(80)),
H3TestResponse::text(200, "second"),
],
1,
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let warm = run(client.get("/warm").http3_only()).unwrap();
assert_eq!(block_on(warm.text()).unwrap(), "warm");
let first_client = client.clone();
let second_client = client.clone();
let first = std::thread::spawn(move || {
let response = run(first_client.get("/one").http3_only()).unwrap();
block_on(response.text()).unwrap()
});
std::thread::sleep(std::time::Duration::from_millis(10));
let second = std::thread::spawn(move || {
let response = run(second_client.get("/two").http3_only()).unwrap();
block_on(response.text()).unwrap()
});
let mut responses = vec![first.join().unwrap(), second.join().unwrap()];
responses.sort();
assert_eq!(responses, vec!["first".to_owned(), "second".to_owned()]);
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "h3")]
#[test]
fn early_http3_response_cancels_remaining_request_body_and_keeps_connection_reusable() {
let (base, connection_count) = block_on(spawn_h3_quic_early_response_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let upload_polls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let upload_polls_task = Arc::clone(&upload_polls);
let body_stream = async_stream::stream! {
upload_polls_task.fetch_add(1, Ordering::SeqCst);
yield Ok(Bytes::from_static(b"hello"));
async_io::Timer::after(std::time::Duration::from_millis(80)).await;
upload_polls_task.fetch_add(1, Ordering::SeqCst);
yield Ok(Bytes::from_static(b"world"));
};
let response = run(client
.post("/early")
.http3_only()
.timeout(std::time::Duration::from_secs(1))
.body_stream(Box::pin(body_stream)))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "done");
block_on(async {
async_io::Timer::after(std::time::Duration::from_millis(140)).await;
});
let second = run(client
.get("/after")
.http3_only()
.timeout(std::time::Duration::from_secs(1)))
.unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
assert_eq!(upload_polls.load(Ordering::SeqCst), 1);
}
#[cfg(feature = "h3")]
#[test]
fn dropping_http3_response_stream_keeps_connection_reusable() {
let (base, connection_count) = block_on(spawn_h3_quic_cancel_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/stream").http3_only()).unwrap();
let mut body = response.bytes_stream();
assert_eq!(
block_on(body.next()).unwrap().unwrap(),
Bytes::from_static(b"hello")
);
drop(body);
block_on(async {
async_io::Timer::after(std::time::Duration::from_millis(80)).await;
});
let second = run(client.get("/after").http3_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(feature = "h3")]
#[test]
fn request_body_stream_error_fails_only_the_failed_http3_stream() {
let (base, connection_count, recorded_requests) =
block_on(spawn_h3_quic_upload_error_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let body_stream = async_stream::stream! {
yield Ok(Bytes::from_static(b"hello"));
yield Err(Error::new(ErrorKind::Transport, "upload failed"));
};
let err = run(client
.post("/upload")
.http3_only()
.timeout(std::time::Duration::from_secs(1))
.body_stream(Box::pin(body_stream)))
.unwrap_err();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert!(err.to_string().contains("upload failed"));
let second = run(client
.get("/after")
.http3_only()
.timeout(std::time::Duration::from_secs(1)))
.unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
let recorded_requests = recorded_requests.lock().unwrap();
assert!(
recorded_requests
.iter()
.any(|request| request.header(":path") == Some("/after"))
);
if let Some(upload_request) = recorded_requests
.iter()
.find(|request| request.header(":path") == Some("/upload"))
{
assert_eq!(upload_request.body(), b"hello");
}
}
#[cfg(feature = "h3")]
#[test]
fn streams_http3_response_body_and_trailers() {
let base = block_on(spawn_h3_quic_streaming_body_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/stream").http3_only()).unwrap();
let (mut body, trailers) = response.into_body_stream_and_trailer_state();
assert_eq!(
block_on(body.next()).unwrap().unwrap(),
Bytes::from_static(b"hello")
);
assert_eq!(
block_on(body.next()).unwrap().unwrap(),
Bytes::from_static(b" world")
);
assert!(block_on(body.next()).is_none());
let trailers = trailers.take().unwrap();
assert_eq!(trailers.get("x-finished"), Some("yes"));
}
#[cfg(feature = "h3")]
#[test]
fn streams_http3_request_body_without_content_length() {
let (base, recorded_requests) = block_on(spawn_h3_quic_recording_body_server(
H3TestResponse::text(200, "ok"),
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let body_stream = async_stream::stream! {
yield Ok(Bytes::from_static(b"hello"));
async_io::Timer::after(std::time::Duration::from_millis(25)).await;
yield Ok(Bytes::from_static(b"world"));
};
let response = run(client
.post("/upload")
.http3_only()
.body_stream(Box::pin(body_stream)))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(recorded.header(":path"), Some("/upload"));
assert_eq!(recorded.body(), b"helloworld");
assert_eq!(recorded.header("content-length"), None);
}
#[cfg(feature = "h3")]
#[test]
fn executes_grpc_unary_json_request_over_http3() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let request = HelloRequest {
name: "Ada".to_owned(),
};
let reply = HelloReply {
message: "hello over h3".to_owned(),
};
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+json")
.body_bytes(grpc_test_json_frame(&reply))
.trailer("grpc-status", "0");
let (base, recorded_requests) =
block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.http3_only()
.message(&request)?
.await
})
.unwrap();
assert_eq!(block_on(response.message::<HelloReply>()).unwrap(), reply);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(
recorded.header(":path"),
Some("/helloworld.Greeter/SayHello")
);
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+json")
);
assert_eq!(recorded.body(), grpc_test_json_frame(&request).as_slice());
}
#[cfg(feature = "h3")]
#[test]
fn grpc_http3_binary_metadata_round_trips_through_helpers() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
ok: bool,
}
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+json")
.header("x-stream-bin", "AQI=")
.body_bytes(grpc_test_json_frame(&HelloReply { ok: true }))
.trailer("grpc-status", "0")
.trailer("grpc-status-details-bin", "ZGV0YWlscw==")
.trailer("x-trace-bin", "AAE=");
let (base, recorded_requests) =
block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.http3_only()
.metadata_bin("x-request", b"\x00\x01")?
.message(&serde_json::json!({ "name": "Ada" }))?
.await
})
.unwrap();
assert_eq!(block_on(response.message::<HelloReply>()).unwrap().ok, true);
assert_eq!(
response.metadata_bin("x-stream").unwrap(),
vec![bytes::Bytes::from_static(b"\x01\x02")]
);
let status = response.status().unwrap();
assert_eq!(status.details_bin(), Some(&b"details"[..]));
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("x-trace-bin"), Some("AAE="));
let recorded = recorded_requests.lock().unwrap().pop().unwrap();
assert_eq!(
recorded.header(":path"),
Some("/helloworld.Greeter/SayHello")
);
assert_eq!(recorded.header("x-request-bin"), Some("AAE="));
}
#[cfg(feature = "h3")]
#[test]
fn executes_grpc_protobuf_bytes_request_over_http3() {
let request = Bytes::from_static(&[0x08, 0x96, 0x01]);
let reply = Bytes::from_static(&[0x10, 0x01]);
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+proto")
.body_bytes(grpc_test_frame(&reply))
.trailer("grpc-status", "0");
let (base, recorded_requests) =
block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/protobuf.Greeter/SayHello")
.http3_only()
.codec(crate::GrpcCodec::Protobuf)
.message_bytes(request.clone())?
.await
})
.unwrap();
assert_eq!(block_on(response.message_bytes()).unwrap(), reply);
let recorded_requests = recorded_requests.lock().unwrap();
let recorded = &recorded_requests[0];
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+proto")
);
assert_eq!(recorded.body(), grpc_test_frame(&request).as_slice());
}
#[cfg(feature = "h3")]
#[test]
fn executes_grpc_client_streaming_json_request_over_http3() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Chunk {
value: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Summary {
count: usize,
}
let chunks = vec![
Chunk {
value: "one".to_owned(),
},
Chunk {
value: "two".to_owned(),
},
];
let summary = Summary { count: 2 };
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+json")
.body_bytes(grpc_test_json_frame(&summary))
.trailer("grpc-status", "0");
let (base, recorded_requests) =
block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let request_stream = stream::iter(chunks.clone().into_iter().map(Ok::<_, Error>));
let mut response = block_on(async {
client
.grpc("/upload.Writer/Append")
.http3_only()
.messages(request_stream)?
.await
})
.unwrap();
assert_eq!(block_on(response.message::<Summary>()).unwrap(), summary);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+json")
);
assert_eq!(recorded.header("content-length"), None);
let expected_body = [
grpc_test_json_frame(&chunks[0]),
grpc_test_json_frame(&chunks[1]),
]
.concat();
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(feature = "h3")]
#[test]
fn executes_grpc_client_streaming_protobuf_bytes_request_over_http3() {
let chunks = vec![
Bytes::from_static(&[0x08, 0x01]),
Bytes::from_static(&[0x08, 0x02]),
];
let summary = Bytes::from_static(&[0x10, 0x02]);
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+proto")
.body_bytes(grpc_test_frame(&summary))
.trailer("grpc-status", "0");
let (base, recorded_requests) =
block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let request_stream = stream::iter(chunks.clone().into_iter().map(Ok::<_, Error>));
let mut response = block_on(async {
client
.grpc("/upload.Writer/AppendBytes")
.http3_only()
.codec(crate::GrpcCodec::Protobuf)
.messages_bytes(request_stream)?
.await
})
.unwrap();
assert_eq!(block_on(response.message_bytes()).unwrap(), summary);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+proto")
);
assert_eq!(recorded.header("content-length"), None);
let expected_body = [grpc_test_frame(&chunks[0]), grpc_test_frame(&chunks[1])].concat();
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(feature = "h3")]
#[test]
fn reads_grpc_server_streaming_messages_over_http3() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct ListRequest {
topic: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Event {
seq: usize,
}
let request = ListRequest {
topic: "builds".to_owned(),
};
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+json")
.body_bytes(
[
grpc_test_json_frame(&Event { seq: 1 }),
grpc_test_json_frame(&Event { seq: 2 }),
]
.concat(),
)
.trailer("grpc-status", "0");
let (base, recorded_requests) =
block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/stream.Events/List")
.http3_only()
.message(&request)?
.await
})
.unwrap();
let messages = block_on(response.messages::<Event>().unwrap().collect::<Vec<_>>())
.into_iter()
.map(|item| item.unwrap())
.collect::<Vec<_>>();
assert_eq!(messages, vec![Event { seq: 1 }, Event { seq: 2 }]);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(recorded.header(":path"), Some("/stream.Events/List"));
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+json")
);
assert_eq!(recorded.body(), grpc_test_json_frame(&request).as_slice());
}
#[test]
fn http3_only_policy_returns_explicit_error() {
let result =
run(get("http://example.com").protocol_policy(crate::ProtocolPolicy::Http3Only));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Transport));
}
#[test]
fn http2_only_policy_returns_explicit_error() {
let result =
run(get("http://example.com").protocol_policy(crate::ProtocolPolicy::Http2Only));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Transport));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_http2_only_request_over_tls() {
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![
H2TestResponse::text(200, "hello"),
]]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/secure").http2_only()).unwrap();
assert_eq!(response.version(), crate::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_h2c_prior_knowledge_request_over_http() {
let (base, _) = block_on(spawn_h2c_script_server(vec![vec![H2TestResponse::text(
200, "h2c",
)]]))
.unwrap();
let response = run(get(format!("{base}/resource"))
.http2_only()
.prior_knowledge_h2c(true))
.unwrap();
assert_eq!(response.version(), crate::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "h2c");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn auto_uses_h2c_when_prior_knowledge_enabled() {
let (base, _) = block_on(spawn_h2c_script_server(vec![vec![H2TestResponse::text(
200, "auto-h2c",
)]]))
.unwrap();
let response = run(get(format!("{base}/auto")).prior_knowledge_h2c(true)).unwrap();
assert_eq!(response.version(), crate::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "auto-h2c");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn prefer_http2_uses_http2_transport_when_available() {
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![
H2TestResponse::text(200, "preferred"),
]]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.prefer_http2()
.build()
.unwrap();
let response = run(client.get("/pref")).unwrap();
assert_eq!(response.version(), crate::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "preferred");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_h2c_upgrade_request_over_http() {
let base = block_on(spawn_h2c_upgrade_server(H2TestResponse::text(
200,
"upgrade-ok",
)))
.unwrap();
let response = run(get(format!("{base}/upgrade")).http2_only()).unwrap();
assert_eq!(response.version(), crate::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "upgrade-ok");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn prefer_http2_falls_back_to_http1_when_h2c_upgrade_is_rejected() {
fn assert_request(request: &str) {
assert!(request.starts_with("GET /fallback HTTP/1.1\r\n"));
assert!(
request.contains("\r\nUpgrade: h2c\r\n")
|| request.contains("\r\nupgrade: h2c\r\n")
);
assert!(
request.contains("\r\nConnection: Upgrade, HTTP2-Settings\r\n")
|| request.contains("\r\nconnection: Upgrade, HTTP2-Settings\r\n")
|| request.contains("\r\nconnection: upgrade, http2-settings\r\n")
);
}
let base = block_on(spawn_assert_server(
assert_request,
"HTTP/1.1 200 OK\r\nContent-Length: 8\r\nConnection: close\r\n\r\nfallback",
))
.unwrap();
let response = run(get(format!("{base}/fallback")).prefer_http2()).unwrap();
assert_eq!(response.version(), crate::Version::Http11);
assert_eq!(block_on(response.text()).unwrap(), "fallback");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn streams_http2_response_body_and_trailers() {
let scripted = H2TestResponse::empty(200)
.header("content-length", "11")
.body_frame("hello")
.body_frame(" world")
.trailer("x-finished", "yes");
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/stream").http2_only()).unwrap();
let (body, trailers) = block_on(response.bytes_and_trailers()).unwrap();
assert_eq!(body, Bytes::from_static(b"hello world"));
let trailers = trailers.expect("trailers");
assert_eq!(trailers.get("x-finished"), Some("yes"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn dropping_http2_response_stream_sends_rst_stream_cancel_and_keeps_connection_reusable() {
let (base, connection_count, reset_codes) = block_on(spawn_h2_tls_cancel_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/stream").http2_only()).unwrap();
let mut body = response.bytes_stream();
assert_eq!(
block_on(body.next()).unwrap().unwrap(),
Bytes::from_static(b"hello")
);
drop(body);
let second = run(client.get("/after").http2_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
assert_eq!(reset_codes.lock().unwrap().as_slice(), &[8]);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn early_http2_response_cancels_remaining_request_body_and_keeps_connection_reusable() {
let (base, connection_count, reset_codes, recorded_requests) =
block_on(spawn_h2_tls_early_response_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let upload_polls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let upload_polls_task = Arc::clone(&upload_polls);
let body_stream = async_stream::stream! {
upload_polls_task.fetch_add(1, Ordering::SeqCst);
yield Ok(Bytes::from_static(b"hello"));
async_io::Timer::after(std::time::Duration::from_millis(80)).await;
upload_polls_task.fetch_add(1, Ordering::SeqCst);
yield Ok(Bytes::from_static(b"world"));
};
let response = run(client
.post("/early")
.http2_only()
.timeout(std::time::Duration::from_secs(1))
.body_stream(Box::pin(body_stream)))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "done");
block_on(async {
async_io::Timer::after(std::time::Duration::from_millis(140)).await;
});
let second = run(client.get("/after").http2_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
assert_eq!(reset_codes.lock().unwrap().as_slice(), &[8]);
assert_eq!(upload_polls.load(Ordering::SeqCst), 1);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 2);
assert_eq!(recorded_requests[0].header(":path"), Some("/early"));
assert_eq!(recorded_requests[0].body(), b"hello");
assert_eq!(recorded_requests[1].header(":path"), Some("/after"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn request_body_stream_error_resets_only_the_failed_http2_stream() {
let (base, connection_count, reset_codes, recorded_requests) =
block_on(spawn_h2_tls_upload_error_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let body_stream = async_stream::stream! {
yield Ok(Bytes::from_static(b"hello"));
yield Err(Error::new(ErrorKind::Transport, "upload failed"));
};
let err = run(client
.post("/upload")
.http2_only()
.timeout(std::time::Duration::from_secs(1))
.body_stream(Box::pin(body_stream)))
.unwrap_err();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert!(err.to_string().contains("upload failed"));
let second = run(client.get("/after").http2_only()).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
assert_eq!(reset_codes.lock().unwrap().as_slice(), &[8]);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 2);
assert_eq!(recorded_requests[0].header(":path"), Some("/upload"));
assert_eq!(recorded_requests[0].body(), b"hello");
assert_eq!(recorded_requests[1].header(":path"), Some("/after"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn rejects_http2_data_frame_before_response_headers() {
let (base, connection_count) = block_on(spawn_h2_tls_protocol_error_server(
H2ProtocolErrorScenario::DataBeforeHeaders,
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let err = run(client
.get("/bad")
.http2_only()
.timeout(std::time::Duration::from_secs(1)))
.unwrap_err();
let message = err.to_string();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert!(message.contains("http2 DATA frame arrived before final response headers"));
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn rejects_http2_trailers_without_end_stream() {
let (base, connection_count) = block_on(spawn_h2_tls_protocol_error_server(
H2ProtocolErrorScenario::TrailersWithoutEndStream,
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client
.get("/bad")
.http2_only()
.timeout(std::time::Duration::from_secs(1)))
.unwrap();
let err = block_on(response.text()).unwrap_err();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert!(
err.to_string()
.contains("http2 trailing headers must end the stream")
);
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn accepts_http2_settings_enable_push_from_server() {
let settings = encode_h2_test_settings(&[(0x2, 0)]); let (base, _connection_count, _recorded_requests) =
block_on(spawn_h2_tls_recording_server_with_settings(
vec![vec![H2TestResponse::text(200, "ok")]],
settings,
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client
.get("/")
.http2_only()
.timeout(std::time::Duration::from_secs(5)))
.unwrap();
assert_eq!(response.status().as_u16(), 200);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn rejects_http2_request_headers_that_exceed_peer_max_header_list_size() {
let settings = encode_h2_test_settings(&[(0x6, 512)]);
let (base, connection_count, recorded_requests) =
block_on(spawn_h2_tls_recording_server_with_settings(
vec![vec![H2TestResponse::text(200, "one")]],
settings,
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http2_only()).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "one");
let oversized = "a".repeat(600);
let err = run(client
.get("/two")
.header("x-oversized", &oversized)
.unwrap()
.http2_only())
.unwrap_err();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert!(err.to_string().contains("SETTINGS_MAX_HEADER_LIST_SIZE"));
assert_eq!(recorded_requests.lock().unwrap().len(), 1);
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn decodes_gzip_http2_response_body() {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"compressed").unwrap();
let compressed = encoder.finish().unwrap();
let scripted = H2TestResponse::empty(200)
.header("content-encoding", "gzip")
.header("content-length", compressed.len().to_string())
.body_frame(compressed);
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/gzip").http2_only()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "compressed");
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "zstd"))]
#[test]
fn decodes_zstd_http2_response_body() {
let compressed = zstd::stream::encode_all(&b"compressed zstd"[..], 0).unwrap();
let response = H2TestResponse::empty(200)
.header("content-encoding", "zstd")
.header("content-length", compressed.len().to_string())
.body_frame(compressed);
let (base, _, _) = block_on(spawn_h2_tls_recording_server(vec![vec![response]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/zstd").http2_only()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "compressed zstd");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn http2_sends_default_accept_encoding_header() {
let (base, _, recorded_requests) = block_on(spawn_h2_tls_recording_server(vec![vec![
H2TestResponse::text(200, "ok"),
]]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/headers").http2_only()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let recorded = recorded_requests.lock().unwrap();
assert_eq!(recorded.len(), 1);
assert_eq!(
recorded[0].header("accept-encoding"),
Some(DEFAULT_ACCEPT_ENCODING)
);
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
#[test]
fn http2_emulation_chrome_writes_expected_lifecycle_frames_and_header_order() {
let (base, capture) = block_on(spawn_h2_tls_lifecycle_probe_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.emulation(Emulation::Chrome136)
.build()
.unwrap();
let response = run(client.get("/probe").http2_only()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let (lifecycle, request) = capture
.lock()
.unwrap()
.take()
.expect("captured h2 lifecycle");
assert_eq!(
lifecycle.client_settings,
Some(vec![
(0x1, 65_536),
(0x2, 0),
(0x4, 6_291_456),
(0x5, 16_384)
])
);
assert_eq!(lifecycle.connection_window_increment, Some(15_597_570));
assert!(lifecycle.priority_frames.is_empty());
let first_headers_idx = lifecycle
.frame_sequence
.iter()
.position(|(frame_type, _, _)| *frame_type == H2_TYPE_HEADERS)
.expect("headers frame in lifecycle");
let settings_idx = lifecycle
.frame_sequence
.iter()
.position(|(frame_type, flags, stream_id)| {
*frame_type == H2_TYPE_SETTINGS && *flags & H2_ACK == 0 && *stream_id == 0
})
.expect("client settings frame");
let window_update_idx = lifecycle
.frame_sequence
.iter()
.position(|(frame_type, _, stream_id)| {
*frame_type == H2_TYPE_WINDOW_UPDATE && *stream_id == 0
})
.expect("connection window update");
assert_eq!(settings_idx, 0);
assert!(window_update_idx < first_headers_idx);
let names = request
.headers
.iter()
.map(|(name, _)| name.as_str())
.collect::<Vec<_>>();
assert_eq!(
names,
vec![
":method",
":authority",
":scheme",
":path",
"user-agent",
"accept",
"accept-language",
"accept-encoding",
"content-length",
]
);
assert_eq!(request.header(":path"), Some("/probe"));
assert_eq!(
request.header("accept-encoding"),
Some(DEFAULT_ACCEPT_ENCODING)
);
assert_eq!(request.header("content-length"), Some("0"));
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
#[test]
fn http2_custom_fingerprint_writes_priority_before_headers() {
let (base, capture) = block_on(spawn_h2_tls_lifecycle_probe_server()).unwrap();
let profile = EmulationProfile::builder()
.profile_id("custom-h2-probe")
.http2_fingerprint(Http2Fingerprint {
settings_order: vec![
"INITIAL_WINDOW_SIZE".to_owned(),
"HEADER_TABLE_SIZE".to_owned(),
"ENABLE_PUSH".to_owned(),
"MAX_FRAME_SIZE".to_owned(),
],
pseudo_header_order: vec![
":method".to_owned(),
":path".to_owned(),
":authority".to_owned(),
":scheme".to_owned(),
],
regular_header_order: vec!["x-test".to_owned(), "accept-encoding".to_owned()],
header_table_size: Some(32_768),
initial_window_size: Some(131_072),
initial_connection_window_size: Some(1_048_576),
max_frame_size: Some(32_768),
priorities: vec![Http2PrioritySpec {
stream_id: None,
phase: Http2PriorityPhase::BeforeHeaders,
stream_dependency: 0,
weight: 220,
exclusive: true,
}],
})
.build();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.emulation(profile)
.build()
.unwrap();
let response = run(client
.get("/custom")
.http2_only()
.header("x-test", "1")
.unwrap())
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let (lifecycle, request) = capture
.lock()
.unwrap()
.take()
.expect("captured h2 lifecycle");
assert_eq!(
lifecycle.client_settings,
Some(vec![(0x4, 131_072), (0x1, 32_768), (0x2, 0), (0x5, 32_768)])
);
assert_eq!(lifecycle.connection_window_increment, Some(983_041));
assert_eq!(lifecycle.priority_frames, vec![(1, 0, 220, true)]);
let first_priority_idx = lifecycle
.frame_sequence
.iter()
.position(|(frame_type, _, _)| *frame_type == H2_TYPE_PRIORITY)
.expect("priority frame in lifecycle");
let first_headers_idx = lifecycle
.frame_sequence
.iter()
.position(|(frame_type, _, _)| *frame_type == H2_TYPE_HEADERS)
.expect("headers frame in lifecycle");
assert!(first_priority_idx < first_headers_idx);
let names = request
.headers
.iter()
.map(|(name, _)| name.as_str())
.collect::<Vec<_>>();
assert_eq!(
names,
vec![
":method",
":path",
":authority",
":scheme",
"x-test",
"accept-encoding",
"content-length",
]
);
assert_eq!(request.header(":path"), Some("/custom"));
assert_eq!(request.header("x-test"), Some("1"));
assert_eq!(request.header("content-length"), Some("0"));
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
#[test]
fn http2_custom_fingerprint_can_emit_placeholder_priority_tree_before_headers() {
let (base, capture) = block_on(spawn_h2_tls_lifecycle_probe_server()).unwrap();
let profile = EmulationProfile::builder()
.profile_id("custom-h2-priority-tree")
.http2_fingerprint(Http2Fingerprint {
settings_order: vec![
"HEADER_TABLE_SIZE".to_owned(),
"ENABLE_PUSH".to_owned(),
"INITIAL_WINDOW_SIZE".to_owned(),
"MAX_FRAME_SIZE".to_owned(),
],
pseudo_header_order: vec![
":method".to_owned(),
":authority".to_owned(),
":scheme".to_owned(),
":path".to_owned(),
],
regular_header_order: vec!["accept-encoding".to_owned()],
header_table_size: Some(65_536),
initial_window_size: Some(6_291_456),
initial_connection_window_size: Some(15_663_105),
max_frame_size: Some(16_384),
priorities: vec![
Http2PrioritySpec {
stream_id: Some(3),
phase: Http2PriorityPhase::BeforeHeaders,
stream_dependency: 0,
weight: 201,
exclusive: false,
},
Http2PrioritySpec {
stream_id: Some(5),
phase: Http2PriorityPhase::BeforeHeaders,
stream_dependency: 3,
weight: 101,
exclusive: true,
},
],
})
.build();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.emulation(profile)
.build()
.unwrap();
let response = run(client.get("/tree").http2_only()).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let (lifecycle, request) = capture
.lock()
.unwrap()
.take()
.expect("captured h2 lifecycle");
assert_eq!(
lifecycle.priority_frames,
vec![(3, 0, 201, false), (5, 3, 101, true)]
);
let priority_indices = lifecycle
.frame_sequence
.iter()
.enumerate()
.filter_map(|(idx, (frame_type, _, _))| {
(*frame_type == H2_TYPE_PRIORITY).then_some(idx)
})
.collect::<Vec<_>>();
let first_headers_idx = lifecycle
.frame_sequence
.iter()
.position(|(frame_type, _, _)| *frame_type == H2_TYPE_HEADERS)
.expect("headers frame in lifecycle");
assert_eq!(priority_indices.len(), 2);
assert!(priority_indices.iter().all(|idx| *idx < first_headers_idx));
assert_eq!(request.header(":path"), Some("/tree"));
}
#[cfg(all(feature = "rustls", feature = "h2", feature = "emulation"))]
#[test]
fn http2_custom_fingerprint_can_emit_post_headers_reprioritization() {
let (base, capture) = block_on(spawn_h2_tls_lifecycle_probe_server()).unwrap();
let profile = EmulationProfile::builder()
.profile_id("custom-h2-post-headers-priority")
.http2_fingerprint(Http2Fingerprint {
settings_order: vec![
"HEADER_TABLE_SIZE".to_owned(),
"ENABLE_PUSH".to_owned(),
"INITIAL_WINDOW_SIZE".to_owned(),
"MAX_FRAME_SIZE".to_owned(),
],
pseudo_header_order: vec![
":method".to_owned(),
":authority".to_owned(),
":scheme".to_owned(),
":path".to_owned(),
],
regular_header_order: vec!["x-test".to_owned(), "accept-encoding".to_owned()],
header_table_size: Some(65_536),
initial_window_size: Some(6_291_456),
initial_connection_window_size: Some(15_663_105),
max_frame_size: Some(16_384),
priorities: vec![
Http2PrioritySpec {
stream_id: Some(3),
phase: Http2PriorityPhase::BeforeHeaders,
stream_dependency: 0,
weight: 201,
exclusive: false,
},
Http2PrioritySpec {
stream_id: None,
phase: Http2PriorityPhase::AfterHeaders,
stream_dependency: 3,
weight: 111,
exclusive: false,
},
],
})
.build();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.emulation(profile)
.build()
.unwrap();
let response = run(client
.get("/reprioritize")
.http2_only()
.header("x-test", "1")
.unwrap())
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let (lifecycle, request) = capture
.lock()
.unwrap()
.take()
.expect("captured h2 lifecycle");
assert_eq!(
lifecycle.priority_frames,
vec![(3, 0, 201, false), (1, 3, 111, false)]
);
let priority_indices = lifecycle
.frame_sequence
.iter()
.enumerate()
.filter_map(|(idx, (frame_type, _, _))| {
(*frame_type == H2_TYPE_PRIORITY).then_some(idx)
})
.collect::<Vec<_>>();
let first_headers_idx = lifecycle
.frame_sequence
.iter()
.position(|(frame_type, _, _)| *frame_type == H2_TYPE_HEADERS)
.expect("headers frame in lifecycle");
assert_eq!(priority_indices.len(), 2);
assert!(priority_indices[0] < first_headers_idx);
assert!(priority_indices[1] > first_headers_idx);
assert_eq!(request.header(":path"), Some("/reprioritize"));
assert_eq!(request.header("x-test"), Some("1"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn sse_last_event_id_flows_end_to_end_over_http2() {
let scripted = H2TestResponse::text(200, "id: 9\ndata: hello\n\n")
.header("content-type", "text/event-stream");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client
.get("/events")
.http2_only()
.last_event_id("evt-9")
.unwrap())
.unwrap();
let mut stream = response.sse().unwrap();
let event = block_on(async { stream.next().await.unwrap().unwrap() });
assert_eq!(event.id.as_deref(), Some("9"));
assert_eq!(event.data, "hello");
let recorded = recorded_requests.lock().unwrap();
assert_eq!(recorded.len(), 1);
assert_eq!(recorded[0].header("last-event-id"), Some("evt-9"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn reuses_h2c_connection_for_multiple_requests() {
let (base, connection_count) = block_on(spawn_h2c_script_server(vec![vec![
H2TestResponse::text(200, "one"),
H2TestResponse::text(200, "two"),
]]))
.unwrap();
let client = crate::Client::builder()
.http2_only()
.prior_knowledge_h2c(true)
.build()
.unwrap();
let first = run(client.get(format!("{base}/one"))).unwrap();
let second = run(client.get(format!("{base}/two"))).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "one");
assert_eq!(block_on(second.text()).unwrap(), "two");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn reuses_http2_connection_for_multiple_requests() {
let (base, connection_count) = block_on(spawn_h2_tls_script_server(vec![vec![
H2TestResponse::text(200, "one"),
H2TestResponse::text(200, "two"),
]]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http2_only()).unwrap().text();
let second = run(client.get("/two").http2_only()).unwrap().text();
assert_eq!(block_on(first).unwrap(), "one");
assert_eq!(block_on(second).unwrap(), "two");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn reuses_http2_connection_after_keepalive_ping_ack() {
let (base, connection_count) = block_on(spawn_h2_tls_keepalive_ack_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http2_only()).unwrap().text();
assert_eq!(block_on(first).unwrap(), "one");
std::thread::sleep(std::time::Duration::from_millis(80));
let second = run(client.get("/two").http2_only()).unwrap().text();
assert_eq!(block_on(second).unwrap(), "two");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn opens_new_http2_connection_after_keepalive_ping_timeout() {
let (base, connection_count) = block_on(spawn_h2_tls_keepalive_timeout_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.h2_keepalive(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(30),
)
.build()
.unwrap();
let first = run(client.get("/one").http2_only()).unwrap().text();
assert_eq!(block_on(first).unwrap(), "one");
std::thread::sleep(std::time::Duration::from_millis(120));
let second = run(client.get("/two").http2_only()).unwrap().text();
assert_eq!(block_on(second).unwrap(), "two");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn opens_new_http2_connection_after_goaway() {
let (base, connection_count) = block_on(spawn_h2_tls_script_server(vec![
vec![H2TestResponse::text(200, "first").goaway()],
vec![H2TestResponse::text(200, "second")],
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http2_only()).unwrap().text();
let second = run(client.get("/two").http2_only()).unwrap().text();
assert_eq!(block_on(first).unwrap(), "first");
assert_eq!(block_on(second).unwrap(), "second");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn reuses_http2_connection_regardless_of_keepalive_config() {
let (base, connection_count) = block_on(spawn_h2_tls_script_server(vec![vec![
H2TestResponse::text(200, "first"),
H2TestResponse::text(200, "second"),
]]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http2_only().h2_keepalive(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(30),
))
.unwrap()
.text();
assert_eq!(block_on(first).unwrap(), "first");
let second = run(client.get("/two").http2_only().h2_keepalive(
std::time::Duration::from_millis(80),
std::time::Duration::from_millis(80),
))
.unwrap()
.text();
assert_eq!(block_on(second).unwrap(), "second");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn retries_get_request_transparently_after_http2_goaway_on_reused_connection() {
let (base, connection_count) = block_on(spawn_h2_tls_script_server(vec![
vec![
H2TestResponse::text(200, "first"),
H2TestResponse::text(200, "retried").goaway(),
],
vec![H2TestResponse::text(200, "retried")],
]))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").http2_only()).unwrap().text();
assert_eq!(block_on(first).unwrap(), "first");
let second = run(client.get("/two").http2_only()).unwrap().text();
assert_eq!(block_on(second).unwrap(), "retried");
let third = run(client.get("/three").http2_only()).unwrap().text();
assert_eq!(block_on(third).unwrap(), "retried");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn head_request_returns_empty_http2_body() {
let scripted = H2TestResponse::empty(200).header("content-length", "5");
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.head("/head").http2_only()).unwrap();
assert_eq!(response.version(), crate::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn no_content_response_does_not_surface_http2_body_frames() {
let scripted = H2TestResponse::empty(204)
.header("content-length", "4")
.body_frame("oops");
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
match run(client.get("/no-content").http2_only()) {
Ok(response) => {
assert_eq!(response.version(), crate::Version::Http2);
assert_eq!(response.status().as_u16(), 204);
match block_on(response.text()) {
Ok(text) => assert_eq!(text, ""),
Err(err) => assert_eq!(err.kind(), &crate::ErrorKind::Transport),
}
}
Err(err) => assert_eq!(err.kind(), &crate::ErrorKind::Transport),
}
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn streams_http2_request_body_without_content_length() {
let scripted = H2TestResponse::text(200, "ok");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let body_stream = async_stream::stream! {
yield Ok(Bytes::from_static(b"hello"));
async_io::Timer::after(std::time::Duration::from_millis(25)).await;
yield Ok(Bytes::from_static(b"-"));
async_io::Timer::after(std::time::Duration::from_millis(25)).await;
yield Ok(Bytes::from_static(b"world"));
};
let response = run(client
.post("/upload")
.http2_only()
.timeout(std::time::Duration::from_secs(1))
.body_stream(Box::pin(body_stream)))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(recorded.header(":method"), Some("POST"));
assert_eq!(recorded.header(":path"), Some("/upload"));
assert_eq!(recorded.header("content-length"), None);
assert_eq!(recorded.header("transfer-encoding"), None);
assert_eq!(recorded.body(), b"hello-world");
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_grpc_unary_json_request_over_h2c() {
#[derive(serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(serde::Deserialize, serde::Serialize)]
struct HelloReply {
ok: bool,
}
let reply = HelloReply { ok: true };
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.body_frame(grpc_test_json_frame(&reply))
.trailer("grpc-status", "0");
let (base, _count, recorded_requests) =
block_on(spawn_h2c_recording_server(vec![vec![scripted]])).unwrap();
let mut response = block_on(async {
crate::Client::builder()
.prior_knowledge_h2c(true)
.build()
.unwrap()
.grpc(format!("{base}/helloworld.Greeter/SayHello"))
.prior_knowledge_h2c(true)
.message(&HelloRequest {
name: "Ada".to_owned(),
})?
.await
})
.unwrap();
let message = block_on(response.message::<HelloReply>()).unwrap();
assert!(message.ok);
let recorded = recorded_requests.lock().unwrap().pop().unwrap();
assert_eq!(recorded.header(":method"), Some("POST"));
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+json")
);
assert_eq!(recorded.header("te"), Some("trailers"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_grpc_unary_json_request_over_http2() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let request = HelloRequest {
name: "Ada".to_owned(),
};
let reply = HelloReply {
message: "hello Ada".to_owned(),
};
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.body_frame(grpc_test_json_frame(&reply))
.trailer("grpc-status", "0")
.trailer("x-trace-id", "abc");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.metadata("x-request-id", "req-1")?
.message(&request)?
.await
})
.unwrap();
assert_eq!(block_on(response.message::<HelloReply>()).unwrap(), reply);
let status = response.status().unwrap();
assert_eq!(status.code(), 0);
assert_eq!(status.message(), "");
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("grpc-status"), Some("0"));
assert_eq!(trailers.get("x-trace-id"), Some("abc"));
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(recorded.header(":method"), Some("POST"));
assert_eq!(
recorded.header(":path"),
Some("/helloworld.Greeter/SayHello")
);
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+json")
);
assert_eq!(recorded.header("te"), Some("trailers"));
assert_eq!(recorded.header("x-request-id"), Some("req-1"));
assert_eq!(recorded.body(), grpc_test_json_frame(&request).as_slice());
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn grpc_http2_binary_metadata_round_trips_through_helpers() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
ok: bool,
}
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.header("x-stream-bin", "AQI=")
.body_frame(grpc_test_json_frame(&HelloReply { ok: true }))
.trailer("grpc-status", "0")
.trailer("grpc-status-details-bin", "ZGV0YWlscw==")
.trailer("x-trace-bin", "AAE=");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.metadata_bin("x-request", b"\x00\x01")?
.message(&serde_json::json!({ "name": "Ada" }))?
.await
})
.unwrap();
assert_eq!(block_on(response.message::<HelloReply>()).unwrap().ok, true);
assert_eq!(
response.metadata_bin("x-stream").unwrap(),
vec![bytes::Bytes::from_static(b"\x01\x02")]
);
let status = response.status().unwrap();
assert_eq!(status.details_bin(), Some(&b"details"[..]));
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("x-trace-bin"), Some("AAE="));
let recorded = recorded_requests.lock().unwrap().pop().unwrap();
assert_eq!(recorded.header("x-request-bin"), Some("AAE="));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_grpc_client_streaming_json_request_over_http2() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Chunk {
value: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Summary {
count: usize,
}
let chunks = vec![
Chunk {
value: "one".to_owned(),
},
Chunk {
value: "two".to_owned(),
},
];
let summary = Summary { count: 2 };
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.body_frame(grpc_test_json_frame(&summary))
.trailer("grpc-status", "0");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let request_stream = stream::iter(chunks.clone().into_iter().map(Ok::<_, Error>));
let mut response = block_on(async {
client
.grpc("/upload.Writer/Append")
.messages(request_stream)?
.await
})
.unwrap();
assert_eq!(block_on(response.message::<Summary>()).unwrap(), summary);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+json")
);
assert_eq!(recorded.header("content-length"), None);
let expected_body = [
grpc_test_json_frame(&chunks[0]),
grpc_test_json_frame(&chunks[1]),
]
.concat();
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_grpc_client_streaming_protobuf_bytes_request_over_http2() {
let chunks = vec![
Bytes::from_static(&[0x08, 0x01]),
Bytes::from_static(&[0x08, 0x02]),
];
let summary = Bytes::from_static(&[0x10, 0x02]);
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+proto")
.body_frame(grpc_test_frame(&summary))
.trailer("grpc-status", "0");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let request_stream = stream::iter(chunks.clone().into_iter().map(Ok::<_, Error>));
let mut response = block_on(async {
client
.grpc("/upload.Writer/AppendBytes")
.codec(crate::GrpcCodec::Protobuf)
.messages_bytes(request_stream)?
.await
})
.unwrap();
assert_eq!(block_on(response.message_bytes()).unwrap(), summary);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+proto")
);
assert_eq!(recorded.header("content-length"), None);
let expected_body = [grpc_test_frame(&chunks[0]), grpc_test_frame(&chunks[1])].concat();
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn reads_grpc_server_streaming_messages_over_http2() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct ListRequest {
topic: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Event {
seq: usize,
}
let request = ListRequest {
topic: "builds".to_owned(),
};
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.body_frame(grpc_test_json_frame(&Event { seq: 1 }))
.body_frame(grpc_test_json_frame(&Event { seq: 2 }))
.trailer("grpc-status", "0");
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response =
block_on(async { client.grpc("/stream.Events/List").message(&request)?.await })
.unwrap();
let messages = block_on(response.messages::<Event>().unwrap().collect::<Vec<_>>())
.into_iter()
.map(|item| item.unwrap())
.collect::<Vec<_>>();
assert_eq!(messages, vec![Event { seq: 1 }, Event { seq: 2 }]);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn streams_grpc_server_messages_incrementally_over_http2() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct ListRequest {
topic: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Event {
seq: usize,
}
let base = block_on(spawn_h2_tls_delayed_grpc_stream_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let request = ListRequest {
topic: "builds".to_owned(),
};
let mut response = block_on(async {
client
.grpc("/stream.Events/List")
.message(&request)?
.send_streaming()
.await
})
.unwrap();
assert!(!response.is_complete());
assert_eq!(response.status().unwrap(), None);
assert_eq!(response.trailers().unwrap(), None);
assert_eq!(
block_on(response.next_message::<Event>()).unwrap(),
Some(Event { seq: 1 })
);
assert!(!response.is_complete());
assert_eq!(response.status().unwrap(), None);
assert_eq!(
block_on(response.next_message::<Event>()).unwrap(),
Some(Event { seq: 2 })
);
assert_eq!(block_on(response.next_message::<Event>()).unwrap(), None);
assert!(response.is_complete());
let status = response.status().unwrap().unwrap();
assert_eq!(status.code(), 0);
assert_eq!(status.message(), "");
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("grpc-status"), Some("0"));
assert_eq!(trailers.get("x-trace-id"), Some("delayed"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn streaming_grpc_response_surfaces_trailing_error_status() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let request = HelloRequest {
name: "Ada".to_owned(),
};
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.body_frame(grpc_test_json_frame(&HelloReply {
message: "hello Ada".to_owned(),
}))
.trailer("grpc-status", "7")
.trailer("grpc-message", "permission%20denied");
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.message(&request)?
.send_streaming()
.await
})
.unwrap();
assert_eq!(
block_on(response.next_message::<HelloReply>()).unwrap(),
Some(HelloReply {
message: "hello Ada".to_owned(),
})
);
let err = block_on(response.next_message::<HelloReply>()).unwrap_err();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert!(
err.to_string()
.contains("grpc request failed with status 7")
);
let status = response.status().unwrap().unwrap();
assert_eq!(status.code(), 7);
assert_eq!(status.message(), "permission denied");
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("grpc-status"), Some("7"));
assert_eq!(trailers.get("grpc-message"), Some("permission%20denied"));
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn grpc_streaming_api_can_receive_response_before_request_stream_completes() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Chunk {
value: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Event {
seq: usize,
}
let (base, recorded_requests) = block_on(spawn_h2_tls_grpc_bidi_probe_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let chunks = vec![
Chunk {
value: "one".to_owned(),
},
Chunk {
value: "two".to_owned(),
},
];
let request_chunks = chunks.clone();
let request_stream = async_stream::stream! {
yield Ok::<_, Error>(request_chunks[0].clone());
async_io::Timer::after(std::time::Duration::from_millis(60)).await;
yield Ok::<_, Error>(request_chunks[1].clone());
};
let mut response = block_on(async {
client
.grpc("/chat.Service/Talk")
.messages(request_stream)?
.send_streaming()
.await
})
.unwrap();
assert_eq!(response.status().unwrap(), None);
assert_eq!(
block_on(response.next_message::<Event>()).unwrap(),
Some(Event { seq: 1 })
);
assert_eq!(response.status().unwrap(), None);
assert_eq!(
block_on(response.next_message::<Event>()).unwrap(),
Some(Event { seq: 2 })
);
assert_eq!(block_on(response.next_message::<Event>()).unwrap(), None);
assert_eq!(response.status().unwrap().unwrap().code(), 0);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
let expected_body = [
grpc_test_json_frame(&chunks[0]),
grpc_test_json_frame(&chunks[1]),
]
.concat();
assert_eq!(recorded.header(":path"), Some("/chat.Service/Talk"));
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(feature = "h3")]
#[test]
fn grpc_streaming_api_can_receive_response_before_request_stream_completes_over_http3() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Chunk {
value: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Event {
seq: usize,
}
let (base, recorded_requests) = block_on(spawn_h3_quic_grpc_bidi_probe_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let chunks = vec![
Chunk {
value: "one".to_owned(),
},
Chunk {
value: "two".to_owned(),
},
];
let request_chunks = chunks.clone();
let request_stream = async_stream::stream! {
yield Ok::<_, Error>(request_chunks[0].clone());
async_io::Timer::after(std::time::Duration::from_millis(60)).await;
yield Ok::<_, Error>(request_chunks[1].clone());
};
let mut response = block_on(async {
client
.grpc("/chat.Service/Talk")
.http3_only()
.messages(request_stream)?
.send_streaming()
.await
})
.unwrap();
assert_eq!(response.status().unwrap(), None);
assert_eq!(
block_on(response.next_message::<Event>()).unwrap(),
Some(Event { seq: 1 })
);
assert_eq!(response.status().unwrap(), None);
assert_eq!(
block_on(response.next_message::<Event>()).unwrap(),
Some(Event { seq: 2 })
);
assert_eq!(block_on(response.next_message::<Event>()).unwrap(), None);
assert_eq!(response.status().unwrap().unwrap().code(), 0);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
let expected_body = [
grpc_test_json_frame(&chunks[0]),
grpc_test_json_frame(&chunks[1]),
]
.concat();
assert_eq!(recorded.header(":path"), Some("/chat.Service/Talk"));
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(feature = "h3")]
#[test]
fn streaming_grpc_http3_response_surfaces_trailing_error_status() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let request = HelloRequest {
name: "Ada".to_owned(),
};
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+json")
.body_bytes(grpc_test_json_frame(&HelloReply {
message: "hello Ada".to_owned(),
}))
.trailer("grpc-status", "7")
.trailer("grpc-message", "permission%20denied");
let (base, _) = block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.http3_only()
.message(&request)?
.send_streaming()
.await
})
.unwrap();
assert_eq!(
block_on(response.next_message::<HelloReply>()).unwrap(),
Some(HelloReply {
message: "hello Ada".to_owned(),
})
);
let err = block_on(response.next_message::<HelloReply>()).unwrap_err();
assert_eq!(err.kind(), &ErrorKind::Transport);
assert!(
err.to_string()
.contains("grpc request failed with status 7")
);
let status = response.status().unwrap().unwrap();
assert_eq!(status.code(), 7);
assert_eq!(status.message(), "permission denied");
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("grpc-status"), Some("7"));
assert_eq!(trailers.get("grpc-message"), Some("permission%20denied"));
}
#[cfg(feature = "h3")]
#[test]
fn exposes_grpc_error_status_from_trailers_only_response_over_http3() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let response = H3TestResponse::text(200, "")
.header("content-type", "application/grpc+json")
.header("grpc-status", "5")
.header("grpc-message", "user%20not%20found");
let (base, _) = block_on(spawn_h3_quic_recording_body_server(response)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/GetUser")
.http3_only()
.message(&HelloRequest {
name: "missing".to_owned(),
})?
.await
})
.unwrap();
let status = response.status().unwrap();
assert_eq!(status.code(), 5);
assert_eq!(status.message(), "user not found");
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("grpc-status"), Some("5"));
assert_eq!(trailers.get("grpc-message"), Some("user%20not%20found"));
let err = block_on(response.message::<HelloReply>()).unwrap_err();
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
}
#[cfg(feature = "h3")]
#[test]
fn grpc_duplex_call_supports_interleaved_send_and_receive_over_http3() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Chunk {
value: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Event {
seq: usize,
}
let (base, recorded_requests) = block_on(spawn_h3_quic_grpc_bidi_probe_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut call = block_on(async {
client
.grpc("/chat.Service/Talk")
.http3_only()
.metadata("x-chat-id", "room-1")?
.open_duplex()
.await
})
.unwrap();
block_on(call.send_message(&Chunk {
value: "one".to_owned(),
}))
.unwrap();
assert_eq!(
block_on(call.next_message::<Event>()).unwrap(),
Some(Event { seq: 1 })
);
assert_eq!(
call.metadata_bin("x-stream").unwrap(),
vec![bytes::Bytes::from_static(b"\x01\x02")]
);
assert_eq!(call.status().unwrap(), None);
block_on(call.send_message(&Chunk {
value: "two".to_owned(),
}))
.unwrap();
call.finish_sending();
assert_eq!(
block_on(call.next_message::<Event>()).unwrap(),
Some(Event { seq: 2 })
);
assert_eq!(block_on(call.next_message::<Event>()).unwrap(), None);
assert!(call.is_complete());
assert_eq!(call.status().unwrap().unwrap().code(), 0);
assert_eq!(
call.trailer_metadata_bin("x-trace").unwrap(),
vec![bytes::Bytes::from_static(b"\x00\x01")]
);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(recorded.header(":path"), Some("/chat.Service/Talk"));
assert_eq!(recorded.header("x-chat-id"), Some("room-1"));
let expected_body = [
grpc_test_json_frame(&Chunk {
value: "one".to_owned(),
}),
grpc_test_json_frame(&Chunk {
value: "two".to_owned(),
}),
]
.concat();
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn grpc_duplex_call_supports_interleaved_send_and_receive_over_http2() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Chunk {
value: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct Event {
seq: usize,
}
let (base, recorded_requests) = block_on(spawn_h2_tls_grpc_bidi_probe_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut call = block_on(async {
client
.grpc("/chat.Service/Talk")
.metadata("x-chat-id", "room-1")?
.open_duplex()
.await
})
.unwrap();
block_on(call.send_message(&Chunk {
value: "one".to_owned(),
}))
.unwrap();
assert_eq!(
block_on(call.next_message::<Event>()).unwrap(),
Some(Event { seq: 1 })
);
assert_eq!(
call.metadata_bin("x-stream").unwrap(),
vec![bytes::Bytes::from_static(b"\x01\x02")]
);
assert_eq!(call.status().unwrap(), None);
block_on(call.send_message(&Chunk {
value: "two".to_owned(),
}))
.unwrap();
call.finish_sending();
assert_eq!(
block_on(call.next_message::<Event>()).unwrap(),
Some(Event { seq: 2 })
);
assert_eq!(block_on(call.next_message::<Event>()).unwrap(), None);
assert!(call.is_complete());
assert_eq!(call.status().unwrap().unwrap().code(), 0);
assert_eq!(
call.trailer_metadata_bin("x-trace").unwrap(),
vec![bytes::Bytes::from_static(b"\x00\x01")]
);
let recorded_requests = recorded_requests.lock().unwrap();
assert_eq!(recorded_requests.len(), 1);
let recorded = &recorded_requests[0];
assert_eq!(recorded.header(":path"), Some("/chat.Service/Talk"));
assert_eq!(recorded.header("x-chat-id"), Some("room-1"));
let expected_body = [
grpc_test_json_frame(&Chunk {
value: "one".to_owned(),
}),
grpc_test_json_frame(&Chunk {
value: "two".to_owned(),
}),
]
.concat();
assert_eq!(recorded.body(), expected_body.as_slice());
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn executes_grpc_protobuf_bytes_request_over_http2() {
let request = Bytes::from_static(&[0x08, 0x96, 0x01]);
let reply = Bytes::from_static(&[0x10, 0x01]);
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+proto")
.body_frame(grpc_test_frame(&reply))
.trailer("grpc-status", "0");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/protobuf.Greeter/SayHello")
.codec(crate::GrpcCodec::Protobuf)
.message_bytes(request.clone())?
.await
})
.unwrap();
assert_eq!(block_on(response.message_bytes()).unwrap(), reply);
let recorded_requests = recorded_requests.lock().unwrap();
let recorded = &recorded_requests[0];
assert_eq!(
recorded.header("content-type"),
Some("application/grpc+proto")
);
assert_eq!(recorded.body(), grpc_test_frame(&request).as_slice());
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn decodes_gzip_compressed_grpc_json_response() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let request = HelloRequest {
name: "Ada".to_owned(),
};
let reply = HelloReply {
message: "hello zipped".to_owned(),
};
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.header("grpc-encoding", "gzip")
.body_frame(grpc_test_gzip_json_frame(&reply))
.trailer("grpc-status", "0");
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.message(&request)?
.await
})
.unwrap();
assert_eq!(block_on(response.message::<HelloReply>()).unwrap(), reply);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn encodes_gzip_compressed_grpc_json_request_body() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
ok: bool,
}
let request = HelloRequest {
name: "Ada".to_owned(),
};
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.body_frame(grpc_test_json_frame(&HelloReply { ok: true }))
.trailer("grpc-status", "0");
let (base, _, recorded_requests) =
block_on(spawn_h2_tls_recording_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/SayHello")
.compression("gzip")
.message(&request)?
.await
})
.unwrap();
assert_eq!(
block_on(response.message::<HelloReply>()).unwrap(),
HelloReply { ok: true }
);
let recorded_requests = recorded_requests.lock().unwrap();
let recorded = &recorded_requests[0];
assert_eq!(recorded.header("grpc-encoding"), Some("gzip"));
assert_eq!(recorded.header("grpc-accept-encoding"), Some("gzip"));
let decoded = decode_grpc_test_frames(recorded.body(), Some("gzip")).unwrap();
assert_eq!(decoded.len(), 1);
let decoded_request: HelloRequest = serde_json::from_slice(&decoded[0]).unwrap();
assert_eq!(decoded_request, request);
}
#[cfg(all(feature = "rustls", feature = "h2"))]
#[test]
fn exposes_grpc_error_status_from_trailers_only_response() {
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloRequest {
name: String,
}
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct HelloReply {
message: String,
}
let scripted = H2TestResponse::empty(200)
.header("content-type", "application/grpc+json")
.header("grpc-status", "5")
.header("grpc-message", "user%20not%20found");
let (base, _) = block_on(spawn_h2_tls_script_server(vec![vec![scripted]])).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let mut response = block_on(async {
client
.grpc("/helloworld.Greeter/GetUser")
.message(&HelloRequest {
name: "missing".to_owned(),
})?
.await
})
.unwrap();
let status = response.status().unwrap();
assert_eq!(status.code(), 5);
assert_eq!(status.message(), "user not found");
let trailers = response.trailers().unwrap().unwrap();
assert_eq!(trailers.get("grpc-status"), Some("5"));
assert_eq!(trailers.get("grpc-message"), Some("user%20not%20found"));
let err = block_on(response.message::<HelloReply>()).unwrap_err();
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
}
#[cfg(feature = "rustls")]
#[test]
fn performs_https_request_with_insecure_test_config() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(feature = "rustls")]
#[test]
fn request_can_override_client_tls_validation() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.build()
.unwrap();
let response = run(client.get("/secure").danger_accept_invalid_certs(true)).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(feature = "rustls")]
#[test]
fn request_can_override_client_tls_config_back_to_strict() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let result = run(client
.get("/secure")
.tls_config(crate::TlsConfig::default()));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Transport));
}
#[cfg(feature = "rustls")]
#[test]
fn pem_root_store_allows_self_signed_https_request() {
let (base, cert_pem) = block_on(spawn_tls_test_server_with_cert(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.root_store(crate::RootStore::pem(cert_pem))
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(all(feature = "rustls", feature = "native-tls"))]
#[test]
fn pem_root_store_allows_self_signed_https_request_with_native_tls() {
let (base, cert_pem) = block_on(spawn_tls_test_server_with_cert(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.tls_backend(crate::TlsBackend::Native)
.root_store(crate::RootStore::pem(cert_pem))
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(all(feature = "rustls", feature = "btls-backend"))]
#[test]
fn performs_https_request_with_insecure_test_config_with_boring() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.tls_backend(crate::TlsBackend::Boring)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(all(feature = "rustls", feature = "btls-backend"))]
#[test]
fn pem_root_store_allows_self_signed_https_request_with_boring() {
let (base, cert_pem) = block_on(spawn_tls_test_server_with_cert(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.tls_backend(crate::TlsBackend::Boring)
.root_store(crate::RootStore::pem(cert_pem))
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(all(feature = "rustls", feature = "btls-backend"))]
#[test]
fn boring_negotiates_http11_alpn_when_server_advertises_it() {
let (base, cert_pem) = block_on(super::spawn_tls_test_server_with_cert_and_alpn(
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
vec![b"http/1.1".to_vec()],
Some(b"http/1.1".to_vec()),
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.tls_backend(crate::TlsBackend::Boring)
.root_store(crate::RootStore::pem(cert_pem))
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[cfg(feature = "rustls")]
#[test]
fn pinned_certificate_allows_matching_self_signed_https_request() {
let (base, cert_pem) = block_on(spawn_tls_test_server_with_cert(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let cert = rustls_pemfile::certs(&mut std::io::BufReader::new(cert_pem.as_bytes()))
.unwrap()
.remove(0);
let fingerprint = sha2::Sha256::digest(&cert);
let fingerprint = fingerprint
.iter()
.map(|byte| format!("{byte:02x}"))
.collect::<String>();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.root_store(crate::RootStore::pem(cert_pem))
.pin_certificate("localhost", fingerprint)
.unwrap()
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(feature = "rustls")]
#[test]
fn pinned_certificate_rejects_mismatched_self_signed_https_request() {
let (base, cert_pem) = block_on(spawn_tls_test_server_with_cert(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.root_store(crate::RootStore::pem(cert_pem))
.pin_certificate(
"localhost",
"0000000000000000000000000000000000000000000000000000000000000000",
)
.unwrap()
.build()
.unwrap();
let result = run(client.get("/secure"));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Transport));
}
#[cfg(feature = "rustls")]
#[test]
fn reuses_https_keep_alive_connection_for_same_client() {
let (base, connection_count) = block_on(spawn_tls_keep_alive_server()).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one")).unwrap();
assert_eq!(block_on(first.text()).unwrap(), "ok");
let second = run(client.get("/two")).unwrap();
assert_eq!(block_on(second.text()).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 1);
}
#[cfg(all(feature = "rustls", feature = "native-tls"))]
#[test]
fn performs_https_request_with_native_tls_backend() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.tls_backend(crate::TlsBackend::Native)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/secure")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(all(feature = "rustls", feature = "native-tls"))]
#[test]
fn request_can_override_client_tls_backend() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get("/secure").tls_backend(crate::TlsBackend::Native)).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello");
}
#[cfg(all(feature = "rustls", feature = "native-tls"))]
#[test]
fn different_tls_backends_do_not_share_https_pool_entry() {
let (base, connection_count) = block_on(spawn_tls_connection_count_server(2, 1)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one").tls_backend(crate::TlsBackend::Rustls))
.unwrap()
.text();
let second = run(client.get("/two").tls_backend(crate::TlsBackend::Native))
.unwrap()
.text();
assert_eq!(block_on(first).unwrap(), "ok");
assert_eq!(block_on(second).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "rustls")]
#[test]
fn different_alpn_configs_do_not_share_https_pool_entry() {
let (base, connection_count) = block_on(spawn_tls_connection_count_server(2, 1)).unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let first = run(client.get("/one")).unwrap().text();
let second = run(client.get("/two").disable_alpn()).unwrap().text();
assert_eq!(block_on(first).unwrap(), "ok");
assert_eq!(block_on(second).unwrap(), "ok");
assert_eq!(connection_count.load(Ordering::SeqCst), 2);
}
#[cfg(feature = "rustls")]
#[test]
fn rejects_unsupported_custom_alpn_on_http1_tls_transport() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let result = run(crate::Client::builder()
.base_url(&base)
.unwrap()
.danger_accept_invalid_certs(true)
.build()
.unwrap()
.get("/secure")
.alpn_protocols(["h2"]));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Transport));
}
#[cfg(feature = "rustls")]
#[test]
fn https_rejects_invalid_certificate_by_default() {
let base = block_on(spawn_tls_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.build()
.unwrap();
let result = run(client.get("/secure"));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Transport));
}
#[test]
fn serializes_query_and_cookies() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.starts_with("GET /users?q=hello%20world HTTP/1.1\r\n"));
assert!(request.contains("\r\ncookie: session=a; theme=dark\r\n"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(get(format!("{base}/users"))
.query([("q", "hello world")])
.cookie("session", "a")
.cookie("theme", "dark"))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn reads_chunked_http_response() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n",
))
.unwrap();
let response = run(get(format!("{base}/chunked"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello world");
}
#[test]
fn reads_chunked_http_response_streaming() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n",
))
.unwrap();
let response = run(get(format!("{base}/chunked"))).unwrap();
let chunks = block_on(response.bytes_stream().collect::<Vec<_>>());
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].as_ref().unwrap(), &Bytes::from_static(b"hello"));
assert_eq!(chunks[1].as_ref().unwrap(), &Bytes::from_static(b" world"));
}
#[test]
fn reads_chunked_trailers() {
let base = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\nx-grpc-status: 0\r\nx-trace-id: abc\r\n\r\n",
))
.unwrap();
let response = run(get(format!("{base}/chunked"))).unwrap();
let (body, trailers) = block_on(response.bytes_and_trailers()).unwrap();
assert_eq!(body, Bytes::from_static(b"hello"));
let trailers = trailers.unwrap();
assert_eq!(trailers.get("x-grpc-status"), Some("0"));
assert_eq!(trailers.get("x-trace-id"), Some("abc"));
}
#[test]
fn follows_redirect_responses() {
let target = block_on(spawn_test_server(
"HTTP/1.1 200 OK\r\nContent-Length: 7\r\n\r\ntarget!",
))
.unwrap();
let redirect_response =
format!("HTTP/1.1 302 Found\r\nLocation: {target}/done\r\nContent-Length: 0\r\n\r\n");
let redirect = block_on(spawn_test_server(Box::leak(
redirect_response.into_boxed_str(),
)))
.unwrap();
let response = run(get(format!("{redirect}/start"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "target!");
}
#[test]
fn follows_relative_redirect_responses() {
let redirect = block_on(spawn_sequence_server(vec![
"HTTP/1.1 302 Found\r\nLocation: /done\r\nContent-Length: 0\r\n\r\n",
"HTTP/1.1 200 OK\r\nContent-Length: 8\r\n\r\nrelative",
]))
.unwrap();
let response = run(get(format!("{redirect}/start"))).unwrap();
let final_url = response.url().as_str().to_owned();
assert_eq!(block_on(response.text()).unwrap(), "relative");
assert_eq!(final_url, format!("{redirect}/done"));
}
#[test]
fn redirect_policy_none_returns_redirect_response() {
let redirect_response =
"HTTP/1.1 302 Found\r\nLocation: /done\r\nContent-Length: 0\r\n\r\n";
let redirect = block_on(spawn_test_server(redirect_response)).unwrap();
let response =
run(get(format!("{redirect}/start")).redirect(crate::RedirectPolicy::None)).unwrap();
assert_eq!(response.status(), StatusCode::FOUND);
assert_eq!(response.headers().get("location"), Some("/done"));
}
#[test]
fn redirect_limit_exceeded_returns_error() {
let loop_response = "HTTP/1.1 302 Found\r\nLocation: /loop\r\nContent-Length: 0\r\n\r\n";
let base = block_on(spawn_test_server(loop_response)).unwrap();
let result = run(get(format!("{base}/loop")).redirect(crate::RedirectPolicy::Limit(1)));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Transport));
}
#[test]
fn cross_origin_redirect_strips_authorization_header() {
let target = block_on(spawn_assert_server(
|request| {
assert!(!request.contains("\r\nauthorization: Bearer token-123\r\n"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let redirect_response =
format!("HTTP/1.1 302 Found\r\nLocation: {target}/next\r\nContent-Length: 0\r\n\r\n");
let redirect = block_on(spawn_test_server(Box::leak(
redirect_response.into_boxed_str(),
)))
.unwrap();
let client = crate::Client::builder()
.base_url(&redirect)
.unwrap()
.bearer_auth("token-123")
.unwrap()
.build()
.unwrap();
let response = run(client.get("/start")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn cross_origin_redirect_strips_client_builder_cookie() {
let target = block_on(spawn_assert_server(
|request| {
assert!(
!request
.lines()
.any(|line| line.to_lowercase().starts_with("cookie:")),
"cookie header must be absent at cross-origin target; got:\n{request}"
);
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let redirect_response =
format!("HTTP/1.1 302 Found\r\nLocation: {target}/next\r\nContent-Length: 0\r\n\r\n");
let redirect = block_on(spawn_test_server(Box::leak(
redirect_response.into_boxed_str(),
)))
.unwrap();
let client = crate::Client::builder()
.base_url(&redirect)
.unwrap()
.cookie("session", "very-secret")
.build()
.unwrap();
let response = run(client.get("/start")).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn temporary_redirect_preserves_method_and_body() {
let target = block_on(spawn_assert_server(
|request| {
assert!(request.starts_with("POST /final HTTP/1.1\r\n"));
assert!(request.ends_with("\r\n\r\npayload"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let redirect_response = format!(
"HTTP/1.1 307 Temporary Redirect\r\nLocation: {target}/final\r\nContent-Length: 0\r\n\r\n"
);
let redirect = block_on(spawn_test_server(Box::leak(
redirect_response.into_boxed_str(),
)))
.unwrap();
let response = run(crate::post(format!("{redirect}/start"))
.text("payload")
.unwrap())
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn temporary_redirect_rejects_non_replayable_body() {
let redirect_response =
"HTTP/1.1 307 Temporary Redirect\r\nLocation: /final\r\nContent-Length: 0\r\n\r\n";
let redirect = block_on(spawn_test_server(redirect_response)).unwrap();
let body_stream: crate::BodyStream =
Box::pin(stream::once(Ok(Bytes::from_static(b"payload"))));
let result =
run(crate::post(format!("{redirect}/start"))
.body(crate::Body::from_stream(body_stream)));
let err = result.expect_err("stream body redirect should fail");
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
assert!(err.to_string().contains("replayable request body"));
}
#[test]
fn see_other_redirect_switches_to_get_and_drops_body() {
let target = block_on(spawn_assert_server(
|request| {
assert!(request.starts_with("GET /final HTTP/1.1\r\n"));
assert!(request.ends_with("\r\n\r\n"));
assert!(!request.ends_with("\r\n\r\npayload"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let redirect_response = format!(
"HTTP/1.1 303 See Other\r\nLocation: {target}/final\r\nContent-Length: 0\r\n\r\n"
);
let redirect = block_on(spawn_test_server(Box::leak(
redirect_response.into_boxed_str(),
)))
.unwrap();
let response = run(crate::post(format!("{redirect}/start"))
.text("payload")
.unwrap())
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn head_redirect_keeps_head_method() {
let target = block_on(spawn_assert_server(
|request| {
assert!(request.starts_with("HEAD /final HTTP/1.1\r\n"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
))
.unwrap();
let redirect_response =
format!("HTTP/1.1 302 Found\r\nLocation: {target}/final\r\nContent-Length: 0\r\n\r\n");
let redirect = block_on(spawn_test_server(Box::leak(
redirect_response.into_boxed_str(),
)))
.unwrap();
let response = run(crate::head(format!("{redirect}/start"))).unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn request_timeout_triggers_error() {
let base = block_on(spawn_delayed_server(
std::time::Duration::from_millis(50),
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let result = run(get(format!("{base}/slow")).timeout(std::time::Duration::from_millis(10)));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Timeout));
}
#[test]
fn read_timeout_triggers_error() {
let base = block_on(spawn_delayed_server(
std::time::Duration::from_millis(50),
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let result =
run(get(format!("{base}/slow")).read_timeout(std::time::Duration::from_millis(10)));
assert!(matches!(result, Err(err) if err.kind() == &crate::ErrorKind::Timeout));
}
#[test]
fn sends_json_request_body() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.starts_with("POST /users HTTP/1.1\r\n"));
assert!(request.contains("\r\ncontent-type: application/json\r\n"));
assert!(request.ends_with("\r\n\r\n{\"name\":\"alice\"}"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(crate::post(format!("{base}/users"))
.json(&serde_json::json!({ "name": "alice" }))
.unwrap())
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn sends_chunked_body_for_stream() {
let base = spawn_chunked_assert_server(
|request| {
assert!(request.starts_with("POST /upload HTTP/1.1\r\n"));
assert!(request.contains("\r\ntransfer-encoding: chunked\r\n"));
assert!(!request.contains("\r\ncontent-length:"));
let body = request.split("\r\n\r\n").nth(1).unwrap_or("");
assert!(body.contains("hello"));
assert!(body.contains("world"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
)
.unwrap();
let stream = stream::iter(vec![
Ok(Bytes::from_static(b"hello")),
Ok(Bytes::from_static(b"world")),
]);
let response =
run(crate::post(format!("{base}/upload")).body_stream(Box::pin(stream))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn bytes_body_can_use_chunked_transfer_encoding() {
let base = spawn_chunked_assert_server(
|request| {
assert!(request.starts_with("POST /upload HTTP/1.1\r\n"));
assert!(request.contains("\r\ntransfer-encoding: chunked\r\n"));
assert!(!request.contains("\r\ncontent-length:"));
let body = request.split("\r\n\r\n").nth(1).unwrap_or("");
assert!(body.contains("hello"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
)
.unwrap();
let response = run(crate::post(format!("{base}/upload"))
.header("transfer-encoding", "chunked")
.unwrap()
.body("hello"))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn allows_stream_with_content_length_header() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.contains("\r\ncontent-length: 5\r\n"));
assert!(!request.contains("\r\ntransfer-encoding: chunked\r\n"));
assert!(request.ends_with("\r\n\r\nhello"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let stream = stream::iter(vec![Ok(Bytes::from_static(b"hello"))]);
let response = run(crate::post(format!("{base}/upload"))
.header("content-length", "5")
.unwrap()
.body_stream(Box::pin(stream)))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn sends_default_accept_encoding_header() {
let base = block_on(spawn_assert_server(
|request| {
let expected = format!("\r\naccept-encoding: {}\r\n", DEFAULT_ACCEPT_ENCODING);
assert!(request.contains(&expected));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(get(format!("{base}/encoding"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn http_proxy_forwards_http1_request() {
let proxy = block_on(spawn_http_proxy_server(
|request| {
assert!(request.starts_with("GET http://example.com/through-proxy HTTP/1.1\r\n"));
assert!(
request.contains("\r\nhost: example.com\r\n")
|| request.contains("\r\nHost: example.com\r\n")
);
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(crate::Client::builder()
.proxy(proxy)
.build()
.unwrap()
.get("http://example.com/through-proxy"))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn http_proxy_sends_basic_auth_header() {
let proxy = block_on(spawn_http_proxy_server(
|request| {
let lower = request.to_ascii_lowercase();
assert!(lower.contains("\r\nproxy-authorization: basic dxnlcjpwyxnz\r\n"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let proxy = match proxy {
crate::Proxy::Http { addr, .. } => crate::Proxy::http_with_auth(addr, "user", "pass"),
_ => unreachable!(),
};
let response = run(crate::Client::builder()
.proxy(proxy)
.build()
.unwrap()
.get("http://example.com/auth-proxy"))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn socks5_proxy_forwards_http1_request() {
let proxy = block_on(spawn_socks5_proxy_server(
None,
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(crate::Client::builder()
.proxy(proxy)
.build()
.unwrap()
.get("http://example.com/socks"))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn socks5_proxy_supports_username_password_auth() {
let proxy = block_on(spawn_socks5_proxy_server(
Some(("user", "pass")),
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(crate::Client::builder()
.proxy(proxy)
.build()
.unwrap()
.get("http://example.com/socks-auth"))
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn respects_custom_accept_encoding_header() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.contains("\r\naccept-encoding: gzip\r\n"));
let default = format!("\r\naccept-encoding: {}\r\n", DEFAULT_ACCEPT_ENCODING);
assert!(!request.contains(&default));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(get(format!("{base}/encoding"))
.header("accept-encoding", "gzip")
.unwrap())
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn decodes_gzip_response_body() {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"hello gzip").unwrap();
let compressed = encoder.finish().unwrap();
let mut response = format!(
"HTTP/1.1 200 OK\r\nContent-Encoding: gzip\r\nContent-Length: {}\r\n\r\n",
compressed.len()
)
.into_bytes();
response.extend_from_slice(&compressed);
let base = block_on(spawn_bytes_server(response)).unwrap();
let response = run(get(format!("{base}/gzip"))).unwrap();
assert_eq!(response.headers().get("content-encoding"), None);
assert_eq!(response.headers().get("content-length"), None);
assert_eq!(block_on(response.text()).unwrap(), "hello gzip");
}
#[test]
fn decodes_deflate_response_body() {
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"hello deflate").unwrap();
let compressed = encoder.finish().unwrap();
let mut response = format!(
"HTTP/1.1 200 OK\r\nContent-Encoding: deflate\r\nContent-Length: {}\r\n\r\n",
compressed.len()
)
.into_bytes();
response.extend_from_slice(&compressed);
let base = block_on(spawn_bytes_server(response)).unwrap();
let response = run(get(format!("{base}/deflate"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello deflate");
}
#[cfg(feature = "brotli")]
#[test]
fn decodes_brotli_response_body() {
let mut compressed = Vec::new();
{
let mut compressor = brotli::CompressorWriter::new(&mut compressed, 4096, 5, 22);
compressor.write_all(b"hello br").unwrap();
compressor.flush().unwrap();
}
let mut response = format!(
"HTTP/1.1 200 OK\r\nContent-Encoding: br\r\nContent-Length: {}\r\n\r\n",
compressed.len()
)
.into_bytes();
response.extend_from_slice(&compressed);
let base = block_on(spawn_bytes_server(response)).unwrap();
let response = run(get(format!("{base}/br"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello br");
}
#[cfg(feature = "zstd")]
#[test]
fn decodes_zstd_response_body() {
let compressed = zstd::stream::encode_all(&b"hello zstd"[..], 0).unwrap();
let mut response = format!(
"HTTP/1.1 200 OK\r\nContent-Encoding: zstd\r\nContent-Length: {}\r\n\r\n",
compressed.len()
)
.into_bytes();
response.extend_from_slice(&compressed);
let base = block_on(spawn_bytes_server(response)).unwrap();
let response = run(get(format!("{base}/zstd"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "hello zstd");
}
#[test]
fn local_addr_binds_http_connection_source_port() {
let (base, peer_addr) = block_on(spawn_peer_addr_server()).unwrap();
let reserved = std::net::TcpListener::bind(("127.0.0.1", 0)).unwrap();
let local_addr = reserved.local_addr().unwrap();
drop(reserved);
let client = crate::Client::builder()
.local_addr(local_addr)
.build()
.unwrap();
let response = run(client.get(format!("{base}/bound"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
assert_eq!(peer_addr.lock().unwrap().unwrap().port(), local_addr.port());
}
#[test]
fn sends_form_and_auth_headers() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.starts_with("POST /login HTTP/1.1\r\n"));
assert!(
request.contains("\r\ncontent-type: application/x-www-form-urlencoded\r\n")
);
assert!(request.contains("\r\nauthorization: Basic dXNlcjpwYXNz\r\n"));
assert!(request.ends_with("\r\n\r\nname=alice%20bob&role=admin"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
))
.unwrap();
let response = run(crate::post(format!("{base}/login"))
.basic_auth("user", "pass")
.unwrap()
.form([("name", "alice bob"), ("role", "admin")])
.unwrap())
.unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
}
#[test]
fn client_default_bearer_auth_is_sent() {
let base = block_on(spawn_assert_server(
|request| {
assert!(request.contains("\r\nauthorization: Bearer token-123\r\n"));
},
"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
))
.unwrap();
let client = crate::Client::builder()
.base_url(&base)
.unwrap()
.bearer_auth("token-123")
.unwrap()
.build()
.unwrap();
let _ = run(client.get("/me")).unwrap();
}
#[test]
fn ordered_headers_respects_profile_order() {
use crate::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert("x-a", "1").unwrap();
headers.insert("x-b", "2").unwrap();
headers.insert("user-agent", "ua").unwrap();
let order = vec!["x-b".to_owned(), "user-agent".to_owned()];
let names: Vec<_> = super::ordered_headers(&headers, Vec::new(), Some(order.as_slice()))
.into_iter()
.map(|header| header.lower_name)
.collect();
let pos = |name: &str| names.iter().position(|n| n == name).unwrap();
assert!(pos("x-b") < pos("user-agent"));
assert!(pos("user-agent") < pos("x-a"));
}
#[test]
fn encode_request_orders_generated_host_header_with_profile_order() {
use crate::Body;
use crate::CompressionMode;
use crate::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert("x-a", "1").unwrap();
headers.insert("user-agent", "ua").unwrap();
let fingerprint = super::Http1Fingerprint {
header_order: vec!["host".to_owned(), "user-agent".to_owned(), "x-a".to_owned()],
original_header_case: Vec::new(),
};
let (encoded, _, _) = super::encode_request(
"GET",
&crate::Url::parse("http://example.com/path").unwrap(),
&headers,
&[],
CompressionMode::Disabled,
Body::empty().into_data().unwrap(),
Some(&fingerprint),
)
.unwrap();
let request = String::from_utf8(encoded).unwrap();
let host_pos = request.find("\r\nhost: example.com\r\n").unwrap();
let user_agent_pos = request.find("\r\nuser-agent: ua\r\n").unwrap();
let x_a_pos = request.find("\r\nx-a: 1\r\n").unwrap();
assert!(host_pos < user_agent_pos);
assert!(user_agent_pos < x_a_pos);
}
#[test]
fn encode_request_restores_original_header_case_for_profile_overrides() {
use crate::Body;
use crate::CompressionMode;
use crate::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert("user-agent", "custom-ua").unwrap();
let fingerprint = super::Http1Fingerprint {
header_order: vec![
"host".to_owned(),
"user-agent".to_owned(),
"accept-encoding".to_owned(),
"cookie".to_owned(),
"content-length".to_owned(),
],
original_header_case: vec![
("host".to_owned(), "Host".to_owned()),
("user-agent".to_owned(), "User-Agent".to_owned()),
("accept-encoding".to_owned(), "Accept-Encoding".to_owned()),
("cookie".to_owned(), "Cookie".to_owned()),
("content-length".to_owned(), "Content-Length".to_owned()),
],
};
let (encoded, _, _) = super::encode_request(
"POST",
&crate::Url::parse("http://example.com/upload").unwrap(),
&headers,
&[("sid".to_owned(), "abc".to_owned())],
CompressionMode::Auto,
Body::from("hello").into_data().unwrap(),
Some(&fingerprint),
)
.unwrap();
let request = String::from_utf8(encoded).unwrap();
assert!(request.contains("\r\nHost: example.com\r\n"));
assert!(request.contains("\r\nUser-Agent: custom-ua\r\n"));
let accept_encoding = format!("\r\nAccept-Encoding: {}\r\n", DEFAULT_ACCEPT_ENCODING);
assert!(request.contains(&accept_encoding));
assert!(request.contains("\r\nCookie: sid=abc\r\n"));
assert!(request.contains("\r\nContent-Length: 5\r\n"));
assert!(!request.contains("\r\nuser-agent:"));
assert!(!request.contains("\r\ncontent-length:"));
}
}