rama_http_headers/forwarded/
exotic_forward_ip.rs

1use crate::Header;
2use rama_core::error::{ErrorContext, OpaqueError};
3use rama_http_types::header::{
4    CF_CONNECTING_IP, CLIENT_IP, TRUE_CLIENT_IP, X_CLIENT_IP, X_REAL_IP,
5};
6use rama_http_types::{HeaderName, HeaderValue};
7use rama_macros::paste;
8use rama_net::forwarded::{ForwardedElement, NodeId};
9use std::fmt;
10use std::net::{IpAddr, Ipv6Addr};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13struct ClientAddr {
14    ip: IpAddr,
15    port: Option<u16>,
16}
17
18impl fmt::Display for ClientAddr {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match &self.port {
21            Some(port) => match &self.ip {
22                IpAddr::V6(ip) => write!(f, "[{ip}]:{port}"),
23                IpAddr::V4(ip) => write!(f, "{ip}:{port}"),
24            },
25            None => self.ip.fmt(f),
26        }
27    }
28}
29
30impl std::str::FromStr for ClientAddr {
31    type Err = OpaqueError;
32
33    fn from_str(s: &str) -> Result<Self, Self::Err> {
34        if let Ok(ip) = s.parse() {
35            // first try host alone, as it is most common,
36            // and also prevents IPv6 to be seen by default with port
37            return Ok(ClientAddr { ip, port: None });
38        }
39
40        let (s, port) = try_to_split_num_port_from_str(s);
41        let ip = try_to_parse_str_to_ip(s).context("parse forwarded ip")?;
42
43        match ip {
44            IpAddr::V6(_) if port.is_some() && !s.starts_with('[') => Err(
45                OpaqueError::from_display("missing brackets for IPv6 address with port"),
46            ),
47            _ => Ok(ClientAddr { ip, port }),
48        }
49    }
50}
51
52fn try_to_parse_str_to_ip(value: &str) -> Option<IpAddr> {
53    if value.starts_with('[') || value.ends_with(']') {
54        let value = value
55            .strip_prefix('[')
56            .and_then(|value| value.strip_suffix(']'))?;
57        Some(IpAddr::V6(value.parse::<Ipv6Addr>().ok()?))
58    } else {
59        value.parse::<IpAddr>().ok()
60    }
61}
62
63fn try_to_split_num_port_from_str(s: &str) -> (&str, Option<u16>) {
64    if let Some(colon) = s.as_bytes().iter().rposition(|c| *c == b':') {
65        match s[colon + 1..].parse() {
66            Ok(port) => (&s[..colon], Some(port)),
67            Err(_) => (s, None),
68        }
69    } else {
70        (s, None)
71    }
72}
73
74macro_rules! exotic_forward_ip_headers {
75    (
76        $(
77            #[doc = $desc:literal]
78            #[header = $header:ident]
79            $(#[$outer:meta])*
80            pub struct $name:ident;
81        )+
82    ) => {
83        $(
84            #[derive(Debug, Clone, PartialEq, Eq)]
85            #[doc = $desc]
86            $(#[$outer])*
87            pub struct $name(ClientAddr);
88
89            impl Header for $name {
90                fn name() -> &'static HeaderName {
91                    &$header
92                }
93
94                fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(
95                    values: &mut I,
96                ) -> Result<Self, crate::Error> {
97                    Ok($name(
98                        values
99                            .next()
100                            .and_then(|value| value.to_str().ok().and_then(|s| s.parse().ok()))
101                            .ok_or_else(crate::Error::invalid)?,
102                    ))
103                }
104
105                fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
106                    let s = self.0.to_string();
107                    values.extend(Some(HeaderValue::from_str(&s).unwrap()))
108                }
109            }
110
111            impl super::ForwardHeader for $name {
112                fn try_from_forwarded<'a, I>(input: I) -> Option<Self>
113                where
114                    I: IntoIterator<Item = &'a ForwardedElement>,
115                {
116                    let node = input
117                        .into_iter()
118                        .next()?
119                        .ref_forwarded_for()?;
120                    let ip = node.ip()?;
121                    let port = node.port();
122                    Some($name(ClientAddr { ip, port }))
123                }
124            }
125
126            paste! {
127                impl IntoIterator for $name {
128                    type Item = ForwardedElement;
129                    type IntoIter = [<$name Iterator>];
130
131                    fn into_iter(self) -> Self::IntoIter {
132                        [<$name Iterator>](Some(self.0))
133                    }
134                }
135
136                #[derive(Debug, Clone)]
137                #[doc = concat!("An iterator over the `", stringify!($name), "` header's elements.")]
138                pub struct [<$name Iterator>](Option<ClientAddr>);
139
140                impl Iterator for [<$name Iterator>] {
141                    type Item = ForwardedElement;
142
143                    fn next(&mut self) -> Option<Self::Item> {
144                        self.0.take().map(|addr| {
145                            let node: NodeId = (addr.ip, addr.port).into();
146                            ForwardedElement::forwarded_for(node)
147                        })
148                    }
149                }
150            }
151        )+
152    };
153}
154
155exotic_forward_ip_headers! {
156    #[doc = "CF-Connecting-IP provides the client IP address connecting to Cloudflare to the origin web server."]
157    #[header = CF_CONNECTING_IP]
158    pub struct CFConnectingIp;
159
160    #[doc = "True-Client-IP provides the original client IP address to the origin web server (Cloudflare Enterprise)."]
161    #[header = TRUE_CLIENT_IP]
162    pub struct TrueClientIp;
163
164    #[doc = "X-Real-Ip is used by some proxy software to set the real client Ip Address (known to them)."]
165    #[header = X_REAL_IP]
166    pub struct XRealIp;
167
168
169    #[doc = "Client-Ip is used by some proxy software to set the real client Ip Address (known to them)."]
170    #[header = CLIENT_IP]
171    pub struct ClientIp;
172
173    #[doc = "X-Client-Ip is used by some proxy software to set the real client Ip Address (known to them)."]
174    #[header = X_CLIENT_IP]
175    pub struct XClientIp;
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    macro_rules! test_headers {
183        ($($ty: ident),+ ; $name: ident, $input: expr, $expected: literal) => {
184            #[test]
185            fn $name() {
186                $(
187                    assert_eq!(
188                        $ty::decode(
189                            &mut $input
190                                .into_iter()
191                                .map(|s| HeaderValue::from_bytes(s.as_bytes()).unwrap())
192                                .collect::<Vec<_>>()
193                                .iter()
194                        )
195                        .unwrap(),
196                        $ty($expected.parse().unwrap()),
197                    );
198                )+
199            }
200        };
201    }
202
203    macro_rules! test_header {
204        ($name: ident, $input: expr, $expected: literal) => {
205            test_headers!(CFConnectingIp, TrueClientIp, XRealIp, ClientIp, XClientIp; $name, $input, $expected);
206        };
207    }
208
209    // Tests from the Docs
210    test_header!(test1, vec!["203.0.113.195"], "203.0.113.195");
211    test_header!(test2, vec!["203.0.113.195:80"], "203.0.113.195:80");
212    test_header!(
213        test3,
214        vec!["2001:db8:85a3:8d3:1319:8a2e:370:7348"],
215        "2001:db8:85a3:8d3:1319:8a2e:370:7348"
216    );
217    test_header!(
218        test4,
219        vec!["[2001:db8:85a3:8d3:1319:8a2e:370:7348]:8080"],
220        "[2001:db8:85a3:8d3:1319:8a2e:370:7348]:8080"
221    );
222
223    macro_rules! symmetric_test_header {
224        ($name: ident) => {
225            for input in [
226                $name("127.0.0.1:8080".parse().unwrap()),
227                $name(
228                    "[2001:db8:85a3:8d3:1319:8a2e:370:7348]:8080"
229                        .parse()
230                        .unwrap(),
231                ),
232                $name("203.0.113.195".parse().unwrap()),
233                $name("203.0.113.195:80".parse().unwrap()),
234            ] {
235                let mut values = Vec::new();
236                input.encode(&mut values);
237                assert_eq!($name::decode(&mut values.iter()).unwrap(), input);
238            }
239        };
240    }
241
242    #[test]
243    fn test_symmetry_encode() {
244        symmetric_test_header!(CFConnectingIp);
245        symmetric_test_header!(TrueClientIp);
246        symmetric_test_header!(XRealIp);
247        symmetric_test_header!(ClientIp);
248        symmetric_test_header!(XClientIp);
249    }
250}