1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license OR Apache 2.0
//! Types for building an Snocat server and accepting, authenticating, and routing connections
#![warn(unused_imports)]
use futures::future::FutureExt;
use std::collections::HashSet;
use std::{ops::RangeInclusive, sync::Arc};
use tokio::sync::Mutex;

pub mod modular;

#[derive(Debug, Clone)]
pub struct PortRangeAllocator {
  range: std::ops::RangeInclusive<u16>,
  allocated: Arc<Mutex<std::collections::HashSet<u16>>>,
  mark_queue: tokio::sync::mpsc::UnboundedSender<u16>,
  // UnboundedReceiver does not implement clone, so we need an ArcMut of it
  mark_receiver: Arc<Mutex<tokio::sync::mpsc::UnboundedReceiver<u16>>>,
}

#[derive(thiserror::Error, Debug)]
pub enum PortRangeAllocationError {
  #[error("No ports were available to be allocated in range {0:?}")]
  NoFreePorts(std::ops::RangeInclusive<u16>),
}

impl PortRangeAllocator {
  pub fn new<T: Into<u16>>(bind_port_range: std::ops::RangeInclusive<T>) -> PortRangeAllocator {
    let (start, end): (u16, u16) = {
      let (a, b) = bind_port_range.into_inner();
      (a.into(), b.into())
    };
    let (mark_sender, mark_receiver) = tokio::sync::mpsc::unbounded_channel();
    PortRangeAllocator {
      range: std::ops::RangeInclusive::new(start, end),
      allocated: Default::default(),
      mark_queue: mark_sender,
      mark_receiver: Arc::new(Mutex::new(mark_receiver)),
    }
  }

  pub async fn allocate(&self) -> Result<PortRangeAllocationHandle, PortRangeAllocationError> {
    // Used for cleaning up in the deallocator
    let cloned_self = self.clone();
    let range = self.range.clone();
    let mark_receiver = Arc::clone(&self.mark_receiver);
    let mut lock = self.allocated.lock().await;

    // Consume existing marks for freed ports
    {
      let mut mark_receiver = mark_receiver.lock().await;
      Self::cleanup_freed_ports(&mut *lock, &mut *mark_receiver);
    }

    let port = range
      .clone()
      .into_iter()
      .filter(|test_port| !lock.contains(test_port))
      .min()
      .ok_or_else(|| PortRangeAllocationError::NoFreePorts(range.clone()))?;

    let allocation = PortRangeAllocationHandle::new(port, cloned_self);
    lock.insert(allocation.port);
    Ok(allocation)
  }

  pub async fn free(&self, port: u16) -> Result<bool, anyhow::Error> {
    let mark_receiver = Arc::clone(&self.mark_receiver);
    let mut lock = self.allocated.lock().await;
    let removed = lock.remove(&port);
    if removed {
      tracing::trace!(port = port, "unbound port");
    }
    let mut mark_receiver = mark_receiver.lock().await;
    Self::cleanup_freed_ports(&mut *lock, &mut *mark_receiver);
    Ok(removed)
  }

  fn cleanup_freed_ports(
    allocations: &mut HashSet<u16>,
    mark_receiver: &mut tokio::sync::mpsc::UnboundedReceiver<u16>,
  ) {
    // recv waits forever if a sender can still produce values
    // skip that by only receiving those immediately available
    // HACK: Relies on unbounded receivers being immediately available without intermediate polling
    while let Some(Some(marked)) = mark_receiver.recv().now_or_never() {
      let removed = allocations.remove(&marked);
      if removed {
        tracing::trace!(port = marked, "unbound marked port");
      }
    }
  }

  pub fn mark_freed(&self, port: u16) {
    match self.allocated.try_lock() {
      // fast path if synchronous is possible, skipping the mark queue
      Ok(mut allocations) => {
        // remove specified port; we don't care if it succeeded or not
        let _removed = allocations.remove(&port);
        return;
      }
      Err(_would_block) => {
        match self.mark_queue.send(port) {
          // Message queued, do nothing
          Ok(()) => (),
          // Other side was closed
          // Without a receiver, we don't actually need to free anything, so do nothing
          Err(_send_error) => (),
        }
      }
    }
  }

  pub fn range(&self) -> &RangeInclusive<u16> {
    &self.range
  }
}

pub struct PortRangeAllocationHandle {
  port: u16,
  allocated_in: Option<PortRangeAllocator>,
}

impl PortRangeAllocationHandle {
  pub fn new(port: u16, allocated_in: PortRangeAllocator) -> Self {
    Self {
      port,
      allocated_in: Some(allocated_in),
    }
  }
  pub fn port(&self) -> u16 {
    self.port
  }
}

impl Drop for PortRangeAllocationHandle {
  fn drop(&mut self) {
    match std::mem::replace(&mut self.allocated_in, None) {
      None => (),
      Some(allocator) => {
        allocator.mark_freed(self.port);
      }
    }
  }
}