wasm-bindgen-threads-xform 0.2.84

Support for threading-related transformations in wasm-bindgen
Documentation
use anyhow::{anyhow, bail, Error};
use std::cmp;
use std::env;
use walrus::ir::Value;
use walrus::{
    ir::MemArg, ExportItem, FunctionId, GlobalId, GlobalKind, InitExpr, InstrSeqBuilder, MemoryId,
    Module, ValType,
};
use wasm_bindgen_wasm_conventions as wasm_conventions;

const PAGE_SIZE: u32 = 1 << 16;
const ATOMIC_MEM_ARG: MemArg = MemArg {
    align: 4,
    offset: 0,
};

/// Configuration for the transformation pass in this module.
///
/// Created primarily through `new` and then executed through `run`.
pub struct Config {
    maximum_memory: u32,
    thread_stack_size: u32,
    enabled: bool,
}

#[derive(Clone, Copy)]
pub struct ThreadCount(walrus::LocalId);

impl Config {
    /// Create a new configuration with default settings.
    pub fn new() -> Config {
        Config {
            maximum_memory: 1 << 30,    // 1GB
            thread_stack_size: 1 << 20, // 1MB
            enabled: env::var("WASM_BINDGEN_THREADS").is_ok(),
        }
    }

    /// Is threaded Wasm enabled?
    pub fn is_enabled(&self, module: &Module) -> bool {
        if self.enabled {
            return true;
        }

        // Compatibility with older LLVM outputs. Newer LLVM outputs, when
        // atomics are enabled, emit a shared memory. That's a good indicator
        // that we have work to do. If shared memory isn't enabled, though then
        // this isn't an atomic module so there's nothing to do. We still allow,
        // though, an environment variable to force us to go down this path to
        // remain compatibile with older LLVM outputs.
        match wasm_conventions::get_memory(module) {
            Ok(memory) => module.memories.get(memory).shared,
            Err(_) => false,
        }
    }

    /// Specify the maximum amount of memory the wasm module can ever have.
    ///
    /// We'll be specifying that the memory for this wasm module is shared, and
    /// all shared memories must have their maximum limit specified (whereas
    /// by default Rust/LLVM/LLD don't specify a maximum).
    ///
    /// The default for this option is 16MB, and this can be used to change
    /// the maximum memory we'll be specifying.
    ///
    /// The `max` argument is in units of bytes.
    ///
    /// If the maximum memory is already specified this setting won't have any
    /// affect.
    pub fn maximum_memory(&mut self, max: u32) -> &mut Config {
        self.maximum_memory = max;
        self
    }

    /// Specify the stack size for all threads spawned.
    ///
    /// The stack size is typically set by rustc as an argument to LLD and
    /// defaults to 1MB for the main thread. All threads spawned by the
    /// main thread, however, need to allocate their own stack!
    ///
    /// This configuration option indicates how large the stack of each child
    /// thread will be. This will be allocated as part of the `start` function
    /// and will be stored in LLVM's global stack pointer.
    pub fn thread_stack_size(&mut self, size: u32) -> &mut Config {
        self.thread_stack_size = size;
        self
    }

