Skip to main content

roam_types/
calls.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{MaybeSend, MaybeSync, Metadata, RequestCall, RequestResponse, RoamError, SelfRef};
4
5// As a recap, a service defined like so:
6//
7// #[roam::service]
8// trait Hash {
9//   async fn hash(&self, payload: &[u8]) -> Result<&[u8], E>;
10// }
11//
12// Would expand to the following caller:
13//
14// impl HashClient {
15//   async fn hash(&self, payload: &[u8]) -> Result<SelfRef<&[u8]>, RoamError<E>>;
16// }
17//
18// Would expand to a service trait (what users implement):
19//
20// trait Hash {
21//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]);
22// }
23//
24// And a HashDispatcher<S: Hash> that implements Handler<R: ReplySink>:
25// it deserializes args, constructs an ErasedCall<T, E> from the ReplySink,
26// and routes to the appropriate method by method ID.
27//
28// For owned success returns, generated methods return values directly and
29// the dispatcher sends replies on their behalf.
30//
31// HashDispatcher<S> implements Handler<R>, and can be stored as
32// Box<dyn Handler<R>> to erase both S and the service type.
33//
34// Why impl Call in HashServer? So that the server can reply with something
35// _borrowed_ from its own stack frame.
36//
37// For example:
38//
39// impl Hash for MyHasher {
40//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]) {
41//     let result: [u8; 16] = compute_hash(payload);
42//     call.ok(&result).await;
43//   }
44// }
45//
46// Call's public API is:
47//
48// trait Call<T, E> {
49//   async fn reply(self, result: Result<T, E>);
50//   async fn ok(self, value: T) { self.reply(Ok(value)).await }
51//   async fn err(self, error: E) { self.reply(Err(error)).await }
52// }
53//
54// If a Call is dropped before reply/ok/err is called, the caller will
55// receive a RoamError::Cancelled error. This is to ensure that the caller
56// is always notified, even if the handler panics or otherwise fails to
57// reply.
58
59/// Represents an in-progress call from a client that must be replied to.
60///
61/// A `Call` is handed to a [`Handler`] implementation and provides the
62/// mechanism for sending a response back to the caller. The response can
63/// be sent via [`Call::reply`], [`Call::ok`], or [`Call::err`].
64///
65/// # Cancellation
66///
67/// If a `Call` is dropped without a reply being sent, the caller will
68/// automatically receive a [`RoamError::Cancelled`] error. This guarantees
69/// that the caller is always notified, even if the handler panics or
70/// otherwise fails to produce a reply.
71///
72/// # Type Parameters
73///
74/// - `T`: The success value type of the response.
75/// - `E`: The error value type of the response.
76pub trait Call<'wire, T, E>: MaybeSend
77where
78    T: facet::Facet<'wire> + MaybeSend,
79    E: facet::Facet<'wire> + MaybeSend,
80{
81    /// Send a [`Result`] back to the caller, consuming this `Call`.
82    fn reply(self, result: Result<T, E>) -> impl std::future::Future<Output = ()> + MaybeSend;
83
84    /// Send a successful response back to the caller, consuming this `Call`.
85    ///
86    /// Equivalent to `self.reply(Ok(value)).await`.
87    fn ok(self, value: T) -> impl std::future::Future<Output = ()> + MaybeSend
88    where
89        Self: Sized,
90    {
91        self.reply(Ok(value))
92    }
93
94    /// Send an error response back to the caller, consuming this `Call`.
95    ///
96    /// Equivalent to `self.reply(Err(error)).await`.
97    fn err(self, error: E) -> impl std::future::Future<Output = ()> + MaybeSend
98    where
99        Self: Sized,
100    {
101        self.reply(Err(error))
102    }
103}
104
105/// Sink for sending a reply back to the caller.
106///
107/// Implemented by the session driver. Provides backpressure: `send_reply`
108/// awaits until the transport can accept the response before serializing it.
109///
110/// # Cancellation
111///
112/// If the `ReplySink` is dropped without `send_reply` being called, the caller
113/// will automatically receive a [`crate::RoamError::Cancelled`] error.
114pub trait ReplySink: MaybeSend + MaybeSync + 'static {
115    /// Send the response, consuming the sink. Any error that happens during send_reply
116    /// must set a flag in the driver for it to reply with an error.
117    ///
118    /// This cannot return a Result because we cannot trust callers to deal with it, and
119    /// it's not like they can try sending a second reply anyway.
120    ///
121    /// Do not spawn a task to send the error because it too, might fail.
122    fn send_reply(
123        self,
124        response: RequestResponse<'_>,
125    ) -> impl std::future::Future<Output = ()> + MaybeSend;
126
127    /// Send an error response back to the caller, consuming the sink.
128    ///
129    /// This is a convenience method used by generated dispatchers when
130    /// deserialization fails or the method ID is unknown.
131    fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
132        self,
133        error: RoamError<E>,
134    ) -> impl std::future::Future<Output = ()> + MaybeSend
135    where
136        Self: Sized,
137    {
138        use crate::{Payload, RequestResponse};
139        // Wire format is always Result<T, RoamError<E>>. We don't know T here,
140        // but postcard encodes () as zero bytes, so Result<(), RoamError<E>>
141        // produces the same Err variant encoding as any Result<T, RoamError<E>>.
142        async move {
143            let wire: Result<(), RoamError<E>> = Err(error);
144            self.send_reply(RequestResponse {
145                ret: Payload::outgoing(&wire),
146                channels: vec![],
147                metadata: Default::default(),
148            })
149            .await;
150        }
151    }
152
153    /// Return a channel binder for binding Tx/Rx handles in deserialized args.
154    ///
155    /// Returns `None` by default. The driver's `ReplySink` implementation
156    /// overrides this to provide actual channel binding.
157    #[cfg(not(target_arch = "wasm32"))]
158    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
159        None
160    }
161}
162
163/// Type-erased handler for incoming service calls.
164///
165/// Implemented (by the macro-generated dispatch code) for server-side types.
166/// Takes a fully decoded [`RequestCall`](crate::RequestCall) — already parsed
167/// from the wire — and a [`ReplySink`] through which the response is sent.
168///
169/// The dispatch impl decodes the args, routes by [`crate::MethodId`], and
170/// invokes the appropriate typed [`Call`]-based method on the concrete server type.
171/// A cloneable handle to a connection, handed out by the session driver.
172///
173/// Generated clients hold an [`ErasedCaller`] and use it to send calls. The caller
174/// serializes the outgoing [`RequestCall`] (with borrowed args), registers a
175/// pending response slot, and awaits the response from the peer.
176pub trait Caller: Clone + MaybeSend + MaybeSync + 'static {
177    /// Send a call and wait for the response.
178    fn call<'a>(
179        &'a self,
180        call: RequestCall<'a>,
181    ) -> impl Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + MaybeSend + 'a;
182
183    /// Return a channel binder for binding Tx/Rx handles in args before sending.
184    ///
185    /// Returns `None` by default. The driver's `Caller` implementation
186    /// overrides this to provide actual channel binding.
187    #[cfg(not(target_arch = "wasm32"))]
188    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
189        None
190    }
191}
192
193trait ErasedCallerDyn: MaybeSend + MaybeSync + 'static {
194    #[cfg(not(target_arch = "wasm32"))]
195    fn call<'a>(
196        &'a self,
197        call: RequestCall<'a>,
198    ) -> Pin<
199        Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + Send + 'a>,
200    >;
201    #[cfg(target_arch = "wasm32")]
202    fn call<'a>(
203        &'a self,
204        call: RequestCall<'a>,
205    ) -> Pin<Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + 'a>>;
206
207    #[cfg(not(target_arch = "wasm32"))]
208    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder>;
209}
210
211impl<C: Caller> ErasedCallerDyn for C {
212    #[cfg(not(target_arch = "wasm32"))]
213    fn call<'a>(
214        &'a self,
215        call: RequestCall<'a>,
216    ) -> Pin<
217        Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + Send + 'a>,
218    > {
219        Box::pin(Caller::call(self, call))
220    }
221    #[cfg(target_arch = "wasm32")]
222    fn call<'a>(
223        &'a self,
224        call: RequestCall<'a>,
225    ) -> Pin<Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + 'a>>
226    {
227        Box::pin(Caller::call(self, call))
228    }
229
230    #[cfg(not(target_arch = "wasm32"))]
231    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
232        Caller::channel_binder(self)
233    }
234}
235
236/// Type-erased [`Caller`] wrapper used by generated clients.
237#[derive(Clone)]
238pub struct ErasedCaller {
239    inner: Arc<dyn ErasedCallerDyn>,
240}
241
242impl ErasedCaller {
243    pub fn new<C: Caller>(caller: C) -> Self {
244        Self {
245            inner: Arc::new(caller),
246        }
247    }
248}
249
250impl Caller for ErasedCaller {
251    fn call<'a>(
252        &'a self,
253        call: RequestCall<'a>,
254    ) -> impl Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + MaybeSend + 'a
255    {
256        self.inner.call(call)
257    }
258
259    #[cfg(not(target_arch = "wasm32"))]
260    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
261        self.inner.channel_binder()
262    }
263}
264
265pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
266    /// Dispatch an incoming call to the appropriate method implementation.
267    fn handle(
268        &self,
269        call: SelfRef<crate::RequestCall<'static>>,
270        reply: R,
271    ) -> impl std::future::Future<Output = ()> + MaybeSend + '_;
272}
273
274impl<R: ReplySink> Handler<R> for () {
275    async fn handle(&self, _call: SelfRef<crate::RequestCall<'static>>, _reply: R) {}
276}
277
278/// A decoded response value paired with response metadata.
279///
280/// This helper is available for lower-level callers that need both the
281/// decoded value and metadata together. Generated Rust client methods do
282/// not expose response metadata in their return types.
283pub struct ResponseParts<'a, T> {
284    /// The decoded return value.
285    pub ret: T,
286    /// Metadata attached to the response by the server.
287    pub metadata: Metadata<'a>,
288}
289
290impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
291    type Target = T;
292    fn deref(&self) -> &T {
293        &self.ret
294    }
295}
296
297/// Concrete [`Call`] implementation backed by a [`ReplySink`].
298///
299/// Constructed by the dispatcher and handed to the server method.
300/// When the server calls [`Call::reply`], the result is serialized and
301/// sent through the sink.
302pub struct SinkCall<R: ReplySink> {
303    reply: R,
304}
305
306impl<R: ReplySink> SinkCall<R> {
307    pub fn new(reply: R) -> Self {
308        Self { reply }
309    }
310}
311
312impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
313where
314    T: facet::Facet<'wire> + MaybeSend,
315    E: facet::Facet<'wire> + MaybeSend,
316    R: ReplySink,
317{
318    async fn reply(self, result: Result<T, E>) {
319        use crate::{Payload, RequestResponse};
320        let wire: Result<T, crate::RoamError<E>> = result.map_err(crate::RoamError::User);
321        let ptr =
322            facet::PtrConst::new((&wire as *const Result<T, crate::RoamError<E>>).cast::<u8>());
323        let shape = <Result<T, crate::RoamError<E>> as facet::Facet<'wire>>::SHAPE;
324        // SAFETY: `wire` lives until `send_reply(...).await` completes in this function,
325        // and `shape` matches the pointed value exactly.
326        let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
327        self.reply
328            .send_reply(RequestResponse {
329                ret,
330                channels: vec![],
331                metadata: Default::default(),
332            })
333            .await;
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use std::sync::{Arc, Mutex};
340
341    use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
342
343    use super::{Call, Caller, Handler, ReplySink, ResponseParts};
344
345    struct RecordingCall<T, E> {
346        observed: Arc<Mutex<Option<Result<T, E>>>>,
347    }
348
349    impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
350    where
351        T: facet::Facet<'wire> + MaybeSend + Send + 'static,
352        E: facet::Facet<'wire> + MaybeSend + Send + 'static,
353    {
354        async fn reply(self, result: Result<T, E>) {
355            let mut guard = self.observed.lock().expect("recording mutex poisoned");
356            *guard = Some(result);
357        }
358    }
359
360    struct RecordingReplySink {
361        saw_send_reply: Arc<Mutex<bool>>,
362        saw_outgoing_payload: Arc<Mutex<bool>>,
363    }
364
365    impl ReplySink for RecordingReplySink {
366        async fn send_reply(self, response: RequestResponse<'_>) {
367            let mut saw_send_reply = self
368                .saw_send_reply
369                .lock()
370                .expect("send-reply mutex poisoned");
371            *saw_send_reply = true;
372
373            let mut saw_outgoing = self
374                .saw_outgoing_payload
375                .lock()
376                .expect("payload-kind mutex poisoned");
377            *saw_outgoing = matches!(response.ret, Payload::Outgoing { .. });
378        }
379    }
380
381    #[derive(Clone)]
382    struct NoopCaller;
383
384    impl Caller for NoopCaller {
385        fn call<'a>(
386            &'a self,
387            _call: RequestCall<'a>,
388        ) -> impl Future<
389            Output = Result<crate::SelfRef<RequestResponse<'static>>, crate::RoamError>,
390        > + MaybeSend
391        + 'a {
392            async move { unreachable!("NoopCaller::call is not used by this test") }
393        }
394    }
395
396    #[tokio::test]
397    async fn call_ok_and_err_route_through_reply() {
398        let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
399        RecordingCall {
400            observed: Arc::clone(&observed_ok),
401        }
402        .ok(7)
403        .await;
404        assert!(matches!(
405            *observed_ok.lock().expect("ok mutex poisoned"),
406            Some(Ok(7))
407        ));
408
409        let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
410            Arc::new(Mutex::new(None));
411        RecordingCall {
412            observed: Arc::clone(&observed_err),
413        }
414        .err("boom")
415        .await;
416        assert!(matches!(
417            *observed_err.lock().expect("err mutex poisoned"),
418            Some(Err("boom"))
419        ));
420    }
421
422    #[tokio::test]
423    async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
424        let saw_send_reply = Arc::new(Mutex::new(false));
425        let saw_outgoing_payload = Arc::new(Mutex::new(false));
426        let sink = RecordingReplySink {
427            saw_send_reply: Arc::clone(&saw_send_reply),
428            saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
429        };
430
431        sink.send_error(crate::RoamError::<String>::Cancelled).await;
432
433        assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
434        assert!(
435            *saw_outgoing_payload
436                .lock()
437                .expect("payload-kind mutex poisoned")
438        );
439    }
440
441    #[tokio::test]
442    async fn unit_handler_is_noop() {
443        let req = crate::SelfRef::owning(
444            crate::Backing::Boxed(Box::<[u8]>::default()),
445            RequestCall {
446                method_id: crate::MethodId(1),
447                channels: vec![],
448                metadata: Metadata::default(),
449                args: Payload::Incoming(&[]),
450            },
451        );
452        ().handle(
453            req,
454            RecordingReplySink {
455                saw_send_reply: Arc::new(Mutex::new(false)),
456                saw_outgoing_payload: Arc::new(Mutex::new(false)),
457            },
458        )
459        .await;
460    }
461
462    #[test]
463    fn response_parts_deref_exposes_ret() {
464        let parts = ResponseParts {
465            ret: 42_u32,
466            metadata: Metadata::default(),
467        };
468        assert_eq!(*parts, 42);
469    }
470
471    #[test]
472    fn default_channel_binder_accessor_for_caller_returns_none() {
473        let caller = NoopCaller;
474        assert!(caller.channel_binder().is_none());
475    }
476
477    #[test]
478    fn default_channel_binder_accessor_for_reply_sink_returns_none() {
479        let sink = RecordingReplySink {
480            saw_send_reply: Arc::new(Mutex::new(false)),
481            saw_outgoing_payload: Arc::new(Mutex::new(false)),
482        };
483        assert!(sink.channel_binder().is_none());
484    }
485}