use async_trait::async_trait;
use bytes::Bytes;
use mio::net::TcpStream as MIOTcpStream;
use parking_lot::Mutex;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream as TokioTcpStream};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use super::pool::{AddressedStreamFactory, NonBlockingStream};
use super::TransportError;
use crate::hub::event::IOSource;
use crate::hub::utils::error;
use tungstenite::handshake::{
client::ClientHandshake,
server::{NoCallback, ServerHandshake},
HandshakeError, MidHandshake,
};
use tungstenite::protocol::WebSocket;
enum MessageStream {
ServerHandshake(MidHandshake<ServerHandshake<MIOTcpStream, NoCallback>>),
ClientHandshake(MidHandshake<ClientHandshake<MIOTcpStream>>),
Done(WebSocket<MIOTcpStream>),
Empty,
}
impl MessageStream {
fn take(&mut self) -> Self {
std::mem::replace(self, MessageStream::Empty)
}
fn try_handshake(&mut self) -> Result<(), TransportError> {
match self.take() {
MessageStream::ServerHandshake(mid) => match mid.handshake() {
Ok(stream) => *self = MessageStream::Done(stream),
Err(HandshakeError::Interrupted(mid)) => *self = MessageStream::ServerHandshake(mid),
Err(_) => return Err(TransportError::BothTerminated),
},
MessageStream::ClientHandshake(mid) => match mid.handshake() {
Ok((stream, _resp)) => *self = MessageStream::Done(stream),
Err(HandshakeError::Interrupted(mid)) => *self = MessageStream::ClientHandshake(mid),
Err(_) => return Err(TransportError::BothTerminated),
},
e => *self = e,
}
Ok(())
}
}
impl NonBlockingStream for MessageStream {
fn try_recv(&mut self) -> Result<Bytes, TransportError> {
self.try_handshake()?;
if let MessageStream::Done(ws) = self {
use tungstenite::error::Error::*;
loop {
match ws.read() {
Ok(msg) => match msg {
tungstenite::protocol::Message::Binary(blob) => return Ok(blob.into()),
_ => (),
},
Err(Capacity(_)) => return Err(TransportError::NotReady),
Err(Io(e)) => {
return match e.kind() {
io::ErrorKind::WouldBlock => Err(TransportError::NotReady),
_ => Err(TransportError::BothTerminated),
}
}
Err(ConnectionClosed) => return Err(TransportError::HalfTerminated),
_ => return Err(TransportError::BothTerminated),
}
}
} else {
Err(TransportError::NotReady)
}
}
fn try_send(&mut self, data: Option<Bytes>) -> Result<bool, TransportError> {
let data = match data {
Some(d) => d,
None => return Ok(false),
};
self.try_handshake()?;
if let MessageStream::Done(ws) = self {
use tungstenite::error::Error::*;
match ws.send(tungstenite::protocol::Message::Binary(data.to_vec())) {
Ok(()) => Ok(false),
Err(WriteBufferFull(_)) => Err(TransportError::NotReady),
Err(Io(e)) => {
return match e.kind() {
io::ErrorKind::WouldBlock => Err(TransportError::NotReady),
_ => Err(TransportError::BothTerminated),
}
}
Err(ConnectionClosed) => Err(TransportError::HalfTerminated),
_ => Err(TransportError::BothTerminated),
}
} else {
Err(TransportError::NotReady)
}
}
fn source(&mut self) -> IOSource {
match self {
MessageStream::ServerHandshake(mid) => IOSource::MIO(mid.get_mut().get_mut()),
MessageStream::ClientHandshake(mid) => IOSource::MIO(mid.get_mut().get_mut()),
MessageStream::Done(stream) => IOSource::MIO(stream.get_mut()),
MessageStream::Empty => IOSource::Empty,
}
}
fn shutdown(&mut self, how: std::net::Shutdown) -> io::Result<()> {
match self {
MessageStream::ServerHandshake(mid) => mid.get_ref().get_ref(),
MessageStream::ClientHandshake(mid) => mid.get_ref().get_ref(),
MessageStream::Done(stream) => stream.get_ref(),
MessageStream::Empty => return Ok(()),
}
.shutdown(how)
}
}
struct FactoryInner {
accepted_stream: Mutex<mpsc::Receiver<MessageStream>>,
listen_handle: Option<JoinHandle<()>>,
}
impl Drop for FactoryInner {
fn drop(&mut self) {
if let Some(h) = self.listen_handle.take() {
h.abort()
}
}
}
use super::tcp::tokio_to_mio_stream;
#[derive(Clone)]
pub struct Factory(Arc<FactoryInner>);
impl Factory {
pub fn new(listen_addr: Option<SocketAddr>) -> Self {
let (tx, accepted_stream) = mpsc::channel(1);
let listen_handle = listen_addr.map(|listen_addr| {
tokio::spawn(async move {
let listener = match TcpListener::bind(listen_addr).await {
Ok(l) => l,
Err(e) => {
error!("[WebSocket] failed to bind to address {}: {}", listen_addr, e);
return
}
};
loop {
if let Ok((stream, _)) = listener.accept().await {
let stream = match tungstenite::accept(tokio_to_mio_stream(stream)) {
Ok(stream) => MessageStream::Done(stream),
Err(HandshakeError::Interrupted(mid)) => MessageStream::ServerHandshake(mid),
Err(e) => panic!("{:?}", e),
};
tx.send(stream).await.ok();
}
}
})
});
Self(Arc::new(FactoryInner {
accepted_stream: Mutex::new(accepted_stream),
listen_handle,
}))
}
}
#[async_trait]
impl AddressedStreamFactory for Factory {
async fn create_stream(&self, url: &str) -> Option<Box<dyn NonBlockingStream>> {
let url: url::Url = url.parse().ok()?;
let addrs = url
.socket_addrs(|| match url.scheme() {
"wss" => Some(443),
"ws" => Some(80),
_ => None,
})
.ok()?;
let tcp_stream = TokioTcpStream::connect(addrs[0]).await.ok()?;
match tungstenite::client::client(url, tokio_to_mio_stream(tcp_stream)) {
Ok((stream, _)) => Some(Box::new(MessageStream::Done(stream))),
Err(HandshakeError::Interrupted(mid)) => Some(Box::new(MessageStream::ClientHandshake(mid))),
Err(e) => {
error!("[WebSocket] {:?}", e);
None
}
}
}
async fn discover_stream(&self) -> Box<dyn NonBlockingStream> {
match self.0.accepted_stream.lock().recv().await {
None => futures::future::pending().await,
Some(s) => Box::new(s),
}
}
}