airio-ws 0.1.0

WebSocket transport for airio
Documentation
mod framed;
use std::{
    io,
    net::SocketAddr,
    pin::Pin,
    task::{Context, Poll},
};

use airio_core::{DialOpts, ListenerId, TransportError, TransportEvent, utils::RwStreamSink};
pub use async_tungstenite::tungstenite::protocol::WebSocketConfig;
use async_tungstenite::{
    accept_async_with_config, client_async_with_config,
    tungstenite::{self, http::Uri},
};
pub use framed::BytesWebSocketStream;
use futures::{FutureExt, TryFutureExt};

pub struct Transport {
    config: Option<WebSocketConfig>,
    inner: airio_tcp::Transport,
}

impl Transport {
    pub fn new(config: Option<WebSocketConfig>, inner: airio_tcp::Transport) -> Self {
        Self { config, inner }
    }
}

impl airio_core::Transport for Transport {
    type Output = RwStreamSink<BytesWebSocketStream<airio_tcp::TcpStream>>;
    type Error = tungstenite::Error;
    type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
    type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;

    fn listen_on(
        &mut self,
        id: ListenerId,
        addr: SocketAddr,
    ) -> Result<(), TransportError<Self::Error>> {
        self.inner
            .listen_on(id, addr)
            .map_err(|e| e.map(|r| r.into()))
    }

    fn remove_listener(&mut self, id: ListenerId) -> bool {
        self.inner.remove_listener(id)
    }

    fn dial(
        &mut self,
        addr: SocketAddr,
        opts: DialOpts,
    ) -> Result<Self::Dial, TransportError<Self::Error>> {
        let dial_fut = self
            .inner
            .dial(addr, opts)
            .map_err(|e| e.map(|r| r.into()))?;

        let config = self.config.clone();
        let request = Uri::builder()
            .scheme("ws")
            .authority(addr.to_string())
            .path_and_query("/")
            .build()
            .map_err(tungstenite::Error::from)
            .map_err(TransportError::Other)?;

        Ok(dial_fut
            .map_err(tungstenite::Error::from)
            .and_then(move |stream| client_async_with_config(request, stream, config))
            .map_ok(|(s, _)| BytesWebSocketStream::new(s))
            .map_ok(RwStreamSink::new)
            .boxed())
    }

    fn poll(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
        let config = self.config.clone();

        Pin::new(&mut self.inner).poll(cx).map(|r| {
            r.map_upgrade(|u| {
                u.map_err(tungstenite::Error::from)
                    .and_then(move |s| accept_async_with_config(s, config))
                    .map_ok(BytesWebSocketStream::new)
                    .map_ok(RwStreamSink::new)
                    .boxed()
            })
            .map_err(tungstenite::Error::from)
        })
    }
}

fn into_io_error(error: tungstenite::Error) -> io::Error {
    match error {
        tungstenite::Error::Io(e) => e,
        e => io::Error::new(io::ErrorKind::Other, e),
    }
}