use std::collections::{HashMap, VecDeque};
use thiserror::Error;
use tokio::sync::broadcast;
use tokio::sync::{mpsc, oneshot};
const PERMIT_GRANT_BATCH_SIZE: usize = 64;
#[derive(Debug)]
pub struct PermitGuard {
pub resource_type: ResourceType,
control_tx: mpsc::Sender<ControlCommand>,
}
impl Drop for PermitGuard {
fn drop(&mut self) {
let _ = self.control_tx.try_send(ControlCommand::Release {
resource: self.resource_type,
});
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum ResourceType {
Reserve,
PeerConnection,
DiskRead,
DiskWrite,
}
#[derive(Error, Debug, Clone)]
pub enum ResourceManagerError {
#[error("The resource manager has been shut down.")]
ManagerShutdown,
#[error("The request queue for the resource is full.")]
QueueFull,
}
#[derive(Clone, Debug)]
pub struct ResourceManagerClient {
acquire_txs: HashMap<ResourceType, mpsc::Sender<AcquireCommand>>,
control_tx: mpsc::Sender<ControlCommand>,
}
impl ResourceManagerClient {
pub async fn acquire_peer_connection(&self) -> Result<PermitGuard, ResourceManagerError> {
self.acquire(ResourceType::PeerConnection).await
}
pub async fn acquire_disk_read(&self) -> Result<PermitGuard, ResourceManagerError> {
self.acquire(ResourceType::DiskRead).await
}
pub async fn acquire_disk_write(&self) -> Result<PermitGuard, ResourceManagerError> {
self.acquire(ResourceType::DiskWrite).await
}
pub async fn update_limits(
&self,
new_limits: HashMap<ResourceType, usize>,
) -> Result<(), ResourceManagerError> {
let command = ControlCommand::UpdateLimits { limits: new_limits };
self.control_tx
.send(command)
.await
.map_err(|_| ResourceManagerError::ManagerShutdown)
}
async fn acquire(&self, resource: ResourceType) -> Result<PermitGuard, ResourceManagerError> {
let (respond_to, rx) = oneshot::channel();
let command = AcquireCommand { respond_to };
let tx = self.acquire_txs.get(&resource).unwrap();
tx.send(command)
.await
.map_err(|_| ResourceManagerError::ManagerShutdown)?;
match rx.await {
Ok(result) => result,
Err(_) => Err(ResourceManagerError::ManagerShutdown),
}
}
}
#[derive(Debug)]
struct AcquireCommand {
respond_to: oneshot::Sender<Result<PermitGuard, ResourceManagerError>>,
}
#[derive(Debug)]
pub enum ControlCommand {
Release {
resource: ResourceType,
},
UpdateLimits {
limits: HashMap<ResourceType, usize>,
},
ProcessQueue {
resource: ResourceType,
},
}
pub struct ResourceManager {
acquire_rxs: HashMap<ResourceType, mpsc::Receiver<AcquireCommand>>,
control_rx: mpsc::Receiver<ControlCommand>,
control_tx: mpsc::Sender<ControlCommand>,
resources: HashMap<ResourceType, ResourceState>,
shutdown_tx: broadcast::Sender<()>,
}
struct ResourceState {
limit: usize,
in_use: usize,
max_queue_size: usize,
wait_queue: VecDeque<oneshot::Sender<Result<PermitGuard, ResourceManagerError>>>,
}
impl ResourceManager {
pub fn new(
limits: HashMap<ResourceType, (usize, usize)>,
shutdown_tx: broadcast::Sender<()>,
) -> (Self, ResourceManagerClient) {
let (control_tx, control_rx) = mpsc::channel(256);
let mut acquire_txs = HashMap::new();
let mut acquire_rxs = HashMap::new();
let mut resources = HashMap::new();
for (res_type, (limit, max_queue_size)) in limits.iter() {
resources.insert(
*res_type,
ResourceState {
limit: *limit,
in_use: 0,
max_queue_size: *max_queue_size,
wait_queue: VecDeque::new(),
},
);
if *res_type != ResourceType::Reserve {
let (tx, rx) = mpsc::channel(256);
acquire_txs.insert(*res_type, tx);
acquire_rxs.insert(*res_type, rx);
}
}
let client = ResourceManagerClient {
acquire_txs,
control_tx: control_tx.clone(),
};
let actor = Self {
acquire_rxs,
control_rx,
control_tx,
resources,
shutdown_tx,
};
(actor, client)
}
pub async fn run(mut self) {
let mut peer_rx = self
.acquire_rxs
.remove(&ResourceType::PeerConnection)
.unwrap();
let mut read_rx = self.acquire_rxs.remove(&ResourceType::DiskRead).unwrap();
let mut write_rx = self.acquire_rxs.remove(&ResourceType::DiskWrite).unwrap();
let mut shutdown_rx = self.shutdown_tx.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.recv() => break,
Some(cmd) = peer_rx.recv() => self.handle_acquire(ResourceType::PeerConnection, cmd.respond_to),
Some(cmd) = read_rx.recv() => self.handle_acquire(ResourceType::DiskRead, cmd.respond_to),
Some(cmd) = write_rx.recv() => self.handle_acquire(ResourceType::DiskWrite, cmd.respond_to),
Some(cmd) = self.control_rx.recv() => {
match cmd {
ControlCommand::Release { resource } => self.handle_release(resource),
ControlCommand::UpdateLimits { limits } => self.handle_update_limits(limits),
ControlCommand::ProcessQueue { resource } => self.handle_process_queue(resource),
}
},
else => { break; }
}
}
}
fn handle_acquire(
&mut self,
resource: ResourceType,
respond_to: oneshot::Sender<Result<PermitGuard, ResourceManagerError>>,
) {
let state = self.resources.get_mut(&resource).unwrap();
if state.in_use < state.limit {
state.in_use += 1;
let guard = PermitGuard {
resource_type: resource,
control_tx: self.control_tx.clone(),
};
let _ = respond_to.send(Ok(guard));
} else if state.wait_queue.len() < state.max_queue_size {
state.wait_queue.push_back(respond_to);
} else {
let _ = respond_to.send(Err(ResourceManagerError::QueueFull));
}
}
fn handle_release(&mut self, resource: ResourceType) {
let state = self.resources.get_mut(&resource).unwrap();
state.in_use = state.in_use.saturating_sub(1);
let _ = self
.control_tx
.try_send(ControlCommand::ProcessQueue { resource });
}
fn handle_update_limits(&mut self, limits: HashMap<ResourceType, usize>) {
for (resource, new_limit) in limits {
if let Some(state) = self.resources.get_mut(&resource) {
state.limit = new_limit;
let _ = self
.control_tx
.try_send(ControlCommand::ProcessQueue { resource });
}
}
}
fn handle_process_queue(&mut self, resource: ResourceType) {
let state = self.resources.get_mut(&resource).unwrap();
for _ in 0..PERMIT_GRANT_BATCH_SIZE {
if state.in_use >= state.limit {
return;
}
if let Some(next_in_line) = state.wait_queue.pop_front() {
if !next_in_line.is_closed() {
state.in_use += 1;
let guard = PermitGuard {
resource_type: resource,
control_tx: self.control_tx.clone(),
};
if next_in_line.send(Ok(guard)).is_err() {
state.in_use -= 1;
}
}
} else {
return;
}
}
if state.in_use < state.limit && !state.wait_queue.is_empty() {
let _ = self
.control_tx
.try_send(ControlCommand::ProcessQueue { resource });
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::{advance, sleep, timeout};
fn create_limits(
peer: (usize, usize),
read: (usize, usize),
write: (usize, usize),
) -> HashMap<ResourceType, (usize, usize)> {
let mut limits = HashMap::new();
limits.insert(ResourceType::PeerConnection, peer);
limits.insert(ResourceType::DiskRead, read);
limits.insert(ResourceType::DiskWrite, write);
limits
}
fn setup_manager(
limits: HashMap<ResourceType, (usize, usize)>,
) -> (ResourceManagerClient, tokio::task::JoinHandle<()>) {
let (shutdown_tx, _) = broadcast::channel(1);
let (actor, client) = ResourceManager::new(limits, shutdown_tx);
let handle = tokio::spawn(actor.run());
(client, handle)
}
fn create_trial_limits(
resource: ResourceType,
limit: usize,
queue: usize,
) -> HashMap<ResourceType, (usize, usize)> {
let mut limits = create_limits((1, 0), (1, 0), (1, 0));
match resource {
ResourceType::PeerConnection => {
limits.insert(ResourceType::PeerConnection, (limit, queue));
}
ResourceType::DiskRead => {
limits.insert(ResourceType::DiskRead, (limit, queue));
}
ResourceType::DiskWrite => {
limits.insert(ResourceType::DiskWrite, (limit, queue));
}
ResourceType::Reserve => {}
}
limits
}
async fn measure_throughput_for_resource(resource: ResourceType, limit: usize) -> usize {
let queue_size = 20_000;
let worker_count = 64;
let work_time = Duration::from_millis(10);
let run_steps = 120;
let limits = create_trial_limits(resource, limit, queue_size);
let (client, manager_handle) = setup_manager(limits);
let completed = Arc::new(AtomicUsize::new(0));
let stop = Arc::new(AtomicBool::new(false));
let mut workers = Vec::new();
for _ in 0..worker_count {
let worker_client = client.clone();
let worker_completed = completed.clone();
let worker_stop = stop.clone();
workers.push(tokio::spawn(async move {
loop {
if worker_stop.load(Ordering::Relaxed) {
break;
}
let permit_result = match resource {
ResourceType::PeerConnection => {
worker_client.acquire_peer_connection().await
}
ResourceType::DiskRead => worker_client.acquire_disk_read().await,
ResourceType::DiskWrite => worker_client.acquire_disk_write().await,
ResourceType::Reserve => unreachable!("Reserve is not acquirable"),
};
let permit = match permit_result {
Ok(permit) => permit,
Err(ResourceManagerError::QueueFull) => {
tokio::task::yield_now().await;
continue;
}
Err(ResourceManagerError::ManagerShutdown) => break,
};
sleep(work_time).await;
drop(permit);
worker_completed.fetch_add(1, Ordering::Relaxed);
}
}));
}
for _ in 0..run_steps {
tokio::task::yield_now().await;
advance(work_time).await;
}
stop.store(true, Ordering::Relaxed);
tokio::task::yield_now().await;
advance(work_time).await;
tokio::task::yield_now().await;
for worker in workers {
worker.abort();
let _ = worker.await;
}
manager_handle.abort();
let _ = manager_handle.await;
completed.load(Ordering::Relaxed)
}
#[tokio::test]
async fn test_acquire_release_success() {
let limits = create_limits((1, 1), (0, 0), (0, 0));
let (client, _handle) = setup_manager(limits);
let guard1 = client.acquire_peer_connection().await;
assert!(guard1.is_ok());
drop(guard1);
let guard2 = client.acquire_peer_connection().await;
assert!(guard2.is_ok());
}
#[tokio::test]
async fn test_acquire_blocks_and_wakes() {
let limits = create_limits((1, 1), (0, 0), (0, 0));
let (client, _handle) = setup_manager(limits);
let guard1 = client.acquire_peer_connection().await.unwrap();
let client_clone = client.clone();
let acquire_task =
tokio::spawn(async move { client_clone.acquire_peer_connection().await });
sleep(Duration::from_millis(50)).await;
assert!(
!acquire_task.is_finished(),
"Acquire did not block when it should have"
);
drop(guard1);
let result = timeout(Duration::from_millis(100), acquire_task).await;
assert!(result.is_ok(), "Task timed out, did not unblock");
let inner_result = result.unwrap(); assert!(inner_result.is_ok(), "Task join failed"); assert!(inner_result.unwrap().is_ok(), "Acquire task failed"); }
#[tokio::test]
async fn test_queue_full_rejection() {
let limits = create_limits((1, 1), (0, 0), (0, 0));
let (client, _handle) = setup_manager(limits);
let guard1 = client.acquire_peer_connection().await.unwrap();
let client_clone = client.clone();
let acquire_task2 =
tokio::spawn(async move { client_clone.acquire_peer_connection().await });
sleep(Duration::from_millis(50)).await;
assert!(!acquire_task2.is_finished());
let result = client.acquire_peer_connection().await;
match result {
Err(ResourceManagerError::QueueFull) => { }
_ => panic!("Expected QueueFull, got {:?}", result),
}
drop(guard1);
let _ = acquire_task2.await;
}
#[tokio::test]
async fn test_update_limit_increase_wakes_waiters() {
let limits = create_limits((1, 1), (0, 0), (0, 0));
let (client, _handle) = setup_manager(limits);
let _guard1 = client.acquire_peer_connection().await.unwrap();
let client_clone = client.clone();
let acquire_task =
tokio::spawn(async move { client_clone.acquire_peer_connection().await });
sleep(Duration::from_millis(50)).await;
assert!(!acquire_task.is_finished());
let mut new_limits = HashMap::new();
new_limits.insert(ResourceType::PeerConnection, 2);
client.update_limits(new_limits).await.unwrap();
let result = timeout(Duration::from_millis(100), acquire_task).await;
assert!(
result.is_ok(),
"Task timed out, did not unblock after limit update"
);
let inner_result = result.unwrap();
assert!(inner_result.is_ok(), "Task join failed");
assert!(inner_result.unwrap().is_ok(), "Acquire task failed");
}
#[tokio::test]
async fn test_update_limit_decrease() {
let limits = create_limits((2, 1), (0, 0), (0, 0));
let (client, _handle) = setup_manager(limits);
let guard1 = client.acquire_peer_connection().await.unwrap();
let guard2 = client.acquire_peer_connection().await.unwrap();
let mut new_limits = HashMap::new();
new_limits.insert(ResourceType::PeerConnection, 1);
client.update_limits(new_limits).await.unwrap();
let client_clone = client.clone();
let acquire_task =
tokio::spawn(async move { client_clone.acquire_peer_connection().await });
sleep(Duration::from_millis(50)).await;
assert!(!acquire_task.is_finished());
drop(guard1);
sleep(Duration::from_millis(50)).await; assert!(!acquire_task.is_finished(), "Task unblocked too early");
drop(guard2);
let result = timeout(Duration::from_millis(100), acquire_task).await;
assert!(
result.is_ok(),
"Task did not unblock after second guard dropped"
);
let inner_result = result.unwrap();
assert!(inner_result.is_ok(), "Task join failed");
assert!(inner_result.unwrap().is_ok(), "Acquire task failed");
}
#[tokio::test]
async fn test_resources_are_independent() {
let limits = create_limits((1, 1), (1, 1), (0, 0));
let (client, _handle) = setup_manager(limits);
let _peer_guard = client.acquire_peer_connection().await.unwrap();
let client_clone = client.clone();
let peer_task = tokio::spawn(async move { client_clone.acquire_peer_connection().await });
sleep(Duration::from_millis(50)).await;
assert!(
!peer_task.is_finished(),
"Peer connection acquire did not block"
);
let read_result = client.acquire_disk_read().await;
assert!(
read_result.is_ok(),
"DiskRead acquire failed, was blocked by PeerConnection"
);
let write_result = client.acquire_disk_write().await;
match write_result {
Err(ResourceManagerError::QueueFull) => { }
_ => panic!("Expected QueueFull for 0-limit resource"),
}
drop(_peer_guard);
let _ = peer_task.await;
}
#[tokio::test]
async fn test_manager_shutdown() {
let limits = create_limits((1, 1), (0, 0), (0, 0));
let (client, handle) = setup_manager(limits);
handle.abort();
sleep(Duration::from_millis(20)).await;
let result = client.acquire_peer_connection().await;
match result {
Err(ResourceManagerError::ManagerShutdown) => { }
_ => panic!("Expected ManagerShutdown, got {:?}", result),
}
let result_update = client.update_limits(HashMap::new()).await;
match result_update {
Err(ResourceManagerError::ManagerShutdown) => { }
_ => panic!("Expected ManagerShutdown, got {:?}", result_update),
}
}
#[tokio::test]
async fn test_multiple_waiters_are_woken() {
let limit = 5;
let queue = 5;
let limits = create_limits((limit, queue), (0, 0), (0, 0));
let (client, _handle) = setup_manager(limits);
let mut guards = Vec::new();
for _ in 0..limit {
guards.push(client.acquire_peer_connection().await.unwrap());
}
let mut tasks = Vec::new();
for _ in 0..queue {
let client_clone = client.clone();
tasks.push(tokio::spawn(async move {
client_clone.acquire_peer_connection().await
}));
}
sleep(Duration::from_millis(50)).await;
for (i, task) in tasks.iter().enumerate() {
assert!(!task.is_finished(), "Task {} finished early", i);
}
drop(guards);
for (i, task) in tasks.into_iter().enumerate() {
let res = timeout(Duration::from_millis(100), task).await;
assert!(res.is_ok(), "Task {} timed out waiting to join", i);
let join_res = res.unwrap();
assert!(join_res.is_ok(), "Task {} join error", i);
assert!(join_res.unwrap().is_ok(), "Task {} acquire failed", i);
}
}
#[tokio::test]
async fn test_dropped_waiter_does_not_leak_permit() {
let limits = create_limits((1, 2), (0, 0), (0, 0));
let (client, _handle) = setup_manager(limits);
let guard1 = client.acquire_peer_connection().await.unwrap();
let client_clone = client.clone();
let waiting_task = tokio::spawn(async move {
client_clone.acquire_peer_connection().await
});
sleep(Duration::from_millis(20)).await;
waiting_task.abort();
sleep(Duration::from_millis(20)).await;
drop(guard1);
sleep(Duration::from_millis(20)).await;
let result = timeout(Duration::from_millis(100), client.acquire_peer_connection()).await;
assert!(
result.is_ok(),
"Permit leaked! The aborted waiter consumed a slot."
);
assert!(result.unwrap().is_ok());
}
#[tokio::test(start_paused = true)]
async fn test_disk_permit_throughput_roughly_halves_when_limit_halves() {
let baseline_limit = 16;
let half_limit = baseline_limit / 2;
let read_baseline =
measure_throughput_for_resource(ResourceType::DiskRead, baseline_limit).await;
let read_half = measure_throughput_for_resource(ResourceType::DiskRead, half_limit).await;
assert!(
read_baseline > 0,
"Read baseline throughput should be non-zero"
);
let read_ratio = read_half as f64 / read_baseline as f64;
assert!(
(0.35..=0.75).contains(&read_ratio),
"DiskRead throughput did not scale as expected: baseline={}, half={}, ratio={:.3}",
read_baseline,
read_half,
read_ratio
);
let write_baseline =
measure_throughput_for_resource(ResourceType::DiskWrite, baseline_limit).await;
let write_half = measure_throughput_for_resource(ResourceType::DiskWrite, half_limit).await;
assert!(
write_baseline > 0,
"Write baseline throughput should be non-zero"
);
let write_ratio = write_half as f64 / write_baseline as f64;
assert!(
(0.35..=0.75).contains(&write_ratio),
"DiskWrite throughput did not scale as expected: baseline={}, half={}, ratio={:.3}",
write_baseline,
write_half,
write_ratio
);
}
#[tokio::test(start_paused = true)]
async fn test_peer_limit_throughput_roughly_halves_when_limit_halves() {
let baseline_limit = 16;
let half_limit = baseline_limit / 2;
let baseline =
measure_throughput_for_resource(ResourceType::PeerConnection, baseline_limit).await;
let half = measure_throughput_for_resource(ResourceType::PeerConnection, half_limit).await;
assert!(baseline > 0, "Peer baseline throughput should be non-zero");
let ratio = half as f64 / baseline as f64;
assert!(
(0.35..=0.75).contains(&ratio),
"Peer throughput did not scale as expected: baseline={}, half={}, ratio={:.3}",
baseline,
half,
ratio
);
}
}