use std::collections::HashMap;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::path::Path;
use color_eyre::Result;
use color_eyre::eyre::eyre;
use serde::Deserialize;
use serde::Serialize;
use socket2::Domain;
use socket2::Socket;
use socket2::Type;
use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tracing::info;
const PORT_RANGE_START: u16 = 32768;
const PORT_RANGE_END: u16 = 61000;
pub fn create_freebind_socket(addr: &SocketAddr) -> std::io::Result<Socket> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, None)?;
socket.set_reuse_address(true)?;
if addr.is_ipv4() {
socket.set_freebind_v4(true)?;
} else {
socket.set_freebind_v6(true)?;
}
Ok(socket)
}
pub fn is_port_free<A: ToSocketAddrs>(addr: A) -> bool {
let Ok(mut addrs) = addr.to_socket_addrs() else {
return false;
};
let Some(sock_addr) = addrs.next() else {
return false;
};
let Ok(socket) = create_freebind_socket(&sock_addr) else {
return false;
};
socket.bind(&sock_addr.into()).is_ok()
}
pub struct PortAllocator {
sockets: RwLock<HashMap<SocketAddr, TcpListener>>,
}
impl Default for PortAllocator {
fn default() -> Self {
Self::new()
}
}
impl PortAllocator {
pub fn new() -> Self {
Self {
sockets: RwLock::new(HashMap::new()),
}
}
pub async fn allocate(&self, addr: SocketAddr) -> Result<()> {
let socket = create_freebind_socket(&addr)
.map_err(|e| eyre!("Failed to create socket for {addr}: {e}"))?;
socket
.bind(&addr.into())
.map_err(|e| eyre!("Failed to reserve {addr}: {e}"))?;
socket
.listen(128)
.map_err(|e| eyre!("Failed to listen on {addr}: {e}"))?;
let std_listener: std::net::TcpListener = socket.into();
std_listener
.set_nonblocking(true)
.map_err(|e| eyre!("Failed to set nonblocking for {addr}: {e}"))?;
let listener = TcpListener::from_std(std_listener)
.map_err(|e| eyre!("Failed to create tokio listener for {addr}: {e}"))?;
info!("Reserved {addr}");
self.sockets.write().await.insert(addr, listener);
Ok(())
}
pub async fn deallocate(&self, addr: SocketAddr) {
self.sockets.write().await.remove(&addr);
info!("Released {addr}");
}
pub async fn is_allocated(&self, addr: SocketAddr) -> bool {
if self.sockets.read().await.contains_key(&addr) {
return true;
}
let Ok(socket) = create_freebind_socket(&addr) else {
return false;
};
match socket.bind(&addr.into()) {
Err(e) => e.kind() == std::io::ErrorKind::AddrInUse,
Ok(_) => false,
}
}
pub async fn deallocate_all(&self) {
let mut sockets = self.sockets.write().await;
let count = sockets.len();
sockets.clear();
info!("Released all {count} reservations");
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PortAssignments {
assignments: HashMap<String, u16>,
}
impl PortAssignments {
pub fn load(path: &Path) -> Self {
std::fs::read_to_string(path)
.ok()
.and_then(|s| serde_json::from_str(&s).ok())
.unwrap_or_default()
}
pub fn save(&self, path: &Path) -> Result<(), std::io::Error> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let contents = serde_json::to_string_pretty(self)?;
std::fs::write(path, contents)
}
pub fn get(&self, key: &str) -> Option<u16> {
self.assignments.get(key).copied()
}
pub fn set(&mut self, key: String, port: u16) {
self.assignments.insert(key, port);
}
#[allow(dead_code)]
pub fn remove(&mut self, key: &str) -> Option<u16> {
self.assignments.remove(key)
}
pub fn is_used(&self, port: u16) -> bool {
self.assignments.values().any(|&p| p == port)
}
pub fn get_or_allocate(&mut self, key: &str) -> Option<u16> {
if let Some(&p) = self.assignments.get(key) {
return Some(p);
}
let p = self.allocate_port()?;
self.assignments.insert(key.to_string(), p);
Some(p)
}
fn allocate_port(&self) -> Option<u16> {
(PORT_RANGE_START..=PORT_RANGE_END)
.find(|&port| !self.is_used(port) && is_port_free(format!("0.0.0.0:{port}")))
}
}
#[cfg(test)]
mod tests {
use std::net::TcpListener;
use tempfile::TempDir;
use super::*;
#[test]
fn set_and_get() {
let mut pa = PortAssignments::default();
pa.set("example-drive-80".into(), 32000);
assert_eq!(pa.get("example-drive-80"), Some(32000));
}
#[test]
fn remove() {
let mut pa = PortAssignments::default();
pa.set("key".into(), 32000);
assert_eq!(pa.remove("key"), Some(32000));
assert_eq!(pa.get("key"), None);
}
#[test]
fn persistence() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("ports.json");
let mut pa = PortAssignments::default();
pa.set("s1".into(), 40000);
pa.set("s2".into(), 40001);
pa.save(&path).unwrap();
let loaded = PortAssignments::load(&path);
assert_eq!(loaded.get("s1"), Some(40000));
assert_eq!(loaded.get("s2"), Some(40001));
}
#[test]
fn get_or_allocate_returns_existing() {
let mut pa = PortAssignments::default();
pa.set("key".into(), 32000);
assert_eq!(pa.get_or_allocate("key"), Some(32000));
}
#[test]
fn get_or_allocate_creates_new() {
let mut pa = PortAssignments::default();
let p = pa.get_or_allocate("new-key").unwrap();
assert!((PORT_RANGE_START..=PORT_RANGE_END).contains(&p));
assert_eq!(pa.get("new-key"), Some(p));
}
#[test]
fn is_port_free_localhost() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
assert!(is_port_free(format!("127.0.0.1:{port}")));
}
#[test]
fn is_port_free_occupied() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
assert!(!is_port_free(format!("127.0.0.1:{port}")));
drop(listener);
assert!(is_port_free(format!("127.0.0.1:{port}")));
}
#[test]
fn allocate_port_assigns_unique() {
let mut pa = PortAssignments::default();
let p1 = pa.allocate_port().unwrap();
pa.set("s1".into(), p1);
let p2 = pa.allocate_port().unwrap();
assert_ne!(p1, p2);
}
}