wasm-server-runner 1.0.1

cargo run for wasm programs
mod certificate;

use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap};
use std::net::SocketAddr;
use std::path::PathBuf;

use anyhow::Context as _;
use axum::error_handling::HandleError;
use axum::extract::ws::{self, WebSocket};
use axum::extract::{Path, WebSocketUpgrade};
use axum::http::{HeaderValue, StatusCode};
use axum::response::{Html, IntoResponse, Response};
use axum::routing::{get, get_service};
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use axum_server_dual_protocol::ServerExt;
use http::HeaderName;
use tower::ServiceBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::services::ServeDir;
use tower_http::set_header::SetResponseHeaderLayer;

use crate::wasm_bindgen::WasmBindgenOutput;
use crate::Result;

fn generate_version() -> String {
    std::iter::repeat_with(fastrand::alphanumeric).take(12).collect()
}

pub struct Options {
    pub title: String,
    pub address: String,
    pub directory: PathBuf,
    pub custom_index_html: Option<PathBuf>,
    pub https: bool,
    pub no_module: bool,
}

pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<()> {
    let WasmBindgenOutput { js, wasm, snippets, local_modules } = output;

    let middleware_stack = ServiceBuilder::new()
        .layer(CompressionLayer::new())
        .layer(SetResponseHeaderLayer::if_not_present(
            HeaderName::from_static("cross-origin-opener-policy"),
            HeaderValue::from_static("same-origin"),
        ))
        .layer(SetResponseHeaderLayer::if_not_present(
            HeaderName::from_static("cross-origin-embedder-policy"),
            HeaderValue::from_static("require-corp"),
        ))
        .into_inner();

    let version = generate_version();

    let html_source = options
        .custom_index_html
        .map(|index_html_path| {
            let path = match index_html_path.is_absolute() {
                true => index_html_path,
                false => options.directory.join(index_html_path),
            };
            std::fs::read_to_string(path).map(Cow::Owned)
        })
        .unwrap_or_else(|| Ok(Cow::Borrowed(include_str!("../static/index.html"))))?;
    let mut html = html_source.replace("{{ TITLE }}", &options.title);

    if options.no_module {
        html = html
            .replace("{{ NO_MODULE }}", "<script src=\"./api/wasm.js\"></script>")
            .replace("// {{ MODULE }}", "");
    } else {
        html = html
            .replace("// {{ MODULE }}", "import wasm_bindgen from './api/wasm.js';")
            .replace("{{ NO_MODULE }}", "");
    };

    let serve_dir =
        HandleError::new(get_service(ServeDir::new(options.directory)), internal_server_error);

    let app = Router::new()
        .route("/", get(move || async { Html(html) }))
        .route("/api/wasm.js", get(|| async { WithContentType("application/javascript", js) }))
        .route("/api/wasm.wasm", get(|| async { WithContentType("application/wasm", wasm) }))
        .route("/api/version", get(move || async { version }))
        .route("/ws", get(|ws: WebSocketUpgrade| async { ws.on_upgrade(handle_ws) }))
        .route(
            "/api/snippets/{*rest}",
            get(|Path(path): Path<String>| async move {
                match get_snippet_source(&path, &local_modules, &snippets) {
                    Ok(source) => Ok(WithContentType("application/javascript", source)),
                    Err(e) => {
                        tracing::error!("failed to serve snippet `{path}`: {e}");
                        Err(e)
                    }
                }
            }),
        )
        .fallback_service(serve_dir)
        .layer(middleware_stack);

    let addr: SocketAddr = options
        .address
        .parse()
        .or_else(|_| {
            format!("{}:{}", options.address, pick_port::pick_free_port(1334, 10).unwrap_or(1334))
                .parse()
        })
        .context("Error parsing WASM_SERVER_RUNNER_ADDRESS")?;

    if options.https {
        let certificate = certificate::certificate()?;
        let config =
            RustlsConfig::from_der(vec![certificate.certificate], certificate.private_key).await?;

        tracing::info!(target: "wasm_server_runner", "starting webserver at https://{}", addr);
        axum_server_dual_protocol::bind_dual_protocol(addr, config)
            .set_upgrade(true)
            .serve(app.into_make_service())
            .await?;
    } else {
        tracing::info!(target: "wasm_server_runner", "starting webserver at http://{}", addr);
        axum_server::bind(addr).serve(app.into_make_service()).await?;
    }

    Ok(())
}

