use std::collections::{HashMap, HashSet, VecDeque};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Instant;
use async_channel::{Receiver, Sender, TryRecvError};
use async_io::Timer;
use bytes::Bytes;
use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
use futures_lite::{StreamExt, future};
use crate::body::{Body, BodyData, BodyStream};
use crate::browser_emulation::{
BrowserProfile, Http2Fingerprint, Http2PriorityPhase, Http2PrioritySpec,
};
use crate::decode::CompressionMode;
use crate::decode::{DEFAULT_ACCEPT_ENCODING, decode_response_body, maybe_decode_response_body};
use crate::dns::{DnsCache, DnsConfig};
use crate::error::{Error, ErrorKind, Result};
use crate::header::HeaderMap;
use crate::metrics::Metrics;
use crate::pool::{PoolConfig, PoolKey};
use crate::progress::{ProgressConfig, ProgressPhase, ProgressReporter};
use crate::proxy::Proxy;
use crate::request::H2KeepAliveConfig;
use crate::request::{Method, ProgressCallback, ProtocolPolicy, Request, TimeoutConfig};
use crate::response::{Response, StatusCode, TrailerState, Version};
use crate::retry::should_retry_stale_connection;
use crate::tls::TlsConfig;
use crate::url::Url;
use crate::util::{parse_content_length, response_body_allowed};
use super::frame::{
DEFAULT_HEADER_TABLE_SIZE, DEFAULT_INITIAL_WINDOW_SIZE, DEFAULT_MAX_FRAME_SIZE, Flags, Frame,
FrameType, GoAwayFrame, HeaderCodec, MAX_FLOW_CONTROL_WINDOW, Setting,
build_request_header_list, client_settings_payload_with_fingerprint, header_list_size,
read_frame, read_header_block, read_push_promise, write_data_frame, write_headers_frames,
write_ping, write_ping_ack, write_preface_and_settings, write_priority_frame, write_rst_stream,
write_settings_ack, write_window_update,
};
use super::transport::{BoxedStream, connect_h2_stream, with_read_prefix, with_timeout_io};
const MAX_CLIENT_STREAM_ID: u32 = (1 << 31) - 1;
const H2_ERROR_CANCEL: u32 = 0x8;
const H2_ERROR_REFUSED_STREAM: u32 = 0x7;
pub(super) async fn execute(
request: Request,
pool: Arc<Mutex<super::ConnectionPool>>,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
pool_config: PoolConfig,
) -> Result<Response> {
let prepared = PreparedRequest::from_request(request).await?;
let retry_prepared = prepared.try_clone();
let settings_payload =
client_settings_payload_with_fingerprint(prepared.http2_fingerprint.as_ref())?;
let pool_key = if prepared.url.scheme() == "http" {
PoolKey::for_h2c(
&prepared.url,
prepared.h2_keepalive_config,
prepared.emulation_profile.as_ref(),
&settings_payload,
)?
} else {
PoolKey::for_h2(
&prepared.url,
&prepared.tls_config,
prepared.h2_keepalive_config,
prepared.emulation_profile.as_ref(),
&settings_payload,
)?
};
let method = prepared.method;
let pooled = pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.acquire(&pool_key, pool_config);
let reused_connection = pooled.is_some();
let (connection, pooled_connection) = match pooled {
Some(connection) => (connection, true),
None => {
if prepared.url.scheme() == "http" && !prepared.prior_knowledge_h2c {
match H2ClientConnection::connect_via_h2c_upgrade(prepared).await? {
H2cConnectResult::Upgraded {
connection,
response,
} => {
let pooled_connection = pool
.lock()
.unwrap_or_else(|err| err.into_inner())
.insert(pool_key, connection.clone(), pool_config);
if !pooled_connection {
connection.close_when_idle();
}
return Ok(response);
}
H2cConnectResult::Http1(response) => return Ok(response),
}
}
let connection = H2ClientConnection::connect(
&prepared.url,
prepared.timeout_config,
&prepared.tls_config,
prepared.prior_knowledge_h2c,
prepared.h2_keepalive_config,
prepared.http2_fingerprint.clone(),
Arc::clone(&dns_cache),
dns_config,
local_addr,
prepared.proxy.as_ref(),
)
.await?;
let pooled_connection = pool.lock().unwrap_or_else(|err| err.into_inner()).insert(
pool_key.clone(),
connection.clone(),
pool_config,
);
if !pooled_connection {
connection.close_when_idle();
}
(connection, pooled_connection)
}
};
match connection.execute(prepared).await {
Ok(response) => Ok(annotate_metrics(response, reused_connection)),
Err(err) if should_retry_stale_connection(method, &err) => {
let Some(prepared) = retry_prepared else {
if !pooled_connection {
connection.close();
}
return Err(err);
};
connection.close();
let connection = H2ClientConnection::connect(
&prepared.url,
prepared.timeout_config,
&prepared.tls_config,
prepared.prior_knowledge_h2c,
prepared.h2_keepalive_config,
prepared.http2_fingerprint.clone(),
Arc::clone(&dns_cache),
dns_config,
local_addr,
prepared.proxy.as_ref(),
)
.await?;
let pooled_retry =
pool.lock()
.unwrap()
.insert(pool_key, connection.clone(), pool_config);
if !pooled_retry {
connection.close_when_idle();
}
let result = connection
.execute(prepared)
.await
.map(|response| annotate_metrics(response, false));
if result.is_err() && !pooled_retry {
connection.close();
}
result
}
Err(err) => {
if !pooled_connection {
connection.close();
}
Err(err)
}
}
}
fn annotate_metrics(response: Response, reused_connection: bool) -> Response {
let version = response.version();
response.with_metrics(
Metrics::default()
.with_protocol(version)
.with_reused_connection(reused_connection),
)
}
enum PreparedRequestBody {
Bytes(Bytes),
Stream {
stream: Option<BodyStream>,
},
}
impl PreparedRequestBody {
fn try_clone(&self) -> Option<Self> {
match self {
Self::Bytes(bytes) => Some(Self::Bytes(bytes.clone())),
Self::Stream { .. } => None,
}
}
fn content_length(&self) -> Option<u64> {
match self {
Self::Bytes(bytes) => Some(bytes.len() as u64),
Self::Stream { .. } => None,
}
}
fn into_upgrade_bytes(self) -> Result<Bytes> {
match self {
Self::Bytes(bytes) => Ok(bytes),
Self::Stream { .. } => Err(Error::new(
ErrorKind::Transport,
"h2c upgrade requires a buffered request body; use prior_knowledge_h2c(true) for streaming uploads",
)),
}
}
fn attach_waker(&self, _sender: Sender<ConnectionCommand>) {
}
fn abort_upload(&self) {
}
fn ends_stream_with_headers(&self) -> bool {
matches!(self, Self::Bytes(bytes) if bytes.is_empty())
}
fn upload_total(&self) -> Option<usize> {
match self {
Self::Bytes(bytes) => Some(bytes.len()),
Self::Stream { .. } => None,
}
}
fn into_active(
self,
connection_sender: Sender<ConnectionCommand>,
) -> Result<ActiveRequestBody> {
match self {
Self::Bytes(bytes) => Ok(ActiveRequestBody::Bytes { data: bytes }),
Self::Stream { stream } => {
let mut stream_body = stream.expect("stream body already consumed");
let (sender, receiver) = async_channel::unbounded();
let (abort_sender, abort_receiver) = async_channel::bounded::<()>(1);
std::thread::Builder::new()
.name("request-h2-upload".to_owned())
.spawn(move || {
async_io::block_on(async move {
loop {
enum UploadEvent {
Chunk(Option<Result<Bytes>>),
Abort,
}
let event = future::or(
async { UploadEvent::Chunk(stream_body.next().await) },
async {
let _ = abort_receiver.recv().await;
UploadEvent::Abort
},
)
.await;
match event {
UploadEvent::Chunk(Some(chunk)) => {
if sender.send(chunk.map(Some)).await.is_err() {
return;
}
let _ =
connection_sender.try_send(ConnectionCommand::Drive);
}
UploadEvent::Chunk(None) => {
if sender.send(Ok(None)).await.is_ok() {
let _ = connection_sender
.try_send(ConnectionCommand::Drive);
}
return;
}
UploadEvent::Abort => return,
}
}
});
})
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to spawn http2 upload task",
err,
)
})?;
Ok(ActiveRequestBody::Stream {
receiver,
abort_sender,
buffer: Bytes::new(),
offset: 0,
finished: false,
})
}
}
}
}
struct PreparedRequest {
method: Method,
url: Url,
headers: HeaderMap,
cookies: Vec<(String, String)>,
timeout_config: TimeoutConfig,
protocol_policy: ProtocolPolicy,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
h2_keepalive_config: H2KeepAliveConfig,
tls_config: TlsConfig,
prior_knowledge_h2c: bool,
compression_mode: CompressionMode,
proxy: Option<Proxy>,
body: PreparedRequestBody,
http2_fingerprint: Option<Http2Fingerprint>,
emulation_profile: Option<BrowserProfile>,
}
impl PreparedRequest {
async fn from_request(request: Request) -> Result<Self> {
let (
method,
url,
headers,
cookies,
timeout_config,
protocol_policy,
_retry_policy,
prior_knowledge_h2c,
progress_callback,
progress_config,
h2_keepalive_config,
tls_config,
proxy,
compression_mode,
body,
browser_profile,
) = request.into_parts();
if url.scheme() != "https" && url.scheme() != "http" {
return Err(Error::new(
ErrorKind::Transport,
"http2 transport requires https or cleartext http",
));
}
let body = match body.into_data()? {
BodyData::Bytes(bytes) => PreparedRequestBody::Bytes(bytes),
BodyData::Stream(stream_body) => PreparedRequestBody::Stream {
stream: Some(stream_body),
},
};
Ok(Self {
method,
url,
headers,
cookies,
timeout_config,
protocol_policy,
progress_callback,
progress_config,
h2_keepalive_config,
tls_config,
prior_knowledge_h2c,
compression_mode,
proxy,
body,
emulation_profile: browser_profile.clone(),
http2_fingerprint: browser_profile.and_then(|p| p.http2_fingerprint().cloned()),
})
}
fn try_clone(&self) -> Option<Self> {
Some(Self {
method: self.method,
url: self.url.clone(),
headers: self.headers.clone(),
cookies: self.cookies.clone(),
timeout_config: self.timeout_config,
protocol_policy: self.protocol_policy,
progress_callback: self.progress_callback.clone(),
progress_config: self.progress_config,
h2_keepalive_config: self.h2_keepalive_config,
tls_config: self.tls_config.clone(),
prior_knowledge_h2c: self.prior_knowledge_h2c,
compression_mode: self.compression_mode,
proxy: self.proxy.clone(),
body: self.body.try_clone()?,
http2_fingerprint: self.http2_fingerprint.clone(),
emulation_profile: self.emulation_profile.clone(),
})
}
}
#[derive(Clone)]
pub(super) struct H2ClientConnection {
shared: Arc<SharedConnection>,
}
enum H2cConnectResult {
Upgraded {
connection: H2ClientConnection,
response: Response,
},
Http1(Response),
}
impl H2ClientConnection {
async fn connect(
url: &Url,
timeout_config: TimeoutConfig,
tls_config: &TlsConfig,
prior_knowledge_h2c: bool,
h2_keepalive_config: H2KeepAliveConfig,
http2_fingerprint: Option<Http2Fingerprint>,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
proxy: Option<&Proxy>,
) -> Result<Self> {
let core = ConnectionCore::connect(
url,
timeout_config,
tls_config,
prior_knowledge_h2c,
http2_fingerprint,
dns_cache,
dns_config,
local_addr,
proxy,
)
.await?;
let (sender, receiver) = async_channel::unbounded();
let shared = Arc::new(SharedConnection::new(
sender,
h2_keepalive_config.idle_timeout,
));
let task_shared = Arc::clone(&shared);
let _ = std::thread::Builder::new()
.name(format!(
"request-h2-{}:{}",
url.host(),
url.effective_port()
))
.spawn(move || {
async_io::block_on(async move {
run_connection_task(core, h2_keepalive_config, receiver, task_shared).await;
});
})
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to spawn h2 task", err)
})?;
Ok(Self { shared })
}
async fn connect_via_h2c_upgrade(request: PreparedRequest) -> Result<H2cConnectResult> {
let PreparedRequest {
method,
url,
headers,
cookies,
timeout_config,
protocol_policy,
progress_callback,
progress_config,
h2_keepalive_config,
tls_config: _,
prior_knowledge_h2c: _,
compression_mode: _,
proxy: _,
body,
http2_fingerprint,
emulation_profile: _,
} = request;
let body = body.into_upgrade_bytes()?;
let settings_payload =
client_settings_payload_with_fingerprint(http2_fingerprint.as_ref())?;
let request_head = encode_h2c_upgrade_request(
method.as_str(),
&url,
&headers,
&cookies,
&body,
&settings_payload,
)?;
let connect_future = async_net::TcpStream::connect((url.host(), url.effective_port()));
let mut stream =
with_timeout_io(timeout_config.connect, connect_future, "connect timed out").await?;
with_timeout_io(
timeout_config.write,
stream.write_all(&request_head),
"write timed out",
)
.await?;
if !body.is_empty() {
with_timeout_io(
timeout_config.write,
stream.write_all(&body),
"write timed out",
)
.await?;
}
with_timeout_io(timeout_config.write, stream.flush(), "write timed out").await?;
let upgrade = read_h2c_upgrade_response(&mut stream, timeout_config).await?;
if upgrade.status != StatusCode::new(101) {
if protocol_policy == ProtocolPolicy::PreferHttp2 {
let response = read_http1_response_after_upgrade_rejection(
stream,
method,
url,
timeout_config,
progress_callback,
progress_config,
upgrade,
)
.await?;
return Ok(H2cConnectResult::Http1(response));
}
return Err(Error::new(
ErrorKind::Transport,
format!(
"server rejected h2c upgrade with status {}",
upgrade.status.as_u16()
),
));
}
validate_h2c_upgrade_response(&upgrade.headers)?;
let stream = with_read_prefix(Box::new(stream), upgrade.remaining);
let core =
ConnectionCore::connect_upgraded(stream, timeout_config, http2_fingerprint).await?;
let (sender, receiver) = async_channel::unbounded();
let shared = Arc::new(SharedConnection::new(
sender,
h2_keepalive_config.idle_timeout,
));
shared.active_requests.fetch_add(1, Ordering::SeqCst);
let task_shared = Arc::clone(&shared);
let (response_tx, response_rx) = async_channel::bounded(1);
let mut upload_progress = progress_callback.clone().map(|callback| {
ProgressReporter::new(
callback,
ProgressPhase::Upload,
Some(body.len()),
progress_config,
)
});
if let Some(reporter) = upload_progress.as_mut() {
reporter.record(body.len());
reporter.finish();
}
let download_progress = progress_callback.clone().map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
let local_stream_window = core.state.local_initial_stream_window_size;
let initial_stream = ActiveStream::from_upgraded_request(
url.clone(),
method,
upload_progress,
download_progress,
response_tx,
local_stream_window,
);
let _ = std::thread::Builder::new()
.name(format!(
"request-h2-{}:{}",
url.host(),
url.effective_port()
))
.spawn(move || {
async_io::block_on(async move {
run_connection_task_with_initial_stream(
core,
h2_keepalive_config,
receiver,
task_shared,
1,
initial_stream,
)
.await;
});
})
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to spawn h2 task", err)
})?;
let connection = Self { shared };
let response = response_rx.recv().await.map_err(|_| {
Error::new(
ErrorKind::Transport,
"http2 connection closed before upgraded response headers arrived",
)
})??;
Ok(H2cConnectResult::Upgraded {
connection,
response,
})
}
async fn execute(&self, request: PreparedRequest) -> Result<Response> {
if self.is_closed() {
return Err(Error::new(
ErrorKind::StaleConnection,
"http2 connection is closed",
));
}
self.shared.active_requests.fetch_add(1, Ordering::SeqCst);
self.touch();
let (response_tx, response_rx) = async_channel::bounded(1);
self.shared
.sender
.send(ConnectionCommand::Execute(ExecuteCommand {
request,
response_tx,
}))
.await
.map_err(|_| {
self.shared.active_requests.fetch_sub(1, Ordering::SeqCst);
self.touch();
Error::new(
ErrorKind::StaleConnection,
"failed to submit request to http2 connection",
)
})?;
let result = response_rx.recv().await.map_err(|_| {
self.touch();
Error::new(
ErrorKind::StaleConnection,
"http2 connection task stopped before completing the request",
)
})?;
self.touch();
result
}
pub(super) fn load(&self) -> usize {
self.shared.active_requests.load(Ordering::SeqCst)
}
pub(super) fn ptr_eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.shared, &other.shared)
}
pub(super) fn can_accept_new_stream(&self) -> bool {
!self.is_closed() && self.load() < self.max_concurrent_streams()
}
pub(super) fn should_evict(&self, now: Instant, pool_config: PoolConfig) -> bool {
if self.is_closed() {
return true;
}
if self.load() > 0 {
return false;
}
match pool_config.idle_timeout {
Some(timeout) => {
now.duration_since(
*self
.shared
.last_used
.lock()
.unwrap_or_else(|err| err.into_inner()),
) > timeout
}
None => false,
}
}
fn is_closed(&self) -> bool {
self.shared.closed.load(Ordering::SeqCst) || self.shared.sender.is_closed()
}
pub(super) fn close(&self) {
self.shared.closed.store(true, Ordering::SeqCst);
self.shared.sender.close();
}
fn close_when_idle(&self) {
self.shared.close_when_idle.store(true, Ordering::SeqCst);
}
fn max_concurrent_streams(&self) -> usize {
self.shared.max_concurrent_streams.load(Ordering::SeqCst)
}
fn touch(&self) {
*self
.shared
.last_used
.lock()
.unwrap_or_else(|err| err.into_inner()) = Instant::now();
}
}
struct SharedConnection {
sender: Sender<ConnectionCommand>,
keepalive_idle_timeout: Option<std::time::Duration>,
active_requests: AtomicUsize,
max_concurrent_streams: AtomicUsize,
closed: AtomicBool,
close_when_idle: AtomicBool,
last_used: Mutex<Instant>,
}
impl SharedConnection {
fn new(
sender: Sender<ConnectionCommand>,
keepalive_idle_timeout: Option<std::time::Duration>,
) -> Self {
Self {
sender,
keepalive_idle_timeout,
active_requests: AtomicUsize::new(0),
max_concurrent_streams: AtomicUsize::new(usize::MAX),
closed: AtomicBool::new(false),
close_when_idle: AtomicBool::new(false),
last_used: Mutex::new(Instant::now()),
}
}
}
enum ConnectionCommand {
Execute(ExecuteCommand),
CancelStream { stream_id: u32 },
Drive,
}
struct ExecuteCommand {
request: PreparedRequest,
response_tx: Sender<Result<Response>>,
}
struct ConnectionCore {
stream: BoxedStream,
timeout_config: TimeoutConfig,
header_codec: HeaderCodec,
state: ConnectionState,
fingerprint: Option<Http2Fingerprint>,
}
impl ConnectionCore {
async fn connect(
url: &Url,
timeout_config: TimeoutConfig,
tls_config: &TlsConfig,
prior_knowledge_h2c: bool,
fingerprint: Option<Http2Fingerprint>,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
proxy: Option<&Proxy>,
) -> Result<Self> {
let stream = connect_h2_stream(
url,
timeout_config,
tls_config,
prior_knowledge_h2c,
dns_cache,
dns_config,
local_addr,
proxy,
)
.await?;
let mut core = Self {
stream,
timeout_config,
header_codec: HeaderCodec::new(),
state: ConnectionState::default(),
fingerprint,
};
core.handshake().await?;
Ok(core)
}
async fn connect_upgraded(
stream: BoxedStream,
timeout_config: TimeoutConfig,
fingerprint: Option<Http2Fingerprint>,
) -> Result<Self> {
let mut core = Self {
stream,
timeout_config,
header_codec: HeaderCodec::new(),
state: ConnectionState::default(),
fingerprint,
};
core.handshake().await?;
Ok(core)
}
async fn handshake(&mut self) -> Result<()> {
self.apply_local_fingerprint()?;
let settings_payload = client_settings_payload_with_fingerprint(self.fingerprint.as_ref())?;
write_preface_and_settings(&mut self.stream, self.timeout_config, settings_payload).await?;
if let Some(fp) = self.fingerprint.as_ref() {
if let Some(size) = fp.initial_connection_window_size {
let increment = size.saturating_sub(DEFAULT_INITIAL_WINDOW_SIZE as u32);
if increment > 0 {
write_window_update(&mut self.stream, 0, increment, self.timeout_config)
.await?;
}
}
}
Ok(())
}
fn apply_local_fingerprint(&mut self) -> Result<()> {
if let Some(fp) = self.fingerprint.as_ref() {
self.state.apply_local_fingerprint(fp)?;
if let Some(size) = fp.header_table_size {
self.header_codec.set_max_decoder_table_size(size as usize);
}
}
Ok(())
}
}
struct ConnectionTask {
core: ConnectionCore,
keepalive_config: H2KeepAliveConfig,
receiver: Receiver<ConnectionCommand>,
shared: Arc<SharedConnection>,
pending: VecDeque<ExecuteCommand>,
pending_resets: VecDeque<u32>,
in_flight: HashMap<u32, ActiveStream>,
locally_reset_streams: HashSet<u32>,
accepting_new_streams: bool,
keepalive_outstanding: Option<([u8; 8], Instant)>,
next_keepalive_nonce: u64,
scheduled_keepalive_wake: Option<Instant>,
}
impl ConnectionTask {
fn new(
core: ConnectionCore,
keepalive_config: H2KeepAliveConfig,
receiver: Receiver<ConnectionCommand>,
shared: Arc<SharedConnection>,
) -> Self {
Self {
core,
keepalive_config,
receiver,
shared,
pending: VecDeque::new(),
pending_resets: VecDeque::new(),
in_flight: HashMap::new(),
locally_reset_streams: HashSet::new(),
accepting_new_streams: true,
keepalive_outstanding: None,
next_keepalive_nonce: 1,
scheduled_keepalive_wake: None,
}
}
fn new_with_initial_stream(
mut core: ConnectionCore,
keepalive_config: H2KeepAliveConfig,
receiver: Receiver<ConnectionCommand>,
shared: Arc<SharedConnection>,
stream_id: u32,
stream: ActiveStream,
) -> Self {
core.state.reserve_open_stream(stream_id);
let mut in_flight = HashMap::new();
in_flight.insert(stream_id, stream);
Self {
core,
keepalive_config,
receiver,
shared,
pending: VecDeque::new(),
pending_resets: VecDeque::new(),
in_flight,
locally_reset_streams: HashSet::new(),
accepting_new_streams: true,
keepalive_outstanding: None,
next_keepalive_nonce: 1,
scheduled_keepalive_wake: None,
}
}
async fn run(&mut self) -> Result<()> {
loop {
self.drain_commands();
self.flush_pending_resets().await?;
self.open_pending_streams().await?;
self.flush_outbound_data().await?;
if self.shared.close_when_idle.load(Ordering::SeqCst)
&& self.pending.is_empty()
&& self.in_flight.is_empty()
{
return Ok(());
}
self.refresh_keepalive_wake();
self.check_keepalive_timeout()?;
self.handle_keepalive_timer().await?;
self.schedule_keepalive_wake();
if self.receiver.is_closed() {
if self.pending.is_empty() && self.in_flight.is_empty() {
return Ok(());
}
let event = LoopEvent::Frame(
read_frame(
&mut self.core.stream,
self.core.timeout_config,
self.core.state.local_max_frame_size,
)
.await,
);
self.handle_loop_event(event).await?;
continue;
}
let receiver = &self.receiver;
let event = future::or(async { LoopEvent::Command(receiver.recv().await) }, async {
LoopEvent::Frame(
read_frame(
&mut self.core.stream,
self.core.timeout_config,
self.core.state.local_max_frame_size,
)
.await,
)
})
.await;
self.handle_loop_event(event).await?;
}
}
async fn handle_loop_event(&mut self, event: LoopEvent) -> Result<()> {
match event {
LoopEvent::Command(Ok(command)) => {
self.enqueue_command(command);
}
LoopEvent::Command(Err(_)) => {
if self.pending.is_empty() && self.in_flight.is_empty() {
return Ok(());
}
}
LoopEvent::Frame(Ok(frame)) => {
self.touch();
self.process_frame(frame).await?;
}
LoopEvent::Frame(Err(err)) => return Err(err),
}
Ok(())
}
fn touch(&self) {
*self
.shared
.last_used
.lock()
.unwrap_or_else(|err| err.into_inner()) = Instant::now();
}
fn refresh_keepalive_wake(&mut self) {
if self
.scheduled_keepalive_wake
.map(|deadline| Instant::now() >= deadline)
.unwrap_or(false)
{
self.scheduled_keepalive_wake = None;
}
}
fn schedule_keepalive_wake(&mut self) {
let next_wake = self
.should_poll_keepalive()
.then(|| self.next_keepalive_deadline());
if self.scheduled_keepalive_wake == next_wake {
return;
}
self.scheduled_keepalive_wake = next_wake;
let Some(deadline) = next_wake else {
return;
};
let sender = self.shared.sender.clone();
std::thread::Builder::new()
.name("h2-keepalive-wake".to_owned())
.spawn(move || {
async_io::block_on(async move {
if let Some(duration) = deadline.checked_duration_since(Instant::now()) {
Timer::after(duration).await;
}
let _ = sender.try_send(ConnectionCommand::Drive);
});
})
.ok();
}
fn should_poll_keepalive(&self) -> bool {
self.keepalive_outstanding.is_some()
|| (self.keepalive_config.idle_timeout.is_some()
&& self.pending.is_empty()
&& self.is_keepalive_idle())
}
fn is_keepalive_idle(&self) -> bool {
self.shared.active_requests.load(Ordering::SeqCst) == 0
}
fn next_keepalive_deadline(&self) -> Instant {
if let Some((_, sent_at)) = self.keepalive_outstanding {
sent_at + self.keepalive_config.ack_timeout
} else {
*self
.shared
.last_used
.lock()
.unwrap_or_else(|err| err.into_inner())
+ self
.keepalive_config
.idle_timeout
.unwrap_or(self.keepalive_config.ack_timeout)
}
}
fn check_keepalive_timeout(&self) -> Result<()> {
if let Some((_, sent_at)) = self.keepalive_outstanding {
if Instant::now().duration_since(sent_at) >= self.keepalive_config.ack_timeout {
self.shared.closed.store(true, Ordering::SeqCst);
self.shared.sender.close();
return Err(Error::new(
ErrorKind::Transport,
"http2 keepalive ping timed out",
));
}
}
Ok(())
}
async fn handle_keepalive_timer(&mut self) -> Result<()> {
self.check_keepalive_timeout()?;
let Some(idle_timeout) = self.keepalive_config.idle_timeout else {
return Ok(());
};
if self.keepalive_outstanding.is_some()
|| !self.pending.is_empty()
|| !self.is_keepalive_idle()
{
return Ok(());
}
if Instant::now().duration_since(
*self
.shared
.last_used
.lock()
.unwrap_or_else(|err| err.into_inner()),
) < idle_timeout
{
return Ok(());
}
let payload = self.next_keepalive_nonce.to_be_bytes();
self.next_keepalive_nonce = self.next_keepalive_nonce.wrapping_add(1);
write_ping(&mut self.core.stream, &payload, self.core.timeout_config).await?;
self.keepalive_outstanding = Some((payload, Instant::now()));
let sender = self.shared.sender.clone();
let ack_timeout = self.keepalive_config.ack_timeout;
std::thread::Builder::new()
.name("h2-keepalive-ack-timeout".to_owned())
.spawn(move || {
async_io::block_on(async move {
Timer::after(ack_timeout).await;
let _ = sender.try_send(ConnectionCommand::Drive);
});
})
.ok();
Ok(())
}
fn drain_commands(&mut self) {
loop {
match self.receiver.try_recv() {
Ok(command) => self.enqueue_command(command),
Err(TryRecvError::Empty) | Err(TryRecvError::Closed) => break,
}
}
}
fn enqueue_command(&mut self, command: ConnectionCommand) {
match command {
ConnectionCommand::Execute(command) => self.pending.push_back(command),
ConnectionCommand::CancelStream { stream_id } => {
if self.in_flight.contains_key(&stream_id)
&& !self.locally_reset_streams.contains(&stream_id)
&& !self.pending_resets.contains(&stream_id)
{
self.pending_resets.push_back(stream_id);
}
}
ConnectionCommand::Drive => {}
}
}
async fn flush_pending_resets(&mut self) -> Result<()> {
while let Some(stream_id) = self.pending_resets.pop_front() {
self.cancel_stream(stream_id).await?;
}
Ok(())
}
async fn open_pending_streams(&mut self) -> Result<()> {
while self.accepting_new_streams
&& self.core.state.can_open_new_stream(self.in_flight.len())
{
let Some(command) = self.pending.pop_front() else {
break;
};
self.start_stream(command).await?;
}
Ok(())
}
async fn start_stream(&mut self, command: ExecuteCommand) -> Result<()> {
let stream_id = self.core.state.open_stream();
command
.request
.body
.attach_waker(self.shared.sender.clone());
let body_len = command.request.body.content_length();
let end_stream = command.request.body.ends_stream_with_headers();
let upload_total = command.request.body.upload_total();
let pseudo_header_order = command
.request
.http2_fingerprint
.as_ref()
.map(|fp| fp.pseudo_header_order.as_slice());
let regular_header_order = command
.request
.http2_fingerprint
.as_ref()
.map(|fp| fp.regular_header_order.as_slice());
let request_headers = match build_request_header_list(
command.request.method,
&command.request.url,
&command.request.headers,
&command.request.cookies,
command.request.compression_mode,
body_len,
pseudo_header_order,
regular_header_order,
) {
Ok(headers) => headers,
Err(err) => {
self.fail_pending_command(command, &err);
return Err(err);
}
};
if let Some(limit) = self.core.state.peer_max_header_list_size {
let encoded_size = header_list_size(&request_headers);
if encoded_size > limit {
let err = Error::new(
ErrorKind::Transport,
format!(
"http2 request headers exceed peer SETTINGS_MAX_HEADER_LIST_SIZE: {encoded_size} > {limit}"
),
);
self.fail_pending_command(command, &err);
return Err(err);
}
}
if let Some(priorities) = command
.request
.http2_fingerprint
.as_ref()
.map(|fp| fp.priorities.as_slice())
{
if let Err(err) = write_priority_fingerprint_frames(
&mut self.core.stream,
priorities,
Http2PriorityPhase::BeforeHeaders,
stream_id,
self.core.timeout_config,
)
.await
{
self.fail_pending_command(command, &err);
return Err(err);
}
}
let header_block = match self
.core
.header_codec
.encode_request(&request_headers, self.core.state.peer_header_table_size)
{
Ok(header_block) => header_block,
Err(err) => {
self.fail_pending_command(command, &err);
return Err(err);
}
};
if let Err(err) = write_headers_frames(
&mut self.core.stream,
stream_id,
&header_block,
end_stream,
self.core.timeout_config,
self.core.state.peer_max_frame_size,
)
.await
{
self.shared.closed.store(true, Ordering::SeqCst);
self.shared.sender.close();
self.pending.push_front(command);
self.fail_all(&err);
return Err(err);
}
if let Some(priorities) = command
.request
.http2_fingerprint
.as_ref()
.map(|fp| fp.priorities.as_slice())
{
if let Err(err) = write_priority_fingerprint_frames(
&mut self.core.stream,
priorities,
Http2PriorityPhase::AfterHeaders,
stream_id,
self.core.timeout_config,
)
.await
{
self.fail_pending_command(command, &err);
return Err(err);
}
}
let connection_sender = self.shared.sender.clone();
let ExecuteCommand {
request,
response_tx,
} = command;
let PreparedRequest {
method,
url,
headers: _,
cookies: _,
timeout_config: _,
protocol_policy: _,
progress_callback,
progress_config,
h2_keepalive_config: _,
tls_config: _,
prior_knowledge_h2c: _,
compression_mode,
proxy: _,
body,
http2_fingerprint: _,
emulation_profile: _,
} = request;
let request_body = match body.into_active(connection_sender) {
Ok(body) => body,
Err(err) => {
let _ = response_tx.try_send(Err(connection_error(&err)));
self.release_request_slot();
return Err(err);
}
};
let mut upload_progress = progress_callback.clone().map(|callback| {
ProgressReporter::new(
callback,
ProgressPhase::Upload,
upload_total,
progress_config,
)
});
if end_stream {
if let Some(reporter) = upload_progress.as_mut() {
reporter.finish();
}
}
let download_progress = progress_callback.clone().map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
self.in_flight.insert(
stream_id,
ActiveStream {
url,
request_body,
sent: 0,
end_stream_sent: end_stream,
outbound_window: self.core.state.initial_stream_window(),
upload_progress,
download_progress,
response_state: ResponseStreamState::new(method, compression_mode),
response_tx,
response_sent: false,
body_tx: None,
deferred_trailers: None,
inbound_stream_consumed: 0,
local_stream_window: self.core.state.local_initial_stream_window_size,
},
);
Ok(())
}
async fn flush_outbound_data(&mut self) -> Result<()> {
let stream_ids = self.in_flight.keys().copied().collect::<Vec<_>>();
for stream_id in stream_ids {
loop {
let connection_capacity = self.core.state.available_connection_capacity();
if connection_capacity == 0 {
return Ok(());
}
let frame = match {
let Some(stream) = self.in_flight.get_mut(&stream_id) else {
break;
};
stream.next_outbound_frame(
connection_capacity,
self.core.state.peer_max_frame_size,
)
} {
Ok(frame) => frame,
Err(err) => {
self.fail_request_stream(stream_id, err).await?;
break;
}
};
let Some(frame) = frame else {
break;
};
write_data_frame(
&mut self.core.stream,
stream_id,
&frame.payload,
frame.end_stream,
self.core.timeout_config,
)
.await?;
self.core
.state
.consume_connection_capacity(frame.payload.len())?;
let Some(stream) = self.in_flight.get_mut(&stream_id) else {
break;
};
stream.record_sent(frame.payload.len(), frame.end_stream)?;
if frame.end_stream {
break;
}
}
}
Ok(())
}
fn track_locally_reset(&mut self, stream_id: u32) {
const MAX_LOCALLY_RESET: usize = 1000;
if self.locally_reset_streams.len() >= MAX_LOCALLY_RESET {
if let Some(&victim) = self.locally_reset_streams.iter().next() {
self.locally_reset_streams.remove(&victim);
}
}
self.locally_reset_streams.insert(stream_id);
}
async fn cancel_stream(&mut self, stream_id: u32) -> Result<()> {
let Some(mut stream) = self.in_flight.remove(&stream_id) else {
self.locally_reset_streams.remove(&stream_id);
return Ok(());
};
stream.abort_request_body();
stream.body_tx = None;
stream.deferred_trailers = None;
self.track_locally_reset(stream_id);
let result = write_rst_stream(
&mut self.core.stream,
stream_id,
H2_ERROR_CANCEL,
self.core.timeout_config,
)
.await;
self.release_request_slot();
result
}
async fn fail_request_stream(&mut self, stream_id: u32, err: Error) -> Result<()> {
let Some(mut stream) = self.in_flight.remove(&stream_id) else {
self.locally_reset_streams.remove(&stream_id);
return Ok(());
};
stream.body_tx = None;
stream.deferred_trailers = None;
self.track_locally_reset(stream_id);
let reset_result = write_rst_stream(
&mut self.core.stream,
stream_id,
H2_ERROR_CANCEL,
self.core.timeout_config,
)
.await;
self.fail_active_stream(stream, err);
reset_result
}
async fn process_frame(&mut self, frame: Frame) -> Result<()> {
match frame.frame_type {
FrameType::Settings => self.handle_settings(frame).await,
FrameType::Ping => self.handle_ping(frame).await,
FrameType::Headers => self.handle_headers(frame).await,
FrameType::PushPromise => self.handle_push_promise(frame).await,
FrameType::Continuation => Err(Error::new(
ErrorKind::Transport,
"unexpected standalone http2 CONTINUATION frame",
)),
FrameType::Data => self.handle_data(frame).await,
FrameType::RstStream => self.handle_rst_stream(frame),
FrameType::GoAway => self.handle_goaway(frame),
FrameType::WindowUpdate => self.handle_window_update(frame),
FrameType::Priority | FrameType::Unknown => Ok(()),
}
}
async fn handle_push_promise(&mut self, frame: Frame) -> Result<()> {
let parent_stream_id = frame.stream_id;
if !self.in_flight.contains_key(&parent_stream_id)
&& !self.locally_reset_streams.contains(&parent_stream_id)
{
return Err(Error::new(
ErrorKind::Transport,
format!("http2 PUSH_PROMISE received for unknown stream {parent_stream_id}"),
));
}
let promised = read_push_promise(
&mut self.core.stream,
frame,
self.core.timeout_config,
self.core.state.peer_max_frame_size,
)
.await?;
let _ = self.core.header_codec.decode_block(&promised.fragment)?;
self.track_locally_reset(promised.promised_stream_id);
write_rst_stream(
&mut self.core.stream,
promised.promised_stream_id,
H2_ERROR_REFUSED_STREAM,
self.core.timeout_config,
)
.await
}
async fn handle_settings(&mut self, frame: Frame) -> Result<()> {
let settings = frame.settings()?;
if !frame.flags.contains(Flags::ACK) {
self.apply_settings(&settings)?;
write_settings_ack(&mut self.core.stream, self.core.timeout_config).await?;
}
Ok(())
}
async fn handle_ping(&mut self, frame: Frame) -> Result<()> {
let payload = frame.validated_ping_payload()?.to_vec();
if frame.flags.contains(Flags::ACK) {
if let Some((expected, _)) = self.keepalive_outstanding {
if payload.as_slice() == expected {
self.keepalive_outstanding = None;
}
}
} else {
write_ping_ack(&mut self.core.stream, &payload, self.core.timeout_config).await?;
}
Ok(())
}
async fn handle_headers(&mut self, frame: Frame) -> Result<()> {
let stream_id = frame.stream_id;
let header_block = read_header_block(
&mut self.core.stream,
frame,
self.core.timeout_config,
self.core.state.peer_max_frame_size,
)
.await?;
let decoded = self
.core
.header_codec
.decode_block(&header_block.fragment)?;
if self.locally_reset_streams.contains(&stream_id) {
if header_block.end_stream {
self.locally_reset_streams.remove(&stream_id);
}
let _ = decoded;
return Ok(());
}
let stream = self.in_flight.get_mut(&stream_id).ok_or_else(|| {
Error::new(
ErrorKind::Transport,
format!("http2 response headers received for unknown stream {stream_id}"),
)
})?;
let block_kind = stream.response_state.apply_headers(decoded)?;
if matches!(block_kind, HeaderBlockKind::Trailers) && !header_block.end_stream {
return Err(Error::new(
ErrorKind::Transport,
"http2 trailing headers must end the stream",
));
}
if header_block.end_stream {
if matches!(block_kind, HeaderBlockKind::Informational) {
return Err(Error::new(
ErrorKind::Transport,
"informational http2 response cannot end the stream",
));
}
stream.response_state.mark_end_stream();
}
self.maybe_send_streaming_response(stream_id)?;
self.finish_stream_if_complete(stream_id).await
}
async fn handle_data(&mut self, frame: Frame) -> Result<()> {
let stream_id = frame.stream_id;
let end_stream = frame.is_end_stream();
let payload = frame.into_data_payload()?;
let payload_len = payload.len();
let payload_bytes = Bytes::from(payload);
let locally_reset = self.locally_reset_streams.contains(&stream_id);
let body_tx = if locally_reset {
None
} else {
let stream = self.in_flight.get_mut(&stream_id).ok_or_else(|| {
Error::new(
ErrorKind::Transport,
format!("http2 DATA frame received for unknown stream {stream_id}"),
)
})?;
if !stream.response_state.can_receive_data() {
return Err(Error::new(
ErrorKind::Transport,
"http2 DATA frame arrived before final response headers",
));
}
stream.response_state.record_body_bytes(payload_len);
if !stream.response_state.allows_response_body() {
if payload_len != 0 {
return Err(Error::new(
ErrorKind::Transport,
"http2 response body is not allowed for this request or status",
));
}
None
} else {
if let Some(reporter) = stream.download_progress.as_mut() {
reporter.record(payload_len);
}
if stream.response_sent {
stream.body_tx.clone()
} else {
stream.response_state.push_body(&payload_bytes);
None
}
}
};
if let Some(body_tx) = body_tx {
if body_tx.try_send(Ok(payload_bytes)).is_err() {
if let Some(stream) = self.in_flight.get_mut(&stream_id) {
stream.body_tx = None;
stream.deferred_trailers = None;
}
}
}
let connection_increment = self.core.state.record_inbound_data(payload_len)?;
let stream_increment = if !locally_reset {
if let Some(stream) = self.in_flight.get_mut(&stream_id) {
let size = payload_len as u32;
stream.inbound_stream_consumed =
stream.inbound_stream_consumed.saturating_add(size);
let threshold = stream.local_stream_window / 2;
if stream.inbound_stream_consumed > threshold {
let inc = stream.inbound_stream_consumed;
stream.inbound_stream_consumed = 0;
inc
} else {
0
}
} else {
0
}
} else {
0
};
write_window_update(
&mut self.core.stream,
0,
connection_increment,
self.core.timeout_config,
)
.await?;
write_window_update(
&mut self.core.stream,
stream_id,
stream_increment,
self.core.timeout_config,
)
.await?;
if locally_reset {
if end_stream {
self.locally_reset_streams.remove(&stream_id);
}
return Ok(());
}
if end_stream {
if let Some(stream) = self.in_flight.get_mut(&stream_id) {
stream.response_state.mark_end_stream();
}
}
self.finish_stream_if_complete(stream_id).await
}
fn handle_rst_stream(&mut self, frame: Frame) -> Result<()> {
let stream_id = frame.stream_id;
let error_code = frame.rst_stream_error_code()?;
if self.locally_reset_streams.remove(&stream_id) {
return Ok(());
}
if let Some(stream) = self.in_flight.remove(&stream_id) {
self.fail_active_stream(
stream,
Error::new(
ErrorKind::Transport,
format!("http2 stream reset by peer: error_code={error_code}"),
),
);
}
Ok(())
}
fn handle_goaway(&mut self, frame: Frame) -> Result<()> {
let GoAwayFrame {
last_stream_id,
error_code,
} = frame.goaway()?;
let reason = format!(
"http2 connection closed by peer: last_stream_id={last_stream_id}, error_code={error_code}"
);
self.accepting_new_streams = false;
self.shared.closed.store(true, Ordering::SeqCst);
self.shared
.max_concurrent_streams
.store(0, Ordering::SeqCst);
self.shared.sender.close();
let pending_error = Error::new(ErrorKind::StaleConnection, reason.clone());
while let Some(command) = self.pending.pop_front() {
self.fail_pending_command(command, &pending_error);
}
self.fail_queued_commands(&pending_error);
if error_code != 0 {
return Err(Error::new(ErrorKind::Transport, reason));
}
let rejected_streams = self
.in_flight
.keys()
.copied()
.filter(|stream_id| *stream_id > last_stream_id)
.collect::<Vec<_>>();
for stream_id in rejected_streams {
if let Some(stream) = self.in_flight.remove(&stream_id) {
self.fail_active_stream(
stream,
Error::new(
ErrorKind::StaleConnection,
format!(
"http2 connection closed by peer before stream {stream_id} completed: last_stream_id={last_stream_id}, error_code={error_code}"
),
),
);
}
}
Ok(())
}
fn handle_window_update(&mut self, frame: Frame) -> Result<()> {
let stream_id = frame.stream_id;
let increment = frame.window_update_increment()?;
if stream_id == 0 {
self.core.state.apply_connection_window_update(increment)?;
return Ok(());
}
if self.locally_reset_streams.contains(&stream_id) {
return Ok(());
}
if let Some(stream) = self.in_flight.get_mut(&stream_id) {
stream.apply_window_update(increment)?;
}
Ok(())
}
fn apply_settings(&mut self, settings: &[Setting]) -> Result<()> {
for setting in settings {
match *setting {
Setting::HeaderTableSize(size) => {
self.core.state.peer_header_table_size = size;
}
Setting::EnablePush(_enabled) => {
}
Setting::MaxConcurrentStreams(limit) => {
self.core.state.peer_max_concurrent_streams = limit as usize;
}
Setting::InitialWindowSize(size) => {
let delta = size - self.core.state.peer_initial_window_size;
for stream in self.in_flight.values_mut() {
stream.adjust_outbound_window(delta)?;
}
self.core.state.peer_initial_window_size = size;
}
Setting::MaxFrameSize(size) => {
self.core.state.peer_max_frame_size = size;
}
Setting::MaxHeaderListSize(limit) => {
self.core.state.peer_max_header_list_size = Some(limit as usize);
}
Setting::Unknown(identifier, value) => {
let _ = (identifier, value);
}
}
}
self.shared.max_concurrent_streams.store(
self.core.state.peer_max_concurrent_streams,
Ordering::SeqCst,
);
Ok(())
}
fn maybe_send_streaming_response(&mut self, stream_id: u32) -> Result<()> {
let Some(stream) = self.in_flight.get_mut(&stream_id) else {
return Ok(());
};
if let Some(response) =
stream.prepare_streaming_response(stream_id, self.shared.sender.clone())?
{
let response_tx = stream.response_tx.clone();
if response_tx.try_send(Ok(response)).is_err() {
stream.body_tx = None;
stream.deferred_trailers = None;
}
}
Ok(())
}
async fn finish_stream_if_complete(&mut self, stream_id: u32) -> Result<()> {
let complete = self
.in_flight
.get(&stream_id)
.map(|stream| stream.response_state.is_complete())
.unwrap_or(false);
if !complete {
return Ok(());
}
let mut stream = self
.in_flight
.remove(&stream_id)
.expect("stream must exist");
let cancel_unfinished_request = !stream.end_stream_sent;
if cancel_unfinished_request {
stream.abort_request_body();
self.track_locally_reset(stream_id);
let _ = write_rst_stream(
&mut self.core.stream,
stream_id,
H2_ERROR_CANCEL,
self.core.timeout_config,
)
.await;
}
if let Some(reporter) = stream.upload_progress.as_mut() {
reporter.finish();
}
if let Some(reporter) = stream.download_progress.as_mut() {
reporter.finish();
}
if stream.response_sent {
if let Err(err) = stream.response_state.validate_received_body() {
if let Some(body_tx) = stream.body_tx.take() {
let _ = body_tx.try_send(Err(err));
}
self.release_request_slot();
return Ok(());
}
if let Some(trailers) = stream.response_state.trailers.clone() {
if let Some(deferred_trailers) = stream.deferred_trailers.take() {
let _ = deferred_trailers.set(trailers);
}
}
drop(stream.body_tx.take());
self.release_request_slot();
return Ok(());
}
let url = stream.url.clone();
let response_tx = stream.response_tx.clone();
let result = stream.response_state.into_response(url);
self.send_stream_result(response_tx, result);
self.release_request_slot();
Ok(())
}
fn send_stream_result(&self, response_tx: Sender<Result<Response>>, result: Result<Response>) {
let _ = response_tx.try_send(result);
}
fn fail_pending_command(&self, command: ExecuteCommand, err: &Error) {
command.request.body.abort_upload();
let _ = command.response_tx.try_send(Err(connection_error(err)));
self.release_request_slot();
}
fn fail_active_stream(&self, mut stream: ActiveStream, err: Error) {
stream.abort_request_body();
if let Some(reporter) = stream.upload_progress.as_mut() {
reporter.finish();
}
if let Some(reporter) = stream.download_progress.as_mut() {
reporter.finish();
}
if stream.response_sent {
if let Some(body_tx) = stream.body_tx.take() {
let _ = body_tx.try_send(Err(connection_error(&err)));
}
} else {
let _ = stream.response_tx.try_send(Err(connection_error(&err)));
}
self.release_request_slot();
}
fn fail_queued_commands(&mut self, err: &Error) {
loop {
match self.receiver.try_recv() {
Ok(ConnectionCommand::Execute(command)) => self.fail_pending_command(command, err),
Ok(ConnectionCommand::CancelStream { .. }) | Ok(ConnectionCommand::Drive) => {}
Err(TryRecvError::Empty) | Err(TryRecvError::Closed) => break,
}
}
}
fn release_request_slot(&self) {
let previous = self.shared.active_requests.fetch_sub(1, Ordering::SeqCst);
*self
.shared
.last_used
.lock()
.unwrap_or_else(|err| err.into_inner()) = Instant::now();
if previous == 1 {
if let Some(idle_timeout) = self.shared.keepalive_idle_timeout {
let sender = self.shared.sender.clone();
std::thread::Builder::new()
.name("h2-keepalive-idle-wake".to_owned())
.spawn(move || {
async_io::block_on(async move {
Timer::after(idle_timeout).await;
let _ = sender.try_send(ConnectionCommand::Drive);
});
})
.ok();
}
}
}
fn fail_all(&mut self, err: &Error) {
while let Some(command) = self.pending.pop_front() {
self.fail_pending_command(command, err);
}
let streams = self
.in_flight
.drain()
.map(|(_, stream)| stream)
.collect::<Vec<_>>();
for stream in streams {
self.fail_active_stream(stream, connection_error(err));
}
self.fail_queued_commands(err);
}
}
async fn run_connection_task(
core: ConnectionCore,
keepalive_config: H2KeepAliveConfig,
receiver: Receiver<ConnectionCommand>,
shared: Arc<SharedConnection>,
) {
let mut task = ConnectionTask::new(core, keepalive_config, receiver, Arc::clone(&shared));
let result = task.run().await;
if let Err(err) = result {
shared.closed.store(true, Ordering::SeqCst);
shared.sender.close();
task.fail_all(&err);
}
shared.closed.store(true, Ordering::SeqCst);
shared.sender.close();
}
async fn run_connection_task_with_initial_stream(
core: ConnectionCore,
keepalive_config: H2KeepAliveConfig,
receiver: Receiver<ConnectionCommand>,
shared: Arc<SharedConnection>,
stream_id: u32,
stream: ActiveStream,
) {
let mut task = ConnectionTask::new_with_initial_stream(
core,
keepalive_config,
receiver,
Arc::clone(&shared),
stream_id,
stream,
);
let result = task.run().await;
if let Err(err) = result {
shared.closed.store(true, Ordering::SeqCst);
shared.sender.close();
task.fail_all(&err);
}
shared.closed.store(true, Ordering::SeqCst);
shared.sender.close();
}
enum LoopEvent {
Command(std::result::Result<ConnectionCommand, async_channel::RecvError>),
Frame(Result<Frame>),
}
struct CancelStreamOnDrop {
sender: Sender<ConnectionCommand>,
stream_id: u32,
armed: bool,
}
impl CancelStreamOnDrop {
fn new(sender: Sender<ConnectionCommand>, stream_id: u32) -> Self {
Self {
sender,
stream_id,
armed: true,
}
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for CancelStreamOnDrop {
fn drop(&mut self) {
if !self.armed {
return;
}
let _ = self.sender.try_send(ConnectionCommand::CancelStream {
stream_id: self.stream_id,
});
}
}
enum ActiveRequestBody {
Bytes {
data: Bytes,
},
Stream {
receiver: Receiver<Result<Option<Bytes>>>,
abort_sender: Sender<()>,
buffer: Bytes,
offset: usize,
finished: bool,
},
}
struct OutboundFrame {
payload: Bytes,
end_stream: bool,
}
struct ActiveStream {
url: Url,
request_body: ActiveRequestBody,
sent: usize,
end_stream_sent: bool,
outbound_window: i32,
upload_progress: Option<ProgressReporter>,
download_progress: Option<ProgressReporter>,
response_state: ResponseStreamState,
response_tx: Sender<Result<Response>>,
response_sent: bool,
body_tx: Option<Sender<Result<Bytes>>>,
deferred_trailers: Option<Arc<OnceLock<HeaderMap>>>,
inbound_stream_consumed: u32,
local_stream_window: u32,
}
impl ActiveStream {
fn from_upgraded_request(
url: Url,
request_method: Method,
upload_progress: Option<ProgressReporter>,
download_progress: Option<ProgressReporter>,
response_tx: Sender<Result<Response>>,
local_stream_window: u32,
) -> Self {
Self {
url,
request_body: ActiveRequestBody::Bytes { data: Bytes::new() },
sent: 0,
end_stream_sent: true,
outbound_window: DEFAULT_INITIAL_WINDOW_SIZE,
upload_progress,
download_progress,
response_state: ResponseStreamState::new(request_method, CompressionMode::Auto),
response_tx,
response_sent: false,
body_tx: None,
deferred_trailers: None,
inbound_stream_consumed: 0,
local_stream_window,
}
}
fn abort_request_body(&self) {
if let ActiveRequestBody::Stream { abort_sender, .. } = &self.request_body {
let _ = abort_sender.try_send(());
}
}
fn prepare_streaming_response(
&mut self,
stream_id: u32,
command_tx: Sender<ConnectionCommand>,
) -> Result<Option<Response>> {
if self.response_sent
|| !self.response_state.received_final_headers
|| self.response_state.requires_aggregated_body()
{
return Ok(None);
}
let status = self.response_state.status.ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"http2 response completed without final headers",
)
})?;
let headers = self.response_state.headers.clone();
let body = if self.response_state.end_stream {
Body::empty()
} else {
let (body_tx, body_rx) = async_channel::unbounded();
self.body_tx = Some(body_tx);
Body::from_stream(Box::pin(async_stream::stream! {
let mut cancel = CancelStreamOnDrop::new(command_tx, stream_id);
while let Ok(item) = body_rx.recv().await {
yield item;
}
cancel.disarm();
}))
};
let trailers = if self.response_state.end_stream {
TrailerState::Ready(None)
} else {
let deferred = Arc::new(OnceLock::new());
self.deferred_trailers = Some(Arc::clone(&deferred));
TrailerState::Deferred(deferred)
};
self.response_sent = true;
Ok(Some(Response::new_with_trailer_state(
status,
Version::Http2,
self.url.clone(),
headers,
trailers,
body,
)))
}
fn outbound_capacity(&self, connection_capacity: usize, max_frame_size: usize) -> usize {
if self.end_stream_sent {
return 0;
}
connection_capacity
.min(self.outbound_window.max(0) as usize)
.min(max_frame_size)
}
fn next_outbound_frame(
&mut self,
connection_capacity: usize,
max_frame_size: usize,
) -> Result<Option<OutboundFrame>> {
let capacity = self.outbound_capacity(connection_capacity, max_frame_size);
if capacity == 0 {
return Ok(None);
}
match &mut self.request_body {
ActiveRequestBody::Bytes { data } => {
if self.sent >= data.len() {
return Ok(None);
}
let to_send = capacity.min(data.len() - self.sent);
Ok(Some(OutboundFrame {
payload: data.slice(self.sent..self.sent + to_send),
end_stream: self.sent + to_send == data.len(),
}))
}
ActiveRequestBody::Stream {
receiver,
buffer,
offset,
finished,
..
} => {
while *offset >= buffer.len() && !*finished {
match receiver.try_recv() {
Ok(Ok(Some(chunk))) if chunk.is_empty() => continue,
Ok(Ok(Some(chunk))) => {
*buffer = chunk;
*offset = 0;
break;
}
Ok(Ok(None)) | Err(TryRecvError::Closed) => {
*finished = true;
break;
}
Ok(Err(err)) => return Err(err),
Err(TryRecvError::Empty) => break,
}
}
if *offset < buffer.len() {
let to_send = capacity.min(buffer.len() - *offset);
Ok(Some(OutboundFrame {
payload: buffer.slice(*offset..*offset + to_send),
end_stream: *finished && *offset + to_send == buffer.len(),
}))
} else if *finished {
Ok(Some(OutboundFrame {
payload: Bytes::new(),
end_stream: true,
}))
} else {
Ok(None)
}
}
}
}
fn record_sent(&mut self, size: usize, end_stream: bool) -> Result<()> {
let size_i32 = i32::try_from(size).map_err(|_| {
Error::new(
ErrorKind::Transport,
"http2 outbound data chunk exceeded supported window size",
)
})?;
self.sent += size;
self.outbound_window -= size_i32;
if let ActiveRequestBody::Stream { buffer, offset, .. } = &mut self.request_body {
if size > 0 {
*offset += size;
if *offset >= buffer.len() {
*buffer = Bytes::new();
*offset = 0;
}
}
}
if let Some(reporter) = self.upload_progress.as_mut() {
reporter.record(size);
if end_stream {
reporter.finish();
}
}
if end_stream {
self.end_stream_sent = true;
}
Ok(())
}
fn adjust_outbound_window(&mut self, delta: i32) -> Result<()> {
self.outbound_window = self.outbound_window.checked_add(delta).ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"http2 outbound stream window overflowed",
)
})?;
if self.outbound_window > MAX_FLOW_CONTROL_WINDOW {
return Err(Error::new(
ErrorKind::Transport,
"http2 outbound stream window exceeded protocol maximum",
));
}
Ok(())
}
fn apply_window_update(&mut self, increment: u32) -> Result<()> {
let increment = i32::try_from(increment).map_err(|_| {
Error::new(
ErrorKind::Transport,
"http2 WINDOW_UPDATE increment exceeded supported range",
)
})?;
self.outbound_window = checked_window_add(self.outbound_window, increment)?;
Ok(())
}
}
#[derive(Debug)]
struct ConnectionState {
next_stream_id: u32,
peer_initial_window_size: i32,
peer_max_frame_size: usize,
peer_max_concurrent_streams: usize,
peer_header_table_size: usize,
peer_max_header_list_size: Option<usize>,
outbound_connection_window: i32,
inbound_connection_consumed: u32,
local_connection_window: u32,
local_max_frame_size: usize,
local_initial_stream_window_size: u32,
}
impl Default for ConnectionState {
fn default() -> Self {
Self {
next_stream_id: 1,
peer_initial_window_size: DEFAULT_INITIAL_WINDOW_SIZE,
peer_max_frame_size: DEFAULT_MAX_FRAME_SIZE,
peer_max_concurrent_streams: usize::MAX,
peer_header_table_size: DEFAULT_HEADER_TABLE_SIZE,
peer_max_header_list_size: None,
outbound_connection_window: DEFAULT_INITIAL_WINDOW_SIZE,
inbound_connection_consumed: 0,
local_connection_window: DEFAULT_INITIAL_WINDOW_SIZE as u32,
local_max_frame_size: DEFAULT_MAX_FRAME_SIZE,
local_initial_stream_window_size: DEFAULT_INITIAL_WINDOW_SIZE as u32,
}
}
}
impl ConnectionState {
fn can_open_new_stream(&self, in_flight: usize) -> bool {
self.next_stream_id <= MAX_CLIENT_STREAM_ID && in_flight < self.peer_max_concurrent_streams
}
fn open_stream(&mut self) -> u32 {
let stream_id = self.next_stream_id;
self.next_stream_id = self.next_stream_id.saturating_add(2);
stream_id
}
fn reserve_open_stream(&mut self, stream_id: u32) {
self.next_stream_id = self.next_stream_id.max(stream_id.saturating_add(2));
}
fn initial_stream_window(&self) -> i32 {
self.peer_initial_window_size
}
fn apply_local_fingerprint(&mut self, fp: &Http2Fingerprint) -> Result<()> {
if let Some(size) = fp.max_frame_size {
if !(DEFAULT_MAX_FRAME_SIZE as u32..=super::frame::MAX_FRAME_SIZE_UPPER_BOUND as u32)
.contains(&size)
{
return Err(Error::new(
ErrorKind::Transport,
"invalid http2 MAX_FRAME_SIZE value",
));
}
self.local_max_frame_size = size as usize;
}
if let Some(size) = fp.initial_window_size {
if size > MAX_FLOW_CONTROL_WINDOW as u32 {
return Err(Error::new(
ErrorKind::Transport,
"invalid http2 INITIAL_WINDOW_SIZE value",
));
}
self.local_initial_stream_window_size = size;
}
if let Some(size) = fp.initial_connection_window_size {
if !(DEFAULT_INITIAL_WINDOW_SIZE as u32..=MAX_FLOW_CONTROL_WINDOW as u32)
.contains(&size)
{
return Err(Error::new(
ErrorKind::Transport,
"invalid http2 INITIAL_CONNECTION_WINDOW_SIZE value",
));
}
self.local_connection_window = size;
}
Ok(())
}
fn available_connection_capacity(&self) -> usize {
(self.outbound_connection_window.max(0) as usize).min(self.peer_max_frame_size)
}
fn consume_connection_capacity(&mut self, size: usize) -> Result<()> {
let size = i32::try_from(size).map_err(|_| {
Error::new(
ErrorKind::Transport,
"http2 outbound data chunk exceeded supported connection window",
)
})?;
self.outbound_connection_window -= size;
Ok(())
}
fn apply_connection_window_update(&mut self, increment: u32) -> Result<()> {
let increment = i32::try_from(increment).map_err(|_| {
Error::new(
ErrorKind::Transport,
"http2 WINDOW_UPDATE increment exceeded supported range",
)
})?;
self.outbound_connection_window =
checked_window_add(self.outbound_connection_window, increment)?;
Ok(())
}
fn record_inbound_data(&mut self, size: usize) -> Result<u32> {
let size = u32::try_from(size).map_err(|_| {
Error::new(
ErrorKind::Transport,
"http2 inbound data chunk exceeded supported window size",
)
})?;
self.inbound_connection_consumed = self.inbound_connection_consumed.saturating_add(size);
let threshold = self.local_connection_window / 2;
if self.inbound_connection_consumed > threshold {
let increment = self.inbound_connection_consumed;
self.inbound_connection_consumed = 0;
Ok(increment)
} else {
Ok(0)
}
}
}
struct H2cUpgradeResponse {
status: StatusCode,
headers: HeaderMap,
remaining: Vec<u8>,
}
fn encode_h2c_upgrade_request(
method: &str,
url: &Url,
headers: &HeaderMap,
cookies: &[(String, String)],
body: &Bytes,
settings_payload: &[u8],
) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
buffer.extend_from_slice(method.as_bytes());
buffer.extend_from_slice(b" ");
buffer.extend_from_slice(url.path_and_query().as_bytes());
buffer.extend_from_slice(b" HTTP/1.1\r\n");
let mut has_host = false;
let mut has_content_length = false;
let mut has_accept_encoding = false;
for (name, value) in headers.iter() {
match name.as_str() {
"host" => has_host = true,
"content-length" => has_content_length = true,
"accept-encoding" => has_accept_encoding = true,
"transfer-encoding" => {
return Err(Error::new(
ErrorKind::InvalidHeaderValue,
"h2c upgrade does not support transfer-encoding; provide a buffered body instead",
));
}
"connection" | "upgrade" | "http2-settings" => continue,
_ => {}
}
buffer.extend_from_slice(name.as_str().as_bytes());
buffer.extend_from_slice(b": ");
buffer.extend_from_slice(value.as_str().as_bytes());
buffer.extend_from_slice(b"\r\n");
}
if !has_host {
buffer.extend_from_slice(b"host: ");
buffer.extend_from_slice(url.authority().as_bytes());
buffer.extend_from_slice(b"\r\n");
}
if !has_content_length {
buffer.extend_from_slice(b"content-length: ");
buffer.extend_from_slice(body.len().to_string().as_bytes());
buffer.extend_from_slice(b"\r\n");
}
if !has_accept_encoding {
buffer.extend_from_slice(b"accept-encoding: ");
buffer.extend_from_slice(DEFAULT_ACCEPT_ENCODING.as_bytes());
buffer.extend_from_slice(b"\r\n");
}
if !cookies.is_empty() {
buffer.extend_from_slice(b"cookie: ");
for (index, (name, value)) in cookies.iter().enumerate() {
if index > 0 {
buffer.extend_from_slice(b"; ");
}
buffer.extend_from_slice(name.as_bytes());
buffer.extend_from_slice(b"=");
buffer.extend_from_slice(value.as_bytes());
}
buffer.extend_from_slice(b"\r\n");
}
buffer.extend_from_slice(b"connection: Upgrade, HTTP2-Settings\r\n");
buffer.extend_from_slice(b"upgrade: h2c\r\n");
buffer.extend_from_slice(b"http2-settings: ");
buffer.extend_from_slice(crate::util::encode_base64url_no_pad(settings_payload).as_bytes());
buffer.extend_from_slice(b"\r\n\r\n");
Ok(buffer)
}
async fn read_h2c_upgrade_response<S>(
stream: &mut S,
timeout_config: TimeoutConfig,
) -> Result<H2cUpgradeResponse>
where
S: futures_lite::io::AsyncRead + Unpin + ?Sized,
{
let mut buffer = Vec::new();
let mut header_end = None;
while header_end.is_none() {
let mut chunk = [0_u8; 1024];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
buffer.extend_from_slice(&chunk[..read]);
header_end = find_header_block_end(&buffer);
}
let header_end = header_end.ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"h2c upgrade response headers are incomplete",
)
})?;
let head = &buffer[..header_end];
let remaining = buffer[header_end + 4..].to_vec();
let head_text = std::str::from_utf8(head).map_err(|err| {
Error::with_source(
ErrorKind::Decode,
"h2c upgrade response headers are not valid utf-8",
err,
)
})?;
let mut lines = head_text.split("\r\n");
let status_line = lines
.next()
.ok_or_else(|| Error::new(ErrorKind::Transport, "missing h2c upgrade status line"))?;
let status = parse_http1_status_line(status_line)?;
let mut headers = HeaderMap::new();
for line in lines {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').ok_or_else(|| {
Error::new(
ErrorKind::Transport,
format!("invalid h2c upgrade header line: {line}"),
)
})?;
headers.append(name.trim(), value.trim())?;
}
Ok(H2cUpgradeResponse {
status,
headers,
remaining,
})
}
async fn read_http1_response_after_upgrade_rejection(
stream: async_net::TcpStream,
request_method: Method,
url: Url,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
upgrade: H2cUpgradeResponse,
) -> Result<Response> {
let H2cUpgradeResponse {
status,
mut headers,
remaining,
} = upgrade;
let allows_body = response_body_allowed(request_method, status);
if !allows_body {
return Ok(Response::new(
status,
Version::Http11,
url,
headers,
None,
Body::empty(),
));
}
let mut stream: BoxedStream = Box::new(stream);
let has_content_encoding = headers.get("content-encoding").is_some();
if header_contains_token(&headers, "transfer-encoding", "chunked") {
if has_content_encoding {
let (body_bytes, trailers) = read_http1_chunked_body_bytes(
&mut stream,
remaining,
timeout_config,
progress_callback,
progress_config,
)
.await?;
let body = decode_response_body(&mut headers, body_bytes)?;
return Ok(Response::new(
status,
Version::Http11,
url,
headers,
trailers,
body,
));
}
let trailers = Arc::new(OnceLock::new());
let body = read_http1_chunked_body_stream(
stream,
remaining,
timeout_config,
progress_callback,
progress_config,
Arc::clone(&trailers),
);
return Ok(Response::new_with_trailer_state(
status,
Version::Http11,
url,
headers,
TrailerState::Deferred(trailers),
body,
));
}
if let Some(content_length) = parse_content_length(&headers, "http1")? {
let body_bytes = read_http1_length_body_bytes(
&mut stream,
remaining,
timeout_config,
progress_callback,
progress_config,
content_length,
)
.await?;
let body = if has_content_encoding {
decode_response_body(&mut headers, body_bytes)?
} else {
Body::from(body_bytes)
};
return Ok(Response::new(
status,
Version::Http11,
url,
headers,
None,
body,
));
}
if has_content_encoding {
let body_bytes = read_http1_to_eof_bytes(
&mut stream,
remaining,
timeout_config,
progress_callback,
progress_config,
)
.await?;
let body = decode_response_body(&mut headers, body_bytes)?;
return Ok(Response::new(
status,
Version::Http11,
url,
headers,
None,
body,
));
}
let body = read_http1_to_eof_body_stream(
stream,
remaining,
timeout_config,
progress_callback,
progress_config,
);
Ok(Response::new(
status,
Version::Http11,
url,
headers,
None,
body,
))
}
async fn read_http1_length_body_bytes<S>(
stream: &mut S,
mut body: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
content_length: usize,
) -> Result<Vec<u8>>
where
S: futures_lite::io::AsyncRead + Unpin + ?Sized,
{
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(
callback,
ProgressPhase::Download,
Some(content_length),
progress_config,
)
});
if let Some(progress) = &mut progress {
if !body.is_empty() {
progress.record(body.len());
}
}
while body.len() < content_length {
let mut chunk = vec![0; content_length - body.len()];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
if let Some(progress) = &mut progress {
progress.record(read);
}
}
if let Some(progress) = &mut progress {
progress.finish();
}
Ok(body)
}
async fn read_http1_to_eof_bytes<S>(
stream: &mut S,
mut body: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
) -> Result<Vec<u8>>
where
S: futures_lite::io::AsyncRead + Unpin + ?Sized,
{
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
if let Some(progress) = &mut progress {
if !body.is_empty() {
progress.record(body.len());
}
}
loop {
let mut chunk = [0_u8; 4096];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
if let Some(progress) = &mut progress {
progress.record(read);
}
}
if let Some(progress) = &mut progress {
progress.finish();
}
Ok(body)
}
fn read_http1_to_eof_body_stream(
mut stream: BoxedStream,
initial: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
) -> Body {
let stream_body = async_stream::try_stream! {
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
if !initial.is_empty() {
if let Some(progress) = &mut progress {
progress.record(initial.len());
}
yield Bytes::from(initial);
}
loop {
let mut chunk = vec![0_u8; 8192];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
break;
}
chunk.truncate(read);
if let Some(progress) = &mut progress {
progress.record(read);
}
yield Bytes::from(chunk);
}
if let Some(progress) = &mut progress {
progress.finish();
}
};
Body::from_stream(Box::pin(stream_body))
}
async fn read_http1_chunked_body_bytes<S>(
stream: &mut S,
mut buffer: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
) -> Result<(Vec<u8>, Option<HeaderMap>)>
where
S: futures_lite::io::AsyncRead + Unpin + ?Sized,
{
let mut decoded = Vec::new();
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
loop {
let size_line_end = loop {
if let Some(pos) = find_http1_crlf(&buffer) {
break pos;
}
read_http1_more(stream, &mut buffer, timeout_config).await?;
};
let size_line = String::from_utf8(buffer[..size_line_end].to_vec()).map_err(|err| {
Error::with_source(ErrorKind::Decode, "chunk size is not valid utf-8", err)
})?;
buffer.drain(..size_line_end + 2);
let size = usize::from_str_radix(size_line.trim(), 16).map_err(|_| {
Error::new(
ErrorKind::Transport,
format!("invalid chunk size: {size_line}"),
)
})?;
if size == 0 {
if let Some(progress) = &mut progress {
progress.finish();
}
if buffer.starts_with(b"\r\n") {
return Ok((decoded, None));
}
while find_header_block_end(&buffer).is_none() {
read_http1_more(stream, &mut buffer, timeout_config).await?;
}
let trailers = parse_http1_trailers(&buffer)?;
return Ok((decoded, trailers));
}
while buffer.len() < size + 2 {
read_http1_more(stream, &mut buffer, timeout_config).await?;
}
decoded.extend_from_slice(&buffer[..size]);
if let Some(progress) = &mut progress {
progress.record(size);
}
buffer.drain(..size + 2);
}
}
fn read_http1_chunked_body_stream(
mut stream: BoxedStream,
mut buffer: Vec<u8>,
timeout_config: TimeoutConfig,
progress_callback: Option<ProgressCallback>,
progress_config: ProgressConfig,
trailers: Arc<OnceLock<HeaderMap>>,
) -> Body {
let stream_body = async_stream::try_stream! {
let mut progress = progress_callback.map(|callback| {
ProgressReporter::new(callback, ProgressPhase::Download, None, progress_config)
});
loop {
let size_line_end = loop {
if let Some(pos) = find_http1_crlf(&buffer) {
break pos;
}
read_http1_more(&mut stream, &mut buffer, timeout_config).await?;
};
let size_line = String::from_utf8(buffer[..size_line_end].to_vec()).map_err(|err| {
Error::with_source(ErrorKind::Decode, "chunk size is not valid utf-8", err)
})?;
buffer.drain(..size_line_end + 2);
let size = usize::from_str_radix(size_line.trim(), 16).map_err(|_| {
Error::new(
ErrorKind::Transport,
format!("invalid chunk size: {size_line}"),
)
})?;
if size == 0 {
if buffer.starts_with(b"\r\n") {
buffer.drain(..2);
} else {
while find_header_block_end(&buffer).is_none() {
read_http1_more(&mut stream, &mut buffer, timeout_config).await?;
}
let parsed = parse_http1_trailers(&buffer)?;
if let Some(parsed) = parsed {
let _ = trailers.set(parsed);
}
}
if let Some(progress) = &mut progress {
progress.finish();
}
break;
}
while buffer.len() < size + 2 {
read_http1_more(&mut stream, &mut buffer, timeout_config).await?;
}
let chunk = buffer.drain(..size).collect::<Vec<_>>();
buffer.drain(..2);
if let Some(progress) = &mut progress {
progress.record(size);
}
yield Bytes::from(chunk);
}
};
Body::from_stream(Box::pin(stream_body))
}
fn validate_h2c_upgrade_response(headers: &HeaderMap) -> Result<()> {
let upgrade = headers.get("upgrade").ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"h2c upgrade response is missing Upgrade: h2c",
)
})?;
if !upgrade.eq_ignore_ascii_case("h2c") {
return Err(Error::new(
ErrorKind::Transport,
format!("unexpected h2c upgrade token: {upgrade}"),
));
}
if !header_contains_token(headers, "connection", "upgrade") {
return Err(Error::new(
ErrorKind::Transport,
"h2c upgrade response is missing Connection: Upgrade",
));
}
Ok(())
}
fn header_contains_token(headers: &HeaderMap, name: &str, token: &str) -> bool {
headers
.get_all(name)
.iter()
.flat_map(|value| value.split(','))
.any(|part| part.trim().eq_ignore_ascii_case(token))
}
fn find_header_block_end(bytes: &[u8]) -> Option<usize> {
bytes.windows(4).position(|window| window == b"\r\n\r\n")
}
fn find_http1_crlf(bytes: &[u8]) -> Option<usize> {
bytes.windows(2).position(|window| window == b"\r\n")
}
async fn read_http1_more<S>(
stream: &mut S,
buffer: &mut Vec<u8>,
timeout_config: TimeoutConfig,
) -> Result<()>
where
S: futures_lite::io::AsyncRead + Unpin + ?Sized,
{
let mut chunk = [0_u8; 1024];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut chunk),
"read timed out",
)
.await?;
if read == 0 {
return Err(Error::new(
ErrorKind::Transport,
"unexpected eof while reading chunked body",
));
}
buffer.extend_from_slice(&chunk[..read]);
Ok(())
}
fn parse_http1_trailers(buffer: &[u8]) -> Result<Option<HeaderMap>> {
if buffer == b"\r\n" {
return Ok(None);
}
let header_end = find_header_block_end(buffer)
.ok_or_else(|| Error::new(ErrorKind::Transport, "incomplete trailer block"))?;
let head = &buffer[..header_end];
let text = std::str::from_utf8(head).map_err(|err| {
Error::with_source(ErrorKind::Decode, "trailers are not valid utf-8", err)
})?;
let mut headers = HeaderMap::new();
for line in text.split("\r\n") {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').ok_or_else(|| {
Error::new(
ErrorKind::Transport,
format!("invalid trailer line: {line}"),
)
})?;
headers.append(name.trim(), value.trim())?;
}
Ok(Some(headers))
}
fn parse_http1_status_line(line: &str) -> Result<StatusCode> {
let mut parts = line.split_whitespace();
let version = parts
.next()
.ok_or_else(|| Error::new(ErrorKind::Transport, "invalid h2c upgrade status line"))?;
if !version.starts_with("HTTP/1.") {
return Err(Error::new(
ErrorKind::Transport,
format!("unsupported h2c upgrade response version: {version}"),
));
}
let code = parts
.next()
.ok_or_else(|| Error::new(ErrorKind::Transport, "missing h2c upgrade status code"))?;
let code = code.parse().map_err(|_| {
Error::new(
ErrorKind::Transport,
format!("invalid h2c upgrade status code: {code}"),
)
})?;
Ok(StatusCode::new(code))
}
fn checked_window_add(current: i32, increment: i32) -> Result<i32> {
let updated = current
.checked_add(increment)
.ok_or_else(|| Error::new(ErrorKind::Transport, "http2 flow-control window overflowed"))?;
if updated > MAX_FLOW_CONTROL_WINDOW {
return Err(Error::new(
ErrorKind::Transport,
"http2 flow-control window exceeded protocol maximum",
));
}
Ok(updated)
}
async fn write_priority_fingerprint_frames(
stream: &mut BoxedStream,
priorities: &[Http2PrioritySpec],
phase: Http2PriorityPhase,
request_stream_id: u32,
timeout_config: TimeoutConfig,
) -> Result<()> {
for priority in priorities.iter().filter(|priority| priority.phase == phase) {
let target_stream_id = priority.stream_id.unwrap_or(request_stream_id);
if target_stream_id == 0 {
return Err(Error::new(
ErrorKind::Transport,
"http2 priority fingerprint cannot target stream 0",
));
}
write_priority_frame(
stream,
target_stream_id,
priority.stream_dependency,
priority.weight,
priority.exclusive,
timeout_config,
)
.await?;
}
Ok(())
}
fn connection_error(err: &Error) -> Error {
Error::new(err.kind().clone(), err.to_string())
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum HeaderBlockKind {
Informational,
Final,
Trailers,
}
struct ResponseStreamState {
request_method: Method,
compression_mode: CompressionMode,
status: Option<StatusCode>,
headers: HeaderMap,
trailers: Option<HeaderMap>,
body: Vec<u8>,
received_body_bytes: usize,
received_final_headers: bool,
received_trailers: bool,
end_stream: bool,
expected_content_length: Option<usize>,
}
impl ResponseStreamState {
fn new(request_method: Method, compression_mode: CompressionMode) -> Self {
Self {
request_method,
compression_mode,
status: None,
headers: HeaderMap::new(),
trailers: None,
body: Vec::new(),
received_body_bytes: 0,
received_final_headers: false,
received_trailers: false,
end_stream: false,
expected_content_length: None,
}
}
fn is_complete(&self) -> bool {
self.end_stream
}
fn mark_end_stream(&mut self) {
self.end_stream = true;
}
fn can_receive_data(&self) -> bool {
self.received_final_headers && !self.received_trailers
}
fn allows_response_body(&self) -> bool {
self.status
.map(|status| response_body_allowed(self.request_method, status))
.unwrap_or(false)
}
fn requires_aggregated_body(&self) -> bool {
self.allows_response_body() && self.headers.get("content-encoding").is_some()
}
fn record_body_bytes(&mut self, size: usize) {
self.received_body_bytes = self.received_body_bytes.saturating_add(size);
}
fn push_body(&mut self, chunk: &[u8]) {
self.body.extend_from_slice(chunk);
}
fn validate_received_body(&self) -> Result<()> {
if !self.allows_response_body() {
return Ok(());
}
if let Some(expected) = self.expected_content_length {
if self.received_body_bytes != expected {
return Err(Error::new(
ErrorKind::Transport,
format!(
"http2 response body length mismatch: expected {expected}, got {}",
self.received_body_bytes
),
));
}
}
Ok(())
}
fn apply_headers(&mut self, decoded: Vec<(String, String)>) -> Result<HeaderBlockKind> {
let parsed = ParsedHeaders::parse(decoded, !self.received_final_headers)?;
if let Some(status) = parsed.status {
if status.as_u16() < 200 {
return Ok(HeaderBlockKind::Informational);
}
if self.received_final_headers {
return Err(Error::new(
ErrorKind::Transport,
"http2 final response headers received more than once",
));
}
self.expected_content_length = parse_content_length(&parsed.headers, "http2")?;
for (name, value) in parsed.headers.iter() {
self.headers.append(name.as_str(), value.as_str())?;
}
self.status = Some(status);
self.received_final_headers = true;
return Ok(HeaderBlockKind::Final);
}
if !self.received_final_headers {
return Err(Error::new(
ErrorKind::Transport,
"http2 response headers missing :status",
));
}
if self.received_trailers {
return Err(Error::new(
ErrorKind::Transport,
"http2 trailers received more than once",
));
}
self.trailers = Some(parsed.headers);
self.received_trailers = true;
Ok(HeaderBlockKind::Trailers)
}
fn into_response(self, url: Url) -> Result<Response> {
let ResponseStreamState {
request_method,
compression_mode,
status,
mut headers,
trailers,
body,
received_body_bytes,
expected_content_length,
..
} = self;
let status = status.ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"http2 response completed without final headers",
)
})?;
if response_body_allowed(request_method, status) {
if let Some(expected) = expected_content_length {
if received_body_bytes != expected {
return Err(Error::new(
ErrorKind::Transport,
format!(
"http2 response body length mismatch: expected {expected}, got {}",
received_body_bytes
),
));
}
}
}
let body = if response_body_allowed(request_method, status) {
maybe_decode_response_body(&mut headers, body, compression_mode)?
} else {
Body::from(body)
};
Ok(Response::new_with_trailer_state(
status,
Version::Http2,
url,
headers,
TrailerState::Ready(trailers),
body,
))
}
}
struct ParsedHeaders {
status: Option<StatusCode>,
headers: HeaderMap,
}
impl ParsedHeaders {
fn parse(decoded: Vec<(String, String)>, allow_status: bool) -> Result<Self> {
let mut status = None;
let mut regular_seen = false;
let mut headers = HeaderMap::new();
for (name, value) in decoded {
if let Some(pseudo_name) = name.strip_prefix(':') {
if regular_seen {
return Err(Error::new(
ErrorKind::Transport,
"http2 pseudo headers must appear before regular headers",
));
}
if !allow_status || pseudo_name != "status" {
return Err(Error::new(
ErrorKind::Transport,
format!("unexpected http2 pseudo header: {name}"),
));
}
if status.is_some() {
return Err(Error::new(
ErrorKind::Transport,
"http2 response repeated :status pseudo header",
));
}
let value = value.parse::<u16>().map_err(|err| {
Error::with_source(ErrorKind::Transport, "invalid http2 :status value", err)
})?;
status = Some(StatusCode::new(value));
continue;
}
regular_seen = true;
headers.append(name, value)?;
}
Ok(Self { status, headers })
}
}