    /// Execute the transformation on the parsed wasm module specified.
    ///
    /// This function will prepare `Module` to be run on multiple threads,
    /// performing steps such as:
    ///
    /// * All data segments are switched to "passive" data segments to ensure
    ///   they're only initialized once (coming later)
    /// * If memory is exported from this module, it is instead switched to
    ///   being imported (with the same parameters).
    /// * The imported memory is required to be `shared`, ensuring it's backed
    ///   by a `SharedArrayBuffer` on the web.
    /// * A `global` for a thread ID is injected.
    /// * Four bytes in linear memory are reserved for the counter of thread
    ///   IDs.
    /// * A `start` function is injected (or prepended if one already exists)
    ///   which initializes memory for the first thread and otherwise allocates
    ///   thread ids for all threads.
    /// * Some stack space is prepared for each thread after the first one.
    ///
    /// More and/or less may happen here over time, stay tuned!
    pub fn run(&self, module: &mut Module) -> Result<Option<ThreadCount>, Error> {
        if !self.is_enabled(module) {
            return Ok(None);
        }

        let memory = wasm_conventions::get_memory(module)?;

        // Now we need to allocate extra static memory for:
        // - A thread id counter.
        // - A temporary stack for calls to `malloc()` and `free()`.
        // - A lock to synchronize usage of the above stack.
        // For this, we allocate 1 extra page of memory (should be enough as temporary
        // stack) and grab the first 2 _aligned_ i32 words to use as counter and lock.
        let static_data_align = 4;
        let static_data_pages = 1;
        let (base, addr) =
            allocate_static_data(module, memory, static_data_pages, static_data_align)?;

        let mem = module.memories.get_mut(memory);
        assert!(mem.shared);
        let prev_max = mem.maximum.unwrap();
        assert!(mem.import.is_some());
        mem.maximum = Some(cmp::max(self.maximum_memory / PAGE_SIZE, prev_max));
        assert!(mem.data_segments.is_empty());

        let tls = Tls {
            init: delete_synthetic_func(module, "__wasm_init_tls")?,
            size: delete_synthetic_global(module, "__tls_size")?,
            align: delete_synthetic_global(module, "__tls_align")?,
            base: wasm_conventions::get_tls_base(module)
                .ok_or_else(|| anyhow!("failed to find tls base"))?,
        };

        let thread_counter_addr = addr as i32;

        let stack_alloc =
            module
                .globals
                .add_local(ValType::I32, true, InitExpr::Value(Value::I32(0)));

        // Make sure the temporary stack is aligned down
        let temp_stack = (base + static_data_pages * PAGE_SIZE) & !(static_data_align - 1);

        let stack = Stack {
            pointer: wasm_conventions::get_shadow_stack_pointer(module)
                .ok_or_else(|| anyhow!("failed to find shadow stack pointer"))?,
            temp: temp_stack as i32,
            temp_lock: thread_counter_addr + 4,
            alloc: stack_alloc,
            size: self.thread_stack_size,
        };

        let _ = module.exports.add("__stack_alloc", stack.alloc);

        let thread_count = inject_start(module, &tls, &stack, thread_counter_addr, memory)?;

        // we expose a `__wbindgen_thread_destroy()` helper function that deallocates stack space.
        //
        // ## Safety
        // After calling this function in a given agent, the instance should be considered
        // "destroyed" and any further invocations into it will trigger UB. This function
        // should not be called from an agent that cannot block (e.g. the main document thread).
        //
        // You can also call it from a "leader" agent, passing appropriate values, if said leader
        // is in charge of cleaning up after a "follower" agent. In that case:
        // - The "appropriate values" are the values of the `__tls_base` and `__stack_alloc` globals
        //   from the follower thread, after initialization.
        // - The leader does _not_ need to block.
        // - Similar restrictions apply: the follower thread should be considered unusable afterwards,
        //   the leader should not call this function with the same set of parameteres twice.
        // - Moreover, concurrent calls can lead to UB: the follower could be in the middle of a
        //   call while the leader is destroying its stack! You should make sure that this cannot happen.
        inject_destroy(module, &tls, &stack, memory)?;

        Ok(Some(thread_count))
    }
}

impl ThreadCount {
    pub fn wrap_start(self, builder: &mut walrus::FunctionBuilder, start: FunctionId) {
        // We only want to call the start function if we are in the first thread.
        // The thread counter should be 0 for the first thread.
        builder.func_body().local_get(self.0).if_else(
            None,
            |_| {},
            |body| {
                body.call(start);
            },
        );
    }
}

fn delete_synthetic_func(module: &mut Module, name: &str) -> Result<FunctionId, Error> {
    match delete_synthetic_export(module, name)? {
        walrus::ExportItem::Function(f) => Ok(f),
        _ => bail!("`{}` must be a function", name),
    }
}

fn delete_synthetic_global(module: &mut Module, name: &str) -> Result<u32, Error> {
    let id = match delete_synthetic_export(module, name)? {
        walrus::ExportItem::Global(g) => g,
        _ => bail!("`{}` must be a global", name),
    };
    let g = match module.globals.get(id).kind {
        walrus::GlobalKind::Local(g) => g,
        walrus::GlobalKind::Import(_) => bail!("`{}` must not be an imported global", name),
    };
    match g {
        InitExpr::Value(Value::I32(v)) => Ok(v as u32),
        _ => bail!("`{}` was not an `i32` constant", name),
    }
}

fn delete_synthetic_export(module: &mut Module, name: &str) -> Result<ExportItem, Error> {
    let item = module
        .exports
        .iter()
        .find(|e| e.name == name)
        .ok_or_else(|| anyhow!("failed to find `{}`", name))?;
    let ret = item.item;
    let id = item.id();
    module.exports.delete(id);
    Ok(ret)
}

