pyodide-webassembly-runtime-layer 0.13.0

WASM runtime compatibility interface implementation for the webbrowser WebAssembly runtime, exposed through Pyodide.
Documentation
use std::{
    any::TypeId,
    marker::PhantomData,
    sync::{Arc, Weak},
};

use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyTuple, PyTypeInfo};
use pyo3_error::PyErrChain;
use wasm_runtime_layer::{
    backend::{AsContext, AsContextMut, Val, WasmFunc, WasmStoreContext},
    FuncType,
};
use wobbly::sync::Wobbly;

use crate::{
    conversion::{py_to_js_proxy, ToPy, ValExt},
    store::StoreContextMut,
    ArgumentVec, Engine,
};

/// A bound function, which may be an export from a WASM [`Instance`] or a host
/// function.
///
/// [`Instance`]: crate::instance::Instance
#[derive(Debug)]
pub struct Func {
    /// The inner function
    pyfunc: Py<PyAny>,
    /// The function signature
    ty: FuncType,
    /// The user state type of the context
    user_state: Option<TypeId>,
}

impl Clone for Func {
    fn clone(&self) -> Self {
        Python::attach(|py| Self {
            pyfunc: self.pyfunc.clone_ref(py),
            ty: self.ty.clone(),
            user_state: self.user_state,
        })
    }
}

impl WasmFunc<Engine> for Func {
    fn new<T: 'static>(
        mut ctx: impl AsContextMut<Engine, UserState = T>,
        ty: FuncType,
        func: impl 'static
            + Send
            + Sync
            + Fn(StoreContextMut<T>, &[Val<Engine>], &mut [Val<Engine>]) -> anyhow::Result<()>,
    ) -> Self {
        Python::attach(|py| -> Result<Self, PyErr> {
            #[cfg(feature = "tracing")]
            tracing::debug!("Func::new");

            let mut store: StoreContextMut<T> = ctx.as_context_mut();

            let weak_store = store.as_weak_proof();

            let user_state = non_static_type_id(store.data());
            let ty_clone = ty.clone();

            let func = Arc::new(move |args: Bound<PyTuple>| -> Result<Py<PyAny>, PyErr> {
                let py = args.py();

                let Some(mut strong_store) = Weak::upgrade(&weak_store) else {
                    return Err(PyRuntimeError::new_err(
                        "host func called after free of its associated store",
                    ));
                };

                // Safety:
                //
                // - The proof is constructed from a mutable store context
                // - Calling a host function (from the host or from WASM) provides that call
                //   with a mutable reborrow of the store context
                let store = unsafe { StoreContextMut::from_proof_unchecked(&mut strong_store) };

                let ty = &ty_clone;

                let args = ty
                    .params()
                    .iter()
                    .zip(args.iter())
                    .map(|(ty, arg)| Val::from_py_typed(arg, *ty))
                    .collect::<Result<ArgumentVec<_>, _>>()?;
                let mut results = vec![Val::I32(0); ty.results().len()];

                #[cfg(feature = "tracing")]
                let _span = tracing::debug_span!("call_host", ?args, ?ty).entered();

                match func(store, &args, &mut results) {
                    Ok(()) => {
                        #[cfg(feature = "tracing")]
                        tracing::debug!(?results, "result");
                    },
                    Err(err) => {
                        #[cfg(feature = "tracing")]
                        tracing::error!("{err:?}");
                        return Err(PyErrChain::pyerr_from_err(py, err));
                    },
                }

                let results = match results.as_slice() {
                    [] => py.None(),
                    [res] => res.to_py(py)?,
                    results => PyTuple::new(
                        py,
                        results
                            .iter()
                            .map(|res| res.to_py(py))
                            .collect::<Result<ArgumentVec<_>, PyErr>>()?,
                    )?
                    .into_any()
                    .unbind(),
                };

                Ok(results)
            });

            let func = Bound::new(
                py,
                PyHostFunc {
                    func: store.register_host_func(func),
                    #[cfg(feature = "tracing")]
                    ty: ty.clone(),
                },
            )?;
            let func = py_to_js_proxy(func)?;

            Ok(Self {
                pyfunc: func.unbind(),
                ty,
                user_state: Some(user_state),
            })
        })
        .expect("Func::new should not fail")
    }

    fn ty(&self, _ctx: impl AsContext<Engine>) -> FuncType {
        self.ty.clone()
    }

    fn call<T>(
        &self,
        mut ctx: impl AsContextMut<Engine>,
        args: &[Val<Engine>],
        results: &mut [Val<Engine>],
    ) -> anyhow::Result<()> {
        Python::attach(|py| {
            let store: StoreContextMut<_> = ctx.as_context_mut();

            if let Some(user_state) = self.user_state {
                assert_eq!(user_state, non_static_type_id(store.data()));
            }

            #[cfg(feature = "tracing")]
            let _span = tracing::debug_span!("call_guest", ?args, ?self.ty).entered();

            // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
            assert_eq!(self.ty.params().len(), args.len());
            assert_eq!(self.ty.results().len(), results.len());

            let args = args
                .iter()
                .map(|arg| arg.to_py(py))
                .collect::<Result<ArgumentVec<_>, PyErr>>()?;
            let args = PyTuple::new(py, args)?;

            let res = self.pyfunc.bind(py).call1(args)?;

            #[cfg(feature = "tracing")]
            tracing::debug!(%res, ?self.ty);

            match (self.ty.results(), results) {
                ([], []) => (),
                ([ty], [result]) => *result = Val::from_py_typed(res, *ty)?,
                (tys, results) => {
                    let res: Bound<PyTuple> = PyTuple::type_object(py)
                        .call1((res,))?
                        .extract()
                        .map_err(PyErr::from)?;

                    // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
                    assert_eq!(tys.len(), res.len());

                    for ((ty, result), value) in self
                        .ty
                        .results()
                        .iter()
                        .zip(results.iter_mut())
                        .zip(res.iter())
                    {
                        *result = Val::from_py_typed(value, *ty)?;
                    }
                },
            }

            Ok(())
        })
    }
}

