1#![cfg(not(target_arch = "wasm32"))]
2use std::sync::Arc;
12
13use facet_core::PtrMut;
14use facet_path::PathAccessError;
15use tokio::sync::mpsc;
16
17use crate::ChannelId;
18use crate::channel::{
19 ChannelBinding, ChannelSink, CoreSlot, IncomingChannelMessage, ReceiverSlot, SinkSlot,
20};
21use crate::rpc_plan::{ChannelKind, RpcPlan};
22
23pub trait ChannelBinder: Send + Sync {
28 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>);
32
33 fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>);
35
36 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink>;
41
42 fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage>;
46}
47
48#[allow(unsafe_code)]
61pub unsafe fn bind_channels_caller_args(
62 args_ptr: *mut u8,
63 plan: &RpcPlan,
64 binder: &dyn ChannelBinder,
65) -> Vec<ChannelId> {
66 let shape = plan.shape;
67 let mut channel_ids = Vec::new();
68
69 for loc in plan.channel_locations {
70 let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
72
73 match poke.at_path_mut(&loc.path) {
74 Ok(channel_poke) => match loc.kind {
75 ChannelKind::Rx => {
80 let (channel_id, sink) = binder.create_tx(loc.initial_credit);
81 channel_ids.push(channel_id);
82 if let Ok(mut ps) = channel_poke.into_struct()
83 && let Ok(mut core_field) = ps.field_by_name("core")
84 && let Ok(slot) = core_field.get_mut::<CoreSlot>()
85 && let Some(core) = &slot.inner
86 {
87 core.set_binding(ChannelBinding::Sink(sink));
88 }
89 }
90 ChannelKind::Tx => {
95 let (channel_id, receiver) = binder.create_rx();
96 channel_ids.push(channel_id);
97 if let Ok(mut ps) = channel_poke.into_struct()
98 && let Ok(mut core_field) = ps.field_by_name("core")
99 && let Ok(slot) = core_field.get_mut::<CoreSlot>()
100 && let Some(core) = &slot.inner
101 {
102 core.set_binding(ChannelBinding::Receiver(receiver));
103 }
104 }
105 },
106 Err(PathAccessError::OptionIsNone { .. }) => {
107 }
109 Err(_) => {}
110 }
111 }
112
113 channel_ids
114}
115
116#[allow(unsafe_code)]
128pub unsafe fn bind_channels_callee_args(
129 args_ptr: *mut u8,
130 plan: &RpcPlan,
131 channel_ids: &[ChannelId],
132 binder: &dyn ChannelBinder,
133) {
134 let shape = plan.shape;
135 let mut id_idx = 0;
136
137 for loc in plan.channel_locations {
138 let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
140
141 match poke.at_path_mut(&loc.path) {
142 Ok(channel_poke) => {
143 if id_idx >= channel_ids.len() {
144 break;
145 }
146 let channel_id = channel_ids[id_idx];
147 id_idx += 1;
148
149 match loc.kind {
150 ChannelKind::Tx => {
153 let sink = binder.bind_tx(channel_id, loc.initial_credit);
154 if let Ok(mut ps) = channel_poke.into_struct()
155 && let Ok(mut sink_field) = ps.field_by_name("sink")
156 && let Ok(slot) = sink_field.get_mut::<SinkSlot>()
157 {
158 slot.inner = Some(sink);
159 }
160 }
161 ChannelKind::Rx => {
164 let receiver = binder.register_rx(channel_id);
165 if let Ok(mut ps) = channel_poke.into_struct()
166 && let Ok(mut receiver_field) = ps.field_by_name("receiver")
167 && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
168 {
169 slot.inner = Some(receiver);
170 }
171 }
172 }
173 }
174 Err(PathAccessError::OptionIsNone { .. }) => {
175 }
177 Err(_) => {}
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use std::collections::HashMap;
185 use std::future::Future;
186 use std::pin::Pin;
187 use std::sync::{Arc, Mutex};
188
189 use facet::Facet;
190 use tokio::sync::mpsc;
191
192 use crate::channel::{ChannelSink, IncomingChannelMessage, RxError, TxError, channel};
193 use crate::{Backing, ChannelClose, ChannelId, Metadata, Payload, RpcPlan, SelfRef, Tx};
194
195 use super::{ChannelBinder, bind_channels_callee_args, bind_channels_caller_args};
196
197 #[derive(Default)]
198 struct TestSink;
199
200 impl ChannelSink for TestSink {
201 fn send_payload<'payload>(
202 &self,
203 _payload: Payload<'payload>,
204 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
205 Box::pin(async { Ok(()) })
206 }
207
208 fn close_channel(
209 &self,
210 _metadata: Metadata,
211 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
212 Box::pin(async { Ok(()) })
213 }
214 }
215
216 #[derive(Default)]
217 struct TestBinder {
218 next_id: Mutex<u64>,
219 create_tx_credits: Mutex<Vec<u32>>,
220 bind_tx_calls: Mutex<Vec<(ChannelId, u32)>>,
221 register_rx_calls: Mutex<Vec<ChannelId>>,
222 rx_senders: Mutex<HashMap<u64, mpsc::Sender<IncomingChannelMessage>>>,
223 }
224
225 impl TestBinder {
226 fn new() -> Self {
227 Self {
228 next_id: Mutex::new(100),
229 ..Self::default()
230 }
231 }
232
233 fn alloc_id(&self) -> ChannelId {
234 let mut guard = self.next_id.lock().expect("next-id mutex poisoned");
235 let id = *guard;
236 *guard += 2;
237 ChannelId(id)
238 }
239
240 fn sender_for(&self, channel_id: ChannelId) -> mpsc::Sender<IncomingChannelMessage> {
241 self.rx_senders
242 .lock()
243 .expect("sender map mutex poisoned")
244 .get(&channel_id.0)
245 .cloned()
246 .expect("missing sender for channel id")
247 }
248 }
249
250 impl ChannelBinder for TestBinder {
251 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
252 self.create_tx_credits
253 .lock()
254 .expect("create-tx mutex poisoned")
255 .push(initial_credit);
256 (self.alloc_id(), Arc::new(TestSink))
257 }
258
259 fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>) {
260 let channel_id = self.alloc_id();
261 let (tx, rx) = mpsc::channel(8);
262 self.rx_senders
263 .lock()
264 .expect("sender map mutex poisoned")
265 .insert(channel_id.0, tx);
266 (channel_id, rx)
267 }
268
269 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
270 self.bind_tx_calls
271 .lock()
272 .expect("bind-tx mutex poisoned")
273 .push((channel_id, initial_credit));
274 Arc::new(TestSink)
275 }
276
277 fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage> {
278 self.register_rx_calls
279 .lock()
280 .expect("register-rx mutex poisoned")
281 .push(channel_id);
282 let (tx, rx) = mpsc::channel(8);
283 self.rx_senders
284 .lock()
285 .expect("sender map mutex poisoned")
286 .insert(channel_id.0, tx);
287 rx
288 }
289 }
290
291 #[derive(Facet)]
292 struct CallerArgs {
293 tx: crate::Tx<u32, 16>,
294 rx: crate::Rx<u32, 16>,
295 maybe_tx: Option<crate::Tx<u32, 16>>,
296 maybe_rx: Option<crate::Rx<u32, 16>>,
297 }
298
299 #[derive(Facet)]
300 struct CalleeArgs {
301 tx: crate::Tx<u32, 16>,
302 rx: crate::Rx<u32, 16>,
303 }
304
305 #[tokio::test]
306 async fn bind_channels_caller_args_binds_paired_handles_and_skips_none_options() {
307 let (tx_arg, mut rx_peer) = channel::<u32>();
308 let (tx_peer, rx_arg) = channel::<u32>();
309 let mut args = CallerArgs {
310 tx: tx_arg,
311 rx: rx_arg,
312 maybe_tx: None,
313 maybe_rx: None,
314 };
315
316 let plan = RpcPlan::for_type::<CallerArgs>();
317 let binder = TestBinder::new();
318
319 let channel_ids = unsafe {
320 bind_channels_caller_args((&mut args as *mut CallerArgs).cast::<u8>(), plan, &binder)
321 };
322
323 assert_eq!(
324 channel_ids.len(),
325 2,
326 "only present channels should be bound"
327 );
328 assert_eq!(
329 binder
330 .create_tx_credits
331 .lock()
332 .expect("create-tx mutex poisoned")
333 .as_slice(),
334 &[16],
335 "Rx<T, N> in caller args should allocate sink with declared N credit"
336 );
337
338 tx_peer
339 .send(77)
340 .await
341 .expect("paired Tx should become bound via create_tx");
342
343 let close_ref = SelfRef::owning(
344 Backing::Boxed(Box::<[u8]>::default()),
345 ChannelClose {
346 metadata: Metadata::default(),
347 },
348 );
349 binder
350 .sender_for(channel_ids[0])
351 .send(IncomingChannelMessage::Close(close_ref))
352 .await
353 .expect("send close to paired Rx");
354 assert!(
355 rx_peer.recv().await.expect("recv close").is_none(),
356 "paired Rx should become bound via create_rx"
357 );
358 }
359
360 #[tokio::test]
361 async fn bind_channels_callee_args_binds_tx_and_rx_with_supplied_ids() {
362 let mut args = CalleeArgs {
363 tx: Tx::unbound(),
364 rx: crate::Rx::unbound(),
365 };
366 let plan = RpcPlan::for_type::<CalleeArgs>();
367 let binder = TestBinder::new();
368 let channel_ids = [ChannelId(41), ChannelId(43)];
369
370 unsafe {
371 bind_channels_callee_args(
372 (&mut args as *mut CalleeArgs).cast::<u8>(),
373 plan,
374 &channel_ids,
375 &binder,
376 )
377 };
378
379 args.tx
380 .send(5)
381 .await
382 .expect("callee-side Tx should be bound via bind_tx");
383
384 let close_ref = SelfRef::owning(
385 Backing::Boxed(Box::<[u8]>::default()),
386 ChannelClose {
387 metadata: Metadata::default(),
388 },
389 );
390 binder
391 .sender_for(ChannelId(43))
392 .send(IncomingChannelMessage::Close(close_ref))
393 .await
394 .expect("send close to bound callee Rx");
395 assert!(args.rx.recv().await.expect("recv close").is_none());
396
397 assert_eq!(
398 binder
399 .bind_tx_calls
400 .lock()
401 .expect("bind-tx mutex poisoned")
402 .as_slice(),
403 &[(ChannelId(41), 16)]
404 );
405 assert_eq!(
406 binder
407 .register_rx_calls
408 .lock()
409 .expect("register-rx mutex poisoned")
410 .as_slice(),
411 &[ChannelId(43)]
412 );
413 }
414
415 #[tokio::test]
416 async fn bind_channels_callee_args_stops_when_channel_ids_are_exhausted() {
417 let mut args = CalleeArgs {
418 tx: Tx::unbound(),
419 rx: crate::Rx::unbound(),
420 };
421 let plan = RpcPlan::for_type::<CalleeArgs>();
422 let binder = TestBinder::new();
423 let only_one_id = [ChannelId(51)];
424
425 unsafe {
426 bind_channels_callee_args(
427 (&mut args as *mut CalleeArgs).cast::<u8>(),
428 plan,
429 &only_one_id,
430 &binder,
431 )
432 };
433
434 args.tx
435 .send(1)
436 .await
437 .expect("first channel location should bind");
438 let recv = args.rx.recv().await;
439 assert!(
440 matches!(recv, Err(RxError::Unbound)),
441 "second channel location should remain unbound when IDs are exhausted"
442 );
443 }
444}