volga 0.9.1

Easy & Fast Web Framework for Rust
Documentation
//! Extractors for client IP address

use crate::{
    HttpRequest,
    error::Error,
    http::{
        FromRequestParts, FromRequestRef,
        endpoints::args::{FromPayload, Payload, Source},
        request_scope::HttpRequestScope,
    },
};
use futures_util::future::{Ready, ready};
use hyper::http::{Extensions, request::Parts};
use std::fmt::Display;
use std::{net::SocketAddr, ops::Deref};

/// Wraps the client's [`SocketAddr`]
///
/// # Example
/// ```no_run
/// use volga::{HttpResult, ClientIp, ok};
///
/// async fn handle(ip: ClientIp) -> HttpResult {
///     ok!("Client IP: {ip}")
/// }
/// ```
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct ClientIp(pub(crate) SocketAddr);

impl Display for ClientIp {
    #[inline]
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

impl Deref for ClientIp {
    type Target = SocketAddr;

    #[inline]
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl ClientIp {
    /// Unwraps the inner [`SocketAddr`]
    #[inline]
    pub fn into_inner(self) -> SocketAddr {
        self.0
    }
}

impl TryFrom<&Extensions> for ClientIp {
    type Error = Error;

    #[inline]
    fn try_from(extensions: &Extensions) -> Result<Self, Self::Error> {
        extensions
            .get::<HttpRequestScope>()
            .map(|s| s.client_ip)
            .ok_or_else(|| Error::server_error("Client IP: missing"))
    }
}

impl TryFrom<&Parts> for ClientIp {
    type Error = Error;

    #[inline]
    fn try_from(parts: &Parts) -> Result<Self, Self::Error> {
        ClientIp::try_from(&parts.extensions)
    }
}

/// Extracts `ClientIp` from request parts
impl FromRequestParts for ClientIp {
    #[inline]
    fn from_parts(parts: &Parts) -> Result<Self, Error> {
        parts.try_into()
    }
}

/// Extracts `ClientIp` from request
impl FromRequestRef for ClientIp {
    #[inline]
    fn from_request(req: &HttpRequest) -> Result<Self, Error> {
        req.extensions().try_into()
    }
}

/// Extracts `ClientIp` from request payload
impl FromPayload for ClientIp {
    type Future = Ready<Result<Self, Error>>;

    const SOURCE: Source = Source::Parts;

    #[inline]
    fn from_payload(payload: Payload<'_>) -> Self::Future {
        let Payload::Parts(parts) = payload else {
            unreachable!()
        };
        ready(parts.try_into())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::http::endpoints::args::{FromPayload, FromRequestParts, FromRequestRef, Payload};
    use crate::http::request_scope::HttpRequestScope;
    use crate::{HttpBody, HttpRequest};
    use hyper::{Request, http::Extensions};

    fn make_scope(ip: ClientIp) -> HttpRequestScope {
        HttpRequestScope {
            client_ip: ip,
            ..HttpRequestScope::default()
        }
    }

    #[tokio::test]
    async fn it_reads_from_payload() {
        let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
        let req = Request::get("/")
            .extension(make_scope(ip))
            .body(())
            .unwrap();

        let (parts, _) = req.into_parts();
        let client_ip = ClientIp::from_payload(Payload::Parts(&parts))
            .await
            .unwrap();

        assert_eq!(client_ip, ip);
    }

    #[test]
    fn it_gets_from_extensions() {
        let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
        let mut extensions = Extensions::new();
        extensions.insert(make_scope(ip));

        let client_ip = ClientIp::try_from(&extensions).unwrap();

        assert_eq!(client_ip, ip);
    }

    #[test]
    fn it_gets_err_from_extensions_if_missing() {
        let extensions = Extensions::new();

        let client_ip = ClientIp::try_from(&extensions);

        assert!(client_ip.is_err());
    }

    #[test]
    fn it_gets_from_request_parts() {
        let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
        let req = Request::get("/")
            .extension(make_scope(ip))
            .body(())
            .unwrap();

        let (parts, _) = req.into_parts();
        let client_ip = ClientIp::from_parts(&parts).unwrap();

        assert_eq!(client_ip, ip);
    }

    #[test]
    fn it_gets_from_request_ref() {
        let ip = ClientIp(SocketAddr::from(([0, 0, 0, 0], 8080)));
        let req = Request::get("/")
            .extension(make_scope(ip))
            .body(HttpBody::empty())
            .unwrap();

        let (parts, body) = req.into_parts();
        let req = HttpRequest::from_parts(parts, body);

        let client_ip = ClientIp::from_request(&req).unwrap();

        assert_eq!(client_ip, ip);
    }
}