Skip to main content

roam_types/
channel_binding.rs

1#![cfg(not(target_arch = "wasm32"))]
2//! Channel binding infrastructure for connecting Tx/Rx handles to the driver.
3//!
4//! Binding functions handle channel binding for request args:
5//!
6//! - [`bind_channels_caller_args`]: Caller-side, arg position. Allocates IDs,
7//!   stores bindings in the shared core so the paired handle can use them.
8//! - [`bind_channels_callee_args`]: Callee-side, arg position. Binds deserialized
9//!   standalone handles directly using IDs from `Request.channels`.
10
11use std::sync::Arc;
12
13use crate::ChannelId;
14use crate::channel::{
15    BoundChannelReceiver, BoundChannelSink, ChannelBinding, ChannelLivenessHandle, ChannelSink,
16    CoreSlot, ReceiverSlot, ReplenisherSlot, SinkSlot,
17};
18use crate::rpc_plan::{ChannelKind, RpcPlan};
19use facet_core::PtrMut;
20use facet_path::PathAccessError;
21
22/// Trait for channel operations, implemented by the session driver.
23///
24/// This abstraction lets the binding functions and macro-generated code bind
25/// channels without depending on concrete driver types.
26pub trait ChannelBinder: Send + Sync {
27    /// Allocate a channel ID and create a sink for sending items.
28    ///
29    /// `initial_credit` is the const generic `N` from `Tx<T, N>` or `Rx<T, N>`.
30    fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>);
31
32    /// Allocate a channel ID, register it for routing, and return a receiver.
33    fn create_rx(&self, initial_credit: u32) -> (ChannelId, BoundChannelReceiver);
34
35    /// Create a sink for a known channel ID (callee side).
36    ///
37    /// The channel ID comes from `Request.channels`.
38    /// `initial_credit` is the const generic `N` from `Tx<T, N>`.
39    fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink>;
40
41    /// Register an inbound channel by ID and return the receiver (callee side).
42    ///
43    /// The channel ID comes from `Request.channels`.
44    fn register_rx(&self, channel_id: ChannelId, initial_credit: u32) -> BoundChannelReceiver;
45
46    /// Optional opaque handle that keeps the underlying session/connection alive
47    /// for the lifetime of any bound channel handle.
48    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
49        None
50    }
51}
52
53// r[impl rpc.channel.binding.caller-args]
54// r[impl rpc.channel.allocation]
55/// Bind channels in args on the **caller** side, returning channel IDs.
56///
57/// The caller created `(tx, rx)` pairs via `channel()`. Only one handle from
58/// each pair is in the args; the other was kept by the caller. This function
59/// stores bindings in the shared core so the kept handle can use them.
60///
61/// # Safety
62///
63/// `args_ptr` must point to valid, initialized memory for a value whose
64/// shape matches `plan.shape`.
65#[allow(unsafe_code)]
66pub unsafe fn bind_channels_caller_args(
67    args_ptr: *mut u8,
68    plan: &RpcPlan,
69    binder: &dyn ChannelBinder,
70) -> Vec<ChannelId> {
71    let shape = plan.shape;
72    let mut channel_ids = Vec::new();
73
74    for loc in plan.channel_locations {
75        // SAFETY: caller guarantees args_ptr is valid and initialized for this shape
76        let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
77
78        match poke.at_path_mut(&loc.path) {
79            Ok(channel_poke) => match loc.kind {
80                // r[impl rpc.channel.binding.caller-args.rx]
81                // Rx in args: handler receives, caller sends.
82                // Create a sink and store it in the shared core so the caller's
83                // paired Tx can send through it.
84                ChannelKind::Rx => {
85                    let (channel_id, sink) = binder.create_tx(loc.initial_credit);
86                    channel_ids.push(channel_id);
87                    let liveness = binder.channel_liveness();
88                    if let Ok(mut ps) = channel_poke.into_struct()
89                        && let Ok(mut core_field) = ps.field_by_name("core")
90                        && let Ok(slot) = core_field.get_mut::<CoreSlot>()
91                        && let Some(core) = &slot.inner
92                    {
93                        core.set_binding(ChannelBinding::Sink(BoundChannelSink { sink, liveness }));
94                    }
95                }
96                // r[impl rpc.channel.binding.caller-args.tx]
97                // Tx in args: handler sends, caller receives.
98                // Create a receiver and store it in the shared core so the caller's
99                // paired Rx can receive from it.
100                ChannelKind::Tx => {
101                    let (channel_id, bound) = binder.create_rx(loc.initial_credit);
102                    channel_ids.push(channel_id);
103                    if let Ok(mut ps) = channel_poke.into_struct()
104                        && let Ok(mut core_field) = ps.field_by_name("core")
105                        && let Ok(slot) = core_field.get_mut::<CoreSlot>()
106                        && let Some(core) = &slot.inner
107                    {
108                        core.set_binding(ChannelBinding::Receiver(bound));
109                    }
110                }
111            },
112            Err(PathAccessError::OptionIsNone { .. }) => {
113                // Option<Tx/Rx> is None — skip
114            }
115            Err(_) => {}
116        }
117    }
118
119    channel_ids
120}
121
122// r[impl rpc.channel.binding]
123// r[impl rpc.channel.binding.callee-args]
124/// Bind channels in deserialized args on the **callee** side.
125///
126/// Handles are standalone (not part of a pair). Bind directly into the
127/// handle's local slot using channel IDs from `Request.channels`.
128///
129/// # Safety
130///
131/// `args_ptr` must point to valid, initialized memory for a value whose
132/// shape matches `plan.shape`.
133#[allow(unsafe_code)]
134pub unsafe fn bind_channels_callee_args(
135    args_ptr: *mut u8,
136    plan: &RpcPlan,
137    channel_ids: &[ChannelId],
138    binder: &dyn ChannelBinder,
139) {
140    let shape = plan.shape;
141    let mut id_idx = 0;
142
143    for loc in plan.channel_locations {
144        // SAFETY: caller guarantees args_ptr is valid and initialized for this shape
145        let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
146
147        match poke.at_path_mut(&loc.path) {
148            Ok(channel_poke) => {
149                if id_idx >= channel_ids.len() {
150                    break;
151                }
152                let channel_id = channel_ids[id_idx];
153                id_idx += 1;
154
155                match loc.kind {
156                    // r[impl rpc.channel.binding.callee-args.tx]
157                    // Tx in args: handler sends. Bind a sink directly.
158                    ChannelKind::Tx => {
159                        let sink = binder.bind_tx(channel_id, loc.initial_credit);
160                        let liveness = binder.channel_liveness();
161                        if let Ok(mut ps) = channel_poke.into_struct() {
162                            if let Ok(mut sink_field) = ps.field_by_name("sink")
163                                && let Ok(slot) = sink_field.get_mut::<SinkSlot>()
164                            {
165                                slot.inner = Some(sink);
166                            }
167                            if let Ok(mut liveness_field) = ps.field_by_name("liveness")
168                                && let Ok(slot) =
169                                    liveness_field.get_mut::<crate::channel::LivenessSlot>()
170                            {
171                                slot.inner = liveness;
172                            }
173                        }
174                    }
175                    // r[impl rpc.channel.binding.callee-args.rx]
176                    // Rx in args: handler receives. Register and bind a receiver directly.
177                    ChannelKind::Rx => {
178                        let bound = binder.register_rx(channel_id, loc.initial_credit);
179                        if let Ok(mut ps) = channel_poke.into_struct() {
180                            if let Ok(mut receiver_field) = ps.field_by_name("receiver")
181                                && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
182                            {
183                                slot.inner = Some(bound.receiver);
184                            }
185                            if let Ok(mut liveness_field) = ps.field_by_name("liveness")
186                                && let Ok(slot) =
187                                    liveness_field.get_mut::<crate::channel::LivenessSlot>()
188                            {
189                                slot.inner = bound.liveness;
190                            }
191                            if let Ok(mut replenisher_field) = ps.field_by_name("replenisher")
192                                && let Ok(slot) = replenisher_field.get_mut::<ReplenisherSlot>()
193                            {
194                                slot.inner = bound.replenisher;
195                            }
196                        }
197                    }
198                }
199            }
200            Err(PathAccessError::OptionIsNone { .. }) => {
201                // Option<Tx/Rx> is None — skip this channel location
202            }
203            Err(_) => {}
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::collections::HashMap;
211    use std::future::Future;
212    use std::pin::Pin;
213    use std::sync::{Arc, Mutex};
214
215    use facet::Facet;
216    use tokio::sync::mpsc;
217
218    use crate::channel::{
219        BoundChannelReceiver, ChannelSink, IncomingChannelMessage, RxError, TxError, channel,
220    };
221    use crate::{Backing, ChannelClose, ChannelId, Metadata, Payload, RpcPlan, SelfRef, Tx};
222
223    use super::{ChannelBinder, bind_channels_callee_args, bind_channels_caller_args};
224
225    #[derive(Default)]
226    struct TestSink;
227
228    impl ChannelSink for TestSink {
229        fn send_payload<'payload>(
230            &self,
231            _payload: Payload<'payload>,
232        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
233            Box::pin(async { Ok(()) })
234        }
235
236        fn close_channel(
237            &self,
238            _metadata: Metadata,
239        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
240            Box::pin(async { Ok(()) })
241        }
242    }
243
244    #[derive(Default)]
245    struct TestBinder {
246        next_id: Mutex<u64>,
247        create_tx_credits: Mutex<Vec<u32>>,
248        bind_tx_calls: Mutex<Vec<(ChannelId, u32)>>,
249        register_rx_calls: Mutex<Vec<(ChannelId, u32)>>,
250        rx_senders: Mutex<HashMap<u64, mpsc::Sender<IncomingChannelMessage>>>,
251    }
252
253    impl TestBinder {
254        fn new() -> Self {
255            Self {
256                next_id: Mutex::new(100),
257                ..Self::default()
258            }
259        }
260
261        fn alloc_id(&self) -> ChannelId {
262            let mut guard = self.next_id.lock().expect("next-id mutex poisoned");
263            let id = *guard;
264            *guard += 2;
265            ChannelId(id)
266        }
267
268        fn sender_for(&self, channel_id: ChannelId) -> mpsc::Sender<IncomingChannelMessage> {
269            self.rx_senders
270                .lock()
271                .expect("sender map mutex poisoned")
272                .get(&channel_id.0)
273                .cloned()
274                .expect("missing sender for channel id")
275        }
276    }
277
278    impl ChannelBinder for TestBinder {
279        fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
280            self.create_tx_credits
281                .lock()
282                .expect("create-tx mutex poisoned")
283                .push(initial_credit);
284            (self.alloc_id(), Arc::new(TestSink))
285        }
286
287        fn create_rx(&self, _initial_credit: u32) -> (ChannelId, BoundChannelReceiver) {
288            let channel_id = self.alloc_id();
289            let (tx, rx) = mpsc::channel(8);
290            self.rx_senders
291                .lock()
292                .expect("sender map mutex poisoned")
293                .insert(channel_id.0, tx);
294            (
295                channel_id,
296                BoundChannelReceiver {
297                    receiver: rx,
298                    liveness: None,
299                    replenisher: None,
300                },
301            )
302        }
303
304        fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
305            self.bind_tx_calls
306                .lock()
307                .expect("bind-tx mutex poisoned")
308                .push((channel_id, initial_credit));
309            Arc::new(TestSink)
310        }
311
312        fn register_rx(&self, channel_id: ChannelId, initial_credit: u32) -> BoundChannelReceiver {
313            self.register_rx_calls
314                .lock()
315                .expect("register-rx mutex poisoned")
316                .push((channel_id, initial_credit));
317            let (tx, rx) = mpsc::channel(8);
318            self.rx_senders
319                .lock()
320                .expect("sender map mutex poisoned")
321                .insert(channel_id.0, tx);
322            BoundChannelReceiver {
323                receiver: rx,
324                liveness: None,
325                replenisher: None,
326            }
327        }
328    }
329
330    #[derive(Facet)]
331    struct CallerArgs {
332        tx: crate::Tx<u32, 16>,
333        rx: crate::Rx<u32, 16>,
334        maybe_tx: Option<crate::Tx<u32, 16>>,
335        maybe_rx: Option<crate::Rx<u32, 16>>,
336    }
337
338    #[derive(Facet)]
339    struct CalleeArgs {
340        tx: crate::Tx<u32, 16>,
341        rx: crate::Rx<u32, 16>,
342    }
343
344    #[tokio::test]
345    async fn bind_channels_caller_args_binds_paired_handles_and_skips_none_options() {
346        let (tx_arg, mut rx_peer) = channel::<u32>();
347        let (tx_peer, rx_arg) = channel::<u32>();
348        let mut args = CallerArgs {
349            tx: tx_arg,
350            rx: rx_arg,
351            maybe_tx: None,
352            maybe_rx: None,
353        };
354
355        let plan = RpcPlan::for_type::<CallerArgs>();
356        let binder = TestBinder::new();
357
358        let channel_ids = unsafe {
359            bind_channels_caller_args((&mut args as *mut CallerArgs).cast::<u8>(), plan, &binder)
360        };
361
362        assert_eq!(
363            channel_ids.len(),
364            2,
365            "only present channels should be bound"
366        );
367        assert_eq!(
368            binder
369                .create_tx_credits
370                .lock()
371                .expect("create-tx mutex poisoned")
372                .as_slice(),
373            &[16],
374            "Rx<T, N> in caller args should allocate sink with declared N credit"
375        );
376
377        tx_peer
378            .send(77)
379            .await
380            .expect("paired Tx should become bound via create_tx");
381
382        let close_ref = SelfRef::owning(
383            Backing::Boxed(Box::<[u8]>::default()),
384            ChannelClose {
385                metadata: Metadata::default(),
386            },
387        );
388        binder
389            .sender_for(channel_ids[0])
390            .send(IncomingChannelMessage::Close(close_ref))
391            .await
392            .expect("send close to paired Rx");
393        assert!(
394            rx_peer.recv().await.expect("recv close").is_none(),
395            "paired Rx should become bound via create_rx"
396        );
397    }
398
399    #[tokio::test]
400    async fn bind_channels_callee_args_binds_tx_and_rx_with_supplied_ids() {
401        let mut args = CalleeArgs {
402            tx: Tx::unbound(),
403            rx: crate::Rx::unbound(),
404        };
405        let plan = RpcPlan::for_type::<CalleeArgs>();
406        let binder = TestBinder::new();
407        let channel_ids = [ChannelId(41), ChannelId(43)];
408
409        unsafe {
410            bind_channels_callee_args(
411                (&mut args as *mut CalleeArgs).cast::<u8>(),
412                plan,
413                &channel_ids,
414                &binder,
415            )
416        };
417
418        args.tx
419            .send(5)
420            .await
421            .expect("callee-side Tx should be bound via bind_tx");
422
423        let close_ref = SelfRef::owning(
424            Backing::Boxed(Box::<[u8]>::default()),
425            ChannelClose {
426                metadata: Metadata::default(),
427            },
428        );
429        binder
430            .sender_for(ChannelId(43))
431            .send(IncomingChannelMessage::Close(close_ref))
432            .await
433            .expect("send close to bound callee Rx");
434        assert!(args.rx.recv().await.expect("recv close").is_none());
435
436        assert_eq!(
437            binder
438                .bind_tx_calls
439                .lock()
440                .expect("bind-tx mutex poisoned")
441                .as_slice(),
442            &[(ChannelId(41), 16)]
443        );
444        assert_eq!(
445            binder
446                .register_rx_calls
447                .lock()
448                .expect("register-rx mutex poisoned")
449                .as_slice(),
450            &[(ChannelId(43), 16)]
451        );
452    }
453
454    #[tokio::test]
455    async fn bind_channels_callee_args_stops_when_channel_ids_are_exhausted() {
456        let mut args = CalleeArgs {
457            tx: Tx::unbound(),
458            rx: crate::Rx::unbound(),
459        };
460        let plan = RpcPlan::for_type::<CalleeArgs>();
461        let binder = TestBinder::new();
462        let only_one_id = [ChannelId(51)];
463
464        unsafe {
465            bind_channels_callee_args(
466                (&mut args as *mut CalleeArgs).cast::<u8>(),
467                plan,
468                &only_one_id,
469                &binder,
470            )
471        };
472
473        args.tx
474            .send(1)
475            .await
476            .expect("first channel location should bind");
477        let recv = args.rx.recv().await;
478        assert!(
479            matches!(recv, Err(RxError::Unbound)),
480            "second channel location should remain unbound when IDs are exhausted"
481        );
482    }
483}