/// Allocates extra space for static data. Returns `(addr, base)`, where:
/// - `base` is the starting address of the extra `pages`.
/// - `addr` is the _first_ address in that chunk that is aligned to `align`.
fn allocate_static_data(
    module: &mut Module,
    memory: MemoryId,
    pages: u32,
    align: u32,
) -> Result<(u32, u32), Error> {
    // First up, look for a `__heap_base` export which is injected by LLD as
    // part of the linking process. Note that `__heap_base` should in theory be
    // *after* the stack and data, which means it's at the very end of the
    // address space and should be safe for us to inject extra pages of data at.
    let heap_base = module
        .exports
        .iter()
        .filter(|e| e.name == "__heap_base")
        .filter_map(|e| match e.item {
            ExportItem::Global(id) => Some(id),
            _ => None,
        })
        .next();
    let heap_base = match heap_base {
        Some(idx) => idx,
        None => bail!("failed to find `__heap_base` for injecting thread id"),
    };

    // Now we need to bump up `__heap_base` by a few pages. Do lots of validation
    // here to make sure that `__heap_base` is an non-mutable integer, and then do
    // some logic to ensure that the return the correct, aligned `address` as specified
    // by `align`.
    let (base, address) = {
        let global = module.globals.get_mut(heap_base);
        if global.ty != ValType::I32 {
            bail!("the `__heap_base` global doesn't have the type `i32`");
        }
        if global.mutable {
            bail!("the `__heap_base` global is unexpectedly mutable");
        }
        let offset = match &mut global.kind {
            GlobalKind::Local(InitExpr::Value(Value::I32(n))) => n,
            _ => bail!("`__heap_base` not a locally defined `i32`"),
        };

        let address = (*offset as u32 + (align - 1)) & !(align - 1); // align up
        let base = *offset;

        *offset = *offset + (pages * PAGE_SIZE) as i32;

        (base, address)
    };

    let memory = module.memories.get_mut(memory);
    memory.initial += pages;
    memory.maximum = memory.maximum.map(|m| cmp::max(m, memory.initial));

    Ok((base as u32, address))
}

struct Tls {
    init: walrus::FunctionId,
    size: u32,
    align: u32,
    base: GlobalId,
}

struct Stack {
    /// The stack pointer global
    pointer: GlobalId,
    /// The address of a small, "scratch-space" stack
    temp: i32,
    /// The address of a lock for the temporary stack
    temp_lock: i32,
    /// A global to store allocated stack
    alloc: GlobalId,
    /// The size of the stack
    size: u32,
}

fn inject_start(
    module: &mut Module,
    tls: &Tls,
    stack: &Stack,
    thread_counter_addr: i32,
    memory: MemoryId,
) -> Result<ThreadCount, Error> {
    use walrus::ir::*;

    assert!(stack.size % PAGE_SIZE == 0);

    let local = module.locals.add(ValType::I32);
    let thread_count = module.locals.add(ValType::I32);

    let malloc = find_function(module, "__wbindgen_malloc")?;

    let builder = wasm_bindgen_wasm_conventions::get_or_insert_start_builder(module);

    let mut body = builder.func_body();

    // Perform an if/else based on whether we're the first thread or not. Our
    // thread ID will be zero if we're the first thread, otherwise it'll be
    // nonzero (assuming we don't overflow...)
    body.i32_const(thread_counter_addr)
        .i32_const(1)
        .atomic_rmw(memory, AtomicOp::Add, AtomicWidth::I32, ATOMIC_MEM_ARG)
        .local_tee(thread_count)
        .if_else(
            None,
            // If our thread id is nonzero then we're the second or greater thread, so
            // we give ourselves a stack and we update our stack
            // pointer as the default stack pointer is surely wrong for us.
            |body| {
                // local = malloc(stack.size) [aka base]
                with_temp_stack(body, memory, &stack, |body| {
                    body.i32_const(stack.size as i32)
                        .call(malloc)
                        .local_tee(local);
                });

                // stack.alloc = base
                body.global_set(stack.alloc);

                // stack_pointer = base + stack.size
                body.global_get(stack.alloc)
                    .i32_const(stack.size as i32)
                    .binop(BinaryOp::I32Add)
                    .global_set(stack.pointer);
            },
            // If the thread id is zero then the default stack pointer works for
            // us.
            |_| {},
        );

    // Afterwards we need to initialize our thread-local state.
    body.i32_const(tls.size as i32)
        .i32_const(tls.align as i32)
        .drop() // TODO: need to actually respect alignment
        .call(malloc)
        .global_set(tls.base)
        .global_get(tls.base)
        .call(tls.init);

    Ok(ThreadCount(thread_count))
}

