use std::future::Future;
use crate::prelude::*;
use crate::traits::WebSocketReceiverTrait;
use crate::{Message, Result};
use tokio_tungstenite::{WebSocketStream, MaybeTlsStream};
use tokio::net::TcpStream;
use futures_util::stream::SplitStream;
use futures_util::StreamExt;
use futures_util::SinkExt;
use std::convert::TryFrom;
use std::sync::{Arc, Mutex};
use futures_util::stream::SplitSink;
#[derive(Derivative)]
#[derivative(Debug)]
pub struct WebSocketReceiver {
#[derivative(Debug="ignore")]
receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
#[derivative(Debug="ignore")]
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>
}
impl From<(SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>)> for WebSocketReceiver {
fn from((receiver, sender): (SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>)) -> Self {
WebSocketReceiver { receiver, sender }
}
}
impl WebSocketReceiverTrait for WebSocketReceiver {
fn next(&mut self) -> impl Future<Output = Result<Option<Message>>> {
async move {
loop {
match self.receiver.next().await {
Some(Ok(tungstenite_msg)) => {
if let tokio_tungstenite::tungstenite::Message::Ping(payload) = &tungstenite_msg {
let mut sender = self.sender.lock()
.map_err(|err| crate::error::Error::LockError(format!("Failed to lock sender for pong: {:?}", err)))?;
let pong_msg = tokio_tungstenite::tungstenite::Message::Pong(payload.clone());
if let Err(error) = sender.send(pong_msg).await {
return Err(crate::error::Error::SendError(error));
}
continue;
}
if matches!(tungstenite_msg, tokio_tungstenite::tungstenite::Message::Pong(_)) {
continue;
}
return Message::try_from(tungstenite_msg)
.map_err(|err| {
match err {
crate::error::Error::UnsupportedMessageType(msg) => {
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
use std::io::{Error as IoError, ErrorKind};
let io_err = IoError::new(ErrorKind::InvalidData, msg);
crate::error::Error::ReceiveError(
TungsteniteError::Io(io_err)
)
},
other => other
}
})
.map(Some);
},
Some(Err(error)) => {
return Err(crate::error::Error::ReceiveError(error));
},
None => {
return Ok(None);
}
}
}
}
}
}