barter_integration/
channel.rs1use crate::Unrecoverable;
2use derive_more::{Constructor, Display};
3use futures::{Sink, Stream};
4use serde::{Deserialize, Serialize};
5use std::{
6 fmt::Debug,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use tracing::warn;
11
12pub trait Tx
13where
14 Self: Debug + Clone + Send,
15{
16 type Item;
17 type Error: Unrecoverable + Debug;
18 fn send<Item: Into<Self::Item>>(&self, item: Item) -> Result<(), Self::Error>;
19}
20
21#[derive(Debug)]
23pub struct Channel<T> {
24 pub tx: UnboundedTx<T>,
25 pub rx: UnboundedRx<T>,
26}
27
28impl<T> Channel<T> {
29 pub fn new() -> Self {
31 let (tx, rx) = mpsc_unbounded();
32 Self { tx, rx }
33 }
34}
35
36impl<T> Default for Channel<T> {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct UnboundedTx<T> {
44 pub tx: tokio::sync::mpsc::UnboundedSender<T>,
45}
46
47impl<T> UnboundedTx<T> {
48 pub fn new(tx: tokio::sync::mpsc::UnboundedSender<T>) -> Self {
49 Self { tx }
50 }
51}
52
53impl<T> Tx for UnboundedTx<T>
54where
55 T: Debug + Clone + Send,
56{
57 type Item = T;
58 type Error = tokio::sync::mpsc::error::SendError<T>;
59
60 fn send<Item: Into<Self::Item>>(&self, item: Item) -> Result<(), Self::Error> {
61 self.tx.send(item.into())
62 }
63}
64
65impl<T> Unrecoverable for tokio::sync::mpsc::error::SendError<T> {
66 fn is_unrecoverable(&self) -> bool {
67 true
68 }
69}
70
71impl<T> Sink<T> for UnboundedTx<T> {
72 type Error = tokio::sync::mpsc::error::SendError<T>;
73
74 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 Poll::Ready(Ok(()))
77 }
78
79 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
80 self.tx.send(item)
81 }
82
83 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 Poll::Ready(Ok(()))
86 }
87
88 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89 Poll::Ready(Ok(()))
91 }
92}
93
94#[derive(Debug, Constructor)]
95pub struct UnboundedRx<T> {
96 pub rx: tokio::sync::mpsc::UnboundedReceiver<T>,
97}
98
99impl<T> Iterator for UnboundedRx<T> {
100 type Item = T;
101
102 fn next(&mut self) -> Option<Self::Item> {
103 loop {
104 match self.rx.try_recv() {
105 Ok(event) => break Some(event),
106 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => continue,
107 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break None,
108 }
109 }
110 }
111}
112
113impl<T> UnboundedRx<T> {
114 pub fn into_stream(self) -> tokio_stream::wrappers::UnboundedReceiverStream<T> {
115 tokio_stream::wrappers::UnboundedReceiverStream::new(self.rx)
116 }
117}
118
119impl<T> Stream for UnboundedRx<T> {
120 type Item = T;
121
122 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123 self.rx.poll_recv(cx)
124 }
125}
126
127#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
128pub struct ChannelTxDroppable<ChannelTx> {
129 pub state: ChannelState<ChannelTx>,
130}
131
132impl<ChannelTx> ChannelTxDroppable<ChannelTx> {
133 pub fn new(tx: ChannelTx) -> Self {
134 Self {
135 state: ChannelState::Active(tx),
136 }
137 }
138
139 pub fn new_disabled() -> Self {
140 Self {
141 state: ChannelState::Disabled,
142 }
143 }
144
145 pub fn disable(&mut self) {
146 self.state = ChannelState::Disabled
147 }
148}
149
150#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize, Display)]
151pub enum ChannelState<Tx> {
152 Active(Tx),
153 Disabled,
154}
155
156impl<ChannelTx> ChannelTxDroppable<ChannelTx>
157where
158 ChannelTx: Tx,
159{
160 pub fn send(&mut self, item: ChannelTx::Item) {
161 let ChannelState::Active(tx) = &self.state else {
162 return;
163 };
164
165 if tx.send(item).is_err() {
166 let name = std::any::type_name::<ChannelTx::Item>();
167 warn!(
168 name,
169 "ChannelTxDroppable receiver dropped - items will no longer be sent"
170 );
171 self.state = ChannelState::Disabled
172 }
173 }
174}
175pub fn mpsc_unbounded<T>() -> (UnboundedTx<T>, UnboundedRx<T>) {
176 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
177 (UnboundedTx::new(tx), UnboundedRx::new(rx))
178}