pyodide_webassembly_runtime_layer/
func.rs1use std::{
2 any::TypeId,
3 marker::PhantomData,
4 sync::{Arc, Weak},
5};
6
7use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyTuple, PyTypeInfo};
8use pyo3_error::PyErrChain;
9use wasm_runtime_layer::{
10 backend::{AsContext, AsContextMut, Val, WasmFunc, WasmStoreContext},
11 FuncType,
12};
13use wobbly::sync::Wobbly;
14
15use crate::{
16 conversion::{py_to_js_proxy, ToPy, ValExt},
17 store::StoreContextMut,
18 ArgumentVec, Engine,
19};
20
21#[derive(Debug)]
26pub struct Func {
27 pyfunc: Py<PyAny>,
29 ty: FuncType,
31 user_state: Option<TypeId>,
33}
34
35impl Clone for Func {
36 fn clone(&self) -> Self {
37 Python::attach(|py| Self {
38 pyfunc: self.pyfunc.clone_ref(py),
39 ty: self.ty.clone(),
40 user_state: self.user_state,
41 })
42 }
43}
44
45impl WasmFunc<Engine> for Func {
46 fn new<T: 'static>(
47 mut ctx: impl AsContextMut<Engine, UserState = T>,
48 ty: FuncType,
49 func: impl 'static
50 + Send
51 + Sync
52 + Fn(StoreContextMut<T>, &[Val<Engine>], &mut [Val<Engine>]) -> anyhow::Result<()>,
53 ) -> Self {
54 Python::attach(|py| -> Result<Self, PyErr> {
55 #[cfg(feature = "tracing")]
56 tracing::debug!("Func::new");
57
58 let mut store: StoreContextMut<T> = ctx.as_context_mut();
59
60 let weak_store = store.as_weak_proof();
61
62 let user_state = non_static_type_id(store.data());
63 let ty_clone = ty.clone();
64
65 let func = Arc::new(move |args: Bound<PyTuple>| -> Result<Py<PyAny>, PyErr> {
66 let py = args.py();
67
68 let Some(mut strong_store) = Weak::upgrade(&weak_store) else {
69 return Err(PyRuntimeError::new_err(
70 "host func called after free of its associated store",
71 ));
72 };
73
74 let store = unsafe { StoreContextMut::from_proof_unchecked(&mut strong_store) };
80
81 let ty = &ty_clone;
82
83 let args = ty
84 .params()
85 .iter()
86 .zip(args.iter())
87 .map(|(ty, arg)| Val::from_py_typed(arg, *ty))
88 .collect::<Result<ArgumentVec<_>, _>>()?;
89 let mut results = vec![Val::I32(0); ty.results().len()];
90
91 #[cfg(feature = "tracing")]
92 let _span = tracing::debug_span!("call_host", ?args, ?ty).entered();
93
94 match func(store, &args, &mut results) {
95 Ok(()) => {
96 #[cfg(feature = "tracing")]
97 tracing::debug!(?results, "result");
98 },
99 Err(err) => {
100 #[cfg(feature = "tracing")]
101 tracing::error!("{err:?}");
102 return Err(PyErrChain::pyerr_from_err(py, err));
103 },
104 }
105
106 let results = match results.as_slice() {
107 [] => py.None(),
108 [res] => res.to_py(py)?,
109 results => PyTuple::new(
110 py,
111 results
112 .iter()
113 .map(|res| res.to_py(py))
114 .collect::<Result<ArgumentVec<_>, PyErr>>()?,
115 )?
116 .into_any()
117 .unbind(),
118 };
119
120 Ok(results)
121 });
122
123 let func = Bound::new(
124 py,
125 PyHostFunc {
126 func: store.register_host_func(func),
127 #[cfg(feature = "tracing")]
128 ty: ty.clone(),
129 },
130 )?;
131 let func = py_to_js_proxy(func)?;
132
133 Ok(Self {
134 pyfunc: func.unbind(),
135 ty,
136 user_state: Some(user_state),
137 })
138 })
139 .expect("Func::new should not fail")
140 }
141
142 fn ty(&self, _ctx: impl AsContext<Engine>) -> FuncType {
143 self.ty.clone()
144 }
145
146 fn call<T>(
147 &self,
148 mut ctx: impl AsContextMut<Engine>,
149 args: &[Val<Engine>],
150 results: &mut [Val<Engine>],
151 ) -> anyhow::Result<()> {
152 Python::attach(|py| {
153 let store: StoreContextMut<_> = ctx.as_context_mut();
154
155 if let Some(user_state) = self.user_state {
156 assert_eq!(user_state, non_static_type_id(store.data()));
157 }
158
159 #[cfg(feature = "tracing")]
160 let _span = tracing::debug_span!("call_guest", ?args, ?self.ty).entered();
161
162 assert_eq!(self.ty.params().len(), args.len());
164 assert_eq!(self.ty.results().len(), results.len());
165
166 let args = args
167 .iter()
168 .map(|arg| arg.to_py(py))
169 .collect::<Result<ArgumentVec<_>, PyErr>>()?;
170 let args = PyTuple::new(py, args)?;
171
172 let res = self.pyfunc.bind(py).call1(args)?;
173
174 #[cfg(feature = "tracing")]
175 tracing::debug!(%res, ?self.ty);
176
177 match (self.ty.results(), results) {
178 ([], []) => (),
179 ([ty], [result]) => *result = Val::from_py_typed(res, *ty)?,
180 (tys, results) => {
181 let res: Bound<PyTuple> = PyTuple::type_object(py)
182 .call1((res,))?
183 .extract()
184 .map_err(PyErr::from)?;
185
186 assert_eq!(tys.len(), res.len());
188
189 for ((ty, result), value) in self
190 .ty
191 .results()
192 .iter()
193 .zip(results.iter_mut())
194 .zip(res.iter())
195 {
196 *result = Val::from_py_typed(value, *ty)?;
197 }
198 },
199 }
200
201 Ok(())
202 })
203 }
204}
205
206impl ToPy for Func {
207 fn to_py(&self, py: Python) -> Result<Py<PyAny>, PyErr> {
208 Ok(self.pyfunc.clone_ref(py))
209 }
210}
211
212impl Func {
213 pub(crate) fn from_exported_function(func: Bound<PyAny>, ty: FuncType) -> anyhow::Result<Self> {
215 if !func.is_callable() {
216 anyhow::bail!("expected WebAssembly.Function but found {func:?} which is not callable");
217 }
218
219 #[cfg(feature = "tracing")]
220 tracing::debug!(%func, ?ty, "Func::from_exported_function");
221
222 Ok(Self {
223 pyfunc: func.unbind(),
224 ty,
225 user_state: None,
226 })
227 }
228}
229
230pub type PyHostFuncFn = dyn 'static + Send + Sync + Fn(Bound<PyTuple>) -> Result<Py<PyAny>, PyErr>;
231
232#[pyclass(frozen)]
233struct PyHostFunc {
234 func: Wobbly<PyHostFuncFn>,
235 #[cfg(feature = "tracing")]
236 ty: FuncType,
237}
238
239#[pymethods]
240impl PyHostFunc {
241 #[pyo3(signature = (*args))]
242 fn __call__(&self, args: Bound<PyTuple>) -> Result<Py<PyAny>, PyErr> {
243 #[cfg(feature = "tracing")]
244 let _span = tracing::debug_span!("call_trampoline", ?self.ty, args = %args).entered();
245
246 let Some(func) = self.func.upgrade() else {
247 return Err(PyRuntimeError::new_err(
248 "weak host func called after free of its associated store",
249 ));
250 };
251
252 func(args)
253 }
254}
255
256fn non_static_type_id<T: ?Sized>(_x: &T) -> TypeId {
259 trait NonStaticAny {
260 fn get_type_id(&self) -> TypeId
261 where
262 Self: 'static;
263 }
264
265 impl<T: ?Sized> NonStaticAny for PhantomData<T> {
266 fn get_type_id(&self) -> TypeId
267 where
268 Self: 'static,
269 {
270 TypeId::of::<T>()
271 }
272 }
273
274 let phantom_data = PhantomData::<T>;
275 NonStaticAny::get_type_id(unsafe {
276 core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
277 })
278}