use crate::common::{HttpHooks, Profile, RunCommon, RunTarget};
use bytes::Bytes;
use clap::Parser;
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::{BodyExt as _, Full};
use hyper::server::conn::http1;
use pin_project_lite::pin_project;
use std::convert::Infallible;
use std::ffi::OsString;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{
path::PathBuf,
sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicU64, Ordering},
},
time::{Duration, Instant},
};
use tokio::io::{self, AsyncWrite};
use tokio::sync::{Notify, Semaphore};
use wasmtime::component::{Component, GuestTaskId, Linker};
use wasmtime::error::Context as _;
use wasmtime::{
AsContextMut as _, Engine, Result, Store, StoreContextMut, StoreLimits, UpdateDeadline, bail,
};
use wasmtime_cli_flags::opt::WasmtimeOptionValue;
use wasmtime_wasi::p2::{StreamError, StreamResult};
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
use wasmtime_wasi_http::WasiHttpCtx;
use wasmtime_wasi_http::handler::{
self, HandlerState, Instance, Prepared, Proxy, ProxyHandler, ProxyPre, ShouldAccept, ViewFn,
WorkerExpiration, WorkerState, WorkerStatus,
};
use wasmtime_wasi_http::io::TokioIo;
#[cfg(feature = "debug")]
use crate::commands::run::RunCommand;
#[cfg(feature = "wasi-config")]
use wasmtime_wasi_config::{WasiConfig, WasiConfigVariables};
#[cfg(feature = "wasi-keyvalue")]
use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder};
#[cfg(feature = "wasi-nn")]
use wasmtime_wasi_nn::wit::WasiNnCtx;
const DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT: usize = 128;
const DEFAULT_WASIP2_MAX_INSTANCE_REUSE_COUNT: usize = 1;
const DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT: usize = 16;
struct Host {
table: wasmtime::component::ResourceTable,
ctx: WasiCtx,
http: WasiHttpCtx,
hooks: HttpHooks,
limits: StoreLimits,
#[cfg(feature = "wasi-nn")]
nn: Option<WasiNnCtx>,
#[cfg(feature = "wasi-config")]
wasi_config: Option<WasiConfigVariables>,
#[cfg(feature = "wasi-keyvalue")]
wasi_keyvalue: Option<WasiKeyValueCtx>,
#[cfg(feature = "profiling")]
guest_profiler: Option<Arc<wasmtime::GuestProfiler>>,
write_profile: Option<WriteProfile>,
}
impl WasiView for Host {
fn ctx(&mut self) -> WasiCtxView<'_> {
WasiCtxView {
ctx: &mut self.ctx,
table: &mut self.table,
}
}
}
impl wasmtime_wasi_http::p2::WasiHttpView for Host {
fn http(&mut self) -> wasmtime_wasi_http::p2::WasiHttpCtxView<'_> {
wasmtime_wasi_http::p2::WasiHttpCtxView {
ctx: &mut self.http,
table: &mut self.table,
hooks: &mut self.hooks,
}
}
}
#[cfg(feature = "component-model-async")]
impl wasmtime_wasi_http::p3::WasiHttpView for Host {
fn http(&mut self) -> wasmtime_wasi_http::p3::WasiHttpCtxView<'_> {
wasmtime_wasi_http::p3::WasiHttpCtxView {
table: &mut self.table,
ctx: &mut self.http,
hooks: &mut self.hooks,
}
}
}
const DEFAULT_ADDR: std::net::SocketAddr = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)),
8080,
);
fn parse_duration(s: &str) -> Result<Duration, String> {
Duration::parse(Some(s)).map_err(|e| e.to_string())
}
#[derive(Parser)]
pub struct ServeCommand {
#[command(flatten)]
run: RunCommon,
#[arg(long , value_name = "SOCKADDR", default_value_t = DEFAULT_ADDR)]
addr: SocketAddr,
#[arg(long, value_name = "SOCKADDR")]
shutdown_addr: Option<SocketAddr>,
#[arg(long)]
no_logging_prefix: bool,
#[arg(value_name = "WASM", required = true)]
component: PathBuf,
#[arg(long)]
max_instance_reuse_count: Option<usize>,
#[arg(long)]
max_instance_concurrent_reuse_count: Option<usize>,
#[arg(long, default_value = "1s", value_parser = parse_duration)]
idle_instance_timeout: Duration,
#[arg(short = 'H', long = "header", value_name = "HEADER")]
headers: Vec<String>,
#[arg(long)]
max_concurrent_requests: Option<usize>,
#[arg(long)]
max_concurrent_connections: Option<usize>,
}
impl ServeCommand {
pub fn execute(mut self) -> Result<()> {
self.run.common.init_logging()?;
if self.run.common.wasi.nn == Some(true) {
#[cfg(not(feature = "wasi-nn"))]
{
bail!("Cannot enable wasi-nn when the binary is not compiled with this feature.");
}
}
if self.run.common.wasi.threads == Some(true) {
bail!("wasi-threads does not support components yet")
}
if self.run.common.wasi.http.replace(true) == Some(false) {
bail!("wasi-http is required for the serve command, and must not be disabled");
}
if self.run.common.wasm.component_model.replace(true) == Some(false) {
bail!("components are required for the serve command, and must not be disabled");
}
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_time()
.enable_io()
.build()?;
runtime.block_on(self.serve())?;
Ok(())
}
#[cfg(feature = "debug")]
fn debugger_setup(&mut self) -> Result<Option<RunCommand>> {
fn set_implicit_option(
place: &str,
name: &str,
setting: &mut Option<bool>,
value: bool,
) -> Result<()> {
if *setting == Some(!value) {
bail!(
"Explicitly-set option on {place} {name}={} is not compatible \
with debugging-implied setting {value}",
setting.unwrap()
);
}
*setting = Some(value);
Ok(())
}
#[cfg(feature = "gdbstub")]
let override_bytes = if let Some(addr) = self.run.gdbstub.as_deref() {
if self.run.common.debug.debugger.is_some() {
bail!("-g/--gdb cannot be combined with -Ddebugger=");
}
let addr = if addr.parse::<u16>().is_ok() {
format!("127.0.0.1:{addr}")
} else {
use std::net::SocketAddr as SA;
addr.parse::<SA>()
.with_context(|| format!("invalid gdbstub address: `{addr}`"))?;
addr.to_string()
};
self.run.common.debug.debugger = Some("<built-in gdbstub>".into());
self.run.common.debug.arg.push(addr);
Some(gdbstub_component_artifact::GDBSTUB_COMPONENT)
} else {
None
};
#[cfg(not(feature = "gdbstub"))]
let override_bytes = None;
if let Some(debugger_component_path) = self.run.common.debug.debugger.as_ref() {
set_implicit_option(
"debuggee",
"guest_debug",
&mut self.run.common.debug.guest_debug,
true,
)?;
set_implicit_option(
"debuggee",
"epoch_interruption",
&mut self.run.common.wasm.epoch_interruption,
true,
)?;
let mut debugger_run = RunCommand::try_parse_from(
["run".into(), debugger_component_path.into()]
.into_iter()
.chain(self.run.common.debug.arg.iter().map(OsString::from)),
)?;
debugger_run.module_bytes = override_bytes;
debugger_run.run.common.wasi.tcp.get_or_insert(true);
debugger_run
.run
.common
.wasi
.inherit_network
.get_or_insert(true);
set_implicit_option(
"debugger",
"inherit_stdin",
&mut debugger_run.run.common.wasi.inherit_stdin,
self.run.common.debug.inherit_stdin.unwrap_or(false),
)?;
set_implicit_option(
"debugger",
"inherit_stdout",
&mut debugger_run.run.common.wasi.inherit_stdout,
self.run.common.debug.inherit_stdout.unwrap_or(false),
)?;
set_implicit_option(
"debugger",
"inherit_stderr",
&mut debugger_run.run.common.wasi.inherit_stderr,
self.run.common.debug.inherit_stderr.unwrap_or(false),
)?;
Ok(Some(debugger_run))
} else {
Ok(None)
}
}
#[cfg(feature = "debug")]
async fn serve_under_debugger(
self,
mut debug_run: RunCommand,
linker: Linker<Host>,
component: Component,
) -> Result<()> {
let mut debuggee_store = self.new_store(linker.engine(), None)?;
debuggee_store.debug_register_component(&component)?;
let debug_engine = debug_run.new_engine()?;
let debug_main = debug_run.run.load_module(
&debug_engine,
debug_run.module_and_args[0].as_ref(),
debug_run.module_bytes.as_ref().map(|v| &v[..]),
)?;
let (mut debug_store, debug_linker) =
debug_run.new_store_and_linker(&debug_engine, &debug_main)?;
let debug_component = match debug_main {
RunTarget::Core(_) => {
bail!("Debugger component is a core module; only components are supported")
}
RunTarget::Component(c) => c,
};
let mut debug_linker = match debug_linker {
crate::commands::run::CliLinker::Core(_) => unreachable!(),
crate::commands::run::CliLinker::Component(l) => l,
};
debug_run.add_debugger_api(&mut debug_linker)?;
debug_run
.invoke_debugger(
&mut debug_store,
&debug_component,
&mut debug_linker,
debuggee_store,
move |store| Box::pin(self.serve_maybe_debug(linker, component, Some(store))),
)
.await
}
fn new_store(&self, engine: &Engine, instance_id: Option<u64>) -> Result<Store<Host>> {
let mut builder = WasiCtxBuilder::new();
self.run.configure_wasip2(&mut builder)?;
if let Some(instance_id) = instance_id {
builder.env("INSTANCE_ID", instance_id.to_string());
}
let stdout_prefix: String;
let stderr_prefix: String;
match instance_id {
Some(instance_id) if !self.no_logging_prefix => {
stdout_prefix = format!("stdout [{instance_id}] :: ");
stderr_prefix = format!("stderr [{instance_id}] :: ");
}
_ => {
stdout_prefix = "".to_string();
stderr_prefix = "".to_string();
}
}
builder.stdout(LogStream::new(stdout_prefix, Output::Stdout));
builder.stderr(LogStream::new(stderr_prefix, Output::Stderr));
let mut table = wasmtime::component::ResourceTable::new();
if let Some(max) = self.run.common.wasi.max_resources {
table.set_max_capacity(max);
}
let mut host = Host {
table,
ctx: builder.build(),
http: self.run.wasi_http_ctx()?,
hooks: self.run.wasi_http_hooks(),
limits: StoreLimits::default(),
#[cfg(feature = "wasi-nn")]
nn: None,
#[cfg(feature = "wasi-config")]
wasi_config: None,
#[cfg(feature = "wasi-keyvalue")]
wasi_keyvalue: None,
#[cfg(feature = "profiling")]
guest_profiler: None,
write_profile: None,
};
if self.run.common.wasi.nn == Some(true) {
#[cfg(feature = "wasi-nn")]
{
let graphs = self
.run
.common
.wasi
.nn_graph
.iter()
.map(|g| (g.format.clone(), g.dir.clone()))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?;
host.nn.replace(WasiNnCtx::new(backends, registry));
}
}
if self.run.common.wasi.config == Some(true) {
#[cfg(feature = "wasi-config")]
{
let vars = WasiConfigVariables::from_iter(
self.run
.common
.wasi
.config_var
.iter()
.map(|v| (v.key.clone(), v.value.clone())),
);
host.wasi_config.replace(vars);
}
}
if self.run.common.wasi.keyvalue == Some(true) {
#[cfg(feature = "wasi-keyvalue")]
{
let ctx = WasiKeyValueCtxBuilder::new()
.in_memory_data(
self.run
.common
.wasi
.keyvalue_in_memory_data
.iter()
.map(|v| (v.key.clone(), v.value.clone())),
)
.build();
host.wasi_keyvalue.replace(ctx);
}
}
let mut store = Store::new(engine, host);
if let Some(fuel) = self.run.common.wasi.hostcall_fuel {
store.set_hostcall_fuel(fuel);
}
store.data_mut().limits = self.run.store_limits();
store.limiter(|t| &mut t.limits);
if let Some(fuel) = self.run.common.wasm.fuel {
store.set_fuel(fuel)?;
}
Ok(store)
}
fn add_to_linker(&self, linker: &mut Linker<Host>) -> Result<()> {
self.run.validate_p3_option()?;
let cli = self.run.validate_cli_enabled()?;
if cli == Some(true) {
self.run.add_wasmtime_wasi_to_linker(linker)?;
wasmtime_wasi_http::p2::add_only_http_to_linker_async(linker)?;
#[cfg(feature = "component-model-async")]
if self.run.common.wasi.p3.unwrap_or(crate::common::P3_DEFAULT) {
wasmtime_wasi_http::p3::add_to_linker(linker)?;
}
} else {
wasmtime_wasi_http::p2::add_to_linker_async(linker)?;
#[cfg(feature = "component-model-async")]
if self.run.common.wasi.p3.unwrap_or(crate::common::P3_DEFAULT) {
wasmtime_wasi_http::p3::add_to_linker(linker)?;
wasmtime_wasi::p3::clocks::add_to_linker(linker)?;
wasmtime_wasi::p3::random::add_to_linker(linker)?;
wasmtime_wasi::p3::cli::add_to_linker(linker)?;
}
}
if self.run.common.wasi.nn == Some(true) {
#[cfg(not(feature = "wasi-nn"))]
{
bail!("support for wasi-nn was disabled at compile time");
}
#[cfg(feature = "wasi-nn")]
{
wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| {
let ctx = h.nn.as_mut().unwrap();
wasmtime_wasi_nn::wit::WasiNnView::new(&mut h.table, ctx)
})?;
}
}
if self.run.common.wasi.config == Some(true) {
#[cfg(not(feature = "wasi-config"))]
{
bail!("support for wasi-config was disabled at compile time");
}
#[cfg(feature = "wasi-config")]
{
wasmtime_wasi_config::add_to_linker(linker, |h| {
WasiConfig::from(h.wasi_config.as_ref().unwrap())
})?;
}
}
if self.run.common.wasi.keyvalue == Some(true) {
#[cfg(not(feature = "wasi-keyvalue"))]
{
bail!("support for wasi-keyvalue was disabled at compile time");
}
#[cfg(feature = "wasi-keyvalue")]
{
wasmtime_wasi_keyvalue::add_to_linker(linker, |h: &mut Host| {
WasiKeyValue::new(h.wasi_keyvalue.as_ref().unwrap(), &mut h.table)
})?;
}
}
if self.run.common.wasi.threads == Some(true) {
bail!("support for wasi-threads is not available with components");
}
if self.run.common.wasi.http == Some(false) {
bail!("support for wasi-http must be enabled for `serve` subcommand");
}
Ok(())
}
async fn serve(mut self) -> Result<()> {
#[cfg(feature = "debug")]
let debug_run = self.debugger_setup()?;
let mut config = self
.run
.common
.config(use_pooling_allocator_by_default().unwrap_or(None))?;
config.wasm_component_model(true);
if self.run.common.wasm.timeout.is_some() {
config.epoch_interruption(true);
}
match self.run.profile {
Some(Profile::Native(s)) => {
config.profiler(s);
}
Some(Profile::Guest { .. }) => {
config.epoch_interruption(true);
}
None => {}
}
let engine = Engine::new(&config)?;
let mut linker = Linker::new(&engine);
self.add_to_linker(&mut linker)?;
let component = match self.run.load_module(&engine, &self.component, None)? {
RunTarget::Core(_) => bail!("The serve command currently requires a component"),
RunTarget::Component(c) => c,
};
#[cfg(feature = "debug")]
if let Some(debug_run) = debug_run {
return self
.serve_under_debugger(debug_run, linker, component)
.await;
}
self.serve_maybe_debug(linker, component, None).await
}
async fn serve_maybe_debug(
self,
linker: Linker<Host>,
component: Component,
mut debuggee_store: Option<&mut Store<Host>>,
) -> Result<()> {
let engine = linker.engine();
let request_headers = RequestHeaders::parse(&self.headers)?;
let instance = linker.instantiate_pre(&component)?;
#[cfg(feature = "component-model-async")]
let instance = match wasmtime_wasi_http::p3::bindings::ServicePre::new(instance.clone()) {
Ok(pre) => ProxyPre::P3(pre),
Err(_) => ProxyPre::P2(wasmtime_wasi_http::p2::bindings::ProxyPre::new(instance)?),
};
#[cfg(not(feature = "component-model-async"))]
let instance = ProxyPre::P2(wasmtime_wasi_http::p2::bindings::ProxyPre::new(instance)?);
let shutdown = Arc::new(GracefulShutdown::default());
tokio::task::spawn({
let shutdown = shutdown.clone();
async move {
tokio::signal::ctrl_c().await.unwrap();
shutdown.requested.notify_one();
}
});
if let Some(addr) = self.shutdown_addr {
let listener = tokio::net::TcpListener::bind(addr).await?;
eprintln!(
"Listening for shutdown on tcp://{}/",
listener.local_addr()?
);
let shutdown = shutdown.clone();
tokio::task::spawn(async move {
let _ = listener.accept().await;
shutdown.requested.notify_one();
});
}
let socket = match &self.addr {
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
};
socket.set_reuseaddr(!cfg!(windows))?;
socket.bind(self.addr)?;
let listener = socket.listen(100)?;
eprintln!("Serving HTTP on http://{}/", listener.local_addr()?);
log::info!("Listening on {}", self.addr);
let epoch_interval = if let Some(Profile::Guest { interval, .. }) = self.run.profile {
Some(interval)
} else if let Some(t) = self.run.common.wasm.timeout {
Some(EPOCH_INTERRUPT_PERIOD.min(t))
} else if debuggee_store.is_some() {
Some(Duration::from_millis(1))
} else {
None
};
let _epoch_thread = epoch_interval.map(|t| EpochThread::spawn(t, engine.clone()));
let max_instance_reuse_count = self.max_instance_reuse_count.unwrap_or_else(|| {
if let ProxyPre::P3(_) = &instance {
DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT
} else {
DEFAULT_WASIP2_MAX_INSTANCE_REUSE_COUNT
}
});
let max_instance_concurrent_reuse_count = if let ProxyPre::P3(_) = &instance {
self.max_instance_concurrent_reuse_count
.unwrap_or(DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT)
} else {
1
};
let max_concurrent_connections = self
.max_concurrent_connections
.unwrap_or(if debuggee_store.is_some() { 1 } else { 1000 });
let max_concurrent_requests = self
.max_concurrent_requests
.unwrap_or(if debuggee_store.is_some() { 1 } else { 1000 });
if debuggee_store.is_some() && max_concurrent_connections != 1 {
bail!("cannot have more than 1 max concurrent connections with a debugger");
}
if debuggee_store.is_some() && max_concurrent_requests != 1 {
bail!("cannot have more than 1 max concurrent requests with a debugger");
}
let sem_connections = Arc::new(Semaphore::new(max_concurrent_connections));
let handler = ProxyHandler::new(HostHandlerState {
sem_requests: Semaphore::new(max_concurrent_requests),
cmd: self,
component,
request_headers,
max_instance_reuse_count,
max_instance_concurrent_reuse_count,
instance,
next_instance_id: AtomicU64::default(),
next_request_id: AtomicU64::default(),
_shutdown_guard: Box::new(shutdown.clone().increment()),
});
loop {
let (connection_permit, stream) = tokio::select! {
_ = shutdown.requested.notified() => break,
v = async {
let permit = sem_connections.clone().acquire_owned().await?;
let (stream, _) = listener.accept().await?;
wasmtime::error::Ok((permit, stream))
} => v?,
};
stream.set_nodelay(true)?;
let shutdown_guard = shutdown.clone().increment();
match &mut debuggee_store {
Some(store) => {
handle_client(stream, &handler, Some(store)).await;
}
None => {
let handler = handler.clone();
tokio::task::spawn(async move {
handle_client(stream, &handler, None).await;
drop(shutdown_guard);
drop(connection_permit);
});
}
}
}
handler.state().sem_requests.close();
drop(handler);
if shutdown.close() {
return Ok(());
}
eprintln!("Waiting for child tasks to exit, ctrl-c again to quit sooner...");
tokio::select! {
_ = tokio::signal::ctrl_c() => {}
_ = shutdown.complete.notified() => {}
}
Ok(())
}
}
pin_project! {
struct HostWorkerExpiration {
idle_timeout: Duration,
request_timeout: Duration,
#[pin]
sleep: tokio::time::Sleep,
}
}
impl WorkerExpiration for HostWorkerExpiration {
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
status: WorkerStatus,
start: Instant,
) -> Poll<()> {
let mut me = self.project();
let timeout = match status {
WorkerStatus::Idle => *me.idle_timeout,
WorkerStatus::Requests | WorkerStatus::PostReturn => *me.request_timeout,
};
if let Some(deadline) = start.checked_add(timeout) {
let deadline = deadline.into();
if deadline != me.sleep.deadline() {
me.sleep.as_mut().reset(deadline);
}
me.sleep.poll(cx)
} else {
Poll::Pending
}
}
}
struct HostWorkerState {
instance_id: u64,
max_instance_reuse_count: usize,
max_instance_concurrent_reuse_count: usize,
request_timeout: Duration,
}
impl WorkerState for HostWorkerState {
type StoreData = Host;
type RequestId = u64;
fn should_accept_request(&self, concurrent_count: usize, total_count: usize) -> ShouldAccept {
if total_count >= self.max_instance_reuse_count {
ShouldAccept::Never
} else if concurrent_count >= self.max_instance_concurrent_reuse_count {
ShouldAccept::No
} else {
ShouldAccept::Yes
}
}
fn on_request_start(
&self,
_store: StoreContextMut<Host>,
request_id: u64,
_task_id: GuestTaskId,
) -> Pin<Box<dyn Future<Output = ()> + 'static + Send + Sync>> {
log::info!(
"Instance {} handling request {request_id}",
self.instance_id,
);
Box::pin(tokio::time::sleep(self.request_timeout))
}
fn drop(&self, mut store: Store<Self::StoreData>, result: Result<(), wasmtime::Error>) {
if let Err(error) = result {
eprintln!("worker failed: {error:?}");
}
if let Some(write_profile) = store.data_mut().write_profile.take() {
write_profile(store.as_context_mut());
}
drop(store);
}
}
struct HostHandlerState {
cmd: ServeCommand,
component: Component,
request_headers: RequestHeaders,
max_instance_reuse_count: usize,
max_instance_concurrent_reuse_count: usize,
instance: ProxyPre<Host>,
next_instance_id: AtomicU64,
next_request_id: AtomicU64,
sem_requests: Semaphore,
_shutdown_guard: Box<dyn std::any::Any + Send + Sync>,
}
impl HostHandlerState {
async fn instantiate_into(&self, store: &mut Store<Host>) -> Result<Proxy> {
let write_profile = setup_epoch_handler(&self.cmd, &mut *store, self.component.clone())?;
store.data_mut().write_profile = Some(write_profile);
self.instance.instantiate_async(&mut *store).await
}
fn view(&self) -> ViewFn<Host> {
match &self.instance {
ProxyPre::P2(_) => ViewFn::P2(wasmtime_wasi_http::p2::WasiHttpView::http),
ProxyPre::P3(_) => ViewFn::P3(wasmtime_wasi_http::p3::WasiHttpView::http),
}
}
}
impl HandlerState for HostHandlerState {
type StoreData = Host;
type WorkerExpiration = HostWorkerExpiration;
type WorkerState = HostWorkerState;
async fn instantiate(
&self,
) -> Result<Instance<Self::StoreData, Self::WorkerExpiration, Self::WorkerState>> {
let instance_id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
let mut store = self
.cmd
.new_store(self.component.engine(), Some(instance_id))?;
let proxy = self.instantiate_into(&mut store).await?;
Ok(Instance {
store,
proxy,
view: self.view(),
expiration: HostWorkerExpiration {
idle_timeout: self.cmd.idle_instance_timeout,
request_timeout: self.cmd.run.common.wasm.timeout.unwrap_or(Duration::MAX),
sleep: tokio::time::sleep(Duration::MAX),
},
state: HostWorkerState {
max_instance_reuse_count: self.max_instance_reuse_count,
max_instance_concurrent_reuse_count: self.max_instance_concurrent_reuse_count,
instance_id,
request_timeout: self.cmd.run.common.wasm.timeout.unwrap_or(Duration::MAX),
},
})
}
}
#[derive(Default)]
struct GracefulShutdown {
requested: Notify,
complete: Notify,
state: Mutex<GracefulShutdownState>,
}
#[derive(Default)]
struct GracefulShutdownState {
active_tasks: u32,
notify_when_done: bool,
}
impl GracefulShutdown {
fn increment(self: Arc<Self>) -> impl Drop + Send + Sync {
struct Guard(Arc<GracefulShutdown>);
let mut state = self.state.lock().unwrap();
assert!(!state.notify_when_done);
state.active_tasks += 1;
drop(state);
return Guard(self);
impl Drop for Guard {
fn drop(&mut self) {
let mut state = self.0.state.lock().unwrap();
state.active_tasks -= 1;
if state.notify_when_done && state.active_tasks == 0 {
self.0.complete.notify_one();
}
}
}
}
fn close(&self) -> bool {
let mut state = self.state.lock().unwrap();
state.notify_when_done = true;
state.active_tasks == 0
}
}
const EPOCH_INTERRUPT_PERIOD: Duration = Duration::from_millis(50);
struct EpochThread {
shutdown: Arc<AtomicBool>,
handle: Option<std::thread::JoinHandle<()>>,
}
impl EpochThread {
fn spawn(interval: std::time::Duration, engine: Engine) -> Self {
let shutdown = Arc::new(AtomicBool::new(false));
let handle = {
let shutdown = Arc::clone(&shutdown);
let handle = std::thread::spawn(move || {
while !shutdown.load(Ordering::Relaxed) {
std::thread::sleep(interval);
engine.increment_epoch();
}
});
Some(handle)
};
EpochThread { shutdown, handle }
}
}
impl Drop for EpochThread {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
self.shutdown.store(true, Ordering::Relaxed);
handle.join().unwrap();
}
}
}
type WriteProfile = Box<dyn FnOnce(StoreContextMut<Host>) + Send>;
fn setup_epoch_handler(
cmd: &ServeCommand,
store: &mut Store<Host>,
component: Component,
) -> Result<WriteProfile> {
if let Some(Profile::Guest { interval, path }) = &cmd.run.profile {
#[cfg(feature = "profiling")]
return setup_guest_profiler(store, path.clone(), *interval, component.clone());
#[cfg(not(feature = "profiling"))]
{
let _ = (path, interval);
bail!("support for profiling disabled at compile time!");
}
}
if cmd.run.common.wasm.timeout.is_some() || cmd.run.common.debug.debugger.is_some() {
store.epoch_deadline_async_yield_and_update(1);
}
Ok(Box::new(|_store| {}))
}
#[cfg(feature = "profiling")]
fn setup_guest_profiler(
store: &mut Store<Host>,
path: String,
interval: Duration,
component: Component,
) -> Result<WriteProfile> {
use wasmtime::{AsContext, GuestProfiler, StoreContext, StoreContextMut};
let module_name = "<main>";
store.data_mut().guest_profiler = Some(Arc::new(GuestProfiler::new_component(
store.engine(),
module_name,
interval,
component,
std::iter::empty(),
)?));
fn sample(
mut store: StoreContextMut<Host>,
f: impl FnOnce(&mut GuestProfiler, StoreContext<Host>),
) {
let mut profiler = store.data_mut().guest_profiler.take().unwrap();
f(
Arc::get_mut(&mut profiler).expect("profiling doesn't support threads yet"),
store.as_context(),
);
store.data_mut().guest_profiler = Some(profiler);
}
store.call_hook(|store, kind| {
sample(store, |profiler, store| profiler.call_hook(store, kind));
Ok(())
});
store.epoch_deadline_callback(move |store| {
sample(store, |profiler, store| {
profiler.sample(store, std::time::Duration::ZERO)
});
Ok(UpdateDeadline::Continue(1))
});
store.set_epoch_deadline(1);
let write_profile = Box::new(move |mut store: StoreContextMut<Host>| {
let profiler = Arc::try_unwrap(store.data_mut().guest_profiler.take().unwrap())
.expect("profiling doesn't support threads yet");
if let Err(e) = std::fs::File::create(&path)
.map_err(wasmtime::Error::new)
.and_then(|output| profiler.finish(std::io::BufWriter::new(output)))
{
eprintln!("failed writing profile at {path}: {e:#}");
} else {
eprintln!();
eprintln!("Profile written to: {path}");
eprintln!("View this profile at https://profiler.firefox.com/.");
}
});
Ok(write_profile)
}
type Request = hyper::Request<hyper::body::Incoming>;
async fn handle_client(
client: tokio::net::TcpStream,
handler: &ProxyHandler<HostHandlerState>,
debuggee_store: Option<&mut Store<Host>>,
) {
let lock = &debuggee_store.map(tokio::sync::Mutex::new);
if let Err(e) = http1::Builder::new()
.keep_alive(true)
.serve_connection(
TokioIo::new(client),
hyper::service::service_fn(move |req| async move {
let mut debuggee_store = match &lock {
Some(store) => Some(store.lock().await),
None => None,
};
let debuggee_store = debuggee_store.as_mut().map(|s| &mut ***s);
match handle_request(handler, debuggee_store, req).await {
Ok(r) => Ok::<_, Infallible>(r),
Err(e) => {
eprintln!("error: {e:?}");
let error_html = "\
<!doctype html>
<html>
<head>
<title>500 Internal Server Error</title>
</head>
<body>
<center>
<h1>500 Internal Server Error</h1>
<hr>
wasmtime
</center>
</body>
</html>";
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "text/html; charset=UTF-8")
.body(
Full::new(bytes::Bytes::from(error_html))
.map_err(|_| unreachable!())
.boxed_unsync(),
)
.unwrap())
}
}
}),
)
.await
{
eprintln!("error: {e:?}");
}
}
async fn handle_request(
handler: &ProxyHandler<HostHandlerState>,
debuggee_store: Option<&mut Store<Host>>,
mut req: Request,
) -> Result<hyper::Response<UnsyncBoxBody<Bytes, wasmtime::Error>>> {
use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode;
let _request_permit = handler.state().sem_requests.acquire().await?;
handler.state().request_headers.apply(req.headers_mut());
let request_id = handler
.state()
.next_request_id
.fetch_add(1, Ordering::Relaxed);
log::info!(
"Received request {request_id}: {} {}",
req.method(),
req.uri()
);
let req = req.map(|body| {
body.map_err(ErrorCode::from_hyper_request_error)
.map_err(handler::ErrorCode::from)
.boxed_unsync()
});
match debuggee_store {
Some(store) => {
let instance = handler.state().instantiate_into(store).await?;
let (tx, rx) = futures::channel::oneshot::channel();
let prepared = Prepared::new(
store.as_context_mut(),
&instance,
req,
handler.state().view(),
tx,
)?;
store
.run_concurrent(async |store| prepared.run(store, std::future::pending()).await)
.await??;
rx.await?
}
None => handler.handle(request_id, req).await,
}
}
#[derive(Clone, Default)]
struct RequestHeaders {
entries: Vec<(HeaderName, HeaderValue)>,
}
impl RequestHeaders {
fn parse(headers: &[String]) -> Result<Self> {
let mut entries = Vec::new();
for header in headers {
if let Some(path) = header.strip_prefix('@') {
let contents = std::fs::read_to_string(path)
.with_context(|| format!("failed to read header file `{path}`"))?;
for line in contents.lines().filter(|line| !line.trim().is_empty()) {
entries.push(parse_header(line)?);
}
} else {
entries.push(parse_header(header)?);
}
}
Ok(Self { entries })
}
fn apply(&self, headers: &mut HeaderMap) {
for name in self.entries.iter().map(|(name, _)| name) {
headers.remove(name);
}
for (name, value) in &self.entries {
headers.append(name, value.clone());
}
}
}
fn parse_header(header: &str) -> Result<(HeaderName, HeaderValue)> {
let (name, value) = header
.split_once(':')
.with_context(|| format!("header `{header}` is missing `:`"))?;
let name = HeaderName::from_bytes(name.trim().as_bytes())
.with_context(|| format!("invalid header name in header `{header}`"))?;
let value = HeaderValue::from_str(value.trim_start())
.with_context(|| format!("invalid header value in header `{header}`"))?;
Ok((name, value))
}
#[derive(Clone)]
enum Output {
Stdout,
Stderr,
}
impl Output {
fn write_all(&self, buf: &[u8]) -> io::Result<()> {
use std::io::Write;
match self {
Output::Stdout => std::io::stdout().write_all(buf),
Output::Stderr => std::io::stderr().write_all(buf),
}
}
}
#[derive(Clone)]
struct LogStream {
output: Output,
state: Arc<LogStreamState>,
}
struct LogStreamState {
prefix: String,
needs_prefix_on_next_write: AtomicBool,
}
impl LogStream {
fn new(prefix: String, output: Output) -> LogStream {
LogStream {
output,
state: Arc::new(LogStreamState {
prefix,
needs_prefix_on_next_write: AtomicBool::new(true),
}),
}
}
fn write_all(&mut self, mut bytes: &[u8]) -> io::Result<()> {
while !bytes.is_empty() {
if self
.state
.needs_prefix_on_next_write
.load(Ordering::Relaxed)
{
self.output.write_all(self.state.prefix.as_bytes())?;
self.state
.needs_prefix_on_next_write
.store(false, Ordering::Relaxed);
}
match bytes.iter().position(|b| *b == b'\n') {
Some(i) => {
let (a, b) = bytes.split_at(i + 1);
bytes = b;
self.output.write_all(a)?;
self.state
.needs_prefix_on_next_write
.store(true, Ordering::Relaxed);
}
None => {
self.output.write_all(bytes)?;
break;
}
}
}
Ok(())
}
}
impl wasmtime_wasi::cli::StdoutStream for LogStream {
fn p2_stream(&self) -> Box<dyn wasmtime_wasi::p2::OutputStream> {
Box::new(self.clone())
}
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(self.clone())
}
}
impl wasmtime_wasi::cli::IsTerminal for LogStream {
fn is_terminal(&self) -> bool {
match &self.output {
Output::Stdout => std::io::stdout().is_terminal(),
Output::Stderr => std::io::stderr().is_terminal(),
}
}
}
impl wasmtime_wasi::p2::OutputStream for LogStream {
fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> {
self.write_all(&bytes)
.map_err(|e| StreamError::LastOperationFailed(e.into()))?;
Ok(())
}
fn flush(&mut self) -> StreamResult<()> {
Ok(())
}
fn check_write(&mut self) -> StreamResult<usize> {
Ok(1024 * 1024)
}
}
#[async_trait::async_trait]
impl wasmtime_wasi::p2::Pollable for LogStream {
async fn ready(&mut self) {}
}
impl AsyncWrite for LogStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.write_all(buf).map(|_| buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn use_pooling_allocator_by_default() -> Result<Option<bool>> {
use wasmtime::{Config, Memory, MemoryType};
const BITS_TO_TEST: u32 = 42;
let mut config = Config::new();
config.wasm_memory64(true);
config.memory_reservation(1 << BITS_TO_TEST);
let engine = Engine::new(&config)?;
let mut store = Store::new(&engine, ());
let ty = MemoryType::new64(0, Some(1 << (BITS_TO_TEST - 16)));
if Memory::new(&mut store, ty).is_ok() {
Ok(Some(true))
} else {
Ok(None)
}
}