Skip to main content

roam_types/
channel.rs

1use std::convert::Infallible;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6#[cfg(not(target_arch = "wasm32"))]
7use std::sync::Mutex;
8#[cfg(not(target_arch = "wasm32"))]
9use std::sync::atomic::{AtomicBool, Ordering};
10
11use facet::Facet;
12use facet_core::PtrConst;
13#[cfg(not(target_arch = "wasm32"))]
14use tokio::sync::{Semaphore, mpsc};
15
16#[cfg(not(target_arch = "wasm32"))]
17use crate::{ChannelClose, ChannelItem, ChannelReset, Metadata, Payload, SelfRef};
18
19// r[impl rpc.channel.pair]
20/// The binding stored in a channel core — either a sink or a receiver, never both.
21#[cfg(not(target_arch = "wasm32"))]
22pub enum ChannelBinding {
23    Sink(BoundChannelSink),
24    Receiver(BoundChannelReceiver),
25}
26
27#[cfg(not(target_arch = "wasm32"))]
28pub trait ChannelLiveness: Send + Sync + 'static {}
29
30#[cfg(not(target_arch = "wasm32"))]
31impl<T: Send + Sync + 'static> ChannelLiveness for T {}
32
33#[cfg(not(target_arch = "wasm32"))]
34pub type ChannelLivenessHandle = Arc<dyn ChannelLiveness>;
35
36#[cfg(not(target_arch = "wasm32"))]
37pub trait ChannelCreditReplenisher: Send + Sync + 'static {
38    fn on_item_consumed(&self);
39}
40
41#[cfg(not(target_arch = "wasm32"))]
42pub type ChannelCreditReplenisherHandle = Arc<dyn ChannelCreditReplenisher>;
43
44#[cfg(not(target_arch = "wasm32"))]
45#[derive(Clone)]
46pub struct BoundChannelSink {
47    pub sink: Arc<dyn ChannelSink>,
48    pub liveness: Option<ChannelLivenessHandle>,
49}
50
51#[cfg(not(target_arch = "wasm32"))]
52pub struct BoundChannelReceiver {
53    pub receiver: mpsc::Receiver<IncomingChannelMessage>,
54    pub liveness: Option<ChannelLivenessHandle>,
55    pub replenisher: Option<ChannelCreditReplenisherHandle>,
56}
57
58// r[impl rpc.channel.pair]
59/// Shared state between a `Tx`/`Rx` pair created by `channel()`.
60///
61/// Contains a `Mutex<Option<ChannelBinding>>` that is written once during
62/// binding and read/taken by the paired handle. The mutex is only locked
63/// during binding (once) and on first use by the paired handle (once).
64#[cfg(not(target_arch = "wasm32"))]
65pub struct ChannelCore {
66    binding: Mutex<Option<ChannelBinding>>,
67}
68
69#[cfg(not(target_arch = "wasm32"))]
70impl ChannelCore {
71    fn new() -> Self {
72        Self {
73            binding: Mutex::new(None),
74        }
75    }
76
77    /// Store a binding in the core. Panics if already set.
78    pub fn set_binding(&self, binding: ChannelBinding) {
79        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
80        assert!(guard.is_none(), "channel binding already set");
81        *guard = Some(binding);
82    }
83
84    /// Clone the sink from the core (for Tx reading the sink).
85    /// Returns None if no sink has been set or if the binding is a Receiver.
86    pub fn get_sink(&self) -> Option<Arc<dyn ChannelSink>> {
87        let guard = self.binding.lock().expect("channel core mutex poisoned");
88        match guard.as_ref() {
89            Some(ChannelBinding::Sink(bound)) => Some(bound.sink.clone()),
90            _ => None,
91        }
92    }
93
94    /// Take the receiver out of the core (for Rx on first recv).
95    /// Returns None if no receiver has been set or if it was already taken.
96    pub fn take_receiver(&self) -> Option<BoundChannelReceiver> {
97        let mut guard = self.binding.lock().expect("channel core mutex poisoned");
98        match guard.take() {
99            Some(ChannelBinding::Receiver(bound)) => Some(bound),
100            other => {
101                // Put it back if it wasn't a receiver
102                *guard = other;
103                None
104            }
105        }
106    }
107}
108
109/// Slot for the shared channel core, accessible via facet reflection.
110#[derive(Facet)]
111#[facet(opaque)]
112pub(crate) struct CoreSlot {
113    #[cfg(not(target_arch = "wasm32"))]
114    pub(crate) inner: Option<Arc<ChannelCore>>,
115}
116
117impl CoreSlot {
118    pub(crate) fn empty() -> Self {
119        Self {
120            #[cfg(not(target_arch = "wasm32"))]
121            inner: None,
122        }
123    }
124}
125
126// r[impl rpc.channel.pair]
127/// Create a channel pair with shared state.
128///
129/// Both ends hold an `Arc` reference to the same `ChannelCore`. The framework
130/// binds the handle that appears in args or return values, and the paired
131/// handle reads or takes the binding from the shared core.
132pub fn channel<T>() -> (Tx<T>, Rx<T>) {
133    #[cfg(not(target_arch = "wasm32"))]
134    {
135        let core = Arc::new(ChannelCore::new());
136        (Tx::paired(core.clone()), Rx::paired(core))
137    }
138    #[cfg(target_arch = "wasm32")]
139    {
140        (Tx::unbound(), Rx::unbound())
141    }
142}
143
144/// Runtime sink implemented by the session driver.
145///
146/// The contract is strict: successful completion means the item has gone
147/// through the conduit to the link commit boundary.
148#[cfg(not(target_arch = "wasm32"))]
149pub trait ChannelSink: Send + Sync + 'static {
150    fn send_payload<'payload>(
151        &self,
152        payload: Payload<'payload>,
153    ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>>;
154
155    fn close_channel(
156        &self,
157        metadata: Metadata,
158    ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>>;
159
160    /// Synchronous drop-time close signal.
161    ///
162    /// This is used by `Tx::drop` to notify the runtime immediately without
163    /// spawning detached tasks. Implementations should enqueue a close intent
164    /// to their runtime/driver if possible.
165    fn close_channel_on_drop(&self) {}
166}
167
168// r[impl rpc.flow-control.credit]
169// r[impl rpc.flow-control.credit.exhaustion]
170/// A [`ChannelSink`] wrapper that enforces credit-based flow control.
171///
172/// Each `send_payload` acquires one permit from the semaphore, blocking if
173/// credit is zero. The semaphore is shared with the driver so that incoming
174/// `GrantCredit` messages can add permits via [`CreditSink::credit`].
175#[cfg(not(target_arch = "wasm32"))]
176pub struct CreditSink<S: ChannelSink> {
177    inner: S,
178    credit: Arc<Semaphore>,
179}
180
181#[cfg(not(target_arch = "wasm32"))]
182impl<S: ChannelSink> CreditSink<S> {
183    // r[impl rpc.flow-control.credit.initial]
184    // r[impl rpc.flow-control.credit.initial.zero]
185    /// Wrap `inner` with `initial_credit` permits (the const generic `N`).
186    pub fn new(inner: S, initial_credit: u32) -> Self {
187        Self {
188            inner,
189            credit: Arc::new(Semaphore::new(initial_credit as usize)),
190        }
191    }
192
193    /// Returns the credit semaphore. The driver holds a clone so
194    /// `GrantCredit` messages can call `add_permits`.
195    pub fn credit(&self) -> &Arc<Semaphore> {
196        &self.credit
197    }
198}
199
200#[cfg(not(target_arch = "wasm32"))]
201impl<S: ChannelSink> ChannelSink for CreditSink<S> {
202    fn send_payload<'payload>(
203        &self,
204        payload: Payload<'payload>,
205    ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
206        let credit = self.credit.clone();
207        let fut = self.inner.send_payload(payload);
208        Box::pin(async move {
209            let permit = credit
210                .acquire()
211                .await
212                .map_err(|_| TxError::Transport("channel credit semaphore closed".into()))?;
213            permit.forget();
214            fut.await
215        })
216    }
217
218    fn close_channel(
219        &self,
220        metadata: Metadata,
221    ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
222        // Close does not consume credit — it's a control message.
223        self.inner.close_channel(metadata)
224    }
225
226    fn close_channel_on_drop(&self) {
227        self.inner.close_channel_on_drop();
228    }
229}
230
231/// Message delivered to an `Rx` by the driver.
232#[cfg(not(target_arch = "wasm32"))]
233pub enum IncomingChannelMessage {
234    Item(SelfRef<ChannelItem<'static>>),
235    Close(SelfRef<ChannelClose<'static>>),
236    Reset(SelfRef<ChannelReset<'static>>),
237}
238
239/// Sender-side runtime slot.
240#[derive(Facet)]
241#[facet(opaque)]
242pub(crate) struct SinkSlot {
243    #[cfg(not(target_arch = "wasm32"))]
244    pub(crate) inner: Option<Arc<dyn ChannelSink>>,
245}
246
247impl SinkSlot {
248    pub(crate) fn empty() -> Self {
249        Self {
250            #[cfg(not(target_arch = "wasm32"))]
251            inner: None,
252        }
253    }
254}
255
256/// Opaque liveness retention slot for bound channel handles.
257#[derive(Facet)]
258#[facet(opaque)]
259pub(crate) struct LivenessSlot {
260    #[cfg(not(target_arch = "wasm32"))]
261    pub(crate) inner: Option<ChannelLivenessHandle>,
262}
263
264impl LivenessSlot {
265    pub(crate) fn empty() -> Self {
266        Self {
267            #[cfg(not(target_arch = "wasm32"))]
268            inner: None,
269        }
270    }
271}
272
273/// Receiver-side runtime slot.
274#[derive(Facet)]
275#[facet(opaque)]
276pub(crate) struct ReceiverSlot {
277    #[cfg(not(target_arch = "wasm32"))]
278    pub(crate) inner: Option<mpsc::Receiver<IncomingChannelMessage>>,
279}
280
281impl ReceiverSlot {
282    pub(crate) fn empty() -> Self {
283        Self {
284            #[cfg(not(target_arch = "wasm32"))]
285            inner: None,
286        }
287    }
288}
289
290/// Receiver-side credit replenishment slot.
291#[derive(Facet)]
292#[facet(opaque)]
293pub(crate) struct ReplenisherSlot {
294    #[cfg(not(target_arch = "wasm32"))]
295    pub(crate) inner: Option<ChannelCreditReplenisherHandle>,
296}
297
298impl ReplenisherSlot {
299    pub(crate) fn empty() -> Self {
300        Self {
301            #[cfg(not(target_arch = "wasm32"))]
302            inner: None,
303        }
304    }
305}
306
307/// Sender handle: "I send". The holder of a `Tx<T>` sends items of type `T`.
308///
309/// In method args, the handler holds it (handler sends → caller).
310///
311/// Wire encoding is always unit (`()`), with channel IDs carried exclusively
312/// in `Message::Request.channels`.
313// r[impl rpc.channel]
314// r[impl rpc.channel.direction]
315// r[impl rpc.channel.payload-encoding]
316#[derive(Facet)]
317#[facet(proxy = ())]
318pub struct Tx<T, const N: usize = 16> {
319    pub(crate) sink: SinkSlot,
320    pub(crate) core: CoreSlot,
321    pub(crate) liveness: LivenessSlot,
322    #[cfg(not(target_arch = "wasm32"))]
323    #[facet(opaque)]
324    closed: AtomicBool,
325    #[facet(opaque)]
326    _marker: PhantomData<T>,
327}
328
329impl<T, const N: usize> Tx<T, N> {
330    /// Create a standalone unbound Tx (used by deserialization).
331    pub fn unbound() -> Self {
332        Self {
333            sink: SinkSlot::empty(),
334            core: CoreSlot::empty(),
335            liveness: LivenessSlot::empty(),
336            #[cfg(not(target_arch = "wasm32"))]
337            closed: AtomicBool::new(false),
338            _marker: PhantomData,
339        }
340    }
341
342    /// Create a Tx that is part of a `channel()` pair.
343    #[cfg(not(target_arch = "wasm32"))]
344    fn paired(core: Arc<ChannelCore>) -> Self {
345        Self {
346            sink: SinkSlot::empty(),
347            core: CoreSlot { inner: Some(core) },
348            liveness: LivenessSlot::empty(),
349            closed: AtomicBool::new(false),
350            _marker: PhantomData,
351        }
352    }
353
354    pub fn is_bound(&self) -> bool {
355        #[cfg(not(target_arch = "wasm32"))]
356        {
357            if self.sink.inner.is_some() {
358                return true;
359            }
360            if let Some(core) = &self.core.inner {
361                return core.get_sink().is_some();
362            }
363            false
364        }
365        #[cfg(target_arch = "wasm32")]
366        false
367    }
368
369    /// Check if this Tx is part of a channel() pair (has a shared core).
370    pub fn has_core(&self) -> bool {
371        #[cfg(not(target_arch = "wasm32"))]
372        return self.core.inner.is_some();
373        #[cfg(target_arch = "wasm32")]
374        return false;
375    }
376
377    // r[impl rpc.channel.pair.tx-read]
378    #[cfg(not(target_arch = "wasm32"))]
379    fn resolve_sink(&self) -> Result<Arc<dyn ChannelSink>, TxError> {
380        // Fast path: local slot (standalone/callee-side handle)
381        if let Some(sink) = &self.sink.inner {
382            return Ok(sink.clone());
383        }
384        // Slow path: read from shared core (paired handle)
385        if let Some(core) = &self.core.inner
386            && let Some(sink) = core.get_sink()
387        {
388            return Ok(sink);
389        }
390        Err(TxError::Unbound)
391    }
392
393    #[cfg(not(target_arch = "wasm32"))]
394    pub async fn send<'value>(&self, value: T) -> Result<(), TxError>
395    where
396        T: Facet<'value>,
397    {
398        let sink = self.resolve_sink()?;
399        let ptr = PtrConst::new((&value as *const T).cast::<u8>());
400        // SAFETY: `value` is explicitly dropped only after `await`, so the pointer
401        // remains valid for the whole send operation.
402        let payload = unsafe { Payload::outgoing_unchecked(ptr, T::SHAPE) };
403        let result = sink.send_payload(payload).await;
404        drop(value);
405        result
406    }
407
408    // r[impl rpc.channel.lifecycle]
409    #[cfg(not(target_arch = "wasm32"))]
410    pub async fn close<'value>(&self, metadata: Metadata<'value>) -> Result<(), TxError> {
411        self.closed.store(true, Ordering::Release);
412        let sink = self.resolve_sink()?;
413        sink.close_channel(metadata).await
414    }
415
416    #[doc(hidden)]
417    #[cfg(not(target_arch = "wasm32"))]
418    pub fn bind(&mut self, sink: Arc<dyn ChannelSink>) {
419        self.bind_with_liveness(sink, None);
420    }
421
422    #[doc(hidden)]
423    #[cfg(not(target_arch = "wasm32"))]
424    pub fn bind_with_liveness(
425        &mut self,
426        sink: Arc<dyn ChannelSink>,
427        liveness: Option<ChannelLivenessHandle>,
428    ) {
429        self.sink.inner = Some(sink);
430        self.liveness.inner = liveness;
431    }
432}
433
434#[cfg(not(target_arch = "wasm32"))]
435impl<T, const N: usize> Drop for Tx<T, N> {
436    fn drop(&mut self) {
437        if self.closed.swap(true, Ordering::AcqRel) {
438            return;
439        }
440
441        let sink = if let Some(sink) = &self.sink.inner {
442            Some(sink.clone())
443        } else if let Some(core) = &self.core.inner {
444            core.get_sink()
445        } else {
446            None
447        };
448
449        let Some(sink) = sink else {
450            return;
451        };
452
453        // Synchronous signal into the runtime/driver; no detached async work here.
454        sink.close_channel_on_drop();
455    }
456}
457
458#[allow(clippy::infallible_try_from)]
459impl<T, const N: usize> TryFrom<&Tx<T, N>> for () {
460    type Error = Infallible;
461
462    fn try_from(_value: &Tx<T, N>) -> Result<Self, Self::Error> {
463        Ok(())
464    }
465}
466
467#[allow(clippy::infallible_try_from)]
468impl<T, const N: usize> TryFrom<()> for Tx<T, N> {
469    type Error = Infallible;
470
471    fn try_from(_value: ()) -> Result<Self, Self::Error> {
472        Ok(Self::unbound())
473    }
474}
475
476/// Error when sending on a `Tx`.
477#[derive(Debug)]
478pub enum TxError {
479    Unbound,
480    Transport(String),
481}
482
483impl std::fmt::Display for TxError {
484    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485        match self {
486            Self::Unbound => write!(f, "channel is not bound"),
487            Self::Transport(msg) => write!(f, "transport error: {msg}"),
488        }
489    }
490}
491
492impl std::error::Error for TxError {}
493
494/// Receiver handle: "I receive". The holder of an `Rx<T>` receives items of type `T`.
495///
496/// In method args, the handler holds it (handler receives ← caller).
497///
498/// Wire encoding is always unit (`()`), with channel IDs carried exclusively
499/// in `Message::Request.channels`.
500#[derive(Facet)]
501#[facet(proxy = ())]
502pub struct Rx<T, const N: usize = 16> {
503    pub(crate) receiver: ReceiverSlot,
504    pub(crate) core: CoreSlot,
505    pub(crate) liveness: LivenessSlot,
506    pub(crate) replenisher: ReplenisherSlot,
507    #[facet(opaque)]
508    _marker: PhantomData<T>,
509}
510
511impl<T, const N: usize> Rx<T, N> {
512    /// Create a standalone unbound Rx (used by deserialization).
513    pub fn unbound() -> Self {
514        Self {
515            receiver: ReceiverSlot::empty(),
516            core: CoreSlot::empty(),
517            liveness: LivenessSlot::empty(),
518            replenisher: ReplenisherSlot::empty(),
519            _marker: PhantomData,
520        }
521    }
522
523    /// Create an Rx that is part of a `channel()` pair.
524    #[cfg(not(target_arch = "wasm32"))]
525    fn paired(core: Arc<ChannelCore>) -> Self {
526        Self {
527            receiver: ReceiverSlot::empty(),
528            core: CoreSlot { inner: Some(core) },
529            liveness: LivenessSlot::empty(),
530            replenisher: ReplenisherSlot::empty(),
531            _marker: PhantomData,
532        }
533    }
534
535    pub fn is_bound(&self) -> bool {
536        #[cfg(not(target_arch = "wasm32"))]
537        {
538            if self.receiver.inner.is_some() {
539                return true;
540            }
541            false
542        }
543        #[cfg(target_arch = "wasm32")]
544        false
545    }
546
547    /// Check if this Rx is part of a channel() pair (has a shared core).
548    pub fn has_core(&self) -> bool {
549        #[cfg(not(target_arch = "wasm32"))]
550        return self.core.inner.is_some();
551        #[cfg(target_arch = "wasm32")]
552        return false;
553    }
554
555    // r[impl rpc.channel.pair.rx-take]
556    #[cfg(not(target_arch = "wasm32"))]
557    pub async fn recv(&mut self) -> Result<Option<SelfRef<T>>, RxError>
558    where
559        T: Facet<'static>,
560    {
561        // On first call, take receiver from shared core into local slot
562        if self.receiver.inner.is_none()
563            && let Some(core) = &self.core.inner
564            && let Some(bound) = core.take_receiver()
565        {
566            self.receiver.inner = Some(bound.receiver);
567            self.liveness.inner = bound.liveness;
568            self.replenisher.inner = bound.replenisher;
569        }
570
571        let receiver = self.receiver.inner.as_mut().ok_or(RxError::Unbound)?;
572        match receiver.recv().await {
573            Some(IncomingChannelMessage::Close(_)) | None => Ok(None),
574            Some(IncomingChannelMessage::Reset(_)) => Err(RxError::Reset),
575            Some(IncomingChannelMessage::Item(msg)) => {
576                let value = msg
577                    .try_repack(|item, _backing_bytes| {
578                        let Payload::Incoming(bytes) = item.item else {
579                            return Err(RxError::Protocol(
580                                "incoming channel item payload was not Incoming".into(),
581                            ));
582                        };
583                        facet_postcard::from_slice_borrowed(bytes).map_err(RxError::Deserialize)
584                    })
585                    .map(Some);
586                if value.is_ok()
587                    && let Some(replenisher) = &self.replenisher.inner
588                {
589                    replenisher.on_item_consumed();
590                }
591                value
592            }
593        }
594    }
595
596    #[doc(hidden)]
597    #[cfg(not(target_arch = "wasm32"))]
598    pub fn bind(&mut self, receiver: mpsc::Receiver<IncomingChannelMessage>) {
599        self.bind_with_liveness(receiver, None);
600    }
601
602    #[doc(hidden)]
603    #[cfg(not(target_arch = "wasm32"))]
604    pub fn bind_with_liveness(
605        &mut self,
606        receiver: mpsc::Receiver<IncomingChannelMessage>,
607        liveness: Option<ChannelLivenessHandle>,
608    ) {
609        self.receiver.inner = Some(receiver);
610        self.liveness.inner = liveness;
611        self.replenisher.inner = None;
612    }
613}
614
615#[allow(clippy::infallible_try_from)]
616impl<T, const N: usize> TryFrom<&Rx<T, N>> for () {
617    type Error = Infallible;
618
619    fn try_from(_value: &Rx<T, N>) -> Result<Self, Self::Error> {
620        Ok(())
621    }
622}
623
624#[allow(clippy::infallible_try_from)]
625impl<T, const N: usize> TryFrom<()> for Rx<T, N> {
626    type Error = Infallible;
627
628    fn try_from(_value: ()) -> Result<Self, Self::Error> {
629        Ok(Self::unbound())
630    }
631}
632
633/// Error when receiving from an `Rx`.
634#[derive(Debug)]
635pub enum RxError {
636    Unbound,
637    Reset,
638    Deserialize(facet_postcard::DeserializeError),
639    Protocol(String),
640}
641
642impl std::fmt::Display for RxError {
643    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644        match self {
645            Self::Unbound => write!(f, "channel is not bound"),
646            Self::Reset => write!(f, "channel reset by peer"),
647            Self::Deserialize(e) => write!(f, "deserialize error: {e}"),
648            Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
649        }
650    }
651}
652
653impl std::error::Error for RxError {}
654
655/// Check if a shape represents a `Tx` channel.
656pub fn is_tx(shape: &facet_core::Shape) -> bool {
657    shape.decl_id == Tx::<()>::SHAPE.decl_id
658}
659
660/// Check if a shape represents an `Rx` channel.
661pub fn is_rx(shape: &facet_core::Shape) -> bool {
662    shape.decl_id == Rx::<()>::SHAPE.decl_id
663}
664
665/// Check if a shape represents any channel type (`Tx` or `Rx`).
666pub fn is_channel(shape: &facet_core::Shape) -> bool {
667    is_tx(shape) || is_rx(shape)
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673    use crate::{Backing, ChannelClose, ChannelItem, ChannelReset, Metadata, SelfRef};
674    use std::sync::atomic::{AtomicUsize, Ordering};
675
676    struct CountingSink {
677        send_calls: AtomicUsize,
678        close_calls: AtomicUsize,
679        close_on_drop_calls: AtomicUsize,
680    }
681
682    impl CountingSink {
683        fn new() -> Self {
684            Self {
685                send_calls: AtomicUsize::new(0),
686                close_calls: AtomicUsize::new(0),
687                close_on_drop_calls: AtomicUsize::new(0),
688            }
689        }
690    }
691
692    impl ChannelSink for CountingSink {
693        fn send_payload<'payload>(
694            &self,
695            _payload: Payload<'payload>,
696        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'payload>> {
697            self.send_calls.fetch_add(1, Ordering::AcqRel);
698            Box::pin(async { Ok(()) })
699        }
700
701        fn close_channel(
702            &self,
703            _metadata: Metadata,
704        ) -> Pin<Box<dyn Future<Output = Result<(), TxError>> + Send + 'static>> {
705            self.close_calls.fetch_add(1, Ordering::AcqRel);
706            Box::pin(async { Ok(()) })
707        }
708
709        fn close_channel_on_drop(&self) {
710            self.close_on_drop_calls.fetch_add(1, Ordering::AcqRel);
711        }
712    }
713
714    struct CountingReplenisher {
715        calls: AtomicUsize,
716    }
717
718    impl CountingReplenisher {
719        fn new() -> Self {
720            Self {
721                calls: AtomicUsize::new(0),
722            }
723        }
724    }
725
726    impl ChannelCreditReplenisher for CountingReplenisher {
727        fn on_item_consumed(&self) {
728            self.calls.fetch_add(1, Ordering::AcqRel);
729        }
730    }
731
732    #[tokio::test]
733    async fn tx_close_does_not_emit_drop_close_after_explicit_close() {
734        let sink_impl = Arc::new(CountingSink::new());
735        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
736
737        let mut tx = Tx::<u32>::unbound();
738        tx.bind(sink);
739        tx.close(Metadata::default())
740            .await
741            .expect("close should succeed");
742        drop(tx);
743
744        assert_eq!(sink_impl.close_calls.load(Ordering::Acquire), 1);
745        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 0);
746    }
747
748    #[test]
749    fn tx_drop_emits_close_on_drop_for_bound_sink() {
750        let sink_impl = Arc::new(CountingSink::new());
751        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
752
753        let mut tx = Tx::<u32>::unbound();
754        tx.bind(sink);
755        drop(tx);
756
757        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
758    }
759
760    #[test]
761    fn tx_drop_emits_close_on_drop_for_paired_core_binding() {
762        let sink_impl = Arc::new(CountingSink::new());
763        let sink: Arc<dyn ChannelSink> = sink_impl.clone();
764
765        let (tx, _rx) = channel::<u32>();
766        let core = tx.core.inner.as_ref().expect("paired tx should have core");
767        core.set_binding(ChannelBinding::Sink(BoundChannelSink {
768            sink,
769            liveness: None,
770        }));
771        drop(tx);
772
773        assert_eq!(sink_impl.close_on_drop_calls.load(Ordering::Acquire), 1);
774    }
775
776    #[tokio::test]
777    async fn rx_recv_returns_unbound_when_not_bound() {
778        let mut rx = Rx::<u32>::unbound();
779        let err = match rx.recv().await {
780            Ok(_) => panic!("unbound rx should fail"),
781            Err(err) => err,
782        };
783        assert!(matches!(err, RxError::Unbound));
784    }
785
786    #[tokio::test]
787    async fn rx_recv_returns_none_on_close() {
788        let (tx, rx_inner) = mpsc::channel(1);
789        let mut rx = Rx::<u32>::unbound();
790        rx.bind(rx_inner);
791
792        let close = SelfRef::owning(
793            Backing::Boxed(Box::<[u8]>::default()),
794            ChannelClose {
795                metadata: Metadata::default(),
796            },
797        );
798        tx.send(IncomingChannelMessage::Close(close))
799            .await
800            .expect("send close");
801
802        assert!(rx.recv().await.expect("recv should succeed").is_none());
803    }
804
805    #[tokio::test]
806    async fn rx_recv_returns_reset_error() {
807        let (tx, rx_inner) = mpsc::channel(1);
808        let mut rx = Rx::<u32>::unbound();
809        rx.bind(rx_inner);
810
811        let reset = SelfRef::owning(
812            Backing::Boxed(Box::<[u8]>::default()),
813            ChannelReset {
814                metadata: Metadata::default(),
815            },
816        );
817        tx.send(IncomingChannelMessage::Reset(reset))
818            .await
819            .expect("send reset");
820
821        let err = match rx.recv().await {
822            Ok(_) => panic!("reset should be surfaced as error"),
823            Err(err) => err,
824        };
825        assert!(matches!(err, RxError::Reset));
826    }
827
828    #[tokio::test]
829    async fn rx_recv_rejects_outgoing_payload_variant_as_protocol_error() {
830        static VALUE: u32 = 42;
831
832        let (tx, rx_inner) = mpsc::channel(1);
833        let mut rx = Rx::<u32>::unbound();
834        rx.bind(rx_inner);
835
836        let item = SelfRef::owning(
837            Backing::Boxed(Box::<[u8]>::default()),
838            ChannelItem {
839                item: Payload::outgoing(&VALUE),
840            },
841        );
842        tx.send(IncomingChannelMessage::Item(item))
843            .await
844            .expect("send item");
845
846        let err = match rx.recv().await {
847            Ok(_) => panic!("outgoing payload should be protocol error"),
848            Err(err) => err,
849        };
850        assert!(matches!(err, RxError::Protocol(_)));
851    }
852
853    #[tokio::test]
854    async fn rx_recv_notifies_replenisher_after_consuming_an_item() {
855        let (tx, rx_inner) = mpsc::channel(1);
856        let replenisher = Arc::new(CountingReplenisher::new());
857        let mut rx = Rx::<u32>::unbound();
858        rx.bind(rx_inner);
859        rx.replenisher.inner = Some(replenisher.clone());
860
861        let encoded = facet_postcard::to_vec(&123_u32).expect("serialize test item");
862        let item = SelfRef::owning(
863            Backing::Boxed(Box::<[u8]>::default()),
864            ChannelItem {
865                item: Payload::Incoming(Box::leak(encoded.into_boxed_slice())),
866            },
867        );
868        tx.send(IncomingChannelMessage::Item(item))
869            .await
870            .expect("send item");
871
872        let value = rx
873            .recv()
874            .await
875            .expect("recv should succeed")
876            .expect("expected item");
877        assert_eq!(*value, 123_u32);
878        assert_eq!(replenisher.calls.load(Ordering::Acquire), 1);
879    }
880}