use crate::Request;
use crate::headers::forwarded::{
ForwardHeader, Via, XForwardedFor, XForwardedHost, XForwardedProto,
};
use rama_core::{Context, Layer, Service};
use rama_http_headers::HeaderMapExt;
use rama_http_headers::forwarded::Forwarded;
use rama_net::forwarded::ForwardedElement;
use std::fmt;
use std::marker::PhantomData;
pub struct GetForwardedHeaderLayer<T = rama_http_headers::forwarded::Forwarded> {
_headers: PhantomData<fn() -> T>,
}
impl<T: fmt::Debug> fmt::Debug for GetForwardedHeaderLayer<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("GetForwardedHeaderLayer")
.field(
"_headers",
&format_args!("{}", std::any::type_name::<fn() -> T>()),
)
.finish()
}
}
impl<T: Clone> Clone for GetForwardedHeaderLayer<T> {
fn clone(&self) -> Self {
Self {
_headers: PhantomData,
}
}
}
impl Default for GetForwardedHeaderLayer {
fn default() -> Self {
Self::forwarded()
}
}
impl<T> GetForwardedHeaderLayer<T> {
pub const fn new() -> Self {
Self {
_headers: PhantomData,
}
}
}
impl GetForwardedHeaderLayer {
#[inline]
pub fn forwarded() -> Self {
Self::new()
}
}
impl GetForwardedHeaderLayer<Via> {
#[inline]
pub fn via() -> Self {
Self::new()
}
}
impl GetForwardedHeaderLayer<XForwardedFor> {
#[inline]
pub fn x_forwarded_for() -> Self {
Self::new()
}
}
impl GetForwardedHeaderLayer<XForwardedHost> {
#[inline]
pub fn x_forwarded_host() -> Self {
Self::new()
}
}
impl GetForwardedHeaderLayer<XForwardedProto> {
#[inline]
pub fn x_forwarded_proto() -> Self {
Self::new()
}
}
impl<H, S> Layer<S> for GetForwardedHeaderLayer<H> {
type Service = GetForwardedHeaderService<S, H>;
fn layer(&self, inner: S) -> Self::Service {
Self::Service {
inner,
_headers: PhantomData,
}
}
}
pub struct GetForwardedHeaderService<S, T = Forwarded> {
inner: S,
_headers: PhantomData<fn() -> T>,
}
impl<S: fmt::Debug, T> fmt::Debug for GetForwardedHeaderService<S, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GetForwardedHeaderService")
.field("inner", &self.inner)
.field("_headers", &format_args!("{}", std::any::type_name::<T>()))
.finish()
}
}
impl<S: Clone, T> Clone for GetForwardedHeaderService<S, T> {
fn clone(&self) -> Self {
GetForwardedHeaderService {
inner: self.inner.clone(),
_headers: PhantomData,
}
}
}
impl<S, T> GetForwardedHeaderService<S, T> {
pub const fn new(inner: S) -> Self {
Self {
inner,
_headers: PhantomData,
}
}
}
impl<S> GetForwardedHeaderService<S> {
#[inline]
pub fn forwarded(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> GetForwardedHeaderService<S, Via> {
#[inline]
pub fn via(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> GetForwardedHeaderService<S, XForwardedFor> {
#[inline]
pub fn x_forwarded_for(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> GetForwardedHeaderService<S, XForwardedHost> {
#[inline]
pub fn x_forwarded_host(inner: S) -> Self {
Self::new(inner)
}
}
impl<S> GetForwardedHeaderService<S, XForwardedProto> {
#[inline]
pub fn x_forwarded_proto(inner: S) -> Self {
Self::new(inner)
}
}
impl<H, S, State, Body> Service<State, Request<Body>> for GetForwardedHeaderService<S, H>
where
H: ForwardHeader + Send + Sync + 'static,
S: Service<State, Request<Body>>,
Body: Send + 'static,
State: Clone + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
fn serve(
&self,
mut ctx: Context<State>,
req: Request<Body>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
if let Some(header) = req.headers().typed_get::<H>() {
forwarded_elements.extend(header);
}
if !forwarded_elements.is_empty() {
match ctx.get_mut::<Forwarded>() {
Some(ref mut f) => {
f.extend(forwarded_elements);
}
None => {
let mut it = forwarded_elements.into_iter();
let mut forwarded = rama_net::forwarded::Forwarded::new(it.next().unwrap());
forwarded.extend(it);
ctx.insert(forwarded);
}
}
}
self.inner.serve(ctx, req)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Response, StatusCode, service::web::response::IntoResponse};
use rama_core::{Layer, error::OpaqueError, service::service_fn};
use rama_http_headers::forwarded::{TrueClientIp, XRealIp};
use rama_net::forwarded::{ForwardedProtocol, ForwardedVersion};
use std::{convert::Infallible, net::IpAddr};
fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
async fn dummy_service_fn() -> Result<Response, OpaqueError> {
Ok(StatusCode::OK.into_response())
}
#[test]
fn test_get_forwarded_service_is_service() {
assert_is_service(GetForwardedHeaderService::forwarded(service_fn(
dummy_service_fn,
)));
assert_is_service(GetForwardedHeaderService::via(service_fn(dummy_service_fn)));
assert_is_service(GetForwardedHeaderService::x_forwarded_for(service_fn(
dummy_service_fn,
)));
assert_is_service(GetForwardedHeaderService::x_forwarded_proto(service_fn(
dummy_service_fn,
)));
assert_is_service(GetForwardedHeaderService::x_forwarded_host(service_fn(
dummy_service_fn,
)));
assert_is_service(GetForwardedHeaderService::<_, TrueClientIp>::new(
service_fn(dummy_service_fn),
));
assert_is_service(
GetForwardedHeaderLayer::forwarded().into_layer(service_fn(dummy_service_fn)),
);
assert_is_service(GetForwardedHeaderLayer::via().into_layer(service_fn(dummy_service_fn)));
assert_is_service(
GetForwardedHeaderLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
);
}
#[tokio::test]
async fn test_get_forwarded_header_forwarded() {
let service = GetForwardedHeaderLayer::forwarded().into_layer(service_fn(
async |ctx: Context<()>, _| {
let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
assert_eq!(forwarded.client_proto(), Some(ForwardedProtocol::HTTP));
Ok::<_, Infallible>(())
},
));
let req = Request::builder()
.header("Forwarded", "for=\"12.23.34.45:5000\";proto=http")
.body(())
.unwrap();
service.serve(Context::default(), req).await.unwrap();
}
#[tokio::test]
async fn test_get_forwarded_header_via() {
let service =
GetForwardedHeaderLayer::via().into_layer(service_fn(async |ctx: Context<()>, _| {
let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
assert!(forwarded.client_ip().is_none());
assert_eq!(
forwarded.iter().next().unwrap().ref_forwarded_by(),
Some(&(IpAddr::from([12, 23, 34, 45]), 5000).into())
);
assert!(forwarded.client_proto().is_none());
assert_eq!(forwarded.client_version(), Some(ForwardedVersion::HTTP_11));
Ok::<_, Infallible>(())
}));
let req = Request::builder()
.header("Via", "1.1 12.23.34.45:5000")
.body(())
.unwrap();
service.serve(Context::default(), req).await.unwrap();
}
#[tokio::test]
async fn test_get_forwarded_header_x_forwarded_for() {
let service = GetForwardedHeaderLayer::x_forwarded_for().into_layer(service_fn(
async |ctx: Context<()>, _| {
let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
assert!(forwarded.client_proto().is_none());
Ok::<_, Infallible>(())
},
));
let req = Request::builder()
.header("X-Forwarded-For", "12.23.34.45, 127.0.0.1")
.body(())
.unwrap();
service.serve(Context::default(), req).await.unwrap();
}
}