use std::{fmt::Display, str::FromStr};
use rand::Rng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PortRange {
start: u16,
end: u16,
}
impl PortRange {
pub fn size(&self) -> usize {
(self.end - self.start + 1) as usize
}
pub fn contains(&self, port: u16) -> bool {
port >= self.start && port <= self.end
}
pub fn start(&self) -> u16 {
self.start
}
pub fn end(&self) -> u16 {
self.end
}
}
impl Default for PortRange {
fn default() -> Self {
Self {
start: 49152,
end: 65535,
}
}
}
impl From<std::ops::Range<u16>> for PortRange {
fn from(range: std::ops::Range<u16>) -> Self {
debug_assert!(range.start <= range.end, "Port range start must be <= end");
Self {
start: range.start,
end: range.end,
}
}
}
impl Display for PortRange {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}..{}", self.start, self.end)
}
}
#[derive(Debug)]
pub struct PortRangeParseError(String);
impl std::error::Error for PortRangeParseError {}
impl std::fmt::Display for PortRangeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<std::num::ParseIntError> for PortRangeParseError {
fn from(error: std::num::ParseIntError) -> Self {
PortRangeParseError(error.to_string())
}
}
impl FromStr for PortRange {
type Err = PortRangeParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (start, end) = s
.split_once("..")
.ok_or(PortRangeParseError(s.to_string()))?;
Ok(Self {
start: start.parse()?,
end: end.parse()?,
})
}
}
impl Serialize for PortRange {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for PortRange {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::from_str(&s).map_err(|e| serde::de::Error::custom(e.0))
}
}
#[derive(PartialEq, Eq)]
pub enum Bit {
Low,
High,
}
pub struct PortAllocator {
port_range: PortRange,
buckets: Vec<u64>,
allocated: usize,
bit_len: u32,
max_offset: usize,
}
impl Default for PortAllocator {
fn default() -> Self {
Self::new(PortRange::default())
}
}
impl PortAllocator {
pub fn new(port_range: PortRange) -> Self {
let capacity = port_range.size();
let bucket_size = (capacity + 63) / 64;
let tail_bits = capacity % 64;
let bit_len = if tail_bits == 0 { 64 } else { tail_bits } as u32;
Self {
bit_len,
buckets: vec![0; bucket_size],
max_offset: bucket_size - 1,
allocated: 0,
port_range,
}
}
pub fn capacity(&self) -> usize {
self.port_range.size()
}
pub fn port_range(&self) -> &PortRange {
&self.port_range
}
pub fn len(&self) -> usize {
self.allocated
}
pub fn is_empty(&self) -> bool {
self.allocated == 0
}
pub fn allocate(&mut self, start: Option<usize>) -> Option<u16> {
let mut index = None;
let mut offset = start.unwrap_or_else(|| rand::rng().random_range(0..=self.max_offset));
let start_offset = offset;
loop {
if let Some(i) = {
let bucket = self.buckets[offset];
if bucket < u64::MAX {
let idx = bucket.leading_ones();
if offset == self.max_offset && idx >= self.bit_len {
None
} else {
Some(idx)
}
} else {
None
}
} {
index = Some(i as usize);
break;
}
if offset == self.max_offset {
offset = 0;
} else {
offset += 1;
}
if offset == start_offset {
break;
}
}
let index = index?;
self.set_bit(offset, index, Bit::High);
self.allocated += 1;
let num = (offset * 64 + index) as u16;
let port = self.port_range.start + num;
Some(port)
}
pub fn set_bit(&mut self, bucket: usize, index: usize, bit: Bit) {
let high_mask = 1 << (63 - index);
let mask = match bit {
Bit::Low => u64::MAX ^ high_mask,
Bit::High => high_mask,
};
let value = self.buckets[bucket];
self.buckets[bucket] = match bit {
Bit::High => value | mask,
Bit::Low => value & mask,
};
}
pub fn deallocate(&mut self, port: u16) {
assert!(self.port_range.contains(port));
let offset = (port - self.port_range.start) as usize;
let bucket = offset / 64;
let index = offset - (bucket * 64);
let bit = match (self.buckets[bucket] & (1 << (63 - index))) >> (63 - index) {
0 => Bit::Low,
1 => Bit::High,
_ => unreachable!("Bit value can only be 0 or 1"),
};
if bit == Bit::Low {
return;
}
self.set_bit(bucket, index, Bit::Low);
self.allocated -= 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn allocate_all_ports_without_gaps_includes_tail_bits() {
let range = PortRange::from(50000..50069);
let mut pool = PortAllocator::new(range);
let mut ports = HashSet::new();
while let Some(port) = pool.allocate(None) {
assert!(range.contains(port));
assert!(ports.insert(port));
}
assert_eq!(pool.capacity(), ports.len());
assert_eq!(range.start(), *ports.iter().min().unwrap());
assert_eq!(range.end(), *ports.iter().max().unwrap());
}
#[test]
fn random_allocation_varies_first_port() {
let range = PortRange::from(50000..50127);
let mut first_ports = HashSet::new();
for _ in 0..128 {
let mut pool = PortAllocator::new(range);
if let Some(port) = pool.allocate(None) {
first_ports.insert(port);
}
}
assert!(first_ports.len() > 1);
}
}