1use std::convert::Infallible;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6#[cfg(not(target_arch = "wasm32"))]
7use std::sync::Mutex;
8#[cfg(not(target_arch = "wasm32"))]
9use std::sync::atomic::{AtomicBool, Ordering};
10
11use facet::Facet;
12use facet_core::PtrConst;
13#[cfg(not(target_arch = "wasm32"))]
14use tokio::sync::{Semaphore, mpsc};
15
16#[cfg(not(target_arch = "wasm32"))]
17use crate::{ChannelClose, ChannelItem, ChannelReset, Metadata, Payload, SelfRef};
18
19#[cfg(not(target_arch = "wasm32"))]
22pub enum ChannelBinding {
23 Sink(BoundChannelSink),
24 Receiver(BoundChannelReceiver),
25}
26
27#[cfg(not(target_arch = "wasm32"))]
28pub trait ChannelLiveness: Send + Sync + 'static {}
29
30#[cfg(not(target_arch = "wasm32"))]
31impl<T: Send + Sync + 'static> ChannelLiveness for T {}
32
33#[cfg(not(target_arch = "wasm32"))]
34pub type ChannelLivenessHandle = Arc<dyn ChannelLiveness>;
35
36#[cfg(not(target_arch = "wasm32"))]
37#[derive(Clone)]
38pub struct BoundChannelSink {
39 pub sink: Arc<dyn ChannelSink>,
40 pub liveness: Option<ChannelLivenessHandle>,
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44pub struct BoundChannelReceiver {
45 pub receiver: mpsc::Receiver<IncomingChannelMessage>,
46 pub liveness: Option<ChannelLivenessHandle>,
47}
48
49#[cfg(not(target_arch = "wasm32"))]
56pub struct ChannelCore {
57 binding: Mutex<Option<ChannelBinding>>,
58}
59
60#[cfg(not(target_arch = "wasm32"))]
61impl ChannelCore {
62 fn new() -> Self {
63 Self {
64 binding: Mutex::new(None),
65 }
66 }
67
68 pub fn set_binding(&self, binding: ChannelBinding) {
70 let mut guard = self.binding.lock().expect("channel core mutex poisoned");
71 assert!(guard.is_none(), "channel binding already set");
72 *guard = Some(binding);
73 }
74
75 pub fn get_sink(&self) -> Option<Arc<dyn ChannelSink>> {
78 let guard = self.binding.lock().expect("channel core mutex poisoned");
79 match guard.as_ref() {
80 Some(ChannelBinding::Sink(bound)) => Some(bound.sink.clone()),
81 _ => None,
82 }
83 }
84
85 pub fn take_receiver(&self) -> Option<BoundChannelReceiver> {
88 let mut guard = self.binding.lock().expect("channel core mutex poisoned");
89 match guard.take() {
90 Some(ChannelBinding::Receiver(bound)) => Some(bound),
91 other => {
92 *guard = other;
94 None
95 }
96 }
97 }
98}
99
100#[derive(Facet)]
102#[facet(opaque)]
103pub(crate) struct CoreSlot {
104 #[cfg(not(target_arch = "wasm32"))]
105 pub(crate) inner: Option<Arc<ChannelCore>>,
106}
107
108impl CoreSlot {
109 pub(crate) fn empty() -> Self {
110 Self {
111 #[cfg(not(target_arch = "wasm32"))]
112 inner: None,
113 }
114 }
115}
116
117pub fn channel<T>() -> (Tx<T>, Rx<T>) {
124 #[cfg(not(target_arch = "wasm32"))]
125 {
126 let core = Arc::new(ChannelCore::new());
127 (Tx::paired(core.clone()), Rx::paired(core))
128 }
129 #[cfg(target_arch = "wasm32")]
130 {
131 (Tx::unbound(), Rx::unbound())
132 }
133}
134
135#[cfg(not(target_arch = "wasm32"))]
140pub trait ChannelSink: Send + Sync + 'static {
141 fn send_payload<'payload>(
142 &self,
143 payload: Payload<'payload>,
144 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>>;
145
146 fn close_channel(
147 &self,
148 metadata: Metadata,
149 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>>;
150
151 fn close_channel_on_drop(&self) {}
157}
158
159#[cfg(not(target_arch = "wasm32"))]
167pub struct CreditSink<S: ChannelSink> {
168 inner: S,
169 credit: Arc<Semaphore>,
170}
171
172#[cfg(not(target_arch = "wasm32"))]
173impl<S: ChannelSink> CreditSink<S> {
174 pub fn new(inner: S, initial_credit: u32) -> Self {
178 Self {
179 inner,
180 credit: Arc::new(Semaphore::new(initial_credit as usize)),
181 }
182 }
183
184 pub fn credit(&self) -> &Arc<Semaphore> {
187 &self.credit
188 }
189}
190
191#[cfg(not(target_arch = "wasm32"))]
192impl<S: ChannelSink> ChannelSink for CreditSink<S> {
193 fn send_payload<'payload>(
194 &self,
195 payload: Payload<'payload>,
196 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
197 let credit = self.credit.clone();
198 let fut = self.inner.send_payload(payload);
199 Box::pin(async move {
200 let permit = credit
201 .acquire()
202 .await
203 .map_err(|_| TxError::Transport("channel credit semaphore closed".into()))?;
204 permit.forget();
205 fut.await
206 })
207 }
208
209 fn close_channel(
210 &self,
211 metadata: Metadata,
212 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
213 self.inner.close_channel(metadata)
215 }
216
217 fn close_channel_on_drop(&self) {
218 self.inner.close_channel_on_drop();
219 }
220}
221
222#[cfg(not(target_arch = "wasm32"))]
224pub enum IncomingChannelMessage {
225 Item(SelfRef<ChannelItem<'static>>),
226 Close(SelfRef<ChannelClose<'static>>),
227 Reset(SelfRef<ChannelReset<'static>>),
228}
229
230#[derive(Facet)]
232#[facet(opaque)]
233pub(crate) struct SinkSlot {
234 #[cfg(not(target_arch = "wasm32"))]
235 pub(crate) inner: Option<Arc<dyn ChannelSink>>,
236}
237
238impl SinkSlot {
239 pub(crate) fn empty() -> Self {
240 Self {
241 #[cfg(not(target_arch = "wasm32"))]
242 inner: None,
243 }
244 }
245}
246
247#[derive(Facet)]
249#[facet(opaque)]
250pub(crate) struct LivenessSlot {
251 #[cfg(not(target_arch = "wasm32"))]
252 pub(crate) inner: Option<ChannelLivenessHandle>,
253}
254
255impl LivenessSlot {
256 pub(crate) fn empty() -> Self {
257 Self {
258 #[cfg(not(target_arch = "wasm32"))]
259 inner: None,
260 }
261 }
262}
263
264#[derive(Facet)]
266#[facet(opaque)]
267pub(crate) struct ReceiverSlot {
268 #[cfg(not(target_arch = "wasm32"))]
269 pub(crate) inner: Option<mpsc::Receiver<IncomingChannelMessage>>,
270}
271
272impl ReceiverSlot {
273 pub(crate) fn empty() -> Self {
274 Self {
275 #[cfg(not(target_arch = "wasm32"))]
276 inner: None,
277 }
278 }
279}
280
281#[derive(Facet)]
291#[facet(proxy = ())]
292pub struct Tx<T, const N: usize = 16> {
293 pub(crate) sink: SinkSlot,
294 pub(crate) core: CoreSlot,
295 pub(crate) liveness: LivenessSlot,
296 #[cfg(not(target_arch = "wasm32"))]
297 #[facet(opaque)]
298 closed: AtomicBool,
299 #[facet(opaque)]
300 _marker: PhantomData<T>,
301}
302
303impl<T, const N: usize> Tx<T, N> {
304 pub fn unbound() -> Self {
306 Self {
307 sink: SinkSlot::empty(),
308 core: CoreSlot::empty(),
309 liveness: LivenessSlot::empty(),
310 #[cfg(not(target_arch = "wasm32"))]
311 closed: AtomicBool::new(false),
312 _marker: PhantomData,
313 }
314 }
315
316 #[cfg(not(target_arch = "wasm32"))]
318 fn paired(core: Arc<ChannelCore>) -> Self {
319 Self {
320 sink: SinkSlot::empty(),
321 core: CoreSlot { inner: Some(core) },
322 liveness: LivenessSlot::empty(),
323 closed: AtomicBool::new(false),
324 _marker: PhantomData,
325 }
326 }
327
328 pub fn is_bound(&self) -> bool {
329 #[cfg(not(target_arch = "wasm32"))]
330 {
331 if self.sink.inner.is_some() {
332 return true;
333 }
334 if let Some(core) = &self.core.inner {
335 return core.get_sink().is_some();
336 }
337 false
338 }
339 #[cfg(target_arch = "wasm32")]
340 false
341 }
342
343 pub fn has_core(&self) -> bool {
345 #[cfg(not(target_arch = "wasm32"))]
346 return self.core.inner.is_some();
347 #[cfg(target_arch = "wasm32")]
348 return false;
349 }
350
351 #[cfg(not(target_arch = "wasm32"))]
353 fn resolve_sink(&self) -> Result<Arc<dyn ChannelSink>, TxError> {
354 if let Some(sink) = &self.sink.inner {
356 return Ok(sink.clone());
357 }
358 if let Some(core) = &self.core.inner
360 && let Some(sink) = core.get_sink()
361 {
362 return Ok(sink);
363 }
364 Err(TxError::Unbound)
365 }
366
367 #[cfg(not(target_arch = "wasm32"))]
368 pub async fn send<'value>(&self, value: T) -> Result<(), TxError>
369 where
370 T: Facet<'value>,
371 {
372 let sink = self.resolve_sink()?;
373 let ptr = PtrConst::new((&value as *const T).cast::<u8>());
374 let payload = unsafe { Payload::outgoing_unchecked(ptr, T::SHAPE) };
377 let result = sink.send_payload(payload).await;
378 drop(value);
379 result
380 }
381
382 #[cfg(not(target_arch = "wasm32"))]
384 pub async fn close<'value>(&self, metadata: Metadata<'value>) -> Result<(), TxError> {
385 self.closed.store(true, Ordering::Release);
386 let sink = self.resolve_sink()?;
387 sink.close_channel(metadata).await
388 }
389
390 #[doc(hidden)]
391 #[cfg(not(target_arch = "wasm32"))]
392 pub fn bind(&mut self, sink: Arc<dyn ChannelSink>) {
393 self.bind_with_liveness(sink, None);
394 }
395
396 #[doc(hidden)]
397 #[cfg(not(target_arch = "wasm32"))]
398 pub fn bind_with_liveness(
399 &mut self,
400 sink: Arc<dyn ChannelSink>,
401 liveness: Option<ChannelLivenessHandle>,
402 ) {
403 self.sink.inner = Some(sink);
404 self.liveness.inner = liveness;
405 }
406}
407
408#[cfg(not(target_arch = "wasm32"))]
409impl<T, const N: usize> Drop for Tx<T, N> {
410 fn drop(&mut self) {
411 if self.closed.swap(true, Ordering::AcqRel) {
412 return;
413 }
414
415 let sink = if let Some(sink) = &self.sink.inner {
416 Some(sink.clone())
417 } else if let Some(core) = &self.core.inner {
418 core.get_sink()
419 } else {
420 None
421 };
422
423 let Some(sink) = sink else {
424 return;
425 };
426
427 sink.close_channel_on_drop();
429 }
430}
431
432#[allow(clippy::infallible_try_from)]
433impl<T, const N: usize> TryFrom<&Tx<T, N>> for () {
434 type Error = Infallible;
435
436 fn try_from(_value: &Tx<T, N>) -> Result<Self, Self::Error> {
437 Ok(())
438 }
439}
440
441#[allow(clippy::infallible_try_from)]
442impl<T, const N: usize> TryFrom<()> for Tx<T, N> {
443 type Error = Infallible;
444
445 fn try_from(_value: ()) -> Result<Self, Self::Error> {
446 Ok(Self::unbound())
447 }
448}
449
450#[derive(Debug)]
452pub enum TxError {
453 Unbound,
454 Transport(String),
455}
456
457impl std::fmt::Display for TxError {
458 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459 match self {
460 Self::Unbound => write!(f, "channel is not bound"),
461 Self::Transport(msg) => write!(f, "transport error: {msg}"),
462 }
463 }
464}
465
466impl std::error::Error for TxError {}
467
468#[derive(Facet)]
475#[facet(proxy = ())]
476pub struct Rx<T, const N: usize = 16> {
477 pub(crate) receiver: ReceiverSlot,
478 pub(crate) core: CoreSlot,
479 pub(crate) liveness: LivenessSlot,
480 #[facet(opaque)]
481 _marker: PhantomData<T>,
482}
483
484impl<T, const N: usize> Rx<T, N> {
485 pub fn unbound() -> Self {
487 Self {
488 receiver: ReceiverSlot::empty(),
489 core: CoreSlot::empty(),
490 liveness: LivenessSlot::empty(),
491 _marker: PhantomData,
492 }
493 }
494
495 #[cfg(not(target_arch = "wasm32"))]
497 fn paired(core: Arc<ChannelCore>) -> Self {
498 Self {
499 receiver: ReceiverSlot::empty(),
500 core: CoreSlot { inner: Some(core) },
501 liveness: LivenessSlot::empty(),
502 _marker: PhantomData,
503 }
504 }
505
506 pub fn is_bound(&self) -> bool {
507 #[cfg(not(target_arch = "wasm32"))]
508 {
509 if self.receiver.inner.is_some() {
510 return true;
511 }
512 false
513 }
514 #[cfg(target_arch = "wasm32")]
515 false
516 }
517
518 pub fn has_core(&self) -> bool {
520 #[cfg(not(target_arch = "wasm32"))]
521 return self.core.inner.is_some();
522 #[cfg(target_arch = "wasm32")]
523 return false;
524 }
525
526 #[cfg(not(target_arch = "wasm32"))]
528 pub async fn recv(&mut self) -> Result<Option<SelfRef<T>>, RxError>
529 where
530 T: Facet<'static>,
531 {
532 if self.receiver.inner.is_none()
534 && let Some(core) = &self.core.inner
535 && let Some(bound) = core.take_receiver()
536 {
537 self.receiver.inner = Some(bound.receiver);
538 self.liveness.inner = bound.liveness;
539 }
540
541 let receiver = self.receiver.inner.as_mut().ok_or(RxError::Unbound)?;
542 match receiver.recv().await {
543 Some(IncomingChannelMessage::Close(_)) | None => Ok(None),
544 Some(IncomingChannelMessage::Reset(_)) => Err(RxError::Reset),
545 Some(IncomingChannelMessage::Item(msg)) => msg
546 .try_repack(|item, _backing_bytes| {
547 let Payload::Incoming(bytes) = item.item else {
548 return Err(RxError::Protocol(
549 "incoming channel item payload was not Incoming".into(),
550 ));
551 };
552 facet_postcard::from_slice_borrowed(bytes).map_err(RxError::Deserialize)
553 })
554 .map(Some),
555 }
556 }
557
558 #[doc(hidden)]
559 #[cfg(not(target_arch = "wasm32"))]
560 pub fn bind(&mut self, receiver: mpsc::Receiver<IncomingChannelMessage>) {
561 self.bind_with_liveness(receiver, None);
562 }
563
564 #[doc(hidden)]
565 #[cfg(not(target_arch = "wasm32"))]
566 pub fn bind_with_liveness(
567 &mut self,
568 receiver: mpsc::Receiver<IncomingChannelMessage>,
569 liveness: Option<ChannelLivenessHandle>,
570 ) {
571 self.receiver.inner = Some(receiver);
572 self.liveness.inner = liveness;
573 }
574}
575
576#[allow(clippy::infallible_try_from)]
577impl<T, const N: usize> TryFrom<&Rx<T, N>> for () {
578 type Error = Infallible;
579
580 fn try_from(_value: &Rx<T, N>) -> Result<Self, Self::Error> {
581 Ok(())
582 }
583}
584
585#[allow(clippy::infallible_try_from)]
586impl<T, const N: usize> TryFrom<()> for Rx<T, N> {
587 type Error = Infallible;
588
589 fn try_from(_value: ()) -> Result<Self, Self::Error> {
590 Ok(Self::unbound())
591 }
592}
593
594#[derive(Debug)]
596pub enum RxError {
597 Unbound,
598 Reset,
599 Deserialize(facet_postcard::DeserializeError),
600 Protocol(String),
601}
602
603impl std::fmt::Display for RxError {
604 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605 match self {
606 Self::Unbound => write!(f, "channel is not bound"),
607 Self::Reset => write!(f, "channel reset by peer"),
608 Self::Deserialize(e) => write!(f, "deserialize error: {e}"),
609 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
610 }
611 }
612}
613
614impl std::error::Error for RxError {}
615
616pub fn is_tx(shape: &facet_core::Shape) -> bool {
618 shape.decl_id == Tx::<()>::SHAPE.decl_id
619}
620
621pub fn is_rx(shape: &facet_core::Shape) -> bool {
623 shape.decl_id == Rx::<()>::SHAPE.decl_id
624}
625
626pub fn is_channel(shape: &facet_core::Shape) -> bool {
628 is_tx(shape) || is_rx(shape)
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634 use crate::{Backing, ChannelClose, ChannelItem, ChannelReset, Metadata, SelfRef};
635 use std::sync::atomic::{AtomicUsize, Ordering};
636
637 struct CountingSink {
638 send_calls: AtomicUsize,
639 close_calls: AtomicUsize,
640 close_on_drop_calls: AtomicUsize,
641 }
642
643 impl CountingSink {
644 fn new() -> Self {
645 Self {
646 send_calls: AtomicUsize::new(0),
647 close_calls: AtomicUsize::new(0),
648 close_on_drop_calls: AtomicUsize::new(0),
649 }
650 }
651 }
652
653 impl ChannelSink for CountingSink {
654 fn send_payload<'payload>(
655 &self,
656 _payload: Payload<'payload>,
657 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
658 self.send_calls.fetch_add(1, Ordering::AcqRel);
659 Box::pin(async { Ok(()) })
660 }
661
662 fn close_channel(
663 &self,
664 _metadata: Metadata,
665 ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
666 self.close_calls.fetch_add(1, Ordering::AcqRel);
667 Box::pin(async { Ok(()) })
668 }
669
670 fn close_channel_on_drop(&self) {
671 self.close_on_drop_calls.fetch_add(1, Ordering::AcqRel);
672 }
673 }
674
675 #[tokio::test]
676 async fn tx_close_does_not_emit_drop_close_after_explicit_close() {
677 let sink_impl = Arc::new(CountingSink::new());
678 let sink: Arc<dyn ChannelSink> = sink_impl.clone();
679
680 let mut tx = Tx::<u32>::unbound();
681 tx.bind(sink);
682 tx.close(Metadata::default())
683 .await
684 .expect("close should succeed");
685 drop(tx);
686
687 assert_eq!(sink_impl.close_calls.load(Ordering::Acquire), 1);
688 assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 0);
689 }
690
691 #[test]
692 fn tx_drop_emits_close_on_drop_for_bound_sink() {
693 let sink_impl = Arc::new(CountingSink::new());
694 let sink: Arc<dyn ChannelSink> = sink_impl.clone();
695
696 let mut tx = Tx::<u32>::unbound();
697 tx.bind(sink);
698 drop(tx);
699
700 assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
701 }
702
703 #[test]
704 fn tx_drop_emits_close_on_drop_for_paired_core_binding() {
705 let sink_impl = Arc::new(CountingSink::new());
706 let sink: Arc<dyn ChannelSink> = sink_impl.clone();
707
708 let (tx, _rx) = channel::<u32>();
709 let core = tx.core.inner.as_ref().expect("paired tx should have core");
710 core.set_binding(ChannelBinding::Sink(BoundChannelSink {
711 sink,
712 liveness: None,
713 }));
714 drop(tx);
715
716 assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
717 }
718
719 #[tokio::test]
720 async fn rx_recv_returns_unbound_when_not_bound() {
721 let mut rx = Rx::<u32>::unbound();
722 let err = match rx.recv().await {
723 Ok(_) => panic!("unbound rx should fail"),
724 Err(err) => err,
725 };
726 assert!(matches!(err, RxError::Unbound));
727 }
728
729 #[tokio::test]
730 async fn rx_recv_returns_none_on_close() {
731 let (tx, rx_inner) = mpsc::channel(1);
732 let mut rx = Rx::<u32>::unbound();
733 rx.bind(rx_inner);
734
735 let close = SelfRef::owning(
736 Backing::Boxed(Box::<[u8]>::default()),
737 ChannelClose {
738 metadata: Metadata::default(),
739 },
740 );
741 tx.send(IncomingChannelMessage::Close(close))
742 .await
743 .expect("send close");
744
745 assert!(rx.recv().await.expect("recv should succeed").is_none());
746 }
747
748 #[tokio::test]
749 async fn rx_recv_returns_reset_error() {
750 let (tx, rx_inner) = mpsc::channel(1);
751 let mut rx = Rx::<u32>::unbound();
752 rx.bind(rx_inner);
753
754 let reset = SelfRef::owning(
755 Backing::Boxed(Box::<[u8]>::default()),
756 ChannelReset {
757 metadata: Metadata::default(),
758 },
759 );
760 tx.send(IncomingChannelMessage::Reset(reset))
761 .await
762 .expect("send reset");
763
764 let err = match rx.recv().await {
765 Ok(_) => panic!("reset should be surfaced as error"),
766 Err(err) => err,
767 };
768 assert!(matches!(err, RxError::Reset));
769 }
770
771 #[tokio::test]
772 async fn rx_recv_rejects_outgoing_payload_variant_as_protocol_error() {
773 static VALUE: u32 = 42;
774
775 let (tx, rx_inner) = mpsc::channel(1);
776 let mut rx = Rx::<u32>::unbound();
777 rx.bind(rx_inner);
778
779 let item = SelfRef::owning(
780 Backing::Boxed(Box::<[u8]>::default()),
781 ChannelItem {
782 item: Payload::outgoing(&VALUE),
783 },
784 );
785 tx.send(IncomingChannelMessage::Item(item))
786 .await
787 .expect("send item");
788
789 let err = match rx.recv().await {
790 Ok(_) => panic!("outgoing payload should be protocol error"),
791 Err(err) => err,
792 };
793 assert!(matches!(err, RxError::Protocol(_)));
794 }
795}