distant_net/common/
port.rs

1use std::fmt;
2use std::net::{IpAddr, SocketAddr};
3use std::ops::RangeInclusive;
4use std::str::FromStr;
5
6use derive_more::Display;
7use serde::{de, Deserialize, Serialize};
8
9/// Represents some range of ports
10#[derive(Copy, Clone, Debug, Display, PartialEq, Eq)]
11#[display(
12    fmt = "{}{}",
13    start,
14    "end.as_ref().map(|end| format!(\":{}\", end)).unwrap_or_default()"
15)]
16pub struct PortRange {
17    pub start: u16,
18    pub end: Option<u16>,
19}
20
21impl PortRange {
22    /// Represents an ephemeral port as defined using the port range of 0.
23    pub const EPHEMERAL: Self = Self {
24        start: 0,
25        end: None,
26    };
27
28    /// Creates a port range targeting a single `port`.
29    #[inline]
30    pub fn single(port: u16) -> Self {
31        Self {
32            start: port,
33            end: None,
34        }
35    }
36
37    /// Builds a collection of `SocketAddr` instances from the port range and given ip address
38    pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
39        let mut socket_addrs = Vec::new();
40        let addr = addr.into();
41
42        for port in self {
43            socket_addrs.push(SocketAddr::from((addr, port)));
44        }
45
46        socket_addrs
47    }
48
49    /// Returns true if port range represents the ephemeral port.
50    #[inline]
51    pub fn is_ephemeral(&self) -> bool {
52        self == &Self::EPHEMERAL
53    }
54}
55
56impl From<u16> for PortRange {
57    fn from(port: u16) -> Self {
58        Self::single(port)
59    }
60}
61
62impl From<RangeInclusive<u16>> for PortRange {
63    fn from(r: RangeInclusive<u16>) -> Self {
64        let (start, end) = r.into_inner();
65        Self {
66            start,
67            end: Some(end),
68        }
69    }
70}
71
72impl<'a> IntoIterator for &'a PortRange {
73    type IntoIter = RangeInclusive<u16>;
74    type Item = u16;
75
76    fn into_iter(self) -> Self::IntoIter {
77        self.start..=self.end.unwrap_or(self.start)
78    }
79}
80
81impl IntoIterator for PortRange {
82    type IntoIter = RangeInclusive<u16>;
83    type Item = u16;
84
85    fn into_iter(self) -> Self::IntoIter {
86        self.start..=self.end.unwrap_or(self.start)
87    }
88}
89
90impl FromStr for PortRange {
91    type Err = std::num::ParseIntError;
92
93    /// Parses PORT into single range or PORT1:PORTN into full range
94    fn from_str(s: &str) -> Result<Self, Self::Err> {
95        match s.find(':') {
96            Some(idx) if idx + 1 < s.len() => Ok(Self {
97                start: s[..idx].parse()?,
98                end: Some(s[(idx + 1)..].parse()?),
99            }),
100            _ => Ok(Self {
101                start: s.parse()?,
102                end: None,
103            }),
104        }
105    }
106}
107
108impl Serialize for PortRange {
109    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
110    where
111        S: serde::ser::Serializer,
112    {
113        String::serialize(&self.to_string(), serializer)
114    }
115}
116
117impl<'de> Deserialize<'de> for PortRange {
118    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119    where
120        D: serde::de::Deserializer<'de>,
121    {
122        struct PortRangeVisitor;
123        impl<'de> de::Visitor<'de> for PortRangeVisitor {
124            type Value = PortRange;
125
126            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
127                write!(formatter, "a port in the form NUMBER or START:END")
128            }
129
130            fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
131            where
132                E: de::Error,
133            {
134                FromStr::from_str(s).map_err(de::Error::custom)
135            }
136
137            fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
138            where
139                E: de::Error,
140            {
141                Ok(PortRange {
142                    start: v as u16,
143                    end: None,
144                })
145            }
146
147            fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
148            where
149                E: de::Error,
150            {
151                Ok(PortRange {
152                    start: v,
153                    end: None,
154                })
155            }
156
157            fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
158            where
159                E: de::Error,
160            {
161                Ok(PortRange {
162                    start: v.try_into().map_err(de::Error::custom)?,
163                    end: None,
164                })
165            }
166
167            fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
168            where
169                E: de::Error,
170            {
171                Ok(PortRange {
172                    start: v.try_into().map_err(de::Error::custom)?,
173                    end: None,
174                })
175            }
176
177            fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
178            where
179                E: de::Error,
180            {
181                Ok(PortRange {
182                    start: v.try_into().map_err(de::Error::custom)?,
183                    end: None,
184                })
185            }
186
187            fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
188            where
189                E: de::Error,
190            {
191                Ok(PortRange {
192                    start: v.try_into().map_err(de::Error::custom)?,
193                    end: None,
194                })
195            }
196
197            fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E>
198            where
199                E: de::Error,
200            {
201                Ok(PortRange {
202                    start: v.try_into().map_err(de::Error::custom)?,
203                    end: None,
204                })
205            }
206
207            fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
208            where
209                E: de::Error,
210            {
211                Ok(PortRange {
212                    start: v.try_into().map_err(de::Error::custom)?,
213                    end: None,
214                })
215            }
216
217            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
218            where
219                E: de::Error,
220            {
221                Ok(PortRange {
222                    start: v.try_into().map_err(de::Error::custom)?,
223                    end: None,
224                })
225            }
226
227            fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
228            where
229                E: de::Error,
230            {
231                Ok(PortRange {
232                    start: v.try_into().map_err(de::Error::custom)?,
233                    end: None,
234                })
235            }
236        }
237
238        deserializer.deserialize_any(PortRangeVisitor)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn display_should_properly_reflect_port_range() {
248        let p = PortRange {
249            start: 100,
250            end: None,
251        };
252        assert_eq!(p.to_string(), "100");
253
254        let p = PortRange {
255            start: 100,
256            end: Some(200),
257        };
258        assert_eq!(p.to_string(), "100:200");
259    }
260
261    #[test]
262    fn from_range_inclusive_should_map_to_port_range() {
263        let p = PortRange::from(100..=200);
264        assert_eq!(p.start, 100);
265        assert_eq!(p.end, Some(200));
266    }
267
268    #[test]
269    fn into_iterator_should_support_port_range() {
270        let p = PortRange {
271            start: 1,
272            end: None,
273        };
274        assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1]);
275        assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1]);
276
277        let p = PortRange {
278            start: 1,
279            end: Some(3),
280        };
281        assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
282        assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
283    }
284
285    #[test]
286    fn make_socket_addrs_should_produce_a_socket_addr_per_port() {
287        let ip_addr = "127.0.0.1".parse::<IpAddr>().unwrap();
288
289        let p = PortRange {
290            start: 1,
291            end: None,
292        };
293        assert_eq!(
294            p.make_socket_addrs(ip_addr),
295            vec![SocketAddr::new(ip_addr, 1)]
296        );
297
298        let p = PortRange {
299            start: 1,
300            end: Some(3),
301        };
302        assert_eq!(
303            p.make_socket_addrs(ip_addr),
304            vec![
305                SocketAddr::new(ip_addr, 1),
306                SocketAddr::new(ip_addr, 2),
307                SocketAddr::new(ip_addr, 3),
308            ]
309        );
310    }
311
312    #[test]
313    fn parse_should_fail_if_not_starting_with_number() {
314        assert!("100a".parse::<PortRange>().is_err());
315    }
316
317    #[test]
318    fn parse_should_fail_if_provided_end_port_that_is_not_a_number() {
319        assert!("100:200a".parse::<PortRange>().is_err());
320    }
321
322    #[test]
323    fn parse_should_be_able_to_properly_read_in_port_range() {
324        let p: PortRange = "100".parse().unwrap();
325        assert_eq!(
326            p,
327            PortRange {
328                start: 100,
329                end: None
330            }
331        );
332
333        let p: PortRange = "100:200".parse().unwrap();
334        assert_eq!(
335            p,
336            PortRange {
337                start: 100,
338                end: Some(200)
339            }
340        );
341    }
342
343    #[test]
344    fn serialize_should_leverage_tostring() {
345        assert_eq!(
346            serde_json::to_value(PortRange {
347                start: 123,
348                end: None,
349            })
350            .unwrap(),
351            serde_json::Value::String("123".to_string())
352        );
353
354        assert_eq!(
355            serde_json::to_value(PortRange {
356                start: 123,
357                end: Some(456),
358            })
359            .unwrap(),
360            serde_json::Value::String("123:456".to_string())
361        );
362    }
363
364    #[test]
365    fn deserialize_should_use_single_number_as_start() {
366        // Supports parsing numbers
367        assert_eq!(
368            serde_json::from_str::<PortRange>("123").unwrap(),
369            PortRange {
370                start: 123,
371                end: None
372            }
373        );
374    }
375
376    #[test]
377    fn deserialize_should_leverage_fromstr_for_strings() {
378        // Supports string number
379        assert_eq!(
380            serde_json::from_str::<PortRange>("\"123\"").unwrap(),
381            PortRange {
382                start: 123,
383                end: None
384            }
385        );
386
387        // Supports string start:end
388        assert_eq!(
389            serde_json::from_str::<PortRange>("\"123:456\"").unwrap(),
390            PortRange {
391                start: 123,
392                end: Some(456)
393            }
394        );
395    }
396}