#[cfg(feature = "p3")]
use crate::p3;
use futures::stream::{FuturesUnordered, StreamExt};
use std::collections::VecDeque;
use std::collections::btree_map::{BTreeMap, Entry};
use std::future;
use std::pin::{Pin, pin};
use std::sync::{
Arc, Mutex,
atomic::{
AtomicBool, AtomicU64, AtomicUsize,
Ordering::{Relaxed, SeqCst},
},
};
use std::task::Poll;
use std::time::{Duration, Instant};
use tokio::sync::Notify;
use wasmtime::AsContextMut;
use wasmtime::component::Accessor;
use wasmtime::{Result, Store, StoreContextMut, format_err};
#[cfg(feature = "p2")]
pub mod p2 {
#[expect(missing_docs, reason = "bindgen-generated code")]
pub mod bindings {
wasmtime::component::bindgen!({
path: "wit",
world: "wasi:http/proxy",
imports: { default: tracing },
exports: { default: async | store },
require_store_data_send: true,
with: {
"wasi:http": crate::p2::bindings::http,
"wasi:io": wasmtime_wasi::p2::bindings::io,
}
});
pub use wasi::*;
}
}
pub enum ProxyPre<T: 'static> {
#[cfg(feature = "p2")]
P2(p2::bindings::ProxyPre<T>),
#[cfg(feature = "p3")]
P3(p3::bindings::ServicePre<T>),
}
impl<T: 'static> ProxyPre<T> {
async fn instantiate_async(&self, store: impl AsContextMut<Data = T>) -> Result<Proxy>
where
T: Send,
{
Ok(match self {
#[cfg(feature = "p2")]
Self::P2(pre) => Proxy::P2(pre.instantiate_async(store).await?),
#[cfg(feature = "p3")]
Self::P3(pre) => Proxy::P3(pre.instantiate_async(store).await?),
})
}
}
pub enum Proxy {
#[cfg(feature = "p2")]
P2(p2::bindings::Proxy),
#[cfg(feature = "p3")]
P3(p3::bindings::Service),
}
pub type TaskFn<T> = Box<
dyn for<'a> FnOnce(&'a Accessor<T>, &'a Proxy) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
+ Send,
>;
struct Queue<T> {
queue: Mutex<VecDeque<T>>,
notify: Notify,
}
impl<T> Default for Queue<T> {
fn default() -> Self {
Self {
queue: Default::default(),
notify: Default::default(),
}
}
}
impl<T> Queue<T> {
fn is_empty(&self) -> bool {
self.queue.lock().unwrap().is_empty()
}
fn push(&self, item: T) {
self.queue.lock().unwrap().push_back(item);
self.notify.notify_one();
}
fn try_pop(&self) -> Option<T> {
self.queue.lock().unwrap().pop_front()
}
async fn pop(&self) -> T {
let mut notified = pin!(self.notify.notified());
loop {
notified.as_mut().enable();
if let Some(item) = self.try_pop() {
return item;
}
notified.as_mut().await;
notified.set(self.notify.notified());
}
}
}
pub struct StoreBundle<T: 'static> {
pub store: Store<T>,
pub write_profile: Box<dyn FnOnce(StoreContextMut<T>) + Send>,
}
pub trait HandlerState: 'static + Sync + Send {
type StoreData: Send;
fn new_store(&self, req_id: Option<u64>) -> Result<StoreBundle<Self::StoreData>>;
fn request_timeout(&self) -> Duration;
fn idle_instance_timeout(&self) -> Duration;
fn max_instance_reuse_count(&self) -> usize;
fn max_instance_concurrent_reuse_count(&self) -> usize;
fn handle_worker_error(&self, error: wasmtime::Error);
}
struct ProxyHandlerInner<S: HandlerState> {
state: S,
instance_pre: ProxyPre<S::StoreData>,
next_id: AtomicU64,
task_queue: Queue<TaskFn<S::StoreData>>,
worker_count: AtomicUsize,
}
#[derive(Default)]
struct StartTimes(BTreeMap<Instant, usize>);
impl StartTimes {
fn add(&mut self, time: Instant) {
*self.0.entry(time).or_insert(0) += 1;
}
fn remove(&mut self, time: Instant) {
let Entry::Occupied(mut entry) = self.0.entry(time) else {
unreachable!()
};
match *entry.get() {
0 => unreachable!(),
1 => {
entry.remove();
}
_ => {
*entry.get_mut() -= 1;
}
}
}
fn earliest(&self) -> Option<Instant> {
self.0.first_key_value().map(|(&k, _)| k)
}
}
struct Worker<S>
where
S: HandlerState,
{
handler: ProxyHandler<S>,
available: bool,
}
impl<S> Worker<S>
where
S: HandlerState,
{
fn set_available(&mut self, available: bool) {
if available != self.available {
self.available = available;
if available {
self.handler.0.worker_count.fetch_add(1, Relaxed);
} else {
let count = self.handler.0.worker_count.fetch_sub(1, SeqCst);
if count == 1 && !self.handler.0.task_queue.is_empty() {
self.handler.start_worker(None, None);
}
}
}
}
async fn run(mut self, task: Option<TaskFn<S::StoreData>>, req_id: Option<u64>) {
if let Err(error) = self.run_(task, req_id).await {
self.handler.0.state.handle_worker_error(error);
}
}
async fn run_(
&mut self,
task: Option<TaskFn<S::StoreData>>,
req_id: Option<u64>,
) -> Result<()> {
let handler = &self.handler.0;
let StoreBundle {
mut store,
write_profile,
} = handler.state.new_store(req_id)?;
let request_timeout = handler.state.request_timeout();
let idle_instance_timeout = handler.state.idle_instance_timeout();
let max_instance_reuse_count = handler.state.max_instance_reuse_count();
let max_instance_concurrent_reuse_count =
handler.state.max_instance_concurrent_reuse_count();
let proxy = &handler.instance_pre.instantiate_async(&mut store).await?;
let accept_concurrent = AtomicBool::new(true);
let task_start_times = Mutex::new(StartTimes::default());
let mut future = pin!(store.run_concurrent(async |accessor| {
let mut reuse_count = 0;
let mut timed_out = false;
let mut futures = FuturesUnordered::new();
let accept_task = |task: TaskFn<S::StoreData>,
futures: &mut FuturesUnordered<_>,
reuse_count: &mut usize| {
accept_concurrent.store(false, Relaxed);
*reuse_count += 1;
let start_time = Instant::now().checked_add(request_timeout);
if let Some(start_time) = start_time {
task_start_times.lock().unwrap().add(start_time);
}
futures.push(tokio::time::timeout(request_timeout, async move {
(task)(accessor, proxy).await;
start_time
}));
};
if let Some(task) = task {
accept_task(task, &mut futures, &mut reuse_count);
}
let handler = self.handler.clone();
while !(futures.is_empty() && reuse_count >= max_instance_reuse_count) {
let new_task = {
let future_count = futures.len();
let mut next_future = pin!(async {
if futures.is_empty() {
future::pending().await
} else {
futures.next().await.unwrap()
}
});
let mut next_task = pin!(tokio::time::timeout(
if future_count == 0 {
idle_instance_timeout
} else {
Duration::MAX
},
handler.0.task_queue.pop()
));
future::poll_fn(|cx| match next_future.as_mut().poll(cx) {
Poll::Pending => {
self.set_available(
reuse_count < max_instance_reuse_count
&& future_count < max_instance_concurrent_reuse_count
&& (future_count == 0 || accept_concurrent.load(Relaxed)),
);
if self.available {
next_task.as_mut().poll(cx).map(Some)
} else {
Poll::Pending
}
}
Poll::Ready(Ok(start_time)) => {
if let Some(start_time) = start_time {
task_start_times.lock().unwrap().remove(start_time);
}
Poll::Ready(None)
}
Poll::Ready(Err(_)) => {
timed_out = true;
reuse_count = max_instance_reuse_count;
Poll::Ready(None)
}
})
.await
};
match new_task {
Some(Ok(task)) => {
accept_task(task, &mut futures, &mut reuse_count);
}
Some(Err(_)) => break,
None => {}
}
}
accessor.with(|mut access| write_profile(access.as_context_mut()));
if timed_out {
Err(format_err!("guest timed out"))
} else {
wasmtime::error::Ok(())
}
}));
let mut sleep = pin!(tokio::time::sleep(Duration::MAX));
future::poll_fn(|cx| {
let poll = future.as_mut().poll(cx);
if poll.is_pending() {
if let Some(deadline) = task_start_times
.lock()
.unwrap()
.earliest()
.and_then(|v| v.checked_add(request_timeout.saturating_mul(2)))
{
sleep.as_mut().reset(deadline.into());
if sleep.as_mut().poll(cx).is_ready() {
return Poll::Ready(Err(format_err!("guest timed out")));
}
}
if !accept_concurrent.swap(true, Relaxed) {
return future.as_mut().poll(cx);
}
}
poll
})
.await?
}
}
impl<S> Drop for Worker<S>
where
S: HandlerState,
{
fn drop(&mut self) {
self.set_available(false);
}
}
pub struct ProxyHandler<S: HandlerState>(Arc<ProxyHandlerInner<S>>);
impl<S: HandlerState> Clone for ProxyHandler<S> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<S> ProxyHandler<S>
where
S: HandlerState,
{
pub fn new(state: S, instance_pre: ProxyPre<S::StoreData>) -> Self {
Self(Arc::new(ProxyHandlerInner {
state,
instance_pre,
next_id: AtomicU64::from(0),
task_queue: Default::default(),
worker_count: AtomicUsize::from(0),
}))
}
pub fn spawn(&self, req_id: Option<u64>, task: TaskFn<S::StoreData>) {
match self.0.state.max_instance_reuse_count() {
0 => panic!("`max_instance_reuse_count` must be at least 1"),
_ => {
if self.0.worker_count.load(Relaxed) == 0 {
self.start_worker(Some(task), req_id);
} else {
self.0.task_queue.push(task);
if self.0.worker_count.load(SeqCst) == 0 {
self.start_worker(None, None);
}
}
}
}
}
pub fn next_req_id(&self) -> u64 {
self.0.next_id.fetch_add(1, Relaxed)
}
pub fn state(&self) -> &S {
&self.0.state
}
pub fn instance_pre(&self) -> &ProxyPre<S::StoreData> {
&self.0.instance_pre
}
fn start_worker(&self, task: Option<TaskFn<S::StoreData>>, req_id: Option<u64>) {
tokio::spawn(
Worker {
handler: self.clone(),
available: false,
}
.run(task, req_id),
);
}
}