misskey_websocket/
channel.rs1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use crate::error::{Error, Result};
7use crate::model::{incoming::IncomingMessage, outgoing::OutgoingMessage};
8
9#[cfg(feature = "async-tungstenite09")]
10use async_tungstenite09 as async_tungstenite;
11
12#[cfg(feature = "async-std-runtime")]
13use async_tungstenite::async_std::{connect_async, ConnectStream};
14#[cfg(any(feature = "tokio-runtime", feature = "tokio02-runtime"))]
15use async_tungstenite::tokio::{connect_async, ConnectStream};
16use async_tungstenite::tungstenite::{
17 error::{Error as WsError, Result as WsResult},
18 Message as WsMessage,
19};
20use async_tungstenite::WebSocketStream;
21use futures::{
22 sink::{Sink, SinkExt},
23 stream::{SplitSink, SplitStream, Stream, StreamExt, TryStreamExt},
24};
25#[cfg(feature = "inspect-contents")]
26use log::debug;
27use url::Url;
28
29pub struct WebSocketReceiver(SplitStream<PingPongWebSocketStream<WebSocketStream<ConnectStream>>>);
31
32impl fmt::Debug for WebSocketReceiver {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 f.debug_struct("WebSocketReceiver").finish()
35 }
36}
37
38impl Stream for WebSocketReceiver {
39 type Item = Result<IncomingMessage>;
40
41 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
42 let text = match futures::ready!(self.0.poll_next_unpin(cx)?) {
43 Some(WsMessage::Text(t)) => t,
44 Some(WsMessage::Ping(_)) | Some(WsMessage::Pong(_)) => return self.poll_next(cx),
45 None | Some(WsMessage::Close(_)) => return Poll::Ready(None),
46 Some(m) => return Poll::Ready(Some(Err(Error::UnexpectedMessage(m)))),
47 };
48
49 #[cfg(feature = "inspect-contents")]
50 debug!("received message: {}", text);
51
52 Poll::Ready(Some(Ok(serde_json::from_str(&text)?)))
53 }
54}
55
56pub struct Recv<'a> {
57 stream: &'a mut WebSocketReceiver,
58}
59
60impl Future for Recv<'_> {
61 type Output = Result<IncomingMessage>;
62
63 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64 self.stream
65 .poll_next_unpin(cx)
66 .map(|opt| opt.unwrap_or_else(|| Err(WsError::ConnectionClosed.into())))
67 }
68}
69
70impl WebSocketReceiver {
71 pub fn recv(&mut self) -> Recv<'_> {
74 Recv { stream: self }
75 }
76}
77
78pub struct WebSocketSender(
80 SplitSink<PingPongWebSocketStream<WebSocketStream<ConnectStream>>, WsMessage>,
81);
82
83#[derive(Debug, Clone)]
84pub struct TrySendError {
85 pub message: OutgoingMessage,
86 pub error: Error,
87}
88
89impl WebSocketSender {
90 pub async fn try_send(
92 &mut self,
93 item: OutgoingMessage,
94 ) -> std::result::Result<(), TrySendError> {
95 self.send(&item).await.map_err(|error| TrySendError {
96 message: item,
97 error,
98 })
99 }
100}
101
102impl fmt::Debug for WebSocketSender {
103 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104 f.debug_struct("WebSocketSender").finish()
105 }
106}
107
108impl Sink<&'_ OutgoingMessage> for WebSocketSender {
109 type Error = Error;
110
111 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
112 self.0.poll_ready_unpin(cx).map_err(Into::into)
113 }
114
115 fn start_send(mut self: Pin<&mut Self>, item: &OutgoingMessage) -> Result<()> {
116 let msg = WsMessage::Text(serde_json::to_string(item)?);
117
118 #[cfg(feature = "inspect-contents")]
119 debug!("send message: {:?}", msg);
120
121 self.0.start_send_unpin(msg).map_err(Into::into)
122 }
123
124 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
125 self.0.poll_flush_unpin(cx).map_err(Into::into)
126 }
127
128 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
129 self.0.poll_close_unpin(cx).map_err(Into::into)
130 }
131}
132
133pub enum SendPongState {
134 WaitSink(Vec<u8>),
135 WaitFlush,
136}
137
138pub struct PingPongWebSocketStream<S> {
139 stream: S,
140 state: Option<SendPongState>,
141}
142
143impl<S> PingPongWebSocketStream<S> {
144 pub fn new(stream: S) -> Self {
145 PingPongWebSocketStream {
146 stream,
147 state: None,
148 }
149 }
150}
151
152impl<S: Unpin> Stream for PingPongWebSocketStream<S>
153where
154 S: Sink<WsMessage, Error = WsError> + Stream<Item = WsResult<WsMessage>>,
155{
156 type Item = WsResult<WsMessage>;
157 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
158 match self.state.take() {
159 None => {
160 let data = match futures::ready!(self.stream.try_poll_next_unpin(cx)) {
161 Some(Ok(WsMessage::Ping(data))) => data,
162 opt => return Poll::Ready(opt),
163 };
164
165 self.state.replace(SendPongState::WaitSink(data));
166 self.poll_next(cx)
167 }
168 Some(SendPongState::WaitSink(data)) => {
169 match self.stream.poll_ready_unpin(cx) {
170 Poll::Pending => {
171 self.state.replace(SendPongState::WaitSink(data));
172 return Poll::Pending;
173 }
174 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
175 Poll::Ready(Ok(())) => {}
176 }
177
178 self.stream.start_send_unpin(WsMessage::Pong(data))?;
179 self.state.replace(SendPongState::WaitFlush);
180 self.poll_next(cx)
181 }
182 Some(SendPongState::WaitFlush) => match self.stream.poll_flush_unpin(cx) {
183 Poll::Pending => {
184 self.state.replace(SendPongState::WaitFlush);
185 Poll::Pending
186 }
187 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
188 Poll::Ready(Ok(())) => self.poll_next(cx),
189 },
190 }
191 }
192}
193
194impl<S: Unpin> Sink<WsMessage> for PingPongWebSocketStream<S>
195where
196 S: Sink<WsMessage, Error = WsError>,
197{
198 type Error = WsError;
199
200 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
201 self.stream.poll_ready_unpin(cx)
202 }
203
204 fn start_send(mut self: Pin<&mut Self>, item: WsMessage) -> WsResult<()> {
205 self.stream.start_send_unpin(item)
206 }
207
208 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
209 self.stream.poll_flush_unpin(cx)
210 }
211
212 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
213 self.stream.poll_close_unpin(cx)
214 }
215}
216
217pub async fn connect_websocket(url: Url) -> Result<(WebSocketSender, WebSocketReceiver)> {
218 let (ws, _) = connect_async(url).await?;
219 let (sink, stream) = PingPongWebSocketStream::new(ws).split();
220 Ok((WebSocketSender(sink), WebSocketReceiver(stream)))
221}