Skip to main content

wasmtime_cli/commands/
serve.rs

1use crate::common::{Profile, RunCommon, RunTarget};
2use bytes::Bytes;
3use clap::Parser;
4use futures::future::FutureExt;
5use http::{Response, StatusCode};
6use http_body_util::BodyExt as _;
7use http_body_util::combinators::UnsyncBoxBody;
8use std::convert::Infallible;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::{
13    path::PathBuf,
14    sync::{
15        Arc, Mutex,
16        atomic::{AtomicBool, Ordering},
17    },
18    time::Duration,
19};
20use tokio::io::{self, AsyncWrite};
21use tokio::sync::Notify;
22use wasmtime::component::{Component, Linker, ResourceTable};
23use wasmtime::{
24    Engine, Result, Store, StoreContextMut, StoreLimits, UpdateDeadline, bail, error::Context as _,
25};
26use wasmtime_cli_flags::opt::WasmtimeOptionValue;
27use wasmtime_wasi::p2::{StreamError, StreamResult};
28use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
29#[cfg(feature = "component-model-async")]
30use wasmtime_wasi_http::handler::p2::bindings as p2;
31use wasmtime_wasi_http::handler::{HandlerState, Proxy, ProxyHandler, ProxyPre, StoreBundle};
32use wasmtime_wasi_http::io::TokioIo;
33use wasmtime_wasi_http::{
34    DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS, DEFAULT_OUTGOING_BODY_CHUNK_SIZE, WasiHttpCtx,
35    WasiHttpView,
36};
37
38#[cfg(feature = "wasi-config")]
39use wasmtime_wasi_config::{WasiConfig, WasiConfigVariables};
40#[cfg(feature = "wasi-keyvalue")]
41use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder};
42#[cfg(feature = "wasi-nn")]
43use wasmtime_wasi_nn::wit::WasiNnCtx;
44
45const DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT: usize = 128;
46const DEFAULT_WASIP2_MAX_INSTANCE_REUSE_COUNT: usize = 1;
47const DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT: usize = 16;
48
49struct Host {
50    table: wasmtime::component::ResourceTable,
51    ctx: WasiCtx,
52    http: WasiHttpCtx,
53    http_outgoing_body_buffer_chunks: Option<usize>,
54    http_outgoing_body_chunk_size: Option<usize>,
55
56    #[cfg(feature = "component-model-async")]
57    p3_http: crate::common::DefaultP3Ctx,
58
59    limits: StoreLimits,
60
61    #[cfg(feature = "wasi-nn")]
62    nn: Option<WasiNnCtx>,
63
64    #[cfg(feature = "wasi-config")]
65    wasi_config: Option<WasiConfigVariables>,
66
67    #[cfg(feature = "wasi-keyvalue")]
68    wasi_keyvalue: Option<WasiKeyValueCtx>,
69
70    #[cfg(feature = "profiling")]
71    guest_profiler: Option<Arc<wasmtime::GuestProfiler>>,
72}
73
74impl WasiView for Host {
75    fn ctx(&mut self) -> WasiCtxView<'_> {
76        WasiCtxView {
77            ctx: &mut self.ctx,
78            table: &mut self.table,
79        }
80    }
81}
82
83impl WasiHttpView for Host {
84    fn ctx(&mut self) -> &mut WasiHttpCtx {
85        &mut self.http
86    }
87    fn table(&mut self) -> &mut ResourceTable {
88        &mut self.table
89    }
90
91    fn outgoing_body_buffer_chunks(&mut self) -> usize {
92        self.http_outgoing_body_buffer_chunks
93            .unwrap_or_else(|| DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS)
94    }
95
96    fn outgoing_body_chunk_size(&mut self) -> usize {
97        self.http_outgoing_body_chunk_size
98            .unwrap_or_else(|| DEFAULT_OUTGOING_BODY_CHUNK_SIZE)
99    }
100}
101
102#[cfg(feature = "component-model-async")]
103impl wasmtime_wasi_http::p3::WasiHttpView for Host {
104    fn http(&mut self) -> wasmtime_wasi_http::p3::WasiHttpCtxView<'_> {
105        wasmtime_wasi_http::p3::WasiHttpCtxView {
106            table: &mut self.table,
107            ctx: &mut self.p3_http,
108        }
109    }
110}
111
112const DEFAULT_ADDR: std::net::SocketAddr = std::net::SocketAddr::new(
113    std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)),
114    8080,
115);
116
117fn parse_duration(s: &str) -> Result<Duration, String> {
118    Duration::parse(Some(s)).map_err(|e| e.to_string())
119}
120
121/// Runs a WebAssembly module
122#[derive(Parser)]
123pub struct ServeCommand {
124    #[command(flatten)]
125    run: RunCommon,
126
127    /// Socket address for the web server to bind to.
128    #[arg(long , value_name = "SOCKADDR", default_value_t = DEFAULT_ADDR)]
129    addr: SocketAddr,
130
131    /// Socket address where, when connected to, will initiate a graceful
132    /// shutdown.
133    ///
134    /// Note that graceful shutdown is also supported on ctrl-c.
135    #[arg(long, value_name = "SOCKADDR")]
136    shutdown_addr: Option<SocketAddr>,
137
138    /// Disable log prefixes of wasi-http handlers.
139    /// if unspecified, logs will be prefixed with 'stdout|stderr [{req_id}] :: '
140    #[arg(long)]
141    no_logging_prefix: bool,
142
143    /// The WebAssembly component to run.
144    #[arg(value_name = "WASM", required = true)]
145    component: PathBuf,
146
147    /// Maximum number of requests to send to a single component instance before
148    /// dropping it.
149    ///
150    /// This defaults to 1 for WASIp2 components and 128 for WASIp3 components.
151    #[arg(long)]
152    max_instance_reuse_count: Option<usize>,
153
154    /// Maximum number of concurrent requests to send to a single component
155    /// instance.
156    ///
157    /// This defaults to 1 for WASIp2 components and 16 for WASIp3 components.
158    /// Note that setting it to more than 1 will have no effect for WASIp2
159    /// components since they cannot be called concurrently.
160    #[arg(long)]
161    max_instance_concurrent_reuse_count: Option<usize>,
162
163    /// Time to hold an idle component instance for possible reuse before
164    /// dropping it.
165    ///
166    /// A number with no suffix or with an `s` suffix is interpreted as seconds;
167    /// other accepted suffixes include `ms` (milliseconds), `us` or `μs`
168    /// (microseconds), and `ns` (nanoseconds).
169    #[arg(long, default_value = "1s", value_parser = parse_duration)]
170    idle_instance_timeout: Duration,
171}
172
173impl ServeCommand {
174    /// Start a server to run the given wasi-http proxy component
175    pub fn execute(mut self) -> Result<()> {
176        self.run.common.init_logging()?;
177
178        // We force cli errors before starting to listen for connections so then
179        // we don't accidentally delay them to the first request.
180
181        if self.run.common.wasi.nn == Some(true) {
182            #[cfg(not(feature = "wasi-nn"))]
183            {
184                bail!("Cannot enable wasi-nn when the binary is not compiled with this feature.");
185            }
186        }
187
188        if self.run.common.wasi.threads == Some(true) {
189            bail!("wasi-threads does not support components yet")
190        }
191
192        // The serve command requires both wasi-http and the component model, so
193        // we enable those by default here.
194        if self.run.common.wasi.http.replace(true) == Some(false) {
195            bail!("wasi-http is required for the serve command, and must not be disabled");
196        }
197        if self.run.common.wasm.component_model.replace(true) == Some(false) {
198            bail!("components are required for the serve command, and must not be disabled");
199        }
200
201        let runtime = tokio::runtime::Builder::new_multi_thread()
202            .enable_time()
203            .enable_io()
204            .build()?;
205
206        runtime.block_on(self.serve())?;
207
208        Ok(())
209    }
210
211    fn new_store(&self, engine: &Engine, req_id: Option<u64>) -> Result<Store<Host>> {
212        let mut builder = WasiCtxBuilder::new();
213        self.run.configure_wasip2(&mut builder)?;
214
215        if let Some(req_id) = req_id {
216            builder.env("REQUEST_ID", req_id.to_string());
217        }
218
219        let stdout_prefix: String;
220        let stderr_prefix: String;
221        match req_id {
222            Some(req_id) if !self.no_logging_prefix => {
223                stdout_prefix = format!("stdout [{req_id}] :: ");
224                stderr_prefix = format!("stderr [{req_id}] :: ");
225            }
226            _ => {
227                stdout_prefix = "".to_string();
228                stderr_prefix = "".to_string();
229            }
230        }
231        builder.stdout(LogStream::new(stdout_prefix, Output::Stdout));
232        builder.stderr(LogStream::new(stderr_prefix, Output::Stderr));
233
234        let mut table = wasmtime::component::ResourceTable::new();
235        if let Some(max) = self.run.common.wasi.max_resources {
236            table.set_max_capacity(max);
237        }
238        let mut host = Host {
239            table,
240            ctx: builder.build(),
241            http: self.run.wasi_http_ctx()?,
242            http_outgoing_body_buffer_chunks: self.run.common.wasi.http_outgoing_body_buffer_chunks,
243            http_outgoing_body_chunk_size: self.run.common.wasi.http_outgoing_body_chunk_size,
244
245            limits: StoreLimits::default(),
246
247            #[cfg(feature = "wasi-nn")]
248            nn: None,
249            #[cfg(feature = "wasi-config")]
250            wasi_config: None,
251            #[cfg(feature = "wasi-keyvalue")]
252            wasi_keyvalue: None,
253            #[cfg(feature = "profiling")]
254            guest_profiler: None,
255            #[cfg(feature = "component-model-async")]
256            p3_http: crate::common::DefaultP3Ctx,
257        };
258
259        if self.run.common.wasi.nn == Some(true) {
260            #[cfg(feature = "wasi-nn")]
261            {
262                let graphs = self
263                    .run
264                    .common
265                    .wasi
266                    .nn_graph
267                    .iter()
268                    .map(|g| (g.format.clone(), g.dir.clone()))
269                    .collect::<Vec<_>>();
270                let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?;
271                host.nn.replace(WasiNnCtx::new(backends, registry));
272            }
273        }
274
275        if self.run.common.wasi.config == Some(true) {
276            #[cfg(feature = "wasi-config")]
277            {
278                let vars = WasiConfigVariables::from_iter(
279                    self.run
280                        .common
281                        .wasi
282                        .config_var
283                        .iter()
284                        .map(|v| (v.key.clone(), v.value.clone())),
285                );
286                host.wasi_config.replace(vars);
287            }
288        }
289
290        if self.run.common.wasi.keyvalue == Some(true) {
291            #[cfg(feature = "wasi-keyvalue")]
292            {
293                let ctx = WasiKeyValueCtxBuilder::new()
294                    .in_memory_data(
295                        self.run
296                            .common
297                            .wasi
298                            .keyvalue_in_memory_data
299                            .iter()
300                            .map(|v| (v.key.clone(), v.value.clone())),
301                    )
302                    .build();
303                host.wasi_keyvalue.replace(ctx);
304            }
305        }
306
307        let mut store = Store::new(engine, host);
308
309        if let Some(fuel) = self.run.common.wasi.hostcall_fuel {
310            store.set_hostcall_fuel(fuel);
311        }
312
313        store.data_mut().limits = self.run.store_limits();
314        store.limiter(|t| &mut t.limits);
315
316        // If fuel has been configured, we want to add the configured
317        // fuel amount to this store.
318        if let Some(fuel) = self.run.common.wasm.fuel {
319            store.set_fuel(fuel)?;
320        }
321
322        Ok(store)
323    }
324
325    fn add_to_linker(&self, linker: &mut Linker<Host>) -> Result<()> {
326        self.run.validate_p3_option()?;
327        let cli = self.run.validate_cli_enabled()?;
328
329        // Repurpose the `-Scli` flag of `wasmtime run` for `wasmtime serve`
330        // to serve as a signal to enable all WASI interfaces instead of just
331        // those in the `proxy` world. If `-Scli` is present then add all
332        // `command` APIs and then additionally add in the required HTTP APIs.
333        //
334        // If `-Scli` isn't passed then use the `add_to_linker_async`
335        // bindings which adds just those interfaces that the proxy interface
336        // uses.
337        if cli == Some(true) {
338            self.run.add_wasmtime_wasi_to_linker(linker)?;
339            wasmtime_wasi_http::add_only_http_to_linker_async(linker)?;
340            #[cfg(feature = "component-model-async")]
341            if self.run.common.wasi.p3.unwrap_or(crate::common::P3_DEFAULT) {
342                wasmtime_wasi_http::p3::add_to_linker(linker)?;
343            }
344        } else {
345            wasmtime_wasi_http::add_to_linker_async(linker)?;
346            #[cfg(feature = "component-model-async")]
347            if self.run.common.wasi.p3.unwrap_or(crate::common::P3_DEFAULT) {
348                wasmtime_wasi_http::p3::add_to_linker(linker)?;
349                wasmtime_wasi::p3::clocks::add_to_linker(linker)?;
350                wasmtime_wasi::p3::random::add_to_linker(linker)?;
351                wasmtime_wasi::p3::cli::add_to_linker(linker)?;
352            }
353        }
354
355        if self.run.common.wasi.nn == Some(true) {
356            #[cfg(not(feature = "wasi-nn"))]
357            {
358                bail!("support for wasi-nn was disabled at compile time");
359            }
360            #[cfg(feature = "wasi-nn")]
361            {
362                wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| {
363                    let ctx = h.nn.as_mut().unwrap();
364                    wasmtime_wasi_nn::wit::WasiNnView::new(&mut h.table, ctx)
365                })?;
366            }
367        }
368
369        if self.run.common.wasi.config == Some(true) {
370            #[cfg(not(feature = "wasi-config"))]
371            {
372                bail!("support for wasi-config was disabled at compile time");
373            }
374            #[cfg(feature = "wasi-config")]
375            {
376                wasmtime_wasi_config::add_to_linker(linker, |h| {
377                    WasiConfig::from(h.wasi_config.as_ref().unwrap())
378                })?;
379            }
380        }
381
382        if self.run.common.wasi.keyvalue == Some(true) {
383            #[cfg(not(feature = "wasi-keyvalue"))]
384            {
385                bail!("support for wasi-keyvalue was disabled at compile time");
386            }
387            #[cfg(feature = "wasi-keyvalue")]
388            {
389                wasmtime_wasi_keyvalue::add_to_linker(linker, |h: &mut Host| {
390                    WasiKeyValue::new(h.wasi_keyvalue.as_ref().unwrap(), &mut h.table)
391                })?;
392            }
393        }
394
395        if self.run.common.wasi.threads == Some(true) {
396            bail!("support for wasi-threads is not available with components");
397        }
398
399        if self.run.common.wasi.http == Some(false) {
400            bail!("support for wasi-http must be enabled for `serve` subcommand");
401        }
402
403        Ok(())
404    }
405
406    async fn serve(mut self) -> Result<()> {
407        use hyper::server::conn::http1;
408
409        let mut config = self
410            .run
411            .common
412            .config(use_pooling_allocator_by_default().unwrap_or(None))?;
413        config.wasm_component_model(true);
414
415        if self.run.common.wasm.timeout.is_some() {
416            config.epoch_interruption(true);
417        }
418
419        match self.run.profile {
420            Some(Profile::Native(s)) => {
421                config.profiler(s);
422            }
423            Some(Profile::Guest { .. }) => {
424                config.epoch_interruption(true);
425            }
426            None => {}
427        }
428
429        let engine = Engine::new(&config)?;
430        let mut linker = Linker::new(&engine);
431
432        self.add_to_linker(&mut linker)?;
433
434        let component = match self.run.load_module(&engine, &self.component)? {
435            RunTarget::Core(_) => bail!("The serve command currently requires a component"),
436            RunTarget::Component(c) => c,
437        };
438
439        let instance = linker.instantiate_pre(&component)?;
440        #[cfg(feature = "component-model-async")]
441        let instance = match wasmtime_wasi_http::p3::bindings::ServicePre::new(instance.clone()) {
442            Ok(pre) => ProxyPre::P3(pre),
443            Err(_) => ProxyPre::P2(p2::ProxyPre::new(instance)?),
444        };
445        #[cfg(not(feature = "component-model-async"))]
446        let instance = ProxyPre::P2(p2::ProxyPre::new(instance)?);
447
448        // Spawn background task(s) waiting for graceful shutdown signals. This
449        // always listens for ctrl-c but additionally can listen for a TCP
450        // connection to the specified address.
451        let shutdown = Arc::new(GracefulShutdown::default());
452        tokio::task::spawn({
453            let shutdown = shutdown.clone();
454            async move {
455                tokio::signal::ctrl_c().await.unwrap();
456                shutdown.requested.notify_one();
457            }
458        });
459        if let Some(addr) = self.shutdown_addr {
460            let listener = tokio::net::TcpListener::bind(addr).await?;
461            eprintln!(
462                "Listening for shutdown on tcp://{}/",
463                listener.local_addr()?
464            );
465            let shutdown = shutdown.clone();
466            tokio::task::spawn(async move {
467                let _ = listener.accept().await;
468                shutdown.requested.notify_one();
469            });
470        }
471
472        let socket = match &self.addr {
473            SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
474            SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
475        };
476        // Conditionally enable `SO_REUSEADDR` depending on the current
477        // platform. On Unix we want this to be able to rebind an address in
478        // the `TIME_WAIT` state which can happen then a server is killed with
479        // active TCP connections and then restarted. On Windows though if
480        // `SO_REUSEADDR` is specified then it enables multiple applications to
481        // bind the port at the same time which is not something we want. Hence
482        // this is conditionally set based on the platform (and deviates from
483        // Tokio's default from always-on).
484        socket.set_reuseaddr(!cfg!(windows))?;
485        socket.bind(self.addr)?;
486        let listener = socket.listen(100)?;
487
488        eprintln!("Serving HTTP on http://{}/", listener.local_addr()?);
489
490        log::info!("Listening on {}", self.addr);
491
492        let epoch_interval = if let Some(Profile::Guest { interval, .. }) = self.run.profile {
493            Some(interval)
494        } else if let Some(t) = self.run.common.wasm.timeout {
495            Some(EPOCH_INTERRUPT_PERIOD.min(t))
496        } else {
497            None
498        };
499        let _epoch_thread = epoch_interval.map(|t| EpochThread::spawn(t, engine.clone()));
500
501        let max_instance_reuse_count = self.max_instance_reuse_count.unwrap_or_else(|| {
502            if let ProxyPre::P3(_) = &instance {
503                DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT
504            } else {
505                DEFAULT_WASIP2_MAX_INSTANCE_REUSE_COUNT
506            }
507        });
508
509        let max_instance_concurrent_reuse_count = if let ProxyPre::P3(_) = &instance {
510            self.max_instance_concurrent_reuse_count
511                .unwrap_or(DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT)
512        } else {
513            1
514        };
515
516        let handler = ProxyHandler::new(
517            HostHandlerState {
518                cmd: self,
519                engine,
520                component,
521                max_instance_reuse_count,
522                max_instance_concurrent_reuse_count,
523            },
524            instance,
525        );
526
527        loop {
528            // Wait for a socket, but also "race" against shutdown to break out
529            // of this loop. Once the graceful shutdown signal is received then
530            // this loop exits immediately.
531            let (stream, _) = tokio::select! {
532                _ = shutdown.requested.notified() => break,
533                v = listener.accept() => v?,
534            };
535
536            // The Nagle algorithm can impose a significant latency penalty
537            // (e.g. 40ms on Linux) on guests which write small, intermittent
538            // response body chunks (e.g. SSE streams).  Here we disable that
539            // algorithm and rely on the guest to buffer if appropriate to avoid
540            // TCP fragmentation.
541            stream.set_nodelay(true)?;
542
543            let stream = TokioIo::new(stream);
544            let h = handler.clone();
545            let shutdown_guard = shutdown.clone().increment();
546            tokio::task::spawn(async move {
547                if let Err(e) = http1::Builder::new()
548                    .keep_alive(true)
549                    .serve_connection(
550                        stream,
551                        hyper::service::service_fn(move |req| {
552                            let h = h.clone();
553                            async move {
554                                use http_body_util::{BodyExt, Full};
555                                match handle_request(h, req).await {
556                                    Ok(r) => Ok::<_, Infallible>(r),
557                                    Err(e) => {
558                                        eprintln!("error: {e:?}");
559                                        let error_html = "\
560<!doctype html>
561<html>
562<head>
563    <title>500 Internal Server Error</title>
564</head>
565<body>
566    <center>
567        <h1>500 Internal Server Error</h1>
568        <hr>
569        wasmtime
570    </center>
571</body>
572</html>";
573                                        Ok(Response::builder()
574                                            .status(StatusCode::INTERNAL_SERVER_ERROR)
575                                            .header("Content-Type", "text/html; charset=UTF-8")
576                                            .body(
577                                                Full::new(bytes::Bytes::from(error_html))
578                                                    .map_err(|_| unreachable!())
579                                                    .boxed_unsync(),
580                                            )
581                                            .unwrap())
582                                    }
583                                }
584                            }
585                        }),
586                    )
587                    .await
588                {
589                    eprintln!("error: {e:?}");
590                }
591                drop(shutdown_guard);
592            });
593        }
594
595        // Upon exiting the loop we'll no longer process any more incoming
596        // connections but there may still be outstanding connections
597        // processing in child tasks. If there are wait for those to complete
598        // before shutting down completely. Also enable short-circuiting this
599        // wait with a second ctrl-c signal.
600        if shutdown.close() {
601            return Ok(());
602        }
603        eprintln!("Waiting for child tasks to exit, ctrl-c again to quit sooner...");
604        tokio::select! {
605            _ = tokio::signal::ctrl_c() => {}
606            _ = shutdown.complete.notified() => {}
607        }
608
609        Ok(())
610    }
611}
612
613struct HostHandlerState {
614    cmd: ServeCommand,
615    engine: Engine,
616    component: Component,
617    max_instance_reuse_count: usize,
618    max_instance_concurrent_reuse_count: usize,
619}
620
621impl HandlerState for HostHandlerState {
622    type StoreData = Host;
623
624    fn new_store(&self, req_id: Option<u64>) -> Result<StoreBundle<Host>> {
625        let mut store = self.cmd.new_store(&self.engine, req_id)?;
626        let write_profile = setup_epoch_handler(&self.cmd, &mut store, self.component.clone())?;
627
628        Ok(StoreBundle {
629            store,
630            write_profile,
631        })
632    }
633
634    fn request_timeout(&self) -> Duration {
635        self.cmd.run.common.wasm.timeout.unwrap_or(Duration::MAX)
636    }
637
638    fn idle_instance_timeout(&self) -> Duration {
639        self.cmd.idle_instance_timeout
640    }
641
642    fn max_instance_reuse_count(&self) -> usize {
643        self.max_instance_reuse_count
644    }
645
646    fn max_instance_concurrent_reuse_count(&self) -> usize {
647        self.max_instance_concurrent_reuse_count
648    }
649
650    fn handle_worker_error(&self, error: wasmtime::Error) {
651        eprintln!("worker error: {error}");
652    }
653}
654
655/// Helper structure to manage graceful shutdown int he accept loop above.
656#[derive(Default)]
657struct GracefulShutdown {
658    /// Async notification that shutdown has been requested.
659    requested: Notify,
660    /// Async notification that shutdown has completed, signaled when
661    /// `notify_when_done` is `true` and `active_tasks` reaches 0.
662    complete: Notify,
663    /// Internal state related to what's in progress when shutdown is requested.
664    state: Mutex<GracefulShutdownState>,
665}
666
667#[derive(Default)]
668struct GracefulShutdownState {
669    active_tasks: u32,
670    notify_when_done: bool,
671}
672
673impl GracefulShutdown {
674    /// Increments the number of active tasks and returns a guard indicating
675    fn increment(self: Arc<Self>) -> impl Drop {
676        struct Guard(Arc<GracefulShutdown>);
677
678        let mut state = self.state.lock().unwrap();
679        assert!(!state.notify_when_done);
680        state.active_tasks += 1;
681        drop(state);
682
683        return Guard(self);
684
685        impl Drop for Guard {
686            fn drop(&mut self) {
687                let mut state = self.0.state.lock().unwrap();
688                state.active_tasks -= 1;
689                if state.notify_when_done && state.active_tasks == 0 {
690                    self.0.complete.notify_one();
691                }
692            }
693        }
694    }
695
696    /// Flags this state as done spawning tasks and returns whether there are no
697    /// more child tasks remaining.
698    fn close(&self) -> bool {
699        let mut state = self.state.lock().unwrap();
700        state.notify_when_done = true;
701        state.active_tasks == 0
702    }
703}
704
705/// When executing with a timeout enabled, this is how frequently epoch
706/// interrupts will be executed to check for timeouts. If guest profiling
707/// is enabled, the guest epoch period will be used.
708const EPOCH_INTERRUPT_PERIOD: Duration = Duration::from_millis(50);
709
710struct EpochThread {
711    shutdown: Arc<AtomicBool>,
712    handle: Option<std::thread::JoinHandle<()>>,
713}
714
715impl EpochThread {
716    fn spawn(interval: std::time::Duration, engine: Engine) -> Self {
717        let shutdown = Arc::new(AtomicBool::new(false));
718        let handle = {
719            let shutdown = Arc::clone(&shutdown);
720            let handle = std::thread::spawn(move || {
721                while !shutdown.load(Ordering::Relaxed) {
722                    std::thread::sleep(interval);
723                    engine.increment_epoch();
724                }
725            });
726            Some(handle)
727        };
728
729        EpochThread { shutdown, handle }
730    }
731}
732
733impl Drop for EpochThread {
734    fn drop(&mut self) {
735        if let Some(handle) = self.handle.take() {
736            self.shutdown.store(true, Ordering::Relaxed);
737            handle.join().unwrap();
738        }
739    }
740}
741
742type WriteProfile = Box<dyn FnOnce(StoreContextMut<Host>) + Send>;
743
744fn setup_epoch_handler(
745    cmd: &ServeCommand,
746    store: &mut Store<Host>,
747    component: Component,
748) -> Result<WriteProfile> {
749    // Profiling Enabled
750    if let Some(Profile::Guest { interval, path }) = &cmd.run.profile {
751        #[cfg(feature = "profiling")]
752        return setup_guest_profiler(store, path.clone(), *interval, component.clone());
753        #[cfg(not(feature = "profiling"))]
754        {
755            let _ = (path, interval);
756            bail!("support for profiling disabled at compile time!");
757        }
758    }
759
760    // Profiling disabled but there's a global request timeout
761    if cmd.run.common.wasm.timeout.is_some() {
762        store.epoch_deadline_async_yield_and_update(1);
763    }
764
765    Ok(Box::new(|_store| {}))
766}
767
768#[cfg(feature = "profiling")]
769fn setup_guest_profiler(
770    store: &mut Store<Host>,
771    path: String,
772    interval: Duration,
773    component: Component,
774) -> Result<WriteProfile> {
775    use wasmtime::{AsContext, GuestProfiler, StoreContext, StoreContextMut};
776
777    let module_name = "<main>";
778
779    store.data_mut().guest_profiler = Some(Arc::new(GuestProfiler::new_component(
780        store.engine(),
781        module_name,
782        interval,
783        component,
784        std::iter::empty(),
785    )?));
786
787    fn sample(
788        mut store: StoreContextMut<Host>,
789        f: impl FnOnce(&mut GuestProfiler, StoreContext<Host>),
790    ) {
791        let mut profiler = store.data_mut().guest_profiler.take().unwrap();
792        f(
793            Arc::get_mut(&mut profiler).expect("profiling doesn't support threads yet"),
794            store.as_context(),
795        );
796        store.data_mut().guest_profiler = Some(profiler);
797    }
798
799    // Hostcall entry/exit, etc.
800    store.call_hook(|store, kind| {
801        sample(store, |profiler, store| profiler.call_hook(store, kind));
802        Ok(())
803    });
804
805    store.epoch_deadline_callback(move |store| {
806        sample(store, |profiler, store| {
807            profiler.sample(store, std::time::Duration::ZERO)
808        });
809
810        Ok(UpdateDeadline::Continue(1))
811    });
812
813    store.set_epoch_deadline(1);
814
815    let write_profile = Box::new(move |mut store: StoreContextMut<Host>| {
816        let profiler = Arc::try_unwrap(store.data_mut().guest_profiler.take().unwrap())
817            .expect("profiling doesn't support threads yet");
818        if let Err(e) = std::fs::File::create(&path)
819            .map_err(wasmtime::Error::new)
820            .and_then(|output| profiler.finish(std::io::BufWriter::new(output)))
821        {
822            eprintln!("failed writing profile at {path}: {e:#}");
823        } else {
824            eprintln!();
825            eprintln!("Profile written to: {path}");
826            eprintln!("View this profile at https://profiler.firefox.com/.");
827        }
828    });
829
830    Ok(write_profile)
831}
832
833type Request = hyper::Request<hyper::body::Incoming>;
834
835async fn handle_request(
836    handler: ProxyHandler<HostHandlerState>,
837    req: Request,
838) -> Result<hyper::Response<UnsyncBoxBody<Bytes, wasmtime::Error>>> {
839    use tokio::sync::oneshot;
840
841    let req_id = handler.next_req_id();
842
843    log::info!(
844        "Request {req_id} handling {} to {}",
845        req.method(),
846        req.uri()
847    );
848
849    // Here we must declare different channel types for p2 and p3 since p2's
850    // `WasiHttpView::new_response_outparam` expects a specific kind of sender
851    // that uses `p2::http::types::ErrorCode`, and we don't want to have to
852    // convert from the p3 `ErrorCode` to the p2 one, only to convert again to
853    // `wasmtime::Error`.
854
855    type P2Response = Result<
856        hyper::Response<wasmtime_wasi_http::body::HyperOutgoingBody>,
857        p2::http::types::ErrorCode,
858    >;
859    type P3Response = hyper::Response<UnsyncBoxBody<Bytes, wasmtime::Error>>;
860
861    enum Sender {
862        P2(oneshot::Sender<P2Response>),
863        P3(oneshot::Sender<P3Response>),
864    }
865
866    enum Receiver {
867        P2(oneshot::Receiver<P2Response>),
868        P3(oneshot::Receiver<P3Response>),
869    }
870
871    let (tx, rx) = match handler.instance_pre() {
872        ProxyPre::P2(_) => {
873            let (tx, rx) = oneshot::channel();
874            (Sender::P2(tx), Receiver::P2(rx))
875        }
876        ProxyPre::P3(_) => {
877            let (tx, rx) = oneshot::channel();
878            (Sender::P3(tx), Receiver::P3(rx))
879        }
880    };
881
882    handler.spawn(
883        if handler.state().max_instance_reuse_count() == 1 {
884            Some(req_id)
885        } else {
886            None
887        },
888        Box::new(move |store, proxy| {
889            Box::pin(
890                async move {
891                    match proxy {
892                        Proxy::P2(proxy) => {
893                            let Sender::P2(tx) = tx else { unreachable!() };
894                            let (req, out) = store.with(move |mut store| {
895                                let req = store
896                                    .data_mut()
897                                    .new_incoming_request(p2::http::types::Scheme::Http, req)?;
898                                let out = store.data_mut().new_response_outparam(tx)?;
899                                wasmtime::error::Ok((req, out))
900                            })?;
901
902                            proxy
903                                .wasi_http_incoming_handler()
904                                .call_handle(store, req, out)
905                                .await
906                        }
907                        Proxy::P3(proxy) => {
908                            use wasmtime_wasi_http::p3::bindings::http::types::{
909                                ErrorCode, Request,
910                            };
911
912                            let Sender::P3(tx) = tx else { unreachable!() };
913                            let (req, body) = req.into_parts();
914                            let body = body.map_err(ErrorCode::from_hyper_request_error);
915                            let req = http::Request::from_parts(req, body);
916                            let (request, request_io_result) = Request::from_http(req);
917                            let (res, task) = proxy.handle(store, request).await??;
918                            let res = store
919                                .with(|mut store| res.into_http(&mut store, request_io_result))?;
920                            _ = tx.send(res.map(|body| body.map_err(|e| e.into()).boxed_unsync()));
921
922                            // Wait for the task to finish.
923                            task.block(store).await;
924                            Ok(())
925                        }
926                    }
927                }
928                .map(move |result| {
929                    if let Err(error) = result {
930                        eprintln!("[{req_id}] :: {error:?}");
931                    }
932                }),
933            )
934        }),
935    );
936
937    Ok(match rx {
938        Receiver::P2(rx) => rx
939            .await
940            .context("guest never invoked `response-outparam::set` method")?
941            .map_err(|e| wasmtime::Error::from(e))?
942            .map(|body| body.map_err(|e| e.into()).boxed_unsync()),
943        Receiver::P3(rx) => rx.await?,
944    })
945}
946
947#[derive(Clone)]
948enum Output {
949    Stdout,
950    Stderr,
951}
952
953impl Output {
954    fn write_all(&self, buf: &[u8]) -> io::Result<()> {
955        use std::io::Write;
956
957        match self {
958            Output::Stdout => std::io::stdout().write_all(buf),
959            Output::Stderr => std::io::stderr().write_all(buf),
960        }
961    }
962}
963
964#[derive(Clone)]
965struct LogStream {
966    output: Output,
967    state: Arc<LogStreamState>,
968}
969
970struct LogStreamState {
971    prefix: String,
972    needs_prefix_on_next_write: AtomicBool,
973}
974
975impl LogStream {
976    fn new(prefix: String, output: Output) -> LogStream {
977        LogStream {
978            output,
979            state: Arc::new(LogStreamState {
980                prefix,
981                needs_prefix_on_next_write: AtomicBool::new(true),
982            }),
983        }
984    }
985
986    fn write_all(&mut self, mut bytes: &[u8]) -> io::Result<()> {
987        while !bytes.is_empty() {
988            if self
989                .state
990                .needs_prefix_on_next_write
991                .load(Ordering::Relaxed)
992            {
993                self.output.write_all(self.state.prefix.as_bytes())?;
994                self.state
995                    .needs_prefix_on_next_write
996                    .store(false, Ordering::Relaxed);
997            }
998            match bytes.iter().position(|b| *b == b'\n') {
999                Some(i) => {
1000                    let (a, b) = bytes.split_at(i + 1);
1001                    bytes = b;
1002                    self.output.write_all(a)?;
1003                    self.state
1004                        .needs_prefix_on_next_write
1005                        .store(true, Ordering::Relaxed);
1006                }
1007                None => {
1008                    self.output.write_all(bytes)?;
1009                    break;
1010                }
1011            }
1012        }
1013
1014        Ok(())
1015    }
1016}
1017
1018impl wasmtime_wasi::cli::StdoutStream for LogStream {
1019    fn p2_stream(&self) -> Box<dyn wasmtime_wasi::p2::OutputStream> {
1020        Box::new(self.clone())
1021    }
1022    fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
1023        Box::new(self.clone())
1024    }
1025}
1026
1027impl wasmtime_wasi::cli::IsTerminal for LogStream {
1028    fn is_terminal(&self) -> bool {
1029        match &self.output {
1030            Output::Stdout => std::io::stdout().is_terminal(),
1031            Output::Stderr => std::io::stderr().is_terminal(),
1032        }
1033    }
1034}
1035
1036impl wasmtime_wasi::p2::OutputStream for LogStream {
1037    fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> {
1038        self.write_all(&bytes)
1039            .map_err(|e| StreamError::LastOperationFailed(e.into()))?;
1040        Ok(())
1041    }
1042
1043    fn flush(&mut self) -> StreamResult<()> {
1044        Ok(())
1045    }
1046
1047    fn check_write(&mut self) -> StreamResult<usize> {
1048        Ok(1024 * 1024)
1049    }
1050}
1051
1052#[async_trait::async_trait]
1053impl wasmtime_wasi::p2::Pollable for LogStream {
1054    async fn ready(&mut self) {}
1055}
1056
1057impl AsyncWrite for LogStream {
1058    fn poll_write(
1059        mut self: Pin<&mut Self>,
1060        _cx: &mut Context<'_>,
1061        buf: &[u8],
1062    ) -> Poll<io::Result<usize>> {
1063        Poll::Ready(self.write_all(buf).map(|_| buf.len()))
1064    }
1065    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1066        Poll::Ready(Ok(()))
1067    }
1068    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1069        Poll::Ready(Ok(()))
1070    }
1071}
1072
1073/// The pooling allocator is tailor made for the `wasmtime serve` use case, so
1074/// try to use it when we can. The main cost of the pooling allocator, however,
1075/// is the virtual memory required to run it. Not all systems support the same
1076/// amount of virtual memory, for example some aarch64 and riscv64 configuration
1077/// only support 39 bits of virtual address space.
1078///
1079/// The pooling allocator, by default, will request 1000 linear memories each
1080/// sized at 6G per linear memory. This is 6T of virtual memory which ends up
1081/// being about 42 bits of the address space. This exceeds the 39 bit limit of
1082/// some systems, so there the pooling allocator will fail by default.
1083///
1084/// This function attempts to dynamically determine the hint for the pooling
1085/// allocator. This returns `Some(true)` if the pooling allocator should be used
1086/// by default, or `None` or an error otherwise.
1087///
1088/// The method for testing this is to allocate a 0-sized 64-bit linear memory
1089/// with a maximum size that's N bits large where we force all memories to be
1090/// static. This should attempt to acquire N bits of the virtual address space.
1091/// If successful that should mean that the pooling allocator is OK to use, but
1092/// if it fails then the pooling allocator is not used and the normal mmap-based
1093/// implementation is used instead.
1094fn use_pooling_allocator_by_default() -> Result<Option<bool>> {
1095    use wasmtime::{Config, Memory, MemoryType};
1096    const BITS_TO_TEST: u32 = 42;
1097    let mut config = Config::new();
1098    config.wasm_memory64(true);
1099    config.memory_reservation(1 << BITS_TO_TEST);
1100    let engine = Engine::new(&config)?;
1101    let mut store = Store::new(&engine, ());
1102    // NB: the maximum size is in wasm pages to take out the 16-bits of wasm
1103    // page size here from the maximum size.
1104    let ty = MemoryType::new64(0, Some(1 << (BITS_TO_TEST - 16)));
1105    if Memory::new(&mut store, ty).is_ok() {
1106        Ok(Some(true))
1107    } else {
1108        Ok(None)
1109    }
1110}