use std::future::Future;
use extism::{CurrentPlugin, Error, Val};
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
pub(crate) const MAX_GUEST_PAYLOAD_LEN: u64 = 10 * 1024 * 1024;
pub(crate) const MAX_PATH_LEN: u64 = 4 * 1024;
pub(crate) const MAX_LOG_MESSAGE_LEN: u64 = 64 * 1024;
pub(crate) const MAX_KEY_LEN: u64 = 4 * 1024;
#[expect(clippy::cast_sign_loss)]
pub(crate) fn get_safe_string(
plugin: &mut CurrentPlugin,
val: &Val,
limit: u64,
) -> Result<String, Error> {
let ptr = match val {
Val::I64(v) => *v as u64,
Val::I32(v) => u64::from(*v as u32),
_ => return Err(Error::msg("expected memory pointer value")),
};
let len = plugin.memory_length(ptr)?;
if len > limit {
return Err(Error::msg(format!(
"memory allocation of {len} bytes exceeds maximum allowed limit of {limit} bytes"
)));
}
let safe_val =
Val::I64(i64::try_from(ptr).map_err(|_| Error::msg("pointer value out of i64 range"))?);
plugin.memory_get_val(&safe_val)
}
#[expect(clippy::cast_sign_loss)]
pub(crate) fn get_safe_bytes(
plugin: &mut CurrentPlugin,
val: &Val,
limit: u64,
) -> Result<Vec<u8>, Error> {
let ptr = match val {
Val::I64(v) => *v as u64,
Val::I32(v) => u64::from(*v as u32),
_ => return Err(Error::msg("expected memory pointer value")),
};
let len = plugin.memory_length(ptr)?;
if len > limit {
return Err(Error::msg(format!(
"memory allocation of {len} bytes exceeds maximum allowed limit of {limit} bytes"
)));
}
let safe_val =
Val::I64(i64::try_from(ptr).map_err(|_| Error::msg("pointer value out of i64 range"))?);
let memory: Vec<u8> = plugin.memory_get_val(&safe_val)?;
Ok(memory)
}
pub(crate) const HOST_RESULT_OK: u8 = 0x00;
pub(crate) const HOST_RESULT_ERR: u8 = 0x01;
pub(crate) fn write_host_result(
plugin: &mut CurrentPlugin,
outputs: &mut [Val],
result: Result<Vec<u8>, String>,
) -> Result<(), Error> {
let (status, data) = match result {
Ok(payload) => (HOST_RESULT_OK, payload),
Err(msg) => (HOST_RESULT_ERR, msg.into_bytes()),
};
let mut buf = Vec::with_capacity(1 + data.len());
buf.push(status);
buf.extend_from_slice(&data);
let mem = plugin.memory_new(&buf)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
pub(crate) fn bounded_block_on<F, T>(
handle: &tokio::runtime::Handle,
semaphore: &Semaphore,
fut: F,
) -> T
where
F: Future<Output = T>,
{
tokio::task::block_in_place(|| {
handle.block_on(async {
let _permit = semaphore
.acquire()
.await
.expect("host semaphore closed: capsule HostState was dropped");
fut.await
})
})
}
pub(crate) fn bounded_block_on_cancellable<F, T>(
handle: &tokio::runtime::Handle,
semaphore: &Semaphore,
cancel_token: &CancellationToken,
fut: F,
) -> Option<T>
where
F: Future<Output = T>,
{
if cancel_token.is_cancelled() {
return None;
}
tokio::task::block_in_place(|| {
handle.block_on(async {
tokio::select! {
biased;
() = cancel_token.cancelled() => None,
result = async {
let _permit = semaphore
.acquire()
.await
.expect("host semaphore closed: capsule HostState was dropped");
fut.await
} => Some(result),
}
})
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bounded_block_on_limits_concurrency() {
let semaphore = Arc::new(Semaphore::new(2));
let handle = tokio::runtime::Handle::current();
let concurrent = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let mut tasks = Vec::new();
for _ in 0..6 {
let sem = semaphore.clone();
let h = handle.clone();
let c = concurrent.clone();
let mc = max_concurrent.clone();
tasks.push(tokio::task::spawn(async move {
bounded_block_on(&h, &sem, async {
let current = c.fetch_add(1, Ordering::SeqCst) + 1;
mc.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
c.fetch_sub(1, Ordering::SeqCst);
});
}));
}
for t in tasks {
t.await.unwrap();
}
let max = max_concurrent.load(Ordering::SeqCst);
assert!(max <= 2, "max concurrent was {max} but should be <= 2");
assert!(
max >= 1,
"expected at least 1 concurrent execution, got {max}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn bounded_block_on_propagates_result() {
let semaphore = Semaphore::new(4);
let handle = tokio::runtime::Handle::current();
let result: Result<u32, &str> = bounded_block_on(&handle, &semaphore, async { Ok(42) });
assert_eq!(result.unwrap(), 42);
let err: Result<u32, &str> = bounded_block_on(&handle, &semaphore, async { Err("fail") });
assert_eq!(err.unwrap_err(), "fail");
}
#[tokio::test(flavor = "multi_thread")]
async fn cancellation_unblocks_bounded_block_on_cancellable() {
let semaphore = Arc::new(Semaphore::new(4));
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let sem = semaphore.clone();
let h = handle.clone();
let ct = cancel_token.clone();
let cancel = cancel_token.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cancel.cancel();
});
let result = tokio::task::spawn(async move {
bounded_block_on_cancellable(&h, &sem, &ct, async {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
42u32
})
})
.await
.unwrap();
assert!(result.is_none(), "expected None on cancellation");
}
#[tokio::test(flavor = "multi_thread")]
async fn bounded_block_on_cancellable_pre_cancelled() {
let semaphore = Semaphore::new(4);
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
cancel_token.cancel();
let result: Option<u32> =
bounded_block_on_cancellable(&handle, &semaphore, &cancel_token, async {
panic!("future should never execute when token is pre-cancelled");
});
assert!(result.is_none(), "expected None for pre-cancelled token");
}
#[tokio::test(flavor = "multi_thread")]
async fn bounded_block_on_cancellable_normal_completion() {
let semaphore = Semaphore::new(4);
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let result: Option<Result<u32, &str>> =
bounded_block_on_cancellable(&handle, &semaphore, &cancel_token, async { Ok(42) });
assert_eq!(result.unwrap().unwrap(), 42);
let err: Option<Result<u32, &str>> =
bounded_block_on_cancellable(&handle, &semaphore, &cancel_token, async { Err("fail") });
assert_eq!(err.unwrap().unwrap_err(), "fail");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bounded_block_on_cancellable_limits_concurrency() {
let semaphore = Arc::new(Semaphore::new(2));
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let concurrent = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let mut tasks = Vec::new();
for _ in 0..6 {
let sem = semaphore.clone();
let h = handle.clone();
let ct = cancel_token.clone();
let c = concurrent.clone();
let mc = max_concurrent.clone();
tasks.push(tokio::task::spawn(async move {
bounded_block_on_cancellable(&h, &sem, &ct, async {
let current = c.fetch_add(1, Ordering::SeqCst) + 1;
mc.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
c.fetch_sub(1, Ordering::SeqCst);
});
}));
}
for t in tasks {
t.await.unwrap();
}
let max = max_concurrent.load(Ordering::SeqCst);
assert!(max <= 2, "max concurrent was {max} but should be <= 2");
assert!(
max >= 1,
"expected at least 1 concurrent execution, got {max}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bounded_block_on_cancellable_cancel_while_queued_for_permit() {
let semaphore = Arc::new(Semaphore::new(1));
let handle = tokio::runtime::Handle::current();
let cancel_token = CancellationToken::new();
let _permit = semaphore.acquire().await.unwrap();
let ct = cancel_token.clone();
let sem = semaphore.clone();
let h = handle.clone();
let task =
tokio::task::spawn(
async move { bounded_block_on_cancellable(&h, &sem, &ct, async { 42 }) },
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cancel_token.cancel();
let result = task.await.unwrap();
assert!(
result.is_none(),
"expected None (cancelled), got {result:?}"
);
}
}