1#![cfg(not(target_arch = "wasm32"))]
2use std::sync::Arc;
12
13use facet_core::PtrMut;
14use facet_path::PathAccessError;
15use tokio::sync::mpsc;
16
17use crate::ChannelId;
18use crate::channel::{
19 BoundChannelReceiver, BoundChannelSink, ChannelBinding, ChannelLivenessHandle, ChannelSink,
20 CoreSlot, IncomingChannelMessage, ReceiverSlot, SinkSlot,
21};
22use crate::rpc_plan::{ChannelKind, RpcPlan};
23
24pub trait ChannelBinder: Send + Sync {
29 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>);
33
34 fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>);
36
37 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink>;
42
43 fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage>;
47
48 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
51 None
52 }
53}
54
55#[allow(unsafe_code)]
68pub unsafe fn bind_channels_caller_args(
69 args_ptr: *mut u8,
70 plan: &RpcPlan,
71 binder: &dyn ChannelBinder,
72) -> Vec<ChannelId> {
73 let shape = plan.shape;
74 let mut channel_ids = Vec::new();
75
76 for loc in plan.channel_locations {
77 let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
79
80 match poke.at_path_mut(&loc.path) {
81 Ok(channel_poke) => match loc.kind {
82 ChannelKind::Rx => {
87 let (channel_id, sink) = binder.create_tx(loc.initial_credit);
88 channel_ids.push(channel_id);
89 let liveness = binder.channel_liveness();
90 if let Ok(mut ps) = channel_poke.into_struct()
91 && let Ok(mut core_field) = ps.field_by_name("core")
92 && let Ok(slot) = core_field.get_mut::<CoreSlot>()
93 && let Some(core) = &slot.inner
94 {
95 core.set_binding(ChannelBinding::Sink(BoundChannelSink { sink, liveness }));
96 }
97 }
98 ChannelKind::Tx => {
103 let (channel_id, receiver) = binder.create_rx();
104 channel_ids.push(channel_id);
105 let liveness = binder.channel_liveness();
106 if let Ok(mut ps) = channel_poke.into_struct()
107 && let Ok(mut core_field) = ps.field_by_name("core")
108 && let Ok(slot) = core_field.get_mut::<CoreSlot>()
109 && let Some(core) = &slot.inner
110 {
111 core.set_binding(ChannelBinding::Receiver(BoundChannelReceiver {
112 receiver,
113 liveness,
114 }));
115 }
116 }
117 },
118 Err(PathAccessError::OptionIsNone { .. }) => {
119 }
121 Err(_) => {}
122 }
123 }
124
125 channel_ids
126}
127
128#[allow(unsafe_code)]
140pub unsafe fn bind_channels_callee_args(
141 args_ptr: *mut u8,
142 plan: &RpcPlan,
143 channel_ids: &[ChannelId],
144 binder: &dyn ChannelBinder,
145) {
146 let shape = plan.shape;
147 let mut id_idx = 0;
148
149 for loc in plan.channel_locations {
150 let poke = unsafe { facet::Poke::from_raw_parts(PtrMut::new(args_ptr), shape) };
152
153 match poke.at_path_mut(&loc.path) {
154 Ok(channel_poke) => {
155 if id_idx >= channel_ids.len() {
156 break;
157 }
158 let channel_id = channel_ids[id_idx];
159 id_idx += 1;
160
161 match loc.kind {
162 ChannelKind::Tx => {
165 let sink = binder.bind_tx(channel_id, loc.initial_credit);
166 let liveness = binder.channel_liveness();
167 if let Ok(mut ps) = channel_poke.into_struct() {
168 if let Ok(mut sink_field) = ps.field_by_name("sink")
169 && let Ok(slot) = sink_field.get_mut::<SinkSlot>()
170 {
171 slot.inner = Some(sink);
172 }
173 if let Ok(mut liveness_field) = ps.field_by_name("liveness")
174 && let Ok(slot) =
175 liveness_field.get_mut::<crate::channel::LivenessSlot>()
176 {
177 slot.inner = liveness;
178 }
179 }
180 }
181 ChannelKind::Rx => {
184 let receiver = binder.register_rx(channel_id);
185 let liveness = binder.channel_liveness();
186 if let Ok(mut ps) = channel_poke.into_struct() {
187 if let Ok(mut receiver_field) = ps.field_by_name("receiver")
188 && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
189 {
190 slot.inner = Some(receiver);
191 }
192 if let Ok(mut liveness_field) = ps.field_by_name("liveness")
193 && let Ok(slot) =
194 liveness_field.get_mut::<crate::channel::LivenessSlot>()
195 {
196 slot.inner = liveness;
197 }
198 }
199 }
200 }
201 }
202 Err(PathAccessError::OptionIsNone { .. }) => {
203 }
205 Err(_) => {}
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use std::collections::HashMap;
213 use std::future::Future;
214 use std::pin::Pin;
215 use std::sync::{Arc, Mutex};
216
217 use facet::Facet;
218 use tokio::sync::mpsc;
219
220 use crate::channel::{ChannelSink, IncomingChannelMessage, RxError, TxError, channel};
221 use crate::{Backing, ChannelClose, ChannelId, Metadata, Payload, RpcPlan, SelfRef, Tx};
222
223 use super::{ChannelBinder, bind_channels_callee_args, bind_channels_caller_args};
224
225 #[derive(Default)]
226 struct TestSink;
227
228 impl ChannelSink for TestSink {
229 fn send_payload<'payload>(
230 &self,
231 _payload: Payload<'payload>,
232 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
233 Box::pin(async { Ok(()) })
234 }
235
236 fn close_channel(
237 &self,
238 _metadata: Metadata,
239 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
240 Box::pin(async { Ok(()) })
241 }
242 }
243
244 #[derive(Default)]
245 struct TestBinder {
246 next_id: Mutex<u64>,
247 create_tx_credits: Mutex<Vec<u32>>,
248 bind_tx_calls: Mutex<Vec<(ChannelId, u32)>>,
249 register_rx_calls: Mutex<Vec<ChannelId>>,
250 rx_senders: Mutex<HashMap<u64, mpsc::Sender<IncomingChannelMessage>>>,
251 }
252
253 impl TestBinder {
254 fn new() -> Self {
255 Self {
256 next_id: Mutex::new(100),
257 ..Self::default()
258 }
259 }
260
261 fn alloc_id(&self) -> ChannelId {
262 let mut guard = self.next_id.lock().expect("next-id mutex poisoned");
263 let id = *guard;
264 *guard += 2;
265 ChannelId(id)
266 }
267
268 fn sender_for(&self, channel_id: ChannelId) -> mpsc::Sender<IncomingChannelMessage> {
269 self.rx_senders
270 .lock()
271 .expect("sender map mutex poisoned")
272 .get(&channel_id.0)
273 .cloned()
274 .expect("missing sender for channel id")
275 }
276 }
277
278 impl ChannelBinder for TestBinder {
279 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
280 self.create_tx_credits
281 .lock()
282 .expect("create-tx mutex poisoned")
283 .push(initial_credit);
284 (self.alloc_id(), Arc::new(TestSink))
285 }
286
287 fn create_rx(&self) -> (ChannelId, mpsc::Receiver<IncomingChannelMessage>) {
288 let channel_id = self.alloc_id();
289 let (tx, rx) = mpsc::channel(8);
290 self.rx_senders
291 .lock()
292 .expect("sender map mutex poisoned")
293 .insert(channel_id.0, tx);
294 (channel_id, rx)
295 }
296
297 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
298 self.bind_tx_calls
299 .lock()
300 .expect("bind-tx mutex poisoned")
301 .push((channel_id, initial_credit));
302 Arc::new(TestSink)
303 }
304
305 fn register_rx(&self, channel_id: ChannelId) -> mpsc::Receiver<IncomingChannelMessage> {
306 self.register_rx_calls
307 .lock()
308 .expect("register-rx mutex poisoned")
309 .push(channel_id);
310 let (tx, rx) = mpsc::channel(8);
311 self.rx_senders
312 .lock()
313 .expect("sender map mutex poisoned")
314 .insert(channel_id.0, tx);
315 rx
316 }
317 }
318
319 #[derive(Facet)]
320 struct CallerArgs {
321 tx: crate::Tx<u32, 16>,
322 rx: crate::Rx<u32, 16>,
323 maybe_tx: Option<crate::Tx<u32, 16>>,
324 maybe_rx: Option<crate::Rx<u32, 16>>,
325 }
326
327 #[derive(Facet)]
328 struct CalleeArgs {
329 tx: crate::Tx<u32, 16>,
330 rx: crate::Rx<u32, 16>,
331 }
332
333 #[tokio::test]
334 async fn bind_channels_caller_args_binds_paired_handles_and_skips_none_options() {
335 let (tx_arg, mut rx_peer) = channel::<u32>();
336 let (tx_peer, rx_arg) = channel::<u32>();
337 let mut args = CallerArgs {
338 tx: tx_arg,
339 rx: rx_arg,
340 maybe_tx: None,
341 maybe_rx: None,
342 };
343
344 let plan = RpcPlan::for_type::<CallerArgs>();
345 let binder = TestBinder::new();
346
347 let channel_ids = unsafe {
348 bind_channels_caller_args((&mut args as *mut CallerArgs).cast::<u8>(), plan, &binder)
349 };
350
351 assert_eq!(
352 channel_ids.len(),
353 2,
354 "only present channels should be bound"
355 );
356 assert_eq!(
357 binder
358 .create_tx_credits
359 .lock()
360 .expect("create-tx mutex poisoned")
361 .as_slice(),
362 &[16],
363 "Rx<T, N> in caller args should allocate sink with declared N credit"
364 );
365
366 tx_peer
367 .send(77)
368 .await
369 .expect("paired Tx should become bound via create_tx");
370
371 let close_ref = SelfRef::owning(
372 Backing::Boxed(Box::<[u8]>::default()),
373 ChannelClose {
374 metadata: Metadata::default(),
375 },
376 );
377 binder
378 .sender_for(channel_ids[0])
379 .send(IncomingChannelMessage::Close(close_ref))
380 .await
381 .expect("send close to paired Rx");
382 assert!(
383 rx_peer.recv().await.expect("recv close").is_none(),
384 "paired Rx should become bound via create_rx"
385 );
386 }
387
388 #[tokio::test]
389 async fn bind_channels_callee_args_binds_tx_and_rx_with_supplied_ids() {
390 let mut args = CalleeArgs {
391 tx: Tx::unbound(),
392 rx: crate::Rx::unbound(),
393 };
394 let plan = RpcPlan::for_type::<CalleeArgs>();
395 let binder = TestBinder::new();
396 let channel_ids = [ChannelId(41), ChannelId(43)];
397
398 unsafe {
399 bind_channels_callee_args(
400 (&mut args as *mut CalleeArgs).cast::<u8>(),
401 plan,
402 &channel_ids,
403 &binder,
404 )
405 };
406
407 args.tx
408 .send(5)
409 .await
410 .expect("callee-side Tx should be bound via bind_tx");
411
412 let close_ref = SelfRef::owning(
413 Backing::Boxed(Box::<[u8]>::default()),
414 ChannelClose {
415 metadata: Metadata::default(),
416 },
417 );
418 binder
419 .sender_for(ChannelId(43))
420 .send(IncomingChannelMessage::Close(close_ref))
421 .await
422 .expect("send close to bound callee Rx");
423 assert!(args.rx.recv().await.expect("recv close").is_none());
424
425 assert_eq!(
426 binder
427 .bind_tx_calls
428 .lock()
429 .expect("bind-tx mutex poisoned")
430 .as_slice(),
431 &[(ChannelId(41), 16)]
432 );
433 assert_eq!(
434 binder
435 .register_rx_calls
436 .lock()
437 .expect("register-rx mutex poisoned")
438 .as_slice(),
439 &[ChannelId(43)]
440 );
441 }
442
443 #[tokio::test]
444 async fn bind_channels_callee_args_stops_when_channel_ids_are_exhausted() {
445 let mut args = CalleeArgs {
446 tx: Tx::unbound(),
447 rx: crate::Rx::unbound(),
448 };
449 let plan = RpcPlan::for_type::<CalleeArgs>();
450 let binder = TestBinder::new();
451 let only_one_id = [ChannelId(51)];
452
453 unsafe {
454 bind_channels_callee_args(
455 (&mut args as *mut CalleeArgs).cast::<u8>(),
456 plan,
457 &only_one_id,
458 &binder,
459 )
460 };
461
462 args.tx
463 .send(1)
464 .await
465 .expect("first channel location should bind");
466 let recv = args.rx.recv().await;
467 assert!(
468 matches!(recv, Err(RxError::Unbound)),
469 "second channel location should remain unbound when IDs are exhausted"
470 );
471 }
472}