use std::future::Future;
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
pub(crate) const MAX_GUEST_PAYLOAD_LEN: u64 = 10 * 1024 * 1024;
pub(crate) fn bounded_block_on<F, T>(
handle: &tokio::runtime::Handle,
semaphore: &Semaphore,
fut: F,
) -> T
where
F: Future<Output = T>,
{
tokio::task::block_in_place(|| {
handle.block_on(async {
let _permit = semaphore
.acquire()
.await
.expect("host semaphore closed: capsule HostState was dropped");
fut.await
})
})
}
pub(crate) fn bounded_block_on_cancellable<F, T>(
handle: &tokio::runtime::Handle,
semaphore: &Semaphore,
cancel_token: &CancellationToken,
fut: F,
) -> Option<T>
where
F: Future<Output = T>,
{
if cancel_token.is_cancelled() {
return None;
}
tokio::task::block_in_place(|| {
handle.block_on(async {
tokio::select! {
biased;
() = cancel_token.cancelled() => None,
result = async {
let _permit = semaphore
.acquire()
.await
.expect("host semaphore closed: capsule HostState was dropped");
fut.await
} => Some(result),
}
})
})
}
pub(crate) async fn bounded_await<F, T>(semaphore: &Semaphore, fut: F) -> T
where
F: Future<Output = T>,
{
let _permit = semaphore
.acquire()
.await
.expect("host semaphore closed: capsule HostState was dropped");
fut.await
}
pub(crate) async fn bounded_await_cancellable<F, T>(
semaphore: &Semaphore,
cancel_token: &CancellationToken,
fut: F,
) -> Option<T>
where
F: Future<Output = T>,
{
if cancel_token.is_cancelled() {
return None;
}
tokio::select! {
biased;
() = cancel_token.cancelled() => None,
result = async {
let _permit = semaphore
.acquire()
.await
.expect("host semaphore closed: capsule HostState was dropped");
fut.await
} => Some(result),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bounded_block_on_limits_concurrency() {
let semaphore = Arc::new(Semaphore::new(2));
let handle = tokio::runtime::Handle::current();
let concurrent = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let mut tasks = Vec::new();
for _ in 0..6 {
let sem = semaphore.clone();
let h = handle.clone();
let c = concurrent.clone();
let mc = max_concurrent.clone();
tasks.push(tokio::task::spawn(async move {
bounded_block_on(&h, &sem, async {
let current = c.fetch_add(1, Ordering::SeqCst) + 1;
mc.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
c.fetch_sub(1, Ordering::SeqCst);
});
}));
}
for t in tasks {
t.await.unwrap();
}
let max = max_concurrent.load(Ordering::SeqCst);
assert!(max <= 2, "max concurrent was {max} but should be <= 2");
assert!(
max >= 1,
"expected at least 1 concurrent execution, got {max}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn bounded_block_on_propagates_result() {
let semaphore = Semaphore::new(4);
let handle = tokio::runtime::Handle::current();
let result: Result<u32, &str> = bounded_block_on(&handle, &semaphore, async { Ok(42) });
assert_eq!(result.unwrap(), 42);
let err: Result<u32, &str> = bounded_block_on(&handle, &semaphore, async { Err("fail") });
assert_eq!(err.unwrap_err(), "fail");
}
#[tokio::test(flavor = "multi_thread")]
async fn cancellation_unblocks_bounded_block_on_cancellable() {
let semaphore = Arc::new(Semaphore::new(4));
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let sem = semaphore.clone();
let h = handle.clone();
let ct = cancel_token.clone();
let cancel = cancel_token.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cancel.cancel();
});
let result = tokio::task::spawn(async move {
bounded_block_on_cancellable(&h, &sem, &ct, async {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
42u32
})
})
.await
.unwrap();
assert!(result.is_none(), "expected None on cancellation");
}
#[tokio::test(flavor = "multi_thread")]
async fn bounded_block_on_cancellable_pre_cancelled() {
let semaphore = Semaphore::new(4);
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
cancel_token.cancel();
let result: Option<u32> =
bounded_block_on_cancellable(&handle, &semaphore, &cancel_token, async {
panic!("future should never execute when token is pre-cancelled");
});
assert!(result.is_none(), "expected None for pre-cancelled token");
}
#[tokio::test(flavor = "multi_thread")]
async fn bounded_block_on_cancellable_normal_completion() {
let semaphore = Semaphore::new(4);
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let result: Option<Result<u32, &str>> =
bounded_block_on_cancellable(&handle, &semaphore, &cancel_token, async { Ok(42) });
assert_eq!(result.unwrap().unwrap(), 42);
let err: Option<Result<u32, &str>> =
bounded_block_on_cancellable(&handle, &semaphore, &cancel_token, async { Err("fail") });
assert_eq!(err.unwrap().unwrap_err(), "fail");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bounded_block_on_cancellable_limits_concurrency() {
let semaphore = Arc::new(Semaphore::new(2));
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let concurrent = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let mut tasks = Vec::new();
for _ in 0..6 {
let sem = semaphore.clone();
let h = handle.clone();
let ct = cancel_token.clone();
let c = concurrent.clone();
let mc = max_concurrent.clone();
tasks.push(tokio::task::spawn(async move {
bounded_block_on_cancellable(&h, &sem, &ct, async {
let current = c.fetch_add(1, Ordering::SeqCst) + 1;
mc.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
c.fetch_sub(1, Ordering::SeqCst);
});
}));
}
for t in tasks {
t.await.unwrap();
}
let max = max_concurrent.load(Ordering::SeqCst);
assert!(max <= 2, "max concurrent was {max} but should be <= 2");
assert!(
max >= 1,
"expected at least 1 concurrent execution, got {max}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bounded_block_on_cancellable_cancel_while_queued_for_permit() {
let semaphore = Arc::new(Semaphore::new(1));
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let _permit = semaphore.acquire().await.unwrap();
let ct = cancel_token.clone();
let sem = semaphore.clone();
let h = handle.clone();
let task =
tokio::task::spawn(
async move { bounded_block_on_cancellable(&h, &sem, &ct, async { 42 }) },
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cancel_token.cancel();
let result = task.await.unwrap();
assert!(
result.is_none(),
"expected None (cancelled), got {result:?}"
);
}
}