1use crate::error::{Errors, Result};
2use rand::prelude::*;
3use std::{collections::HashSet, net::IpAddr, ops::RangeInclusive};
4
5pub mod error;
6mod utils;
7
8const MIN_PORT: u16 = 1024;
9const MAX_PORT: u16 = 65535;
10
11pub enum Protocol {
13 All,
14 Tcp,
15 Udp,
16}
17
18pub struct PortPicker {
30 range: RangeInclusive<u16>,
31 exclude: HashSet<u16>,
32 protocol: Protocol,
33 host: Option<String>,
34 random: bool,
35}
36
37impl PortPicker {
38 pub fn new() -> Self {
39 PortPicker {
40 range: MIN_PORT..=MAX_PORT,
41 exclude: HashSet::new(),
42 protocol: Protocol::All,
43 host: None,
44 random: false,
45 }
46 }
47
48 pub fn port_range(mut self, range: RangeInclusive<u16>) -> Self {
50 self.range = range;
51 self
52 }
53
54 pub fn execlude(mut self, exclude: HashSet<u16>) -> Self {
56 self.exclude = exclude;
57 self
58 }
59
60 pub fn execlude_add(mut self, port: u16) -> Self {
62 self.exclude.insert(port);
63 self
64 }
65
66 pub fn protocol(mut self, protocol: Protocol) -> Self {
68 self.protocol = protocol;
69 self
70 }
71
72 pub fn host(mut self, host: String) -> Self {
75 self.host = Some(host);
76 self
77 }
78
79 pub fn random(mut self, random: bool) -> Self {
82 self.random = random;
83 self
84 }
85
86 fn random_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
87 let mut rng = rand::thread_rng();
88 let len = self.range.len();
89 for _ in 0..len {
90 let port = rng.gen_range(*self.range.start()..=*self.range.end());
91 if self.exclude.contains(&port) {
92 continue;
93 }
94 if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
95 return Ok(port);
96 }
97 }
98 Err(Errors::NoAvailablePort)
99 }
100
101 fn get_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
102 for port in self.range.clone() {
103 if self.exclude.contains(&port) {
104 continue;
105 }
106 if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
107 return Ok(port);
108 }
109 }
110 Err(Errors::NoAvailablePort)
111 }
112
113 pub fn pick(&self) -> Result<u16> {
114 if self.range.is_empty() {
116 return Err(Errors::InvalidOption(
117 "The start port must be less than or equal to the end port".to_string(),
118 ));
119 }
120 if *self.range.start() < MIN_PORT || *self.range.end() > MAX_PORT {
121 return Err(Errors::InvalidOption(format!(
122 "The port range must be between {} and {}",
123 MIN_PORT, MAX_PORT
124 )));
125 }
126
127 let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
128 if let Some(host) = &self.host {
129 if let Ok(ip_addr) = host.parse::<IpAddr>() {
130 ip_addrs.insert(ip_addr);
131 } else {
132 return Err(Errors::InvalidOption(format!(
133 "The host {} is not a valid IP address",
134 host
135 )));
136 }
137 } else {
138 ip_addrs = utils::get_local_hosts();
139 }
140 if self.random {
141 self.random_port(ip_addrs)
142 } else {
143 self.get_port(ip_addrs)
144 }
145 }
146}
147
148pub fn is_free(port: u16, host: Option<String>, protocol: Protocol) -> bool {
155 let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
156 if let Some(host) = host {
157 if let Ok(ip_addr) = host.parse::<IpAddr>() {
158 ip_addrs.insert(ip_addr);
159 } else {
160 return false;
161 }
162 } else {
163 ip_addrs = utils::get_local_hosts();
164 }
165 utils::is_free_in_hosts(port, &ip_addrs, &protocol)
166}
167
168#[cfg(test)]
169mod tests {
170
171 use super::*;
172
173 #[test]
174 fn test_port_picker() {
175 let port = PortPicker::new().pick().unwrap();
176 assert!(port >= MIN_PORT && port <= MAX_PORT);
177
178 let result = PortPicker::new().port_range(3000..=4000).pick();
179 assert!(result.is_ok());
180 let port = result.unwrap();
181 assert!(port >= 3000 && port <= 4000);
182 }
183
184 #[test]
185 fn test_is_free() {
186 let port = PortPicker::new().pick().unwrap();
187 assert!(is_free(port, None, Protocol::All));
188 assert!(!is_free(80, None, Protocol::All));
190 }
191}