Skip to main content

pyodide_webassembly_runtime_layer/
func.rs

1use std::{
2    any::TypeId,
3    marker::PhantomData,
4    sync::{Arc, Weak},
5};
6
7use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyTuple, PyTypeInfo};
8use pyo3_error::PyErrChain;
9use wasm_runtime_layer::{
10    backend::{AsContext, AsContextMut, Val, WasmFunc, WasmStoreContext},
11    FuncType,
12};
13use wobbly::sync::Wobbly;
14
15use crate::{
16    conversion::{py_to_js_proxy, ToPy, ValExt},
17    store::StoreContextMut,
18    ArgumentVec, Engine,
19};
20
21/// A bound function, which may be an export from a WASM [`Instance`] or a host
22/// function.
23///
24/// [`Instance`]: crate::instance::Instance
25#[derive(Debug)]
26pub struct Func {
27    /// The inner function
28    pyfunc: Py<PyAny>,
29    /// The function signature
30    ty: FuncType,
31    /// The user state type of the context
32    user_state: Option<TypeId>,
33}
34
35impl Clone for Func {
36    fn clone(&self) -> Self {
37        Python::attach(|py| Self {
38            pyfunc: self.pyfunc.clone_ref(py),
39            ty: self.ty.clone(),
40            user_state: self.user_state,
41        })
42    }
43}
44
45impl WasmFunc<Engine> for Func {
46    fn new<T: 'static>(
47        mut ctx: impl AsContextMut<Engine, UserState = T>,
48        ty: FuncType,
49        func: impl 'static
50            + Send
51            + Sync
52            + Fn(StoreContextMut<T>, &[Val<Engine>], &mut [Val<Engine>]) -> anyhow::Result<()>,
53    ) -> Self {
54        Python::attach(|py| -> Result<Self, PyErr> {
55            #[cfg(feature = "tracing")]
56            tracing::debug!("Func::new");
57
58            let mut store: StoreContextMut<T> = ctx.as_context_mut();
59
60            let weak_store = store.as_weak_proof();
61
62            let user_state = non_static_type_id(store.data());
63            let ty_clone = ty.clone();
64
65            let func = Arc::new(move |args: Bound<PyTuple>| -> Result<Py<PyAny>, PyErr> {
66                let py = args.py();
67
68                let Some(mut strong_store) = Weak::upgrade(&weak_store) else {
69                    return Err(PyRuntimeError::new_err(
70                        "host func called after free of its associated store",
71                    ));
72                };
73
74                // Safety:
75                //
76                // - The proof is constructed from a mutable store context
77                // - Calling a host function (from the host or from WASM) provides that call
78                //   with a mutable reborrow of the store context
79                let store = unsafe { StoreContextMut::from_proof_unchecked(&mut strong_store) };
80
81                let ty = &ty_clone;
82
83                let args = ty
84                    .params()
85                    .iter()
86                    .zip(args.iter())
87                    .map(|(ty, arg)| Val::from_py_typed(arg, *ty))
88                    .collect::<Result<ArgumentVec<_>, _>>()?;
89                let mut results = vec![Val::I32(0); ty.results().len()];
90
91                #[cfg(feature = "tracing")]
92                let _span = tracing::debug_span!("call_host", ?args, ?ty).entered();
93
94                match func(store, &args, &mut results) {
95                    Ok(()) => {
96                        #[cfg(feature = "tracing")]
97                        tracing::debug!(?results, "result");
98                    },
99                    Err(err) => {
100                        #[cfg(feature = "tracing")]
101                        tracing::error!("{err:?}");
102                        return Err(PyErrChain::pyerr_from_err(py, err));
103                    },
104                }
105
106                let results = match results.as_slice() {
107                    [] => py.None(),
108                    [res] => res.to_py(py)?,
109                    results => PyTuple::new(
110                        py,
111                        results
112                            .iter()
113                            .map(|res| res.to_py(py))
114                            .collect::<Result<ArgumentVec<_>, PyErr>>()?,
115                    )?
116                    .into_any()
117                    .unbind(),
118                };
119
120                Ok(results)
121            });
122
123            let func = Bound::new(
124                py,
125                PyHostFunc {
126                    func: store.register_host_func(func),
127                    #[cfg(feature = "tracing")]
128                    ty: ty.clone(),
129                },
130            )?;
131            let func = py_to_js_proxy(func)?;
132
133            Ok(Self {
134                pyfunc: func.unbind(),
135                ty,
136                user_state: Some(user_state),
137            })
138        })
139        .expect("Func::new should not fail")
140    }
141
142    fn ty(&self, _ctx: impl AsContext<Engine>) -> FuncType {
143        self.ty.clone()
144    }
145
146    fn call<T>(
147        &self,
148        mut ctx: impl AsContextMut<Engine>,
149        args: &[Val<Engine>],
150        results: &mut [Val<Engine>],
151    ) -> anyhow::Result<()> {
152        Python::attach(|py| {
153            let store: StoreContextMut<_> = ctx.as_context_mut();
154
155            if let Some(user_state) = self.user_state {
156                assert_eq!(user_state, non_static_type_id(store.data()));
157            }
158
159            #[cfg(feature = "tracing")]
160            let _span = tracing::debug_span!("call_guest", ?args, ?self.ty).entered();
161
162            // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
163            assert_eq!(self.ty.params().len(), args.len());
164            assert_eq!(self.ty.results().len(), results.len());
165
166            let args = args
167                .iter()
168                .map(|arg| arg.to_py(py))
169                .collect::<Result<ArgumentVec<_>, PyErr>>()?;
170            let args = PyTuple::new(py, args)?;
171
172            let res = self.pyfunc.bind(py).call1(args)?;
173
174            #[cfg(feature = "tracing")]
175            tracing::debug!(%res, ?self.ty);
176
177            match (self.ty.results(), results) {
178                ([], []) => (),
179                ([ty], [result]) => *result = Val::from_py_typed(res, *ty)?,
180                (tys, results) => {
181                    let res: Bound<PyTuple> = PyTuple::type_object(py)
182                        .call1((res,))?
183                        .extract()
184                        .map_err(PyErr::from)?;
185
186                    // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
187                    assert_eq!(tys.len(), res.len());
188
189                    for ((ty, result), value) in self
190                        .ty
191                        .results()
192                        .iter()
193                        .zip(results.iter_mut())
194                        .zip(res.iter())
195                    {
196                        *result = Val::from_py_typed(value, *ty)?;
197                    }
198                },
199            }
200
201            Ok(())
202        })
203    }
204}
205
206impl ToPy for Func {
207    fn to_py(&self, py: Python) -> Result<Py<PyAny>, PyErr> {
208        Ok(self.pyfunc.clone_ref(py))
209    }
210}
211
212impl Func {
213    /// Creates a new function from a Python value
214    pub(crate) fn from_exported_function(func: Bound<PyAny>, ty: FuncType) -> anyhow::Result<Self> {
215        if !func.is_callable() {
216            anyhow::bail!("expected WebAssembly.Function but found {func:?} which is not callable");
217        }
218
219        #[cfg(feature = "tracing")]
220        tracing::debug!(%func, ?ty, "Func::from_exported_function");
221
222        Ok(Self {
223            pyfunc: func.unbind(),
224            ty,
225            user_state: None,
226        })
227    }
228}
229
230pub type PyHostFuncFn = dyn 'static + Send + Sync + Fn(Bound<PyTuple>) -> Result<Py<PyAny>, PyErr>;
231
232#[pyclass(frozen)]
233struct PyHostFunc {
234    func: Wobbly<PyHostFuncFn>,
235    #[cfg(feature = "tracing")]
236    ty: FuncType,
237}
238
239#[pymethods]
240impl PyHostFunc {
241    #[pyo3(signature = (*args))]
242    fn __call__(&self, args: Bound<PyTuple>) -> Result<Py<PyAny>, PyErr> {
243        #[cfg(feature = "tracing")]
244        let _span = tracing::debug_span!("call_trampoline", ?self.ty, args = %args).entered();
245
246        let Some(func) = self.func.upgrade() else {
247            return Err(PyRuntimeError::new_err(
248                "weak host func called after free of its associated store",
249            ));
250        };
251
252        func(args)
253    }
254}
255
256// Courtesy of David Tolnay:
257// https://github.com/rust-lang/rust/issues/41875#issuecomment-317292888
258fn non_static_type_id<T: ?Sized>(_x: &T) -> TypeId {
259    trait NonStaticAny {
260        fn get_type_id(&self) -> TypeId
261        where
262            Self: 'static;
263    }
264
265    impl<T: ?Sized> NonStaticAny for PhantomData<T> {
266        fn get_type_id(&self) -> TypeId
267        where
268            Self: 'static,
269        {
270            TypeId::of::<T>()
271        }
272    }
273
274    let phantom_data = PhantomData::<T>;
275    NonStaticAny::get_type_id(unsafe {
276        core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
277    })
278}