1use std::fmt;
2use std::net::{IpAddr, SocketAddr};
3use std::ops::RangeInclusive;
4use std::str::FromStr;
5
6use derive_more::Display;
7use serde::{de, Deserialize, Serialize};
8
9#[derive(Copy, Clone, Debug, Display, PartialEq, Eq)]
11#[display(
12 fmt = "{}{}",
13 start,
14 "end.as_ref().map(|end| format!(\":{}\", end)).unwrap_or_default()"
15)]
16pub struct PortRange {
17 pub start: u16,
18 pub end: Option<u16>,
19}
20
21impl PortRange {
22 pub const EPHEMERAL: Self = Self {
24 start: 0,
25 end: None,
26 };
27
28 #[inline]
30 pub fn single(port: u16) -> Self {
31 Self {
32 start: port,
33 end: None,
34 }
35 }
36
37 pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
39 let mut socket_addrs = Vec::new();
40 let addr = addr.into();
41
42 for port in self {
43 socket_addrs.push(SocketAddr::from((addr, port)));
44 }
45
46 socket_addrs
47 }
48
49 #[inline]
51 pub fn is_ephemeral(&self) -> bool {
52 self == &Self::EPHEMERAL
53 }
54}
55
56impl From<u16> for PortRange {
57 fn from(port: u16) -> Self {
58 Self::single(port)
59 }
60}
61
62impl From<RangeInclusive<u16>> for PortRange {
63 fn from(r: RangeInclusive<u16>) -> Self {
64 let (start, end) = r.into_inner();
65 Self {
66 start,
67 end: Some(end),
68 }
69 }
70}
71
72impl<'a> IntoIterator for &'a PortRange {
73 type IntoIter = RangeInclusive<u16>;
74 type Item = u16;
75
76 fn into_iter(self) -> Self::IntoIter {
77 self.start..=self.end.unwrap_or(self.start)
78 }
79}
80
81impl IntoIterator for PortRange {
82 type IntoIter = RangeInclusive<u16>;
83 type Item = u16;
84
85 fn into_iter(self) -> Self::IntoIter {
86 self.start..=self.end.unwrap_or(self.start)
87 }
88}
89
90impl FromStr for PortRange {
91 type Err = std::num::ParseIntError;
92
93 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 match s.find(':') {
96 Some(idx) if idx + 1 < s.len() => Ok(Self {
97 start: s[..idx].parse()?,
98 end: Some(s[(idx + 1)..].parse()?),
99 }),
100 _ => Ok(Self {
101 start: s.parse()?,
102 end: None,
103 }),
104 }
105 }
106}
107
108impl Serialize for PortRange {
109 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
110 where
111 S: serde::ser::Serializer,
112 {
113 String::serialize(&self.to_string(), serializer)
114 }
115}
116
117impl<'de> Deserialize<'de> for PortRange {
118 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119 where
120 D: serde::de::Deserializer<'de>,
121 {
122 struct PortRangeVisitor;
123 impl<'de> de::Visitor<'de> for PortRangeVisitor {
124 type Value = PortRange;
125
126 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
127 write!(formatter, "a port in the form NUMBER or START:END")
128 }
129
130 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
131 where
132 E: de::Error,
133 {
134 FromStr::from_str(s).map_err(de::Error::custom)
135 }
136
137 fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
138 where
139 E: de::Error,
140 {
141 Ok(PortRange {
142 start: v as u16,
143 end: None,
144 })
145 }
146
147 fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
148 where
149 E: de::Error,
150 {
151 Ok(PortRange {
152 start: v,
153 end: None,
154 })
155 }
156
157 fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
158 where
159 E: de::Error,
160 {
161 Ok(PortRange {
162 start: v.try_into().map_err(de::Error::custom)?,
163 end: None,
164 })
165 }
166
167 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
168 where
169 E: de::Error,
170 {
171 Ok(PortRange {
172 start: v.try_into().map_err(de::Error::custom)?,
173 end: None,
174 })
175 }
176
177 fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
178 where
179 E: de::Error,
180 {
181 Ok(PortRange {
182 start: v.try_into().map_err(de::Error::custom)?,
183 end: None,
184 })
185 }
186
187 fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
188 where
189 E: de::Error,
190 {
191 Ok(PortRange {
192 start: v.try_into().map_err(de::Error::custom)?,
193 end: None,
194 })
195 }
196
197 fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E>
198 where
199 E: de::Error,
200 {
201 Ok(PortRange {
202 start: v.try_into().map_err(de::Error::custom)?,
203 end: None,
204 })
205 }
206
207 fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
208 where
209 E: de::Error,
210 {
211 Ok(PortRange {
212 start: v.try_into().map_err(de::Error::custom)?,
213 end: None,
214 })
215 }
216
217 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
218 where
219 E: de::Error,
220 {
221 Ok(PortRange {
222 start: v.try_into().map_err(de::Error::custom)?,
223 end: None,
224 })
225 }
226
227 fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
228 where
229 E: de::Error,
230 {
231 Ok(PortRange {
232 start: v.try_into().map_err(de::Error::custom)?,
233 end: None,
234 })
235 }
236 }
237
238 deserializer.deserialize_any(PortRangeVisitor)
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn display_should_properly_reflect_port_range() {
248 let p = PortRange {
249 start: 100,
250 end: None,
251 };
252 assert_eq!(p.to_string(), "100");
253
254 let p = PortRange {
255 start: 100,
256 end: Some(200),
257 };
258 assert_eq!(p.to_string(), "100:200");
259 }
260
261 #[test]
262 fn from_range_inclusive_should_map_to_port_range() {
263 let p = PortRange::from(100..=200);
264 assert_eq!(p.start, 100);
265 assert_eq!(p.end, Some(200));
266 }
267
268 #[test]
269 fn into_iterator_should_support_port_range() {
270 let p = PortRange {
271 start: 1,
272 end: None,
273 };
274 assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1]);
275 assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1]);
276
277 let p = PortRange {
278 start: 1,
279 end: Some(3),
280 };
281 assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
282 assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
283 }
284
285 #[test]
286 fn make_socket_addrs_should_produce_a_socket_addr_per_port() {
287 let ip_addr = "127.0.0.1".parse::<IpAddr>().unwrap();
288
289 let p = PortRange {
290 start: 1,
291 end: None,
292 };
293 assert_eq!(
294 p.make_socket_addrs(ip_addr),
295 vec![SocketAddr::new(ip_addr, 1)]
296 );
297
298 let p = PortRange {
299 start: 1,
300 end: Some(3),
301 };
302 assert_eq!(
303 p.make_socket_addrs(ip_addr),
304 vec![
305 SocketAddr::new(ip_addr, 1),
306 SocketAddr::new(ip_addr, 2),
307 SocketAddr::new(ip_addr, 3),
308 ]
309 );
310 }
311
312 #[test]
313 fn parse_should_fail_if_not_starting_with_number() {
314 assert!("100a".parse::<PortRange>().is_err());
315 }
316
317 #[test]
318 fn parse_should_fail_if_provided_end_port_that_is_not_a_number() {
319 assert!("100:200a".parse::<PortRange>().is_err());
320 }
321
322 #[test]
323 fn parse_should_be_able_to_properly_read_in_port_range() {
324 let p: PortRange = "100".parse().unwrap();
325 assert_eq!(
326 p,
327 PortRange {
328 start: 100,
329 end: None
330 }
331 );
332
333 let p: PortRange = "100:200".parse().unwrap();
334 assert_eq!(
335 p,
336 PortRange {
337 start: 100,
338 end: Some(200)
339 }
340 );
341 }
342
343 #[test]
344 fn serialize_should_leverage_tostring() {
345 assert_eq!(
346 serde_json::to_value(PortRange {
347 start: 123,
348 end: None,
349 })
350 .unwrap(),
351 serde_json::Value::String("123".to_string())
352 );
353
354 assert_eq!(
355 serde_json::to_value(PortRange {
356 start: 123,
357 end: Some(456),
358 })
359 .unwrap(),
360 serde_json::Value::String("123:456".to_string())
361 );
362 }
363
364 #[test]
365 fn deserialize_should_use_single_number_as_start() {
366 assert_eq!(
368 serde_json::from_str::<PortRange>("123").unwrap(),
369 PortRange {
370 start: 123,
371 end: None
372 }
373 );
374 }
375
376 #[test]
377 fn deserialize_should_leverage_fromstr_for_strings() {
378 assert_eq!(
380 serde_json::from_str::<PortRange>("\"123\"").unwrap(),
381 PortRange {
382 start: 123,
383 end: None
384 }
385 );
386
387 assert_eq!(
389 serde_json::from_str::<PortRange>("\"123:456\"").unwrap(),
390 PortRange {
391 start: 123,
392 end: Some(456)
393 }
394 );
395 }
396}