use std::collections::BTreeSet;
use std::sync::Mutex;
use crate::{Port, PortRange, is_free};
pub struct PortReservation<R: PortRange> {
range: R,
reserved_ports: Mutex<BTreeSet<Port>>,
}
impl Default for PortReservation<std::ops::Range<Port>> {
fn default() -> Self {
Self::new(15000..65535)
}
}
impl Default for PortReservation<std::ops::RangeInclusive<Port>> {
fn default() -> Self {
Self::new(15000..=65535)
}
}
fn reservation_is_free(ports: &BTreeSet<Port>, port: Port) -> bool {
!ports.contains(&port) && is_free(port)
}
impl<R: PortRange> PortReservation<R> {
#[must_use]
pub const fn new(range: R) -> Self {
Self {
range,
reserved_ports: Mutex::new(BTreeSet::new()),
}
}
pub fn reserve_ports(&self, num_ports: usize) -> Vec<Port> {
let mut reserved_ports = self.reserved_ports.lock().unwrap();
let mut ports = Vec::new();
for port in self.range.iter() {
if ports.len() >= num_ports {
break;
}
if reservation_is_free(&reserved_ports, port) {
reserved_ports.insert(port);
ports.push(port);
}
}
drop(reserved_ports);
ports
}
pub fn reserve_port(&self) -> Option<Port> {
let mut reserved_ports = self.reserved_ports.lock().unwrap();
let port = self
.range
.iter()
.find(|x| reservation_is_free(&reserved_ports, *x))?;
reserved_ports.insert(port);
drop(reserved_ports);
Some(port)
}
pub fn release_ports(&self, ports: impl Iterator<Item = Port>) {
let mut reserved_ports = self.reserved_ports.lock().unwrap();
for port in ports {
reserved_ports.remove(&port);
}
}
pub fn release_port(&self, port: Port) {
self.reserved_ports.lock().unwrap().remove(&port);
}
#[must_use]
pub fn is_reserved(&self, port: Port) -> bool {
self.reserved_ports.lock().unwrap().contains(&port)
}
#[must_use]
pub fn is_free(&self, port: Port) -> bool {
!self.reserved_ports.lock().unwrap().contains(&port)
}
}
#[cfg(test)]
mod tests {
use serial_test::serial;
use super::*;
use crate::test_utils::{next_port_range, next_port_range_inclusive};
#[test_log::test]
fn test_reserve_port() {
let range = next_port_range(100);
let reservation = PortReservation::new(range.clone());
let port = reservation.reserve_port();
assert!(port.is_some());
let port = port.unwrap();
assert!(range.contains(&port));
assert!(reservation.is_reserved(port));
}
#[test_log::test]
fn test_reserve_ports() {
let range = next_port_range(100);
let reservation = PortReservation::new(range.clone());
let ports = reservation.reserve_ports(5);
assert_eq!(ports.len(), 5);
for port in &ports {
assert!(range.contains(port));
assert!(reservation.is_reserved(*port));
}
let mut unique_ports = ports.clone();
unique_ports.sort_unstable();
unique_ports.dedup();
assert_eq!(unique_ports.len(), ports.len());
}
#[test_log::test]
fn test_release_port() {
let range = next_port_range(100);
let reservation = PortReservation::new(range);
let port = reservation.reserve_port().unwrap();
assert!(reservation.is_reserved(port));
reservation.release_port(port);
assert!(!reservation.is_reserved(port));
}
#[test_log::test]
fn test_release_ports() {
let range = next_port_range(100);
let reservation = PortReservation::new(range);
let ports = reservation.reserve_ports(10);
assert_eq!(ports.len(), 10);
for port in &ports {
assert!(reservation.is_reserved(*port));
}
reservation.release_ports(ports.iter().copied());
for port in ports {
assert!(!reservation.is_reserved(port));
}
}
#[test_log::test]
#[serial]
fn test_default_implementation() {
let reservation: PortReservation<std::ops::Range<u16>> = PortReservation::default();
let port = reservation.reserve_port();
assert!(port.is_some());
assert!((15000..65535).contains(&port.unwrap()));
}
#[test_log::test]
#[serial]
fn test_default_implementation_inclusive() {
let reservation: PortReservation<std::ops::RangeInclusive<u16>> =
PortReservation::default();
let port = reservation.reserve_port();
assert!(port.is_some());
assert!((15000..=65535).contains(&port.unwrap()));
}
#[test_log::test]
fn test_reserve_more_than_available() {
let range = next_port_range(2);
let reservation = PortReservation::new(range.clone());
let ports = reservation.reserve_ports(10);
assert!(ports.len() <= 2);
assert!(!ports.is_empty());
for port in ports {
assert!(range.contains(&port));
assert!(reservation.is_reserved(port));
}
}
#[test_log::test]
fn test_no_free_ports() {
let range = 65530..65535;
let reservation = PortReservation::new(range.clone());
let ports = reservation.reserve_ports(5);
for port in ports {
assert!(range.contains(&port));
}
}
#[test_log::test]
fn test_inclusive_range() {
let range = next_port_range_inclusive(11);
let reservation = PortReservation::new(range.clone());
let ports = reservation.reserve_ports(5);
assert_eq!(ports.len(), 5);
for port in ports {
assert!(range.contains(&port));
assert!(reservation.is_reserved(port));
}
}
#[test_log::test]
fn test_reserve_after_release() {
let range = next_port_range(100);
let reservation = PortReservation::new(range.clone());
let ports = reservation.reserve_ports(3);
assert_eq!(ports.len(), 3);
reservation.release_ports(ports.iter().copied());
let new_ports = reservation.reserve_ports(3);
assert_eq!(new_ports.len(), 3);
for port in new_ports {
assert!(range.contains(&port));
assert!(reservation.is_reserved(port));
}
}
#[test_log::test]
fn test_concurrent_reservations() {
use std::sync::Arc;
use std::thread;
let range = next_port_range(100);
let reservation = Arc::new(PortReservation::new(range));
let mut handles = vec![];
for _ in 0..10 {
let reservation_clone = Arc::clone(&reservation);
let handle = thread::spawn(move || reservation_clone.reserve_ports(5));
handles.push(handle);
}
let mut all_ports = Vec::new();
for handle in handles {
let ports = handle.join().unwrap();
all_ports.extend(ports);
}
let mut unique_ports = all_ports.clone();
unique_ports.sort_unstable();
unique_ports.dedup();
assert_eq!(
unique_ports.len(),
all_ports.len(),
"Concurrent reservations should not result in duplicate ports"
);
for port in &all_ports {
assert!(
reservation.is_reserved(*port),
"Port {port} should be marked as reserved"
);
}
}
#[test_log::test]
fn test_concurrent_reserve_and_release() {
use std::sync::Arc;
use std::thread;
let range = next_port_range(100);
let reservation = Arc::new(PortReservation::new(range));
let mut handles = vec![];
for i in 0..5 {
let reservation_clone = Arc::clone(&reservation);
let handle = thread::spawn(move || {
let ports = reservation_clone.reserve_ports(3);
if i % 2 == 0 {
reservation_clone.release_ports(ports.iter().copied());
vec![]
} else {
ports
}
});
handles.push(handle);
}
let mut remaining_ports = Vec::new();
for handle in handles {
let ports = handle.join().unwrap();
remaining_ports.extend(ports);
}
for port in &remaining_ports {
assert!(
reservation.is_reserved(*port),
"Port {port} should still be reserved"
);
}
let mut unique_ports = remaining_ports.clone();
unique_ports.sort_unstable();
unique_ports.dedup();
assert_eq!(
unique_ports.len(),
remaining_ports.len(),
"No duplicate ports should remain"
);
}
#[test_log::test]
fn test_is_reserved_non_reserved_port() {
let range = next_port_range(100);
let reservation = PortReservation::new(range.clone());
assert!(!reservation.is_reserved(range.start + 50));
}
#[test_log::test]
fn test_release_non_reserved_port() {
let range = next_port_range(100);
let reservation = PortReservation::new(range.clone());
let port = range.start + 50;
reservation.release_port(port);
assert!(!reservation.is_reserved(port));
}
#[test_log::test]
fn test_reserve_port_returns_none_when_all_occupied() {
let range = next_port_range(2);
let reservation = PortReservation::new(range);
let port1 = reservation.reserve_port();
let port2 = reservation.reserve_port();
assert!(port1.is_some() || port2.is_some());
if port1.is_some() && port2.is_some() {
let mut found_none = false;
for _ in 0..10 {
if reservation.reserve_port().is_none() {
found_none = true;
break;
}
}
assert!(
found_none,
"Should run out of ports in a small range after multiple reservations"
);
}
}
#[test_log::test]
fn test_double_release() {
let range = next_port_range(100);
let reservation = PortReservation::new(range);
let port = reservation.reserve_port().unwrap();
assert!(reservation.is_reserved(port));
reservation.release_port(port);
assert!(!reservation.is_reserved(port));
reservation.release_port(port);
assert!(!reservation.is_reserved(port));
}
#[test_log::test]
fn test_reserve_zero_ports() {
let range = next_port_range(100);
let reservation = PortReservation::new(range);
let ports = reservation.reserve_ports(0);
assert!(ports.is_empty());
}
#[test_log::test]
fn test_release_empty_iterator() {
let range = next_port_range(100);
let reservation = PortReservation::new(range);
let port = reservation.reserve_port().unwrap();
reservation.release_ports(std::iter::empty());
assert!(reservation.is_reserved(port));
}
#[test_log::test]
fn test_is_reserved_port_outside_range() {
let range = next_port_range(100);
let reservation = PortReservation::new(range.clone());
assert!(!reservation.is_reserved(range.start.saturating_sub(1)));
assert!(!reservation.is_reserved(range.end));
assert!(!reservation.is_reserved(range.end + 1000));
}
#[test_log::test]
fn test_is_free_port_outside_range() {
let range = next_port_range(100);
let reservation = PortReservation::new(range.clone());
assert!(reservation.is_free(range.start.saturating_sub(1)));
assert!(reservation.is_free(range.end));
assert!(reservation.is_free(range.end + 1000));
}
#[test_log::test]
fn test_sequential_single_port_reservations() {
let range = next_port_range(100);
let reservation = PortReservation::new(range);
let port1 = reservation.reserve_port();
let port2 = reservation.reserve_port();
let port3 = reservation.reserve_port();
assert!(port1.is_some());
assert!(port2.is_some());
assert!(port3.is_some());
let port1 = port1.unwrap();
let port2 = port2.unwrap();
let port3 = port3.unwrap();
assert_ne!(port1, port2);
assert_ne!(port2, port3);
assert_ne!(port1, port3);
assert!(reservation.is_reserved(port1));
assert!(reservation.is_reserved(port2));
assert!(reservation.is_reserved(port3));
}
#[test_log::test]
fn test_release_specific_subset_of_ports() {
let range = next_port_range(100);
let reservation = PortReservation::new(range);
let ports = reservation.reserve_ports(5);
assert_eq!(ports.len(), 5);
let middle_ports: Vec<_> = ports[1..4].to_vec();
reservation.release_ports(middle_ports.iter().copied());
assert!(reservation.is_reserved(ports[0]));
assert!(reservation.is_reserved(ports[4]));
assert!(!reservation.is_reserved(ports[1]));
assert!(!reservation.is_reserved(ports[2]));
assert!(!reservation.is_reserved(ports[3]));
}
#[test_log::test]
fn test_reservation_is_free_helper() {
use std::collections::BTreeSet;
let mut reserved = BTreeSet::new();
let range = next_port_range(100);
let port = crate::pick_unused_port(range);
if let Some(port) = port {
assert!(reservation_is_free(&reserved, port));
reserved.insert(port);
assert!(!reservation_is_free(&reserved, port));
}
}
}