fn get_snippet_source(
    path: &str,
    local_modules: &HashMap<String, String>,
    snippets: &BTreeMap<String, Vec<String>>,
) -> Result<String, &'static str> {
    if let Some(module) = local_modules.get(path) {
        return Ok(module.clone());
    };

    let (snippet, inline_snippet_name) = path.split_once('/').ok_or("invalid snippet path")?;
    let index = inline_snippet_name
        .strip_prefix("inline")
        .and_then(|path| path.strip_suffix(".js"))
        .ok_or("invalid snippet name in path")?;
    let index: usize = index.parse().map_err(|_| "invalid index")?;
    let snippet = snippets
        .get(snippet)
        .ok_or("invalid snippet name")?
        .get(index)
        .ok_or("snippet index out of bounds")?;
    Ok(snippet.clone())
}

async fn handle_ws(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        let msg = match msg {
            Ok(msg) => msg,
            Err(e) => return tracing::warn!("got error {e}, closing websocket connection"),
        };

        let msg = match msg {
            ws::Message::Text(msg) => msg,
            ws::Message::Close(_) => return,
            _ => unreachable!("got non-text message from websocket"),
        };

        let (mut level, mut text) = msg.split_once(',').unwrap();

        if let Some(rest) = text.strip_prefix("TRACE ") {
            level = "debug";
            text = rest;
        } else if let Some(rest) = text.strip_prefix("DEBUG ") {
            level = "debug";
            text = rest;
        } else if let Some(rest) = text.strip_prefix("INFO ") {
            level = "info";
            text = rest;
        } else if let Some(rest) = text.strip_prefix("WARN ") {
            level = "warn";
            text = rest;
        } else if let Some(rest) = text.strip_prefix("ERROR ") {
            level = "error";
            text = rest;
        }

        match level {
            "log" => tracing::info!(target: "app", "{text}"),

            "trace" => tracing::trace!(target: "app", "{text}"),
            "debug" => tracing::debug!(target: "app", "{text}"),
            "info" => tracing::info!(target: "app", "{text}"),
            "warn" => tracing::warn!(target: "app", "{text}"),
            "error" => tracing::error!(target: "app", "{text}"),
            _ => unimplemented!("unexpected log level {level}: {text}"),
        }
    }
}

struct WithContentType<T>(&'static str, T);
impl<T: IntoResponse> IntoResponse for WithContentType<T> {
    fn into_response(self) -> Response {
        let mut response = self.1.into_response();
        response.headers_mut().insert("Content-Type", HeaderValue::from_static(self.0));
        response
    }
}

async fn internal_server_error(error: impl std::fmt::Display) -> impl IntoResponse {
    (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error))
}

mod pick_port {
    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, TcpListener, ToSocketAddrs};

    fn test_bind_tcp<A: ToSocketAddrs>(addr: A) -> Option<u16> {
        Some(TcpListener::bind(addr).ok()?.local_addr().ok()?.port())
    }
    fn is_free_tcp(port: u16) -> bool {
        let ipv4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port);
        let ipv6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0);

        test_bind_tcp(ipv6).is_some() && test_bind_tcp(ipv4).is_some()
    }

    fn ask_free_tcp_port() -> Option<u16> {
        let ipv4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
        let ipv6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0);
        test_bind_tcp(ipv6).or_else(|| test_bind_tcp(ipv4))
    }

    pub fn pick_free_port(starting_at: u16, try_consecutive: u16) -> Option<u16> {
        (starting_at..=starting_at + try_consecutive)
            .find(|&port| is_free_tcp(port))
            .or_else(ask_free_tcp_port)
    }
}