use crate::{
HttpRequest,
error::Error,
http::{
FromRequestParts, FromRequestRef,
endpoints::args::{FromPayload, Payload, Source},
request_scope::HttpRequestScope,
},
};
use futures_util::future::{Ready, ready};
use hyper::http::{Extensions, request::Parts};
use std::fmt::Display;
use std::{net::SocketAddr, ops::Deref};
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct ClientIp(pub(crate) SocketAddr);
impl Display for ClientIp {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl Deref for ClientIp {
type Target = SocketAddr;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ClientIp {
#[inline]
pub fn into_inner(self) -> SocketAddr {
self.0
}
}
impl TryFrom<&Extensions> for ClientIp {
type Error = Error;
#[inline]
fn try_from(extensions: &Extensions) -> Result<Self, Self::Error> {
extensions
.get::<HttpRequestScope>()
.map(|s| s.client_ip)
.ok_or_else(|| Error::server_error("Client IP: missing"))
}
}
impl TryFrom<&Parts> for ClientIp {
type Error = Error;
#[inline]
fn try_from(parts: &Parts) -> Result<Self, Self::Error> {
ClientIp::try_from(&parts.extensions)
}
}
impl FromRequestParts for ClientIp {
#[inline]
fn from_parts(parts: &Parts) -> Result<Self, Error> {
parts.try_into()
}
}
impl FromRequestRef for ClientIp {
#[inline]
fn from_request(req: &HttpRequest) -> Result<Self, Error> {
req.extensions().try_into()
}
}
impl FromPayload for ClientIp {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Parts;
#[inline]
fn from_payload(payload: Payload<'_>) -> Self::Future {
let Payload::Parts(parts) = payload else {
unreachable!()
};
ready(parts.try_into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::endpoints::args::{FromPayload, FromRequestParts, FromRequestRef, Payload};
use crate::http::request_scope::HttpRequestScope;
use crate::{HttpBody, HttpRequest};
use hyper::{Request, http::Extensions};
fn make_scope(ip: ClientIp) -> HttpRequestScope {
HttpRequestScope {
client_ip: ip,
..HttpRequestScope::default()
}
}
#[tokio::test]
async fn it_reads_from_payload() {
let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
let req = Request::get("/")
.extension(make_scope(ip))
.body(())
.unwrap();
let (parts, _) = req.into_parts();
let client_ip = ClientIp::from_payload(Payload::Parts(&parts))
.await
.unwrap();
assert_eq!(client_ip, ip);
}
#[test]
fn it_gets_from_extensions() {
let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
let mut extensions = Extensions::new();
extensions.insert(make_scope(ip));
let client_ip = ClientIp::try_from(&extensions).unwrap();
assert_eq!(client_ip, ip);
}
#[test]
fn it_gets_err_from_extensions_if_missing() {
let extensions = Extensions::new();
let client_ip = ClientIp::try_from(&extensions);
assert!(client_ip.is_err());
}
#[test]
fn it_gets_from_request_parts() {
let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
let req = Request::get("/")
.extension(make_scope(ip))
.body(())
.unwrap();
let (parts, _) = req.into_parts();
let client_ip = ClientIp::from_parts(&parts).unwrap();
assert_eq!(client_ip, ip);
}
#[test]
fn it_gets_from_request_ref() {
let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
let req = Request::get("/")
.extension(make_scope(ip))
.body(HttpBody::empty())
.unwrap();
let (parts, body) = req.into_parts();
let req = HttpRequest::from_parts(parts, body);
let client_ip = ClientIp::from_request(&req).unwrap();
assert_eq!(client_ip, ip);
}
}