Skip to main content

roam_types/
channel_binding.rs

1#![cfg(not(target_arch = "wasm32"))]
2//! Channel binding infrastructure for connecting Tx/Rx handles to the driver.
3//!
4//! Binding functions handle channel binding for request args:
5//!
6//! - [`bind_channels_caller_args`]: Caller-side, arg position. Allocates IDs,
7//!   stores bindings in the shared core so the paired handle can use them.
8//! - [`bind_channels_callee_args`]: Callee-side, arg position. Binds deserialized
9//!   standalone handles directly using IDs from `Request.channels`.
10
11use std::sync::Arc;
12
13use facet_core::PtrMut;
14use facet_path::PathAccessError;
15use tokio::sync::mpsc;
16
17use crate::ChannelId;
18use crate::channel::{
19    BoundChannelReceiver, BoundChannelSink, ChannelBinding, ChannelLivenessHandle, ChannelSink,
20    CoreSlot, IncomingChannelMessage, ReceiverSlot, SinkSlot,
21};
22use crate::rpc_plan::{ChannelKind, RpcPlan};
23
24/// Trait for channel operations, implemented by the session driver.
25///
26/// This abstraction lets the binding functions and macro-generated code bind
27/// channels without depending on concrete driver types.
28pub trait ChannelBinder: Send + Sync {
29    /// Allocate a channel ID and create a sink for sending items.
30    ///
31    /// `initial_credit` is the const generic `N` from `Tx<T, N>` or `Rx<T, N>`.
32    fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>);
33
34    /// Allocate a channel ID, register it for routing, and return a receiver.
35    fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>);
36
37    /// Create a sink for a known channel ID (callee side).
38    ///
39    /// The channel ID comes from `Request.channels`.
40    /// `initial_credit` is the const generic `N` from `Tx<T, N>`.
41    fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink>;
42
43    /// Register an inbound channel by ID and return the receiver (callee side).
44    ///
45    /// The channel ID comes from `Request.channels`.
46    fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage>;
47
48    /// Optional opaque handle that keeps the underlying session/connection alive
49    /// for the lifetime of any bound channel handle.
50    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
51        None
52    }
53}
54
55// r[impl rpc.channel.binding.caller-args]
56// r[impl rpc.channel.allocation]
57/// Bind channels in args on the **caller** side, returning channel IDs.
58///
59/// The caller created `(tx, rx)` pairs via `channel()`. Only one handle from
60/// each pair is in the args; the other was kept by the caller. This function
61/// stores bindings in the shared core so the kept handle can use them.
62///
63/// # Safety
64///
65/// `args_ptr` must point to valid, initialized memory for a value whose
66/// shape matches `plan.shape`.
67#[allow(unsafe_code)]
68pub unsafe fn bind_channels_caller_args(
69    args_ptr: *mut u8,
70    plan: &RpcPlan,
71    binder: &dyn ChannelBinder,
72) -> Vec<ChannelId> {
73    let shape = plan.shape;
74    let mut channel_ids = Vec::new();
75
76    for loc in plan.channel_locations {
77        // SAFETY: caller guarantees args_ptr is valid and initialized for this shape
78        let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
79
80        match poke.at_path_mut(&loc.path) {
81            Ok(channel_poke) => match loc.kind {
82                // r[impl rpc.channel.binding.caller-args.rx]
83                // Rx in args: handler receives, caller sends.
84                // Create a sink and store it in the shared core so the caller's
85                // paired Tx can send through it.
86                ChannelKind::Rx => {
87                    let (channel_id, sink) = binder.create_tx(loc.initial_credit);
88                    channel_ids.push(channel_id);
89                    let liveness = binder.channel_liveness();
90                    if let Ok(mut ps) = channel_poke.into_struct()
91                        && let Ok(mut core_field) = ps.field_by_name("core")
92                        && let Ok(slot) = core_field.get_mut::<CoreSlot>()
93                        && let Some(core) = &slot.inner
94                    {
95                        core.set_binding(ChannelBinding::Sink(BoundChannelSink { sink, liveness }));
96                    }
97                }
98                // r[impl rpc.channel.binding.caller-args.tx]
99                // Tx in args: handler sends, caller receives.
100                // Create a receiver and store it in the shared core so the caller's
101                // paired Rx can receive from it.
102                ChannelKind::Tx => {
103                    let (channel_id, receiver) = binder.create_rx();
104                    channel_ids.push(channel_id);
105                    let liveness = binder.channel_liveness();
106                    if let Ok(mut ps) = channel_poke.into_struct()
107                        && let Ok(mut core_field) = ps.field_by_name("core")
108                        && let Ok(slot) = core_field.get_mut::<CoreSlot>()
109                        && let Some(core) = &slot.inner
110                    {
111                        core.set_binding(ChannelBinding::Receiver(BoundChannelReceiver {
112                            receiver,
113                            liveness,
114                        }));
115                    }
116                }
117            },
118            Err(PathAccessError::OptionIsNone { .. }) => {
119                // Option<Tx/Rx> is None — skip
120            }
121            Err(_) => {}
122        }
123    }
124
125    channel_ids
126}
127
128// r[impl rpc.channel.binding]
129// r[impl rpc.channel.binding.callee-args]
130/// Bind channels in deserialized args on the **callee** side.
131///
132/// Handles are standalone (not part of a pair). Bind directly into the
133/// handle's local slot using channel IDs from `Request.channels`.
134///
135/// # Safety
136///
137/// `args_ptr` must point to valid, initialized memory for a value whose
138/// shape matches `plan.shape`.
139#[allow(unsafe_code)]
140pub unsafe fn bind_channels_callee_args(
141    args_ptr: *mut u8,
142    plan: &RpcPlan,
143    channel_ids: &[ChannelId],
144    binder: &dyn ChannelBinder,
145) {
146    let shape = plan.shape;
147    let mut id_idx = 0;
148
149    for loc in plan.channel_locations {
150        // SAFETY: caller guarantees args_ptr is valid and initialized for this shape
151        let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
152
153        match poke.at_path_mut(&loc.path) {
154            Ok(channel_poke) => {
155                if id_idx >= channel_ids.len() {
156                    break;
157                }
158                let channel_id = channel_ids[id_idx];
159                id_idx += 1;
160
161                match loc.kind {
162                    // r[impl rpc.channel.binding.callee-args.tx]
163                    // Tx in args: handler sends. Bind a sink directly.
164                    ChannelKind::Tx => {
165                        let sink = binder.bind_tx(channel_id, loc.initial_credit);
166                        let liveness = binder.channel_liveness();
167                        if let Ok(mut ps) = channel_poke.into_struct() {
168                            if let Ok(mut sink_field) = ps.field_by_name("sink")
169                                && let Ok(slot) = sink_field.get_mut::<SinkSlot>()
170                            {
171                                slot.inner = Some(sink);
172                            }
173                            if let Ok(mut liveness_field) = ps.field_by_name("liveness")
174                                && let Ok(slot) =
175                                    liveness_field.get_mut::<crate::channel::LivenessSlot>()
176                            {
177                                slot.inner = liveness;
178                            }
179                        }
180                    }
181                    // r[impl rpc.channel.binding.callee-args.rx]
182                    // Rx in args: handler receives. Register and bind a receiver directly.
183                    ChannelKind::Rx => {
184                        let receiver = binder.register_rx(channel_id);
185                        let liveness = binder.channel_liveness();
186                        if let Ok(mut ps) = channel_poke.into_struct() {
187                            if let Ok(mut receiver_field) = ps.field_by_name("receiver")
188                                && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
189                            {
190                                slot.inner = Some(receiver);
191                            }
192                            if let Ok(mut liveness_field) = ps.field_by_name("liveness")
193                                && let Ok(slot) =
194                                    liveness_field.get_mut::<crate::channel::LivenessSlot>()
195                            {
196                                slot.inner = liveness;
197                            }
198                        }
199                    }
200                }
201            }
202            Err(PathAccessError::OptionIsNone { .. }) => {
203                // Option<Tx/Rx> is None — skip this channel location
204            }
205            Err(_) => {}
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use std::collections::HashMap;
213    use std::future::Future;
214    use std::pin::Pin;
215    use std::sync::{Arc, Mutex};
216
217    use facet::Facet;
218    use tokio::sync::mpsc;
219
220    use crate::channel::{ChannelSink, IncomingChannelMessage, RxError, TxError, channel};
221    use crate::{Backing, ChannelClose, ChannelId, Metadata, Payload, RpcPlan, SelfRef, Tx};
222
223    use super::{ChannelBinder, bind_channels_callee_args, bind_channels_caller_args};
224
225    #[derive(Default)]
226    struct TestSink;
227
228    impl ChannelSink for TestSink {
229        fn send_payload<'payload>(
230            &self,
231            _payload: Payload<'payload>,
232        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
233            Box::pin(async { Ok(()) })
234        }
235
236        fn close_channel(
237            &self,
238            _metadata: Metadata,
239        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
240            Box::pin(async { Ok(()) })
241        }
242    }
243
244    #[derive(Default)]
245    struct TestBinder {
246        next_id: Mutex<u64>,
247        create_tx_credits: Mutex<Vec<u32>>,
248        bind_tx_calls: Mutex<Vec<(ChannelId, u32)>>,
249        register_rx_calls: Mutex<Vec<ChannelId>>,
250        rx_senders: Mutex<HashMap<u64, mpsc::Sender<IncomingChannelMessage>>>,
251    }
252
253    impl TestBinder {
254        fn new() -> Self {
255            Self {
256                next_id: Mutex::new(100),
257                ..Self::default()
258            }
259        }
260
261        fn alloc_id(&self) -> ChannelId {
262            let mut guard = self.next_id.lock().expect("next-id mutex poisoned");
263            let id = *guard;
264            *guard += 2;
265            ChannelId(id)
266        }
267
268        fn sender_for(&self, channel_id: ChannelId) -> mpsc::Sender<IncomingChannelMessage> {
269            self.rx_senders
270                .lock()
271                .expect("sender map mutex poisoned")
272                .get(&channel_id.0)
273                .cloned()
274                .expect("missing sender for channel id")
275        }
276    }
277
278    impl ChannelBinder for TestBinder {
279        fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
280            self.create_tx_credits
281                .lock()
282                .expect("create-tx mutex poisoned")
283                .push(initial_credit);
284            (self.alloc_id(), Arc::new(TestSink))
285        }
286
287        fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>) {
288            let channel_id = self.alloc_id();
289            let (tx, rx) = mpsc::channel(8);
290            self.rx_senders
291                .lock()
292                .expect("sender map mutex poisoned")
293                .insert(channel_id.0, tx);
294            (channel_id, rx)
295        }
296
297        fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
298            self.bind_tx_calls
299                .lock()
300                .expect("bind-tx mutex poisoned")
301                .push((channel_id, initial_credit));
302            Arc::new(TestSink)
303        }
304
305        fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage> {
306            self.register_rx_calls
307                .lock()
308                .expect("register-rx mutex poisoned")
309                .push(channel_id);
310            let (tx, rx) = mpsc::channel(8);
311            self.rx_senders
312                .lock()
313                .expect("sender map mutex poisoned")
314                .insert(channel_id.0, tx);
315            rx
316        }
317    }
318
319    #[derive(Facet)]
320    struct CallerArgs {
321        tx: crate::Tx<u32, 16>,
322        rx: crate::Rx<u32, 16>,
323        maybe_tx: Option<crate::Tx<u32, 16>>,
324        maybe_rx: Option<crate::Rx<u32, 16>>,
325    }
326
327    #[derive(Facet)]
328    struct CalleeArgs {
329        tx: crate::Tx<u32, 16>,
330        rx: crate::Rx<u32, 16>,
331    }
332
333    #[tokio::test]
334    async fn bind_channels_caller_args_binds_paired_handles_and_skips_none_options() {
335        let (tx_arg, mut rx_peer) = channel::<u32>();
336        let (tx_peer, rx_arg) = channel::<u32>();
337        let mut args = CallerArgs {
338            tx: tx_arg,
339            rx: rx_arg,
340            maybe_tx: None,
341            maybe_rx: None,
342        };
343
344        let plan = RpcPlan::for_type::<CallerArgs>();
345        let binder = TestBinder::new();
346
347        let channel_ids = unsafe {
348            bind_channels_caller_args((&mut args as *mut CallerArgs).cast::<u8>(), plan, &binder)
349        };
350
351        assert_eq!(
352            channel_ids.len(),
353            2,
354            "only present channels should be bound"
355        );
356        assert_eq!(
357            binder
358                .create_tx_credits
359                .lock()
360                .expect("create-tx mutex poisoned")
361                .as_slice(),
362            &[16],
363            "Rx<T, N> in caller args should allocate sink with declared N credit"
364        );
365
366        tx_peer
367            .send(77)
368            .await
369            .expect("paired Tx should become bound via create_tx");
370
371        let close_ref = SelfRef::owning(
372            Backing::Boxed(Box::<[u8]>::default()),
373            ChannelClose {
374                metadata: Metadata::default(),
375            },
376        );
377        binder
378            .sender_for(channel_ids[0])
379            .send(IncomingChannelMessage::Close(close_ref))
380            .await
381            .expect("send close to paired Rx");
382        assert!(
383            rx_peer.recv().await.expect("recv close").is_none(),
384            "paired Rx should become bound via create_rx"
385        );
386    }
387
388    #[tokio::test]
389    async fn bind_channels_callee_args_binds_tx_and_rx_with_supplied_ids() {
390        let mut args = CalleeArgs {
391            tx: Tx::unbound(),
392            rx: crate::Rx::unbound(),
393        };
394        let plan = RpcPlan::for_type::<CalleeArgs>();
395        let binder = TestBinder::new();
396        let channel_ids = [ChannelId(41), ChannelId(43)];
397
398        unsafe {
399            bind_channels_callee_args(
400                (&mut args as *mut CalleeArgs).cast::<u8>(),
401                plan,
402                &channel_ids,
403                &binder,
404            )
405        };
406
407        args.tx
408            .send(5)
409            .await
410            .expect("callee-side Tx should be bound via bind_tx");
411
412        let close_ref = SelfRef::owning(
413            Backing::Boxed(Box::<[u8]>::default()),
414            ChannelClose {
415                metadata: Metadata::default(),
416            },
417        );
418        binder
419            .sender_for(ChannelId(43))
420            .send(IncomingChannelMessage::Close(close_ref))
421            .await
422            .expect("send close to bound callee Rx");
423        assert!(args.rx.recv().await.expect("recv close").is_none());
424
425        assert_eq!(
426            binder
427                .bind_tx_calls
428                .lock()
429                .expect("bind-tx mutex poisoned")
430                .as_slice(),
431            &[(ChannelId(41), 16)]
432        );
433        assert_eq!(
434            binder
435                .register_rx_calls
436                .lock()
437                .expect("register-rx mutex poisoned")
438                .as_slice(),
439            &[ChannelId(43)]
440        );
441    }
442
443    #[tokio::test]
444    async fn bind_channels_callee_args_stops_when_channel_ids_are_exhausted() {
445        let mut args = CalleeArgs {
446            tx: Tx::unbound(),
447            rx: crate::Rx::unbound(),
448        };
449        let plan = RpcPlan::for_type::<CalleeArgs>();
450        let binder = TestBinder::new();
451        let only_one_id = [ChannelId(51)];
452
453        unsafe {
454            bind_channels_callee_args(
455                (&mut args as *mut CalleeArgs).cast::<u8>(),
456                plan,
457                &only_one_id,
458                &binder,
459            )
460        };
461
462        args.tx
463            .send(1)
464            .await
465            .expect("first channel location should bind");
466        let recv = args.rx.recv().await;
467        assert!(
468            matches!(recv, Err(RxError::Unbound)),
469            "second channel location should remain unbound when IDs are exhausted"
470        );
471    }
472}