impl ToPy for Func {
    fn to_py(&self, py: Python) -> Result<Py<PyAny>, PyErr> {
        Ok(self.pyfunc.clone_ref(py))
    }
}

impl Func {
    /// Creates a new function from a Python value
    pub(crate) fn from_exported_function(func: Bound<PyAny>, ty: FuncType) -> anyhow::Result<Self> {
        if !func.is_callable() {
            anyhow::bail!("expected WebAssembly.Function but found {func:?} which is not callable");
        }

        #[cfg(feature = "tracing")]
        tracing::debug!(%func, ?ty, "Func::from_exported_function");

        Ok(Self {
            pyfunc: func.unbind(),
            ty,
            user_state: None,
        })
    }
}

pub type PyHostFuncFn = dyn 'static + Send + Sync + Fn(Bound<PyTuple>) -> Result<Py<PyAny>, PyErr>;

#[pyclass(frozen)]
struct PyHostFunc {
    func: Wobbly<PyHostFuncFn>,
    #[cfg(feature = "tracing")]
    ty: FuncType,
}

#[pymethods]
impl PyHostFunc {
    #[pyo3(signature = (*args))]
    fn __call__(&self, args: Bound<PyTuple>) -> Result<Py<PyAny>, PyErr> {
        #[cfg(feature = "tracing")]
        let _span = tracing::debug_span!("call_trampoline", ?self.ty, args = %args).entered();

        let Some(func) = self.func.upgrade() else {
            return Err(PyRuntimeError::new_err(
                "weak host func called after free of its associated store",
            ));
        };

        func(args)
    }
}

// Courtesy of David Tolnay:
// https://github.com/rust-lang/rust/issues/41875#issuecomment-317292888
fn non_static_type_id<T: ?Sized>(_x: &T) -> TypeId {
    trait NonStaticAny {
        fn get_type_id(&self) -> TypeId
        where
            Self: 'static;
    }

    impl<T: ?Sized> NonStaticAny for PhantomData<T> {
        fn get_type_id(&self) -> TypeId
        where
            Self: 'static,
        {
            TypeId::of::<T>()
        }
    }

    let phantom_data = PhantomData::<T>;
    NonStaticAny::get_type_id(unsafe {
        core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
    })
}