1use crate::{Message, WebSocketError, WsHeartbeatConfig};
4use futures_util::{
5 stream::{SplitSink, SplitStream},
6 SinkExt, Stream, StreamExt,
7};
8use hyper::upgrade::Upgraded;
9use hyper_util::rt::TokioIo;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tokio::sync::mpsc;
13use tokio_tungstenite::WebSocketStream as TungsteniteStream;
14
15type UpgradedConnection = TungsteniteStream<TokioIo<Upgraded>>;
17
18#[allow(clippy::large_enum_variant)]
20enum StreamImpl {
21 Direct(UpgradedConnection),
23 Managed {
25 tx: mpsc::Sender<Message>,
26 rx: mpsc::Receiver<Result<Message, WebSocketError>>,
27 },
28}
29
30pub struct WebSocketStream {
32 inner: StreamImpl,
33}
34
35impl WebSocketStream {
36 pub(crate) fn new(inner: UpgradedConnection) -> Self {
38 Self {
39 inner: StreamImpl::Direct(inner),
40 }
41 }
42
43 pub(crate) fn new_managed(inner: UpgradedConnection, config: WsHeartbeatConfig) -> Self {
45 let (mut sender, mut receiver) = inner.split();
46 let (user_tx, mut internal_rx) = mpsc::channel::<Message>(32);
47 let (internal_tx, user_rx) = mpsc::channel::<Result<Message, WebSocketError>>(32);
48
49 tokio::spawn(async move {
51 let mut heartbeat_interval = tokio::time::interval(config.interval);
52 heartbeat_interval.tick().await;
54
55 let mut last_heartbeat = tokio::time::Instant::now();
67 let mut timeout_check = tokio::time::interval(config.timeout);
68
69 loop {
70 tokio::select! {
71 msg = receiver.next() => {
73 match msg {
74 Some(Ok(msg)) => {
75 last_heartbeat = tokio::time::Instant::now();
76 if msg.is_pong() {
77 continue;
79 }
80 if msg.is_ping() {
81 let _ = sender.send(Message::Pong(msg.into_data()).into()).await;
95 continue;
96 }
97
98 if internal_tx.send(Ok(Message::from(msg))).await.is_err() {
100 break; }
102 }
103 Some(Err(e)) => {
104 let _ = internal_tx.send(Err(WebSocketError::from(e))).await;
105 break;
106 }
107 None => break, }
109 }
110
111 msg = internal_rx.recv() => {
113 match msg {
114 Some(msg) => {
115 if sender.send(msg.into()).await.is_err() {
116 break; }
118 }
119 None => break, }
121 }
122
123 _ = heartbeat_interval.tick() => {
125 if sender.send(Message::Ping(vec![]).into()).await.is_err() {
126 break;
127 }
128 }
129
130 _ = timeout_check.tick() => {
132 if last_heartbeat.elapsed() > config.interval + config.timeout {
133 break;
135 }
137 }
138 }
139 }
140 });
142
143 Self {
144 inner: StreamImpl::Managed {
145 tx: user_tx,
146 rx: user_rx,
147 },
148 }
149 }
150
151 pub fn split(self) -> (WebSocketSender, WebSocketReceiver) {
153 match self.inner {
154 StreamImpl::Direct(inner) => {
155 let (sink, stream) = inner.split();
156 (
157 WebSocketSender {
158 inner: SenderImpl::Direct(sink),
159 },
160 WebSocketReceiver {
161 inner: ReceiverImpl::Direct(stream),
162 },
163 )
164 }
165 StreamImpl::Managed { tx, rx } => (
166 WebSocketSender {
167 inner: SenderImpl::Managed(tx),
168 },
169 WebSocketReceiver {
170 inner: ReceiverImpl::Managed(rx),
171 },
172 ),
173 }
174 }
175}
176
177impl WebSocketStream {
179 pub async fn send(&mut self, msg: Message) -> Result<(), WebSocketError> {
181 match &mut self.inner {
182 StreamImpl::Direct(s) => s.send(msg.into()).await.map_err(WebSocketError::from),
183 StreamImpl::Managed { tx, .. } => tx
184 .send(msg)
185 .await
186 .map_err(|_| WebSocketError::ConnectionClosed),
187 }
188 }
189
190 pub async fn recv(&mut self) -> Option<Result<Message, WebSocketError>> {
192 match &mut self.inner {
193 StreamImpl::Direct(s) => s
194 .next()
195 .await
196 .map(|r| r.map(Message::from).map_err(WebSocketError::from)),
197 StreamImpl::Managed { rx, .. } => rx.recv().await,
198 }
199 }
200
201 pub async fn send_text(&mut self, text: impl Into<String>) -> Result<(), WebSocketError> {
203 self.send(Message::text(text)).await
204 }
205
206 pub async fn send_binary(&mut self, data: impl Into<Vec<u8>>) -> Result<(), WebSocketError> {
208 self.send(Message::binary(data)).await
209 }
210
211 pub async fn send_json<T: serde::Serialize>(
213 &mut self,
214 value: &T,
215 ) -> Result<(), WebSocketError> {
216 self.send(Message::json(value)?).await
217 }
218}
219
220enum SenderImpl {
223 Direct(SplitSink<UpgradedConnection, tungstenite::Message>),
224 Managed(mpsc::Sender<Message>),
225}
226
227pub struct WebSocketSender {
229 inner: SenderImpl,
230}
231
232impl WebSocketSender {
233 pub async fn send(&mut self, msg: Message) -> Result<(), WebSocketError> {
235 match &mut self.inner {
236 SenderImpl::Direct(s) => s.send(msg.into()).await.map_err(WebSocketError::from),
237 SenderImpl::Managed(s) => s
238 .send(msg)
239 .await
240 .map_err(|_| WebSocketError::ConnectionClosed),
241 }
242 }
243
244 pub async fn send_text(&mut self, text: impl Into<String>) -> Result<(), WebSocketError> {
246 self.send(Message::text(text)).await
247 }
248
249 pub async fn send_binary(&mut self, data: impl Into<Vec<u8>>) -> Result<(), WebSocketError> {
251 self.send(Message::binary(data)).await
252 }
253
254 pub async fn send_json<T: serde::Serialize>(
256 &mut self,
257 value: &T,
258 ) -> Result<(), WebSocketError> {
259 self.send(Message::json(value)?).await
260 }
261
262 pub async fn close(mut self) -> Result<(), WebSocketError> {
264 match &mut self.inner {
265 SenderImpl::Direct(s) => s.close().await.map_err(WebSocketError::from),
266 SenderImpl::Managed(_) => {
267 Ok(())
269 }
270 }
271 }
272}
273
274enum ReceiverImpl {
275 Direct(SplitStream<UpgradedConnection>),
276 Managed(mpsc::Receiver<Result<Message, WebSocketError>>),
277}
278
279pub struct WebSocketReceiver {
281 inner: ReceiverImpl,
282}
283
284impl WebSocketReceiver {
285 pub async fn recv(&mut self) -> Option<Result<Message, WebSocketError>> {
287 match &mut self.inner {
288 ReceiverImpl::Direct(s) => s
289 .next()
290 .await
291 .map(|r| r.map(Message::from).map_err(WebSocketError::from)),
292 ReceiverImpl::Managed(s) => s.recv().await,
293 }
294 }
295}
296
297impl Stream for WebSocketReceiver {
298 type Item = Result<Message, WebSocketError>;
299
300 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301 match &mut self.inner {
302 ReceiverImpl::Direct(s) => match Pin::new(s).poll_next(cx) {
303 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(Message::from(msg)))),
304 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(WebSocketError::from(e)))),
305 Poll::Ready(None) => Poll::Ready(None),
306 Poll::Pending => Poll::Pending,
307 },
308 ReceiverImpl::Managed(s) => s.poll_recv(cx),
309 }
310 }
311}