1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
use crate::error::{Errors, Result};
use rand::prelude::*;
use std::{collections::HashSet, net::IpAddr, ops::RangeInclusive};

pub mod error;
mod utils;

const MIN_PORT: u16 = 1024;
const MAX_PORT: u16 = 65535;

//
pub enum Protocol {
    All,
    Tcp,
    Udp,
}

/// PortPicker is a simple library to pick a free port in the local machine.
///
/// It can be used to find a free port to start a server or any other use case.
///
/// #Examples:
///
/// ```
/// use random_port::PortPicker;
/// let port = PortPicker::new().pick().unwrap();
/// println!("The free port is {}", port);
/// ```
pub struct PortPicker {
    range: RangeInclusive<u16>,
    exclude: HashSet<u16>,
    protocol: Protocol,
    host: Option<String>,
    random: bool,
}

impl PortPicker {
    pub fn new() -> Self {
        PortPicker {
            range: MIN_PORT..=MAX_PORT,
            exclude: HashSet::new(),
            protocol: Protocol::All,
            host: None,
            random: false,
        }
    }

    /// Specifies the range of ports to check. Must be in the range `1024..=65535`. E.g. `port_range(1024..=65535)`.
    pub fn port_range(mut self, range: RangeInclusive<u16>) -> Self {
        self.range = range;
        self
    }

    /// Specifies the ports to exclude.
    pub fn execlude(mut self, exclude: HashSet<u16>) -> Self {
        self.exclude = exclude;
        self
    }

    /// Specifies a port to exclude.
    pub fn execlude_add(mut self, port: u16) -> Self {
        self.exclude.insert(port);
        self
    }

    /// Specifies the protocol to check, Default is `Protocol::All`. Can be either `Protocol::Tcp`, `Protocol::Udp` or `Protocol::All`.
    pub fn protocol(mut self, protocol: Protocol) -> Self {
        self.protocol = protocol;
        self
    }

    /// Specifies the host to check. Can be either an Ipv4 or Ipv6 address.
    /// If not specified, will checks availability on all local addresses defined in the system.
    pub fn host(mut self, host: String) -> Self {
        self.host = Some(host);
        self
    }

    /// Specifies whether to pick a random port from the range.
    /// If not specified, will pick the first available port from the range.
    pub fn random(mut self, random: bool) -> Self {
        self.random = random;
        self
    }

    fn random_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
        let mut rng = rand::thread_rng();
        let len = self.range.len();
        for _ in 0..len {
            let port = rng.gen_range(*self.range.start()..=*self.range.end());
            if self.exclude.contains(&port) {
                continue;
            }
            if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
                return Ok(port);
            }
        }
        Err(Errors::NoAvailablePort)
    }

    fn get_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
        for port in self.range.clone() {
            if self.exclude.contains(&port) {
                continue;
            }
            if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
                return Ok(port);
            }
        }
        Err(Errors::NoAvailablePort)
    }

    pub fn pick(&self) -> Result<u16> {
        // check params
        if self.range.is_empty() {
            return Err(Errors::InvalidOption(
                "The start port must be less than or equal to the end port".to_string(),
            ));
        }
        if *self.range.start() < MIN_PORT || *self.range.end() > MAX_PORT {
            return Err(Errors::InvalidOption(format!(
                "The port range must be between {} and {}",
                MIN_PORT, MAX_PORT
            )));
        }

        let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
        if let Some(host) = &self.host {
            if let Ok(ip_addr) = host.parse::<IpAddr>() {
                ip_addrs.insert(ip_addr);
            } else {
                return Err(Errors::InvalidOption(format!(
                    "The host {} is not a valid IP address",
                    host
                )));
            }
        } else {
            ip_addrs = utils::get_local_hosts();
        }
        if self.random {
            self.random_port(ip_addrs)
        } else {
            self.get_port(ip_addrs)
        }
    }
}

/// Check if a port is free in the local machine.
/// If the host is not specified, it will check on all local addresses defined in the system.
///
/// - `port`: The port to check.
/// - `host`: The host to check. Can be either an Ipv4 or Ipv6 address.
/// - `protocol`: The protocol to check. Can be either `Protocol::Tcp`, `Protocol::Udp` or `Protocol::All`.
pub fn is_free(port: u16, host: Option<String>, protocol: Protocol) -> bool {
    let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
    if let Some(host) = host {
        if let Ok(ip_addr) = host.parse::<IpAddr>() {
            ip_addrs.insert(ip_addr);
        } else {
            return false;
        }
    } else {
        ip_addrs = utils::get_local_hosts();
    }
    utils::is_free_in_hosts(port, &ip_addrs, &protocol)
}

#[cfg(test)]
mod tests {

    use super::*;

    #[test]
    fn test_port_picker() {
        let port = PortPicker::new().pick().unwrap();
        assert!(port >= MIN_PORT && port <= MAX_PORT);

        let result = PortPicker::new().port_range(3000..=4000).pick();
        assert!(result.is_ok());
        let port = result.unwrap();
        assert!(port >= 3000 && port <= 4000);
    }

    #[test]
    fn test_is_free() {
        let port = PortPicker::new().pick().unwrap();
        assert!(is_free(port, None, Protocol::All));
        // In my macos, port 80 is not free
        assert!(!is_free(80, None, Protocol::All));
    }
}