Skip to main content

roam_types/
calls.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{MaybeSend, MaybeSync, Metadata, RequestCall, RequestResponse, RoamError, SelfRef};
4
5// As a recap, a service defined like so:
6//
7// #[roam::service]
8// trait Hash {
9//   async fn hash(&self, payload: &[u8]) -> Result<&[u8], E>;
10// }
11//
12// Would expand to the following caller:
13//
14// impl HashClient {
15//   async fn hash(&self, payload: &[u8]) -> Result<SelfRef<&[u8]>, RoamError<E>>;
16// }
17//
18// Would expand to a service trait (what users implement):
19//
20// trait Hash {
21//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]);
22// }
23//
24// And a HashDispatcher<S: Hash> that implements Handler<R: ReplySink>:
25// it deserializes args, constructs an ErasedCall<T, E> from the ReplySink,
26// and routes to the appropriate method by method ID.
27//
28// For owned success returns, generated methods return values directly and
29// the dispatcher sends replies on their behalf.
30//
31// HashDispatcher<S> implements Handler<R>, and can be stored as
32// Box<dyn Handler<R>> to erase both S and the service type.
33//
34// Why impl Call in HashServer? So that the server can reply with something
35// _borrowed_ from its own stack frame.
36//
37// For example:
38//
39// impl Hash for MyHasher {
40//   async fn hash(&self, call: impl Call<&[u8], E>, payload: &[u8]) {
41//     let result: [u8; 16] = compute_hash(payload);
42//     call.ok(&result).await;
43//   }
44// }
45//
46// Call's public API is:
47//
48// trait Call<T, E> {
49//   async fn reply(self, result: Result<T, E>);
50//   async fn ok(self, value: T) { self.reply(Ok(value)).await }
51//   async fn err(self, error: E) { self.reply(Err(error)).await }
52// }
53//
54// If a Call is dropped before reply/ok/err is called, the caller will
55// receive a RoamError::Cancelled error. This is to ensure that the caller
56// is always notified, even if the handler panics or otherwise fails to
57// reply.
58
59/// Represents an in-progress call from a client that must be replied to.
60///
61/// A `Call` is handed to a [`Handler`] implementation and provides the
62/// mechanism for sending a response back to the caller. The response can
63/// be sent via [`Call::reply`], [`Call::ok`], or [`Call::err`].
64///
65/// # Cancellation
66///
67/// If a `Call` is dropped without a reply being sent, the caller will
68/// automatically receive a [`RoamError::Cancelled`] error. This guarantees
69/// that the caller is always notified, even if the handler panics or
70/// otherwise fails to produce a reply.
71///
72/// # Type Parameters
73///
74/// - `T`: The success value type of the response.
75/// - `E`: The error value type of the response.
76pub trait Call<'wire, T, E>: MaybeSend
77where
78    T: facet::Facet<'wire> + MaybeSend,
79    E: facet::Facet<'wire> + MaybeSend,
80{
81    /// Send a [`Result`] back to the caller, consuming this `Call`.
82    fn reply(self, result: Result<T, E>) -> impl std::future::Future<Output = ()> + MaybeSend;
83
84    /// Send a successful response back to the caller, consuming this `Call`.
85    ///
86    /// Equivalent to `self.reply(Ok(value)).await`.
87    fn ok(self, value: T) -> impl std::future::Future<Output = ()> + MaybeSend
88    where
89        Self: Sized,
90    {
91        self.reply(Ok(value))
92    }
93
94    /// Send an error response back to the caller, consuming this `Call`.
95    ///
96    /// Equivalent to `self.reply(Err(error)).await`.
97    fn err(self, error: E) -> impl std::future::Future<Output = ()> + MaybeSend
98    where
99        Self: Sized,
100    {
101        self.reply(Err(error))
102    }
103}
104
105/// Sink for sending a reply back to the caller.
106///
107/// Implemented by the session driver. Provides backpressure: `send_reply`
108/// awaits until the transport can accept the response before serializing it.
109///
110/// # Cancellation
111///
112/// If the `ReplySink` is dropped without `send_reply` being called, the caller
113/// will automatically receive a [`crate::RoamError::Cancelled`] error.
114pub trait ReplySink: MaybeSend + MaybeSync + 'static {
115    /// Send the response, consuming the sink. Any error that happens during send_reply
116    /// must set a flag in the driver for it to reply with an error.
117    ///
118    /// This cannot return a Result because we cannot trust callers to deal with it, and
119    /// it's not like they can try sending a second reply anyway.
120    ///
121    /// Do not spawn a task to send the error because it too, might fail.
122    fn send_reply(
123        self,
124        response: RequestResponse<'_>,
125    ) -> impl std::future::Future<Output = ()> + MaybeSend;
126
127    /// Send an error response back to the caller, consuming the sink.
128    ///
129    /// This is a convenience method used by generated dispatchers when
130    /// deserialization fails or the method ID is unknown.
131    fn send_error<E: for<'a> facet::Facet<'a> + MaybeSend>(
132        self,
133        error: RoamError<E>,
134    ) -> impl std::future::Future<Output = ()> + MaybeSend
135    where
136        Self: Sized,
137    {
138        use crate::{Payload, RequestResponse};
139        // Wire format is always Result<T, RoamError<E>>. We don't know T here,
140        // but postcard encodes () as zero bytes, so Result<(), RoamError<E>>
141        // produces the same Err variant encoding as any Result<T, RoamError<E>>.
142        async move {
143            let wire: Result<(), RoamError<E>> = Err(error);
144            self.send_reply(RequestResponse {
145                ret: Payload::outgoing(&wire),
146                channels: vec![],
147                metadata: Default::default(),
148            })
149            .await;
150        }
151    }
152
153    /// Return a channel binder for binding Tx/Rx handles in deserialized args.
154    ///
155    /// Returns `None` by default. The driver's `ReplySink` implementation
156    /// overrides this to provide actual channel binding.
157    #[cfg(not(target_arch = "wasm32"))]
158    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
159        None
160    }
161}
162
163/// Type-erased handler for incoming service calls.
164///
165/// Implemented (by the macro-generated dispatch code) for server-side types.
166/// Takes a fully decoded [`RequestCall`](crate::RequestCall) — already parsed
167/// from the wire — and a [`ReplySink`] through which the response is sent.
168///
169/// The dispatch impl decodes the args, routes by [`crate::MethodId`], and
170/// invokes the appropriate typed [`Call`]-based method on the concrete server type.
171/// A cloneable handle to a connection, handed out by the session driver.
172///
173/// Generated clients hold an [`ErasedCaller`] and use it to send calls. The caller
174/// serializes the outgoing [`RequestCall`] (with borrowed args), registers a
175/// pending response slot, and awaits the response from the peer.
176pub trait Caller: Clone + MaybeSend + MaybeSync + 'static {
177    /// Send a call and wait for the response.
178    #[allow(async_fn_in_trait)]
179    async fn call<'a>(
180        &self,
181        call: RequestCall<'a>,
182    ) -> Result<SelfRef<RequestResponse<'static>>, RoamError>;
183
184    /// Return a channel binder for binding Tx/Rx handles in args before sending.
185    ///
186    /// Returns `None` by default. The driver's `Caller` implementation
187    /// overrides this to provide actual channel binding.
188    #[cfg(not(target_arch = "wasm32"))]
189    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
190        None
191    }
192}
193
194trait ErasedCallerDyn: MaybeSend + MaybeSync + 'static {
195    fn call<'a>(
196        &'a self,
197        call: RequestCall<'a>,
198    ) -> Pin<Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + 'a>>;
199
200    #[cfg(not(target_arch = "wasm32"))]
201    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder>;
202}
203
204impl<C: Caller> ErasedCallerDyn for C {
205    fn call<'a>(
206        &'a self,
207        call: RequestCall<'a>,
208    ) -> Pin<Box<dyn Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>> + 'a>>
209    {
210        Box::pin(Caller::call(self, call))
211    }
212
213    #[cfg(not(target_arch = "wasm32"))]
214    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
215        Caller::channel_binder(self)
216    }
217}
218
219/// Type-erased [`Caller`] wrapper used by generated clients.
220#[derive(Clone)]
221pub struct ErasedCaller {
222    inner: Arc<dyn ErasedCallerDyn>,
223}
224
225impl ErasedCaller {
226    pub fn new<C: Caller>(caller: C) -> Self {
227        Self {
228            inner: Arc::new(caller),
229        }
230    }
231}
232
233impl Caller for ErasedCaller {
234    async fn call<'a>(
235        &self,
236        call: RequestCall<'a>,
237    ) -> Result<SelfRef<RequestResponse<'static>>, RoamError> {
238        self.inner.call(call).await
239    }
240
241    #[cfg(not(target_arch = "wasm32"))]
242    fn channel_binder(&self) -> Option<&dyn crate::ChannelBinder> {
243        self.inner.channel_binder()
244    }
245}
246
247pub trait Handler<R: ReplySink>: MaybeSend + MaybeSync + 'static {
248    /// Dispatch an incoming call to the appropriate method implementation.
249    fn handle(
250        &self,
251        call: SelfRef<crate::RequestCall<'static>>,
252        reply: R,
253    ) -> impl std::future::Future<Output = ()> + MaybeSend + '_;
254}
255
256impl<R: ReplySink> Handler<R> for () {
257    async fn handle(&self, _call: SelfRef<crate::RequestCall<'static>>, _reply: R) {}
258}
259
260/// A decoded response value paired with response metadata.
261///
262/// This helper is available for lower-level callers that need both the
263/// decoded value and metadata together. Generated Rust client methods do
264/// not expose response metadata in their return types.
265pub struct ResponseParts<'a, T> {
266    /// The decoded return value.
267    pub ret: T,
268    /// Metadata attached to the response by the server.
269    pub metadata: Metadata<'a>,
270}
271
272impl<'a, T> std::ops::Deref for ResponseParts<'a, T> {
273    type Target = T;
274    fn deref(&self) -> &T {
275        &self.ret
276    }
277}
278
279/// Concrete [`Call`] implementation backed by a [`ReplySink`].
280///
281/// Constructed by the dispatcher and handed to the server method.
282/// When the server calls [`Call::reply`], the result is serialized and
283/// sent through the sink.
284pub struct SinkCall<R: ReplySink> {
285    reply: R,
286}
287
288impl<R: ReplySink> SinkCall<R> {
289    pub fn new(reply: R) -> Self {
290        Self { reply }
291    }
292}
293
294impl<'wire, T, E, R> Call<'wire, T, E> for SinkCall<R>
295where
296    T: facet::Facet<'wire> + MaybeSend,
297    E: facet::Facet<'wire> + MaybeSend,
298    R: ReplySink,
299{
300    async fn reply(self, result: Result<T, E>) {
301        use crate::{Payload, RequestResponse};
302        let wire: Result<T, crate::RoamError<E>> = result.map_err(crate::RoamError::User);
303        let ptr =
304            facet::PtrConst::new((&wire as *const Result<T, crate::RoamError<E>>).cast::<u8>());
305        let shape = <Result<T, crate::RoamError<E>> as facet::Facet<'wire>>::SHAPE;
306        // SAFETY: `wire` lives until `send_reply(...).await` completes in this function,
307        // and `shape` matches the pointed value exactly.
308        let ret = unsafe { Payload::outgoing_unchecked(ptr, shape) };
309        self.reply
310            .send_reply(RequestResponse {
311                ret,
312                channels: vec![],
313                metadata: Default::default(),
314            })
315            .await;
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use std::sync::{Arc, Mutex};
322
323    use crate::{MaybeSend, Metadata, Payload, RequestCall, RequestResponse};
324
325    use super::{Call, Caller, Handler, ReplySink, ResponseParts};
326
327    struct RecordingCall<T, E> {
328        observed: Arc<Mutex<Option<Result<T, E>>>>,
329    }
330
331    impl<'wire, T, E> Call<'wire, T, E> for RecordingCall<T, E>
332    where
333        T: facet::Facet<'wire> + MaybeSend + Send + 'static,
334        E: facet::Facet<'wire> + MaybeSend + Send + 'static,
335    {
336        async fn reply(self, result: Result<T, E>) {
337            let mut guard = self.observed.lock().expect("recording mutex poisoned");
338            *guard = Some(result);
339        }
340    }
341
342    struct RecordingReplySink {
343        saw_send_reply: Arc<Mutex<bool>>,
344        saw_outgoing_payload: Arc<Mutex<bool>>,
345    }
346
347    impl ReplySink for RecordingReplySink {
348        async fn send_reply(self, response: RequestResponse<'_>) {
349            let mut saw_send_reply = self
350                .saw_send_reply
351                .lock()
352                .expect("send-reply mutex poisoned");
353            *saw_send_reply = true;
354
355            let mut saw_outgoing = self
356                .saw_outgoing_payload
357                .lock()
358                .expect("payload-kind mutex poisoned");
359            *saw_outgoing = matches!(response.ret, Payload::Outgoing { .. });
360        }
361    }
362
363    #[derive(Clone)]
364    struct NoopCaller;
365
366    impl Caller for NoopCaller {
367        async fn call<'a>(
368            &self,
369            _call: RequestCall<'a>,
370        ) -> Result<crate::SelfRef<RequestResponse<'static>>, crate::RoamError> {
371            unreachable!("NoopCaller::call is not used by this test")
372        }
373    }
374
375    #[tokio::test]
376    async fn call_ok_and_err_route_through_reply() {
377        let observed_ok: Arc<Mutex<Option<Result<u32, &'static str>>>> = Arc::new(Mutex::new(None));
378        RecordingCall {
379            observed: Arc::clone(&observed_ok),
380        }
381        .ok(7)
382        .await;
383        assert!(matches!(
384            *observed_ok.lock().expect("ok mutex poisoned"),
385            Some(Ok(7))
386        ));
387
388        let observed_err: Arc<Mutex<Option<Result<u32, &'static str>>>> =
389            Arc::new(Mutex::new(None));
390        RecordingCall {
391            observed: Arc::clone(&observed_err),
392        }
393        .err("boom")
394        .await;
395        assert!(matches!(
396            *observed_err.lock().expect("err mutex poisoned"),
397            Some(Err("boom"))
398        ));
399    }
400
401    #[tokio::test]
402    async fn reply_sink_send_error_uses_outgoing_payload_and_reply_path() {
403        let saw_send_reply = Arc::new(Mutex::new(false));
404        let saw_outgoing_payload = Arc::new(Mutex::new(false));
405        let sink = RecordingReplySink {
406            saw_send_reply: Arc::clone(&saw_send_reply),
407            saw_outgoing_payload: Arc::clone(&saw_outgoing_payload),
408        };
409
410        sink.send_error(crate::RoamError::<String>::Cancelled).await;
411
412        assert!(*saw_send_reply.lock().expect("send-reply mutex poisoned"));
413        assert!(
414            *saw_outgoing_payload
415                .lock()
416                .expect("payload-kind mutex poisoned")
417        );
418    }
419
420    #[tokio::test]
421    async fn unit_handler_is_noop() {
422        let req = crate::SelfRef::owning(
423            crate::Backing::Boxed(Box::<[u8]>::default()),
424            RequestCall {
425                method_id: crate::MethodId(1),
426                channels: vec![],
427                metadata: Metadata::default(),
428                args: Payload::Incoming(&[]),
429            },
430        );
431        ().handle(
432            req,
433            RecordingReplySink {
434                saw_send_reply: Arc::new(Mutex::new(false)),
435                saw_outgoing_payload: Arc::new(Mutex::new(false)),
436            },
437        )
438        .await;
439    }
440
441    #[test]
442    fn response_parts_deref_exposes_ret() {
443        let parts = ResponseParts {
444            ret: 42_u32,
445            metadata: Metadata::default(),
446        };
447        assert_eq!(*parts, 42);
448    }
449
450    #[test]
451    fn default_channel_binder_accessor_for_caller_returns_none() {
452        let caller = NoopCaller;
453        assert!(caller.channel_binder().is_none());
454    }
455
456    #[test]
457    fn default_channel_binder_accessor_for_reply_sink_returns_none() {
458        let sink = RecordingReplySink {
459            saw_send_reply: Arc::new(Mutex::new(false)),
460            saw_outgoing_payload: Arc::new(Mutex::new(false)),
461        };
462        assert!(sink.channel_binder().is_none());
463    }
464}