1#![cfg(not(target_arch = "wasm32"))]
2use std::sync::Arc;
12
13use crate::ChannelId;
14use crate::channel::{
15 BoundChannelReceiver, BoundChannelSink, ChannelBinding, ChannelLivenessHandle, ChannelSink,
16 CoreSlot, ReceiverSlot, ReplenisherSlot, SinkSlot,
17};
18use crate::rpc_plan::{ChannelKind, RpcPlan};
19use facet_core::PtrMut;
20use facet_path::PathAccessError;
21
22pub trait ChannelBinder: Send + Sync {
27 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>);
31
32 fn create_rx(&self, initial_credit: u32) -> (ChannelId, BoundChannelReceiver);
34
35 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink>;
40
41 fn register_rx(&self, channel_id: ChannelId, initial_credit: u32) -> BoundChannelReceiver;
45
46 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
49 None
50 }
51}
52
53#[allow(unsafe_code)]
66pub unsafe fn bind_channels_caller_args(
67 args_ptr: *mut u8,
68 plan: &RpcPlan,
69 binder: &dyn ChannelBinder,
70) -> Vec<ChannelId> {
71 let shape = plan.shape;
72 let mut channel_ids = Vec::new();
73
74 for loc in plan.channel_locations {
75 let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
77
78 match poke.at_path_mut(&loc.path) {
79 Ok(channel_poke) => match loc.kind {
80 ChannelKind::Rx => {
85 let (channel_id, sink) = binder.create_tx(loc.initial_credit);
86 channel_ids.push(channel_id);
87 let liveness = binder.channel_liveness();
88 if let Ok(mut ps) = channel_poke.into_struct()
89 && let Ok(mut core_field) = ps.field_by_name("core")
90 && let Ok(slot) = core_field.get_mut::<CoreSlot>()
91 && let Some(core) = &slot.inner
92 {
93 core.set_binding(ChannelBinding::Sink(BoundChannelSink { sink, liveness }));
94 }
95 }
96 ChannelKind::Tx => {
101 let (channel_id, bound) = binder.create_rx(loc.initial_credit);
102 channel_ids.push(channel_id);
103 if let Ok(mut ps) = channel_poke.into_struct()
104 && let Ok(mut core_field) = ps.field_by_name("core")
105 && let Ok(slot) = core_field.get_mut::<CoreSlot>()
106 && let Some(core) = &slot.inner
107 {
108 core.set_binding(ChannelBinding::Receiver(bound));
109 }
110 }
111 },
112 Err(PathAccessError::OptionIsNone { .. }) => {
113 }
115 Err(_) => {}
116 }
117 }
118
119 channel_ids
120}
121
122#[allow(unsafe_code)]
134pub unsafe fn bind_channels_callee_args(
135 args_ptr: *mut u8,
136 plan: &RpcPlan,
137 channel_ids: &[ChannelId],
138 binder: &dyn ChannelBinder,
139) {
140 let shape = plan.shape;
141 let mut id_idx = 0;
142
143 for loc in plan.channel_locations {
144 let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
146
147 match poke.at_path_mut(&loc.path) {
148 Ok(channel_poke) => {
149 if id_idx >= channel_ids.len() {
150 break;
151 }
152 let channel_id = channel_ids[id_idx];
153 id_idx += 1;
154
155 match loc.kind {
156 ChannelKind::Tx => {
159 let sink = binder.bind_tx(channel_id, loc.initial_credit);
160 let liveness = binder.channel_liveness();
161 if let Ok(mut ps) = channel_poke.into_struct() {
162 if let Ok(mut sink_field) = ps.field_by_name("sink")
163 && let Ok(slot) = sink_field.get_mut::<SinkSlot>()
164 {
165 slot.inner = Some(sink);
166 }
167 if let Ok(mut liveness_field) = ps.field_by_name("liveness")
168 && let Ok(slot) =
169 liveness_field.get_mut::<crate::channel::LivenessSlot>()
170 {
171 slot.inner = liveness;
172 }
173 }
174 }
175 ChannelKind::Rx => {
178 let bound = binder.register_rx(channel_id, loc.initial_credit);
179 if let Ok(mut ps) = channel_poke.into_struct() {
180 if let Ok(mut receiver_field) = ps.field_by_name("receiver")
181 && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
182 {
183 slot.inner = Some(bound.receiver);
184 }
185 if let Ok(mut liveness_field) = ps.field_by_name("liveness")
186 && let Ok(slot) =
187 liveness_field.get_mut::<crate::channel::LivenessSlot>()
188 {
189 slot.inner = bound.liveness;
190 }
191 if let Ok(mut replenisher_field) = ps.field_by_name("replenisher")
192 && let Ok(slot) = replenisher_field.get_mut::<ReplenisherSlot>()
193 {
194 slot.inner = bound.replenisher;
195 }
196 }
197 }
198 }
199 }
200 Err(PathAccessError::OptionIsNone { .. }) => {
201 }
203 Err(_) => {}
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::collections::HashMap;
211 use std::future::Future;
212 use std::pin::Pin;
213 use std::sync::{Arc, Mutex};
214
215 use facet::Facet;
216 use tokio::sync::mpsc;
217
218 use crate::channel::{
219 BoundChannelReceiver, ChannelSink, IncomingChannelMessage, RxError, TxError, channel,
220 };
221 use crate::{Backing, ChannelClose, ChannelId, Metadata, Payload, RpcPlan, SelfRef, Tx};
222
223 use super::{ChannelBinder, bind_channels_callee_args, bind_channels_caller_args};
224
225 #[derive(Default)]
226 struct TestSink;
227
228 impl ChannelSink for TestSink {
229 fn send_payload<'payload>(
230 &self,
231 _payload: Payload<'payload>,
232 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
233 Box::pin(async { Ok(()) })
234 }
235
236 fn close_channel(
237 &self,
238 _metadata: Metadata,
239 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
240 Box::pin(async { Ok(()) })
241 }
242 }
243
244 #[derive(Default)]
245 struct TestBinder {
246 next_id: Mutex<u64>,
247 create_tx_credits: Mutex<Vec<u32>>,
248 bind_tx_calls: Mutex<Vec<(ChannelId, u32)>>,
249 register_rx_calls: Mutex<Vec<(ChannelId, u32)>>,
250 rx_senders: Mutex<HashMap<u64, mpsc::Sender<IncomingChannelMessage>>>,
251 }
252
253 impl TestBinder {
254 fn new() -> Self {
255 Self {
256 next_id: Mutex::new(100),
257 ..Self::default()
258 }
259 }
260
261 fn alloc_id(&self) -> ChannelId {
262 let mut guard = self.next_id.lock().expect("next-id mutex poisoned");
263 let id = *guard;
264 *guard += 2;
265 ChannelId(id)
266 }
267
268 fn sender_for(&self, channel_id: ChannelId) -> mpsc::Sender<IncomingChannelMessage> {
269 self.rx_senders
270 .lock()
271 .expect("sender map mutex poisoned")
272 .get(&channel_id.0)
273 .cloned()
274 .expect("missing sender for channel id")
275 }
276 }
277
278 impl ChannelBinder for TestBinder {
279 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
280 self.create_tx_credits
281 .lock()
282 .expect("create-tx mutex poisoned")
283 .push(initial_credit);
284 (self.alloc_id(), Arc::new(TestSink))
285 }
286
287 fn create_rx(&self, _initial_credit: u32) -> (ChannelId, BoundChannelReceiver) {
288 let channel_id = self.alloc_id();
289 let (tx, rx) = mpsc::channel(8);
290 self.rx_senders
291 .lock()
292 .expect("sender map mutex poisoned")
293 .insert(channel_id.0, tx);
294 (
295 channel_id,
296 BoundChannelReceiver {
297 receiver: rx,
298 liveness: None,
299 replenisher: None,
300 },
301 )
302 }
303
304 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
305 self.bind_tx_calls
306 .lock()
307 .expect("bind-tx mutex poisoned")
308 .push((channel_id, initial_credit));
309 Arc::new(TestSink)
310 }
311
312 fn register_rx(&self, channel_id: ChannelId, initial_credit: u32) -> BoundChannelReceiver {
313 self.register_rx_calls
314 .lock()
315 .expect("register-rx mutex poisoned")
316 .push((channel_id, initial_credit));
317 let (tx, rx) = mpsc::channel(8);
318 self.rx_senders
319 .lock()
320 .expect("sender map mutex poisoned")
321 .insert(channel_id.0, tx);
322 BoundChannelReceiver {
323 receiver: rx,
324 liveness: None,
325 replenisher: None,
326 }
327 }
328 }
329
330 #[derive(Facet)]
331 struct CallerArgs {
332 tx: crate::Tx<u32, 16>,
333 rx: crate::Rx<u32, 16>,
334 maybe_tx: Option<crate::Tx<u32, 16>>,
335 maybe_rx: Option<crate::Rx<u32, 16>>,
336 }
337
338 #[derive(Facet)]
339 struct CalleeArgs {
340 tx: crate::Tx<u32, 16>,
341 rx: crate::Rx<u32, 16>,
342 }
343
344 #[tokio::test]
345 async fn bind_channels_caller_args_binds_paired_handles_and_skips_none_options() {
346 let (tx_arg, mut rx_peer) = channel::<u32>();
347 let (tx_peer, rx_arg) = channel::<u32>();
348 let mut args = CallerArgs {
349 tx: tx_arg,
350 rx: rx_arg,
351 maybe_tx: None,
352 maybe_rx: None,
353 };
354
355 let plan = RpcPlan::for_type::<CallerArgs>();
356 let binder = TestBinder::new();
357
358 let channel_ids = unsafe {
359 bind_channels_caller_args((&mut args as *mut CallerArgs).cast::<u8>(), plan, &binder)
360 };
361
362 assert_eq!(
363 channel_ids.len(),
364 2,
365 "only present channels should be bound"
366 );
367 assert_eq!(
368 binder
369 .create_tx_credits
370 .lock()
371 .expect("create-tx mutex poisoned")
372 .as_slice(),
373 &[16],
374 "Rx<T, N> in caller args should allocate sink with declared N credit"
375 );
376
377 tx_peer
378 .send(77)
379 .await
380 .expect("paired Tx should become bound via create_tx");
381
382 let close_ref = SelfRef::owning(
383 Backing::Boxed(Box::<[u8]>::default()),
384 ChannelClose {
385 metadata: Metadata::default(),
386 },
387 );
388 binder
389 .sender_for(channel_ids[0])
390 .send(IncomingChannelMessage::Close(close_ref))
391 .await
392 .expect("send close to paired Rx");
393 assert!(
394 rx_peer.recv().await.expect("recv close").is_none(),
395 "paired Rx should become bound via create_rx"
396 );
397 }
398
399 #[tokio::test]
400 async fn bind_channels_callee_args_binds_tx_and_rx_with_supplied_ids() {
401 let mut args = CalleeArgs {
402 tx: Tx::unbound(),
403 rx: crate::Rx::unbound(),
404 };
405 let plan = RpcPlan::for_type::<CalleeArgs>();
406 let binder = TestBinder::new();
407 let channel_ids = [ChannelId(41), ChannelId(43)];
408
409 unsafe {
410 bind_channels_callee_args(
411 (&mut args as *mut CalleeArgs).cast::<u8>(),
412 plan,
413 &channel_ids,
414 &binder,
415 )
416 };
417
418 args.tx
419 .send(5)
420 .await
421 .expect("callee-side Tx should be bound via bind_tx");
422
423 let close_ref = SelfRef::owning(
424 Backing::Boxed(Box::<[u8]>::default()),
425 ChannelClose {
426 metadata: Metadata::default(),
427 },
428 );
429 binder
430 .sender_for(ChannelId(43))
431 .send(IncomingChannelMessage::Close(close_ref))
432 .await
433 .expect("send close to bound callee Rx");
434 assert!(args.rx.recv().await.expect("recv close").is_none());
435
436 assert_eq!(
437 binder
438 .bind_tx_calls
439 .lock()
440 .expect("bind-tx mutex poisoned")
441 .as_slice(),
442 &[(ChannelId(41), 16)]
443 );
444 assert_eq!(
445 binder
446 .register_rx_calls
447 .lock()
448 .expect("register-rx mutex poisoned")
449 .as_slice(),
450 &[(ChannelId(43), 16)]
451 );
452 }
453
454 #[tokio::test]
455 async fn bind_channels_callee_args_stops_when_channel_ids_are_exhausted() {
456 let mut args = CalleeArgs {
457 tx: Tx::unbound(),
458 rx: crate::Rx::unbound(),
459 };
460 let plan = RpcPlan::for_type::<CalleeArgs>();
461 let binder = TestBinder::new();
462 let only_one_id = [ChannelId(51)];
463
464 unsafe {
465 bind_channels_callee_args(
466 (&mut args as *mut CalleeArgs).cast::<u8>(),
467 plan,
468 &only_one_id,
469 &binder,
470 )
471 };
472
473 args.tx
474 .send(1)
475 .await
476 .expect("first channel location should bind");
477 let recv = args.rx.recv().await;
478 assert!(
479 matches!(recv, Err(RxError::Unbound)),
480 "second channel location should remain unbound when IDs are exhausted"
481 );
482 }
483}