1#![allow(clippy::type_complexity)] use core::any::Any;
4use core::borrow::Borrow;
5use core::fmt;
6use core::future::Future;
7use core::iter::zip;
8use core::pin::pin;
9use core::time::Duration;
10
11use std::collections::{BTreeMap, HashMap};
12use std::sync::Arc;
13
14use anyhow::anyhow;
15use bytes::{Bytes, BytesMut};
16use futures::future::try_join_all;
17use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
18use tokio_util::codec::Encoder;
19use tracing::{debug, instrument, trace, warn};
20use uuid::Uuid;
21use wasmtime::component::{
22 types, Func, Resource, ResourceAny, ResourceTable, ResourceType, Type, Val,
23};
24use wasmtime::error::Context as _;
25use wasmtime::{bail, AsContextMut, Engine};
26use wrpc_transport::Invoke;
27
28use crate::bindings::rpc::context::Context;
29use crate::bindings::rpc::error::Error;
30use crate::bindings::rpc::transport::{IncomingChannel, Invocation, OutgoingChannel};
31
32pub mod bindings;
33mod codec;
34pub mod paths;
35mod polyfill;
36pub mod rpc;
37mod serve;
38
39pub use codec::*;
40pub use polyfill::*;
41pub use serve::*;
42
43fn rpc_func_name(name: &str) -> &str {
47 if let Some(name) = name.strip_prefix("[constructor]") {
48 name
49 } else if let Some(name) = name.strip_prefix("[static]") {
50 name
51 } else if let Some(name) = name.strip_prefix("[method]") {
52 name
53 } else {
54 name
55 }
56}
57
58fn rpc_result_type<T: Borrow<Type>>(
59 host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
60 results_ty: impl IntoIterator<Item = T>,
61) -> Option<Option<Type>> {
62 let rpc_err_ty = host_resources
63 .get("wrpc:rpc/error@0.1.0")
64 .and_then(|instance| instance.get("error"));
65 let mut results_ty = results_ty.into_iter();
66 match (
67 rpc_err_ty,
68 results_ty.next().as_ref().map(Borrow::borrow),
69 results_ty.next(),
70 ) {
71 (Some((guest_rpc_err_ty, host_rpc_err_ty)), Some(Type::Result(result_ty)), None)
72 if *host_rpc_err_ty == ResourceType::host::<Error>()
73 && result_ty.err() == Some(Type::Own(*guest_rpc_err_ty)) =>
74 {
75 Some(result_ty.ok())
76 }
77 _ => None,
78 }
79}
80
81pub struct RemoteResource(pub Bytes);
82
83#[derive(Debug, Default)]
85pub struct SharedResourceTable(HashMap<Uuid, ResourceAny>);
86
87pub trait WrpcCtx<T: Invoke>: Send {
88 fn context(&self) -> T::Context;
90
91 fn client(&self) -> &T;
93
94 fn shared_resources(&mut self) -> &mut SharedResourceTable;
96
97 fn timeout(&self) -> Option<Duration> {
100 None
101 }
102}
103
104pub struct WrpcCtxView<'a, T: Invoke> {
105 pub ctx: &'a mut dyn WrpcCtx<T>,
106 pub table: &'a mut ResourceTable,
107}
108
109pub trait WrpcView: Send {
110 type Invoke: Invoke;
111
112 fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke>;
113}
114
115impl<T: WrpcView> WrpcView for &mut T {
116 type Invoke = T::Invoke;
117
118 fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke> {
119 T::wrpc(self)
120 }
121}
122
123pub trait WrpcViewExt: WrpcView {
124 fn push_invocation(
125 &mut self,
126 invocation: impl Future<
127 Output = anyhow::Result<(
128 <Self::Invoke as Invoke>::Outgoing,
129 <Self::Invoke as Invoke>::Incoming,
130 )>,
131 > + Send
132 + 'static,
133 ) -> wasmtime::Result<Resource<Invocation>> {
134 self.wrpc()
135 .table
136 .push(Invocation::Future(Box::pin(async move {
137 let res = invocation.await;
138 Box::new(res) as Box<dyn Any + Send>
139 })))
140 .context("failed to push invocation to table")
141 }
142
143 fn get_invocation_result(
144 &mut self,
145 invocation: &Resource<Invocation>,
146 ) -> wasmtime::Result<
147 Option<
148 &Box<
149 anyhow::Result<(
150 <Self::Invoke as Invoke>::Outgoing,
151 <Self::Invoke as Invoke>::Incoming,
152 )>,
153 >,
154 >,
155 > {
156 let invocation = self
157 .wrpc()
158 .table
159 .get(invocation)
160 .context("failed to get invocation from table")?;
161 match invocation {
162 Invocation::Future(..) => Ok(None),
163 Invocation::Ready(res) => {
164 let res = res.downcast_ref().context("invalid invocation type")?;
165 Ok(Some(res))
166 }
167 }
168 }
169
170 fn delete_invocation(
171 &mut self,
172 invocation: Resource<Invocation>,
173 ) -> wasmtime::Result<
174 impl Future<
175 Output = anyhow::Result<(
176 <Self::Invoke as Invoke>::Outgoing,
177 <Self::Invoke as Invoke>::Incoming,
178 )>,
179 >,
180 > {
181 let invocation = self
182 .wrpc()
183 .table
184 .delete(invocation)
185 .context("failed to delete invocation from table")?;
186 Ok(async move {
187 let res = match invocation {
188 Invocation::Future(fut) => fut.await,
189 Invocation::Ready(res) => res,
190 };
191 let res = res
192 .downcast()
193 .map_err(|_| anyhow!("invalid invocation type"))?;
194 *res
195 })
196 }
197
198 fn push_outgoing_channel(
199 &mut self,
200 outgoing: <Self::Invoke as Invoke>::Outgoing,
201 ) -> wasmtime::Result<Resource<OutgoingChannel>> {
202 self.wrpc()
203 .table
204 .push(OutgoingChannel(Arc::new(std::sync::RwLock::new(Box::new(
205 outgoing,
206 )))))
207 .context("failed to push outgoing channel to table")
208 }
209
210 fn delete_outgoing_channel(
211 &mut self,
212 outgoing: Resource<OutgoingChannel>,
213 ) -> wasmtime::Result<<Self::Invoke as Invoke>::Outgoing> {
214 let OutgoingChannel(outgoing) = self
215 .wrpc()
216 .table
217 .delete(outgoing)
218 .context("failed to delete outgoing channel from table")?;
219 let outgoing =
220 Arc::into_inner(outgoing).context("outgoing channel has an active stream")?;
221 let Ok(outgoing) = outgoing.into_inner() else {
222 bail!("lock poisoned");
223 };
224 let outgoing = outgoing
225 .downcast()
226 .map_err(|_| wasmtime::Error::msg("invalid outgoing channel type"))?;
227 Ok(*outgoing)
228 }
229
230 fn push_incoming_channel(
231 &mut self,
232 incoming: <Self::Invoke as Invoke>::Incoming,
233 ) -> wasmtime::Result<Resource<IncomingChannel>> {
234 self.wrpc()
235 .table
236 .push(IncomingChannel(Arc::new(std::sync::RwLock::new(Box::new(
237 incoming,
238 )))))
239 .context("failed to push incoming channel to table")
240 }
241
242 fn delete_incoming_channel(
243 &mut self,
244 incoming: Resource<IncomingChannel>,
245 ) -> wasmtime::Result<<Self::Invoke as Invoke>::Incoming> {
246 let IncomingChannel(incoming) = self
247 .wrpc()
248 .table
249 .delete(incoming)
250 .context("failed to delete incoming channel from table")?;
251 let incoming =
252 Arc::into_inner(incoming).context("incoming channel has an active stream")?;
253 let Ok(incoming) = incoming.into_inner() else {
254 bail!("lock poisoned");
255 };
256 let incoming = incoming
257 .downcast()
258 .map_err(|_| wasmtime::Error::msg("invalid incoming channel type"))?;
259 Ok(*incoming)
260 }
261
262 fn push_error(&mut self, error: Error) -> wasmtime::Result<Resource<Error>> {
263 self.wrpc()
264 .table
265 .push(error)
266 .context("failed to push error to table")
267 }
268
269 fn get_error(&mut self, error: &Resource<Error>) -> wasmtime::Result<&Error> {
270 let error = self
271 .wrpc()
272 .table
273 .get(error)
274 .context("failed to get error from table")?;
275 Ok(error)
276 }
277
278 fn get_error_mut(&mut self, error: &Resource<Error>) -> wasmtime::Result<&mut Error> {
279 let error = self
280 .wrpc()
281 .table
282 .get_mut(error)
283 .context("failed to get error from table")?;
284 Ok(error)
285 }
286
287 fn delete_error(&mut self, error: Resource<Error>) -> wasmtime::Result<Error> {
288 let error = self
289 .wrpc()
290 .table
291 .delete(error)
292 .context("failed to delete error from table")?;
293 Ok(error)
294 }
295
296 fn push_context(
297 &mut self,
298 cx: <Self::Invoke as Invoke>::Context,
299 ) -> wasmtime::Result<Resource<Context>>
300 where
301 <Self::Invoke as Invoke>::Context: 'static,
302 {
303 self.wrpc()
304 .table
305 .push(Context(Box::new(cx)))
306 .context("failed to push context to table")
307 }
308
309 fn delete_context(
310 &mut self,
311 cx: Resource<Context>,
312 ) -> wasmtime::Result<<Self::Invoke as Invoke>::Context>
313 where
314 <Self::Invoke as Invoke>::Context: 'static,
315 {
316 let Context(cx) = self
317 .wrpc()
318 .table
319 .delete(cx)
320 .context("failed to delete context from table")?;
321 let cx = cx
322 .downcast()
323 .map_err(|_| wasmtime::Error::msg("invalid context type"))?;
324 Ok(*cx)
325 }
326}
327
328impl<T: WrpcView> WrpcViewExt for T {}
329
330pub enum CallError {
332 Decode(wasmtime::Error),
333 Encode(wasmtime::Error),
334 Table(wasmtime::Error),
335 Call(wasmtime::Error),
336 TypeMismatch(wasmtime::Error),
337 Write(wasmtime::Error),
338 Flush(wasmtime::Error),
339 Deferred(wasmtime::Error),
340 PostReturn(wasmtime::Error),
341 Guest(Error),
342}
343
344impl core::error::Error for CallError {}
345
346impl fmt::Debug for CallError {
347 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348 match self {
349 CallError::Decode(error)
350 | CallError::Encode(error)
351 | CallError::Table(error)
352 | CallError::Call(error)
353 | CallError::TypeMismatch(error)
354 | CallError::Write(error)
355 | CallError::Flush(error)
356 | CallError::Deferred(error)
357 | CallError::PostReturn(error) => error.fmt(f),
358 CallError::Guest(error) => error.fmt(f),
359 }
360 }
361}
362
363impl fmt::Display for CallError {
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 match self {
366 CallError::Decode(error)
367 | CallError::Encode(error)
368 | CallError::Table(error)
369 | CallError::Call(error)
370 | CallError::TypeMismatch(error)
371 | CallError::Write(error)
372 | CallError::Flush(error)
373 | CallError::Deferred(error)
374 | CallError::PostReturn(error) => error.fmt(f),
375 CallError::Guest(error) => error.fmt(f),
376 }
377 }
378}
379
380#[allow(clippy::too_many_arguments)]
381pub async fn call<C, I, O>(
382 mut store: C,
383 rx: I,
384 mut tx: O,
385 guest_resources: &[ResourceType],
386 host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
387 io_streams: &[ResourceType],
388 params_ty: impl ExactSizeIterator<Item = &Type>,
389 results_ty: &[Type],
390 func: Func,
391) -> Result<(), CallError>
392where
393 I: AsyncRead + wrpc_transport::Index<I> + Send + Sync + Unpin + 'static,
394 O: AsyncWrite + wrpc_transport::Index<O> + Send + Sync + Unpin + 'static,
395 C: AsContextMut,
396 C::Data: WrpcView,
397{
398 let mut params = vec![Val::Bool(false); params_ty.len()];
399 let mut rx = pin!(rx);
400 for (i, (v, ty)) in zip(&mut params, params_ty).enumerate() {
401 read_value(
402 &mut store,
403 &mut rx,
404 guest_resources,
405 io_streams,
406 v,
407 ty,
408 &[i],
409 )
410 .await
411 .with_context(|| format!("failed to decode parameter value {i}"))
412 .map_err(CallError::Decode)?;
413 }
414 let mut results = vec![Val::Bool(false); results_ty.len()];
415 func.call_async(&mut store, ¶ms, &mut results)
416 .await
417 .context("failed to call function")
418 .map_err(CallError::Call)?;
419
420 let mut buf = BytesMut::default();
421 let mut deferred = vec![];
422 match (
423 &rpc_result_type(host_resources, results_ty),
424 results.as_slice(),
425 ) {
426 (None, results) => {
427 for (i, (v, ty)) in zip(results, results_ty).enumerate() {
428 let mut enc =
429 ValEncoder::new(store.as_context_mut(), ty, guest_resources, io_streams);
430 enc.encode(v, &mut buf)
431 .with_context(|| format!("failed to encode result value {i}"))
432 .map_err(CallError::Encode)?;
433 deferred.push(enc.deferred);
434 }
435 }
436 (Some(None), [Val::Result(Ok(None))]) => {}
438 (Some(Some(ty)), [Val::Result(Ok(Some(v)))]) => {
440 let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources, io_streams);
441 enc.encode(v, &mut buf)
442 .context("failed to encode result value 0")
443 .map_err(CallError::Encode)?;
444 deferred.push(enc.deferred);
445 }
446 (Some(..), [Val::Result(Err(Some(err)))]) => {
447 let Val::Resource(err) = &**err else {
448 return Err(CallError::TypeMismatch(wasmtime::Error::msg(
449 "RPC result error value is not a resource",
450 )));
451 };
452 let mut store = store.as_context_mut();
453 let err = err
454 .try_into_resource(&mut store)
455 .context("RPC result error resource type mismatch")
456 .map_err(CallError::TypeMismatch)?;
457 let err = store
458 .data_mut()
459 .delete_error(err)
460 .map_err(CallError::Table)?;
461 return Err(CallError::Guest(err));
462 }
463 _ => {
464 return Err(CallError::TypeMismatch(wasmtime::Error::msg(
465 "RPC result type mismatch",
466 )))
467 }
468 }
469
470 debug!("transmitting results");
471 tx.write_all(&buf)
472 .await
473 .context("failed to transmit results")
474 .map_err(CallError::Write)?;
475 tx.flush()
476 .await
477 .context("failed to flush outgoing stream")
478 .map_err(CallError::Flush)?;
479 if let Err(err) = tx.shutdown().await {
480 trace!(?err, "failed to shutdown outgoing stream");
481 }
482 try_join_all(
483 zip(0.., deferred)
484 .filter_map(|(i, f)| f.map(|f| (tx.index(&[i]), f)))
485 .map(|(w, f)| async move {
486 let w = w.map_err(wasmtime::Error::from_anyhow)?;
487 f(w).await
488 }),
489 )
490 .await
491 .map_err(CallError::Deferred)?;
492 Ok(())
493}
494
495#[instrument(level = "debug", skip_all)]
497pub fn collect_item_resource_exports(
498 engine: &Engine,
499 ty: types::ComponentItem,
500 resources: &mut impl Extend<types::ResourceType>,
501) {
502 match ty {
503 types::ComponentItem::ComponentFunc(_)
504 | types::ComponentItem::CoreFunc(_)
505 | types::ComponentItem::Module(_)
506 | types::ComponentItem::Type(_) => {}
507 types::ComponentItem::Component(ty) => {
508 collect_component_resource_exports(engine, &ty, resources)
509 }
510
511 types::ComponentItem::ComponentInstance(ty) => {
512 collect_instance_resource_exports(engine, &ty, resources)
513 }
514 types::ComponentItem::Resource(ty) => {
515 debug!(?ty, "collect resource export");
516 resources.extend([ty])
517 }
518 }
519}
520
521#[instrument(level = "debug", skip_all)]
523pub fn collect_instance_resource_exports(
524 engine: &Engine,
525 ty: &types::ComponentInstance,
526 resources: &mut impl Extend<types::ResourceType>,
527) {
528 for (name, ty) in ty.exports(engine) {
529 trace!(name, ?ty, "collect instance item resource exports");
530 collect_item_resource_exports(engine, ty, resources);
531 }
532}
533
534#[instrument(level = "debug", skip_all)]
536pub fn collect_component_resource_exports(
537 engine: &Engine,
538 ty: &types::Component,
539 resources: &mut impl Extend<types::ResourceType>,
540) {
541 for (name, ty) in ty.exports(engine) {
542 trace!(name, ?ty, "collect component item resource exports");
543 collect_item_resource_exports(engine, ty, resources);
544 }
545}
546
547#[instrument(level = "debug", skip_all)]
549pub fn collect_component_resource_imports(
550 engine: &Engine,
551 ty: &types::Component,
552 resources: &mut BTreeMap<Box<str>, HashMap<Box<str>, types::ResourceType>>,
553) {
554 for (name, ty) in ty.imports(engine) {
555 match ty {
556 types::ComponentItem::ComponentFunc(..)
557 | types::ComponentItem::CoreFunc(..)
558 | types::ComponentItem::Module(..)
559 | types::ComponentItem::Type(..)
560 | types::ComponentItem::Component(..) => {}
561 types::ComponentItem::ComponentInstance(ty) => {
562 let instance = name;
563 for (name, ty) in ty.exports(engine) {
564 if let types::ComponentItem::Resource(ty) = ty {
565 debug!(instance, name, ?ty, "collect instance resource import");
566 if let Some(resources) = resources.get_mut(instance) {
567 resources.insert(name.into(), ty);
568 } else {
569 resources.insert(instance.into(), HashMap::from([(name.into(), ty)]));
570 }
571 }
572 }
573 }
574 types::ComponentItem::Resource(ty) => {
575 debug!(name, "collect component resource import");
576 if let Some(resources) = resources.get_mut("") {
577 resources.insert(name.into(), ty);
578 } else {
579 resources.insert("".into(), HashMap::from([(name.into(), ty)]));
580 }
581 }
582 }
583 }
584}