use crate::Request;
use crate::headers::HeaderMapExt;
use crate::headers::forwarded::{
ForwardHeader, Via, XForwardedFor, XForwardedHost, XForwardedProto,
};
use rama_core::error::BoxError;
use rama_core::{Context, Layer, Service};
use rama_http_headers::forwarded::Forwarded;
use rama_net::address::Domain;
use rama_net::forwarded::{ForwardedElement, NodeId};
use rama_net::http::RequestContext;
use rama_net::stream::SocketInfo;
use std::fmt;
use std::marker::PhantomData;
pub struct SetForwardedHeaderLayer<T = Forwarded> {
by_node: NodeId,
_headers: PhantomData<fn() -> T>,
}
impl<T: fmt::Debug> fmt::Debug for SetForwardedHeaderLayer<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SetForwardedHeaderLayer")
.field("by_node", &self.by_node)
.field(
"_headers",
&format_args!("{}", std::any::type_name::<fn() -> T>()),
)
.finish()
}
}
impl<T: Clone> Clone for SetForwardedHeaderLayer<T> {
fn clone(&self) -> Self {
Self {
by_node: self.by_node.clone(),
_headers: PhantomData,
}
}
}
impl<T> SetForwardedHeaderLayer<T> {
pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
self.by_node = node_id.into();
self
}
pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
self.by_node = node_id.into();
self
}
}
impl<T> SetForwardedHeaderLayer<T> {
pub fn new() -> Self {
Self {
by_node: Domain::from_static("rama").into(),
_headers: PhantomData,
}
}
}
impl Default for SetForwardedHeaderLayer {
fn default() -> Self {
Self::forwarded()
}
}
impl SetForwardedHeaderLayer {
#[inline]
pub fn forwarded() -> Self {
Self::new()
}
}
impl SetForwardedHeaderLayer<Via> {
#[inline]
pub fn via() -> Self {
Self::new()
}
}
impl SetForwardedHeaderLayer<XForwardedFor> {
#[inline]
pub fn x_forwarded_for() -> Self {
Self::new()
}
}
impl SetForwardedHeaderLayer<XForwardedHost> {
#[inline]
pub fn x_forwarded_host() -> Self {
Self::new()
}
}
impl SetForwardedHeaderLayer<XForwardedProto> {
#[inline]
pub fn x_forwarded_proto() -> Self {
Self::new()
}
}
impl<H, S> Layer<S> for SetForwardedHeaderLayer<H> {
type Service = SetForwardedHeaderService<S, H>;
fn layer(&self, inner: S) -> Self::Service {
Self::Service {
inner,
by_node: self.by_node.clone(),
_headers: PhantomData,
}
}
fn into_layer(self, inner: S) -> Self::Service {
Self::Service {
inner,
by_node: self.by_node,
_headers: PhantomData,
}
}
}
pub struct SetForwardedHeaderService<S, T = Forwarded> {
inner: S,
by_node: NodeId,
_headers: PhantomData<fn() -> T>,
}
impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeaderService<S, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SetForwardedHeaderService")
.field("inner", &self.inner)
.field("by_node", &self.by_node)
.field(
"_headers",
&format_args!("{}", std::any::type_name::<fn() -> T>()),
)
.finish()
}
}
impl<S: Clone, T> Clone for SetForwardedHeaderService<S, T> {
fn clone(&self) -> Self {
SetForwardedHeaderService {
inner: self.inner.clone(),
by_node: self.by_node.clone(),
_headers: PhantomData,
}
}
}
impl<S, T> SetForwardedHeaderService<S, T> {
pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
self.by_node = node_id.into();
self
}
pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
self.by_node = node_id.into();
self
}
}
impl<S, T> SetForwardedHeaderService<S, T> {
pub fn new(inner: S) -> Self {
Self {
inner,
by_node: Domain::from_static("rama").into(),
_headers: PhantomData,
}
}
}
impl<S> SetForwardedHeaderService<S> {
#[inline]
pub fn forwarded(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> SetForwardedHeaderService<S, Via> {
#[inline]
pub fn via(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> SetForwardedHeaderService<S, XForwardedFor> {
#[inline]
pub fn x_forwarded_for(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> SetForwardedHeaderService<S, XForwardedHost> {
#[inline]
pub fn x_forwarded_host(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> SetForwardedHeaderService<S, XForwardedProto> {
#[inline]
pub fn x_forwarded_proto(inner: S) -> Self {
Self::new(inner)
}
}
impl<S, H, State, Body> Service<State, Request<Body>> for SetForwardedHeaderService<S, H>
where
S: Service<State, Request<Body>, Error: Into<BoxError>>,
H: ForwardHeader + Send + Sync + 'static,
Body: Send + 'static,
State: Clone + Send + Sync + 'static,
{
type Response = S::Response;
type Error = BoxError;
async fn serve(
&self,
mut ctx: Context<State>,
mut req: Request<Body>,
) -> Result<Self::Response, Self::Error> {
let forwarded: Option<rama_net::forwarded::Forwarded> = ctx.get().cloned();
let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
forwarded_element.set_forwarded_for(peer_addr);
}
let request_ctx: &mut RequestContext =
ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
forwarded_element.set_forwarded_host(request_ctx.authority.clone());
if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
forwarded_element.set_forwarded_proto(forwarded_proto);
}
let forwarded = match forwarded {
None => Some(rama_net::forwarded::Forwarded::new(forwarded_element)),
Some(mut forwarded) => {
forwarded.append(forwarded_element);
Some(forwarded)
}
};
if let Some(forwarded) = forwarded {
if let Some(header) = H::try_from_forwarded(forwarded.iter()) {
req.headers_mut().typed_insert(header);
}
}
self.inner.serve(ctx, req).await.map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
Response, StatusCode,
headers::forwarded::{TrueClientIp, XRealIp},
service::web::response::IntoResponse,
};
use rama_core::{Layer, error::OpaqueError, service::service_fn};
use std::{convert::Infallible, net::IpAddr};
fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
async fn dummy_service_fn() -> Result<Response, OpaqueError> {
Ok(StatusCode::OK.into_response())
}
#[test]
fn test_set_forwarded_service_is_service() {
assert_is_service(SetForwardedHeaderService::forwarded(service_fn(
dummy_service_fn,
)));
assert_is_service(SetForwardedHeaderService::via(service_fn(dummy_service_fn)));
assert_is_service(SetForwardedHeaderService::x_forwarded_for(service_fn(
dummy_service_fn,
)));
assert_is_service(SetForwardedHeaderService::x_forwarded_proto(service_fn(
dummy_service_fn,
)));
assert_is_service(SetForwardedHeaderService::x_forwarded_host(service_fn(
dummy_service_fn,
)));
assert_is_service(SetForwardedHeaderService::<_, TrueClientIp>::new(
service_fn(dummy_service_fn),
));
assert_is_service(SetForwardedHeaderLayer::via().into_layer(service_fn(dummy_service_fn)));
assert_is_service(
SetForwardedHeaderLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
);
}
#[tokio::test]
async fn test_set_forwarded_service_forwarded() {
async fn svc(request: Request<()>) -> Result<(), Infallible> {
assert_eq!(
request.headers().get("Forwarded").unwrap(),
"by=rama;host=\"example.com:80\";proto=http"
);
Ok(())
}
let service = SetForwardedHeaderService::forwarded(service_fn(svc));
let req = Request::builder().uri("example.com").body(()).unwrap();
service.serve(Context::default(), req).await.unwrap();
}
#[tokio::test]
async fn test_set_forwarded_service_forwarded_with_chain() {
async fn svc(request: Request<()>) -> Result<(), Infallible> {
assert_eq!(
request.headers().get("Forwarded").unwrap(),
"for=12.23.34.45,by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
);
Ok(())
}
let service = SetForwardedHeaderService::forwarded(service_fn(svc));
let req = Request::builder()
.uri("https://www.example.com")
.body(())
.unwrap();
let mut ctx = Context::default();
ctx.insert(rama_net::forwarded::Forwarded::new(
ForwardedElement::forwarded_for(IpAddr::from([12, 23, 34, 45])),
));
ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
service.serve(ctx, req).await.unwrap();
}
#[tokio::test]
async fn test_set_forwarded_service_x_forwarded_for_with_chain() {
async fn svc(request: Request<()>) -> Result<(), Infallible> {
assert_eq!(
request.headers().get("X-Forwarded-For").unwrap(),
"12.23.34.45, 127.0.0.1",
);
Ok(())
}
let service = SetForwardedHeaderService::x_forwarded_for(service_fn(svc));
let req = Request::builder()
.uri("https://www.example.com")
.body(())
.unwrap();
let mut ctx = Context::default();
ctx.insert(rama_net::forwarded::Forwarded::new(
ForwardedElement::forwarded_for(IpAddr::from([12, 23, 34, 45])),
));
ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
service.serve(ctx, req).await.unwrap();
}
#[tokio::test]
async fn test_set_forwarded_service_forwarded_fully_defined() {
async fn svc(request: Request<()>) -> Result<(), Infallible> {
assert_eq!(
request.headers().get("Forwarded").unwrap(),
"by=12.23.34.45;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
);
Ok(())
}
let service = SetForwardedHeaderService::forwarded(service_fn(svc))
.forward_by(IpAddr::from([12, 23, 34, 45]));
let req = Request::builder()
.uri("https://www.example.com")
.body(())
.unwrap();
let mut ctx = Context::default();
ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
service.serve(ctx, req).await.unwrap();
}
#[tokio::test]
async fn test_set_forwarded_service_forwarded_fully_defined_with_chain() {
async fn svc(request: Request<()>) -> Result<(), Infallible> {
assert_eq!(
request.headers().get("Forwarded").unwrap(),
"by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
);
Ok(())
}
let service = SetForwardedHeaderService::forwarded(service_fn(svc));
let req = Request::builder()
.uri("https://www.example.com")
.body(())
.unwrap();
let mut ctx = Context::default();
ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
service.serve(ctx, req).await.unwrap();
}
}