use rama_core::Layer;
use rama_core::Service;
use rama_core::bytes::BytesMut;
use rama_core::error::BoxError;
use rama_core::error::ErrorContext;
use rama_core::extensions::ExtensionsRef;
use rama_core::telemetry::tracing;
use rama_http_headers::Connection;
use rama_http_headers::HeaderMapExt;
use rama_http_headers::Host;
use rama_http_headers::SecWebSocketKey;
use rama_http_headers::SecWebSocketVersion;
use rama_http_headers::Upgrade;
use rama_http_types::HeaderValue;
use rama_http_types::Method;
use rama_http_types::Request;
use rama_http_types::Version;
use rama_http_types::conn::TargetHttpVersion;
use rama_http_types::header::COOKIE;
use rama_http_types::header::Entry;
use rama_http_types::header::HOST;
use rama_http_types::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION};
use rama_http_types::proto::h2::ext::Protocol;
use rama_net::client::{ConnectorService, EstablishedClientConnection};
use rama_net::{AuthorityInputExt, Protocol as Scheme, ProtocolInputExt};
use crate::layer::remove_header::remove_illegal_h2_request_headers;
use rama_utils::macros::generate_set_and_with;
#[derive(Clone, Debug)]
pub struct RequestVersionAdapter<S> {
inner: S,
default_http_version: Option<Version>,
}
impl<S> RequestVersionAdapter<S> {
pub fn new(inner: S) -> Self {
Self {
inner,
default_http_version: None,
}
}
generate_set_and_with! {
pub fn default_version(mut self, version: Option<Version>) -> Self {
self.default_http_version = version;
self
}
}
}
impl<S, Body> Service<Request<Body>> for RequestVersionAdapter<S>
where
S: ConnectorService<Request<Body>, Error: Into<BoxError>>,
Body: Send + 'static,
{
type Output = EstablishedClientConnection<S::Connection, Request<Body>>;
type Error = BoxError;
async fn serve(&self, req: Request<Body>) -> Result<Self::Output, Self::Error> {
let EstablishedClientConnection {
conn,
input: mut req,
} = self.inner.connect(req).await.into_box_error()?;
let version = req
.extensions()
.clone_to_if_absent::<TargetHttpVersion>(conn.extensions())
.map(|version| version.0);
match (version, self.default_http_version) {
(Some(version), _) => {
tracing::trace!(
"setting request version to {:?} based on configured TargetHttpVersion (was: {:?})",
version,
req.version(),
);
adapt_request_version(&mut req, version)?;
}
(_, Some(version)) => {
tracing::trace!(
"setting request version to {:?} based on configured default http version (was: {:?})",
version,
req.version(),
);
adapt_request_version(&mut req, version)?;
conn.extensions().insert(TargetHttpVersion(version));
}
(None, None) => {
tracing::trace!(
"no TargetHttpVersion or default http version configured, leaving request version {:?}",
req.version(),
);
}
}
Ok(EstablishedClientConnection { input: req, conn })
}
}
#[derive(Clone, Debug, Default)]
pub struct RequestVersionAdapterLayer {
default_http_version: Option<Version>,
}
impl RequestVersionAdapterLayer {
#[must_use]
pub fn new() -> Self {
Self {
default_http_version: None,
}
}
generate_set_and_with! {
pub fn default_version(mut self, version: Option<Version>) -> Self {
self.default_http_version = version;
self
}
}
}
impl<S> Layer<S> for RequestVersionAdapterLayer {
type Service = RequestVersionAdapter<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestVersionAdapter {
inner,
default_http_version: self.default_http_version,
}
}
}
pub fn adapt_request_version<Body>(
request: &mut Request<Body>,
target_version: Version,
) -> Result<(), BoxError> {
let request_version = request.version();
if request_version == target_version {
tracing::trace!(
?target_version,
"request version already satisfied, skipping it"
);
return Ok(());
}
tracing::trace!(
?request_version,
?target_version,
"changing request version"
);
let request_is_h1 = request_version <= Version::HTTP_11;
let target_is_h1 = target_version <= Version::HTTP_11;
match (request_is_h1, target_is_h1) {
(true, false) => translate_request_upgrade(request)?,
(false, true) => translate_request_downgrade(request)?,
(true, true) | (false, false) => {}
}
*request.version_mut() = target_version;
ensure_valid_request_for_version(request)?;
Ok(())
}
pub fn ensure_valid_request_for_version<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
if request.version() <= Version::HTTP_11 {
ensure_valid_h1_request(request)
} else {
ensure_valid_h2_or_h3_request(request)
}
}
pub fn ensure_valid_h1_request<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
ensure_h1_host_header(request)?;
merge_cookie_headers_for_http1(request)?;
Ok(())
}
pub fn ensure_valid_h2_or_h3_request<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
ensure_h2_or_h3_uri_authority(request)?;
remove_illegal_h2_request_headers(request.headers_mut());
Ok(())
}
pub(crate) fn is_websocket_protocol(protocol: &Protocol) -> bool {
protocol.as_str().eq_ignore_ascii_case("websocket")
}
pub(crate) fn request_connect_protocol<Body>(request: &Request<Body>) -> Option<Protocol> {
if request.method() == Method::CONNECT
&& let Some(protocol) = request.extensions().get_ref::<Protocol>()
{
return Some(protocol.clone());
}
let is_genuine_upgrade = request
.headers()
.typed_get::<Connection>()
.is_some_and(|connection| connection.contains_upgrade());
if !is_genuine_upgrade {
return None;
}
let upgrade = request.headers().typed_get::<Upgrade>()?;
let token = std::str::from_utf8(upgrade.as_bytes()).ok()?.trim();
(!token.is_empty()).then(|| Protocol::from(token))
}
fn translate_request_upgrade<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
match request_connect_protocol(request) {
Some(protocol) if is_websocket_protocol(&protocol) => {
tracing::trace!("translating h1 websocket upgrade into h2/h3 extended CONNECT");
*request.method_mut() = Method::CONNECT;
request
.extensions()
.insert(Protocol::from_static("websocket"));
}
Some(protocol) => {
return Err(BoxError::from(format!(
"cannot translate HTTP/1 `Upgrade: {}` into an HTTP/2+ Extended CONNECT: only websocket is supported",
protocol.as_str(),
)));
}
None => {}
}
Ok(())
}
fn translate_request_downgrade<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
match request_connect_protocol(request) {
Some(protocol) if is_websocket_protocol(&protocol) => {
tracing::trace!("translating h2/h3 extended CONNECT websocket into h1 upgrade");
*request.method_mut() = Method::GET;
let headers = request.headers_mut();
headers.typed_insert(Upgrade::websocket());
headers.typed_insert(Connection::upgrade());
if !headers.contains_key(SEC_WEBSOCKET_KEY) {
headers.typed_insert(SecWebSocketKey::random());
}
if !headers.contains_key(SEC_WEBSOCKET_VERSION) {
headers.typed_insert(SecWebSocketVersion::V13);
}
}
Some(protocol) => {
return Err(BoxError::from(format!(
"cannot translate an HTTP/2+ Extended CONNECT `:protocol: {}` request to HTTP/1: only websocket is supported",
protocol.as_str(),
)));
}
None => {}
}
Ok(())
}
pub fn ensure_h1_host_header<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
if request.headers().contains_key(HOST) {
return Ok(());
}
let authority = request
.authority()
.context("ensure h1 Host header: request has no resolvable authority")?;
let protocol = request.protocol().cloned();
let authority = authority.without_default_port_for(protocol.as_ref());
tracing::trace!("adding Host header {authority} derived from request authority");
request.headers_mut().typed_insert(Host::from(authority));
Ok(())
}
pub fn ensure_h2_or_h3_uri_authority<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
if request.uri().host().is_some() {
return Ok(());
}
let authority = request
.authority()
.context("ensure h2 URI authority: request has no resolvable authority")?;
let protocol = request.protocol().cloned();
let authority = authority.without_default_port_for(protocol.as_ref());
tracing::trace!("materializing authority {authority} and scheme into request URI");
let uri = request.uri_mut();
uri.set_scheme(protocol.unwrap_or(Scheme::HTTP));
uri.set_host(authority.host);
uri.set_port(authority.port);
Ok(())
}
fn merge_cookie_headers_for_http1<Body>(request: &mut Request<Body>) -> Result<(), BoxError> {
if let Entry::Occupied(cookie_headers) = request.headers_mut().entry(COOKIE) {
let Some((bytes_count, header_count)) = cookie_headers
.iter()
.map(|v| (v.as_bytes().len(), 1usize))
.reduce(|a, b| (a.0 + b.0, a.1 + b.1))
else {
return Ok(());
};
if header_count <= 1 {
return Ok(());
}
let (header_name, mut header_values) = cookie_headers.remove_entry_mult();
let mut buffer = BytesMut::with_capacity(bytes_count + ((header_count - 1) * 2));
if let Some(header_value) = header_values.next() {
buffer.extend_from_slice(header_value.as_bytes());
}
for header_value in header_values {
buffer.extend_from_slice(b"; ");
buffer.extend_from_slice(header_value.as_bytes());
}
let new_header_value = HeaderValue::from_maybe_shared(buffer)
.context("create new cookie header value from combined multiple values")?;
request.headers_mut().insert(header_name, new_header_value);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rama_http_types::header::{CONNECTION, COOKIE, HOST, TRANSFER_ENCODING, UPGRADE};
#[test]
fn test_h1_to_h2_strips_connection_specific_headers() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.uri("https://example.com")
.header(HOST, "example.com")
.header(CONNECTION, "keep-alive, x-custom")
.header("keep-alive", "timeout=5")
.header(TRANSFER_ENCODING, "chunked")
.header("x-custom", "1")
.header("x-keep", "yes")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_2).unwrap();
assert_eq!(req.version(), Version::HTTP_2);
for illegal in [&HOST, &CONNECTION, &TRANSFER_ENCODING] {
assert!(
!req.headers().contains_key(illegal),
"expected {illegal:?} to be removed"
);
}
assert!(!req.headers().contains_key("x-custom"));
assert!(!req.headers().contains_key("keep-alive"));
assert_eq!(req.headers().get("x-keep").unwrap(), "yes");
}
#[test]
fn test_h1_to_h2_websocket_upgrade_becomes_extended_connect() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.method(Method::GET)
.uri("https://example.com/chat")
.header(UPGRADE, "websocket")
.header(CONNECTION, "Upgrade")
.header(SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")
.header(SEC_WEBSOCKET_VERSION, "13")
.header("sec-websocket-protocol", "chat")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_2).unwrap();
assert_eq!(req.method(), Method::CONNECT);
assert_eq!(
req.extensions().get_ref::<Protocol>().map(|p| p.as_str()),
Some("websocket"),
);
assert!(!req.headers().contains_key(UPGRADE));
assert!(!req.headers().contains_key(CONNECTION));
assert!(!req.headers().contains_key(SEC_WEBSOCKET_KEY));
assert_eq!(req.headers().get(SEC_WEBSOCKET_VERSION).unwrap(), "13");
assert_eq!(req.headers().get("sec-websocket-protocol").unwrap(), "chat");
}
#[test]
fn test_h2_to_h1_websocket_connect_becomes_upgrade() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.method(Method::CONNECT)
.uri("https://example.com/chat")
.header(SEC_WEBSOCKET_VERSION, "13")
.body(())
.unwrap();
req.extensions().insert(Protocol::from_static("websocket"));
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
assert_eq!(req.method(), Method::GET);
assert_eq!(req.version(), Version::HTTP_11);
assert!(
req.headers()
.typed_get::<Upgrade>()
.is_some_and(|u| u.is_websocket())
);
assert!(
req.headers()
.typed_get::<Connection>()
.is_some_and(|c| c.contains_upgrade())
);
assert!(req.headers().contains_key(SEC_WEBSOCKET_KEY));
assert_eq!(req.headers().get(SEC_WEBSOCKET_VERSION).unwrap(), "13");
}
#[test]
fn test_h2_to_h1_non_websocket_connect_untouched() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.method(Method::CONNECT)
.uri("example.com:443")
.header(HOST, "example.com:443")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
assert_eq!(req.method(), Method::CONNECT);
assert!(!req.headers().contains_key(UPGRADE));
assert!(!req.headers().contains_key(SEC_WEBSOCKET_KEY));
}
#[test]
fn test_h2_to_h1_adds_host_from_authority() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com/path")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(req.headers().get(HOST).unwrap(), "example.com");
}
#[test]
fn test_h1_to_h2_materializes_uri_authority_and_strips_host() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.uri("/path")
.header(HOST, "example.com")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_2).unwrap();
assert_eq!(req.version(), Version::HTTP_2);
assert_eq!(req.uri().host_str().as_deref(), Some("example.com"));
assert!(!req.headers().contains_key(HOST));
}
#[test]
fn test_h1_to_h3_websocket_upgrade_becomes_extended_connect() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.method(Method::GET)
.uri("https://example.com/chat")
.header(UPGRADE, "websocket")
.header(CONNECTION, "Upgrade")
.header(SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_3).unwrap();
assert_eq!(req.version(), Version::HTTP_3);
assert_eq!(req.method(), Method::CONNECT);
assert_eq!(
req.extensions().get_ref::<Protocol>().map(|p| p.as_str()),
Some("websocket"),
);
assert!(!req.headers().contains_key(UPGRADE));
assert!(!req.headers().contains_key(SEC_WEBSOCKET_KEY));
}
#[test]
fn test_h2_to_h3_only_changes_version() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com/path")
.header(COOKIE, "a=1")
.header(COOKIE, "b=2")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_3).unwrap();
assert_eq!(req.version(), Version::HTTP_3);
assert_eq!(req.headers().get_all(COOKIE).iter().count(), 2);
}
#[test]
fn test_h1_to_h2_unsupported_upgrade_errors() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.method(Method::GET)
.uri("https://example.com/")
.header(UPGRADE, "myproto")
.header(CONNECTION, "Upgrade")
.body(())
.unwrap();
let err = adapt_request_version(&mut req, Version::HTTP_2).unwrap_err();
assert!(
err.to_string().contains("only websocket is supported"),
"{err}"
);
}
#[test]
fn test_h1_to_h2_upgrade_advertisement_is_not_a_switch() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.method(Method::GET)
.uri("https://example.com/")
.header(UPGRADE, "h2c")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_2).unwrap();
assert_eq!(req.version(), Version::HTTP_2);
assert!(!req.headers().contains_key(UPGRADE));
}
#[test]
fn test_h2_to_h1_unsupported_extended_connect_errors() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.method(Method::CONNECT)
.uri("https://example.com/")
.body(())
.unwrap();
req.extensions()
.insert(Protocol::from_static("connect-udp"));
let err = adapt_request_version(&mut req, Version::HTTP_11).unwrap_err();
assert!(
err.to_string().contains("only websocket is supported"),
"{err}"
);
}
#[test]
fn test_merge_multiple_cookies_http2_to_http1() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com")
.header(COOKIE, "a=1")
.header(COOKIE, "b=2")
.header(COOKIE, "c=3")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(
cookie_values.len(),
1,
"Should have exactly one Cookie header"
);
assert_eq!(cookie_values[0].as_bytes(), b"a=1; b=2; c=3");
assert_eq!(req.version(), Version::HTTP_11);
}
#[test]
fn test_merge_multiple_cookies_http3_to_http1() {
let mut req = Request::builder()
.version(Version::HTTP_3)
.uri("https://example.com")
.header(COOKIE, "session=abc123")
.header(COOKIE, "token=xyz789")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(
cookie_values.len(),
1,
"Should have exactly one Cookie header"
);
assert_eq!(cookie_values[0].as_bytes(), b"session=abc123; token=xyz789");
}
#[test]
fn test_single_cookie_http2_to_http1_unchanged() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com")
.header(COOKIE, "single=cookie")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(cookie_values.len(), 1);
assert_eq!(cookie_values[0].as_bytes(), b"single=cookie");
}
#[test]
fn test_no_merge_http1_to_http2() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.uri("https://example.com")
.header(COOKIE, "a=1")
.header(COOKIE, "b=2")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_2).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(
cookie_values.len(),
2,
"Should preserve multiple headers when converting to HTTP/2"
);
}
#[test]
fn test_no_cookies_http2_to_http1() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(cookie_values.len(), 0);
}
#[test]
fn test_merge_preserves_order() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com")
.header(COOKIE, "first=1")
.header(COOKIE, "second=2")
.header(COOKIE, "third=3")
.header(COOKIE, "fourth=4")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(cookie_values.len(), 1);
assert_eq!(
cookie_values[0].as_bytes(),
b"first=1; second=2; third=3; fourth=4",
"Cookie order should be preserved"
);
}
#[test]
fn test_complex_cookie_values() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com")
.header(COOKIE, "uaid=abc123def456")
.header(COOKIE, "MSCC=NR")
.header(COOKIE, "MUID=1234567890ABCDEF")
.header(COOKIE, "VAL1=ASD=DSA&HASH=41&LV=41&V=4&LU=41")
.header(COOKIE, "empty=")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(cookie_values.len(), 1);
assert_eq!(
cookie_values[0].as_bytes(),
b"uaid=abc123def456; MSCC=NR; MUID=1234567890ABCDEF; VAL1=ASD=DSA&HASH=41&LV=41&V=4&LU=41; empty=",
);
}
#[test]
fn test_same_version_http2_keeps_multiple_cookies() {
let mut req = Request::builder()
.version(Version::HTTP_2)
.uri("https://example.com")
.header(COOKIE, "a=1")
.header(COOKIE, "b=2")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_2).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(cookie_values.len(), 2);
}
#[test]
fn test_same_version_http1_is_noop() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.uri("https://example.com")
.header(COOKIE, "a=1")
.header(COOKIE, "b=2")
.body(())
.unwrap();
adapt_request_version(&mut req, Version::HTTP_11).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(cookie_values.len(), 2);
}
#[test]
fn test_ensure_valid_h1_request_normalizes() {
let mut req = Request::builder()
.version(Version::HTTP_11)
.uri("https://example.com")
.header(COOKIE, "a=1")
.header(COOKIE, "b=2")
.body(())
.unwrap();
ensure_valid_request_for_version(&mut req).unwrap();
let cookie_values: Vec<_> = req.headers().get_all(COOKIE).iter().collect();
assert_eq!(cookie_values.len(), 1);
assert_eq!(cookie_values[0].as_bytes(), b"a=1; b=2");
assert!(req.headers().contains_key(HOST));
}
}