Skip to main content

wrpc_runtime_wasmtime/
lib.rs

1#![allow(clippy::type_complexity)] // TODO: https://github.com/bytecodealliance/wrpc/issues/2
2
3use core::any::Any;
4use core::borrow::Borrow;
5use core::fmt;
6use core::future::Future;
7use core::iter::zip;
8use core::pin::pin;
9use core::time::Duration;
10
11use std::collections::{BTreeMap, HashMap};
12use std::sync::Arc;
13
14use anyhow::anyhow;
15use bytes::{Bytes, BytesMut};
16use futures::future::try_join_all;
17use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
18use tokio_util::codec::Encoder;
19use tracing::{debug, instrument, trace, warn};
20use uuid::Uuid;
21use wasmtime::component::{
22    types, Func, Resource, ResourceAny, ResourceTable, ResourceType, Type, Val,
23};
24use wasmtime::error::Context as _;
25use wasmtime::{bail, AsContextMut, Engine};
26use wrpc_transport::Invoke;
27
28use crate::bindings::rpc::context::Context;
29use crate::bindings::rpc::error::Error;
30use crate::bindings::rpc::transport::{IncomingChannel, Invocation, OutgoingChannel};
31
32pub mod bindings;
33mod codec;
34pub mod paths;
35mod polyfill;
36pub mod rpc;
37mod serve;
38
39pub use codec::*;
40pub use polyfill::*;
41pub use serve::*;
42
43// this returns the RPC name for a wasmtime function name.
44// Unfortunately, the [`types::ComponentFunc`] does not include the kind information and we want to
45// avoid (re-)parsing the WIT here.
46fn rpc_func_name(name: &str) -> &str {
47    if let Some(name) = name.strip_prefix("[constructor]") {
48        name
49    } else if let Some(name) = name.strip_prefix("[static]") {
50        name
51    } else if let Some(name) = name.strip_prefix("[method]") {
52        name
53    } else {
54        name
55    }
56}
57
58fn rpc_result_type<T: Borrow<Type>>(
59    host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
60    results_ty: impl IntoIterator<Item = T>,
61) -> Option<Option<Type>> {
62    let rpc_err_ty = host_resources
63        .get("wrpc:rpc/error@0.1.0")
64        .and_then(|instance| instance.get("error"));
65    let mut results_ty = results_ty.into_iter();
66    match (
67        rpc_err_ty,
68        results_ty.next().as_ref().map(Borrow::borrow),
69        results_ty.next(),
70    ) {
71        (Some((guest_rpc_err_ty, host_rpc_err_ty)), Some(Type::Result(result_ty)), None)
72            if *host_rpc_err_ty == ResourceType::host::<Error>()
73                && result_ty.err() == Some(Type::Own(*guest_rpc_err_ty)) =>
74        {
75            Some(result_ty.ok())
76        }
77        _ => None,
78    }
79}
80
81pub struct RemoteResource(pub Bytes);
82
83/// A table of shared resources exported by the component
84#[derive(Debug, Default)]
85pub struct SharedResourceTable(HashMap<Uuid, ResourceAny>);
86
87pub trait WrpcCtx<T: Invoke>: Send {
88    /// Returns context to use for invocation
89    fn context(&self) -> T::Context;
90
91    /// Returns an [Invoke] implementation used to satisfy polyfilled imports
92    fn client(&self) -> &T;
93
94    /// Returns a table of shared exported resources
95    fn shared_resources(&mut self) -> &mut SharedResourceTable;
96
97    /// Optional invocation timeout, component will trap if invocation is not finished within the
98    /// returned [Duration]. If this method returns [None], then no timeout will be used.
99    fn timeout(&self) -> Option<Duration> {
100        None
101    }
102}
103
104pub struct WrpcCtxView<'a, T: Invoke> {
105    pub ctx: &'a mut dyn WrpcCtx<T>,
106    pub table: &'a mut ResourceTable,
107}
108
109pub trait WrpcView: Send {
110    type Invoke: Invoke;
111
112    fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke>;
113}
114
115impl<T: WrpcView> WrpcView for &mut T {
116    type Invoke = T::Invoke;
117
118    fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke> {
119        T::wrpc(self)
120    }
121}
122
123pub trait WrpcViewExt: WrpcView {
124    fn push_invocation(
125        &mut self,
126        invocation: impl Future<
127                Output = anyhow::Result<(
128                    <Self::Invoke as Invoke>::Outgoing,
129                    <Self::Invoke as Invoke>::Incoming,
130                )>,
131            > + Send
132            + 'static,
133    ) -> wasmtime::Result<Resource<Invocation>> {
134        self.wrpc()
135            .table
136            .push(Invocation::Future(Box::pin(async move {
137                let res = invocation.await;
138                Box::new(res) as Box<dyn Any + Send>
139            })))
140            .context("failed to push invocation to table")
141    }
142
143    fn get_invocation_result(
144        &mut self,
145        invocation: &Resource<Invocation>,
146    ) -> wasmtime::Result<
147        Option<
148            &Box<
149                anyhow::Result<(
150                    <Self::Invoke as Invoke>::Outgoing,
151                    <Self::Invoke as Invoke>::Incoming,
152                )>,
153            >,
154        >,
155    > {
156        let invocation = self
157            .wrpc()
158            .table
159            .get(invocation)
160            .context("failed to get invocation from table")?;
161        match invocation {
162            Invocation::Future(..) => Ok(None),
163            Invocation::Ready(res) => {
164                let res = res.downcast_ref().context("invalid invocation type")?;
165                Ok(Some(res))
166            }
167        }
168    }
169
170    fn delete_invocation(
171        &mut self,
172        invocation: Resource<Invocation>,
173    ) -> wasmtime::Result<
174        impl Future<
175            Output = anyhow::Result<(
176                <Self::Invoke as Invoke>::Outgoing,
177                <Self::Invoke as Invoke>::Incoming,
178            )>,
179        >,
180    > {
181        let invocation = self
182            .wrpc()
183            .table
184            .delete(invocation)
185            .context("failed to delete invocation from table")?;
186        Ok(async move {
187            let res = match invocation {
188                Invocation::Future(fut) => fut.await,
189                Invocation::Ready(res) => res,
190            };
191            let res = res
192                .downcast()
193                .map_err(|_| anyhow!("invalid invocation type"))?;
194            *res
195        })
196    }
197
198    fn push_outgoing_channel(
199        &mut self,
200        outgoing: <Self::Invoke as Invoke>::Outgoing,
201    ) -> wasmtime::Result<Resource<OutgoingChannel>> {
202        self.wrpc()
203            .table
204            .push(OutgoingChannel(Arc::new(std::sync::RwLock::new(Box::new(
205                outgoing,
206            )))))
207            .context("failed to push outgoing channel to table")
208    }
209
210    fn delete_outgoing_channel(
211        &mut self,
212        outgoing: Resource<OutgoingChannel>,
213    ) -> wasmtime::Result<<Self::Invoke as Invoke>::Outgoing> {
214        let OutgoingChannel(outgoing) = self
215            .wrpc()
216            .table
217            .delete(outgoing)
218            .context("failed to delete outgoing channel from table")?;
219        let outgoing =
220            Arc::into_inner(outgoing).context("outgoing channel has an active stream")?;
221        let Ok(outgoing) = outgoing.into_inner() else {
222            bail!("lock poisoned");
223        };
224        let outgoing = outgoing
225            .downcast()
226            .map_err(|_| wasmtime::Error::msg("invalid outgoing channel type"))?;
227        Ok(*outgoing)
228    }
229
230    fn push_incoming_channel(
231        &mut self,
232        incoming: <Self::Invoke as Invoke>::Incoming,
233    ) -> wasmtime::Result<Resource<IncomingChannel>> {
234        self.wrpc()
235            .table
236            .push(IncomingChannel(Arc::new(std::sync::RwLock::new(Box::new(
237                incoming,
238            )))))
239            .context("failed to push incoming channel to table")
240    }
241
242    fn delete_incoming_channel(
243        &mut self,
244        incoming: Resource<IncomingChannel>,
245    ) -> wasmtime::Result<<Self::Invoke as Invoke>::Incoming> {
246        let IncomingChannel(incoming) = self
247            .wrpc()
248            .table
249            .delete(incoming)
250            .context("failed to delete incoming channel from table")?;
251        let incoming =
252            Arc::into_inner(incoming).context("incoming channel has an active stream")?;
253        let Ok(incoming) = incoming.into_inner() else {
254            bail!("lock poisoned");
255        };
256        let incoming = incoming
257            .downcast()
258            .map_err(|_| wasmtime::Error::msg("invalid incoming channel type"))?;
259        Ok(*incoming)
260    }
261
262    fn push_error(&mut self, error: Error) -> wasmtime::Result<Resource<Error>> {
263        self.wrpc()
264            .table
265            .push(error)
266            .context("failed to push error to table")
267    }
268
269    fn get_error(&mut self, error: &Resource<Error>) -> wasmtime::Result<&Error> {
270        let error = self
271            .wrpc()
272            .table
273            .get(error)
274            .context("failed to get error from table")?;
275        Ok(error)
276    }
277
278    fn get_error_mut(&mut self, error: &Resource<Error>) -> wasmtime::Result<&mut Error> {
279        let error = self
280            .wrpc()
281            .table
282            .get_mut(error)
283            .context("failed to get error from table")?;
284        Ok(error)
285    }
286
287    fn delete_error(&mut self, error: Resource<Error>) -> wasmtime::Result<Error> {
288        let error = self
289            .wrpc()
290            .table
291            .delete(error)
292            .context("failed to delete error from table")?;
293        Ok(error)
294    }
295
296    fn push_context(
297        &mut self,
298        cx: <Self::Invoke as Invoke>::Context,
299    ) -> wasmtime::Result<Resource<Context>>
300    where
301        <Self::Invoke as Invoke>::Context: 'static,
302    {
303        self.wrpc()
304            .table
305            .push(Context(Box::new(cx)))
306            .context("failed to push context to table")
307    }
308
309    fn delete_context(
310        &mut self,
311        cx: Resource<Context>,
312    ) -> wasmtime::Result<<Self::Invoke as Invoke>::Context>
313    where
314        <Self::Invoke as Invoke>::Context: 'static,
315    {
316        let Context(cx) = self
317            .wrpc()
318            .table
319            .delete(cx)
320            .context("failed to delete context from table")?;
321        let cx = cx
322            .downcast()
323            .map_err(|_| wasmtime::Error::msg("invalid context type"))?;
324        Ok(*cx)
325    }
326}
327
328impl<T: WrpcView> WrpcViewExt for T {}
329
330/// Error type returned by [call]
331pub enum CallError {
332    Decode(wasmtime::Error),
333    Encode(wasmtime::Error),
334    Table(wasmtime::Error),
335    Call(wasmtime::Error),
336    TypeMismatch(wasmtime::Error),
337    Write(wasmtime::Error),
338    Flush(wasmtime::Error),
339    Deferred(wasmtime::Error),
340    PostReturn(wasmtime::Error),
341    Guest(Error),
342}
343
344impl core::error::Error for CallError {}
345
346impl fmt::Debug for CallError {
347    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348        match self {
349            CallError::Decode(error)
350            | CallError::Encode(error)
351            | CallError::Table(error)
352            | CallError::Call(error)
353            | CallError::TypeMismatch(error)
354            | CallError::Write(error)
355            | CallError::Flush(error)
356            | CallError::Deferred(error)
357            | CallError::PostReturn(error) => error.fmt(f),
358            CallError::Guest(error) => error.fmt(f),
359        }
360    }
361}
362
363impl fmt::Display for CallError {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        match self {
366            CallError::Decode(error)
367            | CallError::Encode(error)
368            | CallError::Table(error)
369            | CallError::Call(error)
370            | CallError::TypeMismatch(error)
371            | CallError::Write(error)
372            | CallError::Flush(error)
373            | CallError::Deferred(error)
374            | CallError::PostReturn(error) => error.fmt(f),
375            CallError::Guest(error) => error.fmt(f),
376        }
377    }
378}
379
380#[allow(clippy::too_many_arguments)]
381pub async fn call<C, I, O>(
382    mut store: C,
383    rx: I,
384    mut tx: O,
385    guest_resources: &[ResourceType],
386    host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
387    io_streams: &[ResourceType],
388    params_ty: impl ExactSizeIterator<Item = &Type>,
389    results_ty: &[Type],
390    func: Func,
391) -> Result<(), CallError>
392where
393    I: AsyncRead + wrpc_transport::Index<I> + Send + Sync + Unpin + 'static,
394    O: AsyncWrite + wrpc_transport::Index<O> + Send + Sync + Unpin + 'static,
395    C: AsContextMut,
396    C::Data: WrpcView,
397{
398    let mut params = vec![Val::Bool(false); params_ty.len()];
399    let mut rx = pin!(rx);
400    for (i, (v, ty)) in zip(&mut params, params_ty).enumerate() {
401        read_value(
402            &mut store,
403            &mut rx,
404            guest_resources,
405            io_streams,
406            v,
407            ty,
408            &[i],
409        )
410        .await
411        .with_context(|| format!("failed to decode parameter value {i}"))
412        .map_err(CallError::Decode)?;
413    }
414    let mut results = vec![Val::Bool(false); results_ty.len()];
415    func.call_async(&mut store, &params, &mut results)
416        .await
417        .context("failed to call function")
418        .map_err(CallError::Call)?;
419
420    let mut buf = BytesMut::default();
421    let mut deferred = vec![];
422    match (
423        &rpc_result_type(host_resources, results_ty),
424        results.as_slice(),
425    ) {
426        (None, results) => {
427            for (i, (v, ty)) in zip(results, results_ty).enumerate() {
428                let mut enc =
429                    ValEncoder::new(store.as_context_mut(), ty, guest_resources, io_streams);
430                enc.encode(v, &mut buf)
431                    .with_context(|| format!("failed to encode result value {i}"))
432                    .map_err(CallError::Encode)?;
433                deferred.push(enc.deferred);
434            }
435        }
436        // `result<_, rpc-eror>`
437        (Some(None), [Val::Result(Ok(None))]) => {}
438        // `result<T, rpc-eror>`
439        (Some(Some(ty)), [Val::Result(Ok(Some(v)))]) => {
440            let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources, io_streams);
441            enc.encode(v, &mut buf)
442                .context("failed to encode result value 0")
443                .map_err(CallError::Encode)?;
444            deferred.push(enc.deferred);
445        }
446        (Some(..), [Val::Result(Err(Some(err)))]) => {
447            let Val::Resource(err) = &**err else {
448                return Err(CallError::TypeMismatch(wasmtime::Error::msg(
449                    "RPC result error value is not a resource",
450                )));
451            };
452            let mut store = store.as_context_mut();
453            let err = err
454                .try_into_resource(&mut store)
455                .context("RPC result error resource type mismatch")
456                .map_err(CallError::TypeMismatch)?;
457            let err = store
458                .data_mut()
459                .delete_error(err)
460                .map_err(CallError::Table)?;
461            return Err(CallError::Guest(err));
462        }
463        _ => {
464            return Err(CallError::TypeMismatch(wasmtime::Error::msg(
465                "RPC result type mismatch",
466            )))
467        }
468    }
469
470    debug!("transmitting results");
471    tx.write_all(&buf)
472        .await
473        .context("failed to transmit results")
474        .map_err(CallError::Write)?;
475    tx.flush()
476        .await
477        .context("failed to flush outgoing stream")
478        .map_err(CallError::Flush)?;
479    if let Err(err) = tx.shutdown().await {
480        trace!(?err, "failed to shutdown outgoing stream");
481    }
482    try_join_all(
483        zip(0.., deferred)
484            .filter_map(|(i, f)| f.map(|f| (tx.index(&[i]), f)))
485            .map(|(w, f)| async move {
486                let w = w.map_err(wasmtime::Error::from_anyhow)?;
487                f(w).await
488            }),
489    )
490    .await
491    .map_err(CallError::Deferred)?;
492    Ok(())
493}
494
495/// Recursively iterates the component item type and collects all exported resource types
496#[instrument(level = "debug", skip_all)]
497pub fn collect_item_resource_exports(
498    engine: &Engine,
499    ty: types::ComponentItem,
500    resources: &mut impl Extend<types::ResourceType>,
501) {
502    match ty {
503        types::ComponentItem::ComponentFunc(_)
504        | types::ComponentItem::CoreFunc(_)
505        | types::ComponentItem::Module(_)
506        | types::ComponentItem::Type(_) => {}
507        types::ComponentItem::Component(ty) => {
508            collect_component_resource_exports(engine, &ty, resources)
509        }
510
511        types::ComponentItem::ComponentInstance(ty) => {
512            collect_instance_resource_exports(engine, &ty, resources)
513        }
514        types::ComponentItem::Resource(ty) => {
515            debug!(?ty, "collect resource export");
516            resources.extend([ty])
517        }
518    }
519}
520
521/// Recursively iterates the instance type and collects all exported resource types
522#[instrument(level = "debug", skip_all)]
523pub fn collect_instance_resource_exports(
524    engine: &Engine,
525    ty: &types::ComponentInstance,
526    resources: &mut impl Extend<types::ResourceType>,
527) {
528    for (name, ty) in ty.exports(engine) {
529        trace!(name, ?ty, "collect instance item resource exports");
530        collect_item_resource_exports(engine, ty, resources);
531    }
532}
533
534/// Recursively iterates the component type and collects all exported resource types
535#[instrument(level = "debug", skip_all)]
536pub fn collect_component_resource_exports(
537    engine: &Engine,
538    ty: &types::Component,
539    resources: &mut impl Extend<types::ResourceType>,
540) {
541    for (name, ty) in ty.exports(engine) {
542        trace!(name, ?ty, "collect component item resource exports");
543        collect_item_resource_exports(engine, ty, resources);
544    }
545}
546
547/// Iterates the component type and collects all imported resource types
548#[instrument(level = "debug", skip_all)]
549pub fn collect_component_resource_imports(
550    engine: &Engine,
551    ty: &types::Component,
552    resources: &mut BTreeMap<Box<str>, HashMap<Box<str>, types::ResourceType>>,
553) {
554    for (name, ty) in ty.imports(engine) {
555        match ty {
556            types::ComponentItem::ComponentFunc(..)
557            | types::ComponentItem::CoreFunc(..)
558            | types::ComponentItem::Module(..)
559            | types::ComponentItem::Type(..)
560            | types::ComponentItem::Component(..) => {}
561            types::ComponentItem::ComponentInstance(ty) => {
562                let instance = name;
563                for (name, ty) in ty.exports(engine) {
564                    if let types::ComponentItem::Resource(ty) = ty {
565                        debug!(instance, name, ?ty, "collect instance resource import");
566                        if let Some(resources) = resources.get_mut(instance) {
567                            resources.insert(name.into(), ty);
568                        } else {
569                            resources.insert(instance.into(), HashMap::from([(name.into(), ty)]));
570                        }
571                    }
572                }
573            }
574            types::ComponentItem::Resource(ty) => {
575                debug!(name, "collect component resource import");
576                if let Some(resources) = resources.get_mut("") {
577                    resources.insert(name.into(), ty);
578                } else {
579                    resources.insert("".into(), HashMap::from([(name.into(), ty)]));
580                }
581            }
582        }
583    }
584}