use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use tokio::net::TcpListener;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct PortAllocator {
range_lo: u16,
range_hi: u16,
reserved: HashSet<u16>,
}
impl PortAllocator {
pub fn new(lo: u16, hi: u16) -> Self {
let (lo, hi) = if lo <= hi { (lo, hi) } else { (hi, lo) };
Self {
range_lo: lo,
range_hi: hi,
reserved: HashSet::new(),
}
}
pub async fn reserve(
&mut self,
bind_host: IpAddr,
db_ports: &HashSet<u16>,
) -> Result<u16> {
for p in self.range_lo..=self.range_hi {
if db_ports.contains(&p) || self.reserved.contains(&p) {
continue;
}
if probe_bind(bind_host, p).await.is_ok() {
self.reserved.insert(p);
return Ok(p);
}
}
Err(Error::PortRangeExhausted)
}
pub fn release(&mut self, port: u16) {
self.reserved.remove(&port);
}
pub fn mark_reserved(&mut self, port: u16) {
self.reserved.insert(port);
}
}
async fn probe_bind(host: IpAddr, port: u16) -> Result<()> {
let addr = SocketAddr::new(host, port);
let l = TcpListener::bind(addr).await?;
drop(l);
Ok(())
}
pub async fn find_free_port(
start: u16,
bind_host: IpAddr,
taken: &HashSet<u16>,
) -> Result<u16> {
let mut p = start.max(1025);
while p < u16::MAX {
if !taken.contains(&p) && probe_bind(bind_host, p).await.is_ok() {
return Ok(p);
}
p = p.saturating_add(1);
}
Err(Error::PortRangeExhausted)
}