Skip to main content

roam_types/
calls.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{MaybeSend, MaybeSync, Metadata, RequestCall, RequestResponse, RoamError, SelfRef};
4
5// As a recap, a service defined like so:
6//
7// #[roam::service]
8// trait Hash {
9//   async fn hash(&self, payload: &[u8]) -> Result<&[u8], E>;
10// }
11//
12// Would expand to the following caller:
13//
14// impl HashClient {
15//   async fn hash(&self, payload: &[u8]) -> Result<SelfRef<&[u8]>, RoamError<E>>;
16// }
17//
18// Would expand to a service trait (what users implement):
19//
20// trait Hash {
21//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]);
22// }
23//
24// And a HashDispatcher<S: Hash> that implements Handler<R: ReplySink>:
25// it deserializes args, constructs an ErasedCall<T, E> from the ReplySink,
26// and routes to the appropriate method by method ID.
27//
28// For owned success returns, generated methods return values directly and
29// the dispatcher sends replies on their behalf.
30//
31// HashDispatcher<S> implements Handler<R>, and can be stored as
32// Box<dyn Handler<R>> to erase both S and the service type.
33//
34// Why impl Call in HashServer? So that the server can reply with something
35// _borrowed_ from its own stack frame.
36//
37// For example:
38//
39// impl Hash for MyHasher {
40//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]) {
41//     let result: [u8; 16] = compute_hash(payload);
42//     call.ok(&result).await;
43//   }
44// }
45//
46// Call's public API is:
47//
48// trait Call<T, E> {
49//   async fn reply(self, result: Result<T, E>);
50//   async fn ok(self, value: T) { self.reply(Ok(value)).await }
51//   async fn err(self, error: E) { self.reply(Err(error)).await }
52// }
53//
54// If a Call is dropped before reply/ok/err is called, the caller will
55// receive a RoamError::Cancelled error. This is to ensure that the caller
56// is always notified, even if the handler panics or otherwise fails to
57// reply.
58
59/// Represents an in-progress call from a client that must be replied to.
60///
61/// A `Call` is handed to a [`Handler`] implementation and provides the
62/// mechanism for sending a response back to the caller. The response can
63/// be sent via [`Call::reply`], [`Call::ok`], or [`Call::err`].
64///
65/// # Cancellation
66///
67/// If a `Call` is dropped without a reply being sent, the caller will
68/// automatically receive a [`RoamError::Cancelled`] error. This guarantees
69/// that the caller is always notified, even if the handler panics or
70/// otherwise fails to produce a reply.
71///
72/// # Type Parameters
73///
74/// - `T`: The success value type of the response.
75/// - `E`: The error value type of the response.
76pub trait Call<'wire, T, E>: MaybeSend
77where
78    T: facet::Facet<'wire> + MaybeSend,
79    E: facet::Facet<'wire> + MaybeSend,
80{
81    /// Send a [`Result`] back to the caller, consuming this `Call`.
82    fn reply(self, result: Result<T, E>) -> impl std::future::Future<Output = ()> + MaybeSend;
83
84    /// Send a successful response back to the caller, consuming this `Call`.
85    ///
86    /// Equivalent to `self.reply(Ok(value)).await`.
87    fn ok(self, value: T) -> impl std::future::Future<Output = ()> + MaybeSend
88    where
89        Self: Sized,
90    {
91        self.reply(Ok(value))
92    }
93
94    /// Send an error response back to the caller, consuming this `Call`.
95    ///
96    /// Equivalent to `self.reply(Err(error)).await`.
97    fn err(self, error: E) -> impl std::future::Future<Output = ()> + MaybeSend
98    where
99        Self: Sized,
100    {
101        self.reply(Err(error))
102    }
103}
104
105/// Sink for sending a reply back to the caller.
106///
107/// Implemented by the session driver. Provides backpressure: `send_reply`
108/// awaits until the transport can accept the response before serializing it.
109///
110/// # Cancellation
111///
112/// If the `ReplySink` is dropped without `send_reply` being called, the caller
113/// will automatically receive a [`crate::RoamError::Cancelled`] error.
114pub trait ReplySink: MaybeSend + MaybeSync + 'static {
115    /// Send the response, consuming the sink. Any error that happens during send_reply
116    /// must set a flag in the driver for it to reply with an error.
117    ///
118    /// This cannot return a Result because we cannot trust callers to deal with it, and
119    /// it's not like they can try sending a second reply anyway.
120    ///
121    /// Do not spawn a task to send the error because it too, might fail.
122    fn send_reply(
123        self,
124        response: RequestResponse<'_>,
125    ) -> impl std::future::Future<Output = ()> + MaybeSend;
126
127    /// Send an error response back to the caller, consuming the sink.
128    ///
129    /// This is a convenience method used by generated dispatchers when
130    /// deserialization fails or the method ID is unknown.
131    fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
132        self,
133        error: RoamError<E>,
134    ) -> impl std::future::Future<Output = ()> + MaybeSend
135    where
136        Self: Sized,
137    {
138        use crate::{Payload, RequestResponse};
139        // Wire format is always Result<T, RoamError<E>>. We don't know T here,
140        // but postcard encodes () as zero bytes, so Result<(), RoamError<E>>
141        // produces the same Err variant encoding as any Result<T, RoamError<E>>.
142        async move {
143            let wire: Result<(), RoamError<E>> = Err(error);
144            self.send_reply(RequestResponse {
145                ret: Payload::outgoing(&wire),
146                channels: vec![],
147                metadata: Default::default(),
148            })
149            .await;
150        }
151    }
152
153    /// Return a channel binder for binding Tx/Rx handles in deserialized args.
154    ///
155    /// Returns `None` by default. The driver's `ReplySink` implementation
156    /// overrides this to provide actual channel binding.
157    #[cfg(not(target_arch = "wasm32"))]
158    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
159        None
160    }
161}
162
163/// Type-erased handler for incoming service calls.
164///
165/// Implemented (by the macro-generated dispatch code) for server-side types.
166/// Takes a fully decoded [`RequestCall`](crate::RequestCall) — already parsed
167/// from the wire — and a [`ReplySink`] through which the response is sent.
168///
169/// The dispatch impl decodes the args, routes by [`crate::MethodId`], and
170/// invokes the appropriate typed [`Call`]-based method on the concrete server type.
171/// A cloneable handle to a connection, handed out by the session driver.
172///
173/// Generated clients hold an [`ErasedCaller`] and use it to send calls. The caller
174/// serializes the outgoing [`RequestCall`] (with borrowed args), registers a
175/// pending response slot, and awaits the response from the peer.
176pub trait Caller: Clone + MaybeSend + MaybeSync + 'static {
177    /// Send a call and wait for the response.
178    fn call<'a>(
179        &'a self,
180        call: RequestCall<'a>,
181    ) -> impl Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + MaybeSend + 'a;
182
183    /// Resolve when the underlying connection closes.
184    ///
185    /// Runtime-backed callers can override this to expose connection liveness.
186    /// The default implementation never resolves.
187    #[cfg(not(target_arch = "wasm32"))]
188    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
189        Box::pin(std::future::pending())
190    }
191    #[cfg(target_arch = "wasm32")]
192    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + '_>> {
193        Box::pin(std::future::pending())
194    }
195
196    /// Return whether the underlying connection is still considered connected.
197    ///
198    /// Runtime-backed callers can override this to provide eager liveness
199    /// checks. The default implementation assumes the connection is live.
200    fn is_connected(&self) -> bool {
201        true
202    }
203
204    /// Return a channel binder for binding Tx/Rx handles in args before sending.
205    ///
206    /// Returns `None` by default. The driver's `Caller` implementation
207    /// overrides this to provide actual channel binding.
208    #[cfg(not(target_arch = "wasm32"))]
209    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
210        None
211    }
212}
213
214trait ErasedCallerDyn: MaybeSend + MaybeSync + 'static {
215    #[cfg(not(target_arch = "wasm32"))]
216    fn call<'a>(
217        &'a self,
218        call: RequestCall<'a>,
219    ) -> Pin<
220        Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + Send + 'a>,
221    >;
222    #[cfg(target_arch = "wasm32")]
223    fn call<'a>(
224        &'a self,
225        call: RequestCall<'a>,
226    ) -> Pin<Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + 'a>>;
227
228    #[cfg(not(target_arch = "wasm32"))]
229    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
230    #[cfg(target_arch = "wasm32")]
231    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + '_>>;
232
233    fn is_connected(&self) -> bool;
234
235    #[cfg(not(target_arch = "wasm32"))]
236    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder>;
237}
238
239impl<C: Caller> ErasedCallerDyn for C {
240    #[cfg(not(target_arch = "wasm32"))]
241    fn call<'a>(
242        &'a self,
243        call: RequestCall<'a>,
244    ) -> Pin<
245        Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + Send + 'a>,
246    > {
247        Box::pin(Caller::call(self, call))
248    }
249    #[cfg(target_arch = "wasm32")]
250    fn call<'a>(
251        &'a self,
252        call: RequestCall<'a>,
253    ) -> Pin<Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + 'a>>
254    {
255        Box::pin(Caller::call(self, call))
256    }
257
258    #[cfg(not(target_arch = "wasm32"))]
259    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
260        Caller::closed(self)
261    }
262    #[cfg(target_arch = "wasm32")]
263    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + '_>> {
264        Caller::closed(self)
265    }
266
267    fn is_connected(&self) -> bool {
268        Caller::is_connected(self)
269    }
270
271    #[cfg(not(target_arch = "wasm32"))]
272    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
273        Caller::channel_binder(self)
274    }
275}
276
277/// Type-erased [`Caller`] wrapper used by generated clients.
278#[derive(Clone)]
279pub struct ErasedCaller {
280    inner: Arc<dyn ErasedCallerDyn>,
281}
282
283impl ErasedCaller {
284    pub fn new<C: Caller>(caller: C) -> Self {
285        Self {
286            inner: Arc::new(caller),
287        }
288    }
289}
290
291impl Caller for ErasedCaller {
292    fn call<'a>(
293        &'a self,
294        call: RequestCall<'a>,
295    ) -> impl Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + MaybeSend + 'a
296    {
297        self.inner.call(call)
298    }
299
300    #[cfg(not(target_arch = "wasm32"))]
301    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
302        self.inner.closed()
303    }
304
305    #[cfg(target_arch = "wasm32")]
306    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + '_>> {
307        self.inner.closed()
308    }
309
310    fn is_connected(&self) -> bool {
311        self.inner.is_connected()
312    }
313
314    #[cfg(not(target_arch = "wasm32"))]
315    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
316        self.inner.channel_binder()
317    }
318}
319
320pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
321    /// Dispatch an incoming call to the appropriate method implementation.
322    fn handle(
323        &self,
324        call: SelfRef<crate::RequestCall<'static>>,
325        reply: R,
326    ) -> impl std::future::Future<Output = ()> + MaybeSend + '_;
327}
328
329impl<R: ReplySink> Handler<R> for () {
330    async fn handle(&self, _call: SelfRef<crate::RequestCall<'static>>, _reply: R) {}
331}
332
333/// A decoded response value paired with response metadata.
334///
335/// This helper is available for lower-level callers that need both the
336/// decoded value and metadata together. Generated Rust client methods do
337/// not expose response metadata in their return types.
338pub struct ResponseParts<'a, T> {
339    /// The decoded return value.
340    pub ret: T,
341    /// Metadata attached to the response by the server.
342    pub metadata: Metadata<'a>,
343}
344
345impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
346    type Target = T;
347    fn deref(&self) -> &T {
348        &self.ret
349    }
350}
351
352/// Concrete [`Call`] implementation backed by a [`ReplySink`].
353///
354/// Constructed by the dispatcher and handed to the server method.
355/// When the server calls [`Call::reply`], the result is serialized and
356/// sent through the sink.
357pub struct SinkCall<R: ReplySink> {
358    reply: R,
359}
360
361impl<R: ReplySink> SinkCall<R> {
362    pub fn new(reply: R) -> Self {
363        Self { reply }
364    }
365}
366
367impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
368where
369    T: facet::Facet<'wire> + MaybeSend,
370    E: facet::Facet<'wire> + MaybeSend,
371    R: ReplySink,
372{
373    async fn reply(self, result: Result<T, E>) {
374        use crate::{Payload, RequestResponse};
375        let wire: Result<T, crate::RoamError<E>> = result.map_err(crate::RoamError::User);
376        let ptr =
377            facet::PtrConst::new((&wire as *const Result<T, crate::RoamError<E>>).cast::<u8>());
378        let shape = <Result<T, crate::RoamError<E>> as facet::Facet<'wire>>::SHAPE;
379        // SAFETY: `wire` lives until `send_reply(...).await` completes in this function,
380        // and `shape` matches the pointed value exactly.
381        let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
382        self.reply
383            .send_reply(RequestResponse {
384                ret,
385                channels: vec![],
386                metadata: Default::default(),
387            })
388            .await;
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use std::sync::{Arc, Mutex};
395
396    use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
397
398    use super::{Call, Caller, Handler, ReplySink, ResponseParts};
399
400    struct RecordingCall<T, E> {
401        observed: Arc<Mutex<Option<Result<T, E>>>>,
402    }
403
404    impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
405    where
406        T: facet::Facet<'wire> + MaybeSend + Send + 'static,
407        E: facet::Facet<'wire> + MaybeSend + Send + 'static,
408    {
409        async fn reply(self, result: Result<T, E>) {
410            let mut guard = self.observed.lock().expect("recording mutex poisoned");
411            *guard = Some(result);
412        }
413    }
414
415    struct RecordingReplySink {
416        saw_send_reply: Arc<Mutex<bool>>,
417        saw_outgoing_payload: Arc<Mutex<bool>>,
418    }
419
420    impl ReplySink for RecordingReplySink {
421        async fn send_reply(self, response: RequestResponse<'_>) {
422            let mut saw_send_reply = self
423                .saw_send_reply
424                .lock()
425                .expect("send-reply mutex poisoned");
426            *saw_send_reply = true;
427
428            let mut saw_outgoing = self
429                .saw_outgoing_payload
430                .lock()
431                .expect("payload-kind mutex poisoned");
432            *saw_outgoing = matches!(response.ret, Payload::Outgoing { .. });
433        }
434    }
435
436    #[derive(Clone)]
437    struct NoopCaller;
438
439    impl Caller for NoopCaller {
440        fn call<'a>(
441            &'a self,
442            _call: RequestCall<'a>,
443        ) -> impl Future<
444            Output = Result<crate::SelfRef<RequestResponse<'static>>, crate::RoamError>,
445        > + MaybeSend
446        + 'a {
447            async move { unreachable!("NoopCaller::call is not used by this test") }
448        }
449    }
450
451    #[tokio::test]
452    async fn call_ok_and_err_route_through_reply() {
453        let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
454        RecordingCall {
455            observed: Arc::clone(&observed_ok),
456        }
457        .ok(7)
458        .await;
459        assert!(matches!(
460            *observed_ok.lock().expect("ok mutex poisoned"),
461            Some(Ok(7))
462        ));
463
464        let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
465            Arc::new(Mutex::new(None));
466        RecordingCall {
467            observed: Arc::clone(&observed_err),
468        }
469        .err("boom")
470        .await;
471        assert!(matches!(
472            *observed_err.lock().expect("err mutex poisoned"),
473            Some(Err("boom"))
474        ));
475    }
476
477    #[tokio::test]
478    async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
479        let saw_send_reply = Arc::new(Mutex::new(false));
480        let saw_outgoing_payload = Arc::new(Mutex::new(false));
481        let sink = RecordingReplySink {
482            saw_send_reply: Arc::clone(&saw_send_reply),
483            saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
484        };
485
486        sink.send_error(crate::RoamError::<String>::Cancelled).await;
487
488        assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
489        assert!(
490            *saw_outgoing_payload
491                .lock()
492                .expect("payload-kind mutex poisoned")
493        );
494    }
495
496    #[tokio::test]
497    async fn unit_handler_is_noop() {
498        let req = crate::SelfRef::owning(
499            crate::Backing::Boxed(Box::<[u8]>::default()),
500            RequestCall {
501                method_id: crate::MethodId(1),
502                channels: vec![],
503                metadata: Metadata::default(),
504                args: Payload::Incoming(&[]),
505            },
506        );
507        ().handle(
508            req,
509            RecordingReplySink {
510                saw_send_reply: Arc::new(Mutex::new(false)),
511                saw_outgoing_payload: Arc::new(Mutex::new(false)),
512            },
513        )
514        .await;
515    }
516
517    #[test]
518    fn response_parts_deref_exposes_ret() {
519        let parts = ResponseParts {
520            ret: 42_u32,
521            metadata: Metadata::default(),
522        };
523        assert_eq!(*parts, 42);
524    }
525
526    #[test]
527    fn default_channel_binder_accessor_for_caller_returns_none() {
528        let caller = NoopCaller;
529        assert!(caller.channel_binder().is_none());
530    }
531
532    #[test]
533    fn default_channel_binder_accessor_for_reply_sink_returns_none() {
534        let sink = RecordingReplySink {
535            saw_send_reply: Arc::new(Mutex::new(false)),
536            saw_outgoing_payload: Arc::new(Mutex::new(false)),
537        };
538        assert!(sink.channel_binder().is_none());
539    }
540}