rocket_client_addr/
client_addr.rs

1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
2
3use rocket::{
4    http::Status,
5    outcome::Outcome,
6    request::{self, FromRequest, Request},
7};
8
9/// The request guard used for getting an IP address from a client.
10#[derive(Debug, Clone)]
11pub struct ClientAddr {
12    /// IP address from a client.
13    pub ip: IpAddr,
14}
15
16fn is_local_ip(addr: &IpAddr) -> bool {
17    match addr {
18        IpAddr::V4(addr) => {
19            let octets = addr.octets();
20
21            match octets {
22                // --- is_private ---
23                [10, ..] => true,
24                [172, b, ..] if (16..=31).contains(&b) => true,
25                [192, 168, ..] => true,
26                // --- is_loopback ---
27                [127, ..] => true,
28                // --- is_link_local ---
29                [169, 254, ..] => true,
30                // --- is_broadcast ---
31                [255, 255, 255, 255] => true,
32                // --- is_documentation ---
33                [192, 0, 2, _] => true,
34                [198, 51, 100, _] => true,
35                [203, 0, 113, _] => true,
36                // --- is_unspecified ---
37                [0, 0, 0, 0] => true,
38                _ => false,
39            }
40        },
41        IpAddr::V6(addr) => {
42            let segments = addr.segments();
43
44            let is_multicast = segments[0] & 0xFF00 == 0xFF00;
45
46            if is_multicast {
47                segments[0] & 0x000F != 14 // 14 means global
48            } else {
49                match segments {
50                    // --- is_loopback ---
51                    [0, 0, 0, 0, 0, 0, 0, 1] => true,
52                    // --- is_unspecified ---
53                    [0, 0, 0, 0, 0, 0, 0, 0] => true,
54                    _ => {
55                        match segments[0] & 0xFFC0 {
56                            // --- is_unicast_link_local ---
57                            0xFE80 => true,
58                            // --- is_unicast_site_local ---
59                            0xFEC0 => true,
60                            _ => {
61                                // --- is_unique_local ---
62                                if segments[0] & 0xFE00 == 0xFC00 {
63                                    true
64                                } else {
65                                    (segments[0] == 0x2001) && (segments[1] == 0xDB8)
66                                }
67                            },
68                        }
69                    },
70                }
71            }
72        },
73    }
74}
75
76fn from_request(request: &Request<'_>) -> Option<ClientAddr> {
77    let (remote_ip, ok) = match request.remote() {
78        Some(addr) => {
79            let ip = addr.ip();
80
81            let ok = !is_local_ip(&ip);
82
83            (Some(ip), ok)
84        },
85        None => (None, false),
86    };
87
88    if ok {
89        match remote_ip {
90            Some(ip) => Some(ClientAddr {
91                ip,
92            }),
93            None => unreachable!(),
94        }
95    } else {
96        let forwarded_for_ip: Option<&str> = request.headers().get("x-forwarded-for").next(); // Only fetch the first one.
97
98        match forwarded_for_ip {
99            Some(forwarded_for_ip) => {
100                let forwarded_for_ips = forwarded_for_ip.rsplit(',');
101
102                let mut last_ip = None;
103
104                for forwarded_for_ip in forwarded_for_ips {
105                    match forwarded_for_ip.trim().parse::<IpAddr>() {
106                        Ok(ip) => {
107                            last_ip = Some(ip);
108
109                            if !is_local_ip(&ip) {
110                                break;
111                            }
112                        },
113                        Err(_) => {
114                            break;
115                        },
116                    }
117                }
118
119                match last_ip {
120                    Some(ip) => Some(ClientAddr {
121                        ip,
122                    }),
123                    None => match request.real_ip() {
124                        Some(real_ip) => Some(ClientAddr {
125                            ip: real_ip
126                        }),
127                        None => remote_ip.map(|ip| ClientAddr {
128                            ip,
129                        }),
130                    },
131                }
132            },
133            None => match request.real_ip() {
134                Some(real_ip) => Some(ClientAddr {
135                    ip: real_ip
136                }),
137                None => remote_ip.map(|ip| ClientAddr {
138                    ip,
139                }),
140            },
141        }
142    }
143}
144
145#[rocket::async_trait]
146impl<'r> FromRequest<'r> for ClientAddr {
147    type Error = ();
148
149    async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
150        match from_request(request) {
151            Some(client_addr) => Outcome::Success(client_addr),
152            None => Outcome::Forward(Status::BadRequest),
153        }
154    }
155}
156
157#[rocket::async_trait]
158impl<'r> FromRequest<'r> for &'r ClientAddr {
159    type Error = ();
160
161    async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
162        let cache: &Option<ClientAddr> = request.local_cache(|| from_request(request));
163
164        match cache.as_ref() {
165            Some(client_addr) => Outcome::Success(client_addr),
166            None => Outcome::Forward(Status::BadRequest),
167        }
168    }
169}
170
171impl ClientAddr {
172    /// Get an `Ipv4Addr` instance.
173    pub fn get_ipv4(&self) -> Option<Ipv4Addr> {
174        match &self.ip {
175            IpAddr::V4(ipv4) => Some(*ipv4),
176            IpAddr::V6(ipv6) => ipv6.to_ipv4(),
177        }
178    }
179
180    /// Get an IPv4 string.
181    pub fn get_ipv4_string(&self) -> Option<String> {
182        match &self.ip {
183            IpAddr::V4(ipv4) => Some(ipv4.to_string()),
184            IpAddr::V6(ipv6) => ipv6.to_ipv4().map(|ipv6| ipv6.to_string()),
185        }
186    }
187
188    /// Get an `Ipv6Addr` instance.
189    pub fn get_ipv6(&self) -> Ipv6Addr {
190        match &self.ip {
191            IpAddr::V4(ipv4) => ipv4.to_ipv6_mapped(),
192            IpAddr::V6(ipv6) => *ipv6,
193        }
194    }
195
196    /// Get an IPv6 string.
197    pub fn get_ipv6_string(&self) -> String {
198        match &self.ip {
199            IpAddr::V4(ipv4) => ipv4.to_ipv6_mapped().to_string(),
200            IpAddr::V6(ipv6) => ipv6.to_string(),
201        }
202    }
203}