barter_integration/
channel.rs

1use crate::Unrecoverable;
2use derive_more::{Constructor, Display};
3use futures::{Sink, Stream};
4use serde::{Deserialize, Serialize};
5use std::{
6    fmt::Debug,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use tracing::warn;
11
12pub trait Tx
13where
14    Self: Debug + Clone + Send,
15{
16    type Item;
17    type Error: Unrecoverable + Debug;
18    fn send<Item: Into<Self::Item>>(&self, item: Item) -> Result<(), Self::Error>;
19}
20
21/// Convenience type that holds the [`UnboundedTx`] and [`UnboundedRx`].
22#[derive(Debug)]
23pub struct Channel<T> {
24    pub tx: UnboundedTx<T>,
25    pub rx: UnboundedRx<T>,
26}
27
28impl<T> Channel<T> {
29    /// Construct a new unbounded [`Channel`].
30    pub fn new() -> Self {
31        let (tx, rx) = mpsc_unbounded();
32        Self { tx, rx }
33    }
34}
35
36impl<T> Default for Channel<T> {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct UnboundedTx<T> {
44    pub tx: tokio::sync::mpsc::UnboundedSender<T>,
45}
46
47impl<T> UnboundedTx<T> {
48    pub fn new(tx: tokio::sync::mpsc::UnboundedSender<T>) -> Self {
49        Self { tx }
50    }
51}
52
53impl<T> Tx for UnboundedTx<T>
54where
55    T: Debug + Clone + Send,
56{
57    type Item = T;
58    type Error = tokio::sync::mpsc::error::SendError<T>;
59
60    fn send<Item: Into<Self::Item>>(&self, item: Item) -> Result<(), Self::Error> {
61        self.tx.send(item.into())
62    }
63}
64
65impl<T> Unrecoverable for tokio::sync::mpsc::error::SendError<T> {
66    fn is_unrecoverable(&self) -> bool {
67        true
68    }
69}
70
71impl<T> Sink<T> for UnboundedTx<T> {
72    type Error = tokio::sync::mpsc::error::SendError<T>;
73
74    fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        // UnboundedTx is always ready
76        Poll::Ready(Ok(()))
77    }
78
79    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
80        self.tx.send(item)
81    }
82
83    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        // UnboundedTx does not buffer, so no flushing is required
85        Poll::Ready(Ok(()))
86    }
87
88    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89        // UnboundedTx requires no closing logic
90        Poll::Ready(Ok(()))
91    }
92}
93
94#[derive(Debug, Constructor)]
95pub struct UnboundedRx<T> {
96    pub rx: tokio::sync::mpsc::UnboundedReceiver<T>,
97}
98
99impl<T> Iterator for UnboundedRx<T> {
100    type Item = T;
101
102    fn next(&mut self) -> Option<Self::Item> {
103        loop {
104            match self.rx.try_recv() {
105                Ok(event) => break Some(event),
106                Err(tokio::sync::mpsc::error::TryRecvError::Empty) => continue,
107                Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break None,
108            }
109        }
110    }
111}
112
113impl<T> UnboundedRx<T> {
114    pub fn into_stream(self) -> tokio_stream::wrappers::UnboundedReceiverStream<T> {
115        tokio_stream::wrappers::UnboundedReceiverStream::new(self.rx)
116    }
117}
118
119impl<T> Stream for UnboundedRx<T> {
120    type Item = T;
121
122    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123        self.rx.poll_recv(cx)
124    }
125}
126
127#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
128pub struct ChannelTxDroppable<ChannelTx> {
129    pub state: ChannelState<ChannelTx>,
130}
131
132impl<ChannelTx> ChannelTxDroppable<ChannelTx> {
133    pub fn new(tx: ChannelTx) -> Self {
134        Self {
135            state: ChannelState::Active(tx),
136        }
137    }
138
139    pub fn new_disabled() -> Self {
140        Self {
141            state: ChannelState::Disabled,
142        }
143    }
144
145    pub fn disable(&mut self) {
146        self.state = ChannelState::Disabled
147    }
148}
149
150#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize, Display)]
151pub enum ChannelState<Tx> {
152    Active(Tx),
153    Disabled,
154}
155
156impl<ChannelTx> ChannelTxDroppable<ChannelTx>
157where
158    ChannelTx: Tx,
159{
160    pub fn send(&mut self, item: ChannelTx::Item) {
161        let ChannelState::Active(tx) = &self.state else {
162            return;
163        };
164
165        if tx.send(item).is_err() {
166            let name = std::any::type_name::<ChannelTx::Item>();
167            warn!(
168                name,
169                "ChannelTxDroppable receiver dropped - items will no longer be sent"
170            );
171            self.state = ChannelState::Disabled
172        }
173    }
174}
175pub fn mpsc_unbounded<T>() -> (UnboundedTx<T>, UnboundedRx<T>) {
176    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
177    (UnboundedTx::new(tx), UnboundedRx::new(rx))
178}