rama_net/forwarded/
node.rs

1use super::{ObfNode, ObfPort};
2use crate::address::{Authority, Domain, Host};
3use rama_core::error::{ErrorContext, OpaqueError};
4use std::{
5    fmt,
6    net::{IpAddr, Ipv6Addr, SocketAddr},
7};
8
9#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
10/// Node Identifier
11///
12/// The node identifier is one of the following:
13///
14/// - The client's IP address, with an optional port number
15/// - A token indicating that the IP address of the client is not known
16///   to the proxy server (unknown)
17/// - A generated token, allowing for tracing and debugging, while
18///   allowing the internal structure or sensitive information to be
19///   hidden
20///
21/// As specified in proposal section:
22/// <https://datatracker.ietf.org/doc/html/rfc7239#section-6>
23pub struct NodeId {
24    name: NodeName,
25    port: Option<NodePort>,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
29enum NodeName {
30    Unknown,
31    Ip(IpAddr),
32    Obf(ObfNode),
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
36enum NodePort {
37    Num(u16),
38    Obf(ObfPort),
39}
40
41impl NodeId {
42    /// Try to convert a vector of bytes to a [`NodeId`].
43    pub fn try_from_bytes(vec: Vec<u8>) -> Result<Self, OpaqueError> {
44        vec.try_into()
45    }
46
47    /// Try to convert a string slice to a [`NodeId`].
48    pub fn try_from_str(s: &str) -> Result<Self, OpaqueError> {
49        s.to_owned().try_into()
50    }
51
52    #[inline]
53    /// Converts a vector of bytes to a [`NodeId`], converting invalid characters to underscore.
54    pub fn from_bytes_lossy(vec: &[u8]) -> Self {
55        let s = String::from_utf8_lossy(vec);
56        Self::from_str_lossy(&s)
57    }
58
59    /// Converts a string slice to a [`NodeId`], converting invalid characters to underscore.
60    pub fn from_str_lossy(s: &str) -> Self {
61        let s_original = s;
62
63        if s.eq_ignore_ascii_case(UNKNOWN_STR) {
64            return NodeId {
65                name: NodeName::Unknown,
66                port: None,
67            };
68        }
69
70        if let Ok(ip) = try_to_parse_str_to_ip(s) {
71            // early return to prevent stuff like `::1` to
72            // be interpreted as node { name = obf(:), port = num(1) }
73            return NodeId {
74                name: NodeName::Ip(ip),
75                port: None,
76            };
77        }
78
79        let (s, port) = try_to_split_node_port_lossy_from_str(s);
80        let name = try_to_parse_str_to_ip(s)
81            .map(NodeName::Ip)
82            .unwrap_or_else(|_| NodeName::Obf(ObfNode::from_str_lossy(s)));
83
84        match name {
85            NodeName::Ip(IpAddr::V6(_)) if port.is_some() && !s.starts_with('[') => NodeId {
86                name: NodeName::Obf(ObfNode::from_str_lossy(s_original)),
87                port: None,
88            },
89            _ => NodeId { name, port },
90        }
91    }
92
93    /// Return the [`IpAddr`] if one was defined for this [`NodeId`].
94    pub fn ip(&self) -> Option<IpAddr> {
95        match &self.name {
96            NodeName::Ip(addr) => Some(*addr),
97            NodeName::Unknown | NodeName::Obf(_) => None,
98        }
99    }
100
101    /// Return true if this [`NodeId`] has a any kind of port defined,
102    /// even if obfuscated.
103    pub fn has_any_port(&self) -> bool {
104        self.port.is_some()
105    }
106
107    /// Return the numeric port if one was defined for this [`NodeId`].
108    pub fn port(&self) -> Option<u16> {
109        if let Some(NodePort::Num(n)) = self.port {
110            Some(n)
111        } else {
112            None
113        }
114    }
115
116    /// Return the [`Authority`] if this [`NodeId`] has either
117    /// an [`IpAddr`] or [`Domain`] defined, as well as a numeric port.
118    pub fn authority(&self) -> Option<Authority> {
119        match (&self.name, self.port()) {
120            (NodeName::Ip(ip), Some(port)) => Some((*ip, port).into()),
121            // every domain is a valid node name, but not every valid node name is a valid domain!!
122            (NodeName::Obf(s), Some(port)) => s
123                .as_str()
124                .parse::<Domain>()
125                .ok()
126                .map(|domain| (domain, port).into()),
127            _ => None,
128        }
129    }
130}
131
132impl NodePort {
133    /// Converts a string slice to a [`NodePort`], converting invalid characters to underscore.
134    fn from_str_lossy(s: &str) -> Self {
135        s.parse::<u16>()
136            .map(NodePort::Num)
137            .unwrap_or_else(|_| NodePort::Obf(ObfPort::from_str_lossy(s)))
138    }
139}
140
141impl From<IpAddr> for NodeId {
142    #[inline]
143    fn from(ip: IpAddr) -> Self {
144        (ip, None).into()
145    }
146}
147
148impl From<(IpAddr, u16)> for NodeId {
149    #[inline]
150    fn from((ip, port): (IpAddr, u16)) -> Self {
151        (ip, Some(port)).into()
152    }
153}
154
155impl From<(IpAddr, Option<u16>)> for NodeId {
156    fn from((ip, port): (IpAddr, Option<u16>)) -> Self {
157        NodeId {
158            name: NodeName::Ip(ip),
159            port: port.map(NodePort::Num),
160        }
161    }
162}
163
164impl From<Domain> for NodeId {
165    #[inline]
166    fn from(domain: Domain) -> Self {
167        (domain, None).into()
168    }
169}
170
171impl From<(Domain, u16)> for NodeId {
172    #[inline]
173    fn from((domain, port): (Domain, u16)) -> Self {
174        (domain, Some(port)).into()
175    }
176}
177
178impl From<(Domain, Option<u16>)> for NodeId {
179    fn from((domain, port): (Domain, Option<u16>)) -> Self {
180        NodeId {
181            // NOTE: this assumes all domains are valid obf nodes,
182            // which should be ok given the validation rules for domains are more strict!
183            name: NodeName::Obf(ObfNode::from_inner(domain.into_inner())),
184            port: port.map(NodePort::Num),
185        }
186    }
187}
188
189impl From<Authority> for NodeId {
190    fn from(authority: Authority) -> Self {
191        let (host, port) = authority.into_parts();
192        match host {
193            Host::Name(domain) => (domain, port).into(),
194            Host::Address(ip) => (ip, port).into(),
195        }
196    }
197}
198
199impl From<SocketAddr> for NodeId {
200    fn from(addr: SocketAddr) -> Self {
201        NodeId {
202            name: NodeName::Ip(addr.ip()),
203            port: Some(NodePort::Num(addr.port())),
204        }
205    }
206}
207
208impl From<&SocketAddr> for NodeId {
209    fn from(addr: &SocketAddr) -> Self {
210        NodeId {
211            name: NodeName::Ip(addr.ip()),
212            port: Some(NodePort::Num(addr.port())),
213        }
214    }
215}
216
217const UNKNOWN_STR: &str = "unknown";
218
219impl fmt::Display for NodeId {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
221        match &self.name {
222            NodeName::Unknown => UNKNOWN_STR.fmt(f),
223            NodeName::Ip(ip) => match &self.port {
224                None => ip.fmt(f),
225                Some(port) => match ip {
226                    std::net::IpAddr::V4(ip) => write!(f, "{ip}:{port}"),
227                    std::net::IpAddr::V6(ip) => write!(f, "[{ip}]:{port}"),
228                },
229            },
230            NodeName::Obf(s) => match &self.port {
231                None => s.fmt(f),
232                Some(port) => write!(f, "{s}:{port}"),
233            },
234        }
235    }
236}
237
238impl fmt::Display for NodePort {
239    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240        match self {
241            NodePort::Num(num) => num.fmt(f),
242            NodePort::Obf(s) => s.fmt(f),
243        }
244    }
245}
246
247impl std::str::FromStr for NodeId {
248    type Err = OpaqueError;
249
250    fn from_str(s: &str) -> Result<Self, Self::Err> {
251        NodeId::try_from(s)
252    }
253}
254
255impl TryFrom<String> for NodeId {
256    type Error = OpaqueError;
257
258    fn try_from(s: String) -> Result<Self, Self::Error> {
259        s.as_str().try_into()
260    }
261}
262
263impl TryFrom<&str> for NodeId {
264    type Error = OpaqueError;
265
266    fn try_from(s: &str) -> Result<Self, Self::Error> {
267        if s.eq_ignore_ascii_case(UNKNOWN_STR) {
268            return Ok(NodeId {
269                name: NodeName::Unknown,
270                port: None,
271            });
272        }
273
274        if let Ok(ip) = try_to_parse_str_to_ip(s) {
275            // early return to prevent stuff like `::1` to
276            // be interpreted as node { name = obf(:), port = num(1) }
277            return Ok(NodeId {
278                name: NodeName::Ip(ip),
279                port: None,
280            });
281        }
282
283        let (s, port) = try_to_split_node_port_from_str(s);
284        let name = try_to_parse_str_to_ip(s)
285            .map(NodeName::Ip)
286            .or_else(|_| s.parse::<ObfNode>().map(NodeName::Obf))
287            .context("parse str as Node")?;
288
289        match name {
290            NodeName::Ip(IpAddr::V6(_)) if port.is_some() && !s.starts_with('[') => Err(
291                OpaqueError::from_display("missing brackets for node IPv6 address with port"),
292            ),
293            _ => Ok(NodeId { name, port }),
294        }
295    }
296}
297
298fn try_to_parse_str_to_ip(value: &str) -> Result<IpAddr, OpaqueError> {
299    if value.starts_with('[') || value.ends_with(']') {
300        let value = value
301            .strip_prefix('[')
302            .and_then(|value| value.strip_suffix(']'))
303            .context("strip brackets from ipv6 str")?;
304        Ok(IpAddr::V6(
305            value.parse::<Ipv6Addr>().context("parse str as ipv6")?,
306        ))
307    } else {
308        value.parse::<IpAddr>().context("parse ipv4/6 str")
309    }
310}
311
312impl TryFrom<Vec<u8>> for NodeId {
313    type Error = OpaqueError;
314
315    fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
316        let s = String::from_utf8(bytes).context("parse node from bytes")?;
317        s.try_into()
318    }
319}
320
321impl TryFrom<&[u8]> for NodeId {
322    type Error = OpaqueError;
323
324    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
325        let s = std::str::from_utf8(bytes).context("parse node from bytes")?;
326        s.try_into()
327    }
328}
329
330fn try_to_split_node_port_from_str(s: &str) -> (&str, Option<NodePort>) {
331    if let Some(colon) = s.as_bytes().iter().rposition(|c| *c == b':') {
332        match s[colon + 1..].parse() {
333            Ok(port) => (&s[..colon], Some(port)),
334            Err(_) => (s, None),
335        }
336    } else {
337        (s, None)
338    }
339}
340
341fn try_to_split_node_port_lossy_from_str(s: &str) -> (&str, Option<NodePort>) {
342    if let Some(colon) = s.as_bytes().iter().rposition(|c| *c == b':') {
343        let port = NodePort::from_str_lossy(&s[colon + 1..]);
344        let s = &s[..colon];
345        (s, Some(port))
346    } else {
347        (s, None)
348    }
349}
350
351impl std::str::FromStr for NodePort {
352    type Err = OpaqueError;
353
354    fn from_str(s: &str) -> Result<Self, Self::Err> {
355        s.parse::<u16>()
356            .map(NodePort::Num)
357            .or_else(|_| s.parse::<ObfPort>().map(NodePort::Obf))
358            .context("parse str as NodePort")
359    }
360}
361
362impl serde::Serialize for NodeId {
363    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
364    where
365        S: serde::Serializer,
366    {
367        let address = self.to_string();
368        address.serialize(serializer)
369    }
370}
371
372impl<'de> serde::Deserialize<'de> for NodeId {
373    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
374    where
375        D: serde::Deserializer<'de>,
376    {
377        let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
378        s.parse().map_err(serde::de::Error::custom)
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_parse_node_id_valid() {
388        for (s, expected) in [
389            (
390                "unknown",
391                NodeId {
392                    name: NodeName::Unknown,
393                    port: None,
394                },
395            ),
396            (
397                "::1",
398                NodeId {
399                    name: NodeName::Ip("::1".parse().unwrap()),
400                    port: None,
401                },
402            ),
403            (
404                "127.0.0.1",
405                NodeId {
406                    name: NodeName::Ip("127.0.0.1".parse().unwrap()),
407                    port: None,
408                },
409            ),
410            (
411                "192.0.2.43:47011",
412                NodeId {
413                    name: NodeName::Ip("192.0.2.43".parse().unwrap()),
414                    port: Some(NodePort::Num(47011)),
415                },
416            ),
417            (
418                "[2001:db8:cafe::17]:47011",
419                NodeId {
420                    name: NodeName::Ip("2001:db8:cafe::17".parse().unwrap()),
421                    port: Some(NodePort::Num(47011)),
422                },
423            ),
424            (
425                "192.0.2.43:_foo",
426                NodeId {
427                    name: NodeName::Ip("192.0.2.43".parse().unwrap()),
428                    port: Some(NodePort::Obf(ObfPort::from_static("_foo"))),
429                },
430            ),
431            (
432                "[2001:db8:cafe::17]:_bar",
433                NodeId {
434                    name: NodeName::Ip("2001:db8:cafe::17".parse().unwrap()),
435                    port: Some(NodePort::Obf(ObfPort::from_static("_bar"))),
436                },
437            ),
438            (
439                "foo",
440                NodeId {
441                    name: NodeName::Obf(ObfNode::from_static("foo")),
442                    port: None,
443                },
444            ),
445            (
446                "_foo",
447                NodeId {
448                    name: NodeName::Obf(ObfNode::from_static("_foo")),
449                    port: None,
450                },
451            ),
452            (
453                "foo:_bar",
454                NodeId {
455                    name: NodeName::Obf(ObfNode::from_static("foo")),
456                    port: Some(NodePort::Obf(ObfPort::from_static("_bar"))),
457                },
458            ),
459            (
460                "foo:42",
461                NodeId {
462                    name: NodeName::Obf(ObfNode::from_static("foo")),
463                    port: Some(NodePort::Num(42)),
464                },
465            ),
466        ] {
467            match s.parse::<NodeId>() {
468                Err(err) => panic!("failed to parse '{s}': {err}"),
469                Ok(node_id) => assert_eq!(node_id, expected, "parse: {}", s),
470            }
471        }
472    }
473
474    #[test]
475    fn test_parse_node_id_invalid() {
476        for s in [
477            "",
478            "@",
479            "2001:db8:3333:4444:5555:6666:7777:8888:80",
480            "foo:bar",
481            "foo:_b+r",
482            "😀",
483            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
484        ] {
485            let node_result = s.parse::<NodeId>();
486            assert!(
487                node_result.is_err(),
488                "parse invalid: {}; parsed: {:?}",
489                s,
490                node_result
491            );
492        }
493    }
494
495    #[test]
496    fn test_parse_node_id_lossy() {
497        for (s, expected) in [
498            (
499                "",
500                NodeId {
501                    name: NodeName::Obf(ObfNode::from_static("_")),
502                    port: None,
503                },
504            ),
505            (
506                "@",
507                NodeId {
508                    name: NodeName::Obf(ObfNode::from_static("_")),
509                    port: None,
510                },
511            ),
512            (
513                "2001:db8:3333:4444:5555:6666:7777:8888:80",
514                NodeId {
515                    name: NodeName::Obf(ObfNode::from_static(
516                        "2001_db8_3333_4444_5555_6666_7777_8888_80",
517                    )),
518                    port: None,
519                },
520            ),
521            (
522                "foo:bar",
523                NodeId {
524                    name: NodeName::Obf(ObfNode::from_static("foo")),
525                    port: Some(NodePort::Obf(ObfPort::from_static("_bar"))),
526                },
527            ),
528            (
529                "foo:_b+r",
530                NodeId {
531                    name: NodeName::Obf(ObfNode::from_static("foo")),
532                    port: Some(NodePort::Obf(ObfPort::from_static("_b_r"))),
533                },
534            ),
535            (
536                "😀",
537                NodeId {
538                    name: NodeName::Obf(ObfNode::from_static("____")),
539                    port: None,
540                },
541            ),
542            (
543                "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
544                NodeId {
545                    name: NodeName::Obf(ObfNode::from_static(
546                        "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuv",
547                    )),
548                    port: None,
549                },
550            ),
551        ] {
552            let node_id = NodeId::from_str_lossy(s);
553            assert_eq!(node_id, expected, "parse str: {}", s);
554
555            let node_id = NodeId::from_bytes_lossy(s.as_bytes());
556            assert_eq!(node_id, expected, "parse bytes: {}", s);
557        }
558    }
559}