use crate::error::{Errors, Result};
use rand::prelude::*;
use std::{collections::HashSet, net::IpAddr, ops::RangeInclusive};
pub mod error;
mod utils;
const MIN_PORT: u16 = 1024;
const MAX_PORT: u16 = 65535;
pub enum Protocol {
All,
Tcp,
Udp,
}
pub struct PortPicker {
range: RangeInclusive<u16>,
exclude: HashSet<u16>,
protocol: Protocol,
host: Option<String>,
random: bool,
}
impl PortPicker {
pub fn new() -> Self {
PortPicker {
range: MIN_PORT..=MAX_PORT,
exclude: HashSet::new(),
protocol: Protocol::All,
host: None,
random: false,
}
}
pub fn port_range(mut self, range: RangeInclusive<u16>) -> Self {
self.range = range;
self
}
pub fn execlude(mut self, exclude: HashSet<u16>) -> Self {
self.exclude = exclude;
self
}
pub fn execlude_add(mut self, port: u16) -> Self {
self.exclude.insert(port);
self
}
pub fn protocol(mut self, protocol: Protocol) -> Self {
self.protocol = protocol;
self
}
pub fn host(mut self, host: String) -> Self {
self.host = Some(host);
self
}
pub fn random(mut self, random: bool) -> Self {
self.random = random;
self
}
fn random_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
let mut rng = rand::thread_rng();
let len = self.range.len();
for _ in 0..len {
let port = rng.gen_range(*self.range.start()..=*self.range.end());
if self.exclude.contains(&port) {
continue;
}
if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
return Ok(port);
}
}
Err(Errors::NoAvailablePort)
}
fn get_port(&self, ip_addrs: HashSet<IpAddr>) -> Result<u16> {
for port in self.range.clone() {
if self.exclude.contains(&port) {
continue;
}
if utils::is_free_in_hosts(port, &ip_addrs, &self.protocol) {
return Ok(port);
}
}
Err(Errors::NoAvailablePort)
}
pub fn pick(&self) -> Result<u16> {
if self.range.is_empty() {
return Err(Errors::InvalidOption(
"The start port must be less than or equal to the end port".to_string(),
));
}
if *self.range.start() < MIN_PORT || *self.range.end() > MAX_PORT {
return Err(Errors::InvalidOption(format!(
"The port range must be between {} and {}",
MIN_PORT, MAX_PORT
)));
}
let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
if let Some(host) = &self.host {
if let Ok(ip_addr) = host.parse::<IpAddr>() {
ip_addrs.insert(ip_addr);
} else {
return Err(Errors::InvalidOption(format!(
"The host {} is not a valid IP address",
host
)));
}
} else {
ip_addrs = utils::get_local_hosts();
}
if self.random {
self.random_port(ip_addrs)
} else {
self.get_port(ip_addrs)
}
}
}
pub fn is_free(port: u16, host: Option<String>, protocol: Protocol) -> bool {
let mut ip_addrs: HashSet<IpAddr> = HashSet::new();
if let Some(host) = host {
if let Ok(ip_addr) = host.parse::<IpAddr>() {
ip_addrs.insert(ip_addr);
} else {
return false;
}
} else {
ip_addrs = utils::get_local_hosts();
}
utils::is_free_in_hosts(port, &ip_addrs, &protocol)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_port_picker() {
let port = PortPicker::new().pick().unwrap();
assert!(port >= MIN_PORT && port <= MAX_PORT);
let result = PortPicker::new().port_range(3000..=4000).pick();
assert!(result.is_ok());
let port = result.unwrap();
assert!(port >= 3000 && port <= 4000);
}
#[test]
fn test_is_free() {
let port = PortPicker::new().pick().unwrap();
assert!(is_free(port, None, Protocol::All));
assert!(!is_free(80, None, Protocol::All));
}
}