misskey_websocket/
channel.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use crate::error::{Error, Result};
7use crate::model::{incoming::IncomingMessage, outgoing::OutgoingMessage};
8
9#[cfg(feature = "async-tungstenite09")]
10use async_tungstenite09 as async_tungstenite;
11
12#[cfg(feature = "async-std-runtime")]
13use async_tungstenite::async_std::{connect_async, ConnectStream};
14#[cfg(any(feature = "tokio-runtime", feature = "tokio02-runtime"))]
15use async_tungstenite::tokio::{connect_async, ConnectStream};
16use async_tungstenite::tungstenite::{
17    error::{Error as WsError, Result as WsResult},
18    Message as WsMessage,
19};
20use async_tungstenite::WebSocketStream;
21use futures::{
22    sink::{Sink, SinkExt},
23    stream::{SplitSink, SplitStream, Stream, StreamExt, TryStreamExt},
24};
25#[cfg(feature = "inspect-contents")]
26use log::debug;
27use url::Url;
28
29/// Receiver channel that communicates with Misskey
30pub struct WebSocketReceiver(SplitStream<PingPongWebSocketStream<WebSocketStream<ConnectStream>>>);
31
32impl fmt::Debug for WebSocketReceiver {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        f.debug_struct("WebSocketReceiver").finish()
35    }
36}
37
38impl Stream for WebSocketReceiver {
39    type Item = Result<IncomingMessage>;
40
41    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
42        let text = match futures::ready!(self.0.poll_next_unpin(cx)?) {
43            Some(WsMessage::Text(t)) => t,
44            Some(WsMessage::Ping(_)) | Some(WsMessage::Pong(_)) => return self.poll_next(cx),
45            None | Some(WsMessage::Close(_)) => return Poll::Ready(None),
46            Some(m) => return Poll::Ready(Some(Err(Error::UnexpectedMessage(m)))),
47        };
48
49        #[cfg(feature = "inspect-contents")]
50        debug!("received message: {}", text);
51
52        Poll::Ready(Some(Ok(serde_json::from_str(&text)?)))
53    }
54}
55
56pub struct Recv<'a> {
57    stream: &'a mut WebSocketReceiver,
58}
59
60impl Future for Recv<'_> {
61    type Output = Result<IncomingMessage>;
62
63    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64        self.stream
65            .poll_next_unpin(cx)
66            .map(|opt| opt.unwrap_or_else(|| Err(WsError::ConnectionClosed.into())))
67    }
68}
69
70impl WebSocketReceiver {
71    /// Receive one message from the stream using `StreamExt::next`,
72    /// while folding `None` to an error that represents closed connection.
73    pub fn recv(&mut self) -> Recv<'_> {
74        Recv { stream: self }
75    }
76}
77
78/// Sender channel that communicates with Misskey
79pub struct WebSocketSender(
80    SplitSink<PingPongWebSocketStream<WebSocketStream<ConnectStream>>, WsMessage>,
81);
82
83#[derive(Debug, Clone)]
84pub struct TrySendError {
85    pub message: OutgoingMessage,
86    pub error: Error,
87}
88
89impl WebSocketSender {
90    /// convenient method that retains the message in the error
91    pub async fn try_send(
92        &mut self,
93        item: OutgoingMessage,
94    ) -> std::result::Result<(), TrySendError> {
95        self.send(&item).await.map_err(|error| TrySendError {
96            message: item,
97            error,
98        })
99    }
100}
101
102impl fmt::Debug for WebSocketSender {
103    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104        f.debug_struct("WebSocketSender").finish()
105    }
106}
107
108impl Sink<&'_ OutgoingMessage> for WebSocketSender {
109    type Error = Error;
110
111    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
112        self.0.poll_ready_unpin(cx).map_err(Into::into)
113    }
114
115    fn start_send(mut self: Pin<&mut Self>, item: &OutgoingMessage) -> Result<()> {
116        let msg = WsMessage::Text(serde_json::to_string(item)?);
117
118        #[cfg(feature = "inspect-contents")]
119        debug!("send message: {:?}", msg);
120
121        self.0.start_send_unpin(msg).map_err(Into::into)
122    }
123
124    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
125        self.0.poll_flush_unpin(cx).map_err(Into::into)
126    }
127
128    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
129        self.0.poll_close_unpin(cx).map_err(Into::into)
130    }
131}
132
133pub enum SendPongState {
134    WaitSink(Vec<u8>),
135    WaitFlush,
136}
137
138pub struct PingPongWebSocketStream<S> {
139    stream: S,
140    state: Option<SendPongState>,
141}
142
143impl<S> PingPongWebSocketStream<S> {
144    pub fn new(stream: S) -> Self {
145        PingPongWebSocketStream {
146            stream,
147            state: None,
148        }
149    }
150}
151
152impl<S: Unpin> Stream for PingPongWebSocketStream<S>
153where
154    S: Sink<WsMessage, Error = WsError> + Stream<Item = WsResult<WsMessage>>,
155{
156    type Item = WsResult<WsMessage>;
157    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
158        match self.state.take() {
159            None => {
160                let data = match futures::ready!(self.stream.try_poll_next_unpin(cx)) {
161                    Some(Ok(WsMessage::Ping(data))) => data,
162                    opt => return Poll::Ready(opt),
163                };
164
165                self.state.replace(SendPongState::WaitSink(data));
166                self.poll_next(cx)
167            }
168            Some(SendPongState::WaitSink(data)) => {
169                match self.stream.poll_ready_unpin(cx) {
170                    Poll::Pending => {
171                        self.state.replace(SendPongState::WaitSink(data));
172                        return Poll::Pending;
173                    }
174                    Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
175                    Poll::Ready(Ok(())) => {}
176                }
177
178                self.stream.start_send_unpin(WsMessage::Pong(data))?;
179                self.state.replace(SendPongState::WaitFlush);
180                self.poll_next(cx)
181            }
182            Some(SendPongState::WaitFlush) => match self.stream.poll_flush_unpin(cx) {
183                Poll::Pending => {
184                    self.state.replace(SendPongState::WaitFlush);
185                    Poll::Pending
186                }
187                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
188                Poll::Ready(Ok(())) => self.poll_next(cx),
189            },
190        }
191    }
192}
193
194impl<S: Unpin> Sink<WsMessage> for PingPongWebSocketStream<S>
195where
196    S: Sink<WsMessage, Error = WsError>,
197{
198    type Error = WsError;
199
200    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
201        self.stream.poll_ready_unpin(cx)
202    }
203
204    fn start_send(mut self: Pin<&mut Self>, item: WsMessage) -> WsResult<()> {
205        self.stream.start_send_unpin(item)
206    }
207
208    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
209        self.stream.poll_flush_unpin(cx)
210    }
211
212    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
213        self.stream.poll_close_unpin(cx)
214    }
215}
216
217pub async fn connect_websocket(url: Url) -> Result<(WebSocketSender, WebSocketReceiver)> {
218    let (ws, _) = connect_async(url).await?;
219    let (sink, stream) = PingPongWebSocketStream::new(ws).split();
220    Ok((WebSocketSender(sink), WebSocketReceiver(stream)))
221}