fn inject_destroy(
    module: &mut Module,
    tls: &Tls,
    stack: &Stack,
    memory: MemoryId,
) -> Result<(), Error> {
    let free = find_function(module, "__wbindgen_free")?;

    let mut builder =
        walrus::FunctionBuilder::new(&mut module.types, &[ValType::I32, ValType::I32], &[]);

    builder.name("__wbindgen_thread_destroy".into());

    let mut body = builder.func_body();

    // if no explicit parameters are passed (i.e. their value is 0) then we assume
    // we're being called from the agent that must be destroyed and rely on its globals
    let tls_base = module.locals.add(ValType::I32);
    let stack_alloc = module.locals.add(ValType::I32);

    // Ideally, at this point, we would destroy the values stored in TLS.
    // We can't really do that without help from the standard library.
    // See https://github.com/rustwasm/wasm-bindgen/pull/2769#issuecomment-1015775467.

    body.local_get(tls_base).if_else(
        None,
        |body| {
            body.local_get(tls_base)
                .i32_const(tls.size as i32)
                .call(free);
        },
        |body| {
            body.global_get(tls.base)
                .i32_const(tls.size as i32)
                .call(free);

            // set tls.base = i32::MIN to trigger invalid memory
            body.i32_const(i32::MIN).global_set(tls.base);
        },
    );

    // free the stack calling `__wbindgen_free(stack.alloc, stack.size)`
    body.local_get(stack_alloc).if_else(
        None,
        |body| {
            // we're destroying somebody else's stack, so we can use our own
            body.local_get(stack_alloc)
                .i32_const(stack.size as i32)
                .call(free);
        },
        |mut body| {
            with_temp_stack(&mut body, memory, stack, |body| {
                body.global_get(stack.alloc)
                    .i32_const(stack.size as i32)
                    .call(free);
            });

            // set stack.alloc = 0 to trigger invalid memory
            body.i32_const(0).global_set(stack.alloc);
        },
    );

    let destroy_id = builder.finish(Vec::new(), &mut module.funcs);

    module.exports.add("__wbindgen_thread_destroy", destroy_id);

    Ok(())
}

fn find_function(module: &Module, name: &str) -> Result<FunctionId, Error> {
    let e = module
        .exports
        .iter()
        .find(|e| e.name == name)
        .ok_or_else(|| anyhow!("failed to find `{}`", name))?;
    match e.item {
        walrus::ExportItem::Function(f) => Ok(f),
        _ => bail!("`{}` wasn't a funtion", name),
    }
}

/// Wraps the instructions fed by `block()` so that they can assume that the temporary, scratch
/// stack is usable. Clobbers `stack.pointer`.
fn with_temp_stack(
    body: &mut InstrSeqBuilder<'_>,
    memory: MemoryId,
    stack: &Stack,
    block: impl Fn(&mut InstrSeqBuilder<'_>),
) {
    use walrus::ir::*;

    body.i32_const(stack.temp).global_set(stack.pointer);

    body.loop_(None, |loop_| {
        let loop_id = loop_.id();

        loop_
            .i32_const(stack.temp_lock)
            .i32_const(0)
            .i32_const(1)
            .cmpxchg(memory, AtomicWidth::I32, ATOMIC_MEM_ARG)
            .if_else(
                None,
                |body| {
                    body.i32_const(stack.temp_lock)
                        .i32_const(1)
                        .i64_const(-1)
                        .atomic_wait(memory, ATOMIC_MEM_ARG, false)
                        .drop()
                        .br(loop_id);
                },
                |_| {},
            );
    });

    block(body);

    body.i32_const(stack.temp_lock)
        .i32_const(0)
        .store(memory, StoreKind::I32 { atomic: true }, ATOMIC_MEM_ARG)
        .i32_const(stack.temp_lock)
        .i32_const(1)
        .atomic_notify(memory, ATOMIC_MEM_ARG)
        .drop();
}