1use std::convert::Infallible;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6#[cfg(not(target_arch = "wasm32"))]
7use std::sync::Mutex;
8#[cfg(not(target_arch = "wasm32"))]
9use std::sync::atomic::{AtomicBool, Ordering};
10
11use facet::Facet;
12use facet_core::PtrConst;
13#[cfg(not(target_arch = "wasm32"))]
14use tokio::sync::{Semaphore, mpsc};
15
16#[cfg(not(target_arch = "wasm32"))]
17use crate::{ChannelClose, ChannelItem, ChannelReset, Metadata, Payload, SelfRef};
18
19#[cfg(not(target_arch = "wasm32"))]
22pub enum ChannelBinding {
23 Sink(Arc<dyn ChannelSink>),
24 Receiver(mpsc::Receiver<IncomingChannelMessage>),
25}
26
27#[cfg(not(target_arch = "wasm32"))]
34pub struct ChannelCore {
35 binding: Mutex<Option<ChannelBinding>>,
36}
37
38#[cfg(not(target_arch = "wasm32"))]
39impl ChannelCore {
40 fn new() -> Self {
41 Self {
42 binding: Mutex::new(None),
43 }
44 }
45
46 pub fn set_binding(&self, binding: ChannelBinding) {
48 let mut guard = self.binding.lock().expect("channel core mutex poisoned");
49 assert!(guard.is_none(), "channel binding already set");
50 *guard = Some(binding);
51 }
52
53 pub fn get_sink(&self) -> Option<Arc<dyn ChannelSink>> {
56 let guard = self.binding.lock().expect("channel core mutex poisoned");
57 match guard.as_ref() {
58 Some(ChannelBinding::Sink(sink)) => Some(sink.clone()),
59 _ => None,
60 }
61 }
62
63 pub fn take_receiver(&self) -> Option<mpsc::Receiver<IncomingChannelMessage>> {
66 let mut guard = self.binding.lock().expect("channel core mutex poisoned");
67 match guard.take() {
68 Some(ChannelBinding::Receiver(rx)) => Some(rx),
69 other => {
70 *guard = other;
72 None
73 }
74 }
75 }
76}
77
78#[derive(Facet)]
80#[facet(opaque)]
81pub(crate) struct CoreSlot {
82 #[cfg(not(target_arch = "wasm32"))]
83 pub(crate) inner: Option<Arc<ChannelCore>>,
84}
85
86impl CoreSlot {
87 pub(crate) fn empty() -> Self {
88 Self {
89 #[cfg(not(target_arch = "wasm32"))]
90 inner: None,
91 }
92 }
93}
94
95pub fn channel<T>() -> (Tx<T>, Rx<T>) {
102 #[cfg(not(target_arch = "wasm32"))]
103 {
104 let core = Arc::new(ChannelCore::new());
105 (Tx::paired(core.clone()), Rx::paired(core))
106 }
107 #[cfg(target_arch = "wasm32")]
108 {
109 (Tx::unbound(), Rx::unbound())
110 }
111}
112
113#[cfg(not(target_arch = "wasm32"))]
118pub trait ChannelSink: Send + Sync + 'static {
119 fn send_payload<'payload>(
120 &self,
121 payload: Payload<'payload>,
122 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>>;
123
124 fn close_channel(
125 &self,
126 metadata: Metadata,
127 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>>;
128
129 fn close_channel_on_drop(&self) {}
135}
136
137#[cfg(not(target_arch = "wasm32"))]
145pub struct CreditSink<S: ChannelSink> {
146 inner: S,
147 credit: Arc<Semaphore>,
148}
149
150#[cfg(not(target_arch = "wasm32"))]
151impl<S: ChannelSink> CreditSink<S> {
152 pub fn new(inner: S, initial_credit: u32) -> Self {
156 Self {
157 inner,
158 credit: Arc::new(Semaphore::new(initial_credit as usize)),
159 }
160 }
161
162 pub fn credit(&self) -> &Arc<Semaphore> {
165 &self.credit
166 }
167}
168
169#[cfg(not(target_arch = "wasm32"))]
170impl<S: ChannelSink> ChannelSink for CreditSink<S> {
171 fn send_payload<'payload>(
172 &self,
173 payload: Payload<'payload>,
174 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
175 let credit = self.credit.clone();
176 let fut = self.inner.send_payload(payload);
177 Box::pin(async move {
178 let permit = credit
179 .acquire()
180 .await
181 .map_err(|_| TxError::Transport("channel credit semaphore closed".into()))?;
182 permit.forget();
183 fut.await
184 })
185 }
186
187 fn close_channel(
188 &self,
189 metadata: Metadata,
190 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
191 self.inner.close_channel(metadata)
193 }
194
195 fn close_channel_on_drop(&self) {
196 self.inner.close_channel_on_drop();
197 }
198}
199
200#[cfg(not(target_arch = "wasm32"))]
202pub enum IncomingChannelMessage {
203 Item(SelfRef<ChannelItem<'static>>),
204 Close(SelfRef<ChannelClose<'static>>),
205 Reset(SelfRef<ChannelReset<'static>>),
206}
207
208#[derive(Facet)]
210#[facet(opaque)]
211pub(crate) struct SinkSlot {
212 #[cfg(not(target_arch = "wasm32"))]
213 pub(crate) inner: Option<Arc<dyn ChannelSink>>,
214}
215
216impl SinkSlot {
217 pub(crate) fn empty() -> Self {
218 Self {
219 #[cfg(not(target_arch = "wasm32"))]
220 inner: None,
221 }
222 }
223}
224
225#[derive(Facet)]
227#[facet(opaque)]
228pub(crate) struct ReceiverSlot {
229 #[cfg(not(target_arch = "wasm32"))]
230 pub(crate) inner: Option<mpsc::Receiver<IncomingChannelMessage>>,
231}
232
233impl ReceiverSlot {
234 pub(crate) fn empty() -> Self {
235 Self {
236 #[cfg(not(target_arch = "wasm32"))]
237 inner: None,
238 }
239 }
240}
241
242#[derive(Facet)]
252#[facet(proxy = ())]
253pub struct Tx<T, const N: usize = 16> {
254 pub(crate) sink: SinkSlot,
255 pub(crate) core: CoreSlot,
256 #[cfg(not(target_arch = "wasm32"))]
257 #[facet(opaque)]
258 closed: AtomicBool,
259 #[facet(opaque)]
260 _marker: PhantomData<T>,
261}
262
263impl<T, const N: usize> Tx<T, N> {
264 pub fn unbound() -> Self {
266 Self {
267 sink: SinkSlot::empty(),
268 core: CoreSlot::empty(),
269 #[cfg(not(target_arch = "wasm32"))]
270 closed: AtomicBool::new(false),
271 _marker: PhantomData,
272 }
273 }
274
275 #[cfg(not(target_arch = "wasm32"))]
277 fn paired(core: Arc<ChannelCore>) -> Self {
278 Self {
279 sink: SinkSlot::empty(),
280 core: CoreSlot { inner: Some(core) },
281 closed: AtomicBool::new(false),
282 _marker: PhantomData,
283 }
284 }
285
286 pub fn is_bound(&self) -> bool {
287 #[cfg(not(target_arch = "wasm32"))]
288 {
289 if self.sink.inner.is_some() {
290 return true;
291 }
292 if let Some(core) = &self.core.inner {
293 return core.get_sink().is_some();
294 }
295 false
296 }
297 #[cfg(target_arch = "wasm32")]
298 false
299 }
300
301 pub fn has_core(&self) -> bool {
303 #[cfg(not(target_arch = "wasm32"))]
304 return self.core.inner.is_some();
305 #[cfg(target_arch = "wasm32")]
306 return false;
307 }
308
309 #[cfg(not(target_arch = "wasm32"))]
311 fn resolve_sink(&self) -> Result<Arc<dyn ChannelSink>, TxError> {
312 if let Some(sink) = &self.sink.inner {
314 return Ok(sink.clone());
315 }
316 if let Some(core) = &self.core.inner
318 && let Some(sink) = core.get_sink()
319 {
320 return Ok(sink);
321 }
322 Err(TxError::Unbound)
323 }
324
325 #[cfg(not(target_arch = "wasm32"))]
326 pub async fn send<'value>(&self, value: T) -> Result<(), TxError>
327 where
328 T: Facet<'value>,
329 {
330 let sink = self.resolve_sink()?;
331 let ptr = PtrConst::new((&value as *const T).cast::<u8>());
332 let payload = unsafe { Payload::outgoing_unchecked(ptr, T::SHAPE) };
335 let result = sink.send_payload(payload).await;
336 drop(value);
337 result
338 }
339
340 #[cfg(not(target_arch = "wasm32"))]
342 pub async fn close<'value>(&self, metadata: Metadata<'value>) -> Result<(), TxError> {
343 self.closed.store(true, Ordering::Release);
344 let sink = self.resolve_sink()?;
345 sink.close_channel(metadata).await
346 }
347
348 #[doc(hidden)]
349 #[cfg(not(target_arch = "wasm32"))]
350 pub fn bind(&mut self, sink: Arc<dyn ChannelSink>) {
351 self.sink.inner = Some(sink);
352 }
353}
354
355#[cfg(not(target_arch = "wasm32"))]
356impl<T, const N: usize> Drop for Tx<T, N> {
357 fn drop(&mut self) {
358 if self.closed.swap(true, Ordering::AcqRel) {
359 return;
360 }
361
362 let sink = if let Some(sink) = &self.sink.inner {
363 Some(sink.clone())
364 } else if let Some(core) = &self.core.inner {
365 core.get_sink()
366 } else {
367 None
368 };
369
370 let Some(sink) = sink else {
371 return;
372 };
373
374 sink.close_channel_on_drop();
376 }
377}
378
379#[allow(clippy::infallible_try_from)]
380impl<T, const N: usize> TryFrom<&Tx<T, N>> for () {
381 type Error = Infallible;
382
383 fn try_from(_value: &Tx<T, N>) -> Result<Self, Self::Error> {
384 Ok(())
385 }
386}
387
388#[allow(clippy::infallible_try_from)]
389impl<T, const N: usize> TryFrom<()> for Tx<T, N> {
390 type Error = Infallible;
391
392 fn try_from(_value: ()) -> Result<Self, Self::Error> {
393 Ok(Self::unbound())
394 }
395}
396
397#[derive(Debug)]
399pub enum TxError {
400 Unbound,
401 Transport(String),
402}
403
404impl std::fmt::Display for TxError {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 match self {
407 Self::Unbound => write!(f, "channel is not bound"),
408 Self::Transport(msg) => write!(f, "transport error: {msg}"),
409 }
410 }
411}
412
413impl std::error::Error for TxError {}
414
415#[derive(Facet)]
422#[facet(proxy = ())]
423pub struct Rx<T, const N: usize = 16> {
424 pub(crate) receiver: ReceiverSlot,
425 pub(crate) core: CoreSlot,
426 #[facet(opaque)]
427 _marker: PhantomData<T>,
428}
429
430impl<T, const N: usize> Rx<T, N> {
431 pub fn unbound() -> Self {
433 Self {
434 receiver: ReceiverSlot::empty(),
435 core: CoreSlot::empty(),
436 _marker: PhantomData,
437 }
438 }
439
440 #[cfg(not(target_arch = "wasm32"))]
442 fn paired(core: Arc<ChannelCore>) -> Self {
443 Self {
444 receiver: ReceiverSlot::empty(),
445 core: CoreSlot { inner: Some(core) },
446 _marker: PhantomData,
447 }
448 }
449
450 pub fn is_bound(&self) -> bool {
451 #[cfg(not(target_arch = "wasm32"))]
452 {
453 if self.receiver.inner.is_some() {
454 return true;
455 }
456 false
457 }
458 #[cfg(target_arch = "wasm32")]
459 false
460 }
461
462 pub fn has_core(&self) -> bool {
464 #[cfg(not(target_arch = "wasm32"))]
465 return self.core.inner.is_some();
466 #[cfg(target_arch = "wasm32")]
467 return false;
468 }
469
470 #[cfg(not(target_arch = "wasm32"))]
472 pub async fn recv(&mut self) -> Result<Option<SelfRef<T>>, RxError>
473 where
474 T: Facet<'static>,
475 {
476 if self.receiver.inner.is_none()
478 && let Some(core) = &self.core.inner
479 && let Some(rx) = core.take_receiver()
480 {
481 self.receiver.inner = Some(rx);
482 }
483
484 let receiver = self.receiver.inner.as_mut().ok_or(RxError::Unbound)?;
485 match receiver.recv().await {
486 Some(IncomingChannelMessage::Close(_)) | None => Ok(None),
487 Some(IncomingChannelMessage::Reset(_)) => Err(RxError::Reset),
488 Some(IncomingChannelMessage::Item(msg)) => msg
489 .try_repack(|item, _backing_bytes| {
490 let Payload::Incoming(bytes) = item.item else {
491 return Err(RxError::Protocol(
492 "incoming channel item payload was not Incoming".into(),
493 ));
494 };
495 facet_postcard::from_slice_borrowed(bytes).map_err(RxError::Deserialize)
496 })
497 .map(Some),
498 }
499 }
500
501 #[doc(hidden)]
502 #[cfg(not(target_arch = "wasm32"))]
503 pub fn bind(&mut self, receiver: mpsc::Receiver<IncomingChannelMessage>) {
504 self.receiver.inner = Some(receiver);
505 }
506}
507
508#[allow(clippy::infallible_try_from)]
509impl<T, const N: usize> TryFrom<&Rx<T, N>> for () {
510 type Error = Infallible;
511
512 fn try_from(_value: &Rx<T, N>) -> Result<Self, Self::Error> {
513 Ok(())
514 }
515}
516
517#[allow(clippy::infallible_try_from)]
518impl<T, const N: usize> TryFrom<()> for Rx<T, N> {
519 type Error = Infallible;
520
521 fn try_from(_value: ()) -> Result<Self, Self::Error> {
522 Ok(Self::unbound())
523 }
524}
525
526#[derive(Debug)]
528pub enum RxError {
529 Unbound,
530 Reset,
531 Deserialize(facet_postcard::DeserializeError),
532 Protocol(String),
533}
534
535impl std::fmt::Display for RxError {
536 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537 match self {
538 Self::Unbound => write!(f, "channel is not bound"),
539 Self::Reset => write!(f, "channel reset by peer"),
540 Self::Deserialize(e) => write!(f, "deserialize error: {e}"),
541 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
542 }
543 }
544}
545
546impl std::error::Error for RxError {}
547
548pub fn is_tx(shape: &facet_core::Shape) -> bool {
550 shape.decl_id == Tx::<()>::SHAPE.decl_id
551}
552
553pub fn is_rx(shape: &facet_core::Shape) -> bool {
555 shape.decl_id == Rx::<()>::SHAPE.decl_id
556}
557
558pub fn is_channel(shape: &facet_core::Shape) -> bool {
560 is_tx(shape) || is_rx(shape)
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use crate::{Backing, ChannelClose, ChannelItem, ChannelReset, Metadata, SelfRef};
567 use std::sync::atomic::{AtomicUsize, Ordering};
568
569 struct CountingSink {
570 send_calls: AtomicUsize,
571 close_calls: AtomicUsize,
572 close_on_drop_calls: AtomicUsize,
573 }
574
575 impl CountingSink {
576 fn new() -> Self {
577 Self {
578 send_calls: AtomicUsize::new(0),
579 close_calls: AtomicUsize::new(0),
580 close_on_drop_calls: AtomicUsize::new(0),
581 }
582 }
583 }
584
585 impl ChannelSink for CountingSink {
586 fn send_payload<'payload>(
587 &self,
588 _payload: Payload<'payload>,
589 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
590 self.send_calls.fetch_add(1, Ordering::AcqRel);
591 Box::pin(async { Ok(()) })
592 }
593
594 fn close_channel(
595 &self,
596 _metadata: Metadata,
597 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
598 self.close_calls.fetch_add(1, Ordering::AcqRel);
599 Box::pin(async { Ok(()) })
600 }
601
602 fn close_channel_on_drop(&self) {
603 self.close_on_drop_calls.fetch_add(1, Ordering::AcqRel);
604 }
605 }
606
607 #[tokio::test]
608 async fn tx_close_does_not_emit_drop_close_after_explicit_close() {
609 let sink_impl = Arc::new(CountingSink::new());
610 let sink: Arc<dyn ChannelSink> = sink_impl.clone();
611
612 let mut tx = Tx::<u32>::unbound();
613 tx.bind(sink);
614 tx.close(Metadata::default())
615 .await
616 .expect("close should succeed");
617 drop(tx);
618
619 assert_eq!(sink_impl.close_calls.load(Ordering::Acquire), 1);
620 assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 0);
621 }
622
623 #[test]
624 fn tx_drop_emits_close_on_drop_for_bound_sink() {
625 let sink_impl = Arc::new(CountingSink::new());
626 let sink: Arc<dyn ChannelSink> = sink_impl.clone();
627
628 let mut tx = Tx::<u32>::unbound();
629 tx.bind(sink);
630 drop(tx);
631
632 assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
633 }
634
635 #[test]
636 fn tx_drop_emits_close_on_drop_for_paired_core_binding() {
637 let sink_impl = Arc::new(CountingSink::new());
638 let sink: Arc<dyn ChannelSink> = sink_impl.clone();
639
640 let (tx, _rx) = channel::<u32>();
641 let core = tx.core.inner.as_ref().expect("paired tx should have core");
642 core.set_binding(ChannelBinding::Sink(sink));
643 drop(tx);
644
645 assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
646 }
647
648 #[tokio::test]
649 async fn rx_recv_returns_unbound_when_not_bound() {
650 let mut rx = Rx::<u32>::unbound();
651 let err = match rx.recv().await {
652 Ok(_) => panic!("unbound rx should fail"),
653 Err(err) => err,
654 };
655 assert!(matches!(err, RxError::Unbound));
656 }
657
658 #[tokio::test]
659 async fn rx_recv_returns_none_on_close() {
660 let (tx, rx_inner) = mpsc::channel(1);
661 let mut rx = Rx::<u32>::unbound();
662 rx.bind(rx_inner);
663
664 let close = SelfRef::owning(
665 Backing::Boxed(Box::<[u8]>::default()),
666 ChannelClose {
667 metadata: Metadata::default(),
668 },
669 );
670 tx.send(IncomingChannelMessage::Close(close))
671 .await
672 .expect("send close");
673
674 assert!(rx.recv().await.expect("recv should succeed").is_none());
675 }
676
677 #[tokio::test]
678 async fn rx_recv_returns_reset_error() {
679 let (tx, rx_inner) = mpsc::channel(1);
680 let mut rx = Rx::<u32>::unbound();
681 rx.bind(rx_inner);
682
683 let reset = SelfRef::owning(
684 Backing::Boxed(Box::<[u8]>::default()),
685 ChannelReset {
686 metadata: Metadata::default(),
687 },
688 );
689 tx.send(IncomingChannelMessage::Reset(reset))
690 .await
691 .expect("send reset");
692
693 let err = match rx.recv().await {
694 Ok(_) => panic!("reset should be surfaced as error"),
695 Err(err) => err,
696 };
697 assert!(matches!(err, RxError::Reset));
698 }
699
700 #[tokio::test]
701 async fn rx_recv_rejects_outgoing_payload_variant_as_protocol_error() {
702 static VALUE: u32 = 42;
703
704 let (tx, rx_inner) = mpsc::channel(1);
705 let mut rx = Rx::<u32>::unbound();
706 rx.bind(rx_inner);
707
708 let item = SelfRef::owning(
709 Backing::Boxed(Box::<[u8]>::default()),
710 ChannelItem {
711 item: Payload::outgoing(&VALUE),
712 },
713 );
714 tx.send(IncomingChannelMessage::Item(item))
715 .await
716 .expect("send item");
717
718 let err = match rx.recv().await {
719 Ok(_) => panic!("outgoing payload should be protocol error"),
720 Err(err) => err,
721 };
722 assert!(matches!(err, RxError::Protocol(_)));
723 }
724}