bolic-network 0.0.1

Modern network abstraction and tooling for building distributed systems
Documentation
use async_trait::async_trait;
use bytes::Bytes;
use mio::net::TcpStream as MIOTcpStream;
use parking_lot::Mutex;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream as TokioTcpStream};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;

use super::pool::{AddressedStreamFactory, NonBlockingStream};
use super::TransportError;
use crate::hub::event::IOSource;
use crate::hub::utils::error;

use tungstenite::handshake::{
    client::ClientHandshake,
    server::{NoCallback, ServerHandshake},
    HandshakeError, MidHandshake,
};
use tungstenite::protocol::WebSocket;

enum MessageStream {
    ServerHandshake(MidHandshake<ServerHandshake<MIOTcpStream, NoCallback>>),
    ClientHandshake(MidHandshake<ClientHandshake<MIOTcpStream>>),
    Done(WebSocket<MIOTcpStream>),
    Empty,
}

impl MessageStream {
    fn take(&mut self) -> Self {
        std::mem::replace(self, MessageStream::Empty)
    }
    fn try_handshake(&mut self) -> Result<(), TransportError> {
        match self.take() {
            MessageStream::ServerHandshake(mid) => match mid.handshake() {
                Ok(stream) => *self = MessageStream::Done(stream),
                Err(HandshakeError::Interrupted(mid)) => *self = MessageStream::ServerHandshake(mid),
                Err(_) => return Err(TransportError::BothTerminated),
            },
            MessageStream::ClientHandshake(mid) => match mid.handshake() {
                Ok((stream, _resp)) => *self = MessageStream::Done(stream),
                Err(HandshakeError::Interrupted(mid)) => *self = MessageStream::ClientHandshake(mid),
                Err(_) => return Err(TransportError::BothTerminated),
            },
            e => *self = e,
        }
        Ok(())
    }
}

impl NonBlockingStream for MessageStream {
    fn try_recv(&mut self) -> Result<Bytes, TransportError> {
        self.try_handshake()?;
        if let MessageStream::Done(ws) = self {
            use tungstenite::error::Error::*;
            loop {
                match ws.read() {
                    Ok(msg) => match msg {
                        tungstenite::protocol::Message::Binary(blob) => return Ok(blob.into()),
                        _ => (),
                    },
                    Err(Capacity(_)) => return Err(TransportError::NotReady),
                    Err(Io(e)) => {
                        return match e.kind() {
                            io::ErrorKind::WouldBlock => Err(TransportError::NotReady),
                            _ => Err(TransportError::BothTerminated),
                        }
                    }
                    Err(ConnectionClosed) => return Err(TransportError::HalfTerminated),
                    _ => return Err(TransportError::BothTerminated),
                }
            }
        } else {
            Err(TransportError::NotReady)
        }
    }

    fn try_send(&mut self, data: Option<Bytes>) -> Result<bool, TransportError> {
        let data = match data {
            Some(d) => d,
            None => return Ok(false),
        };
        self.try_handshake()?;
        if let MessageStream::Done(ws) = self {
            use tungstenite::error::Error::*;
            // TODO: improve to make it zero-copy
            match ws.send(tungstenite::protocol::Message::Binary(data.to_vec())) {
                Ok(()) => Ok(false),
                Err(WriteBufferFull(_)) => Err(TransportError::NotReady),
                Err(Io(e)) => {
                    return match e.kind() {
                        io::ErrorKind::WouldBlock => Err(TransportError::NotReady),
                        _ => Err(TransportError::BothTerminated),
                    }
                }
                Err(ConnectionClosed) => Err(TransportError::HalfTerminated),
                _ => Err(TransportError::BothTerminated),
            }
        } else {
            Err(TransportError::NotReady)
        }
    }

    fn source(&mut self) -> IOSource {
        match self {
            MessageStream::ServerHandshake(mid) => IOSource::MIO(mid.get_mut().get_mut()),
            MessageStream::ClientHandshake(mid) => IOSource::MIO(mid.get_mut().get_mut()),
            MessageStream::Done(stream) => IOSource::MIO(stream.get_mut()),
            MessageStream::Empty => IOSource::Empty,
        }
    }

    fn shutdown(&mut self, how: std::net::Shutdown) -> io::Result<()> {
        match self {
            MessageStream::ServerHandshake(mid) => mid.get_ref().get_ref(),
            MessageStream::ClientHandshake(mid) => mid.get_ref().get_ref(),
            MessageStream::Done(stream) => stream.get_ref(),
            MessageStream::Empty => return Ok(()),
        }
        .shutdown(how)
    }
}

struct FactoryInner {
    accepted_stream: Mutex<mpsc::Receiver<MessageStream>>,
    listen_handle: Option<JoinHandle<()>>,
}

impl Drop for FactoryInner {
    fn drop(&mut self) {
        if let Some(h) = self.listen_handle.take() {
            h.abort()
        }
    }
}

use super::tcp::tokio_to_mio_stream;

#[derive(Clone)]
pub struct Factory(Arc<FactoryInner>);

impl Factory {
    pub fn new(listen_addr: Option<SocketAddr>) -> Self {
        let (tx, accepted_stream) = mpsc::channel(1);
        let listen_handle = listen_addr.map(|listen_addr| {
            tokio::spawn(async move {
                let listener = match TcpListener::bind(listen_addr).await {
                    Ok(l) => l,
                    Err(e) => {
                        error!("[WebSocket] failed to bind to address {}: {}", listen_addr, e);
                        return
                    }
                };
                loop {
                    if let Ok((stream, _)) = listener.accept().await {
                        let stream = match tungstenite::accept(tokio_to_mio_stream(stream)) {
                            Ok(stream) => MessageStream::Done(stream),
                            Err(HandshakeError::Interrupted(mid)) => MessageStream::ServerHandshake(mid),
                            Err(e) => panic!("{:?}", e),
                        };
                        tx.send(stream).await.ok();
                    }
                }
            })
        });
        Self(Arc::new(FactoryInner {
            accepted_stream: Mutex::new(accepted_stream),
            listen_handle,
        }))
    }
}

#[async_trait]
impl AddressedStreamFactory for Factory {
    async fn create_stream(&self, url: &str) -> Option<Box<dyn NonBlockingStream>> {
        let url: url::Url = url.parse().ok()?;
        let addrs = url
            .socket_addrs(|| match url.scheme() {
                "wss" => Some(443),
                "ws" => Some(80),
                _ => None,
            })
            .ok()?;
        let tcp_stream = TokioTcpStream::connect(addrs[0]).await.ok()?;
        match tungstenite::client::client(url, tokio_to_mio_stream(tcp_stream)) {
            Ok((stream, _)) => Some(Box::new(MessageStream::Done(stream))),
            Err(HandshakeError::Interrupted(mid)) => Some(Box::new(MessageStream::ClientHandshake(mid))),
            Err(e) => {
                error!("[WebSocket] {:?}", e);
                None
            }
        }
    }

    async fn discover_stream(&self) -> Box<dyn NonBlockingStream> {
        match self.0.accepted_stream.lock().recv().await {
            None => futures::future::pending().await,
            Some(s) => Box::new(s),
        }
    }
}