use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::error::{Error, Result};
use crate::model::{incoming::IncomingMessage, outgoing::OutgoingMessage};
#[cfg(feature = "async-tungstenite09")]
use async_tungstenite09 as async_tungstenite;
#[cfg(feature = "async-std-runtime")]
use async_tungstenite::async_std::{connect_async, ConnectStream};
#[cfg(any(feature = "tokio-runtime", feature = "tokio02-runtime"))]
use async_tungstenite::tokio::{connect_async, ConnectStream};
use async_tungstenite::tungstenite::{
error::{Error as WsError, Result as WsResult},
Message as WsMessage,
};
use async_tungstenite::WebSocketStream;
use futures::{
sink::{Sink, SinkExt},
stream::{SplitSink, SplitStream, Stream, StreamExt, TryStreamExt},
};
#[cfg(feature = "inspect-contents")]
use log::debug;
use url::Url;
pub struct WebSocketReceiver(SplitStream<PingPongWebSocketStream<WebSocketStream<ConnectStream>>>);
impl fmt::Debug for WebSocketReceiver {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WebSocketReceiver").finish()
}
}
impl Stream for WebSocketReceiver {
type Item = Result<IncomingMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let text = match futures::ready!(self.0.poll_next_unpin(cx)?) {
Some(WsMessage::Text(t)) => t,
Some(WsMessage::Ping(_)) | Some(WsMessage::Pong(_)) => return self.poll_next(cx),
None | Some(WsMessage::Close(_)) => return Poll::Ready(None),
Some(m) => return Poll::Ready(Some(Err(Error::UnexpectedMessage(m)))),
};
#[cfg(feature = "inspect-contents")]
debug!("received message: {}", text);
Poll::Ready(Some(Ok(serde_json::from_str(&text)?)))
}
}
pub struct Recv<'a> {
stream: &'a mut WebSocketReceiver,
}
impl Future for Recv<'_> {
type Output = Result<IncomingMessage>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.stream
.poll_next_unpin(cx)
.map(|opt| opt.unwrap_or_else(|| Err(WsError::ConnectionClosed.into())))
}
}
impl WebSocketReceiver {
pub fn recv(&mut self) -> Recv<'_> {
Recv { stream: self }
}
}
pub struct WebSocketSender(
SplitSink<PingPongWebSocketStream<WebSocketStream<ConnectStream>>, WsMessage>,
);
#[derive(Debug, Clone)]
pub struct TrySendError {
pub message: OutgoingMessage,
pub error: Error,
}
impl WebSocketSender {
pub async fn try_send(
&mut self,
item: OutgoingMessage,
) -> std::result::Result<(), TrySendError> {
self.send(&item).await.map_err(|error| TrySendError {
message: item,
error,
})
}
}
impl fmt::Debug for WebSocketSender {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WebSocketSender").finish()
}
}
impl Sink<&'_ OutgoingMessage> for WebSocketSender {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
self.0.poll_ready_unpin(cx).map_err(Into::into)
}
fn start_send(mut self: Pin<&mut Self>, item: &OutgoingMessage) -> Result<()> {
let msg = WsMessage::Text(serde_json::to_string(item)?);
#[cfg(feature = "inspect-contents")]
debug!("send message: {:?}", msg);
self.0.start_send_unpin(msg).map_err(Into::into)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
self.0.poll_flush_unpin(cx).map_err(Into::into)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
self.0.poll_close_unpin(cx).map_err(Into::into)
}
}
pub enum SendPongState {
WaitSink(Vec<u8>),
WaitFlush,
}
pub struct PingPongWebSocketStream<S> {
stream: S,
state: Option<SendPongState>,
}
impl<S> PingPongWebSocketStream<S> {
pub fn new(stream: S) -> Self {
PingPongWebSocketStream {
stream,
state: None,
}
}
}
impl<S: Unpin> Stream for PingPongWebSocketStream<S>
where
S: Sink<WsMessage, Error = WsError> + Stream<Item = WsResult<WsMessage>>,
{
type Item = WsResult<WsMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self.state.take() {
None => {
let data = match futures::ready!(self.stream.try_poll_next_unpin(cx)) {
Some(Ok(WsMessage::Ping(data))) => data,
opt => return Poll::Ready(opt),
};
self.state.replace(SendPongState::WaitSink(data));
self.poll_next(cx)
}
Some(SendPongState::WaitSink(data)) => {
match self.stream.poll_ready_unpin(cx) {
Poll::Pending => {
self.state.replace(SendPongState::WaitSink(data));
return Poll::Pending;
}
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Ok(())) => {}
}
self.stream.start_send_unpin(WsMessage::Pong(data))?;
self.state.replace(SendPongState::WaitFlush);
self.poll_next(cx)
}
Some(SendPongState::WaitFlush) => match self.stream.poll_flush_unpin(cx) {
Poll::Pending => {
self.state.replace(SendPongState::WaitFlush);
Poll::Pending
}
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Ready(Ok(())) => self.poll_next(cx),
},
}
}
}
impl<S: Unpin> Sink<WsMessage> for PingPongWebSocketStream<S>
where
S: Sink<WsMessage, Error = WsError>,
{
type Error = WsError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
self.stream.poll_ready_unpin(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: WsMessage) -> WsResult<()> {
self.stream.start_send_unpin(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
self.stream.poll_flush_unpin(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<WsResult<()>> {
self.stream.poll_close_unpin(cx)
}
}
pub async fn connect_websocket(url: Url) -> Result<(WebSocketSender, WebSocketReceiver)> {
let (ws, _) = connect_async(url).await?;
let (sink, stream) = PingPongWebSocketStream::new(ws).split();
Ok((WebSocketSender(sink), WebSocketReceiver(stream)))
}