use rama_core::{
Context, Service,
context::{self, RequestContextExt},
error::{BoxError, ErrorContext, OpaqueError},
inspect::RequestInspector,
};
use rama_http_headers::{HeaderMapExt, Host};
use rama_http_types::{
Method, Request, Response, Version,
dep::{http::uri::PathAndQuery, http_body},
header::{CONNECTION, HOST, KEEP_ALIVE, PROXY_CONNECTION, TRANSFER_ENCODING, UPGRADE},
};
use rama_net::{address::ProxyAddress, http::RequestContext};
use std::fmt;
use tokio::sync::Mutex;
pub(super) enum SendRequest<Body> {
Http1(Mutex<rama_http_core::client::conn::http1::SendRequest<Body>>),
Http2(rama_http_core::client::conn::http2::SendRequest<Body>),
}
impl<Body: fmt::Debug> fmt::Debug for SendRequest<Body> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_tuple("SendRequest");
match self {
SendRequest::Http1(send_request) => f.field(send_request).finish(),
SendRequest::Http2(send_request) => f.field(send_request).finish(),
}
}
}
pub struct HttpClientService<Body, I = ()> {
pub(super) sender: SendRequest<Body>,
pub(super) http_req_inspector: I,
}
impl<State, BodyIn, BodyOut, I> Service<State, Request<BodyIn>> for HttpClientService<BodyOut, I>
where
State: Clone + Send + Sync + 'static,
BodyIn: Send + 'static,
BodyOut: http_body::Body<Data: Send + 'static, Error: Into<BoxError>> + Unpin + Send + 'static,
I: RequestInspector<
State,
Request<BodyIn>,
Error: Into<BoxError>,
RequestOut = Request<BodyOut>,
>,
{
type Response = Response;
type Error = BoxError;
async fn serve(
&self,
ctx: Context<State>,
mut req: Request<BodyIn>,
) -> Result<Self::Response, Self::Error> {
let original_http_version = req.version();
match self.sender {
SendRequest::Http1(_) => match original_http_version {
Version::HTTP_10 | Version::HTTP_11 => {
tracing::trace!(
?original_http_version,
"request version is already h1 compatible, it will remain unchanged",
);
}
_ => {
tracing::debug!(
?original_http_version,
new_http_version = ?Version::HTTP_11,
"modify request version to compatible h1 connection version",
);
*req.version_mut() = Version::HTTP_11;
}
},
SendRequest::Http2(_) => match original_http_version {
Version::HTTP_2 => {
tracing::trace!(
?original_http_version,
"request version is already h2 compatible, it will remain unchanged",
);
}
_ => {
tracing::debug!(
?original_http_version,
new_http_version = ?Version::HTTP_2,
"modify request version to compatible h2 connection version",
);
*req.version_mut() = Version::HTTP_2;
}
},
}
let (mut ctx, req) = self
.http_req_inspector
.inspect_request(ctx, req)
.await
.map_err(Into::into)?;
let req = sanitize_client_req_header(&mut ctx, req)?;
let context::Parts { extensions, .. } = ctx.into_parts();
let mut resp = match &self.sender {
SendRequest::Http1(sender) => {
let mut sender = sender.lock().await;
sender.ready().await?;
sender.send_request(req).await
}
SendRequest::Http2(sender) => {
let mut sender = sender.clone();
sender.ready().await?;
sender.send_request(req).await
}
}?;
resp.extensions_mut()
.insert(RequestContextExt::from(extensions));
let original_resp_http_version = resp.version();
if original_resp_http_version == original_http_version {
tracing::trace!(
?original_http_version,
"response version matches original http request version, it will remain unchanged",
);
} else {
*resp.version_mut() = original_http_version;
tracing::trace!(
?original_http_version,
?original_resp_http_version,
"change the response http version into the original http request version",
);
}
Ok(resp.map(rama_http_types::Body::new))
}
}
fn sanitize_client_req_header<S, B>(
ctx: &mut Context<S>,
req: Request<B>,
) -> Result<Request<B>, BoxError> {
if req.method() == Method::CONNECT && req.uri().host().is_none() {
return Err(OpaqueError::from_display("missing host in CONNECT request").into());
}
let is_request_proxied = ctx.contains::<ProxyAddress>();
let request_ctx = ctx
.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
.context("fetch request context")?;
Ok(match req.version() {
Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => {
if !is_request_proxied && req.uri().host().is_some() {
tracing::trace!(
"remove authority and scheme from non-connect direct http(~1) request"
);
let (mut parts, body) = req.into_parts();
let mut uri_parts = parts.uri.into_parts();
uri_parts.scheme = None;
uri_parts.authority = None;
if uri_parts.path_and_query.as_ref().map(|pq| pq.as_str()) == Some("/") {
uri_parts.path_and_query = Some(PathAndQuery::from_static("/"));
}
if !parts.headers.contains_key(HOST) {
if request_ctx.authority_has_default_port() {
let host = request_ctx.authority.host().clone();
tracing::trace!(
%host,
"add missing host from authority as host header"
);
parts.headers.typed_insert(Host::from(host));
} else {
let authority = request_ctx.authority.clone();
tracing::trace!(
%authority,
"add missing authority as host header"
);
parts.headers.typed_insert(Host::from(authority));
}
}
parts.uri = rama_http_types::Uri::from_parts(uri_parts)?;
Request::from_parts(parts, body)
} else if !req.headers().contains_key(HOST) {
let mut req = req;
if request_ctx.authority_has_default_port() {
let authority = request_ctx.authority.clone();
tracing::trace!(
uri = %req.uri(),
%authority,
"add host from authority as HOST header to req (was missing it)",
);
req.headers_mut().typed_insert(Host::from(authority));
} else {
let host = request_ctx.authority.host().clone();
tracing::trace!(
uri = %req.uri(),
%host,
"add authority as HOST header to req (was missing it)",
);
req.headers_mut().typed_insert(Host::from(host));
}
req
} else {
req
}
}
Version::HTTP_2 => {
let mut req = if req.uri().host().is_none() {
let request_ctx = ctx.get::<RequestContext>().ok_or_else(|| {
OpaqueError::from_display("[h2+] add scheme/host: missing RequestCtx")
.into_boxed()
})?;
tracing::trace!(
http_version = ?req.version(),
"defining authority and scheme to non-connect direct http request"
);
let (mut parts, body) = req.into_parts();
let mut uri_parts = parts.uri.into_parts();
uri_parts.scheme = Some(
request_ctx
.protocol
.as_str()
.try_into()
.context("use RequestContext.protocol as http scheme")?,
);
let authority = if request_ctx.authority_has_default_port() {
request_ctx.authority.host().to_string()
} else {
request_ctx.authority.to_string()
};
uri_parts.authority = Some(
authority
.try_into()
.context("use RequestContext.authority as http authority")?,
);
parts.uri = rama_http_types::Uri::from_parts(uri_parts)
.context("create http uri from parts")?;
Request::from_parts(parts, body)
} else {
req
};
for illegal_h2_header in [
&CONNECTION,
&TRANSFER_ENCODING,
&PROXY_CONNECTION,
&UPGRADE,
&KEEP_ALIVE,
&HOST,
] {
if let Some(header) = req.headers_mut().remove(illegal_h2_header) {
tracing::trace!(?header, "removed illegal (~http1) header from h2 request");
}
}
req
}
Version::HTTP_3 => {
tracing::debug!(
uri = %req.uri(),
"h3 request detected, but sanitize_client_req_header does not yet support this",
);
req
}
_ => {
tracing::debug!(
uri = %req.uri(),
method = ?req.method(),
"request with unknown version detected, sanitize_client_req_header cannot support this",
);
req
}
})
}