use http::header::{CONNECTION, TE, TRANSFER_ENCODING, UPGRADE};
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response, Version};
use hyper::body::Incoming;
use hyper::client::conn::{http1, http2};
#[cfg(feature = "websocket")]
use openwire_core::Connection;
use openwire_core::{
BoxConnection, CoalescingInfo, Connected, ConnectionInfo, HyperExecutor, RequestBody,
SharedTimer, WireError,
};
use crate::connection::{Address, ConnectionProtocol, RouteKind, UriScheme};
pub(super) async fn bind_http1(
stream: BoxConnection,
) -> Result<
(
http1::SendRequest<RequestBody>,
http1::Connection<BoxConnection, RequestBody>,
),
WireError,
> {
http1::Builder::new()
.handshake(stream)
.await
.map_err(|error| WireError::protocol_binding("HTTP/1.1 client handshake failed", error))
}
pub(super) async fn bind_http2(
stream: BoxConnection,
config: &crate::client::TransportConfig,
executor: HyperExecutor,
timer: SharedTimer,
) -> Result<
(
http2::SendRequest<RequestBody>,
http2::Connection<BoxConnection, RequestBody, HyperExecutor>,
),
WireError,
> {
let mut builder = http2::Builder::new(executor);
builder.timer(timer);
if let Some(interval) = config.http2_keep_alive_interval {
builder.keep_alive_interval(interval);
builder.keep_alive_while_idle(config.http2_keep_alive_while_idle);
}
builder
.handshake(stream)
.await
.map_err(|error| WireError::protocol_binding("HTTP/2 client handshake failed", error))
}
#[cfg(feature = "websocket")]
pub(crate) async fn bind_websocket_handshake(
io: BoxConnection,
request: Request<RequestBody>,
) -> Result<(Response<()>, Option<hyper::upgrade::Upgraded>), WireError> {
let (mut send, conn) = http1::handshake(io)
.await
.map_err(|error| WireError::protocol_binding("HTTP/1.1 client handshake failed", error))?;
tokio::spawn(async move {
let _ = conn.with_upgrades().await;
});
let mut response = send.send_request(request).await.map_err(WireError::from)?;
let upgraded = if response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
Some(
hyper::upgrade::on(&mut response)
.await
.map_err(WireError::from)?,
)
} else {
None
};
let mut out = Response::new(());
*out.status_mut() = response.status();
*out.version_mut() = response.version();
*out.headers_mut() = response.headers().clone();
Ok((out, upgraded))
}
#[cfg(feature = "websocket")]
pub(crate) fn upgraded_into_box_connection(upgraded: hyper::upgrade::Upgraded) -> BoxConnection {
Box::new(UpgradedConnection { inner: upgraded })
}
#[cfg(feature = "websocket")]
struct UpgradedConnection {
inner: hyper::upgrade::Upgraded,
}
#[cfg(feature = "websocket")]
impl Connection for UpgradedConnection {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[cfg(feature = "websocket")]
impl hyper::rt::Read for UpgradedConnection {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: hyper::rt::ReadBufCursor<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
}
}
#[cfg(feature = "websocket")]
impl hyper::rt::Write for UpgradedConnection {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
pub(super) fn determine_protocol(address: &Address, connected: &Connected) -> ConnectionProtocol {
if address.scheme() == UriScheme::Http {
ConnectionProtocol::Http1
} else if connected.is_negotiated_h2() {
ConnectionProtocol::Http2
} else {
ConnectionProtocol::Http1
}
}
pub(super) fn prepare_bound_request(
mut request: Request<RequestBody>,
protocol: ConnectionProtocol,
route_kind: &RouteKind,
) -> Result<Request<RequestBody>, WireError> {
if protocol == ConnectionProtocol::Http2 {
strip_http2_connection_specific_headers(request.headers_mut());
return Ok(request);
}
if protocol != ConnectionProtocol::Http1
|| matches!(route_kind, RouteKind::HttpForwardProxy { .. })
{
return Ok(request);
}
let origin_form = request
.uri()
.path_and_query()
.map(|path| path.as_str())
.unwrap_or("/");
*request.uri_mut() = origin_form.parse().map_err(|error| {
WireError::internal(
"failed to normalize request URI for direct HTTP/1.1 binding",
error,
)
})?;
Ok(request)
}
fn strip_http2_connection_specific_headers(headers: &mut HeaderMap) {
let connection_nominated = connection_nominated_headers(headers);
let te_values = headers.get_all(TE).iter().cloned().collect::<Vec<_>>();
headers.remove(TE);
headers.remove(CONNECTION);
headers.remove(keep_alive_header());
headers.remove(proxy_connection_header());
headers.remove(TRANSFER_ENCODING);
headers.remove(UPGRADE);
let te_was_nominated = connection_nominated.contains(&TE);
for name in connection_nominated {
headers.remove(name);
}
if !te_was_nominated {
append_http2_te_trailers(headers, te_values);
}
}
fn connection_nominated_headers(headers: &HeaderMap) -> Vec<HeaderName> {
let mut out = Vec::new();
for value in headers.get_all(CONNECTION) {
let Ok(value) = value.to_str() else {
continue;
};
for token in value
.split(',')
.map(str::trim)
.filter(|token| !token.is_empty())
{
let Ok(name) = HeaderName::from_bytes(token.as_bytes()) else {
continue;
};
if !out.contains(&name) {
out.push(name);
}
}
}
out
}
fn append_http2_te_trailers(headers: &mut HeaderMap, te_values: Vec<HeaderValue>) {
let has_trailers = te_values.iter().any(|value| {
value.to_str().is_ok_and(|value| {
value
.split(',')
.map(str::trim)
.any(|token| token.eq_ignore_ascii_case("trailers"))
})
});
if has_trailers {
headers.insert(TE, HeaderValue::from_static("trailers"));
}
}
fn keep_alive_header() -> HeaderName {
HeaderName::from_static("keep-alive")
}
fn proxy_connection_header() -> HeaderName {
HeaderName::from_static("proxy-connection")
}
pub(super) fn http1_exchange_allows_reuse(
request_requests_close: bool,
response: &Response<Incoming>,
) -> bool {
if request_requests_close {
return false;
}
if response.version() == Version::HTTP_10 {
return false;
}
!connection_header_requests_close(response.headers())
}
pub(super) fn connection_header_requests_close(headers: &HeaderMap) -> bool {
headers
.get_all(CONNECTION)
.iter()
.any(|value| connection_header_value_requests_close(value).unwrap_or(true))
}
fn connection_header_value_requests_close(value: &HeaderValue) -> Result<bool, ()> {
let value = value.to_str().map_err(|_| ())?;
let bytes = value.as_bytes();
let mut index = 0usize;
while index < bytes.len() {
while index < bytes.len() && is_optional_whitespace(bytes[index]) {
index += 1;
}
if index == bytes.len() {
return Ok(false);
}
if bytes[index] == b',' {
index += 1;
continue;
}
let token_start = index;
while index < bytes.len() && is_tchar(bytes[index]) {
index += 1;
}
if token_start == index {
return Err(());
}
let token = &value[token_start..index];
while index < bytes.len() && is_optional_whitespace(bytes[index]) {
index += 1;
}
match bytes.get(index).copied() {
None => return Ok(token.eq_ignore_ascii_case("close")),
Some(b',') => {
if token.eq_ignore_ascii_case("close") {
return Ok(true);
}
index += 1;
}
Some(_) => return Err(()),
}
}
Ok(false)
}
fn is_optional_whitespace(byte: u8) -> bool {
matches!(byte, b' ' | b'\t')
}
fn is_tchar(byte: u8) -> bool {
matches!(
byte,
b'!' | b'#'
| b'$'
| b'%'
| b'&'
| b'\''
| b'*'
| b'+'
| b'-'
| b'.'
| b'^'
| b'_'
| b'`'
| b'|'
| b'~'
| b'0'..=b'9'
| b'A'..=b'Z'
| b'a'..=b'z'
)
}
pub(super) fn map_hyper_error(error: hyper::Error) -> WireError {
if let Some(source) = find_wire_error(&error) {
return source.clone();
}
WireError::from(error)
}
fn find_wire_error<'a>(error: &'a (dyn std::error::Error + 'static)) -> Option<&'a WireError> {
let mut current = Some(error);
while let Some(source) = current {
if let Some(wire_error) = source.downcast_ref::<WireError>() {
return Some(wire_error);
}
current = source.source();
}
None
}
pub(super) fn connection_info_from_connected(connected: &Connected) -> ConnectionInfo {
connected.connection_info_or_default()
}
pub(super) fn coalescing_info_from_connected(connected: &Connected) -> CoalescingInfo {
connected.coalescing_info().clone()
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use http::header::{CONNECTION, TE, TRANSFER_ENCODING, UPGRADE};
use http::Request;
use super::prepare_bound_request;
use crate::connection::{ConnectionProtocol, RouteKind};
use openwire_core::RequestBody;
#[test]
fn http2_bound_request_strips_connection_specific_headers() {
let request = Request::builder()
.uri("https://example.test/resource")
.header(CONNECTION, "keep-alive, x-hop, upgrade")
.header("keep-alive", "timeout=5")
.header("proxy-connection", "keep-alive")
.header(TRANSFER_ENCODING, "chunked")
.header(UPGRADE, "websocket")
.header("x-hop", "secret")
.header(TE, "gzip")
.body(RequestBody::empty())
.expect("request");
let prepared = prepare_bound_request(request, ConnectionProtocol::Http2, &direct_route())
.expect("prepared request");
let headers = prepared.headers();
assert!(headers.get(CONNECTION).is_none());
assert!(headers.get("keep-alive").is_none());
assert!(headers.get("proxy-connection").is_none());
assert!(headers.get(TRANSFER_ENCODING).is_none());
assert!(headers.get(UPGRADE).is_none());
assert!(headers.get("x-hop").is_none());
assert!(headers.get(TE).is_none());
}
#[test]
fn http2_bound_request_preserves_te_trailers_only() {
let request = Request::builder()
.uri("https://example.test/resource")
.header(TE, "gzip, trailers")
.header(TE, "trailers")
.body(RequestBody::empty())
.expect("request");
let prepared = prepare_bound_request(request, ConnectionProtocol::Http2, &direct_route())
.expect("prepared request");
let te_values = prepared
.headers()
.get_all(TE)
.iter()
.map(|value| value.to_str().expect("te header"))
.collect::<Vec<_>>();
assert_eq!(te_values, vec!["trailers"]);
}
#[test]
fn http2_bound_request_strips_connection_nominated_te() {
let request = Request::builder()
.uri("https://example.test/resource")
.header(CONNECTION, "te")
.header(TE, "trailers")
.body(RequestBody::empty())
.expect("request");
let prepared = prepare_bound_request(request, ConnectionProtocol::Http2, &direct_route())
.expect("prepared request");
assert!(prepared.headers().get(TE).is_none());
}
fn direct_route() -> RouteKind {
RouteKind::Direct {
target: SocketAddr::from(([127, 0, 0, 1], 443)),
}
}
}