1use std::net::IpAddr;
2
3use axum::extract::FromRequestParts;
4use http::request::Parts;
5
6use crate::error::Error;
7
8#[derive(Debug, Clone, Copy)]
21pub struct ClientIp(pub IpAddr);
22
23impl<S: Send + Sync> FromRequestParts<S> for ClientIp {
24 type Rejection = Error;
25
26 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
33 parts.extensions.get::<ClientIp>().copied().ok_or_else(|| {
34 Error::internal("ClientIp not found in request extensions — is ClientIpLayer applied?")
35 })
36 }
37}
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42 use http::request::Parts;
43
44 fn parts_with_client_ip(ip: IpAddr) -> Parts {
45 let mut req = http::Request::builder().body(()).unwrap();
46 req.extensions_mut().insert(ClientIp(ip));
47 req.into_parts().0
48 }
49
50 fn parts_without_client_ip() -> Parts {
51 let req = http::Request::builder().body(()).unwrap();
52 req.into_parts().0
53 }
54
55 #[tokio::test]
56 async fn extracts_client_ip_from_extensions() {
57 let ip: IpAddr = "1.2.3.4".parse().unwrap();
58 let mut parts = parts_with_client_ip(ip);
59 let result = ClientIp::from_request_parts(&mut parts, &()).await;
60 assert!(result.is_ok());
61 assert_eq!(result.unwrap().0, ip);
62 }
63
64 #[tokio::test]
65 async fn returns_error_when_missing() {
66 let mut parts = parts_without_client_ip();
67 let result = ClientIp::from_request_parts(&mut parts, &()).await;
68 assert!(result.is_err());
69 }
70}