Skip to main content

roam_types/
channel_binding.rs

1#![cfg(not(target_arch = "wasm32"))]
2//! Channel binding infrastructure for connecting Tx/Rx handles to the driver.
3//!
4//! Binding functions handle channel binding for request args:
5//!
6//! - [`bind_channels_caller_args`]: Caller-side, arg position. Allocates IDs,
7//!   stores bindings in the shared core so the paired handle can use them.
8//! - [`bind_channels_callee_args`]: Callee-side, arg position. Binds deserialized
9//!   standalone handles directly using IDs from `Request.channels`.
10
11use std::sync::Arc;
12
13use facet_core::PtrMut;
14use facet_path::PathAccessError;
15use tokio::sync::mpsc;
16
17use crate::ChannelId;
18use crate::channel::{
19    ChannelBinding, ChannelSink, CoreSlot, IncomingChannelMessage, ReceiverSlot, SinkSlot,
20};
21use crate::rpc_plan::{ChannelKind, RpcPlan};
22
23/// Trait for channel operations, implemented by the session driver.
24///
25/// This abstraction lets the binding functions and macro-generated code bind
26/// channels without depending on concrete driver types.
27pub trait ChannelBinder: Send + Sync {
28    /// Allocate a channel ID and create a sink for sending items.
29    ///
30    /// `initial_credit` is the const generic `N` from `Tx<T, N>` or `Rx<T, N>`.
31    fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>);
32
33    /// Allocate a channel ID, register it for routing, and return a receiver.
34    fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>);
35
36    /// Create a sink for a known channel ID (callee side).
37    ///
38    /// The channel ID comes from `Request.channels`.
39    /// `initial_credit` is the const generic `N` from `Tx<T, N>`.
40    fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink>;
41
42    /// Register an inbound channel by ID and return the receiver (callee side).
43    ///
44    /// The channel ID comes from `Request.channels`.
45    fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage>;
46}
47
48// r[impl rpc.channel.binding.caller-args]
49// r[impl rpc.channel.allocation]
50/// Bind channels in args on the **caller** side, returning channel IDs.
51///
52/// The caller created `(tx, rx)` pairs via `channel()`. Only one handle from
53/// each pair is in the args; the other was kept by the caller. This function
54/// stores bindings in the shared core so the kept handle can use them.
55///
56/// # Safety
57///
58/// `args_ptr` must point to valid, initialized memory for a value whose
59/// shape matches `plan.shape`.
60#[allow(unsafe_code)]
61pub unsafe fn bind_channels_caller_args(
62    args_ptr: *mut u8,
63    plan: &RpcPlan,
64    binder: &dyn ChannelBinder,
65) -> Vec<ChannelId> {
66    let shape = plan.shape;
67    let mut channel_ids = Vec::new();
68
69    for loc in plan.channel_locations {
70        // SAFETY: caller guarantees args_ptr is valid and initialized for this shape
71        let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
72
73        match poke.at_path_mut(&loc.path) {
74            Ok(channel_poke) => match loc.kind {
75                // r[impl rpc.channel.binding.caller-args.rx]
76                // Rx in args: handler receives, caller sends.
77                // Create a sink and store it in the shared core so the caller's
78                // paired Tx can send through it.
79                ChannelKind::Rx => {
80                    let (channel_id, sink) = binder.create_tx(loc.initial_credit);
81                    channel_ids.push(channel_id);
82                    if let Ok(mut ps) = channel_poke.into_struct()
83                        && let Ok(mut core_field) = ps.field_by_name("core")
84                        && let Ok(slot) = core_field.get_mut::<CoreSlot>()
85                        && let Some(core) = &slot.inner
86                    {
87                        core.set_binding(ChannelBinding::Sink(sink));
88                    }
89                }
90                // r[impl rpc.channel.binding.caller-args.tx]
91                // Tx in args: handler sends, caller receives.
92                // Create a receiver and store it in the shared core so the caller's
93                // paired Rx can receive from it.
94                ChannelKind::Tx => {
95                    let (channel_id, receiver) = binder.create_rx();
96                    channel_ids.push(channel_id);
97                    if let Ok(mut ps) = channel_poke.into_struct()
98                        && let Ok(mut core_field) = ps.field_by_name("core")
99                        && let Ok(slot) = core_field.get_mut::<CoreSlot>()
100                        && let Some(core) = &slot.inner
101                    {
102                        core.set_binding(ChannelBinding::Receiver(receiver));
103                    }
104                }
105            },
106            Err(PathAccessError::OptionIsNone { .. }) => {
107                // Option<Tx/Rx> is None — skip
108            }
109            Err(_) => {}
110        }
111    }
112
113    channel_ids
114}
115
116// r[impl rpc.channel.binding]
117// r[impl rpc.channel.binding.callee-args]
118/// Bind channels in deserialized args on the **callee** side.
119///
120/// Handles are standalone (not part of a pair). Bind directly into the
121/// handle's local slot using channel IDs from `Request.channels`.
122///
123/// # Safety
124///
125/// `args_ptr` must point to valid, initialized memory for a value whose
126/// shape matches `plan.shape`.
127#[allow(unsafe_code)]
128pub unsafe fn bind_channels_callee_args(
129    args_ptr: *mut u8,
130    plan: &RpcPlan,
131    channel_ids: &[ChannelId],
132    binder: &dyn ChannelBinder,
133) {
134    let shape = plan.shape;
135    let mut id_idx = 0;
136
137    for loc in plan.channel_locations {
138        // SAFETY: caller guarantees args_ptr is valid and initialized for this shape
139        let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
140
141        match poke.at_path_mut(&loc.path) {
142            Ok(channel_poke) => {
143                if id_idx >= channel_ids.len() {
144                    break;
145                }
146                let channel_id = channel_ids[id_idx];
147                id_idx += 1;
148
149                match loc.kind {
150                    // r[impl rpc.channel.binding.callee-args.tx]
151                    // Tx in args: handler sends. Bind a sink directly.
152                    ChannelKind::Tx => {
153                        let sink = binder.bind_tx(channel_id, loc.initial_credit);
154                        if let Ok(mut ps) = channel_poke.into_struct()
155                            && let Ok(mut sink_field) = ps.field_by_name("sink")
156                            && let Ok(slot) = sink_field.get_mut::<SinkSlot>()
157                        {
158                            slot.inner = Some(sink);
159                        }
160                    }
161                    // r[impl rpc.channel.binding.callee-args.rx]
162                    // Rx in args: handler receives. Register and bind a receiver directly.
163                    ChannelKind::Rx => {
164                        let receiver = binder.register_rx(channel_id);
165                        if let Ok(mut ps) = channel_poke.into_struct()
166                            && let Ok(mut receiver_field) = ps.field_by_name("receiver")
167                            && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
168                        {
169                            slot.inner = Some(receiver);
170                        }
171                    }
172                }
173            }
174            Err(PathAccessError::OptionIsNone { .. }) => {
175                // Option<Tx/Rx> is None — skip this channel location
176            }
177            Err(_) => {}
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use std::collections::HashMap;
185    use std::future::Future;
186    use std::pin::Pin;
187    use std::sync::{Arc, Mutex};
188
189    use facet::Facet;
190    use tokio::sync::mpsc;
191
192    use crate::channel::{ChannelSink, IncomingChannelMessage, RxError, TxError, channel};
193    use crate::{Backing, ChannelClose, ChannelId, Metadata, Payload, RpcPlan, SelfRef, Tx};
194
195    use super::{ChannelBinder, bind_channels_callee_args, bind_channels_caller_args};
196
197    #[derive(Default)]
198    struct TestSink;
199
200    impl ChannelSink for TestSink {
201        fn send_payload<'payload>(
202            &self,
203            _payload: Payload<'payload>,
204        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
205            Box::pin(async { Ok(()) })
206        }
207
208        fn close_channel(
209            &self,
210            _metadata: Metadata,
211        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
212            Box::pin(async { Ok(()) })
213        }
214    }
215
216    #[derive(Default)]
217    struct TestBinder {
218        next_id: Mutex<u64>,
219        create_tx_credits: Mutex<Vec<u32>>,
220        bind_tx_calls: Mutex<Vec<(ChannelId, u32)>>,
221        register_rx_calls: Mutex<Vec<ChannelId>>,
222        rx_senders: Mutex<HashMap<u64, mpsc::Sender<IncomingChannelMessage>>>,
223    }
224
225    impl TestBinder {
226        fn new() -> Self {
227            Self {
228                next_id: Mutex::new(100),
229                ..Self::default()
230            }
231        }
232
233        fn alloc_id(&self) -> ChannelId {
234            let mut guard = self.next_id.lock().expect("next-id mutex poisoned");
235            let id = *guard;
236            *guard += 2;
237            ChannelId(id)
238        }
239
240        fn sender_for(&self, channel_id: ChannelId) -> mpsc::Sender<IncomingChannelMessage> {
241            self.rx_senders
242                .lock()
243                .expect("sender map mutex poisoned")
244                .get(&channel_id.0)
245                .cloned()
246                .expect("missing sender for channel id")
247        }
248    }
249
250    impl ChannelBinder for TestBinder {
251        fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
252            self.create_tx_credits
253                .lock()
254                .expect("create-tx mutex poisoned")
255                .push(initial_credit);
256            (self.alloc_id(), Arc::new(TestSink))
257        }
258
259        fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>) {
260            let channel_id = self.alloc_id();
261            let (tx, rx) = mpsc::channel(8);
262            self.rx_senders
263                .lock()
264                .expect("sender map mutex poisoned")
265                .insert(channel_id.0, tx);
266            (channel_id, rx)
267        }
268
269        fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
270            self.bind_tx_calls
271                .lock()
272                .expect("bind-tx mutex poisoned")
273                .push((channel_id, initial_credit));
274            Arc::new(TestSink)
275        }
276
277        fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage> {
278            self.register_rx_calls
279                .lock()
280                .expect("register-rx mutex poisoned")
281                .push(channel_id);
282            let (tx, rx) = mpsc::channel(8);
283            self.rx_senders
284                .lock()
285                .expect("sender map mutex poisoned")
286                .insert(channel_id.0, tx);
287            rx
288        }
289    }
290
291    #[derive(Facet)]
292    struct CallerArgs {
293        tx: crate::Tx<u32, 16>,
294        rx: crate::Rx<u32, 16>,
295        maybe_tx: Option<crate::Tx<u32, 16>>,
296        maybe_rx: Option<crate::Rx<u32, 16>>,
297    }
298
299    #[derive(Facet)]
300    struct CalleeArgs {
301        tx: crate::Tx<u32, 16>,
302        rx: crate::Rx<u32, 16>,
303    }
304
305    #[tokio::test]
306    async fn bind_channels_caller_args_binds_paired_handles_and_skips_none_options() {
307        let (tx_arg, mut rx_peer) = channel::<u32>();
308        let (tx_peer, rx_arg) = channel::<u32>();
309        let mut args = CallerArgs {
310            tx: tx_arg,
311            rx: rx_arg,
312            maybe_tx: None,
313            maybe_rx: None,
314        };
315
316        let plan = RpcPlan::for_type::<CallerArgs>();
317        let binder = TestBinder::new();
318
319        let channel_ids = unsafe {
320            bind_channels_caller_args((&mut args as *mut CallerArgs).cast::<u8>(), plan, &binder)
321        };
322
323        assert_eq!(
324            channel_ids.len(),
325            2,
326            "only present channels should be bound"
327        );
328        assert_eq!(
329            binder
330                .create_tx_credits
331                .lock()
332                .expect("create-tx mutex poisoned")
333                .as_slice(),
334            &[16],
335            "Rx<T, N> in caller args should allocate sink with declared N credit"
336        );
337
338        tx_peer
339            .send(77)
340            .await
341            .expect("paired Tx should become bound via create_tx");
342
343        let close_ref = SelfRef::owning(
344            Backing::Boxed(Box::<[u8]>::default()),
345            ChannelClose {
346                metadata: Metadata::default(),
347            },
348        );
349        binder
350            .sender_for(channel_ids[0])
351            .send(IncomingChannelMessage::Close(close_ref))
352            .await
353            .expect("send close to paired Rx");
354        assert!(
355            rx_peer.recv().await.expect("recv close").is_none(),
356            "paired Rx should become bound via create_rx"
357        );
358    }
359
360    #[tokio::test]
361    async fn bind_channels_callee_args_binds_tx_and_rx_with_supplied_ids() {
362        let mut args = CalleeArgs {
363            tx: Tx::unbound(),
364            rx: crate::Rx::unbound(),
365        };
366        let plan = RpcPlan::for_type::<CalleeArgs>();
367        let binder = TestBinder::new();
368        let channel_ids = [ChannelId(41), ChannelId(43)];
369
370        unsafe {
371            bind_channels_callee_args(
372                (&mut args as *mut CalleeArgs).cast::<u8>(),
373                plan,
374                &channel_ids,
375                &binder,
376            )
377        };
378
379        args.tx
380            .send(5)
381            .await
382            .expect("callee-side Tx should be bound via bind_tx");
383
384        let close_ref = SelfRef::owning(
385            Backing::Boxed(Box::<[u8]>::default()),
386            ChannelClose {
387                metadata: Metadata::default(),
388            },
389        );
390        binder
391            .sender_for(ChannelId(43))
392            .send(IncomingChannelMessage::Close(close_ref))
393            .await
394            .expect("send close to bound callee Rx");
395        assert!(args.rx.recv().await.expect("recv close").is_none());
396
397        assert_eq!(
398            binder
399                .bind_tx_calls
400                .lock()
401                .expect("bind-tx mutex poisoned")
402                .as_slice(),
403            &[(ChannelId(41), 16)]
404        );
405        assert_eq!(
406            binder
407                .register_rx_calls
408                .lock()
409                .expect("register-rx mutex poisoned")
410                .as_slice(),
411            &[ChannelId(43)]
412        );
413    }
414
415    #[tokio::test]
416    async fn bind_channels_callee_args_stops_when_channel_ids_are_exhausted() {
417        let mut args = CalleeArgs {
418            tx: Tx::unbound(),
419            rx: crate::Rx::unbound(),
420        };
421        let plan = RpcPlan::for_type::<CalleeArgs>();
422        let binder = TestBinder::new();
423        let only_one_id = [ChannelId(51)];
424
425        unsafe {
426            bind_channels_callee_args(
427                (&mut args as *mut CalleeArgs).cast::<u8>(),
428                plan,
429                &only_one_id,
430                &binder,
431            )
432        };
433
434        args.tx
435            .send(1)
436            .await
437            .expect("first channel location should bind");
438        let recv = args.rx.recv().await;
439        assert!(
440            matches!(recv, Err(RxError::Unbound)),
441            "second channel location should remain unbound when IDs are exhausted"
442        );
443    }
444}