random_port/
lib.rs

1use crate::error::{Errors, Result};
2use rand::prelude::*;
3use std::{collections::HashSet, net::IpAddr, ops::RangeInclusive};
4
5pub mod error;
6mod utils;
7
8const MIN_PORT: u16 = 1024;
9const MAX_PORT: u16 = 65535;
10
11//
12pub enum Protocol {
13    All,
14    Tcp,
15    Udp,
16}
17
18/// PortPicker is a simple library to pick a free port in the local machine.
19///
20/// It can be used to find a free port to start a server or any other use case.
21///
22/// #Examples:
23///
24/// ```
25/// use random_port::PortPicker;
26/// let port = PortPicker::new().pick().unwrap();
27/// println!("The free port is {}", port);
28/// ```
29pub struct PortPicker {
30    range: RangeInclusive<u16>,
31    exclude: HashSet<u16>,
32    protocol: Protocol,
33    host: Option<String>,
34    random: bool,
35}
36
37impl PortPicker {
38    pub fn new() -> Self {
39        PortPicker {
40            range: MIN_PORT..=MAX_PORT,
41            exclude: HashSet::new(),
42            protocol: Protocol::All,
43            host: None,
44            random: false,
45        }
46    }
47
48    /// Specifies the range of ports to check. Must be in the range `1024..=65535`. E.g. `port_range(1024..=65535)`.
49    pub fn port_range(mut self, range: RangeInclusive<u16>) -> Self {
50        self.range = range;
51        self
52    }
53
54    /// Specifies the ports to exclude.
55    pub fn execlude(mut self, exclude: HashSet<u16>) -> Self {
56        self.exclude = exclude;
57        self
58    }
59
60    /// Specifies a port to exclude.
61    pub fn execlude_add(mut self, port: u16) -> Self {
62        self.exclude.insert(port);
63        self
64    }
65
66    /// Specifies the protocol to check, Default is `Protocol::All`. Can be either `Protocol::Tcp`, `Protocol::Udp` or `Protocol::All`.
67    pub fn protocol(mut self, protocol: Protocol) -> Self {
68        self.protocol = protocol;
69        self
70    }
71
72    /// Specifies the host to check. Can be either an Ipv4 or Ipv6 address.
73    /// If not specified, will checks availability on all local addresses defined in the system.
74    pub fn host(mut self, host: String) -> Self {
75        self.host = Some(host);
76        self
77    }
78
79    /// Specifies whether to pick a random port from the range.
80    /// If not specified, will pick the first available port from the range.
81    pub fn random(mut self, random: bool) -> Self {
82        self.random = random;
83        self
84    }
85
86    fn random_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
87        let mut rng = rand::thread_rng();
88        let len = self.range.len();
89        for _ in 0..len {
90            let port = rng.gen_range(*self.range.start()..=*self.range.end());
91            if self.exclude.contains(&port) {
92                continue;
93            }
94            if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
95                return Ok(port);
96            }
97        }
98        Err(Errors::NoAvailablePort)
99    }
100
101    fn get_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
102        for port in self.range.clone() {
103            if self.exclude.contains(&port) {
104                continue;
105            }
106            if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
107                return Ok(port);
108            }
109        }
110        Err(Errors::NoAvailablePort)
111    }
112
113    pub fn pick(&self) -> Result<u16> {
114        // check params
115        if self.range.is_empty() {
116            return Err(Errors::InvalidOption(
117                "The start port must be less than or equal to the end port".to_string(),
118            ));
119        }
120        if *self.range.start() < MIN_PORT || *self.range.end() > MAX_PORT {
121            return Err(Errors::InvalidOption(format!(
122                "The port range must be between {} and {}",
123                MIN_PORT, MAX_PORT
124            )));
125        }
126
127        let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
128        if let Some(host) = &self.host {
129            if let Ok(ip_addr) = host.parse::<IpAddr>() {
130                ip_addrs.insert(ip_addr);
131            } else {
132                return Err(Errors::InvalidOption(format!(
133                    "The host {} is not a valid IP address",
134                    host
135                )));
136            }
137        } else {
138            ip_addrs = utils::get_local_hosts();
139        }
140        if self.random {
141            self.random_port(ip_addrs)
142        } else {
143            self.get_port(ip_addrs)
144        }
145    }
146}
147
148/// Check if a port is free in the local machine.
149/// If the host is not specified, it will check on all local addresses defined in the system.
150///
151/// - `port`: The port to check.
152/// - `host`: The host to check. Can be either an Ipv4 or Ipv6 address.
153/// - `protocol`: The protocol to check. Can be either `Protocol::Tcp`, `Protocol::Udp` or `Protocol::All`.
154pub fn is_free(port: u16, host: Option<String>, protocol: Protocol) -> bool {
155    let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
156    if let Some(host) = host {
157        if let Ok(ip_addr) = host.parse::<IpAddr>() {
158            ip_addrs.insert(ip_addr);
159        } else {
160            return false;
161        }
162    } else {
163        ip_addrs = utils::get_local_hosts();
164    }
165    utils::is_free_in_hosts(port, &ip_addrs, &protocol)
166}
167
168#[cfg(test)]
169mod tests {
170
171    use super::*;
172
173    #[test]
174    fn test_port_picker() {
175        let port = PortPicker::new().pick().unwrap();
176        assert!(port >= MIN_PORT && port <= MAX_PORT);
177
178        let result = PortPicker::new().port_range(3000..=4000).pick();
179        assert!(result.is_ok());
180        let port = result.unwrap();
181        assert!(port >= 3000 && port <= 4000);
182    }
183
184    #[test]
185    fn test_is_free() {
186        let port = PortPicker::new().pick().unwrap();
187        assert!(is_free(port, None, Protocol::All));
188        // In my macos, port 80 is not free
189        assert!(!is_free(80, None, Protocol::All));
190    }
191}