use core::{fmt, marker};
use core::net::IpAddr;
use crate::find_next_ip_after_filter;
use crate::filter::Filter;
use crate::forwarded::parse_forwarded_for_rev;
use ohkami024::{FangProc, Fang};
use ohkami024::{FromRequest, Request, Response};
#[repr(transparent)]
#[derive(Copy, Clone)]
pub struct ClientIp<F: Filter> {
pub inner: Option<IpAddr>,
_filter: marker::PhantomData<F>
}
impl<F: Filter> ClientIp<F> {
#[inline(always)]
fn new(inner: Option<IpAddr>) -> Self {
Self {
inner,
_filter: marker::PhantomData,
}
}
#[inline(always)]
pub fn into_inner(self) -> Option<IpAddr> {
self.inner
}
}
impl<F: Filter> fmt::Debug for ClientIp<F> {
#[inline(always)]
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.inner, fmt)
}
}
impl<'req, F: Send + Sync + Filter + 'static> FromRequest<'req> for ClientIp<F> {
type Error = core::convert::Infallible;
fn from_request(req: &'req Request) -> Option<Result<Self, Self::Error>> {
let filter = req.context.get::<F>()?;
if let Some(ip) = req.headers.forwarded().and_then(|value| find_next_ip_after_filter(parse_forwarded_for_rev(value), filter)) {
Some(Ok(ClientIp::new(Some(ip))))
} else {
Some(Ok(ClientIp::new(None)))
}
}
}
#[derive(Clone)]
pub struct ClientIpMiddleware<F> {
filter: F,
}
impl<F: Filter + Clone + 'static> ClientIpMiddleware<F> {
#[inline(always)]
pub const fn new(filter: F) -> Self {
Self {
filter
}
}
}
impl<F: Filter + Clone + 'static, I: FangProc> Fang<I> for ClientIpMiddleware<F> {
type Proc = ClientIpMiddlewareProc<I, F>;
#[inline(always)]
fn chain(&self, inner: I) -> Self::Proc {
ClientIpMiddlewareProc {
inner,
filter: self.filter.clone(),
}
}
}
pub struct ClientIpMiddlewareProc<I, F> {
inner: I,
filter: F,
}
impl<I: FangProc, F: Filter + 'static> FangProc for ClientIpMiddlewareProc<I, F> {
#[inline(always)]
fn bite<'b>(&'b self, req: &'b mut Request) -> impl Future<Output = Response> {
if req.ip.is_unspecified() {
if let Some(ip) = req.headers.forwarded().and_then(|value| find_next_ip_after_filter(parse_forwarded_for_rev(value), &self.filter)) {
req.ip = ip;
}
}
self.inner.bite(req)
}
}